Taro Yoshino
Published © LGPL

Get started with Tensorflow lite/micro by Sony Spresense

This is a tutorial for MNIST data recognition using Tensorflow lite/micro on Sony Spresense Main Board.

AdvancedProtip1 hour2,096
Get started with Tensorflow lite/micro by Sony Spresense

Things used in this project

Hardware components

Sony Spresense Main board
×1
USB-A to Micro-USB Cable
USB-A to Micro-USB Cable
×1

Software apps and online services

TensorFlow
TensorFlow
Arduino IDE
Arduino IDE
Spresense Arduino Tensorflow Board Package
Bitmap Image Library for Arduino IDE
Sony xmodem_writer for windows
Sony xmodem_writer for linux
xmodem_writer for macOS

Story

Read more

Code

tf_mnist_traning.py

Python
This python code is to learn MNIST data and output the quantized model in a c-style header.
## TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras

## Helper libraries
import numpy as np
import binascii
import logging

## Check the version of tensorflow (should be 2.8.0)
print(tf.__version__)

## To silent verbose
tf.autograph.set_verbosity(0)
logging.getLogger("tensorflow").setLevel(logging.ERROR)


## MNIST download
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# 60,000 training data and 10,000 test data of 28x28 pixel images
print("train_images shape", train_images.shape)
print("train_labels shape", train_labels.shape)
print("test_images shape", test_images.shape)
print("test_labels shape", test_labels.shape)


## Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images/255.0;
test_images = test_images/255.0;
train_labels = tf.keras.utils.to_categorical(train_labels, 10)
test_labels = tf.keras.utils.to_categorical(test_labels, 10)
print('Datasets are normalized')


## Model definition
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=6, kernel_size=(5, 5), padding='same', activation=tf.nn.relu, name="conv2d_6"), 
  keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same'),
  keras.layers.Flatten(),
  keras.layers.Dense(32, activation=tf.nn.relu, name="dense_32"),
  keras.layers.Dense(10),
  keras.layers.Activation(tf.nn.softmax)
])

model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])


## Output the summary of the model
model.summary()


## Training the model
model.fit(x=train_images, y=train_labels, 
  batch_size=128, epochs=30, verbose=1, validation_split=0.1)


## Evaluate the model using all images in the test dataset.
test_loss, test_acc = model.evaluate(x=test_images, y=test_labels, verbose=1)
print('Accuracy = %f' % test_acc)


## Convert the Keras model to a quantized TFLite model.
converter = tf.lite.TFLiteConverter.from_keras_model(model)
def representative_dataset_gen():
   for i in range(100):
      input_image = tf.cast(test_images[i], tf.float32)
      input_image = tf.reshape(input_image, [1,28,28])
      yield ([input_image])

converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
tflite_model = converter.convert()


## Show the quantized model size in KBs.
tflite_model_size = len(tflite_model) / 1024
print('Quantized model size = %dKBs.' % tflite_model_size)
# Save the model to disk
open('qmodel.tflite', "wb").write(tflite_model)


## Output the quantized tflite model to a c-style header
def convert_to_c_array(bytes) -> str:
  hexstr = binascii.hexlify(bytes).decode("UTF-8")
  hexstr = hexstr.upper()
  array = ["0x" + hexstr[i:i + 2] for i in range(0, len(hexstr), 2)]
  array = [array[i:i+10] for i in range(0, len(array), 10)]
  return ",\n  ".join([", ".join(e) for e in array])

tflite_binary = open('model.tflite', 'rb').read()
ascii_bytes = convert_to_c_array(tflite_binary)
header_file = "const unsigned char model_tflite[] = {\n  " + ascii_bytes + "\n};\nunsigned int model_tflite_len = " + str(len(tflite_binary)) + ";"
with open("model.h", "w") as f:
    f.write(header_file)

Spresense_tf_mnist.ino

C/C++
Tensorflow lite/micro code on Sony Spresense.
#include "tensorflow/lite/micro/all_ops_resolver.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/system_setup.h"
#include "tensorflow/lite/schema/schema_generated.h"

#include "qmodel.h"  /* quantized model */
#define TEST_FILE "0003.bmp"

tflite::ErrorReporter* error_reporter = nullptr;
const tflite::Model* model = nullptr;
tflite::MicroInterpreter* interpreter = nullptr;
TfLiteTensor* input = nullptr;
TfLiteTensor* output = nullptr;
int inference_count = 0;

constexpr int kTensorArenaSize = 30000;
uint8_t tensor_arena[kTensorArenaSize];

#include <Flash.h>
#include <BmpImage.h>
BmpImage bmp;

void setup() {
  Serial.begin(115200);
  tflite::InitializeTarget();
  memset(tensor_arena, 0, kTensorArenaSize*sizeof(uint8_t));
  
  // Set up logging. 
  static tflite::MicroErrorReporter micro_error_reporter;
  error_reporter = &micro_error_reporter;
  
  // Map the model into a usable data structure..
  model = tflite::GetModel(model_tflite);
  if (model->version() != TFLITE_SCHEMA_VERSION) {
    Serial.println("Model provided is schema version " 
                  + String(model->version()) + " not equal "
                  + "to supported version "
                  + String(TFLITE_SCHEMA_VERSION));
    return;
  } else {
    Serial.println("Model version: " + String(model->version()));
  }
  
  // This pulls in all the operation implementations we need.
  static tflite::AllOpsResolver resolver;
  
  // Build an interpreter to run the model with.
  static tflite::MicroInterpreter static_interpreter(
      model, resolver, tensor_arena, kTensorArenaSize, error_reporter);
  interpreter = &static_interpreter;
  
  // Allocate memory from the tensor_arena for the model's tensors.
  TfLiteStatus allocate_status = interpreter->AllocateTensors();
  if (allocate_status != kTfLiteOk) {
    Serial.println("AllocateTensors() failed");
    return;
  } else {
    Serial.println("AllocateTensor() Success");
  }

  size_t used_size = interpreter->arena_used_bytes();
  Serial.println("Area used bytes: " + String(used_size));
  input = interpreter->input(0);
  output = interpreter->output(0);

  /* check input */
  if (input->type != kTfLiteFloat32) {
    Serial.println("input type mismatch. expected input type is float32");
    return;
  } else {
    Serial.println("input type is float32");
  }

  Serial.println("Model input:");
  Serial.println("input->type: " + String(input->type));
  Serial.println("dims->size: " + String(input->dims->size));
  for (int n = 0; n < input->dims->size; ++n) {
    Serial.println("dims->data[n]: " + String(input->dims->data[n]));
  }

  Serial.println("Model output:");
  Serial.println("dims->size: " + String(output->dims->size));
  for (int n = 0; n < output->dims->size; ++n) {
    Serial.println("dims->data[n]: " + String(output->dims->data[n]));
  }  
  
  /* read test data */
  File myFile = Flash.open(TEST_FILE);
  if (!myFile) { Serial.println(TEST_FILE " not found"); return; }

  Serial.println("Read " TEST_FILE);
  bmp.begin(myFile);
  BmpImage::BMP_IMAGE_PIX_FMT fmt = bmp.getPixFormat();
  if (fmt != BmpImage::BMP_IMAGE_GRAY8) {
    Serial.println("support format error");
    return;
  }

  int width = bmp.getWidth();
  int height = bmp.getHeight();

  Serial.println("width:  " + String(width));
  Serial.println("height: " + String(height));
  uint8_t* img = bmp.getImgBuff();

  for (int i = 0; i < width*height; ++i) {
    input->data.f[i] = (float)(img[i]/255.0);
  }

  Serial.println("Do inference");
  uint32_t start_time = micros();
  TfLiteStatus invoke_status = interpreter->Invoke();
  if (invoke_status != kTfLiteOk) {
    Serial.println("Invoke failed");
    return;
  }
  uint32_t duration = micros() - start_time;
  Serial.println("Inference time = " + String(duration));
  
  for (int n = 0; n < 10; ++n) {
    float value = output->data.f[n];
    Serial.println("[" + String(n) + "] " + String(value)); 
  }
}

void loop() { }

Credits

Taro Yoshino

Taro Yoshino

3 projects • 8 followers
IoT Engineer in Japan. My expertise is in Embedded Systems, Edge AI Systems, and Acoustic/Vibration Analysis.
Thanks to Taro Yoshino.

Comments