Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
C
courseML_phd2023
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
GILSON Matthieu
courseML_phd2023
Commits
8d9ac605
Commit
8d9ac605
authored
2 years ago
by
GILSON Matthieu
Browse files
Options
Downloads
Patches
Plain Diff
Upload New File
parent
f3b97fb8
No related branches found
No related tags found
No related merge requests found
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
nb_jax/multilayer_perceptron.ipynb
+435
-0
435 additions, 0 deletions
nb_jax/multilayer_perceptron.ipynb
with
435 additions
and
0 deletions
nb_jax/multilayer_perceptron.ipynb
0 → 100644
+
435
−
0
View file @
8d9ac605
{
"cells": [
{
"cell_type": "markdown",
"id": "2359f207-049a-4089-ba71-caafcf5f27ca",
"metadata": {},
"source": [
"Example of multilayer perceptron with JAX\n",
"Adapted from https://towardsdatascience.com/getting-started-with-jax-mlps-cnns-rnns-d0bc389bd683 "
]
},
{
"cell_type": "code",
"execution_count": 54,
"id": "81abb335",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"%config InlineBackend.figure_format = 'retina'\n",
"\n",
"import numpy as onp\n",
"import struct, time\n",
"\n",
"import jax.numpy as np\n",
"from jax import grad, jit, vmap, value_and_grad\n",
"from jax import random\n",
"from jax.scipy.special import logsumexp\n",
"from jax.example_libraries import optimizers\n",
"\n",
"# Generate key which is used to generate random numbers\n",
"key = random.PRNGKey(1)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "61ddcafe",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def ReLU(x):\n",
" \"\"\" Rectified Linear Unit (ReLU) activation function \"\"\"\n",
" return np.maximum(0, x)\n",
"\n",
"jit_ReLU = jit(ReLU)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "a177dfea",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"batch_dim = 32\n",
"feature_dim = 28*28\n",
"hidden_dim = 512\n",
"\n",
"# Generate a batch of vectors to process\n",
"X = random.normal(key, (batch_dim, feature_dim))\n",
"\n",
"# Generate Gaussian weights and biases\n",
"params = [random.normal(key, (hidden_dim, feature_dim)),\n",
" random.normal(key, (hidden_dim, ))] \n",
"\n",
"def relu_layer(params, x):\n",
" \"\"\" Simple ReLu layer for single sample \"\"\"\n",
" return ReLU(np.dot(params[0], x) + params[1])"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "2217c914",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(10000, 28, 28) (10000,)\n"
]
}
],
"source": [
"data_dir = '../nb_bioinspired_pca/'\n",
"\n",
"# use test set as data here\n",
"with open(data_dir+'t10k-images-idx3-ubyte','rb') as f:\n",
" magic, size = struct.unpack(\">II\", f.read(8))\n",
" nrows, ncols = struct.unpack(\">II\", f.read(8))\n",
" image_data = onp.fromfile(f, dtype=np.dtype(np.uint8).newbyteorder('>'))\n",
" image_data = image_data.reshape((size, nrows, ncols))\n",
"\n",
"with open(data_dir+'t10k-labels-idx1-ubyte','rb') as f:\n",
" magic, size = struct.unpack(\">II\", f.read(8))\n",
" labels = onp.fromfile(f, dtype=np.dtype(np.uint8).newbyteorder('>'))\n",
"\n",
"\n",
"image_data = onp.array(image_data, dtype=float) / 255\n",
"labels = onp.array(labels, dtype=int)\n",
"\n",
"print(image_data.shape, labels.shape)\n",
"\n",
"# batch function\n",
"def data_loader(d, l, batch_size=100):\n",
" n = d.shape[0]\n",
" shuf_ind = onp.random.permutation(np.arange(n))\n",
" for ind_start in range(0,n,batch_size):\n",
" sel_ind = shuf_ind[range(ind_start,ind_start+batch_size)]\n",
" yield d[sel_ind,:,:], l[sel_ind]"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "995f5e41-057c-4331-97a8-8a3aa00e1923",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 (100, 28, 28) (100,)\n",
"1 (100, 28, 28) (100,)\n",
"2 (100, 28, 28) (100,)\n",
"3 (100, 28, 28) (100,)\n",
"4 (100, 28, 28) (100,)\n",
"5 (100, 28, 28) (100,)\n",
"6 (100, 28, 28) (100,)\n",
"7 (100, 28, 28) (100,)\n",
"8 (100, 28, 28) (100,)\n",
"9 (100, 28, 28) (100,)\n",
"10 (100, 28, 28) (100,)\n",
"11 (100, 28, 28) (100,)\n",
"12 (100, 28, 28) (100,)\n",
"13 (100, 28, 28) (100,)\n",
"14 (100, 28, 28) (100,)\n",
"15 (100, 28, 28) (100,)\n",
"16 (100, 28, 28) (100,)\n",
"17 (100, 28, 28) (100,)\n",
"18 (100, 28, 28) (100,)\n",
"19 (100, 28, 28) (100,)\n",
"20 (100, 28, 28) (100,)\n",
"21 (100, 28, 28) (100,)\n",
"22 (100, 28, 28) (100,)\n",
"23 (100, 28, 28) (100,)\n",
"24 (100, 28, 28) (100,)\n",
"25 (100, 28, 28) (100,)\n",
"26 (100, 28, 28) (100,)\n",
"27 (100, 28, 28) (100,)\n",
"28 (100, 28, 28) (100,)\n",
"29 (100, 28, 28) (100,)\n",
"30 (100, 28, 28) (100,)\n",
"31 (100, 28, 28) (100,)\n",
"32 (100, 28, 28) (100,)\n",
"33 (100, 28, 28) (100,)\n",
"34 (100, 28, 28) (100,)\n",
"35 (100, 28, 28) (100,)\n",
"36 (100, 28, 28) (100,)\n",
"37 (100, 28, 28) (100,)\n",
"38 (100, 28, 28) (100,)\n",
"39 (100, 28, 28) (100,)\n",
"40 (100, 28, 28) (100,)\n",
"41 (100, 28, 28) (100,)\n",
"42 (100, 28, 28) (100,)\n",
"43 (100, 28, 28) (100,)\n",
"44 (100, 28, 28) (100,)\n",
"45 (100, 28, 28) (100,)\n",
"46 (100, 28, 28) (100,)\n",
"47 (100, 28, 28) (100,)\n",
"48 (100, 28, 28) (100,)\n",
"49 (100, 28, 28) (100,)\n",
"50 (100, 28, 28) (100,)\n",
"51 (100, 28, 28) (100,)\n",
"52 (100, 28, 28) (100,)\n",
"53 (100, 28, 28) (100,)\n",
"54 (100, 28, 28) (100,)\n",
"55 (100, 28, 28) (100,)\n",
"56 (100, 28, 28) (100,)\n",
"57 (100, 28, 28) (100,)\n",
"58 (100, 28, 28) (100,)\n",
"59 (100, 28, 28) (100,)\n",
"60 (100, 28, 28) (100,)\n",
"61 (100, 28, 28) (100,)\n",
"62 (100, 28, 28) (100,)\n",
"63 (100, 28, 28) (100,)\n",
"64 (100, 28, 28) (100,)\n",
"65 (100, 28, 28) (100,)\n",
"66 (100, 28, 28) (100,)\n",
"67 (100, 28, 28) (100,)\n",
"68 (100, 28, 28) (100,)\n",
"69 (100, 28, 28) (100,)\n",
"70 (100, 28, 28) (100,)\n",
"71 (100, 28, 28) (100,)\n",
"72 (100, 28, 28) (100,)\n",
"73 (100, 28, 28) (100,)\n",
"74 (100, 28, 28) (100,)\n",
"75 (100, 28, 28) (100,)\n",
"76 (100, 28, 28) (100,)\n",
"77 (100, 28, 28) (100,)\n",
"78 (100, 28, 28) (100,)\n",
"79 (100, 28, 28) (100,)\n",
"80 (100, 28, 28) (100,)\n",
"81 (100, 28, 28) (100,)\n",
"82 (100, 28, 28) (100,)\n",
"83 (100, 28, 28) (100,)\n",
"84 (100, 28, 28) (100,)\n",
"85 (100, 28, 28) (100,)\n",
"86 (100, 28, 28) (100,)\n",
"87 (100, 28, 28) (100,)\n",
"88 (100, 28, 28) (100,)\n",
"89 (100, 28, 28) (100,)\n",
"90 (100, 28, 28) (100,)\n",
"91 (100, 28, 28) (100,)\n",
"92 (100, 28, 28) (100,)\n",
"93 (100, 28, 28) (100,)\n",
"94 (100, 28, 28) (100,)\n",
"95 (100, 28, 28) (100,)\n",
"96 (100, 28, 28) (100,)\n",
"97 (100, 28, 28) (100,)\n",
"98 (100, 28, 28) (100,)\n",
"99 (100, 28, 28) (100,)\n"
]
}
],
"source": [
"for ind, (d,l) in enumerate(data_loader(image_data, labels)):\n",
" print(ind, d.shape, l.shape)"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "f6deb492",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def initialize_mlp(sizes, key):\n",
" \"\"\" Initialize the weights of all layers of a linear layer network \"\"\"\n",
" keys = random.split(key, len(sizes))\n",
" # Initialize a single layer with Gaussian weights - helper function\n",
" def initialize_layer(m, n, key, scale=1e-2):\n",
" w_key, b_key = random.split(key)\n",
" return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))\n",
" return [initialize_layer(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]\n",
"\n",
"layer_sizes = [784, 512, 512, 10]\n",
"# Return a list of tuples of layer weights\n",
"params = initialize_mlp(layer_sizes, key)"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "9c34e864",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def forward_pass(params, in_array):\n",
" \"\"\" Compute the forward pass for each example individually \"\"\"\n",
" activations = in_array\n",
" \n",
" # Loop over the ReLU hidden layers\n",
" for w, b in params[:-1]:\n",
" activations = relu_layer([w, b], activations)\n",
" \n",
" # Perform final trafo to logits\n",
" final_w, final_b = params[-1]\n",
" logits = np.dot(final_w, activations) + final_b\n",
" return logits - logsumexp(logits)\n",
"\n",
"# Make a batched version of the `predict` function\n",
"batch_forward = vmap(forward_pass, in_axes=(None, 0), out_axes=0)"
]
},
{
"cell_type": "code",
"execution_count": 51,
"id": "040ba6ec",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def one_hot(x, k, dtype=np.float32):\n",
" \"\"\"Create a one-hot encoding of x of size k \"\"\"\n",
" return np.array(x[:, None] == np.arange(k), dtype)\n",
"\n",
"def loss(params, in_arrays, targets):\n",
" \"\"\" Compute the multi-class cross-entropy loss \"\"\"\n",
" preds = batch_forward(params, in_arrays)\n",
" return -np.sum(preds * targets)\n",
" \n",
"def accuracy(params, data_loader):\n",
" \"\"\" Compute the accuracy for a provided dataloader \"\"\"\n",
" acc_total = 0\n",
" for batch_idx, (data, target) in enumerate(data_loader):\n",
" images = np.array(data).reshape(data.shape[0], 28*28)\n",
" targets = one_hot(np.array(target), num_classes)\n",
" \n",
" target_class = np.argmax(targets, axis=1)\n",
" predicted_class = np.argmax(batch_forward(params, images), axis=1)\n",
" acc_total += np.sum(predicted_class == target_class)\n",
" return acc_total/data.shape[0]"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "d485b894",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"@jit\n",
"def update(params, x, y, opt_state):\n",
" \"\"\" Compute the gradient for a batch and update the parameters \"\"\"\n",
" value, grads = value_and_grad(loss)(params, x, y)\n",
" opt_state = opt_update(0, grads, opt_state)\n",
" return get_params(opt_state), opt_state, value\n",
"\n",
"# Defining an optimizer in Jax\n",
"step_size = 1e-3\n",
"opt_init, opt_update, get_params = optimizers.adam(step_size)\n",
"opt_state = opt_init(params)\n",
"\n",
"num_epochs = 10\n",
"num_classes = 10"
]
},
{
"cell_type": "code",
"execution_count": 56,
"id": "f7b3bee3-240b-4be9-9943-e2885dacf01a",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1 | T: 0.34 | Train A: 92.340\n",
"Epoch 2 | T: 0.37 | Train A: 95.420\n",
"Epoch 3 | T: 0.35 | Train A: 96.610\n",
"Epoch 4 | T: 0.36 | Train A: 97.270\n",
"Epoch 5 | T: 0.37 | Train A: 98.510\n",
"Epoch 6 | T: 0.36 | Train A: 99.090\n",
"Epoch 7 | T: 0.36 | Train A: 99.170\n",
"Epoch 8 | T: 0.36 | Train A: 99.580\n",
"Epoch 9 | T: 0.37 | Train A: 99.720\n",
"Epoch 10 | T: 0.37 | Train A: 99.830\n"
]
}
],
"source": [
"def run_mnist_training_loop(num_epochs, opt_state, net_type=\"MLP\"):\n",
" \"\"\" Implements a learning loop over epochs. \"\"\"\n",
" # Initialize placeholder for loggin\n",
" log_acc_train, log_acc_test, train_loss = [], [], []\n",
" \n",
" # Get the initial set of parameters \n",
" params = get_params(opt_state)\n",
" \n",
" # Get initial accuracy after random init\n",
" train_acc = accuracy(params, data_loader(image_data, labels))\n",
" log_acc_train.append(train_acc)\n",
" \n",
" # Loop over the training epochs\n",
" for epoch in range(num_epochs):\n",
" start_time = time.time()\n",
" for batch_idx, (data, target) in enumerate(data_loader(image_data, labels)):\n",
" if net_type == \"MLP\":\n",
" # Flatten the image into 784 vectors for the MLP\n",
" x = np.array(data).reshape(data.shape[0], 28*28)\n",
" elif net_type == \"CNN\":\n",
" # No flattening of the input required for the CNN\n",
" x = np.array(data)\n",
" y = one_hot(np.array(target), num_classes)\n",
" params, opt_state, loss = update(params, x, y, opt_state)\n",
" train_loss.append(loss)\n",
"\n",
" epoch_time = time.time() - start_time\n",
" train_acc = accuracy(params, data_loader(image_data, labels))\n",
" log_acc_train.append(train_acc)\n",
" print(\"Epoch {} | time: {:0.2f} | Train A: {:0.3f}\".format(epoch+1, epoch_time,train_acc))\n",
" \n",
" return train_loss, log_acc_train\n",
"\n",
"\n",
"train_loss, train_log = run_mnist_training_loop(num_epochs,\n",
" opt_state,\n",
" net_type=\"MLP\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
%% Cell type:markdown id:2359f207-049a-4089-ba71-caafcf5f27ca tags:
Example of multilayer perceptron with JAX
Adapted from https://towardsdatascience.com/getting-started-with-jax-mlps-cnns-rnns-d0bc389bd683
%% Cell type:code id:81abb335 tags:
```
python
%
matplotlib
inline
%
config
InlineBackend
.
figure_format
=
'
retina
'
import
numpy
as
onp
import
struct
,
time
import
jax.numpy
as
np
from
jax
import
grad
,
jit
,
vmap
,
value_and_grad
from
jax
import
random
from
jax.scipy.special
import
logsumexp
from
jax.example_libraries
import
optimizers
# Generate key which is used to generate random numbers
key
=
random
.
PRNGKey
(
1
)
```
%% Cell type:code id:61ddcafe tags:
```
python
def
ReLU
(
x
):
"""
Rectified Linear Unit (ReLU) activation function
"""
return
np
.
maximum
(
0
,
x
)
jit_ReLU
=
jit
(
ReLU
)
```
%% Cell type:code id:a177dfea tags:
```
python
batch_dim
=
32
feature_dim
=
28
*
28
hidden_dim
=
512
# Generate a batch of vectors to process
X
=
random
.
normal
(
key
,
(
batch_dim
,
feature_dim
))
# Generate Gaussian weights and biases
params
=
[
random
.
normal
(
key
,
(
hidden_dim
,
feature_dim
)),
random
.
normal
(
key
,
(
hidden_dim
,
))]
def
relu_layer
(
params
,
x
):
"""
Simple ReLu layer for single sample
"""
return
ReLU
(
np
.
dot
(
params
[
0
],
x
)
+
params
[
1
])
```
%% Cell type:code id:2217c914 tags:
```
python
data_dir
=
'
../nb_bioinspired_pca/
'
# use test set as data here
with
open
(
data_dir
+
'
t10k-images-idx3-ubyte
'
,
'
rb
'
)
as
f
:
magic
,
size
=
struct
.
unpack
(
"
>II
"
,
f
.
read
(
8
))
nrows
,
ncols
=
struct
.
unpack
(
"
>II
"
,
f
.
read
(
8
))
image_data
=
onp
.
fromfile
(
f
,
dtype
=
np
.
dtype
(
np
.
uint8
).
newbyteorder
(
'
>
'
))
image_data
=
image_data
.
reshape
((
size
,
nrows
,
ncols
))
with
open
(
data_dir
+
'
t10k-labels-idx1-ubyte
'
,
'
rb
'
)
as
f
:
magic
,
size
=
struct
.
unpack
(
"
>II
"
,
f
.
read
(
8
))
labels
=
onp
.
fromfile
(
f
,
dtype
=
np
.
dtype
(
np
.
uint8
).
newbyteorder
(
'
>
'
))
image_data
=
onp
.
array
(
image_data
,
dtype
=
float
)
/
255
labels
=
onp
.
array
(
labels
,
dtype
=
int
)
print
(
image_data
.
shape
,
labels
.
shape
)
# batch function
def
data_loader
(
d
,
l
,
batch_size
=
100
):
n
=
d
.
shape
[
0
]
shuf_ind
=
onp
.
random
.
permutation
(
np
.
arange
(
n
))
for
ind_start
in
range
(
0
,
n
,
batch_size
):
sel_ind
=
shuf_ind
[
range
(
ind_start
,
ind_start
+
batch_size
)]
yield
d
[
sel_ind
,:,:],
l
[
sel_ind
]
```
%% Output
(10000, 28, 28) (10000,)
%% Cell type:code id:995f5e41-057c-4331-97a8-8a3aa00e1923 tags:
```
python
for
ind
,
(
d
,
l
)
in
enumerate
(
data_loader
(
image_data
,
labels
)):
print
(
ind
,
d
.
shape
,
l
.
shape
)
```
%% Output
0 (100, 28, 28) (100,)
1 (100, 28, 28) (100,)
2 (100, 28, 28) (100,)
3 (100, 28, 28) (100,)
4 (100, 28, 28) (100,)
5 (100, 28, 28) (100,)
6 (100, 28, 28) (100,)
7 (100, 28, 28) (100,)
8 (100, 28, 28) (100,)
9 (100, 28, 28) (100,)
10 (100, 28, 28) (100,)
11 (100, 28, 28) (100,)
12 (100, 28, 28) (100,)
13 (100, 28, 28) (100,)
14 (100, 28, 28) (100,)
15 (100, 28, 28) (100,)
16 (100, 28, 28) (100,)
17 (100, 28, 28) (100,)
18 (100, 28, 28) (100,)
19 (100, 28, 28) (100,)
20 (100, 28, 28) (100,)
21 (100, 28, 28) (100,)
22 (100, 28, 28) (100,)
23 (100, 28, 28) (100,)
24 (100, 28, 28) (100,)
25 (100, 28, 28) (100,)
26 (100, 28, 28) (100,)
27 (100, 28, 28) (100,)
28 (100, 28, 28) (100,)
29 (100, 28, 28) (100,)
30 (100, 28, 28) (100,)
31 (100, 28, 28) (100,)
32 (100, 28, 28) (100,)
33 (100, 28, 28) (100,)
34 (100, 28, 28) (100,)
35 (100, 28, 28) (100,)
36 (100, 28, 28) (100,)
37 (100, 28, 28) (100,)
38 (100, 28, 28) (100,)
39 (100, 28, 28) (100,)
40 (100, 28, 28) (100,)
41 (100, 28, 28) (100,)
42 (100, 28, 28) (100,)
43 (100, 28, 28) (100,)
44 (100, 28, 28) (100,)
45 (100, 28, 28) (100,)
46 (100, 28, 28) (100,)
47 (100, 28, 28) (100,)
48 (100, 28, 28) (100,)
49 (100, 28, 28) (100,)
50 (100, 28, 28) (100,)
51 (100, 28, 28) (100,)
52 (100, 28, 28) (100,)
53 (100, 28, 28) (100,)
54 (100, 28, 28) (100,)
55 (100, 28, 28) (100,)
56 (100, 28, 28) (100,)
57 (100, 28, 28) (100,)
58 (100, 28, 28) (100,)
59 (100, 28, 28) (100,)
60 (100, 28, 28) (100,)
61 (100, 28, 28) (100,)
62 (100, 28, 28) (100,)
63 (100, 28, 28) (100,)
64 (100, 28, 28) (100,)
65 (100, 28, 28) (100,)
66 (100, 28, 28) (100,)
67 (100, 28, 28) (100,)
68 (100, 28, 28) (100,)
69 (100, 28, 28) (100,)
70 (100, 28, 28) (100,)
71 (100, 28, 28) (100,)
72 (100, 28, 28) (100,)
73 (100, 28, 28) (100,)
74 (100, 28, 28) (100,)
75 (100, 28, 28) (100,)
76 (100, 28, 28) (100,)
77 (100, 28, 28) (100,)
78 (100, 28, 28) (100,)
79 (100, 28, 28) (100,)
80 (100, 28, 28) (100,)
81 (100, 28, 28) (100,)
82 (100, 28, 28) (100,)
83 (100, 28, 28) (100,)
84 (100, 28, 28) (100,)
85 (100, 28, 28) (100,)
86 (100, 28, 28) (100,)
87 (100, 28, 28) (100,)
88 (100, 28, 28) (100,)
89 (100, 28, 28) (100,)
90 (100, 28, 28) (100,)
91 (100, 28, 28) (100,)
92 (100, 28, 28) (100,)
93 (100, 28, 28) (100,)
94 (100, 28, 28) (100,)
95 (100, 28, 28) (100,)
96 (100, 28, 28) (100,)
97 (100, 28, 28) (100,)
98 (100, 28, 28) (100,)
99 (100, 28, 28) (100,)
%% Cell type:code id:f6deb492 tags:
```
python
def
initialize_mlp
(
sizes
,
key
):
"""
Initialize the weights of all layers of a linear layer network
"""
keys
=
random
.
split
(
key
,
len
(
sizes
))
# Initialize a single layer with Gaussian weights - helper function
def
initialize_layer
(
m
,
n
,
key
,
scale
=
1e-2
):
w_key
,
b_key
=
random
.
split
(
key
)
return
scale
*
random
.
normal
(
w_key
,
(
n
,
m
)),
scale
*
random
.
normal
(
b_key
,
(
n
,))
return
[
initialize_layer
(
m
,
n
,
k
)
for
m
,
n
,
k
in
zip
(
sizes
[:
-
1
],
sizes
[
1
:],
keys
)]
layer_sizes
=
[
784
,
512
,
512
,
10
]
# Return a list of tuples of layer weights
params
=
initialize_mlp
(
layer_sizes
,
key
)
```
%% Cell type:code id:9c34e864 tags:
```
python
def
forward_pass
(
params
,
in_array
):
"""
Compute the forward pass for each example individually
"""
activations
=
in_array
# Loop over the ReLU hidden layers
for
w
,
b
in
params
[:
-
1
]:
activations
=
relu_layer
([
w
,
b
],
activations
)
# Perform final trafo to logits
final_w
,
final_b
=
params
[
-
1
]
logits
=
np
.
dot
(
final_w
,
activations
)
+
final_b
return
logits
-
logsumexp
(
logits
)
# Make a batched version of the `predict` function
batch_forward
=
vmap
(
forward_pass
,
in_axes
=
(
None
,
0
),
out_axes
=
0
)
```
%% Cell type:code id:040ba6ec tags:
```
python
def
one_hot
(
x
,
k
,
dtype
=
np
.
float32
):
"""
Create a one-hot encoding of x of size k
"""
return
np
.
array
(
x
[:,
None
]
==
np
.
arange
(
k
),
dtype
)
def
loss
(
params
,
in_arrays
,
targets
):
"""
Compute the multi-class cross-entropy loss
"""
preds
=
batch_forward
(
params
,
in_arrays
)
return
-
np
.
sum
(
preds
*
targets
)
def
accuracy
(
params
,
data_loader
):
"""
Compute the accuracy for a provided dataloader
"""
acc_total
=
0
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
data_loader
):
images
=
np
.
array
(
data
).
reshape
(
data
.
shape
[
0
],
28
*
28
)
targets
=
one_hot
(
np
.
array
(
target
),
num_classes
)
target_class
=
np
.
argmax
(
targets
,
axis
=
1
)
predicted_class
=
np
.
argmax
(
batch_forward
(
params
,
images
),
axis
=
1
)
acc_total
+=
np
.
sum
(
predicted_class
==
target_class
)
return
acc_total
/
data
.
shape
[
0
]
```
%% Cell type:code id:d485b894 tags:
```
python
@jit
def
update
(
params
,
x
,
y
,
opt_state
):
"""
Compute the gradient for a batch and update the parameters
"""
value
,
grads
=
value_and_grad
(
loss
)(
params
,
x
,
y
)
opt_state
=
opt_update
(
0
,
grads
,
opt_state
)
return
get_params
(
opt_state
),
opt_state
,
value
# Defining an optimizer in Jax
step_size
=
1e-3
opt_init
,
opt_update
,
get_params
=
optimizers
.
adam
(
step_size
)
opt_state
=
opt_init
(
params
)
num_epochs
=
10
num_classes
=
10
```
%% Cell type:code id:f7b3bee3-240b-4be9-9943-e2885dacf01a tags:
```
python
def
run_mnist_training_loop
(
num_epochs
,
opt_state
,
net_type
=
"
MLP
"
):
"""
Implements a learning loop over epochs.
"""
# Initialize placeholder for loggin
log_acc_train
,
log_acc_test
,
train_loss
=
[],
[],
[]
# Get the initial set of parameters
params
=
get_params
(
opt_state
)
# Get initial accuracy after random init
train_acc
=
accuracy
(
params
,
data_loader
(
image_data
,
labels
))
log_acc_train
.
append
(
train_acc
)
# Loop over the training epochs
for
epoch
in
range
(
num_epochs
):
start_time
=
time
.
time
()
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
data_loader
(
image_data
,
labels
)):
if
net_type
==
"
MLP
"
:
# Flatten the image into 784 vectors for the MLP
x
=
np
.
array
(
data
).
reshape
(
data
.
shape
[
0
],
28
*
28
)
elif
net_type
==
"
CNN
"
:
# No flattening of the input required for the CNN
x
=
np
.
array
(
data
)
y
=
one_hot
(
np
.
array
(
target
),
num_classes
)
params
,
opt_state
,
loss
=
update
(
params
,
x
,
y
,
opt_state
)
train_loss
.
append
(
loss
)
epoch_time
=
time
.
time
()
-
start_time
train_acc
=
accuracy
(
params
,
data_loader
(
image_data
,
labels
))
log_acc_train
.
append
(
train_acc
)
print
(
"
Epoch {} | time: {:0.2f} | Train A: {:0.3f}
"
.
format
(
epoch
+
1
,
epoch_time
,
train_acc
))
return
train_loss
,
log_acc_train
train_loss
,
train_log
=
run_mnist_training_loop
(
num_epochs
,
opt_state
,
net_type
=
"
MLP
"
)
```
%% Output
Epoch 1 | T: 0.34 | Train A: 92.340
Epoch 2 | T: 0.37 | Train A: 95.420
Epoch 3 | T: 0.35 | Train A: 96.610
Epoch 4 | T: 0.36 | Train A: 97.270
Epoch 5 | T: 0.37 | Train A: 98.510
Epoch 6 | T: 0.36 | Train A: 99.090
Epoch 7 | T: 0.36 | Train A: 99.170
Epoch 8 | T: 0.36 | Train A: 99.580
Epoch 9 | T: 0.37 | Train A: 99.720
Epoch 10 | T: 0.37 | Train A: 99.830
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment