Marco Hoefle
Published © MIT

Demystifying AI/ML Development for Edge Devices

Learn how to design, train, and deploy a custom CNN on a DEEPX-accelerated edge device.

IntermediateFull instructions provided2 hours40
Demystifying AI/ML Development for Edge Devices

Things used in this project

Hardware components

AMD AS Rock 4X4 BOX-8840U
×1
DEEPX-M1 M.2 Module
×1

Software apps and online services

Ubuntu
Ubuntu

Story

Read more

Code

Dockerfiles

Dockerfile
Dockerfile + python requirements
No preview (download only).

C++ MNIST Edge sources

C/C++
Application Sources
No preview (download only).

Yupiter Notebook

JSON
Backup: The Jupyter Notebook. It is recommended to create it from scratch yourself
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8a53043c-5a87-4a58-a69b-33b78f6b5dd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision import datasets, transforms\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import TensorDataset, DataLoader, random_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "38ac524e-9168-4fbc-bf7d-3c3c5ce2080f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ----------------------------\n",
    "# Config\n",
    "# ----------------------------\n",
    "BATCH_SIZE = 64\n",
    "EPOCHS = 10\n",
    "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2355d64c-ea62-452f-ab61-4ae1915d4181",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|| 9.91M/9.91M [00:01<00:00, 7.23MB/s]\n",
      "100%|| 28.9k/28.9k [00:00<00:00, 320kB/s]\n",
      "100%|| 1.65M/1.65M [00:00<00:00, 2.15MB/s]\n",
      "100%|| 4.54k/4.54k [00:00<00:00, 14.5MB/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training samples:   60000\n",
      "Validation samples: 9000\n",
      "Test samples:       1000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# ----------------------------\n",
    "# Dataset / transforms\n",
    "# ----------------------------\n",
    "transform = transforms.Compose([\n",
    "  transforms.ToTensor(),\n",
    "  transforms.Normalize((0.1307,), (0.3081,))\n",
    "])\n",
    "\n",
    "train_dataset = datasets.MNIST(\n",
    "    root=\"./data\",\n",
    "    train=True,\n",
    "    download=True,\n",
    "    transform=transform\n",
    ")\n",
    "\n",
    "full_valid_dataset = datasets.MNIST(\n",
    "    root=\"./data\",\n",
    "    train=False,\n",
    "    download=True,\n",
    "    transform=transform\n",
    ")\n",
    "\n",
    "# Split original test set into validation + test\n",
    "test_size = 1000\n",
    "valid_size = len(full_valid_dataset) - test_size\n",
    "\n",
    "generator = torch.Generator().manual_seed(42)\n",
    "valid_dataset, test_dataset = random_split(\n",
    "    full_valid_dataset,\n",
    "    [valid_size, test_size],\n",
    "    generator=generator\n",
    ")\n",
    "\n",
    "train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
    "valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)\n",
    "test_loader  = DataLoader(test_dataset,  batch_size=BATCH_SIZE, shuffle=False)\n",
    "\n",
    "print(\"Training samples:  \", len(train_dataset))\n",
    "print(\"Validation samples:\", len(valid_dataset))\n",
    "print(\"Test samples:      \", len(test_dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6d8e5dd7-059d-467d-8f24-052af6016044",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGbCAYAAAAr/4yjAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAD0tJREFUeJzt3HvM1/P/x/Hnp6RChVZOLdYIjUQH/ZF1OWxJNpkwM61/zMTWjJAl2YyxDs40h9Gy5Uzm9M9V/cNKS4w55NAM0YF1GDKuz/cPP8/pd4Xr9e46ldtt88+n96P3S9K9d4d3rV6v1wMAIqJLRx8AgM5DFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFNgrrVu3Lmq1WsyZM6fVvs5ly5ZFrVaLZcuWtdrXCZ2NKNBpPPHEE1Gr1WLVqlUdfZQ2880338RFF10UBx54YPTu3TvOO++8+OKLLzr6WJD26egDwH/F9u3b4/TTT48tW7bETTfdFN26dYv58+fH2LFjY82aNdG3b9+OPiKIArSXBx98MNauXRsrV66MkSNHRkTE+PHj44QTToi5c+fG7bff3sEnBL98xB7m119/jVmzZsXw4cOjT58+sf/++8dpp50WS5cu/dvN/Pnz48gjj4yePXvG2LFj44MPPmh2zccffxyTJk2Kgw8+OHr06BEjRoyIJUuW/Ot5fvrpp/j4449j06ZN/3rtc889FyNHjswgREQcd9xxceaZZ8Yzzzzzr3toD6LAHmXr1q3x6KOPRkNDQ9x5550xe/bs2LhxY4wbNy7WrFnT7PqFCxfGvffeG1dddVXMmDEjPvjggzjjjDPi+++/z2s+/PDDGD16dHz00Udx4403xty5c2P//fePiRMnxosvvviP51m5cmUcf/zxcf/99//jdU1NTfH+++/HiBEjmn3ZqFGj4vPPP49t27a17BsB2pBfPmKPctBBB8W6deti3333zc8uv/zyOO644+K+++6Lxx57bKfrP/vss1i7dm0cccQRERFx9tlnx6mnnhp33nlnzJs3LyIipk2bFgMHDox33nknunfvHhERU6dOjTFjxsQNN9wQ559//m6f+4cffogdO3bEYYcd1uzL/vzs22+/jWOPPXa37wW7w5MCe5SuXbtmEJqamuKHH36I3377LUaMGBGrV69udv3EiRMzCBF//Kz81FNPjddeey0i/vjBurGxMS666KLYtm1bbNq0KTZt2hSbN2+OcePGxdq1a+Obb7752/M0NDREvV6P2bNn/+O5f/7554iIjM5f9ejRY6droCOJAnucJ598MoYOHRo9evSIvn37Rr9+/eLVV1+NLVu2NLv2mGOOafbZ4MGDY926dRHxx5NEvV6Pm2++Ofr167fTP7fccktERGzYsGG3z9yzZ8+IiNixY0ezL/vll192ugY6kl8+Yo+yaNGimDJlSkycODGmT58e/fv3j65du8Ydd9wRn3/+efHX19TUFBER1113XYwbN26X1xx99NG7deaIiIMPPji6d+8e69evb/Zlf352+OGH7/Z9YHeJAnuU5557LgYNGhQvvPBC1Gq1/PzPn9X/f2vXrm322aeffhpHHXVUREQMGjQoIiK6desWZ511Vusf+P906dIlTjzxxF3+xbwVK1bEoEGDolevXm12f2gpv3zEHqVr164REVGv1/OzFStWxNtvv73L61966aWdfk9g5cqVsWLFihg/fnxERPTv3z8aGhpiwYIFu/xZ/MaNG//xPCV/JHXSpEnxzjvv7BSGTz75JBobG+PCCy/81z20B08KdDqPP/54vPHGG80+nzZtWpx77rnxwgsvxPnnnx8TJkyIL7/8Mh5++OEYMmRIbN++vdnm6KOPjjFjxsSVV14ZO3bsiLvvvjv69u0b119/fV7zwAMPxJgxY+LEE0+Myy+/PAYNGhTff/99vP322/H111/He++997dnXblyZZx++ulxyy23/OtvNk+dOjUeeeSRmDBhQlx33XXRrVu3mDdvXhxyyCFx7bXXtvwbCNqQKNDpPPTQQ7v8fMqUKTFlypT47rvvYsGCBfHmm2/GkCFDYtGiRfHss8/u8kV1kydPji5dusTdd98dGzZsiFGjRsX999+/0x8NHTJkSKxatSpuvfXWeOKJJ2Lz5s3Rv3//OPnkk2PWrFmt9u/Vq1evWLZsWVxzzTVx2223RVNTUzQ0NMT8+fOjX79+rXYf2B21+l+fwwH4T/N7CgAkUQAgiQIASRQASKIAQBIFAFKL/57CX18pAMCepyV/A8GTAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUA0j4dfQDYkw0fPrx4c/XVV1e61+TJk4s3CxcuLN7cd999xZvVq1cXb+icPCkAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACDV6vV6vUUX1mptfRboUMOGDSveNDY2Fm969+5dvGlPW7ZsKd707du3DU5Ca2vJD/eeFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkPbp6ANAWxg1alTx5vnnny/e9OnTp3jTwndQNrNt27biza+//lq8qfJyu9GjRxdvVq9eXbyJqPbvRMt5UgAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQKrVW/h2rlqt1tZnYS+33377VdqdcsopxZtFixYVbwYMGFC8qfL/RdUX4lV5gdxdd91VvFm8eHHxpsq3w8yZM4s3ERF33HFHpR0t+77nSQGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEj7dPQB+O9YsGBBpd0ll1zSyifZM1V5W+wBBxxQvFm+fHnxpqGhoXgzdOjQ4g1tz5MCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSF+JRyfDhw4s3EyZMqHSvWq1WaVeqyovgXnnlleLNnDlzijcREd9++23x5t133y3e/Pjjj8WbM844o3jTXv9dKeNJAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIAqVav1+stutDLq/Zaw4YNK940NjYWb3r37l28qer1118v3lxyySXFm7FjxxZvhg4dWryJiHj00UeLNxs3bqx0r1K///578eann36qdK8q3+arV6+udK+9TUt+uPekAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAtE9HH4DWNXjw4OLN9OnTizd9+vQp3mzatKl4ExGxfv364s2TTz5ZvNm+fXvx5tVXX22Xzd6oZ8+elXbXXntt8ebSSy+tdK//Ik8KACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBA8pbUTqp79+6VdnPmzCnenHPOOcWbbdu2FW8mT55cvImIWLVqVfGm6hs46fwGDhzY0UfYq3lSACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBA8kK8Turkk0+utKvycrsqzjvvvOLN8uXL2+AkQGvypABAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgOSFeJ3UvHnzKu1qtVrxpsqL6rzcjr/q0qX855dNTU1tcBJ2lycFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkL8RrB+eee27xZtiwYZXuVa/XizdLliypdC/4U5WX21X5vhoRsWbNmko7WsaTAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkhfitYOePXsWb/bdd99K99qwYUPx5umnn650Lzq/7t27F29mz57d+gfZhcbGxkq7GTNmtPJJ+CtPCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQPKW1L3Mjh07ijfr169vg5PQ2qq88XTmzJnFm+nTpxdvvv766+LN3LlzizcREdu3b6+0o2U8KQCQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIHkh3l5myZIlHX0E/sWwYcMq7aq8qO7iiy8u3rz88svFmwsuuKB4Q+fkSQGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAMkL8dpBrVZrl01ExMSJE4s306ZNq3QvIq655prizc0331zpXn369CnePPXUU8WbyZMnF2/Ye3hSACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBA8kK8dlCv19tlExFx6KGHFm/uvffe4s3jjz9evNm8eXPxJiJi9OjRxZvLLruseHPSSScVbwYMGFC8+eqrr4o3ERFvvvlm8ebBBx+sdC/+uzwpAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgeSHeXqZr167Fm6lTpxZvLrjgguLN1q1bizcREcccc0ylXXt46623ijdLly6tdK9Zs2ZV2kEJTwoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAECq1ev1eosurNXa+ix7rQEDBhRvnn322Ur3GjlyZKVdqSrfH1r4Xa1VbN68uXizePHi4s20adOKN9BRWvL/oCcFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkL8TrpA477LBKuyuuuKJ4M3PmzOJNe74Q75577inePPTQQ8Wbzz77rHgDexIvxAOgiCgAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACQvxAP4j/BCPACKiAIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQNqnpRfW6/W2PAcAnYAnBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQDS/wA2Ze50d0dnCgAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Get one sample from the training dataset\n",
    "image, label = train_dataset[1]\n",
    "\n",
    "# Convert tensor to numpy for plotting\n",
    "image_np = image.squeeze().numpy()\n",
    "\n",
    "# Display the image\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.imshow(image_np, cmap=\"gray\")\n",
    "plt.title(f\"Label: {label}\")\n",
    "plt.axis(\"off\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "aff59f71-296e-4ddf-aeea-36fe94bc51d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ----------------------------\n",
    "# Model\n",
    "# ----------------------------\n",
    "class MNISTModel(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        self.conv1 = nn.Conv2d(\n",
    "            in_channels=1,\n",
    "            out_channels=16,\n",
    "            kernel_size=3,\n",
    "            bias=True\n",
    "        )\n",
    "        self.bn1 = nn.BatchNorm2d(16)\n",
    "\n",
    "        self.dwconv = nn.Conv2d(\n",
    "            in_channels=16,\n",
    "            out_channels=16,\n",
    "            kernel_size=3,\n",
    "            groups=16,\n",
    "            bias=True\n",
    "        )\n",
    "        self.bn2 = nn.BatchNorm2d(16)\n",
    "\n",
    "        self.dropout = nn.Dropout(0.1)\n",
    "        self.fc = nn.Linear(16 * 24 * 24, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = self.bn1(x)\n",
    "        x = F.relu(x)\n",
    "\n",
    "        x = self.dwconv(x)\n",
    "        x = self.bn2(x)\n",
    "        x = F.relu(x)\n",
    "\n",
    "        x = torch.flatten(x, 1)\n",
    "        x = self.dropout(x)\n",
    "        x = self.fc(x)\n",
    "        return x\n",
    "\n",
    "model = MNISTModel().to(DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e638954d-1e65-4148-8104-4a1710dabecf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sanity check output shape: torch.Size([1, 10])\n"
     ]
    }
   ],
   "source": [
    "# Quick sanity check\n",
    "x = torch.randn(1, 1, 28, 28).to(DEVICE)\n",
    "y = model(x)\n",
    "print(\"Sanity check output shape:\", y.shape)   # [1, 10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "467a74d5-986f-4c49-a87e-66937d9a98f1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "  Train Loss: 0.1511, Acc: 0.9546\n",
      "  Valid Loss: 0.0706, Acc: 0.9788\n",
      "Epoch 2/10\n",
      "  Train Loss: 0.0588, Acc: 0.9815\n",
      "  Valid Loss: 0.0818, Acc: 0.9777\n",
      "Epoch 3/10\n",
      "  Train Loss: 0.0380, Acc: 0.9875\n",
      "  Valid Loss: 0.0839, Acc: 0.9757\n",
      "Epoch 4/10\n",
      "  Train Loss: 0.0276, Acc: 0.9909\n",
      "  Valid Loss: 0.0768, Acc: 0.9810\n",
      "Epoch 5/10\n",
      "  Train Loss: 0.0226, Acc: 0.9926\n",
      "  Valid Loss: 0.0816, Acc: 0.9792\n",
      "Epoch 6/10\n",
      "  Train Loss: 0.0174, Acc: 0.9940\n",
      "  Valid Loss: 0.0800, Acc: 0.9790\n",
      "Epoch 7/10\n",
      "  Train Loss: 0.0151, Acc: 0.9949\n",
      "  Valid Loss: 0.0752, Acc: 0.9842\n",
      "Epoch 8/10\n",
      "  Train Loss: 0.0126, Acc: 0.9957\n",
      "  Valid Loss: 0.0778, Acc: 0.9827\n",
      "Epoch 9/10\n",
      "  Train Loss: 0.0115, Acc: 0.9962\n",
      "  Valid Loss: 0.0832, Acc: 0.9811\n",
      "Epoch 10/10\n",
      "  Train Loss: 0.0109, Acc: 0.9962\n",
      "  Valid Loss: 0.0766, Acc: 0.9823\n"
     ]
    }
   ],
   "source": [
    "# ----------------------------\n",
    "# Loss / optimizer\n",
    "# ----------------------------\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "\n",
    "# ----------------------------\n",
    "# Training loop\n",
    "# ----------------------------\n",
    "for epoch in range(EPOCHS):\n",
    "    # Training\n",
    "    model.train()\n",
    "    train_loss = 0.0\n",
    "    train_correct = 0\n",
    "    train_total = 0\n",
    "\n",
    "    for x, y in train_loader:\n",
    "        x = x.to(DEVICE)\n",
    "        y = y.to(DEVICE)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        outputs = model(x)\n",
    "        loss = criterion(outputs, y)\n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        train_loss += loss.item() * x.size(0)\n",
    "        preds = outputs.argmax(dim=1)\n",
    "        train_correct += (preds == y).sum().item()\n",
    "        train_total += y.size(0)\n",
    "\n",
    "    train_loss /= train_total\n",
    "    train_acc = train_correct / train_total\n",
    "\n",
    "    # Validation\n",
    "    model.eval()\n",
    "    valid_loss = 0.0\n",
    "    valid_correct = 0\n",
    "    valid_total = 0\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for x, y in valid_loader:\n",
    "            x = x.to(DEVICE)\n",
    "            y = y.to(DEVICE)\n",
    "\n",
    "            outputs = model(x)\n",
    "            loss = criterion(outputs, y)\n",
    "\n",
    "            valid_loss += loss.item() * x.size(0)\n",
    "            preds = outputs.argmax(dim=1)\n",
    "            valid_correct += (preds == y).sum().item()\n",
    "            valid_total += y.size(0)\n",
    "\n",
    "    valid_loss /= valid_total\n",
    "    valid_acc = valid_correct / valid_total\n",
    "\n",
    "    print(f\"Epoch {epoch+1}/{EPOCHS}\")\n",
    "    print(f\"  Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f}\")\n",
    "    print(f\"  Valid Loss: {valid_loss:.4f}, Acc: {valid_acc:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "523ec3a1-d9cd-4cae-9cb9-b26c3c9b4da3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Test Acc: 0.9800\n"
     ]
    }
   ],
   "source": [
    "# ----------------------------\n",
    "# Final test evaluation\n",
    "# ----------------------------\n",
    "model.eval()\n",
    "test_correct = 0\n",
    "test_total = 0\n",
    "\n",
    "with torch.no_grad():\n",
    "    for x, y in test_loader:\n",
    "        x = x.to(DEVICE)\n",
    "        y = y.to(DEVICE)\n",
    "\n",
    "        outputs = model(x)\n",
    "        preds = outputs.argmax(dim=1)\n",
    "\n",
    "        test_correct += (preds == y).sum().item()\n",
    "        test_total += y.size(0)\n",
    "\n",
    "test_acc = test_correct / test_total\n",
    "print(f\"\\nTest Acc: {test_acc:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "64aff733-0707-423e-ac96-5136db2c9ce8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ground truth: 5\n",
      "Prediction : 5\n",
      "Confidence : 100.00%\n"
     ]
    }
   ],
   "source": [
    "# Take one sample from test set\n",
    "\n",
    "x, y = next(iter(test_loader))   # get a batch\n",
    "x = x.to(DEVICE)\n",
    "y = y.to(DEVICE)\n",
    "\n",
    "# Pick first image in batch\n",
    "img = x[0].unsqueeze(0)   # keep batch dim  [1, 1, 28, 28]\n",
    "label = y[0].item()\n",
    "\n",
    "with torch.no_grad():\n",
    "    outputs = model(img)\n",
    "    probs = F.softmax(outputs, dim=1)\n",
    "\n",
    "    pred_class = probs.argmax(dim=1).item()\n",
    "    confidence = probs.max().item() * 100\n",
    "\n",
    "print(f\"Ground truth: {label}\")\n",
    "print(f\"Prediction : {pred_class}\")\n",
    "print(f\"Confidence : {confidence:.2f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "2faa0e59-59dc-4019-8eb4-c2afc8eb8769",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGbCAYAAAAr/4yjAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAGLJJREFUeJzt3XtwVPX9//HXQghJgASChAKRgIZLEiApor0YFB0amIZbCwpUaAhQFLAIqBQFhOEShbRMrZeowxRoJVRUbqWAoKLVwRbQotAQ29CEWxCwQe6XkHy+f/DL+0dICHsCScA+HzMZp2fPe89no7PPnD27W59zzgkAAEm1anoBAIAbB1EAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFHANZsxY4ZatWpV08vwy7Bhw2p8rcuWLVN4eLhOnjxZo+u4ERQWFurWW2/Vyy+/XNNLwf9DFG4gubm5evTRR9W2bVuFhIQoJCREsbGxGjt2rL744gtJUrdu3eTz+a76M2PGDL+PO2zYsHLvo3379tf18eXl5ZW6/9q1a6tly5b6yU9+ou3bt1/XY1WVVq1alfu7euSRR/yaLyoq0vTp0/XLX/5S9evXt+0bNmzQiBEj1KFDB9WuXbvCcBUXF2vevHlq3bq1goKC1KlTJy1durTcfXft2qWePXuqfv36Cg8P19ChQ3XkyBG/H+/q1avVuXNnBQUFqWXLlpo+fbouXLhQap+srCx17dpVDRo0UJcuXfTJJ5+UuZ/58+crLi6uzGydOnU0ceJEzZkzR2fPnvV7Xag6ATW9AFy0Zs0aDRw4UAEBAXrooYcUHx+vWrVqKTs7W8uXL1dGRoZyc3M1ZcoUjRw50ua2bt2q3/3ud3r66acVExNj2zt16uTp+HXr1tWCBQtKbQsLC7u2B3UFgwcP1o9//GMVFRVp165dysjI0Lp16/S3v/1NCQkJVXLM6ykhIUGPP/54qW1t27b1a/bPf/6zvvzyS40aNarU9szMTL3xxhvq3LmzmjdvXuF9TJkyRc8995x+8Ytf6M4779SqVav0s5/9TD6fT4MGDbL99u/fr3vuuUdhYWFKS0vTyZMn9etf/1o7duzQli1bFBgYWOFx1q1bp379+qlbt2564YUXtGPHDs2ePVuHDx9WRkaGpIuR++lPf6rw8HClp6dr9erV6tu3r3JychQaGipJOnz4sGbOnKlly5YpIKDsU05qaqomT56szMxMDR8+3K/fI6qQQ43Lyclx9erVczExMS4/P7/M7YWFhe755593e/fuLXPbm2++6SS5TZs2Vfr4KSkprl69epWenz59uouKirrqfrm5uU6SS09PL7V99erVTpIbNWrUFWdPnjxZ6fVdKiUlxa+1XklUVJRLTk6u9HyfPn1cYmJime0HDhxw58+fd845l5ycfMU17t+/39WpU8eNHTvWthUXF7uuXbu6yMhId+HCBds+evRoFxwc7Pbs2WPbNm7c6CS5V1999aprjY2NdfHx8a6wsNC2TZkyxfl8Prdr1y7nnHO7du1ykuwYp06dcsHBwW79+vU2M2LECNe7d+8Kj9WrVy/XtWvXq64JVY+Xj24A8+bN06lTp7Rw4UI1a9aszO0BAQEaN26cbr31Vr/v89ixY8rOztaxY8f8nikqKtLx48f93v96uf/++yVdfPlMkhYtWiSfz6cPP/xQY8aMUUREhCIjI23/devWqWvXrqpXr54aNGig5ORk/fOf/yxzvytXrlSHDh0UFBSkDh06aMWKFeUe/+DBg8rOzlZhYaHfaz5//rxOnTrl5WHq7NmzWr9+vbp3717mtubNm6tOnTpXvY9Vq1apsLBQY8aMsW0+n0+jR4/W/v37S7108/bbb6tXr15q2bKlbevevbvatm2rZcuWVXicrKwsZWVladSoUaX+uh8zZoycc3rrrbckSWfOnJEkNWrUSJIUEhKi4OBgnT59WpL02WefacmSJZo/f36Fx/vRj36kjz/+WAUFBVf9HaBqEYUbwJo1axQdHa3vfe971+0+V6xYoZiYmCs+EV7u9OnTCg0NVVhYmMLDwzV27NhquxC6e/duSVLjxo1LbR8zZoyysrL0zDPPaPLkyZKkP/7xj0pOTlb9+vU1d+5cTZs2TVlZWUpMTFReXp7NbtiwQf3795fP59Ozzz6rfv36KTU1Vdu2bStz/KeeekoxMTE6cOCAX+t9//33FRISovr166tVq1Z6/vnn/Zr79NNPdf78eXXu3Nmv/cvzj3/8Q/Xq1Sv1UqEk3XXXXXa7JB04cECHDx9Wly5dytzHXXfdZftVdBxJZeabN2+uyMhIu71t27YKCwvTjBkztGfPHqWnp+v48eP2GMeNG6dHH31U0dHRFR7vjjvukHNOmzdvrnA/VD2uKdSw48ePKz8/X/369Stz2zfffFPqwly9evUUHBx83dfQrFkzTZo0SZ07d1ZxcbHWr1+vl19+WZ9//rk++OCDcl8HvhanT5/W119/raKiImVnZ2vChAmSpAceeKDUfuHh4XrvvfdUu3ZtSdLJkyc1btw4jRw5Uq+99prtl5KSonbt2iktLc22/+pXv1LTpk318ccf27WRe++9V0lJSYqKiqr02jt16qTExES1a9dO//3vf7Vo0SKNHz9e+fn5mjt3boWz2dnZkqTWrVtX+vgHDx5U06ZN5fP5Sm0vOcPMz8+3/S7dfvm+BQUFOnfunOrWrXvF41Q0X3KcevXqKSMjQyNGjND8+fNVu3ZtzZ07V1FRUcrMzFROTo7Wrl171cd12223Sbp4htKrV6+r7o8qVNOvX/2v27dvn5PkhgwZUua2+Ph4J8l+Ln8t3rnrc02hPHPmzHGS3NKlS6+6r9drCpf/hIaGurlz59p+CxcudJLc4sWLS80vX77cSXLvv/++O3LkSKmfpKQkFx0d7ZxzLj8/30lykydPLrOG2NjYa7qmcLni4mLXo0cPFxAQ4Pbt21fhvnPnznWS3P79+yvcr6JrCvfff7+LiYkps72oqMhJco899phzzrm//vWvTpJ74403yuw7bdo0J8kdPXr0imuYOXOmk+QOHTpU5rauXbu6+Pj4UtsKCgrcJ5984r766ivn3MVrC5GRkW7BggWuqKjIzZgxw7Vu3dp17NjRLV++vMx9njlzxklyTz755BXXhOrBmUINa9CggSSV+1LNq6++qhMnTujQoUMaMmRIta5rwoQJmjZtmt59991S72i5HkaNGqUHHnhAtWrVUsOGDRUXF1fuX6yX/0X973//W9L/vwZxuZJ3u+zZs0eS1KZNmzL7tGvXTp999tk1rf9SPp9PEyZM0DvvvKMPPvjAr39P7hr+zw6Dg4N17ty5MttL3s5ZciZZ8k9/9r3ScSqav3y2UaNG+v73v2//+9lnn1VERIRSU1P1+9//Xq+88oqWLFmivLw8DRw4UFlZWaVeUir5nVx+BoTqRxRqWFhYmJo1a6adO3eWua3kGsOlr5VXl+DgYDVu3LhKLvy1adOm3Iut5a3hUsXFxZIuXlf4zne+U2b/6/0yl79K3gBwtd9VyTWTo0ePlrpw7kWzZs20adMmOedKPYGWvNxT8nbWkpd9SrZf6uDBgwoPD7/iS0eXz1/+BoeDBw/aNYzy5OXl6Te/+Y02bNigWrVqaenSpXr44Yct5osXL9af/vQnTZ061WaOHj0qSbrllluu/OBRLbjQfANITk5WTk6OtmzZUtNLMSdOnNDXX3+tJk2a1PRSzO233y5JioiIUPfu3cv8dOvWTZLsmkHJmcWlvvzyy+u+rv/85z+SdNXfVcmHAUveZVUZCQkJOn36tHbt2lVq+9///ne7XZJatGihJk2alHthfcuWLVf9PEjJ7ZfP5+fna//+/RXOP/HEE+rTp48SExNt5tLPXjRv3rzMRf2S38nlF9BR/YjCDWDSpEkKCQnR8OHDdejQoTK3V+blBn/fknr27FmdOHGizPZZs2bJOaeePXt6PnZV6dGjh0JDQ5WWllbu20dLPqnbrFkzJSQkaPHixaUe/8aNG5WVlVVmzt+3pBYUFKioqKjUtsLCQj333HMKDAzUfffdV+H8HXfcocDAwHKfqP3Vt29f1alTp9TXQjjn9Morr6hFixb64Q9/aNv79++vNWvWaN++fbbtvffe07/+9a9SF/ULCwuVnZ1d6qwiLi5O7du312uvvVbqMWdkZMjn82nAgAHlrm/Tpk1au3at5s2bZ9uaNm1qF9mli5+yvvxM79NPP5XP59MPfvADL78OVAFeProBtGnTRpmZmRo8eLDatWtnn2h2zik3N1eZmZmqVauWp5ccVqxYodTUVC1cuFDDhg274n5fffWVvvvd72rw4MH2l+w777yjtWvXqmfPnurbt++1PrzrJjQ0VBkZGRo6dKg6d+6sQYMGqUmTJtq7d6/+8pe/6O6779aLL74o6eJr2snJyUpMTNTw4cNVUFCgF154QXFxcWWu3zz11FNavHixcnNzK/x6idWrV2v27NkaMGCAWrdurYKCAmVmZmrnzp1KS0sr9yWtSwUFBSkpKUnvvvuuZs6cWeq2L774QqtXr5Yk5eTk6NixY5o9e7YkKT4+Xr1795YkRUZGavz48UpPT1dhYaHuvPNOrVy5Uh999JGWLFli79SSpKefflpvvvmm7rvvPj322GM6efKk0tPT1bFjR6Wmptp+Bw4cUExMjFJSUrRo0SLbnp6erj59+igpKUmDBg3Szp079eKLL2rkyJHl/kVfVFSk8ePH68knnyz12YgBAwZo0qRJatKkifbs2aMdO3ZoyZIlpWY3btyou+++u8zbklEDavAiNy6Tk5PjRo8e7aKjo11QUJALDg527du3d4888ojbvn17uTNXevdRyTt4Fi5cWOExjx496oYMGeKio6NdSEiIq1u3rouLi3NpaWn2CdurudZPNF+uZO1bt24t9/ZNmza5Hj16uLCwMBcUFORuv/12N2zYMLdt27ZS+7399tsuJibG1a1b18XGxrrly5eX+4nmlJQUJ8nl5uZWuK5t27a53r17uxYtWrjAwEBXv359l5iY6JYtW3bVx15i+fLlzufzlfl0esljLu8nJSWl1L5FRUUuLS3NRUVFucDAQBcXF+def/31co+3c+dOl5SU5EJCQlzDhg3dQw89ZO8QKlHy7+Xy4zjn3IoVK1xCQoKrW7eui4yMdFOnTr3ifxcvvfSSi4yMdKdOnSq1vbCw0E2cONHdcsstLioqqsy7yr755hsXGBjoFixYUO79onr5nLuGt0IAuvgtqYsWLaqRC+I3m6KiIsXGxurBBx/UrFmzano5N4Tf/va3mjdvnnbv3l0ln8OBN1xTAKpR7dq1NXPmTL300kt8dbYuXs+YP3++pk6dShBuEFxTAKrZwIEDNXDgwJpexg2hTp062rt3b00vA5fgTAEAYLimAAAwnCkAAAxRAAAYvy8080VVAHBz8+dqAWcKAABDFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGACanoBAOBFUFCQ55nHH3/c80y7du08z/z85z/3PHOj4UwBAGCIAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgCEKAADDF+IBuGYDBgzwPNO4ceNKHWvEiBGeZ5o2bep55oknnvA8823AmQIAwBAFAIAhCgAAQxQAAIYoAAAMUQAAGKIAADBEAQBgiAIAwBAFAIAhCgAAQxQAAIYoAAAM35IKVLPw8PBKzUVERHiemTVrlueZLl26eJ5p1KiR55kLFy54npGkoUOHep45duyY55nNmzd7nvk24EwBAGCIAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgCEKAADDF+LhWykhIcHzTMOGDT3PJCUleZ4ZM2aM5xlJCgsL8zxTXFzseeb111/3PJORkeF5Jisry/OMVLkvBszJyanUsf4XcaYAADBEAQBgiAIAwBAFAIAhCgAAQxQAAIYoAAAMUQAAGKIAADBEAQBgiAIAwBAFAIDxOeecXzv6fFW9FtxEUlNTPc9MnDixUseqVcv73y4tW7b0PBMSEuJ5pjp99NFHnmdmz57teebzzz/3PHPkyBHPM6h+/jzdc6YAADBEAQBgiAIAwBAFAIAhCgAAQxQAAIYoAAAMUQAAGKIAADBEAQBgiAIAwBAFAIAJqOkF4OZ07733ep5p3LhxpY5VmS9jPHHihOeZt956y/PMqlWrPM+sXLnS8wxQXThTAAAYogAAMEQBAGCIAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgPE555xfO1bimypxc4iOjvY8k5CQ4HmmTZs2nmck6Q9/+EOl5rw6cOBAtRwHqCn+PN1zpgAAMEQBAGCIAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgAmo6QXg+oqLi/M8079/f88zeXl5nmc2bdrkeUbii+qA6sSZAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgCEKAABDFAAAhi/E+5aJjY31PNOrVy/PM6NGjfI8s337ds8zAKoXZwoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABifc875taPPV9VrwXUQGhrqeSY3N9fzTGBgoOeZhIQEzzOStHv37krNASjNn6d7zhQAAIYoAAAMUQAAGKIAADBEAQBgiAIAwBAFAIAhCgAAQxQAAIYoAAAMUQAAGKIAADABNb0AXF8RERGeZzZv3ux5plevXp5nVqxY4XlGko4ePep5ZsSIEZ5ncnJyPM8A3zacKQAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMD4nHPOrx19vqpeC2pIq1atPM9ERUVd/4VcweTJkz3P7Nq1y/PMxIkTPc8ANxN/nu45UwAAGKIAADBEAQBgiAIAwBAFAIAhCgAAQxQAAIYoAAAMUQAAGKIAADBEAQBgiAIAwATU9AJQ8/Ly8qplprIefPBBzzP33HOP55ng4GDPM2fOnPE8A9zIOFMAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMD4nHPOrx19vqpeC1Cu4uJizzP79u3zPNOxY0fPM8ePH/c8A9QUf57uOVMAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAE1PQCgKowbdo0zzOnT5+ugpUANxfOFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMD7nnPNrR5+vqtcClCsnJ8fzTHx8vOeZU6dOeZ4Bbib+PN1zpgAAMEQBAGCIAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgCEKAAATUNMLwM0pIiLC80xGRkaljnXbbbd5nunSpYvnmQ8//NDzDPBtw5kCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGL8SDatXy/rfBww8/7HmmX79+nmckae/evZ5n9uzZU6ljAf/rOFMAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMD4nHPOrx19vqpeC2pIRESE55mDBw96ntm6davnGUkaMmSI55mcnJxKHQv4NvPn6Z4zBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADF+IBwUEBHieeeaZZzzPzJkzx/OMJJ07d65ScwBK4wvxAACeEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgCEKAABDFAAAhigAAAzfkgoA/yP4llQAgCdEAQBgiAIAwBAFAIAhCgAAQxQAAIYoAAAMUQAAGKIAADBEAQBgiAIAwBAFAIAhCgAAQxQAAIYoAAAMUQAAGKIAADBEAQBgiAIAwBAFAIAhCgAAQxQAAIYoAAAMUQAAGKIAADAB/u7onKvKdQAAbgCcKQAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAAzP8BNPKLgqw97B0AAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# ----------------------------\n",
    "# Show image\n",
    "# ----------------------------\n",
    "img_np = img.cpu().squeeze().numpy()\n",
    "\n",
    "plt.imshow(img_np, cmap=\"gray\")\n",
    "plt.title(f\"GT: {label} | Pred: {pred_class} ({confidence:.1f}%)\")\n",
    "plt.axis(\"off\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "7cfd48bd-41c0-47a4-8e66-0546b2bab626",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_32/546841718.py:11: UserWarning: # 'dynamic_axes' is not recommended when dynamo=True, and may lead to 'torch._dynamo.exc.UserError: Constraints violated.' Supply the 'dynamic_shapes' argument instead if export is unsuccessful.\n",
      "  torch.onnx.export(\n",
      "W0408 11:04:58.114000 32 torch/onnx/_internal/exporter/_compat.py:133] Setting ONNX exporter to use operator set version 18 because the requested opset_version 13 is a lower version than we have implementations for. Automatic version conversion will be performed, which may not be successful at converting to the requested version. If version conversion is unsuccessful, the opset version of the exported model will be kept at 18. Please consider setting opset_version >=18 to leverage latest ONNX features\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[torch.onnx] Obtain model graph for `MNISTModel([...]` with `torch.export.export(..., strict=False)`...\n",
      "[torch.onnx] Obtain model graph for `MNISTModel([...]` with `torch.export.export(..., strict=False)`... \n",
      "[torch.onnx] Run decompositions...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.\n",
      "  return cls.__new__(cls, *args)\n",
      "The model version conversion is not supported by the onnxscript version converter and fallback is enabled. The model will be converted using the onnx C API (target version: 13).\n",
      "Failed to convert the model to the target version 13 using the ONNX C API. The model was not modified\n",
      "Traceback (most recent call last):\n",
      "  File \"/usr/local/lib/python3.12/dist-packages/onnxscript/version_converter/__init__.py\", line 120, in call\n",
      "    converted_proto = _c_api_utils.call_onnx_api(\n",
      "                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/usr/local/lib/python3.12/dist-packages/onnxscript/version_converter/_c_api_utils.py\", line 65, in call_onnx_api\n",
      "    result = func(proto)\n",
      "             ^^^^^^^^^^^\n",
      "  File \"/usr/local/lib/python3.12/dist-packages/onnxscript/version_converter/__init__.py\", line 115, in _partial_convert_version\n",
      "    return onnx.version_converter.convert_version(\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/usr/local/lib/python3.12/dist-packages/onnx/version_converter.py\", line 39, in convert_version\n",
      "    converted_model_str = C.convert_version(model_str, target_version)\n",
      "                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "RuntimeError: /github/workspace/onnx/version_converter/BaseConverter.h:67: adapter_lookup: Assertion `false` failed: No Adapter From Version $16 for Identity\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[torch.onnx] Run decompositions... \n",
      "[torch.onnx] Translate the graph into ONNX...\n",
      "[torch.onnx] Translate the graph into ONNX... \n",
      "[torch.onnx] Optimize the ONNX graph...\n",
      "Applied 3 of general pattern rewrite rules.\n",
      "[torch.onnx] Optimize the ONNX graph... \n",
      " Save ONNX: models/mnist_model.onnx\n"
     ]
    }
   ],
   "source": [
    "# ----------------------------\n",
    "# Export to onnx\n",
    "# ----------------------------\n",
    "\n",
    "os.makedirs(\"models\", exist_ok=True)\n",
    "onnx_path = \"models/mnist_model.onnx\"\n",
    "\n",
    "# Dummy input (must match your model input!)\n",
    "dummy_input = torch.randn(1, 1, 28, 28).to(DEVICE)\n",
    "\n",
    "torch.onnx.export(\n",
    "    model,                         # model to export\n",
    "    dummy_input,                   # example input\n",
    "    onnx_path,                     # output file\n",
    "    export_params=True,            # store trained weights\n",
    "    opset_version=13,              # good default\n",
    "    do_constant_folding=True,      # optimization\n",
    "    input_names=[\"input\"],         # optional\n",
    "    output_names=[\"output\"],       # optional\n",
    "    dynamic_axes={                 # allow variable batch size\n",
    "        \"input\": {0: \"batch_size\"},\n",
    "        \"output\": {0: \"batch_size\"},\n",
    "    }\n",
    ")\n",
    "\n",
    "print(\" Save ONNX:\", onnx_path)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "cb55b7c1-9ae4-4582-8a07-48f3fc25cc6c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved individual images\n"
     ]
    }
   ],
   "source": [
    "# ----------------------------\n",
    "# Save test images\n",
    "# ----------------------------\n",
    "os.makedirs(\"images/mnist_test\", exist_ok=True)\n",
    "\n",
    "for i in range(len(test_dataset)):\n",
    "    img, label = test_dataset[i]\n",
    "    img_np = img.squeeze().numpy()\n",
    "    plt.imsave(f\"images/mnist_test/{i}_label_{label}.png\", img_np, cmap=\"gray\")\n",
    "\n",
    "print(\"Saved individual images\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "8de2bc92-b89b-4a1c-ab8c-36902729cdc8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ----------------------------\n",
    "# save calibration images\n",
    "# ----------------------------\n",
    "# Raw MNIST without normalization, so files contain normal pixel values\n",
    "train_raw = datasets.MNIST(\n",
    "    root=\"./data\",\n",
    "    train=True,\n",
    "    download=True,\n",
    "    transform=None\n",
    ")\n",
    "\n",
    "out_dir = \"images/mnist_calibration\"\n",
    "os.makedirs(out_dir, exist_ok=True)\n",
    "\n",
    "num_calib = 1000\n",
    "\n",
    "for i in range(num_calib):\n",
    "    img, label = train_raw[i]   # PIL image, label int\n",
    "    filename = os.path.join(out_dir, f\"{i:05d}_label_{label}.png\")\n",
    "    img.save(filename)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "f636d054-047e-4a50-8777-74d0ab47f66d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mnist.json created\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "\n",
    "config = {\n",
    "    \"inputs\": {\"input\": [1, 1, 28, 28]},\n",
    "    \"calibration_num\": 10,\n",
    "    \"calibration_method\": \"ema\",\n",
    "    \"default_loader\": {\n",
    "        \"dataset_path\": \"images/mnist_calibration/\",\n",
    "        \"file_extensions\": [\"png\"],\n",
    "        \"preprocessings\": [\n",
    "            {\"convertColor\": {\"form\": \"RGB2GRAY\"}},\n",
    "            {\"resize\": {\"width\": 28, \"height\": 28}},\n",
    "            {\"div\": {\"x\": 255.0}},\n",
    "            {\"normalize\": {\"mean\": [0.1307], \"std\": [0.3081]}},\n",
    "            {\"expandDim\": {\"axis\": 0}},\n",
    "            {\"expandDim\": {\"axis\": 0}}\n",
    "        ]\n",
    "    }\n",
    "}\n",
    "\n",
    "with open(\"mnist.json\", \"w\") as f:\n",
    "    json.dump(config, f, indent=2)\n",
    "\n",
    "print(\"mnist.json created\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b5e00ae-7d3e-4790-8b8e-88fba96000f4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}

Credits

Marco Hoefle
3 projects • 4 followers
Working for Avnet-Silica as a Software Specialist in the Field Application Team. Technical subjects are: Linux on SoCs, GStreamer, AI/ML

Comments