Juan Acevedo
Created December 20, 2018

Detect potential for flood damage

Use ML and image recognition as a least expensive alternative to detect potential for flood damage.

102
Detect potential for flood damage

Things used in this project

Hardware components

Spresense boards (main & extension)
Sony Spresense boards (main & extension)
×1
Sony spresense camera
×1

Software apps and online services

Arduino IDE
Arduino IDE
google cloud datalab
google cloud storage
google cloud ml engine
Sony nnabla neural network libraries

Story

Read more

Schematics

Project Architecture

Code

Arduino code

C/C++
This code uses the Spresense sdk to use the following libraries:
- SDHCI for SD card access.
- GNSS for location services.
- Camera for using the Spresense camera module.
- DNNRT for loading trained neural networks.
/*
 *  gnss.ino - GNSS example application with changes to use the Camera and 
 *  DNNRT libraries in order to make flood damage predictions from coastal
 *  image data.
 *  
 *  Copyright 2018 Sony Semiconductor Solutions Corporation
 *
 *  This library is free software; you can redistribute it and/or
 *  modify it under the terms of the GNU Lesser General Public
 *  License as published by the Free Software Foundation; either
 *  version 2.1 of the License, or (at your option) any later version.
 *
 *  This library is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 *  Lesser General Public License for more details.
 *
 *  You should have received a copy of the GNU Lesser General Public
 *  License along with this library; if not, write to the Free Software
 *  Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
 */

/**
 * @file gnss.ino
 * @author Sony Semiconductor Solutions Corporation
 * @brief GNSS example application
 * @details Spresense has an built in GNSS receiver which supports GPS and other
 *          GNSS satellites. This skecth provides an example the GNSS operation.
 *          Simply upload the sketch, reset the board and check the USB serial 
 *          output. After 3 seconds status information should start to appear.\n\n
 *
 *          This example code is in the public domain.
 */

/* include the GNSS library */
#include <SDHCI.h>
#include <GNSS.h>
#include <math.h>
#include <stdio.h>  /* for sprintf */
#include <Camera.h>
#include <DNNRT.h>

#define STRING_BUFFER_SIZE  128       /**< %Buffer size */

#define RESTART_CYCLE       (60 * 5)  /**< positioning test term */

#define R                   6378.137  // Radius of earth in KM
#define   M_PI              3.14159265358979323846 /* pi */

static SpGnss Gnss;                   /**< SpGnss object */

double prevLat = -1;
double prevLon = -1;

SDClass  theSD;
int pictureCount = 0;

DNNRT dnnrt;

bool isDnnrtSetup = false;

char *dtostrf (double val, signed char width, unsigned char prec, char *sout) {
  char fmt[20];
  sprintf(fmt, "%%%d.%df", width, prec);
  sprintf(sout, fmt, val);
  return sout;
}

/**
 * @brief Turn on / off the LED0 for CPU active notification.
 */
static void Led_isActive(void)
{
  static int state = 1;
  if (state == 1)
  {
    ledOn(PIN_LED0);
    state = 0;
  }
  else
  {
    ledOff(PIN_LED0);
    state = 1;
  }
}

/**
 * @brief Turn on / off the LED1 for positioning state notification.
 * 
 * @param [in] state Positioning state
 */
static void Led_isPosfix(bool state)
{
  if (state)
  {
    ledOn(PIN_LED1);
  }
  else
  {
    ledOff(PIN_LED1);
  }
}

/**
 * @brief Turn on / off the LED3 for error notification.
 * 
 * @param [in] state Error state
 */
static void Led_isError(bool state)
{
  if (state)
  {
    ledOn(PIN_LED3);
  }
  else
  {
    ledOff(PIN_LED3);
  }
}

static double haversineDistance(double lat1, double lon1, double lat2, double lon2){  // generally used geo measurement function
    double dLat = lat2 * M_PI / 180 - lat1 * M_PI / 180;
    double dLon = lon2 * M_PI / 180 - lon1 * M_PI / 180;
    double a = sin(dLat/2) * sin(dLat/2) +
    cos(lat1 * M_PI / 180) * cos(lat2 * M_PI / 180) *
    sin(dLon/2) * sin(dLon/2);
    double c = 2 * atan2(sqrt(a), sqrt(1-a));
    double d = R * c;
    return d * 1000; // meters
}

bool setupDNNRT() {
  Serial.println("loading nn");
  File nnbfile("resnet_result.nnb");
  if (!nnbfile) {
    Serial.print("nnb not found");
    return false;
  }
  
  Serial.print("dnnrt.begin, ret: ");
  int ret = dnnrt.begin(nnbfile);
  Serial.print(ret);
  if (ret < 0) {
    Serial.println("Runtime initialization failure.");
    return false;
  }

  return true;
}

void setupCamera() {
  Serial.println("Prepare camera");
  theCamera.begin();

  //  Serial.println("Set Auto white balance parameter");
  theCamera.setAutoWhiteBalanceMode(CAM_WHITE_BALANCE_DAYLIGHT);
  
//  theCamera.setStillPictureImageFormat(
//     CAM_IMGSIZE_QUADVGA_H,
//     CAM_IMGSIZE_QUADVGA_V,
//     CAM_IMAGE_PIX_FMT_JPG);
  // 300, 300, grayscale if we train a model on 300 x 300 images.
  theCamera.setStillPictureImageFormat(
     300,
     300,
     CAM_IMAGE_PIX_FMT_GRAY);
}

void takeImage(double lat, double lon) {
  Serial.println("call takePicture()");
  CamImage img = theCamera.takePicture();

  if (img.isAvailable()) {
    //char filename[25] = {0};
    String filename = String("");

    char latStr[15];
    dtostrf(lat,8, 3, latStr);

    char lonStr[15];
    dtostrf(lon,8, 3, lonStr);
    
    // printing double into filename
    // https://stackoverflow.com/questions/27651012/arduino-sprintf-float-not-formatting
    if (isDnnrtSetup) {
      
      DNNVariable input(img.getImgSize());
      float *buf = input.data();

      /*
      * Normalize pixel data into between 0.0 and 1.0.
      * Gray scale image, so divide by 255.
      * This normalization due to how the network was trained.
      */
      unsigned char *imgBuf = img.getImgBuff();
      Serial.print("load image to buffer");
      for (int x = 0; x < img.getImgSize(); x++) {
        buf[x] = float(imgBuf[x]) / 255.0;
      }

      dnnrt.inputVariable(input,0);
      dnnrt.forward();
      DNNVariable output = dnnrt.outputVariable(0);
      int label = output.maxIndex();
      filename = filename
        + label + String("_");
      //sprintf(filename,"%03d_.JPG",label);
    } 

    filename = filename 
      + latStr + String("_") 
      + lonStr + String(".JPG");
    //sprintf(filename,"PICT%03d_.JPG",pictureCount);
    
    Serial.print("Save taken picture as ");
    Serial.print(filename);
    Serial.println("");
    /* Save to SD card as the finename */
    
    File myFile = theSD.open(filename, FILE_WRITE);
    myFile.write(img.getImgBuff(), img.getImgSize());
    myFile.close();
    ++pictureCount;
    // TODO - need to create csv file with image,lat,lon
    
  }
  
}

/**
 * @brief Activate GNSS device and start positioning.
 */
void setup() {
  /* put your setup code here, to run once: */

  int error_flag = 0;

  /* Set serial baudrate. */
  Serial.begin(115200);

  /* Wait HW initialization done. */
  sleep(3);

  /* Turn on all LED:Setup start. */
  ledOn(PIN_LED0);
  ledOn(PIN_LED1);
  ledOn(PIN_LED2);
  ledOn(PIN_LED3);

  /* Set Debug mode to Info */
  Gnss.setDebugMode(PrintInfo);

  int result;

  /* Activate GNSS device */
  result = Gnss.begin();

  if (result != 0)
  {
    Serial.println("Gnss begin error!!");
    error_flag = 1;
  }
  else
  {
    /* Setup GNSS */
    Gnss.select(QZ_L1CA);  // Michibiki complement
    Gnss.select(QZ_L1S);   // Michibiki augmentation(Valid only in Japan)

    /* Start positioning */
    result = Gnss.start(COLD_START);
    if (result != 0)
    {
      Serial.println("Gnss start error!!");
      error_flag = 1;
    }
    else
    {
      Serial.println("Gnss setup OK");
    }
  }

  setupCamera();

  isDnnrtSetup = setupDNNRT();

  /* Turn off all LED:Setup done. */
  ledOff(PIN_LED0);
  ledOff(PIN_LED1);
  ledOff(PIN_LED2);
  ledOff(PIN_LED3);

  /* Set error LED. */
  if (error_flag == 1)
  {
    Led_isError(true);
    exit(0);
  }
}

/**
 * @brief %Print position information.
 */
static void print_pos(SpNavData *pNavData)
{
  char StringBuffer[STRING_BUFFER_SIZE];

  /* print time */
  snprintf(StringBuffer, STRING_BUFFER_SIZE, "%04d/%02d/%02d ", pNavData->time.year, pNavData->time.month, pNavData->time.day);
  Serial.print(StringBuffer);

  snprintf(StringBuffer, STRING_BUFFER_SIZE, "%02d:%02d:%02d.%06d, ", pNavData->time.hour, pNavData->time.minute, pNavData->time.sec, pNavData->time.usec);
  Serial.print(StringBuffer);

  /* print satellites count */
  snprintf(StringBuffer, STRING_BUFFER_SIZE, "numSat:%2d, ", pNavData->numSatellites);
  Serial.print(StringBuffer);

  /* print position data */
  if (pNavData->posFixMode == FixInvalid)
  {
    Serial.print("No-Fix, ");
  }
  else
  {
    Serial.print("Fix, ");
  }
  if (pNavData->posDataExist == 0)
  {
    Serial.print("No Position");
  }
  else
  {
    double lat = pNavData->latitude;
    double lon = pNavData->longitude;

    if (prevLat != -1 && prevLon != -1) {
          double dist = haversineDistance(prevLat,prevLon,lat,lon);
          Serial.print("Dist=");
          Serial.print(dist,6);
          // if dist > 50 meters
          if (dist >= 50) {
            takeImage(lat,lon);
            prevLat = lat;
            prevLon = lon;
          }
    }
    

    Serial.print(",Lat=");
    Serial.print(pNavData->latitude, 6);
    Serial.print(", Lon=");
    Serial.print(pNavData->longitude, 6);

    // Calculate eu
    
  }

  Serial.println("");
}

/**
 * @brief %Print satellite condition.
 */
static void print_condition(SpNavData *pNavData)
{
  char StringBuffer[STRING_BUFFER_SIZE];
  unsigned long cnt;

  /* Print satellite count. */
  snprintf(StringBuffer, STRING_BUFFER_SIZE, "numSatellites:%2d\n", pNavData->numSatellites);
  Serial.print(StringBuffer);

  for (cnt = 0; cnt < pNavData->numSatellites; cnt++)
  {
    const char *pType = "---";
    SpSatelliteType sattype = pNavData->getSatelliteType(cnt);

    /* Get satellite type. */
    /* Keep it to three letters. */
    switch (sattype)
    {
      case GPS:
        pType = "GPS";
        break;
      
      case GLONASS:
        pType = "GLN";
        break;

      case QZ_L1CA:
        pType = "QCA";
        break;

      case SBAS:
        pType = "SBA";
        break;

      case QZ_L1S:
        pType = "Q1S";
        break;

      default:
        pType = "UKN";
        break;
    }

    /* Get print conditions. */
    unsigned long Id  = pNavData->getSatelliteId(cnt);
    unsigned long Elv = pNavData->getSatelliteElevation(cnt);
    unsigned long Azm = pNavData->getSatelliteAzimuth(cnt);
    float sigLevel = pNavData->getSatelliteSignalLevel(cnt);

    /* Print satellite condition. */
    snprintf(StringBuffer, STRING_BUFFER_SIZE, "[%2d] Type:%s, Id:%2d, Elv:%2d, Azm:%3d, CN0:", cnt, pType, Id, Elv, Azm );
    Serial.print(StringBuffer);
    Serial.println(sigLevel, 6);
  }
}

/**
 * @brief %Print position information and satellite condition.
 * 
 * @details When the loop count reaches the RESTART_CYCLE value, GNSS device is 
 *          restarted.
 */
void loop()
{
  /* put your main code here, to run repeatedly: */

  static int LoopCount = 0;
  static int LastPrintMin = 0;

  /* Blink LED. */
  Led_isActive();

  /* Check update. */
  if (Gnss.waitUpdate(-1))
  {
    /* Get NaviData. */
    SpNavData NavData;
    Gnss.getNavData(&NavData);

    /* Set posfix LED. */
    bool LedSet = (NavData.posDataExist && (NavData.posFixMode != FixInvalid));
    Led_isPosfix(LedSet);

    /* Print satellite information every minute. */
    if (NavData.time.minute != LastPrintMin)
    {
      print_condition(&NavData);
      LastPrintMin = NavData.time.minute;
    }

    /* Print position information. */
    print_pos(&NavData);
  }
  else
  {
    /* Not update. */
    Serial.println("data not update");
  }

  /* Check loop count. */
  LoopCount++;
  if (LoopCount >= RESTART_CYCLE)
  {
    int error_flag = 0;

    /* Turn off LED0 */
    ledOff(PIN_LED0);

    /* Set posfix LED. */
    Led_isPosfix(false);

    /* Restart GNSS. */
    if (Gnss.stop() != 0)
    {
      Serial.println("Gnss stop error!!");
      error_flag = 1;
    }
    else if (Gnss.end() != 0)
    {
      Serial.println("Gnss end error!!");
      error_flag = 1;
    }
    else
    {
      Serial.println("Gnss stop OK.");
    }

    if (Gnss.begin() != 0)
    {
      Serial.println("Gnss begin error!!");
      error_flag = 1;
    }
    else if (Gnss.start(HOT_START) != 0)
    {
      Serial.println("Gnss start error!!");
      error_flag = 1;
    }
    else
    {
      Serial.println("Gnss restart OK.");
    }

    LoopCount = 0;

    /* Set error LED. */
    if (error_flag == 1)
    {
      Led_isError(true);
      exit(0);
    }
  }
}

notebook.ipynb

JSON
Used for pre processing images and csv files for training image recognition models.
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import\n",
    "from six.moves import range\n",
    "\n",
    "import os\n",
    "\n",
    "import nnabla as nn\n",
    "import nnabla.logger as logger\n",
    "import nnabla.functions as F\n",
    "import nnabla.parametric_functions as PF\n",
    "import nnabla.solvers as S\n",
    "import nnabla.utils.save as save\n",
    "\n",
    "from args import get_args\n",
    "from mnist_data import data_iterator_mnist\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import shutil\n",
    "\n",
    "from PIL import Image\n",
    "import cv2\n",
    "\n",
    "import numpy\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def categorical_error(pred, label):\n",
    "    \"\"\"\n",
    "    Compute categorical error given score vectors and labels as\n",
    "    numpy.ndarray.\n",
    "    \"\"\"\n",
    "    pred_label = pred.argmax(1)\n",
    "    return (pred_label != label.flat).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def augmentation(h, test, aug):\n",
    "    if aug is None:\n",
    "        aug = not test\n",
    "    if aug:\n",
    "        h = F.image_augmentation(h, (1, 28, 28), (0, 0), 0.9, 1.1, 0.3,\n",
    "                                 1.3, 0.1, False, False, 0.5, False, 1.5, 0.5, False, 0.1, 0)\n",
    "    return h"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mnist_lenet_prediction(image, test=False, aug=None):\n",
    "    \"\"\"\n",
    "    Construct LeNet for MNIST.\n",
    "    \"\"\"\n",
    "    image /= 255.0\n",
    "    image = augmentation(image, test, aug)\n",
    "    c1 = PF.convolution(image, 16, (5, 5), name='conv1')\n",
    "    c1 = F.relu(F.max_pooling(c1, (2, 2)), inplace=True)\n",
    "    c2 = PF.convolution(c1, 16, (5, 5), name='conv2')\n",
    "    c2 = F.relu(F.max_pooling(c2, (2, 2)), inplace=True)\n",
    "    c3 = F.relu(PF.affine(c2, 50, name='fc3'), inplace=True)\n",
    "    c4 = PF.affine(c3, 10, name='fc4')\n",
    "    return c4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mnist_resnet_prediction(image, test=False, aug=None):\n",
    "    \"\"\"\n",
    "    Construct ResNet for MNIST.\n",
    "    \"\"\"\n",
    "    image /= 255.0\n",
    "    image = augmentation(image, test, aug)\n",
    "\n",
    "    def bn(x):\n",
    "        return PF.batch_normalization(x, batch_stat=not test)\n",
    "\n",
    "    def res_unit(x, scope):\n",
    "        C = x.shape[1]\n",
    "        with nn.parameter_scope(scope):\n",
    "            with nn.parameter_scope('conv1'):\n",
    "                h = F.elu(bn(PF.convolution(x, C / 2, (1, 1), with_bias=False)))\n",
    "            with nn.parameter_scope('conv2'):\n",
    "                h = F.elu(\n",
    "                    bn(PF.convolution(h, C / 2, (3, 3), pad=(1, 1), with_bias=False)))\n",
    "            with nn.parameter_scope('conv3'):\n",
    "                h = bn(PF.convolution(h, C, (1, 1), with_bias=False))\n",
    "        return F.elu(F.add2(h, x, inplace=True))\n",
    "    # Conv1 --> 64 x 32 x 32\n",
    "    with nn.parameter_scope(\"conv1\"):\n",
    "        c1 = F.elu(\n",
    "            bn(PF.convolution(image, 64, (3, 3), pad=(3, 3), with_bias=False)))\n",
    "    # Conv2 --> 64 x 16 x 16\n",
    "    c2 = F.max_pooling(res_unit(c1, \"conv2\"), (2, 2))\n",
    "    # Conv3 --> 64 x 8 x 8\n",
    "    c3 = F.max_pooling(res_unit(c2, \"conv3\"), (2, 2))\n",
    "    # Conv4 --> 64 x 8 x 8\n",
    "    c4 = res_unit(c3, \"conv4\")\n",
    "    # Conv5 --> 64 x 4 x 4\n",
    "    c5 = F.max_pooling(res_unit(c4, \"conv5\"), (2, 2))\n",
    "    # Conv5 --> 64 x 4 x 4\n",
    "    c6 = res_unit(c5, \"conv6\")\n",
    "    pl = F.average_pooling(c6, (4, 4))\n",
    "    with nn.parameter_scope(\"classifier\"):\n",
    "        y = PF.affine(pl, 10)\n",
    "    return y\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train():\n",
    "    \"\"\"\n",
    "    Main script.\n",
    "\n",
    "    Steps:\n",
    "\n",
    "    * Parse command line arguments.\n",
    "    * Specify a context for computation.\n",
    "    * Initialize DataIterator for MNIST.\n",
    "    * Construct a computation graph for training and validation.\n",
    "    * Initialize a solver and set parameter variables to it.\n",
    "    * Create monitor instances for saving and displaying training stats.\n",
    "    * Training loop\n",
    "      * Computate error rate for validation data (periodically)\n",
    "      * Get a next minibatch.\n",
    "      * Execute forwardprop on the training graph.\n",
    "      * Compute training error\n",
    "      * Set parameter gradients zero\n",
    "      * Execute backprop.\n",
    "      * Solver updates parameters by using gradients computed by backprop.\n",
    "    \"\"\"\n",
    "    args = get_args()\n",
    "\n",
    "    from numpy.random import seed\n",
    "    seed(0)\n",
    "\n",
    "    # Get context.\n",
    "    from nnabla.ext_utils import get_extension_context\n",
    "    logger.info(\"Running in %s\" % args.context)\n",
    "    ctx = get_extension_context(\n",
    "        args.context, device_id=args.device_id, type_config=args.type_config)\n",
    "    nn.set_default_context(ctx)\n",
    "\n",
    "    # Create CNN network for both training and testing.\n",
    "    if args.net == 'lenet':\n",
    "        mnist_cnn_prediction = mnist_lenet_prediction\n",
    "    elif args.net == 'resnet':\n",
    "        mnist_cnn_prediction = mnist_resnet_prediction\n",
    "    else:\n",
    "        raise ValueError(\"Unknown network type {}\".format(args.net))\n",
    "\n",
    "    # TRAIN\n",
    "    # Create input variables.\n",
    "    image = nn.Variable([args.batch_size, 1, 28, 28])\n",
    "    label = nn.Variable([args.batch_size, 1])\n",
    "    # Create prediction graph.\n",
    "    pred = mnist_cnn_prediction(image, test=False, aug=args.augment_train)\n",
    "    pred.persistent = True\n",
    "    # Create loss function.\n",
    "    loss = F.mean(F.softmax_cross_entropy(pred, label))\n",
    "\n",
    "    # TEST\n",
    "    # Create input variables.\n",
    "    vimage = nn.Variable([args.batch_size, 1, 28, 28])\n",
    "    vlabel = nn.Variable([args.batch_size, 1])\n",
    "    # Create prediction graph.\n",
    "    vpred = mnist_cnn_prediction(vimage, test=True, aug=args.augment_test)\n",
    "\n",
    "    # Create Solver.\n",
    "    solver = S.Adam(args.learning_rate)\n",
    "    solver.set_parameters(nn.get_parameters())\n",
    "\n",
    "    # Create monitor.\n",
    "    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed\n",
    "    monitor = Monitor(args.monitor_path)\n",
    "    monitor_loss = MonitorSeries(\"Training loss\", monitor, interval=10)\n",
    "    monitor_err = MonitorSeries(\"Training error\", monitor, interval=10)\n",
    "    monitor_time = MonitorTimeElapsed(\"Training time\", monitor, interval=100)\n",
    "    monitor_verr = MonitorSeries(\"Test error\", monitor, interval=10)\n",
    "\n",
    "    # Initialize DataIterator for MNIST.\n",
    "    from numpy.random import RandomState\n",
    "    data = data_iterator_mnist(args.batch_size, True, rng=RandomState(1223))\n",
    "    vdata = data_iterator_mnist(args.batch_size, False)\n",
    "    # Training loop.\n",
    "    for i in range(args.max_iter):\n",
    "        if i % args.val_interval == 0:\n",
    "            # Validation\n",
    "            ve = 0.0\n",
    "            for j in range(args.val_iter):\n",
    "                vimage.d, vlabel.d = vdata.next()\n",
    "                vpred.forward(clear_buffer=True)\n",
    "                vpred.data.cast(np.float32, ctx)\n",
    "                ve += categorical_error(vpred.d, vlabel.d)\n",
    "            monitor_verr.add(i, ve / args.val_iter)\n",
    "        if i % args.model_save_interval == 0:\n",
    "            nn.save_parameters(os.path.join(\n",
    "                args.model_save_path, 'params_%06d.h5' % i))\n",
    "        # Training forward\n",
    "        image.d, label.d = data.next()\n",
    "        solver.zero_grad()\n",
    "        loss.forward(clear_no_need_grad=True)\n",
    "        loss.backward(clear_buffer=True)\n",
    "        solver.weight_decay(args.weight_decay)\n",
    "        solver.update()\n",
    "        loss.data.cast(np.float32, ctx)\n",
    "        pred.data.cast(np.float32, ctx)\n",
    "        e = categorical_error(pred.d, label.d)\n",
    "        monitor_loss.add(i, loss.d.copy())\n",
    "        monitor_err.add(i, e)\n",
    "        monitor_time.add(i)\n",
    "\n",
    "    ve = 0.0\n",
    "    for j in range(args.val_iter):\n",
    "        vimage.d, vlabel.d = vdata.next()\n",
    "        vpred.forward(clear_buffer=True)\n",
    "        ve += categorical_error(vpred.d, vlabel.d)\n",
    "    monitor_verr.add(i, ve / args.val_iter)\n",
    "\n",
    "    parameter_file = os.path.join(\n",
    "        args.model_save_path, '{}_params_{:06}.h5'.format(args.net, args.max_iter))\n",
    "    nn.save_parameters(parameter_file)\n",
    "\n",
    "    # append F.Softmax to the prediction graph so users see intuitive outputs\n",
    "    runtime_contents = {\n",
    "        'networks': [\n",
    "            {'name': 'Validation',\n",
    "             'batch_size': args.batch_size,\n",
    "             'outputs': {'y': F.softmax(vpred)},\n",
    "             'names': {'x': vimage}}],\n",
    "        'executors': [\n",
    "            {'name': 'Runtime',\n",
    "             'network': 'Validation',\n",
    "             'data': ['x'],\n",
    "             'output': ['y']}]}\n",
    "    save.save(os.path.join(args.model_save_path,\n",
    "                           '{}_result.nnp'.format(args.net)), runtime_contents)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2018-12-22 20:17:54,773 [nnabla][INFO]: Getting label data from http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz.\n",
      "2018-12-22 20:17:54,775 [nnabla][INFO]: > /Users/hidden/nnabla_data/t10k-labels-idx1-ubyte.gz already exists.\n",
      "2018-12-22 20:17:54,776 [nnabla][INFO]: > If you have any issue when using this file, \n",
      "2018-12-22 20:17:54,777 [nnabla][INFO]: > manually remove the file and try download again.\n",
      "2018-12-22 20:17:54,781 [nnabla][INFO]: Getting label data done.\n",
      "2018-12-22 20:17:54,781 [nnabla][INFO]: Getting image data from http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz.\n",
      "2018-12-22 20:17:54,782 [nnabla][INFO]: > /Users/hidden/nnabla_data/t10k-images-idx3-ubyte.gz already exists.\n",
      "2018-12-22 20:17:54,784 [nnabla][INFO]: > If you have any issue when using this file, \n",
      "2018-12-22 20:17:54,784 [nnabla][INFO]: > manually remove the file and try download again.\n",
      "2018-12-22 20:17:54,834 [nnabla][INFO]: Getting image data done.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10000\n"
     ]
    }
   ],
   "source": [
    "# From mnist_data, I want to see what the image and label data looks like.\n",
    "\n",
    "import numpy\n",
    "import struct\n",
    "import zlib\n",
    "\n",
    "from nnabla.logger import logger\n",
    "from nnabla.utils.data_iterator import data_iterator\n",
    "from nnabla.utils.data_source import DataSource\n",
    "from nnabla.utils.data_source_loader import download\n",
    "\n",
    "def load_mnist(train=True):\n",
    "    '''\n",
    "    Load MNIST dataset images and labels from the original page by Yan LeCun or the cache file.\n",
    "\n",
    "    Args:\n",
    "        train (bool): The testing dataset will be returned if False. Training data has 60000 images, while testing has 10000 images.\n",
    "\n",
    "    Returns:\n",
    "        numpy.ndarray: A shape of (#images, 1, 28, 28). Values in [0.0, 1.0].\n",
    "        numpy.ndarray: A shape of (#images, 1). Values in {0, 1, ..., 9}.\n",
    "\n",
    "    '''\n",
    "    if train:\n",
    "        image_uri = 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'\n",
    "        label_uri = 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'\n",
    "    else:\n",
    "        image_uri = 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz'\n",
    "        label_uri = 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz'\n",
    "    logger.info('Getting label data from {}.'.format(label_uri))\n",
    "    # With python3 we can write this logic as following, but with\n",
    "    # python2, gzip.object does not support file-like object and\n",
    "    # urllib.request does not support 'with statement'.\n",
    "    #\n",
    "    #   with request.urlopen(label_uri) as r, gzip.open(r) as f:\n",
    "    #       _, size = struct.unpack('>II', f.read(8))\n",
    "    #       labels = numpy.frombuffer(f.read(), numpy.uint8).reshape(-1, 1)\n",
    "    #\n",
    "    r = download(label_uri)\n",
    "    data = zlib.decompress(r.read(), zlib.MAX_WBITS | 32)\n",
    "    _, size = struct.unpack('>II', data[0:8])\n",
    "    print(size)\n",
    "    labels = numpy.frombuffer(data[8:], numpy.uint8).reshape(-1, 1)\n",
    "    r.close()\n",
    "    logger.info('Getting label data done.')\n",
    "\n",
    "    logger.info('Getting image data from {}.'.format(image_uri))\n",
    "    r = download(image_uri)\n",
    "    data = zlib.decompress(r.read(), zlib.MAX_WBITS | 32)\n",
    "    _, size, height, width = struct.unpack('>IIII', data[0:16])\n",
    "    images = numpy.frombuffer(data[16:], numpy.uint8).reshape(\n",
    "        size, 1, height, width)\n",
    "    r.close()\n",
    "    logger.info('Getting image data done.')\n",
    "\n",
    "    return images, labels\n",
    "\n",
    "images, labels = load_mnist(train=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 149,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10000\n",
      "<class 'numpy.ndarray'>\n",
      "labels shape: (10000, 1)\n",
      "<class 'numpy.ndarray'>\n",
      "(28, 28)\n",
      "(10000, 1, 28, 28)\n"
     ]
    }
   ],
   "source": [
    "from PIL import Image\n",
    "print(len(labels))\n",
    "print(type(images))\n",
    "#print(labels)\n",
    "print('labels shape: ' + str(labels.shape))\n",
    "img = images[0][0]\n",
    "print(type(img))\n",
    "print(img.shape)\n",
    "img = Image.fromarray(img, 'L')\n",
    "#print(images)\n",
    "img.show()\n",
    "#print(len(images[0][0]))\n",
    "print(images.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2, 1, 3, 3)"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = np.array([[[[1, 2, 3],[1, 2, 3],[1, 2, 3]]],[[[1, 2, 3],[1, 2, 3],[1, 2, 3]]]])\n",
    "x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 211,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "labels shape: (743, 1)\n",
      "images shape: (743, 1, 224, 224)\n"
     ]
    }
   ],
   "source": [
    "LOCAL_FILEPATH = '/Users/hidden/MyApps/hackathons/sony_make_it_better/coastal_images/'\n",
    "IMAGES_FILEPATH = LOCAL_FILEPATH + 'processed_images/'\n",
    "def load_coastal_images(train=True):\n",
    "    images = numpy.empty(0)\n",
    "    if train:\n",
    "        csvPath = LOCAL_FILEPATH + 'local_labeled_images_train.csv'\n",
    "    else:\n",
    "        csvPath = LOCAL_FILEPATH + 'local_labeled_images_eval.csv'\n",
    "    dataset = pd.read_csv(csvPath)\n",
    "    labels = dataset['class'].values\n",
    "    labels = labels.reshape(len(labels),1)\n",
    "    #print('labels shape: ' + str(labels.shape))\n",
    "    #print(labels)\n",
    "    \n",
    "    # Remove this when is working\n",
    "    #dataset = dataset.head()\n",
    "    \n",
    "    # Now images\n",
    "    \n",
    "    imagePaths = dataset['filename'].values\n",
    "    for row in imagePaths:\n",
    "        imgPath = IMAGES_FILEPATH + row\n",
    "        im = Image.open(imgPath)\n",
    "        # read into numpy array\n",
    "        im = numpy.asarray(im)\n",
    "        # drop second value in every pixel which is 255\n",
    "        im = numpy.delete(im, numpy.s_[1:2], axis=2) \n",
    "        # transpose it to set it to (1,224,224)\n",
    "        im = im.T\n",
    "        ## drop the 0 and transpose it again to set it to (224, 224)\n",
    "        im = im[0].T\n",
    "        #im = im.reshape(1,224,224)\n",
    "        #print(im.shape)\n",
    "        #print(imgPath)\n",
    "        images = numpy.append(images,im)\n",
    "        \n",
    "    images = images.reshape(len(imagePaths),1,224,224)\n",
    "    return images, labels\n",
    "\n",
    "\n",
    "images, labels = load_coastal_images()\n",
    "print('labels shape: ' + str(labels.shape))\n",
    "print('images shape: ' + str(images.shape))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 179,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'PIL.PngImagePlugin.PngImageFile'>\n",
      "1\n",
      "(1, 224, 224)\n"
     ]
    }
   ],
   "source": [
    "# Test code. No need to run it. \n",
    "#im = Image.open('/Users/hidden/MyApps/hackathons/sony_make_it_better/coastal_images/processed_images/IMG_0001_SecHKL_Sum12_Pt2.png')#.convert('LA')\n",
    "#print(type(im))\n",
    "#pic = numpy.asarray(im)\n",
    "#pic = numpy.delete(pic, numpy.s_[1:2], axis=2)  \n",
    "#pic = pic.T\n",
    "#pic = pic[0].T\n",
    "#pic = pic.reshape(1,224,224)\n",
    "#print(len(pic))\n",
    "#print(pic.shape)\n",
    "#img = Image.fromarray(pic[0], 'L')\n",
    "#img.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Index(['filename', 'class'], dtype='object')\n",
      "736\n",
      "processed_images/9/IMG_3062_SecDE_Spr12.png\n",
      "processed_images/9/IMG_3065_SecDE_Spr12.png\n",
      "processed_images/9/IMG_3068_SecDE_Spr12.png\n",
      "processed_images/9/IMG_3072_SecDE_Spr12.png\n",
      "processed_images/9/IMG_3075_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2956_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2464_SecDE_Spr12.png\n",
      "processed_images/9/IMG_3842_SecDE_Spr12.png\n",
      "processed_images/9/IMG_3714_SecDE_Spr12.png\n",
      "processed_images/9/IMG_3843_SecDE_Spr12.png\n",
      "processed_images/9/IMG_1871_SecDE_Spr12.png\n",
      "processed_images/9/IMG_1831_SecDE_Spr12.png\n",
      "processed_images/9/IMG_0672_SecBC_Spr12.png\n",
      "processed_images/9/IMG_0675_SecBC_Spr12.png\n",
      "processed_images/9/IMG_0678_SecBC_Spr12.png\n",
      "processed_images/9/IMG_0681_SecBC_Spr12.png\n",
      "processed_images/9/IMG_0683_SecBC_Spr12.png\n",
      "processed_images/9/IMG_0685_SecBC_Spr12.png\n",
      "processed_images/9/IMG_0688_SecBC_Spr12.png\n",
      "processed_images/9/IMG_0691_SecBC_Spr12.png\n",
      "processed_images/9/IMG_0695_SecBC_Spr12.png\n",
      "processed_images/9/IMG_0697_SecBC_Spr12.png\n",
      "processed_images/9/IMG_2645_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2647_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2744_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2390_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2398_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2403_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2407_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2409_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2412_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2415_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2417_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2419_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2422_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2424_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2426_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2428_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2430_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2434_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2436_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2439_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2444_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2446_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2450_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2451_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2455_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2499_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2505_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2508_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2511_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2514_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2541_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2547_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2550_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2553_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2559_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2566_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2570_SecDE_Spr12.png\n",
      "processed_images/9/IMG_2630_SecDE_Spr12.png\n",
      "processed_images/9/IMG_1419_SecBC_Spr12.png\n",
      "processed_images/9/IMG_1765_SecBC_Spr12.png\n",
      "processed_images/9/IMG_1768_SecBC_Spr12.png\n",
      "processed_images/9/IMG_1771_SecBC_Spr12.png\n",
      "processed_images/9/IMG_1773_SecBC_Spr12.png\n",
      "processed_images/9/IMG_1776_SecBC_Spr12.png\n",
      "processed_images/9/IMG_1877_SecDE_Spr12.png\n",
      "processed_images/9/IMG_1881_SecDE_Spr12.png\n",
      "processed_images/9/IMG_1884_SecDE_Spr12.png\n",
      "processed_images/9/IMG_1889_SecDE_Spr12.png\n",
      "processed_images/9/IMG_1891_SecDE_Spr12.png\n",
      "processed_images/9/IMG_1893_SecDE_Spr12.png\n",
      "processed_images/9/IMG_1896_SecDE_Spr12.png\n",
      "processed_images/9/IMG_1899_SecDE_Spr12.png\n",
      "processed_images/9/IMG_1902_SecDE_Spr12.png\n",
      "processed_images/9/IMG_1916_SecDE_Spr12.png\n",
      "processed_images/9/IMG_1921_SecDE_Spr12.png\n",
      "processed_images/9/IMG_1949_SecDE_Spr12.png\n",
      "processed_images/9/IMG_1961_SecDE_Spr12.png\n",
      "processed_images/31/IMG_4173_SecFG_Spr12.png\n",
      "processed_images/31/IMG_4191_SecFG_Spr12.png\n",
      "processed_images/31/IMG_4185_SecFG_Spr12.png\n",
      "processed_images/31/IMG_4232_SecFG_Spr12.png\n",
      "processed_images/31/IMG_4221_SecFG_Spr12.png\n",
      "processed_images/31/IMG_4216_SecFG_Spr12.png\n",
      "processed_images/31/IMG_4210_SecFG_Spr12.png\n",
      "processed_images/31/IMG_4201_SecFG_Spr12.png\n",
      "processed_images/31/IMG_4205_SecFG_Spr12.png\n",
      "processed_images/31/IMG_1571_SecBC_Spr12.png\n",
      "processed_images/31/IMG_1573_SecBC_Spr12.png\n",
      "processed_images/31/IMG_3907_SecDE_Spr12.png\n",
      "processed_images/31/IMG_3911_SecDE_Spr12.png\n",
      "processed_images/31/IMG_3914_SecDE_Spr12.png\n",
      "processed_images/31/IMG_3918_SecDE_Spr12.png\n",
      "processed_images/31/IMG_0369_SecBC_Spr12.png\n",
      "processed_images/31/IMG_0372_SecBC_Spr12.png\n",
      "processed_images/31/IMG_0378_SecBC_Spr12.png\n",
      "processed_images/31/IMG_2788_SecDE_Spr12.png\n",
      "processed_images/31/IMG_2792_SecDE_Spr12.png\n",
      "processed_images/31/IMG_2794_SecDE_Spr12.png\n",
      "processed_images/31/IMG_2798_SecDE_Spr12.png\n",
      "processed_images/31/IMG_4760_SecFG_Spr12.png\n",
      "processed_images/31/IMG_0349_SecBC_Spr12.png\n",
      "processed_images/31/IMG_2819_SecDE_Spr12.png\n",
      "processed_images/31/IMG_2822_SecDE_Spr12.png\n",
      "processed_images/31/IMG_2824_SecDE_Spr12.png\n",
      "processed_images/31/IMG_2812_SecDE_Spr12.png\n",
      "processed_images/31/IMG_2815_SecDE_Spr12.png\n",
      "processed_images/31/IMG_2804_SecDE_Spr12.png\n",
      "processed_images/31/IMG_2807_SecDE_Spr12.png\n",
      "processed_images/31/IMG_2783_SecDE_Spr12.png\n",
      "processed_images/31/IMG_2845_SecDE_Spr12.png\n",
      "processed_images/31/IMG_2849_SecDE_Spr12.png\n",
      "processed_images/31/IMG_2852_SecDE_Spr12.png\n",
      "processed_images/31/IMG_2857_SecDE_Spr12.png\n",
      "processed_images/31/IMG_2870_SecDE_Spr12.png\n",
      "processed_images/31/IMG_4320_SecFG_Spr12.png\n",
      "processed_images/31/IMG_2375_SecDE_Spr12.png\n",
      "processed_images/31/IMG_2269_SecDE_Spr12.png\n",
      "processed_images/31/IMG_2272_SecDE_Spr12.png\n",
      "processed_images/31/IMG_2276_SecDE_Spr12.png\n",
      "processed_images/31/IMG_2281_SecDE_Spr12.png\n",
      "processed_images/31/IMG_4313_SecFG_Spr12.png\n",
      "processed_images/31/IMG_4316_SecFG_Spr12.png\n",
      "processed_images/31/IMG_4312_SecFG_Spr12.png\n",
      "processed_images/31/IMG_2873_SecDE_Spr12.png\n",
      "processed_images/31/IMG_2876_SecDE_Spr12.png\n",
      "processed_images/31/IMG_2879_SecDE_Spr12.png\n",
      "processed_images/31/IMG_2887_SecDE_Spr12.png\n",
      "processed_images/31/IMG_2884_SecDE_Spr12.png\n",
      "processed_images/31/IMG_2898_SecDE_Spr12.png\n",
      "processed_images/31/IMG_2231_SecDE_Spr12.png\n",
      "processed_images/31/IMG_2022_SecDE_Spr12.png\n",
      "processed_images/31/IMG_3817_SecDE_Spr12.png\n",
      "processed_images/31/IMG_3737_SecDE_Spr12.png\n",
      "processed_images/31/IMG_3729_SecDE_Spr12.png\n",
      "processed_images/31/IMG_1862_SecDE_Spr12.png\n",
      "processed_images/7/IMG_4229_SecFG_Spr12.png\n",
      "processed_images/7/IMG_2489_SecDE_Spr12.png\n",
      "processed_images/7/IMG_2497_SecDE_Spr12.png\n",
      "processed_images/7/IMG_2080_SecDE_Spr12.png\n",
      "processed_images/7/IMG_3794_SecDE_Spr12.png\n",
      "processed_images/7/IMG_3795_SecDE_Spr12.png\n",
      "processed_images/7/IMG_3797_SecDE_Spr12.png\n",
      "processed_images/7/IMG_3801_SecDE_Spr12.png\n",
      "processed_images/7/IMG_3765_SecDE_Spr12.png\n",
      "processed_images/7/IMG_3762_SecDE_Spr12.png\n",
      "processed_images/7/IMG_3749_SecDE_Spr12.png\n",
      "processed_images/7/IMG_3752_SecDE_Spr12.png\n",
      "processed_images/7/IMG_3685_SecDE_Spr12.png\n",
      "processed_images/7/IMG_3718_SecDE_Spr12.png\n",
      "processed_images/7/IMG_3848_SecDE_Spr12.png\n",
      "processed_images/7/IMG_3743_SecDE_Spr12.png\n",
      "processed_images/7/IMG_3746_SecDE_Spr12.png\n",
      "processed_images/7/IMG_3740_SecDE_Spr12.png\n",
      "processed_images/7/IMG_1663_SecBC_Spr12.png\n",
      "processed_images/7/IMG_1661_SecBC_Spr12.png\n",
      "processed_images/7/IMG_1659_SecBC_Spr12.png\n",
      "processed_images/7/IMG_1669_SecBC_Spr12.png\n",
      "processed_images/7/IMG_1649_SecBC_Spr12.png\n",
      "processed_images/7/IMG_1652_SecBC_Spr12.png\n",
      "processed_images/7/IMG_1655_SecBC_Spr12.png\n",
      "processed_images/7/IMG_1593_SecBC_Spr12.png\n",
      "processed_images/7/IMG_1598_SecBC_Spr12.png\n",
      "processed_images/7/IMG_1602_SecBC_Spr12.png\n",
      "processed_images/7/IMG_1404_SecBC_Spr12.png\n",
      "processed_images/7/IMG_1401_SecBC_Spr12.png\n",
      "processed_images/7/IMG_1415_SecBC_Spr12.png\n",
      "processed_images/7/IMG_1408_SecBC_Spr12.png\n",
      "processed_images/7/IMG_1604_SecBC_Spr12.png\n",
      "processed_images/7/IMG_1397_SecBC_Spr12.png\n",
      "processed_images/7/IMG_1410_SecBC_Spr12.png\n",
      "processed_images/7/IMG_1606_SecBC_Spr12.png\n",
      "processed_images/7/IMG_1609_SecBC_Spr12.png\n",
      "processed_images/7/IMG_1389_SecBC_Spr12.png\n",
      "processed_images/7/IMG_1382_SecBC_Spr12.png\n",
      "processed_images/7/IMG_1379_SecBC_Spr12.png\n",
      "processed_images/7/IMG_1623_SecBC_Spr12.png\n",
      "processed_images/7/IMG_0107_SecBC_Spr12.png\n",
      "processed_images/7/IMG_0617_SecBC_Spr12.png\n",
      "processed_images/7/IMG_0620_SecBC_Spr12.png\n",
      "processed_images/7/IMG_0614_SecBC_Spr12.png\n",
      "processed_images/7/IMG_3890_SecDE_Spr12.png\n",
      "processed_images/7/IMG_3894_SecDE_Spr12.png\n",
      "processed_images/7/IMG_3898_SecDE_Spr12.png\n",
      "processed_images/7/IMG_2462_SecDE_Spr12.png\n",
      "processed_images/7/IMG_2478_SecDE_Spr12.png\n",
      "processed_images/7/IMG_2161_SecDE_Spr12.png\n",
      "processed_images/7/IMG_2386_SecDE_Spr12.png\n",
      "processed_images/7/IMG_0096_SecBC_Spr12.png\n",
      "processed_images/7/IMG_0099_SecBC_Spr12.png\n",
      "processed_images/7/IMG_0103_SecBC_Spr12.png\n",
      "processed_images/7/IMG_0119_SecBC_Spr12.png\n",
      "processed_images/7/IMG_0123_SecBC_Spr12.png\n",
      "processed_images/7/IMG_0125_SecBC_Spr12.png\n",
      "processed_images/7/IMG_0130_SecBC_Spr12.png\n",
      "processed_images/7/IMG_0134_SecBC_Spr12.png\n",
      "processed_images/7/IMG_0138_SecBC_Spr12.png\n",
      "processed_images/7/IMG_0142_SecBC_Spr12.png\n",
      "processed_images/7/IMG_0145_SecBC_Spr12.png\n",
      "processed_images/7/IMG_0155_SecBC_Spr12.png\n",
      "processed_images/7/IMG_0601_SecBC_Spr12.png\n",
      "processed_images/7/IMG_0594_SecBC_Spr12.png\n",
      "processed_images/7/IMG_1079_SecBC_Spr12.png\n",
      "processed_images/7/IMG_1082_SecBC_Spr12.png\n",
      "processed_images/7/IMG_1085_SecBC_Spr12.png\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "processed_images/7/IMG_1087_SecBC_Spr12.png\n",
      "processed_images/7/IMG_1089_SecBC_Spr12.png\n",
      "processed_images/7/IMG_1091_SecBC_Spr12.png\n",
      "processed_images/7/IMG_1094_SecBC_Spr12.png\n",
      "processed_images/7/IMG_1096_SecBC_Spr12.png\n",
      "processed_images/7/IMG_1109_SecBC_Spr12.png\n",
      "processed_images/7/IMG_1112_SecBC_Spr12.png\n",
      "processed_images/7/IMG_1100_SecBC_Spr12.png\n",
      "processed_images/7/IMG_1103_SecBC_Spr12.png\n",
      "processed_images/7/IMG_1105_SecBC_Spr12.png\n",
      "processed_images/7/IMG_0807_SecBC_Spr12.png\n",
      "processed_images/11/IMG_4251_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4587_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4589_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4592_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4595_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4604_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4622_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4598_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4601_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4697_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4690_SecFG_Spr12.png\n",
      "processed_images/11/IMG_2967_SecDE_Spr12.png\n",
      "processed_images/11/IMG_2962_SecDE_Spr12.png\n",
      "processed_images/11/IMG_2638_SecDE_Spr12.png\n",
      "processed_images/11/IMG_2643_SecDE_Spr12.png\n",
      "processed_images/11/IMG_2599_SecDE_Spr12.png\n",
      "processed_images/11/IMG_2360_SecDE_Spr12.png\n",
      "processed_images/11/IMG_2362_SecDE_Spr12.png\n",
      "processed_images/11/IMG_2277_SecDE_Spr12.png\n",
      "processed_images/11/IMG_2369_SecDE_Spr12.png\n",
      "processed_images/11/IMG_2260_SecDE_Spr12.png\n",
      "processed_images/11/IMG_2243_SecDE_Spr12.png\n",
      "processed_images/11/IMG_2245_SecDE_Spr12.png\n",
      "processed_images/11/IMG_2248_SecDE_Spr12.png\n",
      "processed_images/11/IMG_0047_SecBC_Spr12.png\n",
      "processed_images/11/IMG_0088_SecBC_Spr12.png\n",
      "processed_images/11/IMG_0252_SecBC_Spr12.png\n",
      "processed_images/11/IMG_0264_SecBC_Spr12.png\n",
      "processed_images/11/IMG_2604_SecDE_Spr12.png\n",
      "processed_images/11/IMG_2905_SecDE_Spr12.png\n",
      "processed_images/11/IMG_2164_SecDE_Spr12.png\n",
      "processed_images/11/IMG_2380_SecDE_Spr12.png\n",
      "processed_images/11/IMG_2377_SecDE_Spr12.png\n",
      "processed_images/11/IMG_2684_SecDE_Spr12.png\n",
      "processed_images/11/IMG_2486_SecDE_Spr12.png\n",
      "processed_images/11/IMG_2492_SecDE_Spr12.png\n",
      "processed_images/11/IMG_2473_SecDE_Spr12.png\n",
      "processed_images/11/IMG_2946_SecDE_Spr12.png\n",
      "processed_images/11/IMG_2949_SecDE_Spr12.png\n",
      "processed_images/11/IMG_2761_SecDE_Spr12.png\n",
      "processed_images/11/IMG_2775_SecDE_Spr12.png\n",
      "processed_images/11/IMG_2782_SecDE_Spr12.png\n",
      "processed_images/11/IMG_2800_SecDE_Spr12.png\n",
      "processed_images/11/IMG_2957_SecDE_Spr12.png\n",
      "processed_images/11/IMG_2235_SecDE_Spr12.png\n",
      "processed_images/11/IMG_2167_SecDE_Spr12.png\n",
      "processed_images/11/IMG_2171_SecDE_Spr12.png\n",
      "processed_images/11/IMG_4768_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4771_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4774_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4777_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4780_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4783_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4787_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4791_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4794_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4667_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4670_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4843_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4845_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4859_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4862_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4865_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4868_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4872_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4701_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4687_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4684_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4681_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4704_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4765_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4607_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4394_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4378_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4850_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4756_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4853_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4856_SecFG_Spr12.png\n",
      "processed_images/11/IMG_4846_SecFG_Spr12.png\n",
      "processed_images/11/IMG_2264_SecDE_Spr12.png\n",
      "processed_images/3/IMG_4279_SecFG_Spr12.png\n",
      "processed_images/3/IMG_3117_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3131_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3134_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3137_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3139_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3141_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3143_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3145_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3148_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3150_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3153_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3156_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3159_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3163_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3166_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3169_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3171_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3175_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3177_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3180_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3199_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3202_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3205_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3209_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3211_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3213_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3215_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3218_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3221_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3223_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3225_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3237_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3239_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3242_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3245_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3248_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3254_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3259_SecDE_Spr12.png\n",
      "processed_images/3/IMG_3264_SecDE_Spr12.png\n",
      "processed_images/3/IMG_2827_SecDE_Spr12.png\n",
      "processed_images/3/IMG_4288_SecFG_Spr12.png\n",
      "processed_images/3/IMG_4286_SecFG_Spr12.png\n",
      "processed_images/3/IMG_4264_SecFG_Spr12.png\n",
      "processed_images/3/IMG_4258_SecFG_Spr12.png\n",
      "processed_images/3/IMG_4239_SecFG_Spr12.png\n",
      "processed_images/3/IMG_0535_SecBC_Spr12.png\n",
      "processed_images/3/IMG_0536_SecBC_Spr12.png\n",
      "processed_images/3/IMG_0538_SecBC_Spr12.png\n",
      "processed_images/3/IMG_0543_SecBC_Spr12.png\n",
      "processed_images/3/IMG_0546_SecBC_Spr12.png\n",
      "processed_images/3/IMG_0549_SecBC_Spr12.png\n",
      "processed_images/3/IMG_0311_SecBC_Spr12.png\n",
      "processed_images/3/IMG_0315_SecBC_Spr12.png\n",
      "processed_images/3/IMG_0318_SecBC_Spr12.png\n",
      "processed_images/3/IMG_0320_SecBC_Spr12.png\n",
      "processed_images/3/IMG_0322_SecBC_Spr12.png\n",
      "processed_images/3/IMG_0327_SecBC_Spr12.png\n",
      "processed_images/3/IMG_0329_SecBC_Spr12.png\n",
      "processed_images/3/IMG_0305_SecBC_Spr12.png\n",
      "processed_images/3/IMG_0309_SecBC_Spr12.png\n",
      "processed_images/3/IMG_0291_SecBC_Spr12.png\n",
      "processed_images/3/IMG_0294_SecBC_Spr12.png\n",
      "processed_images/3/IMG_0779_SecBC_Spr12.png\n",
      "processed_images/3/IMG_0786_SecBC_Spr12.png\n",
      "processed_images/3/IMG_0798_SecBC_Spr12.png\n",
      "processed_images/3/IMG_0221_SecBC_Spr12.png\n",
      "processed_images/3/IMG_0552_SecBC_Spr12.png\n",
      "processed_images/3/IMG_0555_SecBC_Spr12.png\n",
      "processed_images/3/IMG_0559_SecBC_Spr12.png\n",
      "processed_images/3/IMG_0561_SecBC_Spr12.png\n",
      "processed_images/3/IMG_1045_SecBC_Spr12.png\n",
      "processed_images/3/IMG_1047_SecBC_Spr12.png\n",
      "processed_images/3/IMG_1037_SecBC_Spr12.png\n",
      "processed_images/3/IMG_1040_SecBC_Spr12.png\n",
      "processed_images/3/IMG_1042_SecBC_Spr12.png\n",
      "processed_images/3/IMG_0986_SecBC_Spr12.png\n",
      "processed_images/3/IMG_0990_SecBC_Spr12.png\n",
      "processed_images/3/IMG_0992_SecBC_Spr12.png\n",
      "processed_images/3/IMG_0994_SecBC_Spr12.png\n",
      "processed_images/5/IMG_2284_SecDE_Spr12.png\n",
      "processed_images/5/IMG_1705_SecBC_Spr12.png\n",
      "processed_images/5/IMG_0005_SecBC_Spr12.png\n",
      "processed_images/5/IMG_0009_SecBC_Spr12.png\n",
      "processed_images/5/IMG_0028_SecBC_Spr12.png\n",
      "processed_images/5/IMG_0023_SecBC_Spr12.png\n",
      "processed_images/5/IMG_1031_SecBC_Spr12.png\n",
      "processed_images/5/IMG_1034_SecBC_Spr12.png\n",
      "processed_images/5/IMG_1006_SecBC_Spr12.png\n",
      "processed_images/5/IMG_1009_SecBC_Spr12.png\n",
      "processed_images/5/IMG_0410_SecBC_Spr12.png\n",
      "processed_images/5/IMG_0333_SecBC_Spr12.png\n",
      "processed_images/5/IMG_0336_SecBC_Spr12.png\n",
      "processed_images/5/IMG_0257_SecBC_Spr12.png\n",
      "processed_images/5/IMG_0260_SecBC_Spr12.png\n",
      "processed_images/5/IMG_4003_SecDE_Spr12.png\n",
      "processed_images/5/IMG_4005_SecDE_Spr12.png\n",
      "processed_images/5/IMG_4009_SecDE_Spr12.png\n",
      "processed_images/5/IMG_1438_SecBC_Spr12.png\n",
      "processed_images/5/IMG_1441_SecBC_Spr12.png\n",
      "processed_images/5/IMG_1444_SecBC_Spr12.png\n",
      "processed_images/5/IMG_1450_SecBC_Spr12.png\n",
      "processed_images/5/IMG_1452_SecBC_Spr12.png\n",
      "processed_images/5/IMG_1455_SecBC_Spr12.png\n",
      "processed_images/5/IMG_1460_SecBC_Spr12.png\n",
      "processed_images/5/IMG_1464_SecBC_Spr12.png\n",
      "processed_images/5/IMG_1468_SecBC_Spr12.png\n",
      "processed_images/5/IMG_1471_SecBC_Spr12.png\n",
      "processed_images/5/IMG_1474_SecBC_Spr12.png\n",
      "processed_images/5/IMG_1477_SecBC_Spr12.png\n",
      "processed_images/5/IMG_1481_SecBC_Spr12.png\n",
      "processed_images/5/IMG_1486_SecBC_Spr12.png\n",
      "processed_images/5/IMG_1491_SecBC_Spr12.png\n",
      "processed_images/5/IMG_1495_SecBC_Spr12.png\n",
      "processed_images/5/IMG_1498_SecBC_Spr12.png\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "processed_images/5/IMG_1505_SecBC_Spr12.png\n",
      "processed_images/5/IMG_1421_SecBC_Spr12.png\n",
      "processed_images/5/IMG_1425_SecBC_Spr12.png\n",
      "processed_images/5/IMG_1429_SecBC_Spr12.png\n",
      "processed_images/5/IMG_1433_SecBC_Spr12.png\n",
      "processed_images/5/IMG_1436_SecBC_Spr12.png\n",
      "processed_images/5/IMG_3270_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3273_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3276_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3280_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3282_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3285_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3288_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3291_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3299_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3302_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3305_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3308_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3310_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3313_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3316_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3319_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3322_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3325_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3327_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3330_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3333_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3337_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3340_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3343_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3346_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3349_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3352_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3354_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3357_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3359_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3361_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3364_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3367_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3370_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3373_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3384_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3388_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3391_SecDE_Spr12.png\n",
      "processed_images/5/IMG_3394_SecDE_Spr12.png\n",
      "processed_images/61/IMG_1720_SecBC_Spr12.png\n",
      "processed_images/61/IMG_1724_SecBC_Spr12.png\n",
      "processed_images/61/IMG_1546_SecBC_Spr12.png\n",
      "processed_images/61/IMG_1726_SecBC_Spr12.png\n",
      "processed_images/61/IMG_1729_SecBC_Spr12.png\n",
      "processed_images/61/IMG_1536_SecBC_Spr12.png\n",
      "processed_images/61/IMG_1539_SecBC_Spr12.png\n",
      "processed_images/61/IMG_1732_SecBC_Spr12.png\n",
      "processed_images/61/IMG_1522_SecBC_Spr12.png\n",
      "processed_images/61/IMG_1742_SecBC_Spr12.png\n",
      "processed_images/61/IMG_1744_SecBC_Spr12.png\n",
      "processed_images/61/IMG_1519_SecBC_Spr12.png\n",
      "processed_images/61/IMG_1746_SecBC_Spr12.png\n",
      "processed_images/61/IMG_1517_SecBC_Spr12.png\n",
      "processed_images/61/IMG_1747_SecBC_Spr12.png\n",
      "processed_images/61/IMG_1748_SecBC_Spr12.png\n",
      "processed_images/61/IMG_1515_SecBC_Spr12.png\n",
      "processed_images/61/IMG_1514_SecBC_Spr12.png\n",
      "processed_images/61/IMG_1510_SecBC_Spr12.png\n",
...

This file has been truncated, please download it to see its full contents.

classification_coastal.py

Python
training script for nnabla.
# Copyright (c) 2017 Sony Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from six.moves import range

import os

import nnabla as nn
import nnabla.logger as logger
import nnabla.functions as F
import nnabla.parametric_functions as PF
import nnabla.solvers as S
import nnabla.utils.save as save

from args import get_args
from coastal_data import data_iterator_mnist

import numpy as np

IMG_SIZE = 224

def categorical_error(pred, label):
    """
    Compute categorical error given score vectors and labels as
    numpy.ndarray.
    """
    pred_label = pred.argmax(1)
    return (pred_label != label.flat).mean()


def augmentation(h, test, aug):
    if aug is None:
        aug = not test
    if aug:
        h = F.image_augmentation(h, (1, IMG_SIZE, IMG_SIZE), (0, 0), 0.9, 1.1, 0.3,
                                 1.3, 0.1, False, False, 0.5, False, 1.5, 0.5, False, 0.1, 0)
    return h


def mnist_lenet_prediction(image, test=False, aug=None):
    """
    Construct LeNet for MNIST.
    """
    image /= 255.0
    image = augmentation(image, test, aug)
    c1 = PF.convolution(image, 16, (5, 5), name='conv1')
    c1 = F.relu(F.max_pooling(c1, (2, 2)), inplace=True)
    c2 = PF.convolution(c1, 16, (5, 5), name='conv2')
    c2 = F.relu(F.max_pooling(c2, (2, 2)), inplace=True)
    c3 = F.relu(PF.affine(c2, 50, name='fc3'), inplace=True)
    c4 = PF.affine(c3, 10, name='fc4')
    return c4


def mnist_resnet_prediction(image, test=False, aug=None):
    print('resnet prediction')
    """
    Construct ResNet for MNIST.
    """
    image /= 255.0
    image = augmentation(image, test, aug)

    def bn(x):
        return PF.batch_normalization(x, batch_stat=not test)

    def res_unit(x, scope):
        C = x.shape[1]
        with nn.parameter_scope(scope):
            with nn.parameter_scope('conv1'):
                h = F.elu(bn(PF.convolution(x, C / 2, (1, 1), with_bias=False)))
            with nn.parameter_scope('conv2'):
                h = F.elu(
                    bn(PF.convolution(h, C / 2, (3, 3), pad=(1, 1), with_bias=False)))
            with nn.parameter_scope('conv3'):
                h = bn(PF.convolution(h, C, (1, 1), with_bias=False))
        return F.elu(F.add2(h, x, inplace=True))
    # Conv1 --> 64 x 32 x 32
    with nn.parameter_scope("conv1"):
        c1 = F.elu(
            bn(PF.convolution(image, 64, (3, 3), pad=(3, 3), with_bias=False)))
    # Conv2 --> 64 x 16 x 16
    c2 = F.max_pooling(res_unit(c1, "conv2"), (2, 2))
    # Conv3 --> 64 x 8 x 8
    c3 = F.max_pooling(res_unit(c2, "conv3"), (2, 2))
    # Conv4 --> 64 x 8 x 8
    c4 = res_unit(c3, "conv4")
    # Conv5 --> 64 x 4 x 4
    c5 = F.max_pooling(res_unit(c4, "conv5"), (2, 2))
    # Conv5 --> 64 x 4 x 4
    c6 = res_unit(c5, "conv6")
    pl = F.average_pooling(c6, (4, 4))
    with nn.parameter_scope("classifier"):
        y = PF.affine(pl, 10)
    return y


def train():
    """
    Main script.

    Steps:

    * Parse command line arguments.
    * Specify a context for computation.
    * Initialize DataIterator for MNIST.
    * Construct a computation graph for training and validation.
    * Initialize a solver and set parameter variables to it.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop on the training graph.
      * Compute training error
      * Set parameter gradients zero
      * Execute backprop.
      * Solver updates parameters by using gradients computed by backprop.
    """
    args = get_args()

    from numpy.random import seed
    seed(0)

    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(
        args.context, device_id=args.device_id, type_config=args.type_config)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    if args.net == 'lenet':
        mnist_cnn_prediction = mnist_resnet_prediction#mnist_lenet_prediction
    elif args.net == 'resnet':
        mnist_cnn_prediction = mnist_resnet_prediction
    else:
        raise ValueError("Unknown network type {}".format(args.net))

    # TRAIN
    # Create input variables.
    image = nn.Variable([args.batch_size, 1, IMG_SIZE, IMG_SIZE])
    label = nn.Variable([args.batch_size, 1])
    # Create prediction graph.
    pred = mnist_cnn_prediction(image, test=False, aug=args.augment_train)
    pred.persistent = True
    # Create loss function.
    loss = F.mean(F.softmax_cross_entropy(pred, label))

    # TEST
    # Create input variables.
    vimage = nn.Variable([args.batch_size, 1, IMG_SIZE, IMG_SIZE])
    vlabel = nn.Variable([args.batch_size, 1])
    # Create prediction graph.
    vpred = mnist_cnn_prediction(vimage, test=True, aug=args.augment_test)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=10)

    # Initialize DataIterator for MNIST.
    from numpy.random import RandomState
    data = data_iterator_mnist(args.batch_size, train=True, rng=RandomState(1223))
    vdata = data_iterator_mnist(args.batch_size, train=False)
    # Training loop.
    for i in range(args.max_iter):
        if i % args.val_interval == 0:
            # Validation
            ve = 0.0
            for j in range(args.val_iter):
                vimage.d, vlabel.d = vdata.next()
                vpred.forward(clear_buffer=True)
                vpred.data.cast(np.float32, ctx)
                ve += categorical_error(vpred.d, vlabel.d)
            monitor_verr.add(i, ve / args.val_iter)
        if i % args.model_save_interval == 0:
            nn.save_parameters(os.path.join(
                args.model_save_path, 'params_%06d.h5' % i))
        # Training forward
        image.d, label.d = data.next()
        solver.zero_grad()
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()
        loss.data.cast(np.float32, ctx)
        pred.data.cast(np.float32, ctx)
        e = categorical_error(pred.d, label.d)
        monitor_loss.add(i, loss.d.copy())
        monitor_err.add(i, e)
        monitor_time.add(i)

    ve = 0.0
    for j in range(args.val_iter):
        vimage.d, vlabel.d = vdata.next()
        vpred.forward(clear_buffer=True)
        ve += categorical_error(vpred.d, vlabel.d)
    monitor_verr.add(i, ve / args.val_iter)

    parameter_file = os.path.join(
        args.model_save_path, '{}_params_{:06}.h5'.format(args.net, args.max_iter))
    nn.save_parameters(parameter_file)

    # append F.Softmax to the prediction graph so users see intuitive outputs
    runtime_contents = {
        'networks': [
            {'name': 'Validation',
             'batch_size': args.batch_size,
             'outputs': {'y': F.softmax(vpred)},
             'names': {'x': vimage}}],
        'executors': [
            {'name': 'Runtime',
             'network': 'Validation',
             'data': ['x'],
             'output': ['y']}]}
    save.save(os.path.join(args.model_save_path,
                           '{}_result.nnp'.format(args.net)), runtime_contents)


if __name__ == '__main__':
    train()

coastal_data.py

Python
reads preprocessed coastal images to be used in training an image recognition model.
# Copyright (c) 2017 Sony Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

'''
Provide data iterator for MNIST examples.
'''
import numpy
import struct
import zlib
import pandas as pd

from nnabla.logger import logger
from nnabla.utils.data_iterator import data_iterator
from nnabla.utils.data_source import DataSource
from nnabla.utils.data_source_loader import download
from PIL import Image

IMAGES_FILEPATH = 'processed_images/'
def load_coastal_images(train=True):
    images = numpy.empty(0)
    if train:
        csvPath = 'local_labeled_images_train.csv'
    else:
        csvPath = 'local_labeled_images_eval.csv'
    dataset = pd.read_csv(csvPath)
    labels = dataset['class'].values
    labels = labels.reshape(len(labels),1)
    #print('labels shape: ' + str(labels.shape))
    #print(labels)
    
    # Remove this when is working
    #dataset = dataset.head()
    
    # Now images
    
    imagePaths = dataset['filename'].values
    for row in imagePaths:
        imgPath = IMAGES_FILEPATH + row
        im = Image.open(imgPath)
        # read into numpy array
        im = numpy.asarray(im)
        # drop second value in every pixel which is 255
        im = numpy.delete(im, numpy.s_[1:2], axis=2) 
        # transpose it to set it to (1,224,224)
        im = im.T
        ## drop the 0 and transpose it again to set it to (224, 224)
        im = im[0].T
        #im = im.reshape(1,224,224)
        #print(im.shape)
        #print(imgPath)
        images = numpy.append(images,im)
        
    images = images.reshape(len(imagePaths),1,224,224)
    return images, labels


class MnistDataSource(DataSource):
    '''
    Get data directly from MNIST dataset from Internet(yann.lecun.com).
    '''

    def _get_data(self, position):
        image = self._images[self._indexes[position]]
        label = self._labels[self._indexes[position]]
        return (image, label)

    def __init__(self, train=True, shuffle=False, rng=None):
        super(MnistDataSource, self).__init__(shuffle=shuffle)
        self._train = train

        self._images, self._labels = load_coastal_images(train)

        self._size = self._labels.size
        self._variables = ('x', 'y')
        if rng is None:
            rng = numpy.random.RandomState(313)
        self.rng = rng
        self.reset()

    def reset(self):
        if self._shuffle:
            self._indexes = self.rng.permutation(self._size)
        else:
            self._indexes = numpy.arange(self._size)
        super(MnistDataSource, self).reset()

    @property
    def images(self):
        """Get copy of whole data with a shape of (N, 1, H, W)."""
        return self._images.copy()

    @property
    def labels(self):
        """Get copy of whole label with a shape of (N, 1)."""
        return self._labels.copy()


def data_iterator_mnist(batch_size,
                        train=True,
                        rng=None,
                        shuffle=True,
                        with_memory_cache=False,
                        with_parallel=False,
                        with_file_cache=False):
    '''
    Provide DataIterator with :py:class:`MnistDataSource`
    with_memory_cache, with_parallel and with_file_cache option's default value is all False,
    because :py:class:`MnistDataSource` is able to store all data into memory.

    For example,

    .. code-block:: python

        with data_iterator_mnist(True, batch_size) as di:
            for data in di:
                SOME CODE TO USE data.

    '''
    return data_iterator(MnistDataSource(train=train, shuffle=shuffle, rng=rng),
                         batch_size,
                         with_memory_cache,
                         with_parallel,
                         with_file_cache)

args.py

Python
this file is needed by classification_coastal.py in order to read args from the command line.
# Copyright (c) 2017 Sony Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def get_args(monitor_path='tmp.monitor', max_iter=10000, model_save_path=None, learning_rate=1e-3, batch_size=64, weight_decay=0, description=None):
    """
    Get command line arguments.

    Arguments set the default values of command line arguments.
    """
    import argparse
    import os
    if model_save_path is None:
        model_save_path = monitor_path
    if description is None:
        description = "Examples on MNIST dataset. The following help shared among examples in this folder. Some arguments are valid or invalid in some examples."
    parser = argparse.ArgumentParser(description)
    parser.add_argument("--batch-size", "-b", type=int, default=batch_size)
    parser.add_argument("--learning-rate", "-l",
                        type=float, default=learning_rate)
    parser.add_argument("--monitor-path", "-m",
                        type=str, default=monitor_path,
                        help='Path monitoring logs saved.')
    parser.add_argument("--max-iter", "-i", type=int, default=max_iter,
                        help='Max iteration of training.')
    parser.add_argument("--val-interval", "-v", type=int, default=100,
                        help='Validation interval.')
    parser.add_argument("--val-iter", "-j", type=int, default=10,
                        help='Each validation runs `val_iter mini-batch iteration.')
    parser.add_argument("--weight-decay", "-w",
                        type=float, default=weight_decay,
                        help='Weight decay factor of SGD update.')
    parser.add_argument("--device-id", "-d", type=str, default='0',
                        help='Device ID the training run on. This is only valid if you specify `-c cudnn`.')
    parser.add_argument("--type-config", "-t", type=str, default='float',
                        help='Type of computation. e.g. "float", "half".')
    parser.add_argument("--model-save-interval", "-s", type=int, default=1000,
                        help='The interval of saving model parameters.')
    parser.add_argument("--model-save-path", "-o",
                        type=str, default=model_save_path,
                        help='Path the model parameters saved.')
    parser.add_argument("--net", "-n", type=str,
                        default='lenet',
                        help="Neural network architecture type (used only in classification*.py).\n  classification.py: ('lenet'|'resnet'),  classification_bnn.py: ('bincon'|'binnet'|'bwn'|'bincon_resnet'|'binnet_resnet'|'bwn_resnet')")
    parser.add_argument('--context', '-c', type=str,
                        default='cpu', help="Extension modules. ex) 'cpu', 'cudnn'.")
    parser.add_argument('--augment-train', action='store_true',
                        default=False, help="Enable data augmentation of training data.")
    parser.add_argument('--augment-test', action='store_true',
                        default=False, help="Enable data augmentation of testing data.")
    args = parser.parse_args()
    if not os.path.isdir(args.model_save_path):
        os.makedirs(args.model_save_path)
    return args

Credits

Juan Acevedo

Juan Acevedo

2 projects • 1 follower
I been developing for 10 years and I'm currently an engineering manager building cutting edge auto and voice applications.

Comments