diff --git a/nb_jax/multilayer_perceptron.ipynb b/nb_jax/multilayer_perceptron.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..9227ac1a676a5a25a724d8c787bd059ddfa97703
--- /dev/null
+++ b/nb_jax/multilayer_perceptron.ipynb
@@ -0,0 +1,435 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "2359f207-049a-4089-ba71-caafcf5f27ca",
+   "metadata": {},
+   "source": [
+    "Example of multilayer perceptron with JAX\n",
+    "Adapted from https://towardsdatascience.com/getting-started-with-jax-mlps-cnns-rnns-d0bc389bd683 "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 54,
+   "id": "81abb335",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "%matplotlib inline\n",
+    "%config InlineBackend.figure_format = 'retina'\n",
+    "\n",
+    "import numpy as onp\n",
+    "import struct, time\n",
+    "\n",
+    "import jax.numpy as np\n",
+    "from jax import grad, jit, vmap, value_and_grad\n",
+    "from jax import random\n",
+    "from jax.scipy.special import logsumexp\n",
+    "from jax.example_libraries import optimizers\n",
+    "\n",
+    "# Generate key which is used to generate random numbers\n",
+    "key = random.PRNGKey(1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 19,
+   "id": "61ddcafe",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "def ReLU(x):\n",
+    "    \"\"\" Rectified Linear Unit (ReLU) activation function \"\"\"\n",
+    "    return np.maximum(0, x)\n",
+    "\n",
+    "jit_ReLU = jit(ReLU)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 20,
+   "id": "a177dfea",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "batch_dim = 32\n",
+    "feature_dim = 28*28\n",
+    "hidden_dim = 512\n",
+    "\n",
+    "# Generate a batch of vectors to process\n",
+    "X = random.normal(key, (batch_dim, feature_dim))\n",
+    "\n",
+    "# Generate Gaussian weights and biases\n",
+    "params = [random.normal(key, (hidden_dim, feature_dim)),\n",
+    "          random.normal(key, (hidden_dim, ))] \n",
+    "\n",
+    "def relu_layer(params, x):\n",
+    "    \"\"\" Simple ReLu layer for single sample \"\"\"\n",
+    "    return ReLU(np.dot(params[0], x) + params[1])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 45,
+   "id": "2217c914",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "(10000, 28, 28) (10000,)\n"
+     ]
+    }
+   ],
+   "source": [
+    "data_dir = '../nb_bioinspired_pca/'\n",
+    "\n",
+    "# use test set as data here\n",
+    "with open(data_dir+'t10k-images-idx3-ubyte','rb') as f:\n",
+    "    magic, size = struct.unpack(\">II\", f.read(8))\n",
+    "    nrows, ncols = struct.unpack(\">II\", f.read(8))\n",
+    "    image_data = onp.fromfile(f, dtype=np.dtype(np.uint8).newbyteorder('>'))\n",
+    "    image_data = image_data.reshape((size, nrows, ncols))\n",
+    "\n",
+    "with open(data_dir+'t10k-labels-idx1-ubyte','rb') as f:\n",
+    "    magic, size = struct.unpack(\">II\", f.read(8))\n",
+    "    labels = onp.fromfile(f, dtype=np.dtype(np.uint8).newbyteorder('>'))\n",
+    "\n",
+    "\n",
+    "image_data = onp.array(image_data, dtype=float) / 255\n",
+    "labels = onp.array(labels, dtype=int)\n",
+    "\n",
+    "print(image_data.shape, labels.shape)\n",
+    "\n",
+    "# batch function\n",
+    "def data_loader(d, l, batch_size=100):\n",
+    "    n = d.shape[0]\n",
+    "    shuf_ind = onp.random.permutation(np.arange(n))\n",
+    "    for ind_start in range(0,n,batch_size):\n",
+    "        sel_ind = shuf_ind[range(ind_start,ind_start+batch_size)]\n",
+    "        yield d[sel_ind,:,:], l[sel_ind]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 46,
+   "id": "995f5e41-057c-4331-97a8-8a3aa00e1923",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "0 (100, 28, 28) (100,)\n",
+      "1 (100, 28, 28) (100,)\n",
+      "2 (100, 28, 28) (100,)\n",
+      "3 (100, 28, 28) (100,)\n",
+      "4 (100, 28, 28) (100,)\n",
+      "5 (100, 28, 28) (100,)\n",
+      "6 (100, 28, 28) (100,)\n",
+      "7 (100, 28, 28) (100,)\n",
+      "8 (100, 28, 28) (100,)\n",
+      "9 (100, 28, 28) (100,)\n",
+      "10 (100, 28, 28) (100,)\n",
+      "11 (100, 28, 28) (100,)\n",
+      "12 (100, 28, 28) (100,)\n",
+      "13 (100, 28, 28) (100,)\n",
+      "14 (100, 28, 28) (100,)\n",
+      "15 (100, 28, 28) (100,)\n",
+      "16 (100, 28, 28) (100,)\n",
+      "17 (100, 28, 28) (100,)\n",
+      "18 (100, 28, 28) (100,)\n",
+      "19 (100, 28, 28) (100,)\n",
+      "20 (100, 28, 28) (100,)\n",
+      "21 (100, 28, 28) (100,)\n",
+      "22 (100, 28, 28) (100,)\n",
+      "23 (100, 28, 28) (100,)\n",
+      "24 (100, 28, 28) (100,)\n",
+      "25 (100, 28, 28) (100,)\n",
+      "26 (100, 28, 28) (100,)\n",
+      "27 (100, 28, 28) (100,)\n",
+      "28 (100, 28, 28) (100,)\n",
+      "29 (100, 28, 28) (100,)\n",
+      "30 (100, 28, 28) (100,)\n",
+      "31 (100, 28, 28) (100,)\n",
+      "32 (100, 28, 28) (100,)\n",
+      "33 (100, 28, 28) (100,)\n",
+      "34 (100, 28, 28) (100,)\n",
+      "35 (100, 28, 28) (100,)\n",
+      "36 (100, 28, 28) (100,)\n",
+      "37 (100, 28, 28) (100,)\n",
+      "38 (100, 28, 28) (100,)\n",
+      "39 (100, 28, 28) (100,)\n",
+      "40 (100, 28, 28) (100,)\n",
+      "41 (100, 28, 28) (100,)\n",
+      "42 (100, 28, 28) (100,)\n",
+      "43 (100, 28, 28) (100,)\n",
+      "44 (100, 28, 28) (100,)\n",
+      "45 (100, 28, 28) (100,)\n",
+      "46 (100, 28, 28) (100,)\n",
+      "47 (100, 28, 28) (100,)\n",
+      "48 (100, 28, 28) (100,)\n",
+      "49 (100, 28, 28) (100,)\n",
+      "50 (100, 28, 28) (100,)\n",
+      "51 (100, 28, 28) (100,)\n",
+      "52 (100, 28, 28) (100,)\n",
+      "53 (100, 28, 28) (100,)\n",
+      "54 (100, 28, 28) (100,)\n",
+      "55 (100, 28, 28) (100,)\n",
+      "56 (100, 28, 28) (100,)\n",
+      "57 (100, 28, 28) (100,)\n",
+      "58 (100, 28, 28) (100,)\n",
+      "59 (100, 28, 28) (100,)\n",
+      "60 (100, 28, 28) (100,)\n",
+      "61 (100, 28, 28) (100,)\n",
+      "62 (100, 28, 28) (100,)\n",
+      "63 (100, 28, 28) (100,)\n",
+      "64 (100, 28, 28) (100,)\n",
+      "65 (100, 28, 28) (100,)\n",
+      "66 (100, 28, 28) (100,)\n",
+      "67 (100, 28, 28) (100,)\n",
+      "68 (100, 28, 28) (100,)\n",
+      "69 (100, 28, 28) (100,)\n",
+      "70 (100, 28, 28) (100,)\n",
+      "71 (100, 28, 28) (100,)\n",
+      "72 (100, 28, 28) (100,)\n",
+      "73 (100, 28, 28) (100,)\n",
+      "74 (100, 28, 28) (100,)\n",
+      "75 (100, 28, 28) (100,)\n",
+      "76 (100, 28, 28) (100,)\n",
+      "77 (100, 28, 28) (100,)\n",
+      "78 (100, 28, 28) (100,)\n",
+      "79 (100, 28, 28) (100,)\n",
+      "80 (100, 28, 28) (100,)\n",
+      "81 (100, 28, 28) (100,)\n",
+      "82 (100, 28, 28) (100,)\n",
+      "83 (100, 28, 28) (100,)\n",
+      "84 (100, 28, 28) (100,)\n",
+      "85 (100, 28, 28) (100,)\n",
+      "86 (100, 28, 28) (100,)\n",
+      "87 (100, 28, 28) (100,)\n",
+      "88 (100, 28, 28) (100,)\n",
+      "89 (100, 28, 28) (100,)\n",
+      "90 (100, 28, 28) (100,)\n",
+      "91 (100, 28, 28) (100,)\n",
+      "92 (100, 28, 28) (100,)\n",
+      "93 (100, 28, 28) (100,)\n",
+      "94 (100, 28, 28) (100,)\n",
+      "95 (100, 28, 28) (100,)\n",
+      "96 (100, 28, 28) (100,)\n",
+      "97 (100, 28, 28) (100,)\n",
+      "98 (100, 28, 28) (100,)\n",
+      "99 (100, 28, 28) (100,)\n"
+     ]
+    }
+   ],
+   "source": [
+    "for ind, (d,l) in enumerate(data_loader(image_data, labels)):\n",
+    "    print(ind, d.shape, l.shape)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 40,
+   "id": "f6deb492",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "def initialize_mlp(sizes, key):\n",
+    "    \"\"\" Initialize the weights of all layers of a linear layer network \"\"\"\n",
+    "    keys = random.split(key, len(sizes))\n",
+    "    # Initialize a single layer with Gaussian weights -  helper function\n",
+    "    def initialize_layer(m, n, key, scale=1e-2):\n",
+    "        w_key, b_key = random.split(key)\n",
+    "        return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))\n",
+    "    return [initialize_layer(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]\n",
+    "\n",
+    "layer_sizes = [784, 512, 512, 10]\n",
+    "# Return a list of tuples of layer weights\n",
+    "params = initialize_mlp(layer_sizes, key)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 41,
+   "id": "9c34e864",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "def forward_pass(params, in_array):\n",
+    "    \"\"\" Compute the forward pass for each example individually \"\"\"\n",
+    "    activations = in_array\n",
+    "    \n",
+    "    # Loop over the ReLU hidden layers\n",
+    "    for w, b in params[:-1]:\n",
+    "        activations = relu_layer([w, b], activations)\n",
+    "    \n",
+    "    # Perform final trafo to logits\n",
+    "    final_w, final_b = params[-1]\n",
+    "    logits = np.dot(final_w, activations) + final_b\n",
+    "    return logits - logsumexp(logits)\n",
+    "\n",
+    "# Make a batched version of the `predict` function\n",
+    "batch_forward = vmap(forward_pass, in_axes=(None, 0), out_axes=0)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 51,
+   "id": "040ba6ec",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "def one_hot(x, k, dtype=np.float32):\n",
+    "    \"\"\"Create a one-hot encoding of x of size k \"\"\"\n",
+    "    return np.array(x[:, None] == np.arange(k), dtype)\n",
+    "\n",
+    "def loss(params, in_arrays, targets):\n",
+    "    \"\"\" Compute the multi-class cross-entropy loss \"\"\"\n",
+    "    preds = batch_forward(params, in_arrays)\n",
+    "    return -np.sum(preds * targets)\n",
+    "  \n",
+    "def accuracy(params, data_loader):\n",
+    "    \"\"\" Compute the accuracy for a provided dataloader \"\"\"\n",
+    "    acc_total = 0\n",
+    "    for batch_idx, (data, target) in enumerate(data_loader):\n",
+    "        images = np.array(data).reshape(data.shape[0], 28*28)\n",
+    "        targets = one_hot(np.array(target), num_classes)\n",
+    "    \n",
+    "        target_class = np.argmax(targets, axis=1)\n",
+    "        predicted_class = np.argmax(batch_forward(params, images), axis=1)\n",
+    "        acc_total += np.sum(predicted_class == target_class)\n",
+    "    return acc_total/data.shape[0]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 43,
+   "id": "d485b894",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "@jit\n",
+    "def update(params, x, y, opt_state):\n",
+    "    \"\"\" Compute the gradient for a batch and update the parameters \"\"\"\n",
+    "    value, grads = value_and_grad(loss)(params, x, y)\n",
+    "    opt_state = opt_update(0, grads, opt_state)\n",
+    "    return get_params(opt_state), opt_state, value\n",
+    "\n",
+    "# Defining an optimizer in Jax\n",
+    "step_size = 1e-3\n",
+    "opt_init, opt_update, get_params = optimizers.adam(step_size)\n",
+    "opt_state = opt_init(params)\n",
+    "\n",
+    "num_epochs = 10\n",
+    "num_classes = 10"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 56,
+   "id": "f7b3bee3-240b-4be9-9943-e2885dacf01a",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Epoch 1 | T: 0.34 | Train A: 92.340\n",
+      "Epoch 2 | T: 0.37 | Train A: 95.420\n",
+      "Epoch 3 | T: 0.35 | Train A: 96.610\n",
+      "Epoch 4 | T: 0.36 | Train A: 97.270\n",
+      "Epoch 5 | T: 0.37 | Train A: 98.510\n",
+      "Epoch 6 | T: 0.36 | Train A: 99.090\n",
+      "Epoch 7 | T: 0.36 | Train A: 99.170\n",
+      "Epoch 8 | T: 0.36 | Train A: 99.580\n",
+      "Epoch 9 | T: 0.37 | Train A: 99.720\n",
+      "Epoch 10 | T: 0.37 | Train A: 99.830\n"
+     ]
+    }
+   ],
+   "source": [
+    "def run_mnist_training_loop(num_epochs, opt_state, net_type=\"MLP\"):\n",
+    "    \"\"\" Implements a learning loop over epochs. \"\"\"\n",
+    "    # Initialize placeholder for loggin\n",
+    "    log_acc_train, log_acc_test, train_loss = [], [], []\n",
+    "    \n",
+    "    # Get the initial set of parameters \n",
+    "    params = get_params(opt_state)\n",
+    "    \n",
+    "    # Get initial accuracy after random init\n",
+    "    train_acc = accuracy(params, data_loader(image_data, labels))\n",
+    "    log_acc_train.append(train_acc)\n",
+    "    \n",
+    "    # Loop over the training epochs\n",
+    "    for epoch in range(num_epochs):\n",
+    "        start_time = time.time()\n",
+    "        for batch_idx, (data, target) in enumerate(data_loader(image_data, labels)):\n",
+    "            if net_type == \"MLP\":\n",
+    "                # Flatten the image into 784 vectors for the MLP\n",
+    "                x = np.array(data).reshape(data.shape[0], 28*28)\n",
+    "            elif net_type == \"CNN\":\n",
+    "                # No flattening of the input required for the CNN\n",
+    "                x = np.array(data)\n",
+    "            y = one_hot(np.array(target), num_classes)\n",
+    "            params, opt_state, loss = update(params, x, y, opt_state)\n",
+    "            train_loss.append(loss)\n",
+    "\n",
+    "        epoch_time = time.time() - start_time\n",
+    "        train_acc = accuracy(params, data_loader(image_data, labels))\n",
+    "        log_acc_train.append(train_acc)\n",
+    "        print(\"Epoch {} | time: {:0.2f} | Train A: {:0.3f}\".format(epoch+1, epoch_time,train_acc))\n",
+    "    \n",
+    "    return train_loss, log_acc_train\n",
+    "\n",
+    "\n",
+    "train_loss, train_log = run_mnist_training_loop(num_epochs,\n",
+    "                                                opt_state,\n",
+    "                                                net_type=\"MLP\")"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.11.3"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}