# Mouse Fatigue Estimation by GSR and EMG Values w/ TensorFlow
# Windows, Linux, or Ubuntu
# By Kutluhan Aktar
# Collate forearm muscle soreness data on the SD card, build and train a neural network model, and run the model directly on Wio Terminal.
# For more information:
# https://www.theamplituhedron.com/projects/Mouse_Fatigue_Estimation_by_GSR_and_EMG_Values_w_TensorFlow/

import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tflite_to_c_array import hex_to_c_array
from test_data import test_inputs, test_labels

# Create a class to build a neural network model after visualizing and scaling (normalizing) the soreness data (GSR and EMG) collected by Wio Terminal.
class Mouse_Fatigue:
    def __init__(self, csv_path):
        self.inputs = []
        self.labels = []
        self.model_name = "mouse_fatigue_level"
        self.scale_val = 1000
        # Read the collated soreness data set (GSR and EMG):
        self.df = pd.read_csv(csv_path)
    # Create graphics for each requested column.
    def graphics(self, column_1, column_2, x_label, y_label):
        # Show the requested data column from the data set:
        plt.gcf().canvas.set_window_title('Mouse Fatigue Estimation by GSR and EMG Values')
        plt.hist2d(self.df[column_1], self.df[column_2], cmap="coolwarm")
    # Visualize data before creating and training the neural network model.
    def data_visualization(self):
        # Scrutinize data columns to build a model with appropriately formatted data:
        self.graphics('GSR', 'EMG', 'GSR', 'EMG')
    # Scale (normalize) data to define appropriately formatted inputs.
    def scale_data_and_define_inputs(self):
        self.df["scaled_GSR"] = self.df["GSR"] / self.scale_val
        self.df["scaled_EMG"] = self.df["EMG"] / self.scale_val
        # Create the inputs array by utilizing the scaled variables:
        for i in range(len(self.df)):
            self.inputs.append(np.array([self.df["scaled_GSR"][i], self.df["scaled_EMG"][i]]))
        self.inputs = np.asarray(self.inputs)
    # Assign labels for each input according to the predefined soreness classes for each data record.
    def define_and_assign_labels(self):
        self.labels = self.df["Soreness"]
    # Split inputs and labels into training and test sets.
    def split_data(self):
        # (training)
        self.train_inputs = self.inputs
        self.train_labels = self.labels
        # (test)
        self.test_inputs = test_inputs / self.scale_val
        self.test_labels = test_labels
    # Build and train an artificial neural network (ANN) model to make predictions on mouse fatigue levels (classes) based on GSR and EMG measurements.
    def build_and_train_model(self):
        # Build the neural network:
        self.model = keras.Sequential([
            keras.layers.Dense(64, activation='relu'),
            keras.layers.Dense(32, activation='relu'),
            keras.layers.Dense(8, activation='relu'),
            keras.layers.Dense(3, activation='softmax')
        # Compile:
        self.model.compile(optimizer='adam', loss="sparse_categorical_crossentropy", metrics=['accuracy'])
        # Train:
        self.model.fit(self.train_inputs, self.train_labels, epochs=150)
        # Test the model accuracy:
        print("\n\nModel Evaluation:")
        test_loss, test_acc = self.model.evaluate(self.test_inputs, self.test_labels) 
        print("Evaluated Accuracy: ", test_acc)
    # Save the model for further usage:
    def save_model(self):
    # Convert the TensorFlow Keras H5 model (.h5) to a TensorFlow Lite model (.tflite).
    def convert_TF_model(self, path):
        #model = tf.keras.models.load_model(path + ".h5")
        converter = tf.lite.TFLiteConverter.from_keras_model(self.model)
        #converter.optimizations = [tf.lite.Optimize.DEFAULT]
        #converter.target_spec.supported_types = [tf.float16]
        tflite_model = converter.convert()
        # Save the recently converted TensorFlow Lite model.
        with open(path + '.tflite', 'wb') as f:
        print("\r\nTensorFlow Keras H5 model converted to a TensorFlow Lite model!\r\n")
        # Convert the recently created TensorFlow Lite model to hex bytes (C array) to generate a .h file string.
        with open("model/{}.h".format(self.model_name), 'w') as file:
            file.write(hex_to_c_array(tflite_model, self.model_name))
        print("\r\nTensorFlow Lite model converted to a C header (.h) file!\r\n")
    # Run Artificial Neural Network (ANN):
    def Neural_Network(self, save):
        if save:
# Define a new class object named 'mouse_fatigue_level':
mouse_fatigue_level = Mouse_Fatigue("data/mouse_fatigue_data_set.csv")

# Visualize data columns:

# Artificial Neural Network (ANN):        

# Convert the TensorFlow Keras H5 model to a TensorFlow Lite model:


import numpy as np

test_inputs = np.array([

test_labels = np.array([0,0,0,1,1,1,2,2,2,0,2,1,0,2,0,1])


# Code from:
# https://www.digikey.com/en/maker/projects/intro-to-tinyml-part-1-training-a-model-for-arduino-in-tensorflow/8f1fc8c0b83d417ab521c48864d2a8ec
# By Shawn Hymel

# Function: Convert some hex value into an array for C programming
def hex_to_c_array(hex_data, var_name):

  c_str = ''

  # Create header guard
  c_str += '#ifndef ' + var_name.upper() + '_H\n'
  c_str += '#define ' + var_name.upper() + '_H\n\n'

  # Add array length at top of file
  c_str += '\nunsigned int ' + var_name + '_len = ' + str(len(hex_data)) + ';\n'

  # Declare C variable
  c_str += 'unsigned char ' + var_name + '[] = {'
  hex_array = []
  for i, val in enumerate(hex_data) :

    # Construct string from hex
    hex_str = format(val, '#04x')

    # Add formatting so each line stays within 80 characters
    if (i + 1) < len(hex_data):
      hex_str += ','
    if (i + 1) % 12 == 0:
      hex_str += '\n '

  # Add closing brace
  c_str += '\n ' + format(' '.join(hex_array)) + '\n};\n\n'

  # Close out header guard
  c_str += '#endif //' + var_name.upper() + '_H'

  return c_str


      // Mouse Fatigue Estimation by GSR and EMG //
     //          Values w/ TensorFlow           //
    //             ---------------             //
   //              (Wio Terminal)             //
  //             by Kutluhan Aktar           //
 //                                         //

// Collate forearm muscle soreness data on the SD card, build and train a neural network model, and run the model directly on Wio Terminal.
// For more information:
// https://www.theamplituhedron.com/projects/Mouse_Fatigue_Estimation_by_GSR_and_EMG_Values_w_TensorFlow/
// Connections
// Wio Terminal :
//                                Grove - GSR sensor
// A0  --------------------------- Grove Connector
//                                Grove - EMG Detector
// A2  --------------------------- Grove Connector

// Include the required libraries.
#include <SPI.h>
#include <Seeed_FS.h>
#include "TFT_eSPI.h"
#include "seeed_line_chart.h"
#include "SD/Seeed_SD.h"
#include "RawImage.h"

// Define the TFT screen:
TFT_eSPI tft;

// Define the sprite settings: 
#define max_size 50 // maximum size of data
doubles gsr_data, emg_data;
TFT_eSprite spr = TFT_eSprite(&tft);

// Initialize the File class and define the file name: 
File myFile;
const char* data_file = "mouse_fatigue_data_set.csv";

// Define the sensor voltage (signal) pins:
#define GSR A0
#define EMG A2

// Define the data holders.
int gsr_value, emg_value;
uint32_t background_color = tft.color565(31,32,32);
uint32_t text_color = tft.color565(174,255,205);

void setup(){

  // Configurable Buttons:
  // Check the connection status between Wio Terminal and the SD card.
  if(!SD.begin(SDCARD_SS_PIN, SDCARD_SPI)) while (1);

  // Initiate the TFT screen:
  // Create the sprite.
  spr.createSprite(TFT_HEIGHT / 2, TFT_WIDTH);

  // Define and display the 16-bit images saved on the SD card:
  drawImage<uint16_t>("data_collect.bmp", TFT_HEIGHT, 0);
  drawImage<uint16_t>("carpal_tunnel.bmp", TFT_HEIGHT/2, 0);
  drawImage<uint16_t>("mouse.bmp", TFT_HEIGHT/2, TFT_WIDTH-90);

void loop(){
  // Obtain current measurements generated by the GSR sensor and the EMG sensor.
  // Initialize the sprite.
  // Adjust the line chart data arrays:
  if(gsr_data.size() == max_size) gsr_data.pop();
  if(emg_data.size() == max_size) emg_data.pop();
  // Append new data variables to the line chart data arrays:
  // Display the line charts on the TFT screen: 
  display_line_chart(0, "GSR", TFT_HEIGHT/2, 90, gsr_data, text_color, tft.color565(165,40,44));
  display_line_chart(110, "EMG", TFT_HEIGHT/2, 90, emg_data, text_color, tft.color565(165,40,44));
  spr.pushSprite(0, 0);

  // Save the data record to the given CSV file with the selected soreness class.
  if(digitalRead(WIO_KEY_A) == LOW) save_data_to_SD_Card("0");
  if(digitalRead(WIO_KEY_B) == LOW) save_data_to_SD_Card("1");
  if(digitalRead(WIO_KEY_C) == LOW) save_data_to_SD_Card("2");

void save_data_to_SD_Card(String Soreness){
  // Open the given CSV file on the SD card in the APPEND file mode.
  myFile = SD.open(data_file, FILE_APPEND);
  // If the given file is opened successfully:
    Serial.print("Writing to "); Serial.print(data_file); Serial.println("...");
    // Create the data record to be inserted as a new row: 
    String data_record = String(gsr_value) + "," + String(emg_value) + "," + Soreness;
    // Append the data record:
    // Close the CSV file:
    Serial.println("Data saved successfully!\n");
    // Notify the user after appending the given data record successfully.
    drawImage<uint16_t>("data_collect.bmp", TFT_HEIGHT/4, 0);
    tft.drawString("Selected Soreness Class: " + Soreness, 0, 140);
    tft.drawString("Data Stored!", 86, 180);
    // If Wio Terminal cannot open the given CSV file successfully:
    Serial.println("Wio Terminal cannot open the given CSV file!\n");
    tft.drawString("Wio Terminal", 35, 105);
    tft.drawString("cannot open the file!", 35, 125);
  // Exit and clear:
  drawImage<uint16_t>("carpal_tunnel.bmp", TFT_HEIGHT/2, 0);
  drawImage<uint16_t>("mouse.bmp", TFT_HEIGHT/2, TFT_WIDTH-90);

void display_line_chart(int header_y, const char* header_title, int chart_width, int chart_height, doubles data, uint32_t graph_color, uint32_t line_color){
  // Define the line graph title settings:
  auto header =  text(0, header_y)
  // Define the header height and draw the line graph title. 
  header.height(header.font_height() * 2);
  // Define the line chart settings:
  auto content = line_chart(0, header.height() + header_y);
  .height(chart_height) // the actual height of the line chart
  .width(chart_width) // the actual width of the line chart
  .based_on(0.0) // the starting point of the y-axis must be float
  .show_circle(false) // drawing a circle at each point, default is on
  .value(data) // passing the given data array to the line graph
  .color(line_color) // setting the line color 
  .x_role_color(graph_color) // setting the line graph color

void get_GSR_data(int calibration){
  long sum = 0;
  // Calculate the average of the last ten GSR sensor measurements to remove the glitch.
  for(int i=0;i<10;i++){
    sum += analogRead(GSR);
  gsr_value = (sum / 10) - calibration;
  Serial.print("GSR Value => "); Serial.println(gsr_value);

void get_EMG_data(){
  long sum = 0;
  // Evaluate the summation of the last 32 EMG sensor measurements.
  for(int i=0;i<32;i++){
    sum += analogRead(EMG);
  // Shift the summation by five with the right shift operator (>>) to obtain the EMG value.  
  emg_value = sum >> 5;
  Serial.print("EMG Value => "); Serial.println(emg_value); Serial.println();


      // Mouse Fatigue Estimation by GSR and EMG //
     //          Values w/ TensorFlow           //
    //             ---------------             //
   //              (Wio Terminal)             //
  //             by Kutluhan Aktar           //
 //                                         //

// Collate forearm muscle soreness data on the SD card, build and train a neural network model, and run the model directly on Wio Terminal.
// For more information:
// https://www.theamplituhedron.com/projects/Mouse_Fatigue_Estimation_by_GSR_and_EMG_Values_w_TensorFlow/
// Connections
// Wio Terminal :
//                                Grove - GSR sensor
// A0  --------------------------- Grove Connector
//                                Grove - EMG Detector
// A2  --------------------------- Grove Connector

// Include the required libraries.
#include <SPI.h>
#include <Seeed_FS.h>
#include "TFT_eSPI.h"
#include "seeed_line_chart.h"
#include "SD/Seeed_SD.h"
#include "RawImage.h"

// Import the required TensorFlow modules.
#include "TensorFlowLite.h"
#include "tensorflow/lite/micro/kernels/micro_ops.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/all_ops_resolver.h"
#include "tensorflow/lite/version.h"

// Import the converted TensorFlow Lite model.
#include "mouse_fatigue_level.h"

// TFLite globals, used for compatibility with Arduino-style sketches:
namespace {
  tflite::ErrorReporter* error_reporter = nullptr;
  const tflite::Model* model = nullptr;
  tflite::MicroInterpreter* interpreter = nullptr;
  TfLiteTensor* model_input = nullptr;
  TfLiteTensor* model_output = nullptr;

  // Create an area of memory to use for input, output, and other TensorFlow arrays.
  constexpr int kTensorArenaSize = 15 * 1024;
  uint8_t tensor_arena[kTensorArenaSize];
} // namespace

// Define the threshold value for the model outputs (results).
float threshold = 0.75;

// Define the muscle soreness level (class) names and color codes:
String classes[] = {"Relaxed", "Tense", "Exhausted"};
uint32_t color_codes[] = {tft.color565(1,156,0), tft.color565(255,169,2), tft.color565(226,16,1)};

// Define the class image list:
const char* images[] = {"relaxed.bmp", "tense.bmp", "exhausted.bmp"};

// Define the TFT screen:
TFT_eSPI tft;

// Define the sprite settings: 
#define max_size 50 // maximum size of data
doubles gsr_data, emg_data;
TFT_eSprite spr = TFT_eSprite(&tft);

// Define the sensor voltage (signal) pins:
#define GSR A0
#define EMG A2

// Define the data holders.
int gsr_value, emg_value;
uint32_t background_color = tft.color565(31,32,32);
uint32_t text_color = tft.color565(174,255,205);

void setup(){

  // 5-Way Switch

  // TensorFlow Lite Model settings:
  // Set up logging (will report to Serial, even within TFLite functions).
  static tflite::MicroErrorReporter micro_error_reporter;
  error_reporter = &micro_error_reporter;

  // Map the model into a usable data structure.
  model = tflite::GetModel(mouse_fatigue_level);
  if (model->version() != TFLITE_SCHEMA_VERSION) {
    error_reporter->Report("Model version does not match Schema");

  // This pulls in all the operation implementations we need.
  // NOLINTNEXTLINE(runtime-global-variables)
  static tflite::AllOpsResolver resolver;

  // Build an interpreter to run the model.
  static tflite::MicroInterpreter static_interpreter(
    model, resolver, tensor_arena, kTensorArenaSize,
  interpreter = &static_interpreter;

  // Allocate memory from the tensor_arena for the model's tensors.
  TfLiteStatus allocate_status = interpreter->AllocateTensors();
  if (allocate_status != kTfLiteOk) {
    error_reporter->Report("AllocateTensors() failed");

  // Assign model input and output buffers (tensors) to pointers.
  model_input = interpreter->input(0);
  model_output = interpreter->output(0);


  // Check the connection status between Wio Terminal and the SD card.
  if(!SD.begin(SDCARD_SS_PIN, SDCARD_SPI)) while (1);

  // Initiate the TFT screen:
  // Create the sprite.
  spr.createSprite(TFT_HEIGHT / 2, TFT_WIDTH);

  // Define and display the 16-bit images saved on the SD card:
  for(int i=0; i<3; i++){ drawImage<uint16_t>(images[i], TFT_HEIGHT, 0); }
  drawImage<uint16_t>("carpal_tunnel.bmp", TFT_HEIGHT/2, 0);
  drawImage<uint16_t>("mouse.bmp", TFT_HEIGHT/2, TFT_WIDTH-90);


void loop(){
  // Obtain current measurements generated by the GSR sensor and the EMG sensor.
  // Initialize the sprite.
  // Adjust the line chart data arrays:
  if(gsr_data.size() == max_size) gsr_data.pop();
  if(emg_data.size() == max_size) emg_data.pop();
  // Append new data variables to the line chart data arrays:
  // Display the line charts on the TFT screen: 
  display_line_chart(0, "GSR", TFT_HEIGHT/2, 90, gsr_data, text_color, tft.color565(165,40,44));
  display_line_chart(110, "EMG", TFT_HEIGHT/2, 90, emg_data, text_color, tft.color565(165,40,44));
  spr.pushSprite(0, 0);

  // Execute the TensorFlow Lite model to make predictions on the muscle soreness levels (classes). 
  if(digitalRead(WIO_5S_PRESS) == LOW) run_inference_to_make_predictions();

void run_inference_to_make_predictions(){    
    // Scale (normalize) muscle soreness data depending on the given model and copy them to the input buffer (tensor):
    model_input->data.f[0] = gsr_value / 1000;
    model_input->data.f[1] = emg_value / 1000;
    // Run inference:
    TfLiteStatus invoke_status = interpreter->Invoke();
    if (invoke_status != kTfLiteOk) {
      error_reporter->Report("Invoke failed on the given input.");

    // Read predicted y values (muscle soreness classes) from the output buffer (tensor): 
    for(int i = 0; i<3; i++){
      if(model_output->data.f[i] >= threshold){
        int w = 150, h = 40, str_x = 12, str_y = 65;
        int y_offset = h + ((h - str_y) / 2);
        int x_offset = classes[i].length() * str_x;
        // Display the detection result (class).
        drawImage<uint16_t>(images[i], (TFT_HEIGHT-75)/2, 0);
        // Print:
        tft.fillRect((TFT_HEIGHT-w)/2, TFT_WIDTH-h, w, h, color_codes[i]);
        tft.drawString(classes[i], (TFT_HEIGHT-x_offset)/2, TFT_WIDTH-y_offset);
    // Exit and clear:
    drawImage<uint16_t>("carpal_tunnel.bmp", TFT_HEIGHT/2, 0);
    drawImage<uint16_t>("mouse.bmp", TFT_HEIGHT/2, TFT_WIDTH-90);

void display_line_chart(int header_y, const char* header_title, int chart_width, int chart_height, doubles data, uint32_t graph_color, uint32_t line_color){
  // Define the line graph title settings:
  auto header =  text(0, header_y)
  // Define the header height and draw the line graph title. 
  header.height(header.font_height() * 2);
  // Define the line chart settings:
  auto content = line_chart(0, header.height() + header_y);
  .height(chart_height) // the actual height of the line chart
  .width(chart_width) // the actual width of the line chart
  .based_on(0.0) // the starting point of the y-axis must be float
  .show_circle(false) // drawing a circle at each point, default is on
  .value(data) // passing the given data array to the line graph
  .color(line_color) // setting the line color 
  .x_role_color(graph_color) // setting the line graph color

void get_GSR_data(int calibration){
  long sum = 0;
  // Calculate the average of the last ten GSR sensor measurements to remove the glitch.
  for(int i=0;i<10;i++){
    sum += analogRead(GSR);
  gsr_value = (sum / 10) - calibration;
  Serial.print("GSR Value => "); Serial.println(gsr_value);

void get_EMG_data(){
  long sum = 0;
  // Evaluate the summation of the last 32 EMG sensor measurements.
  for(int i=0;i<32;i++){
    sum += analogRead(EMG);
  // Shift the summation by five with the right shift operator (>>) to obtain the EMG value.  
  emg_value = sum >> 5;
  Serial.print("EMG Value => "); Serial.println(emg_value); Serial.println();



