diff --git a/nb_jax/multilayer_perceptron.ipynb b/nb_jax/multilayer_perceptron.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..9227ac1a676a5a25a724d8c787bd059ddfa97703 --- /dev/null +++ b/nb_jax/multilayer_perceptron.ipynb @@ -0,0 +1,435 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "2359f207-049a-4089-ba71-caafcf5f27ca", + "metadata": {}, + "source": [ + "Example of multilayer perceptron with JAX\n", + "Adapted from https://towardsdatascience.com/getting-started-with-jax-mlps-cnns-rnns-d0bc389bd683 " + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "id": "81abb335", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "%config InlineBackend.figure_format = 'retina'\n", + "\n", + "import numpy as onp\n", + "import struct, time\n", + "\n", + "import jax.numpy as np\n", + "from jax import grad, jit, vmap, value_and_grad\n", + "from jax import random\n", + "from jax.scipy.special import logsumexp\n", + "from jax.example_libraries import optimizers\n", + "\n", + "# Generate key which is used to generate random numbers\n", + "key = random.PRNGKey(1)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "61ddcafe", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def ReLU(x):\n", + " \"\"\" Rectified Linear Unit (ReLU) activation function \"\"\"\n", + " return np.maximum(0, x)\n", + "\n", + "jit_ReLU = jit(ReLU)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "a177dfea", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "batch_dim = 32\n", + "feature_dim = 28*28\n", + "hidden_dim = 512\n", + "\n", + "# Generate a batch of vectors to process\n", + "X = random.normal(key, (batch_dim, feature_dim))\n", + "\n", + "# Generate Gaussian weights and biases\n", + "params = [random.normal(key, (hidden_dim, feature_dim)),\n", + " random.normal(key, (hidden_dim, ))] \n", + "\n", + "def relu_layer(params, x):\n", + " \"\"\" Simple ReLu layer for single sample \"\"\"\n", + " return ReLU(np.dot(params[0], x) + params[1])" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "2217c914", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(10000, 28, 28) (10000,)\n" + ] + } + ], + "source": [ + "data_dir = '../nb_bioinspired_pca/'\n", + "\n", + "# use test set as data here\n", + "with open(data_dir+'t10k-images-idx3-ubyte','rb') as f:\n", + " magic, size = struct.unpack(\">II\", f.read(8))\n", + " nrows, ncols = struct.unpack(\">II\", f.read(8))\n", + " image_data = onp.fromfile(f, dtype=np.dtype(np.uint8).newbyteorder('>'))\n", + " image_data = image_data.reshape((size, nrows, ncols))\n", + "\n", + "with open(data_dir+'t10k-labels-idx1-ubyte','rb') as f:\n", + " magic, size = struct.unpack(\">II\", f.read(8))\n", + " labels = onp.fromfile(f, dtype=np.dtype(np.uint8).newbyteorder('>'))\n", + "\n", + "\n", + "image_data = onp.array(image_data, dtype=float) / 255\n", + "labels = onp.array(labels, dtype=int)\n", + "\n", + "print(image_data.shape, labels.shape)\n", + "\n", + "# batch function\n", + "def data_loader(d, l, batch_size=100):\n", + " n = d.shape[0]\n", + " shuf_ind = onp.random.permutation(np.arange(n))\n", + " for ind_start in range(0,n,batch_size):\n", + " sel_ind = shuf_ind[range(ind_start,ind_start+batch_size)]\n", + " yield d[sel_ind,:,:], l[sel_ind]" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "995f5e41-057c-4331-97a8-8a3aa00e1923", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 (100, 28, 28) (100,)\n", + "1 (100, 28, 28) (100,)\n", + "2 (100, 28, 28) (100,)\n", + "3 (100, 28, 28) (100,)\n", + "4 (100, 28, 28) (100,)\n", + "5 (100, 28, 28) (100,)\n", + "6 (100, 28, 28) (100,)\n", + "7 (100, 28, 28) (100,)\n", + "8 (100, 28, 28) (100,)\n", + "9 (100, 28, 28) (100,)\n", + "10 (100, 28, 28) (100,)\n", + "11 (100, 28, 28) (100,)\n", + "12 (100, 28, 28) (100,)\n", + "13 (100, 28, 28) (100,)\n", + "14 (100, 28, 28) (100,)\n", + "15 (100, 28, 28) (100,)\n", + "16 (100, 28, 28) (100,)\n", + "17 (100, 28, 28) (100,)\n", + "18 (100, 28, 28) (100,)\n", + "19 (100, 28, 28) (100,)\n", + "20 (100, 28, 28) (100,)\n", + "21 (100, 28, 28) (100,)\n", + "22 (100, 28, 28) (100,)\n", + "23 (100, 28, 28) (100,)\n", + "24 (100, 28, 28) (100,)\n", + "25 (100, 28, 28) (100,)\n", + "26 (100, 28, 28) (100,)\n", + "27 (100, 28, 28) (100,)\n", + "28 (100, 28, 28) (100,)\n", + "29 (100, 28, 28) (100,)\n", + "30 (100, 28, 28) (100,)\n", + "31 (100, 28, 28) (100,)\n", + "32 (100, 28, 28) (100,)\n", + "33 (100, 28, 28) (100,)\n", + "34 (100, 28, 28) (100,)\n", + "35 (100, 28, 28) (100,)\n", + "36 (100, 28, 28) (100,)\n", + "37 (100, 28, 28) (100,)\n", + "38 (100, 28, 28) (100,)\n", + "39 (100, 28, 28) (100,)\n", + "40 (100, 28, 28) (100,)\n", + "41 (100, 28, 28) (100,)\n", + "42 (100, 28, 28) (100,)\n", + "43 (100, 28, 28) (100,)\n", + "44 (100, 28, 28) (100,)\n", + "45 (100, 28, 28) (100,)\n", + "46 (100, 28, 28) (100,)\n", + "47 (100, 28, 28) (100,)\n", + "48 (100, 28, 28) (100,)\n", + "49 (100, 28, 28) (100,)\n", + "50 (100, 28, 28) (100,)\n", + "51 (100, 28, 28) (100,)\n", + "52 (100, 28, 28) (100,)\n", + "53 (100, 28, 28) (100,)\n", + "54 (100, 28, 28) (100,)\n", + "55 (100, 28, 28) (100,)\n", + "56 (100, 28, 28) (100,)\n", + "57 (100, 28, 28) (100,)\n", + "58 (100, 28, 28) (100,)\n", + "59 (100, 28, 28) (100,)\n", + "60 (100, 28, 28) (100,)\n", + "61 (100, 28, 28) (100,)\n", + "62 (100, 28, 28) (100,)\n", + "63 (100, 28, 28) (100,)\n", + "64 (100, 28, 28) (100,)\n", + "65 (100, 28, 28) (100,)\n", + "66 (100, 28, 28) (100,)\n", + "67 (100, 28, 28) (100,)\n", + "68 (100, 28, 28) (100,)\n", + "69 (100, 28, 28) (100,)\n", + "70 (100, 28, 28) (100,)\n", + "71 (100, 28, 28) (100,)\n", + "72 (100, 28, 28) (100,)\n", + "73 (100, 28, 28) (100,)\n", + "74 (100, 28, 28) (100,)\n", + "75 (100, 28, 28) (100,)\n", + "76 (100, 28, 28) (100,)\n", + "77 (100, 28, 28) (100,)\n", + "78 (100, 28, 28) (100,)\n", + "79 (100, 28, 28) (100,)\n", + "80 (100, 28, 28) (100,)\n", + "81 (100, 28, 28) (100,)\n", + "82 (100, 28, 28) (100,)\n", + "83 (100, 28, 28) (100,)\n", + "84 (100, 28, 28) (100,)\n", + "85 (100, 28, 28) (100,)\n", + "86 (100, 28, 28) (100,)\n", + "87 (100, 28, 28) (100,)\n", + "88 (100, 28, 28) (100,)\n", + "89 (100, 28, 28) (100,)\n", + "90 (100, 28, 28) (100,)\n", + "91 (100, 28, 28) (100,)\n", + "92 (100, 28, 28) (100,)\n", + "93 (100, 28, 28) (100,)\n", + "94 (100, 28, 28) (100,)\n", + "95 (100, 28, 28) (100,)\n", + "96 (100, 28, 28) (100,)\n", + "97 (100, 28, 28) (100,)\n", + "98 (100, 28, 28) (100,)\n", + "99 (100, 28, 28) (100,)\n" + ] + } + ], + "source": [ + "for ind, (d,l) in enumerate(data_loader(image_data, labels)):\n", + " print(ind, d.shape, l.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "f6deb492", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def initialize_mlp(sizes, key):\n", + " \"\"\" Initialize the weights of all layers of a linear layer network \"\"\"\n", + " keys = random.split(key, len(sizes))\n", + " # Initialize a single layer with Gaussian weights - helper function\n", + " def initialize_layer(m, n, key, scale=1e-2):\n", + " w_key, b_key = random.split(key)\n", + " return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))\n", + " return [initialize_layer(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]\n", + "\n", + "layer_sizes = [784, 512, 512, 10]\n", + "# Return a list of tuples of layer weights\n", + "params = initialize_mlp(layer_sizes, key)" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "9c34e864", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def forward_pass(params, in_array):\n", + " \"\"\" Compute the forward pass for each example individually \"\"\"\n", + " activations = in_array\n", + " \n", + " # Loop over the ReLU hidden layers\n", + " for w, b in params[:-1]:\n", + " activations = relu_layer([w, b], activations)\n", + " \n", + " # Perform final trafo to logits\n", + " final_w, final_b = params[-1]\n", + " logits = np.dot(final_w, activations) + final_b\n", + " return logits - logsumexp(logits)\n", + "\n", + "# Make a batched version of the `predict` function\n", + "batch_forward = vmap(forward_pass, in_axes=(None, 0), out_axes=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "040ba6ec", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def one_hot(x, k, dtype=np.float32):\n", + " \"\"\"Create a one-hot encoding of x of size k \"\"\"\n", + " return np.array(x[:, None] == np.arange(k), dtype)\n", + "\n", + "def loss(params, in_arrays, targets):\n", + " \"\"\" Compute the multi-class cross-entropy loss \"\"\"\n", + " preds = batch_forward(params, in_arrays)\n", + " return -np.sum(preds * targets)\n", + " \n", + "def accuracy(params, data_loader):\n", + " \"\"\" Compute the accuracy for a provided dataloader \"\"\"\n", + " acc_total = 0\n", + " for batch_idx, (data, target) in enumerate(data_loader):\n", + " images = np.array(data).reshape(data.shape[0], 28*28)\n", + " targets = one_hot(np.array(target), num_classes)\n", + " \n", + " target_class = np.argmax(targets, axis=1)\n", + " predicted_class = np.argmax(batch_forward(params, images), axis=1)\n", + " acc_total += np.sum(predicted_class == target_class)\n", + " return acc_total/data.shape[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "d485b894", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "@jit\n", + "def update(params, x, y, opt_state):\n", + " \"\"\" Compute the gradient for a batch and update the parameters \"\"\"\n", + " value, grads = value_and_grad(loss)(params, x, y)\n", + " opt_state = opt_update(0, grads, opt_state)\n", + " return get_params(opt_state), opt_state, value\n", + "\n", + "# Defining an optimizer in Jax\n", + "step_size = 1e-3\n", + "opt_init, opt_update, get_params = optimizers.adam(step_size)\n", + "opt_state = opt_init(params)\n", + "\n", + "num_epochs = 10\n", + "num_classes = 10" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "id": "f7b3bee3-240b-4be9-9943-e2885dacf01a", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 | T: 0.34 | Train A: 92.340\n", + "Epoch 2 | T: 0.37 | Train A: 95.420\n", + "Epoch 3 | T: 0.35 | Train A: 96.610\n", + "Epoch 4 | T: 0.36 | Train A: 97.270\n", + "Epoch 5 | T: 0.37 | Train A: 98.510\n", + "Epoch 6 | T: 0.36 | Train A: 99.090\n", + "Epoch 7 | T: 0.36 | Train A: 99.170\n", + "Epoch 8 | T: 0.36 | Train A: 99.580\n", + "Epoch 9 | T: 0.37 | Train A: 99.720\n", + "Epoch 10 | T: 0.37 | Train A: 99.830\n" + ] + } + ], + "source": [ + "def run_mnist_training_loop(num_epochs, opt_state, net_type=\"MLP\"):\n", + " \"\"\" Implements a learning loop over epochs. \"\"\"\n", + " # Initialize placeholder for loggin\n", + " log_acc_train, log_acc_test, train_loss = [], [], []\n", + " \n", + " # Get the initial set of parameters \n", + " params = get_params(opt_state)\n", + " \n", + " # Get initial accuracy after random init\n", + " train_acc = accuracy(params, data_loader(image_data, labels))\n", + " log_acc_train.append(train_acc)\n", + " \n", + " # Loop over the training epochs\n", + " for epoch in range(num_epochs):\n", + " start_time = time.time()\n", + " for batch_idx, (data, target) in enumerate(data_loader(image_data, labels)):\n", + " if net_type == \"MLP\":\n", + " # Flatten the image into 784 vectors for the MLP\n", + " x = np.array(data).reshape(data.shape[0], 28*28)\n", + " elif net_type == \"CNN\":\n", + " # No flattening of the input required for the CNN\n", + " x = np.array(data)\n", + " y = one_hot(np.array(target), num_classes)\n", + " params, opt_state, loss = update(params, x, y, opt_state)\n", + " train_loss.append(loss)\n", + "\n", + " epoch_time = time.time() - start_time\n", + " train_acc = accuracy(params, data_loader(image_data, labels))\n", + " log_acc_train.append(train_acc)\n", + " print(\"Epoch {} | time: {:0.2f} | Train A: {:0.3f}\".format(epoch+1, epoch_time,train_acc))\n", + " \n", + " return train_loss, log_acc_train\n", + "\n", + "\n", + "train_loss, train_log = run_mnist_training_loop(num_epochs,\n", + " opt_state,\n", + " net_type=\"MLP\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}