Dobrea Dan Marius
Created February 18, 2020 © GPL3+

A video human detection system for lives saving

The goal of this project is to study the abilities of the NVIDIA Jetson Nano to detect humans in an unconstraint environment.

AdvancedFull instructions providedOver 4 days84
A video human detection system for lives saving

Things used in this project

Hardware components

NVIDIA Jetson Nano Developer Kit
NVIDIA Jetson Nano Developer Kit
×1

Software apps and online services

PyTorch
WinSCP
Snappy Ubuntu Core
Snappy Ubuntu Core

Story

Read more

Code

Code for human detection system

Python
It is used to aquire the training data base, to train the deep learning neural network and, in the end, to test the neuronal model.
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "All steps of the first stage finished !!!!\n"
     ]
    }
   ],
   "source": [
    "import torchvision.transforms as transforms\n",
    "from dataset import ImageClassificationDataset\n",
    "\n",
    "TASK = 'humanDetect'\n",
    "\n",
    "# CATEGORIES = ['no_human', 'single_human', 'multiple_humans']\n",
    "CATEGORIES = ['no_human', 'human(s)']\n",
    "\n",
    "DATASETS = ['A', 'B', 'C']\n",
    "\n",
    "TRANSFORMS = transforms.Compose([\n",
    "    transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),\n",
    "    transforms.Resize((224, 224)),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "])\n",
    "\n",
    "datasets = {}\n",
    "for name in DATASETS:\n",
    "    datasets[name] = ImageClassificationDataset(TASK + '_' + name, CATEGORIES, TRANSFORMS)\n",
    "    \n",
    "# 0 => no_human in frame\n",
    "# 1 => human(s) in frame\n",
    "realClasses = [0, 1, 0, 1, 0, 1]\n",
    "\n",
    "# display ending message for the user\n",
    "print(\"All steps of the first stage finished !!!!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The video streams created !!!!\n"
     ]
    }
   ],
   "source": [
    "import cv2\n",
    "\n",
    "# Read the video from specified path - train\n",
    "train_inputFile = str (\"train1.mp4\")\n",
    "videoStream = cv2.VideoCapture(train_inputFile) \n",
    "\n",
    "# Read the video from specified path - test\n",
    "test_inputFile = str (\"test1.mp4\")\n",
    "videoStreamTest = cv2.VideoCapture(test_inputFile) \n",
    "\n",
    "# display ending message for the user\n",
    "print(\"The video streams created !!!!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data Collection\n",
    "\n",
    "The cell from below create the data collection tool widget"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data collection widget created !!!!\n"
     ]
    }
   ],
   "source": [
    "import ipywidgets\n",
    "from jetcam.utils import bgr8_to_jpeg\n",
    "\n",
    "# initialize active dataset\n",
    "dataset = datasets[DATASETS[0]]\n",
    "\n",
    "currentVideo = 1 \n",
    "\n",
    "# create image preview\n",
    "image_widget = ipywidgets.Image(width=600, height=400)\n",
    "ret,frame = videoStream.read()  \n",
    "if ret:\n",
    "    image_widget.value = bgr8_to_jpeg(frame) \n",
    "\n",
    "# create widgets\n",
    "dataset_widget  = ipywidgets.Dropdown(options=DATASETS, description='Dataset:')\n",
    "category_widget = ipywidgets.Dropdown(options=dataset.categories, description='Category:')\n",
    "next_widget     = ipywidgets.Button(description='Next frame')\n",
    "next5_widget    = ipywidgets.Button(description='Go forward 5 frames')\n",
    "count_widget    = ipywidgets.IntText(description='Count:')\n",
    "save_widget     = ipywidgets.Button(description='Add')\n",
    "restart_widget  = ipywidgets.Button(description='Restart training acq.')\n",
    "# 'success', 'info', 'warning', 'danger' or ''\n",
    "sep1_widget     = ipywidgets.Button(description='                     ', button_style='danger')\n",
    "sep2_widget     = ipywidgets.Button(description='                     ', button_style='danger')\n",
    "sep1_widget.disabled = True\n",
    "sep2_widget.disabled = True\n",
    "\n",
    "# manually update counts at initialization\n",
    "count_widget.value = dataset.get_count(category_widget.value)\n",
    "\n",
    "# sets the active dataset\n",
    "def set_dataset(change):\n",
    "    global dataset\n",
    "    dataset = datasets[change['new']]\n",
    "    count_widget.value = dataset.get_count(category_widget.value)\n",
    "dataset_widget.observe(set_dataset, names='value')\n",
    "\n",
    "# update counts when we select a new category\n",
    "def update_counts(change):\n",
    "    count_widget.value = dataset.get_count(change['new'])\n",
    "category_widget.observe(update_counts, names='value')\n",
    "\n",
    "#display the next image from the video\n",
    "def next(c):\n",
    "    global frame\n",
    "    global videoStream\n",
    "    global currentVideo\n",
    "    \n",
    "    # reading from video stream one frame \n",
    "    retval,frame = videoStream.read()  \n",
    "    if retval:\n",
    "        # if frames are still left continue creating images\n",
    "        image_widget.value = bgr8_to_jpeg(frame) \n",
    "    else:\n",
    "        videoStream.release()\n",
    "        currentVideo +=1\n",
    "        \n",
    "        # no frames, open a new video\n",
    "        videoStream = cv2.VideoCapture(\"train{}.mp4\".format(currentVideo)) \n",
    "        retval,frame = videoStream.read()\n",
    "        image_widget.value = bgr8_to_jpeg(frame)\n",
    "next_widget.on_click(next)\n",
    "\n",
    "def next5(c):\n",
    "    global frame\n",
    "    global videoStream\n",
    "    global currentVideo\n",
    "    \n",
    "    # reading from video stream five frames     \n",
    "    for i in range(5):\n",
    "        retval,frame = videoStream.read()  \n",
    "        if retval:\n",
    "            # if frames are still left continue with the next frame\n",
    "            pass \n",
    "        else:\n",
    "            videoStream.release()\n",
    "            currentVideo +=1\n",
    "            \n",
    "            # no frames, open a new video\n",
    "            videoStream = cv2.VideoCapture(\"train{}.mp4\".format(currentVideo)) \n",
    "            retval,frame = videoStream.read()\n",
    "    image_widget.value = bgr8_to_jpeg(frame)        \n",
    "next5_widget.on_click(next5)\n",
    "\n",
    "# save image for category and update counts\n",
    "def save(c):\n",
    "    dataset.save_entry(frame, category_widget.value)\n",
    "    count_widget.value = dataset.get_count(category_widget.value)\n",
    "save_widget.on_click(save)\n",
    "\n",
    "# restart the training data acq. process\n",
    "def restart(c):\n",
    "    global videoStream\n",
    "    global currentVideo\n",
    "    \n",
    "    currentVideo = 1\n",
    "    \n",
    "    videoStream.release()\n",
    "    videoStream = cv2.VideoCapture(train_inputFile)\n",
    "    ret,frame = videoStream.read()  \n",
    "    if ret:\n",
    "        image_widget.value = bgr8_to_jpeg(frame)     \n",
    "restart_widget.on_click(restart)\n",
    "\n",
    "data_collection_widget = ipywidgets.VBox([\n",
    "    ipywidgets.HBox([image_widget]), dataset_widget, category_widget, \n",
    "    ipywidgets.HBox([next_widget, next5_widget]), \n",
    "    ipywidgets.HBox([save_widget, restart_widget]), \n",
    "    count_widget, \n",
    "    ipywidgets.HBox([sep1_widget, sep2_widget])\n",
    "])\n",
    "\n",
    "# display(data_collection_widget)\n",
    "print(\"Data collection widget created !!!!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model\n",
    "\n",
    "The following cell defines the neural network and adjust the fully connected layer (fc) to match the outputs required for the project.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Deep Learning model configured and model_widget created !!!\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "\n",
    "device = torch.device('cuda')\n",
    "\n",
    "# ALEXNET\n",
    "model = torchvision.models.alexnet(pretrained=True)\n",
    "model.classifier[-1] = torch.nn.Linear(4096, len(dataset.categories))\n",
    "\n",
    "# SQUEEZENET \n",
    "# model = torchvision.models.squeezenet1_1(pretrained=True)\n",
    "# model.classifier[1] = torch.nn.Conv2d(512, len(dataset.categories), kernel_size=1)\n",
    "# model.num_classes = len(dataset.categories)\n",
    "\n",
    "# RESNET 18\n",
    "#model = torchvision.models.resnet18(pretrained=True)\n",
    "#model.fc = torch.nn.Linear(512, len(dataset.categories))\n",
    "\n",
    "# RESNET 34\n",
    "#model = torchvision.models.resnet34(pretrained=True)\n",
    "#model.fc = torch.nn.Linear(512, len(dataset.categories))\n",
    "    \n",
    "model = model.to(device)\n",
    "\n",
    "# ======================================================================================\n",
    "\n",
    "model_save_button = ipywidgets.Button(description='Save model')\n",
    "model_load_button = ipywidgets.Button(description='Load model')\n",
    "model_path_widget = ipywidgets.Text(description='model path', value='my_model.pth')\n",
    "\n",
    "def load_model(c):\n",
    "    model.load_state_dict(torch.load(model_path_widget.value))\n",
    "model_load_button.on_click(load_model)\n",
    "    \n",
    "def save_model(c):\n",
    "    torch.save(model.state_dict(), model_path_widget.value)\n",
    "model_save_button.on_click(save_model)\n",
    "\n",
    "model_widget = ipywidgets.VBox([\n",
    "    model_path_widget,\n",
    "    ipywidgets.HBox([model_load_button, model_save_button])\n",
    "])\n",
    "\n",
    "# display message to the user\n",
    "print(\"Deep Learning model configured and model_widget created !!!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Live Execution\n",
    "\n",
    "Execute the cell below to set up the live execution widget."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Live execution widget created\n"
     ]
    }
   ],
   "source": [
    "import threading\n",
    "import time\n",
    "import statistics\n",
    "from ipywidgets import Layout\n",
    "from utils import preprocess\n",
    "import torch.nn.functional as F\n",
    "\n",
    "# create test image preview\n",
    "imageTest_widget = ipywidgets.Image(width=600, height=400)\n",
    "ret,frame = videoStreamTest.read()  \n",
    "if ret:\n",
    "    imageTest_widget.value = bgr8_to_jpeg(frame) \n",
    "    \n",
    "state_widget      = ipywidgets.ToggleButtons(options=['stop', 'live'], description='state', value='stop')\n",
    "prediction_widget = ipywidgets.Text(description='prediction')\n",
    "score_widgets = []\n",
    "for category in dataset.categories:\n",
    "    score_widget = ipywidgets.FloatSlider(min=0.0, max=1.0, description=category, orientation='vertical')\n",
    "    score_widgets.append(score_widget)\n",
    "frameskip_widget    = ipywidgets.IntSlider(value=2, min=1, max=10, description='Skip frames:', orientation='horizontal')\n",
    "restartTest_widget  = ipywidgets.Button(description='Restart testing')\n",
    "confusion00_widget  = ipywidgets.IntText(value=0, layout=Layout(width='30%'))\n",
    "confusion01_widget  = ipywidgets.IntText(value=0, layout=Layout(width='30%'))\n",
    "confusion10_widget  = ipywidgets.IntText(value=0, layout=Layout(width='30%'))\n",
    "confusion11_widget  = ipywidgets.IntText(value=0, layout=Layout(width='30%'))\n",
    "timeTest_widget     = ipywidgets.FloatText(description='T_classif. [s]:')\n",
    "accuracyTest_widget = ipywidgets.FloatText(description='Accuracy [%]:')\n",
    "meanClassifT_widget = ipywidgets.FloatText(description='Mean T. [s]:')\n",
    "stdvClassifT_widget = ipywidgets.FloatText(description='Stdv T. [s]:')\n",
    "sepT1_widget     = ipywidgets.Button(description='                     ', button_style='danger')\n",
    "sepT2_widget     = ipywidgets.Button(description='                     ', button_style='danger')\n",
    "sepT1_widget.disabled = True\n",
    "sepT2_widget.disabled = True\n",
    "\n",
    "# first video is selected    \n",
    "currentVideoTest = 1    \n",
    "\n",
    "# number of frames to skip\n",
    "no_frameSkip = 1\n",
    "frameskip_widget.value = no_frameSkip\n",
    "\n",
    "# confusion matrix - classification process performances\n",
    "confusion = [[0, 0], [0, 0]]\n",
    "\n",
    "# create a list to keep all the execution times \n",
    "execTimes = []\n",
    "\n",
    "def live(state_widget, model, prediction_widget, score_widget, imageTest_widget):\n",
    "    global dataset\n",
    "    global execTimes\n",
    "    \n",
    "    while state_widget.value == 'live':\n",
    "        startTest = time.time()\n",
    "        \n",
    "        for i in range(no_frameSkip):\n",
    "            image = next_frame()\n",
    "        if state_widget.value == 'stop':\n",
    "            break\n",
    "        imageTest_widget.value = bgr8_to_jpeg(image) \n",
    "\n",
    "        preprocessed = preprocess(image)\n",
    "        output = model(preprocessed)\n",
    "        output = F.softmax(output, dim=1).detach().cpu().numpy().flatten()\n",
    "        category_index = output.argmax()\n",
    "        prediction_widget.value = dataset.categories[category_index]\n",
    "        for i, score in enumerate(list(output)):\n",
    "            score_widgets[i].value = score\n",
    "            \n",
    "        # fill the confusion matrix accordingly\n",
    "        if  category_index == 0 and realClasses[currentVideoTest - 1] == 0:\n",
    "            confusion[0][0] += 1 \n",
    "        if  category_index == 0 and realClasses[currentVideoTest - 1] == 1:\n",
    "            confusion[0][1] += 1 \n",
    "        if  category_index == 1 and realClasses[currentVideoTest - 1] == 0:\n",
    "            confusion[1][0] += 1 \n",
    "        if  category_index == 1 and realClasses[currentVideoTest - 1] == 1:\n",
    "            confusion[1][1] += 1\n",
    "        confusion00_widget.value  = confusion[0][0]\n",
    "        confusion01_widget.value  = confusion[0][1]\n",
    "        confusion10_widget.value  = confusion[1][0]\n",
    "        confusion11_widget.value  = confusion[1][1]\n",
    "        accuracyTest_widget.value = float(confusion[0][0] + confusion[1][1])/float(confusion[0][0]+confusion[0][1]+confusion[1][0]+confusion[1][1])\n",
    "        \n",
    "        classificationTime = time.time() - startTest\n",
    "        timeTest_widget.value = classificationTime\n",
    "        execTimes.append(classificationTime)\n",
    "        \n",
    "    # mean execution time & standard deviation execution time\n",
    "    meanClassifTime = sum(execTimes)/float(len(execTimes))\n",
    "    stdvClassifTime = statistics.stdev(execTimes)        \n",
    "    meanClassifT_widget.value = meanClassifTime\n",
    "    stdvClassifT_widget.value = stdvClassifTime        \n",
    "\n",
    "# the running thread status changed    \n",
    "def start_live(change):\n",
    "    if change['new'] == 'live':\n",
    "        execute_thread = threading.Thread(target=live, args=(state_widget, model, prediction_widget, score_widget, imageTest_widget))\n",
    "        execute_thread.start()\n",
    "state_widget.observe(start_live, names='value')        \n",
    "        \n",
    "def next_frame():\n",
    "    global videoStreamTest\n",
    "    global currentVideoTest\n",
    "    \n",
    "    # reading from video stream one frame \n",
    "    retval,frameT = videoStreamTest.read()  \n",
    "    if retval:\n",
    "        # if frames are still left continue\n",
    "        pass\n",
    "    else:\n",
    "        videoStreamTest.release()\n",
    "        currentVideoTest +=1\n",
    "        \n",
    "        # no frames, open a new video\n",
    "        videoStreamTest = cv2.VideoCapture(\"test{}.mp4\".format(currentVideoTest)) \n",
    "        retval,frameT = videoStreamTest.read()\n",
    "        if retval:\n",
    "            pass\n",
    "        else:\n",
    "            state_widget.value = 'stop'\n",
    "    return frameT\n",
    "  \n",
    "def frames_skip():\n",
    "    global no_frameSkip\n",
    "    no_frameSkip = frameskip_widget.value    \n",
    "frameskip_widget.observe(frames_skip, names='value')\n",
    "\n",
    "# widget used to restart the testing process\n",
    "def restart_test(c):\n",
    "    global videoStreamTest\n",
    "    global currentVideoTest\n",
    "    global confusion\n",
    "    global execTimes;\n",
    "    \n",
    "    confusion[0][0] = 0\n",
    "    confusion[0][1] = 0\n",
    "    confusion[1][0] = 0\n",
    "    confusion[1][1] = 0\n",
    "    confusion00_widget.value = confusion[0][0]\n",
    "    confusion01_widget.value = confusion[0][1]\n",
    "    confusion10_widget.value = confusion[1][0]\n",
    "    confusion11_widget.value = confusion[1][1]\n",
    "    accuracyTest_widget.value = 0\n",
    "    timeTest_widget.value = 0\n",
    "\n",
    "    # first video is selected    \n",
    "    currentVideoTest = 1\n",
    "    \n",
    "    # reinit execTime list\n",
    "    execTimes = []\n",
    "    meanClassifT_widget.value = 0;\n",
    "    stdvClassifT_widget.value = 0;\n",
    "    \n",
    "    # Read the initial video from specified path - test\n",
    "    videoStreamTest.release()\n",
    "    videoStreamTest = cv2.VideoCapture(test_inputFile) \n",
    "    \n",
    "    retval,frameT = videoStreamTest.read()\n",
    "    imageTest_widget.value = bgr8_to_jpeg(frameT) \n",
    "restartTest_widget.on_click(restart_test) \n",
    "\n",
    "live_execution_widget = ipywidgets.VBox([\n",
    "    imageTest_widget,\n",
    "    ipywidgets.HBox([ ipywidgets.HBox(score_widgets), ipywidgets.VBox([timeTest_widget, \n",
    "                                                                       accuracyTest_widget, \n",
    "                                                                       ipywidgets.HBox([sepT1_widget, sepT2_widget]), \n",
    "                                                                       prediction_widget,\n",
    "                                                                       meanClassifT_widget,\n",
    "                                                                       stdvClassifT_widget\n",
    "                                                                      ]) \n",
    "                    ]),\n",
    "    ipywidgets.HBox([state_widget, restartTest_widget]),\n",
    "    frameskip_widget,\n",
    "    ipywidgets.HBox([\n",
    "        ipywidgets.Label(value=r'\\(\\color{red} {Confusion Matrix:}\\)', layout=Layout(height='80px')),\n",
    "        ipywidgets.VBox([\n",
    "            ipywidgets.HBox([confusion00_widget, confusion01_widget]),\n",
    "            ipywidgets.HBox([confusion10_widget, confusion11_widget]),\n",
    "        ], layout=Layout(display='flex', align_items='center'))\n",
    "    ], layout=Layout(border='solid', display='flex', align_items='center', width='70%'))   \n",
    "])\n",
    "\n",
    "# display a user's message\n",
    "print(\"Live execution widget created\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training and Evaluation\n",
    "\n",
    "Execute the following cell to define the trainer, and the widget to control it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trainer configured and train eval widget created !!!\n"
     ]
    }
   ],
   "source": [
    "# was imported in the previous cell\n",
    "# import time\n",
    "import math\n",
    "\n",
    "BATCH_SIZE = 8\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters())\n",
    "# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)\n",
    "\n",
    "epochs_widget   = ipywidgets.IntText(description='epochs', value=1)\n",
    "eval_button     = ipywidgets.Button(description='evaluate')\n",
    "train_button    = ipywidgets.Button(description='train')\n",
    "loss_widget     = ipywidgets.FloatText(description='loss')\n",
    "accuracy_widget = ipywidgets.FloatText(description='accuracy')\n",
    "progress_widget = ipywidgets.FloatProgress(min=0.0, max=1.0, description='progress')\n",
    "time_widget     = ipywidgets.Text(description='T_spent = ')\n",
    "\n",
    "# 'success', 'info', 'warning', 'danger' or ''\n",
    "sep3_widget     = ipywidgets.Button(description='                     ', button_style='danger')\n",
    "sep4_widget     = ipywidgets.Button(description='                     ', button_style='danger')\n",
    "sep3_widget.disabled = True\n",
    "sep4_widget.disabled = True\n",
    "\n",
    "def train_eval(is_training):\n",
    "    global BATCH_SIZE, LEARNING_RATE, MOMENTUM, model, dataset, optimizer, eval_button, train_button, accuracy_widget, loss_widget, progress_widget, state_widget\n",
    "    \n",
    "    try:\n",
    "        start = time.time()\n",
    "        \n",
    "        train_loader = torch.utils.data.DataLoader(\n",
    "            dataset,\n",
    "            batch_size=BATCH_SIZE,\n",
    "            shuffle=True\n",
    "        )\n",
    "\n",
    "        state_widget.value    = 'stop'\n",
    "        train_button.disabled = True\n",
    "        eval_button.disabled  = True\n",
    "        time.sleep(1)\n",
    "\n",
    "        if is_training:\n",
    "            model = model.train()\n",
    "        else:\n",
    "            model = model.eval()\n",
    "\n",
    "        while epochs_widget.value > 0:\n",
    "            i = 0\n",
    "            sum_loss = 0.0\n",
    "            error_count = 0.0\n",
    "            for images, labels in iter(train_loader):\n",
    "                # send data to device\n",
    "                images = images.to(device)\n",
    "                labels = labels.to(device)\n",
    "\n",
    "                if is_training:\n",
    "                    # zero gradients of parameters\n",
    "                    optimizer.zero_grad()\n",
    "\n",
    "                # execute model to get outputs\n",
    "                outputs = model(images)\n",
    "\n",
    "                # compute loss\n",
    "                loss = F.cross_entropy(outputs, labels)\n",
    "\n",
    "                if is_training:\n",
    "                    # run backpropogation to accumulate gradients\n",
    "                    loss.backward()\n",
    "\n",
    "                    # step optimizer to adjust parameters\n",
    "                    optimizer.step()\n",
    "\n",
    "                # increment progress\n",
    "                error_count += len(torch.nonzero(outputs.argmax(1) - labels).flatten())\n",
    "                count = len(labels.flatten())\n",
    "                i += count\n",
    "                sum_loss += float(loss)\n",
    "                progress_widget.value = i / len(dataset)\n",
    "                loss_widget.value = sum_loss / i\n",
    "                accuracy_widget.value = 1.0 - error_count / i\n",
    "                \n",
    "                # if the loss become very small stop and test maybe is worthing to keep this model\n",
    "                if loss_widget.value < 0.00008:\n",
    "                    epochs_widget.value = 1\n",
    "                \n",
    "            display_time(round(time.time() - start))    \n",
    "            if is_training:\n",
    "                epochs_widget.value = epochs_widget.value - 1\n",
    "            else:\n",
    "                break\n",
    "        end = time.time()\n",
    "        display_time(round(end - start))\n",
    "    except e:\n",
    "        pass\n",
    "    model = model.eval()\n",
    "\n",
    "    train_button.disabled = False\n",
    "    eval_button.disabled = False\n",
    "    \n",
    "    if is_training:\n",
    "        state_widget.value = 'live'\n",
    "    \n",
    "train_button.on_click(lambda c: train_eval(is_training=True))\n",
    "eval_button.on_click(lambda c: train_eval(is_training=False))\n",
    "\n",
    "def display_time(seconds_loc):\n",
    "    hours = math.floor (seconds_loc/3600)\n",
    "    seconds_loc = seconds_loc - hours*3600\n",
    "    minutes = math.floor (seconds_loc/60)\n",
    "    seconds_loc = seconds_loc - minutes*60\n",
    "    time_widget.value = 'h:' + str(hours) + ' m:' + str(minutes) + ' s:' + str(seconds_loc)\n",
    "        \n",
    "train_eval_widget = ipywidgets.VBox([\n",
    "    epochs_widget,\n",
    "    progress_widget,\n",
    "    loss_widget,\n",
    "    accuracy_widget,\n",
    "    ipywidgets.HBox([train_button, eval_button]),\n",
    "    time_widget,\n",
    "    ipywidgets.HBox([sep3_widget, sep4_widget])\n",
    "])\n",
    "\n",
    "# display a message to warn the user\n",
    "print(\"Trainer configured and train eval widget created !!!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Display all widgets\n",
    "\n",
    "Here I will create and display the full interactive widget. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "891690448c2a4b2c84bab719106d1d0f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(HBox(children=(VBox(children=(HBox(children=(Image(value=b'\\xff\\xd8\\xff\\xe0\\x00\\x10JFIF\\x00\\x01"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Combine all the widgets into one display\n",
    "all_widget = ipywidgets.VBox([\n",
    "    ipywidgets.HBox([data_collection_widget, live_execution_widget]),\n",
    "    train_eval_widget,\n",
    "    model_widget\n",
    "])\n",
    "\n",
    "display(all_widget)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Release all space and windows and resources once done \n",
    "# videoStream.release()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}

Credits

Dobrea Dan Marius

Dobrea Dan Marius

11 projects • 31 followers
My research interests are in the areas of computational intelligence, biomedical engineering, IoT, drones and robotics.

Comments