Skip to content
Snippets Groups Projects
Commit 8d9ac605 authored by GILSON Matthieu's avatar GILSON Matthieu
Browse files

Upload New File

parent f3b97fb8
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id:2359f207-049a-4089-ba71-caafcf5f27ca tags:
Example of multilayer perceptron with JAX
Adapted from https://towardsdatascience.com/getting-started-with-jax-mlps-cnns-rnns-d0bc389bd683
%% Cell type:code id:81abb335 tags:
``` python
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import numpy as onp
import struct, time
import jax.numpy as np
from jax import grad, jit, vmap, value_and_grad
from jax import random
from jax.scipy.special import logsumexp
from jax.example_libraries import optimizers
# Generate key which is used to generate random numbers
key = random.PRNGKey(1)
```
%% Cell type:code id:61ddcafe tags:
``` python
def ReLU(x):
""" Rectified Linear Unit (ReLU) activation function """
return np.maximum(0, x)
jit_ReLU = jit(ReLU)
```
%% Cell type:code id:a177dfea tags:
``` python
batch_dim = 32
feature_dim = 28*28
hidden_dim = 512
# Generate a batch of vectors to process
X = random.normal(key, (batch_dim, feature_dim))
# Generate Gaussian weights and biases
params = [random.normal(key, (hidden_dim, feature_dim)),
random.normal(key, (hidden_dim, ))]
def relu_layer(params, x):
""" Simple ReLu layer for single sample """
return ReLU(np.dot(params[0], x) + params[1])
```
%% Cell type:code id:2217c914 tags:
``` python
data_dir = '../nb_bioinspired_pca/'
# use test set as data here
with open(data_dir+'t10k-images-idx3-ubyte','rb') as f:
magic, size = struct.unpack(">II", f.read(8))
nrows, ncols = struct.unpack(">II", f.read(8))
image_data = onp.fromfile(f, dtype=np.dtype(np.uint8).newbyteorder('>'))
image_data = image_data.reshape((size, nrows, ncols))
with open(data_dir+'t10k-labels-idx1-ubyte','rb') as f:
magic, size = struct.unpack(">II", f.read(8))
labels = onp.fromfile(f, dtype=np.dtype(np.uint8).newbyteorder('>'))
image_data = onp.array(image_data, dtype=float) / 255
labels = onp.array(labels, dtype=int)
print(image_data.shape, labels.shape)
# batch function
def data_loader(d, l, batch_size=100):
n = d.shape[0]
shuf_ind = onp.random.permutation(np.arange(n))
for ind_start in range(0,n,batch_size):
sel_ind = shuf_ind[range(ind_start,ind_start+batch_size)]
yield d[sel_ind,:,:], l[sel_ind]
```
%% Output
(10000, 28, 28) (10000,)
%% Cell type:code id:995f5e41-057c-4331-97a8-8a3aa00e1923 tags:
``` python
for ind, (d,l) in enumerate(data_loader(image_data, labels)):
print(ind, d.shape, l.shape)
```
%% Output
0 (100, 28, 28) (100,)
1 (100, 28, 28) (100,)
2 (100, 28, 28) (100,)
3 (100, 28, 28) (100,)
4 (100, 28, 28) (100,)
5 (100, 28, 28) (100,)
6 (100, 28, 28) (100,)
7 (100, 28, 28) (100,)
8 (100, 28, 28) (100,)
9 (100, 28, 28) (100,)
10 (100, 28, 28) (100,)
11 (100, 28, 28) (100,)
12 (100, 28, 28) (100,)
13 (100, 28, 28) (100,)
14 (100, 28, 28) (100,)
15 (100, 28, 28) (100,)
16 (100, 28, 28) (100,)
17 (100, 28, 28) (100,)
18 (100, 28, 28) (100,)
19 (100, 28, 28) (100,)
20 (100, 28, 28) (100,)
21 (100, 28, 28) (100,)
22 (100, 28, 28) (100,)
23 (100, 28, 28) (100,)
24 (100, 28, 28) (100,)
25 (100, 28, 28) (100,)
26 (100, 28, 28) (100,)
27 (100, 28, 28) (100,)
28 (100, 28, 28) (100,)
29 (100, 28, 28) (100,)
30 (100, 28, 28) (100,)
31 (100, 28, 28) (100,)
32 (100, 28, 28) (100,)
33 (100, 28, 28) (100,)
34 (100, 28, 28) (100,)
35 (100, 28, 28) (100,)
36 (100, 28, 28) (100,)
37 (100, 28, 28) (100,)
38 (100, 28, 28) (100,)
39 (100, 28, 28) (100,)
40 (100, 28, 28) (100,)
41 (100, 28, 28) (100,)
42 (100, 28, 28) (100,)
43 (100, 28, 28) (100,)
44 (100, 28, 28) (100,)
45 (100, 28, 28) (100,)
46 (100, 28, 28) (100,)
47 (100, 28, 28) (100,)
48 (100, 28, 28) (100,)
49 (100, 28, 28) (100,)
50 (100, 28, 28) (100,)
51 (100, 28, 28) (100,)
52 (100, 28, 28) (100,)
53 (100, 28, 28) (100,)
54 (100, 28, 28) (100,)
55 (100, 28, 28) (100,)
56 (100, 28, 28) (100,)
57 (100, 28, 28) (100,)
58 (100, 28, 28) (100,)
59 (100, 28, 28) (100,)
60 (100, 28, 28) (100,)
61 (100, 28, 28) (100,)
62 (100, 28, 28) (100,)
63 (100, 28, 28) (100,)
64 (100, 28, 28) (100,)
65 (100, 28, 28) (100,)
66 (100, 28, 28) (100,)
67 (100, 28, 28) (100,)
68 (100, 28, 28) (100,)
69 (100, 28, 28) (100,)
70 (100, 28, 28) (100,)
71 (100, 28, 28) (100,)
72 (100, 28, 28) (100,)
73 (100, 28, 28) (100,)
74 (100, 28, 28) (100,)
75 (100, 28, 28) (100,)
76 (100, 28, 28) (100,)
77 (100, 28, 28) (100,)
78 (100, 28, 28) (100,)
79 (100, 28, 28) (100,)
80 (100, 28, 28) (100,)
81 (100, 28, 28) (100,)
82 (100, 28, 28) (100,)
83 (100, 28, 28) (100,)
84 (100, 28, 28) (100,)
85 (100, 28, 28) (100,)
86 (100, 28, 28) (100,)
87 (100, 28, 28) (100,)
88 (100, 28, 28) (100,)
89 (100, 28, 28) (100,)
90 (100, 28, 28) (100,)
91 (100, 28, 28) (100,)
92 (100, 28, 28) (100,)
93 (100, 28, 28) (100,)
94 (100, 28, 28) (100,)
95 (100, 28, 28) (100,)
96 (100, 28, 28) (100,)
97 (100, 28, 28) (100,)
98 (100, 28, 28) (100,)
99 (100, 28, 28) (100,)
%% Cell type:code id:f6deb492 tags:
``` python
def initialize_mlp(sizes, key):
""" Initialize the weights of all layers of a linear layer network """
keys = random.split(key, len(sizes))
# Initialize a single layer with Gaussian weights - helper function
def initialize_layer(m, n, key, scale=1e-2):
w_key, b_key = random.split(key)
return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))
return [initialize_layer(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]
layer_sizes = [784, 512, 512, 10]
# Return a list of tuples of layer weights
params = initialize_mlp(layer_sizes, key)
```
%% Cell type:code id:9c34e864 tags:
``` python
def forward_pass(params, in_array):
""" Compute the forward pass for each example individually """
activations = in_array
# Loop over the ReLU hidden layers
for w, b in params[:-1]:
activations = relu_layer([w, b], activations)
# Perform final trafo to logits
final_w, final_b = params[-1]
logits = np.dot(final_w, activations) + final_b
return logits - logsumexp(logits)
# Make a batched version of the `predict` function
batch_forward = vmap(forward_pass, in_axes=(None, 0), out_axes=0)
```
%% Cell type:code id:040ba6ec tags:
``` python
def one_hot(x, k, dtype=np.float32):
"""Create a one-hot encoding of x of size k """
return np.array(x[:, None] == np.arange(k), dtype)
def loss(params, in_arrays, targets):
""" Compute the multi-class cross-entropy loss """
preds = batch_forward(params, in_arrays)
return -np.sum(preds * targets)
def accuracy(params, data_loader):
""" Compute the accuracy for a provided dataloader """
acc_total = 0
for batch_idx, (data, target) in enumerate(data_loader):
images = np.array(data).reshape(data.shape[0], 28*28)
targets = one_hot(np.array(target), num_classes)
target_class = np.argmax(targets, axis=1)
predicted_class = np.argmax(batch_forward(params, images), axis=1)
acc_total += np.sum(predicted_class == target_class)
return acc_total/data.shape[0]
```
%% Cell type:code id:d485b894 tags:
``` python
@jit
def update(params, x, y, opt_state):
""" Compute the gradient for a batch and update the parameters """
value, grads = value_and_grad(loss)(params, x, y)
opt_state = opt_update(0, grads, opt_state)
return get_params(opt_state), opt_state, value
# Defining an optimizer in Jax
step_size = 1e-3
opt_init, opt_update, get_params = optimizers.adam(step_size)
opt_state = opt_init(params)
num_epochs = 10
num_classes = 10
```
%% Cell type:code id:f7b3bee3-240b-4be9-9943-e2885dacf01a tags:
``` python
def run_mnist_training_loop(num_epochs, opt_state, net_type="MLP"):
""" Implements a learning loop over epochs. """
# Initialize placeholder for loggin
log_acc_train, log_acc_test, train_loss = [], [], []
# Get the initial set of parameters
params = get_params(opt_state)
# Get initial accuracy after random init
train_acc = accuracy(params, data_loader(image_data, labels))
log_acc_train.append(train_acc)
# Loop over the training epochs
for epoch in range(num_epochs):
start_time = time.time()
for batch_idx, (data, target) in enumerate(data_loader(image_data, labels)):
if net_type == "MLP":
# Flatten the image into 784 vectors for the MLP
x = np.array(data).reshape(data.shape[0], 28*28)
elif net_type == "CNN":
# No flattening of the input required for the CNN
x = np.array(data)
y = one_hot(np.array(target), num_classes)
params, opt_state, loss = update(params, x, y, opt_state)
train_loss.append(loss)
epoch_time = time.time() - start_time
train_acc = accuracy(params, data_loader(image_data, labels))
log_acc_train.append(train_acc)
print("Epoch {} | time: {:0.2f} | Train A: {:0.3f}".format(epoch+1, epoch_time,train_acc))
return train_loss, log_acc_train
train_loss, train_log = run_mnist_training_loop(num_epochs,
opt_state,
net_type="MLP")
```
%% Output
Epoch 1 | T: 0.34 | Train A: 92.340
Epoch 2 | T: 0.37 | Train A: 95.420
Epoch 3 | T: 0.35 | Train A: 96.610
Epoch 4 | T: 0.36 | Train A: 97.270
Epoch 5 | T: 0.37 | Train A: 98.510
Epoch 6 | T: 0.36 | Train A: 99.090
Epoch 7 | T: 0.36 | Train A: 99.170
Epoch 8 | T: 0.36 | Train A: 99.580
Epoch 9 | T: 0.37 | Train A: 99.720
Epoch 10 | T: 0.37 | Train A: 99.830
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment