Yuyu QianYuetongYangXurui ZhangYifan Wang
Published © GPL3+

COMP554_TinyML_WAKEWORD_DETECTION_OWLWO

TinyML project implemented by Tensorflow, with the application deployed on ARDUINO NANO board able to detect commands "Yes/No".

BeginnerFull instructions provided8 hours522
COMP554_TinyML_WAKEWORD_DETECTION_OWLWO

Things used in this project

Hardware components

Nano 33 BLE Sense
Arduino Nano 33 BLE Sense
Arduino NANO 33 BLE board has been designed to offer a power savvy and cost effective solution for makers seeking to have Bluetooth Low Energy connectivity in their projects. it is based on a NINA B306 module, that hosts a Nordic nRF52480 that contains a Cortex M4F microcontroller. The NANO 33 BLE also contains a LSM9DS1 9 axis IMU.
×1

Software apps and online services

Windows 10
Microsoft Windows 10
Arduino IDE
Arduino IDE
Google Colab

Hand tools and fabrication machines

Tape, Foam
Tape, Foam

Story

Read more

Schematics

Demo1

This is a demo that demonstrates the result of our training for the "yes/no" wake up word detection

Demo2

This is another demo that demonstrates the wake up word detection deployed on the Nano Board.

Code

Enviroment Set Up

Python
Set up the ENV for training.
import os
os.environ["WANTED_WORDS"] = "yes,no"
os.environ["TRAINING_STEPS"]="15000,3000"
os.environ["LEARNING_RATE"]="0.001,0.0001"
os.environ["TOTAL_STEPS"]="18000"

generate_streaming_test_wav.py

Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Saves out a .wav file with synthesized conversational data and labels.

The best way to estimate the real-world performance of an audio recognition
model is by running it against a continuous stream of data, the way that it
would be used in an application. Training evaluations are only run against
discrete individual samples, so the results aren't as realistic.

To make it easy to run evaluations against audio streams, this script uses
samples from the testing partition of the data set, mixes them in at random
positions together with background noise, and saves out the result as one long
audio file.

Here's an example of generating a test file:

bazel run tensorflow/examples/speech_commands:generate_streaming_test_wav -- \
--data_dir=/tmp/my_wavs --background_dir=/tmp/my_backgrounds \
--background_volume=0.1 --test_duration_seconds=600 \
--output_audio_file=/tmp/streaming_test.wav \
--output_labels_file=/tmp/streaming_test_labels.txt

Once you've created a streaming audio file, you can then use the
test_streaming_accuracy tool to calculate accuracy metrics for a model.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import math
import sys

import numpy as np
import tensorflow as tf

import input_data
import models

FLAGS = None


def mix_in_audio_sample(track_data, track_offset, sample_data, sample_offset,
                        clip_duration, sample_volume, ramp_in, ramp_out):
  """Mixes the sample data into the main track at the specified offset.

  Args:
    track_data: Numpy array holding main audio data. Modified in-place.
    track_offset: Where to mix the sample into the main track.
    sample_data: Numpy array of audio data to mix into the main track.
    sample_offset: Where to start in the audio sample.
    clip_duration: How long the sample segment is.
    sample_volume: Loudness to mix the sample in at.
    ramp_in: Length in samples of volume increase stage.
    ramp_out: Length in samples of volume decrease stage.
  """
  ramp_out_index = clip_duration - ramp_out
  track_end = min(track_offset + clip_duration, track_data.shape[0])
  track_end = min(track_end,
                  track_offset + (sample_data.shape[0] - sample_offset))
  sample_range = track_end - track_offset
  for i in range(sample_range):
    if i < ramp_in:
      envelope_scale = i / ramp_in
    elif i > ramp_out_index:
      envelope_scale = (clip_duration - i) / ramp_out
    else:
      envelope_scale = 1
    sample_input = sample_data[sample_offset + i]
    track_data[track_offset
               + i] += sample_input * envelope_scale * sample_volume


def main(_):
  words_list = input_data.prepare_words_list(FLAGS.wanted_words.split(','))
  model_settings = models.prepare_model_settings(
      len(words_list), FLAGS.sample_rate, FLAGS.clip_duration_ms,
      FLAGS.window_size_ms, FLAGS.window_stride_ms, FLAGS.feature_bin_count,
      'mfcc')
  audio_processor = input_data.AudioProcessor(
      '', FLAGS.data_dir, FLAGS.silence_percentage, 10,
      FLAGS.wanted_words.split(','), FLAGS.validation_percentage,
      FLAGS.testing_percentage, model_settings, FLAGS.data_dir)

  output_audio_sample_count = FLAGS.sample_rate * FLAGS.test_duration_seconds
  output_audio = np.zeros((output_audio_sample_count,), dtype=np.float32)

  # Set up background audio.
  background_crossover_ms = 500
  background_segment_duration_ms = (
      FLAGS.clip_duration_ms + background_crossover_ms)
  background_segment_duration_samples = int(
      (background_segment_duration_ms * FLAGS.sample_rate) / 1000)
  background_segment_stride_samples = int(
      (FLAGS.clip_duration_ms * FLAGS.sample_rate) / 1000)
  background_ramp_samples = int(
      ((background_crossover_ms / 2) * FLAGS.sample_rate) / 1000)

  # Mix the background audio into the main track.
  how_many_backgrounds = int(
      math.ceil(output_audio_sample_count / background_segment_stride_samples))
  for i in range(how_many_backgrounds):
    output_offset = int(i * background_segment_stride_samples)
    background_index = np.random.randint(len(audio_processor.background_data))
    background_samples = audio_processor.background_data[background_index]
    background_offset = np.random.randint(
        0, len(background_samples) - model_settings['desired_samples'])
    background_volume = np.random.uniform(0, FLAGS.background_volume)
    mix_in_audio_sample(output_audio, output_offset, background_samples,
                        background_offset, background_segment_duration_samples,
                        background_volume, background_ramp_samples,
                        background_ramp_samples)

  # Mix the words into the main track, noting their labels and positions.
  output_labels = []
  word_stride_ms = FLAGS.clip_duration_ms + FLAGS.word_gap_ms
  word_stride_samples = int((word_stride_ms * FLAGS.sample_rate) / 1000)
  clip_duration_samples = int(
      (FLAGS.clip_duration_ms * FLAGS.sample_rate) / 1000)
  word_gap_samples = int((FLAGS.word_gap_ms * FLAGS.sample_rate) / 1000)
  how_many_words = int(
      math.floor(output_audio_sample_count / word_stride_samples))
  all_test_data, all_test_labels = audio_processor.get_unprocessed_data(
      -1, model_settings, 'testing')
  for i in range(how_many_words):
    output_offset = (
        int(i * word_stride_samples) + np.random.randint(word_gap_samples))
    output_offset_ms = (output_offset * 1000) / FLAGS.sample_rate
    is_unknown = np.random.randint(100) < FLAGS.unknown_percentage
    if is_unknown:
      wanted_label = input_data.UNKNOWN_WORD_LABEL
    else:
      wanted_label = words_list[2 + np.random.randint(len(words_list) - 2)]
    test_data_start = np.random.randint(len(all_test_data))
    found_sample_data = None
    index_lookup = np.arange(len(all_test_data), dtype=np.int32)
    np.random.shuffle(index_lookup)
    for test_data_offset in range(len(all_test_data)):
      test_data_index = index_lookup[(
          test_data_start + test_data_offset) % len(all_test_data)]
      current_label = all_test_labels[test_data_index]
      if current_label == wanted_label:
        found_sample_data = all_test_data[test_data_index]
        break
    mix_in_audio_sample(output_audio, output_offset, found_sample_data, 0,
                        clip_duration_samples, 1.0, 500, 500)
    output_labels.append({'label': wanted_label, 'time': output_offset_ms})

  input_data.save_wav_file(FLAGS.output_audio_file, output_audio,
                           FLAGS.sample_rate)
  tf.compat.v1.logging.info('Saved streaming test wav to %s',
                            FLAGS.output_audio_file)

  with open(FLAGS.output_labels_file, 'w') as f:
    for output_label in output_labels:
      f.write('%s, %f\n' % (output_label['label'], output_label['time']))
  tf.compat.v1.logging.info('Saved streaming test labels to %s',
                            FLAGS.output_labels_file)


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--data_url',
      type=str,
      # pylint: disable=line-too-long
      default='https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.01.tar.gz',
      # pylint: enable=line-too-long
      help='Location of speech training data')
  parser.add_argument(
      '--data_dir',
      type=str,
      default='/tmp/speech_dataset',
      help="""\
      Where to download the speech training data to.
      """)
  parser.add_argument(
      '--background_dir',
      type=str,
      default='',
      help="""\
      Path to a directory of .wav files to mix in as background noise during training.
      """)
  parser.add_argument(
      '--background_volume',
      type=float,
      default=0.1,
      help="""\
      How loud the background noise should be, between 0 and 1.
      """)
  parser.add_argument(
      '--background_frequency',
      type=float,
      default=0.8,
      help="""\
      How many of the training samples have background noise mixed in.
      """)
  parser.add_argument(
      '--silence_percentage',
      type=float,
      default=10.0,
      help="""\
      How much of the training data should be silence.
      """)
  parser.add_argument(
      '--testing_percentage',
      type=int,
      default=10,
      help='What percentage of wavs to use as a test set.')
  parser.add_argument(
      '--validation_percentage',
      type=int,
      default=10,
      help='What percentage of wavs to use as a validation set.')
  parser.add_argument(
      '--sample_rate',
      type=int,
      default=16000,
      help='Expected sample rate of the wavs.',)
  parser.add_argument(
      '--clip_duration_ms',
      type=int,
      default=1000,
      help='Expected duration in milliseconds of the wavs.',)
  parser.add_argument(
      '--window_size_ms',
      type=float,
      default=30.0,
      help='How long each spectrogram timeslice is',)
  parser.add_argument(
      '--window_stride_ms',
      type=float,
      default=10.0,
      help='How long the stride is between spectrogram timeslices',)
  parser.add_argument(
      '--feature_bin_count',
      type=int,
      default=40,
      help='How many bins to use for the MFCC fingerprint',
  )
  parser.add_argument(
      '--wanted_words',
      type=str,
      default='yes,no,up,down,left,right,on,off,stop,go',
      help='Words to use (others will be added to an unknown label)',)
  parser.add_argument(
      '--output_audio_file',
      type=str,
      default='/tmp/speech_commands_train/streaming_test.wav',
      help='File to save the generated test audio to.')
  parser.add_argument(
      '--output_labels_file',
      type=str,
      default='/tmp/speech_commands_train/streaming_test_labels.txt',
      help='File to save the generated test labels to.')
  parser.add_argument(
      '--test_duration_seconds',
      type=int,
      default=600,
      help='How long the generated test audio file should be.',)
  parser.add_argument(
      '--word_gap_ms',
      type=int,
      default=2000,
      help='How long the average gap should be between words.',)
  parser.add_argument(
      '--unknown_percentage',
      type=int,
      default=30,
      help='What percentage of words should be unknown.')

  FLAGS, unparsed = parser.parse_known_args()
  tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)

label_wav_dir.py

Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Runs a trained audio graph against WAVE files and reports the results.

The model, labels and .wav files specified in the arguments will be loaded, and
then the predictions from running the model against the audio data will be
printed to the console. This is a useful script for sanity checking trained
models, and as an example of how to use an audio model from Python.

Here's an example of running it:

python tensorflow/examples/speech_commands/label_wav_dir.py \
--graph=/tmp/my_frozen_graph.pb \
--labels=/tmp/speech_commands_train/conv_labels.txt \
--wav_dir=/tmp/speech_dataset/left

"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import glob
import sys

import tensorflow as tf


FLAGS = None


def load_graph(filename):
  """Unpersists graph from file as default graph."""
  with tf.io.gfile.GFile(filename, 'rb') as f:
    graph_def = tf.compat.v1.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')


def load_labels(filename):
  """Read in labels, one label per line."""
  return [line.rstrip() for line in tf.io.gfile.GFile(filename)]


def run_graph(wav_dir, labels, input_layer_name, output_layer_name,
              num_top_predictions):
  """Runs the audio data through the graph and prints predictions."""
  with tf.compat.v1.Session() as sess:
    # Feed the audio data as input to the graph.
    #   predictions  will contain a two-dimensional array, where one
    #   dimension represents the input image count, and the other has
    #   predictions per class
    for wav_path in glob.glob(wav_dir + '/*.wav'):
      if not wav_path or not tf.io.gfile.exists(wav_path):
        raise ValueError('Audio file does not exist at {0}'.format(wav_path))
      with open(wav_path, 'rb') as wav_file:
        wav_data = wav_file.read()

      softmax_tensor = sess.graph.get_tensor_by_name(output_layer_name)
      predictions, = sess.run(softmax_tensor, {input_layer_name: wav_data})

      # Sort to show labels in order of confidence
      print('\n%s' % (wav_path.split('/')[-1]))
      top_k = predictions.argsort()[-num_top_predictions:][::-1]
      for node_id in top_k:
        human_string = labels[node_id]
        score = predictions[node_id]
        print('%s (score = %.5f)' % (human_string, score))

    return 0


def label_wav(wav_dir, labels, graph, input_name, output_name, how_many_labels):
  """Loads the model and labels, and runs the inference to print predictions."""
  if not labels or not tf.io.gfile.exists(labels):
    raise ValueError('Labels file does not exist at {0}'.format(labels))

  if not graph or not tf.io.gfile.exists(graph):
    raise ValueError('Graph file does not exist at {0}'.format(graph))

  labels_list = load_labels(labels)

  # load graph, which is stored in the default session
  load_graph(graph)

  run_graph(wav_dir, labels_list, input_name, output_name, how_many_labels)


def main(_):
  """Entry point for script, converts flags to arguments."""
  label_wav(FLAGS.wav_dir, FLAGS.labels, FLAGS.graph, FLAGS.input_name,
            FLAGS.output_name, FLAGS.how_many_labels)


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--wav_dir', type=str, default='', help='Audio file to be identified.')
  parser.add_argument(
      '--graph', type=str, default='', help='Model to use for identification.')
  parser.add_argument(
      '--labels', type=str, default='', help='Path to file containing labels.')
  parser.add_argument(
      '--input_name',
      type=str,
      default='wav_data:0',
      help='Name of WAVE data input node in model.')
  parser.add_argument(
      '--output_name',
      type=str,
      default='labels_softmax:0',
      help='Name of node outputting a prediction in the model.')
  parser.add_argument(
      '--how_many_labels',
      type=int,
      default=3,
      help='Number of results to show.')

  FLAGS, unparsed = parser.parse_known_args()
  tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)

label_wav.py

Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Runs a trained audio graph against a WAVE file and reports the results.

The model, labels and .wav file specified in the arguments will be loaded, and
then the predictions from running the model against the audio data will be
printed to the console. This is a useful script for sanity checking trained
models, and as an example of how to use an audio model from Python.

Here's an example of running it:

python tensorflow/examples/speech_commands/label_wav.py \
--graph=/tmp/my_frozen_graph.pb \
--labels=/tmp/speech_commands_train/conv_labels.txt \
--wav=/tmp/speech_dataset/left/a5d485dc_nohash_0.wav

"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys

import tensorflow as tf


FLAGS = None


def load_graph(filename):
  """Unpersists graph from file as default graph."""
  with tf.io.gfile.GFile(filename, 'rb') as f:
    graph_def = tf.compat.v1.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')


def load_labels(filename):
  """Read in labels, one label per line."""
  return [line.rstrip() for line in tf.io.gfile.GFile(filename)]


def run_graph(wav_data, labels, input_layer_name, output_layer_name,
              num_top_predictions):
  """Runs the audio data through the graph and prints predictions."""
  with tf.compat.v1.Session() as sess:
    # Feed the audio data as input to the graph.
    #   predictions  will contain a two-dimensional array, where one
    #   dimension represents the input image count, and the other has
    #   predictions per class
    softmax_tensor = sess.graph.get_tensor_by_name(output_layer_name)
    predictions, = sess.run(softmax_tensor, {input_layer_name: wav_data})

    # Sort to show labels in order of confidence
    top_k = predictions.argsort()[-num_top_predictions:][::-1]
    for node_id in top_k:
      human_string = labels[node_id]
      score = predictions[node_id]
      print('%s (score = %.5f)' % (human_string, score))

    return 0


def label_wav(wav, labels, graph, input_name, output_name, how_many_labels):
  """Loads the model and labels, and runs the inference to print predictions."""
  if not wav or not tf.io.gfile.exists(wav):
    raise ValueError('Audio file does not exist at {0}'.format(wav))
  if not labels or not tf.io.gfile.exists(labels):
    raise ValueError('Labels file does not exist at {0}'.format(labels))

  if not graph or not tf.io.gfile.exists(graph):
    raise ValueError('Graph file does not exist at {0}'.format(graph))

  labels_list = load_labels(labels)

  # load graph, which is stored in the default session
  load_graph(graph)

  with open(wav, 'rb') as wav_file:
    wav_data = wav_file.read()

  run_graph(wav_data, labels_list, input_name, output_name, how_many_labels)


def main(_):
  """Entry point for script, converts flags to arguments."""
  label_wav(FLAGS.wav, FLAGS.labels, FLAGS.graph, FLAGS.input_name,
            FLAGS.output_name, FLAGS.how_many_labels)


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--wav', type=str, default='', help='Audio file to be identified.')
  parser.add_argument(
      '--graph', type=str, default='', help='Model to use for identification.')
  parser.add_argument(
      '--labels', type=str, default='', help='Path to file containing labels.')
  parser.add_argument(
      '--input_name',
      type=str,
      default='wav_data:0',
      help='Name of WAVE data input node in model.')
  parser.add_argument(
      '--output_name',
      type=str,
      default='labels_softmax:0',
      help='Name of node outputting a prediction in the model.')
  parser.add_argument(
      '--how_many_labels',
      type=int,
      default=3,
      help='Number of results to show.')

  FLAGS, unparsed = parser.parse_known_args()
  tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)

label_wav_dir.py

Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Runs a trained audio graph against WAVE files and reports the results.

The model, labels and .wav files specified in the arguments will be loaded, and
then the predictions from running the model against the audio data will be
printed to the console. This is a useful script for sanity checking trained
models, and as an example of how to use an audio model from Python.

Here's an example of running it:

python tensorflow/examples/speech_commands/label_wav_dir.py \
--graph=/tmp/my_frozen_graph.pb \
--labels=/tmp/speech_commands_train/conv_labels.txt \
--wav_dir=/tmp/speech_dataset/left

"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import glob
import sys

import tensorflow as tf


FLAGS = None


def load_graph(filename):
  """Unpersists graph from file as default graph."""
  with tf.io.gfile.GFile(filename, 'rb') as f:
    graph_def = tf.compat.v1.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')


def load_labels(filename):
  """Read in labels, one label per line."""
  return [line.rstrip() for line in tf.io.gfile.GFile(filename)]


def run_graph(wav_dir, labels, input_layer_name, output_layer_name,
              num_top_predictions):
  """Runs the audio data through the graph and prints predictions."""
  with tf.compat.v1.Session() as sess:
    # Feed the audio data as input to the graph.
    #   predictions  will contain a two-dimensional array, where one
    #   dimension represents the input image count, and the other has
    #   predictions per class
    for wav_path in glob.glob(wav_dir + '/*.wav'):
      if not wav_path or not tf.io.gfile.exists(wav_path):
        raise ValueError('Audio file does not exist at {0}'.format(wav_path))
      with open(wav_path, 'rb') as wav_file:
        wav_data = wav_file.read()

      softmax_tensor = sess.graph.get_tensor_by_name(output_layer_name)
      predictions, = sess.run(softmax_tensor, {input_layer_name: wav_data})

      # Sort to show labels in order of confidence
      print('\n%s' % (wav_path.split('/')[-1]))
      top_k = predictions.argsort()[-num_top_predictions:][::-1]
      for node_id in top_k:
        human_string = labels[node_id]
        score = predictions[node_id]
        print('%s (score = %.5f)' % (human_string, score))

    return 0


def label_wav(wav_dir, labels, graph, input_name, output_name, how_many_labels):
  """Loads the model and labels, and runs the inference to print predictions."""
  if not labels or not tf.io.gfile.exists(labels):
    raise ValueError('Labels file does not exist at {0}'.format(labels))

  if not graph or not tf.io.gfile.exists(graph):
    raise ValueError('Graph file does not exist at {0}'.format(graph))

  labels_list = load_labels(labels)

  # load graph, which is stored in the default session
  load_graph(graph)

  run_graph(wav_dir, labels_list, input_name, output_name, how_many_labels)


def main(_):
  """Entry point for script, converts flags to arguments."""
  label_wav(FLAGS.wav_dir, FLAGS.labels, FLAGS.graph, FLAGS.input_name,
            FLAGS.output_name, FLAGS.how_many_labels)


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--wav_dir', type=str, default='', help='Audio file to be identified.')
  parser.add_argument(
      '--graph', type=str, default='', help='Model to use for identification.')
  parser.add_argument(
      '--labels', type=str, default='', help='Path to file containing labels.')
  parser.add_argument(
      '--input_name',
      type=str,
      default='wav_data:0',
      help='Name of WAVE data input node in model.')
  parser.add_argument(
      '--output_name',
      type=str,
      default='labels_softmax:0',
      help='Name of node outputting a prediction in the model.')
  parser.add_argument(
      '--how_many_labels',
      type=int,
      default=3,
      help='Number of results to show.')

  FLAGS, unparsed = parser.parse_known_args()
  tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)

label_wav.cc

Python
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <fstream>
#include <vector>

#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"

// These are all common classes it's handy to reference with no namespace.
using tensorflow::Flag;
using tensorflow::int32;
using tensorflow::Status;
using tensorflow::string;
using tensorflow::Tensor;
using tensorflow::tstring;

namespace {

// Reads a model graph definition from disk, and creates a session object you
// can use to run it.
Status LoadGraph(const string& graph_file_name,
                 std::unique_ptr<tensorflow::Session>* session) {
  tensorflow::GraphDef graph_def;
  Status load_graph_status =
      ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
  if (!load_graph_status.ok()) {
    return tensorflow::errors::NotFound("Failed to load compute graph at '",
                                        graph_file_name, "'");
  }
  session->reset(tensorflow::NewSession(tensorflow::SessionOptions()));
  Status session_create_status = (*session)->Create(graph_def);
  if (!session_create_status.ok()) {
    return session_create_status;
  }
  return Status::OK();
}

// Takes a file name, and loads a list of labels from it, one per line, and
// returns a vector of the strings.
Status ReadLabelsFile(const string& file_name, std::vector<string>* result) {
  std::ifstream file(file_name);
  if (!file) {
    return tensorflow::errors::NotFound("Labels file ", file_name,
                                        " not found.");
  }
  result->clear();
  string line;
  while (std::getline(file, line)) {
    result->push_back(line);
  }
  return Status::OK();
}

// Analyzes the output of the graph to retrieve the highest scores and
// their positions in the tensor.
void GetTopLabels(const std::vector<Tensor>& outputs, int how_many_labels,
                  Tensor* out_indices, Tensor* out_scores) {
  const Tensor& unsorted_scores_tensor = outputs[0];
  auto unsorted_scores_flat = unsorted_scores_tensor.flat<float>();
  std::vector<std::pair<int, float>> scores;
  scores.reserve(unsorted_scores_flat.size());
  for (int i = 0; i < unsorted_scores_flat.size(); ++i) {
    scores.push_back(std::pair<int, float>({i, unsorted_scores_flat(i)}));
  }
  std::sort(scores.begin(), scores.end(),
            [](const std::pair<int, float>& left,
               const std::pair<int, float>& right) {
              return left.second > right.second;
            });
  scores.resize(how_many_labels);
  Tensor sorted_indices(tensorflow::DT_INT32, {how_many_labels});
  Tensor sorted_scores(tensorflow::DT_FLOAT, {how_many_labels});
  for (int i = 0; i < scores.size(); ++i) {
    sorted_indices.flat<int>()(i) = scores[i].first;
    sorted_scores.flat<float>()(i) = scores[i].second;
  }
  *out_indices = sorted_indices;
  *out_scores = sorted_scores;
}

}  // namespace

int main(int argc, char* argv[]) {
  string wav = "";
  string graph = "";
  string labels = "";
  string input_name = "wav_data";
  string output_name = "labels_softmax";
  int32 how_many_labels = 3;
  std::vector<Flag> flag_list = {
      Flag("wav", &wav, "audio file to be identified"),
      Flag("graph", &graph, "model to be executed"),
      Flag("labels", &labels, "path to file containing labels"),
      Flag("input_name", &input_name, "name of input node in model"),
      Flag("output_name", &output_name, "name of output node in model"),
      Flag("how_many_labels", &how_many_labels, "number of results to show"),
  };
  string usage = tensorflow::Flags::Usage(argv[0], flag_list);
  const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
  if (!parse_result) {
    LOG(ERROR) << usage;
    return -1;
  }

  // We need to call this to set up global state for TensorFlow.
  tensorflow::port::InitMain(argv[0], &argc, &argv);
  if (argc > 1) {
    LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
    return -1;
  }

  // First we load and initialize the model.
  std::unique_ptr<tensorflow::Session> session;
  Status load_graph_status = LoadGraph(graph, &session);
  if (!load_graph_status.ok()) {
    LOG(ERROR) << load_graph_status;
    return -1;
  }

  std::vector<string> labels_list;
  Status read_labels_status = ReadLabelsFile(labels, &labels_list);
  if (!read_labels_status.ok()) {
    LOG(ERROR) << read_labels_status;
    return -1;
  }

  string wav_string;
  Status read_wav_status = tensorflow::ReadFileToString(
      tensorflow::Env::Default(), wav, &wav_string);
  if (!read_wav_status.ok()) {
    LOG(ERROR) << read_wav_status;
    return -1;
  }
  Tensor wav_tensor(tensorflow::DT_STRING, tensorflow::TensorShape({}));
  wav_tensor.scalar<tstring>()() = wav_string;

  // Actually run the audio through the model.
  std::vector<Tensor> outputs;
  Status run_status =
      session->Run({{input_name, wav_tensor}}, {output_name}, {}, &outputs);
  if (!run_status.ok()) {
    LOG(ERROR) << "Running model failed: " << run_status;
    return -1;
  }

  Tensor indices;
  Tensor scores;
  GetTopLabels(outputs, how_many_labels, &indices, &scores);
  tensorflow::TTypes<float>::Flat scores_flat = scores.flat<float>();
  tensorflow::TTypes<int32>::Flat indices_flat = indices.flat<int32>();
  for (int pos = 0; pos < how_many_labels; ++pos) {
    const int label_index = indices_flat(pos);
    const float score = scores_flat(pos);
    LOG(INFO) << labels_list[label_index] << " (" << label_index
              << "): " << score;
  }

  return 0;
}

accuracy_utils.cc

C/C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/examples/speech_commands/accuracy_utils.h"

#include <fstream>
#include <iomanip>
#include <unordered_set>

#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"

namespace tensorflow {

Status ReadGroundTruthFile(const string& file_name,
                           std::vector<std::pair<string, int64>>* result) {
  std::ifstream file(file_name);
  if (!file) {
    return tensorflow::errors::NotFound("Ground truth file '", file_name,
                                        "' not found.");
  }
  result->clear();
  string line;
  while (std::getline(file, line)) {
    std::vector<string> pieces = tensorflow::str_util::Split(line, ',');
    if (pieces.size() != 2) {
      continue;
    }
    float timestamp;
    if (!tensorflow::strings::safe_strtof(pieces[1].c_str(), &timestamp)) {
      return tensorflow::errors::InvalidArgument(
          "Wrong number format at line: ", line);
    }
    string label = pieces[0];
    auto timestamp_int64 = static_cast<int64>(timestamp);
    result->push_back({label, timestamp_int64});
  }
  std::sort(result->begin(), result->end(),
            [](const std::pair<string, int64>& left,
               const std::pair<string, int64>& right) {
              return left.second < right.second;
            });
  return Status::OK();
}

void CalculateAccuracyStats(
    const std::vector<std::pair<string, int64>>& ground_truth_list,
    const std::vector<std::pair<string, int64>>& found_words,
    int64 up_to_time_ms, int64 time_tolerance_ms,
    StreamingAccuracyStats* stats) {
  int64 latest_possible_time;
  if (up_to_time_ms == -1) {
    latest_possible_time = std::numeric_limits<int64>::max();
  } else {
    latest_possible_time = up_to_time_ms + time_tolerance_ms;
  }
  stats->how_many_ground_truth_words = 0;
  for (const std::pair<string, int64>& ground_truth : ground_truth_list) {
    const int64 ground_truth_time = ground_truth.second;
    if (ground_truth_time > latest_possible_time) {
      break;
    }
    ++stats->how_many_ground_truth_words;
  }

  stats->how_many_false_positives = 0;
  stats->how_many_correct_words = 0;
  stats->how_many_wrong_words = 0;
  std::unordered_set<int64> has_ground_truth_been_matched;
  for (const std::pair<string, int64>& found_word : found_words) {
    const string& found_label = found_word.first;
    const int64 found_time = found_word.second;
    const int64 earliest_time = found_time - time_tolerance_ms;
    const int64 latest_time = found_time + time_tolerance_ms;
    bool has_match_been_found = false;
    for (const std::pair<string, int64>& ground_truth : ground_truth_list) {
      const int64 ground_truth_time = ground_truth.second;
      if ((ground_truth_time > latest_time) ||
          (ground_truth_time > latest_possible_time)) {
        break;
      }
      if (ground_truth_time < earliest_time) {
        continue;
      }
      const string& ground_truth_label = ground_truth.first;
      if ((ground_truth_label == found_label) &&
          (has_ground_truth_been_matched.count(ground_truth_time) == 0)) {
        ++stats->how_many_correct_words;
      } else {
        ++stats->how_many_wrong_words;
      }
      has_ground_truth_been_matched.insert(ground_truth_time);
      has_match_been_found = true;
      break;
    }
    if (!has_match_been_found) {
      ++stats->how_many_false_positives;
    }
  }
  stats->how_many_ground_truth_matched = has_ground_truth_been_matched.size();
}

void PrintAccuracyStats(const StreamingAccuracyStats& stats) {
  if (stats.how_many_ground_truth_words == 0) {
    LOG(INFO) << "No ground truth yet, " << stats.how_many_false_positives
              << " false positives";
  } else {
    float any_match_percentage =
        (stats.how_many_ground_truth_matched * 100.0f) /
        stats.how_many_ground_truth_words;
    float correct_match_percentage = (stats.how_many_correct_words * 100.0f) /
                                     stats.how_many_ground_truth_words;
    float wrong_match_percentage = (stats.how_many_wrong_words * 100.0f) /
                                   stats.how_many_ground_truth_words;
    float false_positive_percentage =
        (stats.how_many_false_positives * 100.0f) /
        stats.how_many_ground_truth_words;

    LOG(INFO) << std::setprecision(1) << std::fixed << any_match_percentage
              << "% matched, " << correct_match_percentage << "% correctly, "
              << wrong_match_percentage << "% wrongly, "
              << false_positive_percentage << "% false positives ";
  }
}

}  // namespace tensorflow

accuracy_utils.h

C/C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_ACCURACY_UTILS_H_
#define TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_ACCURACY_UTILS_H_

#include <vector>

#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/types.h"

namespace tensorflow {

struct StreamingAccuracyStats {
  StreamingAccuracyStats()
      : how_many_ground_truth_words(0),
        how_many_ground_truth_matched(0),
        how_many_false_positives(0),
        how_many_correct_words(0),
        how_many_wrong_words(0) {}
  int32 how_many_ground_truth_words;
  int32 how_many_ground_truth_matched;
  int32 how_many_false_positives;
  int32 how_many_correct_words;
  int32 how_many_wrong_words;
};

// Takes a file name, and loads a list of expected word labels and times from
// it, as comma-separated variables.
Status ReadGroundTruthFile(const string& file_name,
                           std::vector<std::pair<string, int64>>* result);

// Given ground truth labels and corresponding predictions found by a model,
// figure out how many were correct. Takes a time limit, so that only
// predictions up to a point in time are considered, in case we're evaluating
// accuracy when the model has only been run on part of the stream.
void CalculateAccuracyStats(
    const std::vector<std::pair<string, int64>>& ground_truth_list,
    const std::vector<std::pair<string, int64>>& found_words,
    int64 up_to_time_ms, int64 time_tolerance_ms,
    StreamingAccuracyStats* stats);

// Writes a human-readable description of the statistics to stdout.
void PrintAccuracyStats(const StreamingAccuracyStats& stats);

}  // namespace tensorflow

#endif  // TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_ACCURACY_UTILS_H_

accuracy_utils_test.cc

C/C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/examples/speech_commands/accuracy_utils.h"

#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"

namespace tensorflow {

TEST(AccuracyUtilsTest, ReadGroundTruthFile) {
  string file_name = tensorflow::io::JoinPath(tensorflow::testing::TmpDir(),
                                              "ground_truth.txt");
  string file_data = "a,10\nb,12\n";
  TF_ASSERT_OK(WriteStringToFile(Env::Default(), file_name, file_data));

  std::vector<std::pair<string, int64>> ground_truth;
  TF_ASSERT_OK(ReadGroundTruthFile(file_name, &ground_truth));
  ASSERT_EQ(2, ground_truth.size());
  EXPECT_EQ("a", ground_truth[0].first);
  EXPECT_EQ(10, ground_truth[0].second);
  EXPECT_EQ("b", ground_truth[1].first);
  EXPECT_EQ(12, ground_truth[1].second);
}

TEST(AccuracyUtilsTest, CalculateAccuracyStats) {
  StreamingAccuracyStats stats;
  CalculateAccuracyStats({{"a", 1000}, {"b", 9000}},
                         {{"a", 1200}, {"b", 5000}, {"a", 8700}}, 10000, 500,
                         &stats);
  EXPECT_EQ(2, stats.how_many_ground_truth_words);
  EXPECT_EQ(2, stats.how_many_ground_truth_matched);
  EXPECT_EQ(1, stats.how_many_false_positives);
  EXPECT_EQ(1, stats.how_many_correct_words);
  EXPECT_EQ(1, stats.how_many_wrong_words);
}

TEST(AccuracyUtilsTest, PrintAccuracyStats) {
  StreamingAccuracyStats stats;
  PrintAccuracyStats(stats);
}

}  // namespace tensorflow

freeze_test.py

Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for data input for speech commands."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os.path

from tensorflow.examples.speech_commands import freeze
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import test_util
from tensorflow.python.ops.variables import global_variables_initializer
from tensorflow.python.platform import test


class FreezeTest(test.TestCase):

  @test_util.run_deprecated_v1
  def testCreateInferenceGraphWithMfcc(self):
    with self.cached_session() as sess:
      freeze.create_inference_graph(
          wanted_words='a,b,c,d',
          sample_rate=16000,
          clip_duration_ms=1000.0,
          clip_stride_ms=30.0,
          window_size_ms=30.0,
          window_stride_ms=10.0,
          feature_bin_count=40,
          model_architecture='conv',
          preprocess='mfcc')
      self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0'))
      self.assertIsNotNone(
          sess.graph.get_tensor_by_name('decoded_sample_data:0'))
      self.assertIsNotNone(sess.graph.get_tensor_by_name('labels_softmax:0'))
      ops = [node.op for node in sess.graph_def.node]
      self.assertEqual(1, ops.count('Mfcc'))

  @test_util.run_deprecated_v1
  def testCreateInferenceGraphWithoutMfcc(self):
    with self.cached_session() as sess:
      freeze.create_inference_graph(
          wanted_words='a,b,c,d',
          sample_rate=16000,
          clip_duration_ms=1000.0,
          clip_stride_ms=30.0,
          window_size_ms=30.0,
          window_stride_ms=10.0,
          feature_bin_count=40,
          model_architecture='conv',
          preprocess='average')
      self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0'))
      self.assertIsNotNone(
          sess.graph.get_tensor_by_name('decoded_sample_data:0'))
      self.assertIsNotNone(sess.graph.get_tensor_by_name('labels_softmax:0'))
      ops = [node.op for node in sess.graph_def.node]
      self.assertEqual(0, ops.count('Mfcc'))

  @test_util.run_deprecated_v1
  def testCreateInferenceGraphWithMicro(self):
    with self.cached_session() as sess:
      freeze.create_inference_graph(
          wanted_words='a,b,c,d',
          sample_rate=16000,
          clip_duration_ms=1000.0,
          clip_stride_ms=30.0,
          window_size_ms=30.0,
          window_stride_ms=10.0,
          feature_bin_count=40,
          model_architecture='conv',
          preprocess='micro')
      self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0'))
      self.assertIsNotNone(
          sess.graph.get_tensor_by_name('decoded_sample_data:0'))
      self.assertIsNotNone(sess.graph.get_tensor_by_name('labels_softmax:0'))

  @test_util.run_deprecated_v1
  def testFeatureBinCount(self):
    with self.cached_session() as sess:
      freeze.create_inference_graph(
          wanted_words='a,b,c,d',
          sample_rate=16000,
          clip_duration_ms=1000.0,
          clip_stride_ms=30.0,
          window_size_ms=30.0,
          window_stride_ms=10.0,
          feature_bin_count=80,
          model_architecture='conv',
          preprocess='average')
      self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0'))
      self.assertIsNotNone(
          sess.graph.get_tensor_by_name('decoded_sample_data:0'))
      self.assertIsNotNone(sess.graph.get_tensor_by_name('labels_softmax:0'))
      ops = [node.op for node in sess.graph_def.node]
      self.assertEqual(0, ops.count('Mfcc'))

  @test_util.run_deprecated_v1
  def testCreateSavedModel(self):
    tmp_dir = self.get_temp_dir()
    saved_model_path = os.path.join(tmp_dir, 'saved_model')
    with self.cached_session() as sess:
      input_tensor, output_tensor = freeze.create_inference_graph(
          wanted_words='a,b,c,d',
          sample_rate=16000,
          clip_duration_ms=1000.0,
          clip_stride_ms=30.0,
          window_size_ms=30.0,
          window_stride_ms=10.0,
          feature_bin_count=40,
          model_architecture='conv',
          preprocess='micro')
      global_variables_initializer().run()
      graph_util.convert_variables_to_constants(
          sess, sess.graph_def, ['labels_softmax'])
      freeze.save_saved_model(saved_model_path, sess, input_tensor,
                              output_tensor)


if __name__ == '__main__':
  test.main()

generate_streaming_test_wav_test.py

Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for test file generation for speech commands."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from tensorflow.examples.speech_commands import generate_streaming_test_wav
from tensorflow.python.platform import test


class GenerateStreamingTestWavTest(test.TestCase):

  def testMixInAudioSample(self):
    track_data = np.zeros([10000])
    sample_data = np.ones([1000])
    generate_streaming_test_wav.mix_in_audio_sample(
        track_data, 2000, sample_data, 0, 1000, 1.0, 100, 100)
    self.assertNear(1.0, track_data[2500], 0.0001)
    self.assertNear(0.0, track_data[3500], 0.0001)


if __name__ == "__main__":
  test.main()

input_data_test.py

Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for data input for speech commands."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

import numpy as np
import tensorflow as tf


from tensorflow.examples.speech_commands import input_data
from tensorflow.examples.speech_commands import models
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test


class InputDataTest(test.TestCase):

  def _getWavData(self):
    with self.cached_session():
      sample_data = tf.zeros([32000, 2])
      wav_encoder = tf.audio.encode_wav(sample_data, 16000)
      wav_data = self.evaluate(wav_encoder)
    return wav_data

  def _saveTestWavFile(self, filename, wav_data):
    with open(filename, "wb") as f:
      f.write(wav_data)

  def _saveWavFolders(self, root_dir, labels, how_many):
    wav_data = self._getWavData()
    for label in labels:
      dir_name = os.path.join(root_dir, label)
      os.mkdir(dir_name)
      for i in range(how_many):
        file_path = os.path.join(dir_name, "some_audio_%d.wav" % i)
        self._saveTestWavFile(file_path, wav_data)

  def _model_settings(self):
    return {
        "desired_samples": 160,
        "fingerprint_size": 40,
        "label_count": 4,
        "window_size_samples": 100,
        "window_stride_samples": 100,
        "fingerprint_width": 40,
        "preprocess": "mfcc",
    }

  def _runGetDataTest(self, preprocess, window_length_ms):
    tmp_dir = self.get_temp_dir()
    wav_dir = os.path.join(tmp_dir, "wavs")
    os.mkdir(wav_dir)
    self._saveWavFolders(wav_dir, ["a", "b", "c"], 100)
    background_dir = os.path.join(wav_dir, "_background_noise_")
    os.mkdir(background_dir)
    wav_data = self._getWavData()
    for i in range(10):
      file_path = os.path.join(background_dir, "background_audio_%d.wav" % i)
      self._saveTestWavFile(file_path, wav_data)
    model_settings = models.prepare_model_settings(
        4, 16000, 1000, window_length_ms, 20, 40, preprocess)
    with self.cached_session() as sess:
      audio_processor = input_data.AudioProcessor(
          "", wav_dir, 10, 10, ["a", "b"], 10, 10, model_settings, tmp_dir)
      result_data, result_labels = audio_processor.get_data(
          10, 0, model_settings, 0.3, 0.1, 100, "training", sess)
      self.assertEqual(10, len(result_data))
      self.assertEqual(10, len(result_labels))

  def testPrepareWordsList(self):
    words_list = ["a", "b"]
    self.assertGreater(
        len(input_data.prepare_words_list(words_list)), len(words_list))

  def testWhichSet(self):
    self.assertEqual(
        input_data.which_set("foo.wav", 10, 10),
        input_data.which_set("foo.wav", 10, 10))
    self.assertEqual(
        input_data.which_set("foo_nohash_0.wav", 10, 10),
        input_data.which_set("foo_nohash_1.wav", 10, 10))

  @test_util.run_deprecated_v1
  def testPrepareDataIndex(self):
    tmp_dir = self.get_temp_dir()
    self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100)
    audio_processor = input_data.AudioProcessor("", tmp_dir, 10, 10,
                                                ["a", "b"], 10, 10,
                                                self._model_settings(), tmp_dir)
    self.assertLess(0, audio_processor.set_size("training"))
    self.assertIn("training", audio_processor.data_index)
    self.assertIn("validation", audio_processor.data_index)
    self.assertIn("testing", audio_processor.data_index)
    self.assertEqual(input_data.UNKNOWN_WORD_INDEX,
                     audio_processor.word_to_index["c"])

  def testPrepareDataIndexEmpty(self):
    tmp_dir = self.get_temp_dir()
    self._saveWavFolders(tmp_dir, ["a", "b", "c"], 0)
    with self.assertRaises(Exception) as e:
      _ = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b"], 10, 10,
                                    self._model_settings(), tmp_dir)
    self.assertIn("No .wavs found", str(e.exception))

  def testPrepareDataIndexMissing(self):
    tmp_dir = self.get_temp_dir()
    self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100)
    with self.assertRaises(Exception) as e:
      _ = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b", "d"], 10,
                                    10, self._model_settings(), tmp_dir)
    self.assertIn("Expected to find", str(e.exception))

  @test_util.run_deprecated_v1
  def testPrepareBackgroundData(self):
    tmp_dir = self.get_temp_dir()
    background_dir = os.path.join(tmp_dir, "_background_noise_")
    os.mkdir(background_dir)
    wav_data = self._getWavData()
    for i in range(10):
      file_path = os.path.join(background_dir, "background_audio_%d.wav" % i)
      self._saveTestWavFile(file_path, wav_data)
    self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100)
    audio_processor = input_data.AudioProcessor("", tmp_dir, 10, 10,
                                                ["a", "b"], 10, 10,
                                                self._model_settings(), tmp_dir)
    self.assertEqual(10, len(audio_processor.background_data))

  def testLoadWavFile(self):
    tmp_dir = self.get_temp_dir()
    file_path = os.path.join(tmp_dir, "load_test.wav")
    wav_data = self._getWavData()
    self._saveTestWavFile(file_path, wav_data)
    sample_data = input_data.load_wav_file(file_path)
    self.assertIsNotNone(sample_data)

  def testSaveWavFile(self):
    tmp_dir = self.get_temp_dir()
    file_path = os.path.join(tmp_dir, "load_test.wav")
    save_data = np.zeros([16000, 1])
    input_data.save_wav_file(file_path, save_data, 16000)
    loaded_data = input_data.load_wav_file(file_path)
    self.assertIsNotNone(loaded_data)
    self.assertEqual(16000, len(loaded_data))

  @test_util.run_deprecated_v1
  def testPrepareProcessingGraph(self):
    tmp_dir = self.get_temp_dir()
    wav_dir = os.path.join(tmp_dir, "wavs")
    os.mkdir(wav_dir)
    self._saveWavFolders(wav_dir, ["a", "b", "c"], 100)
    background_dir = os.path.join(wav_dir, "_background_noise_")
    os.mkdir(background_dir)
    wav_data = self._getWavData()
    for i in range(10):
      file_path = os.path.join(background_dir, "background_audio_%d.wav" % i)
      self._saveTestWavFile(file_path, wav_data)
    model_settings = {
        "desired_samples": 160,
        "fingerprint_size": 40,
        "label_count": 4,
        "window_size_samples": 100,
        "window_stride_samples": 100,
        "fingerprint_width": 40,
        "preprocess": "mfcc",
    }
    audio_processor = input_data.AudioProcessor("", wav_dir, 10, 10, ["a", "b"],
                                                10, 10, model_settings, tmp_dir)
    self.assertIsNotNone(audio_processor.wav_filename_placeholder_)
    self.assertIsNotNone(audio_processor.foreground_volume_placeholder_)
    self.assertIsNotNone(audio_processor.time_shift_padding_placeholder_)
    self.assertIsNotNone(audio_processor.time_shift_offset_placeholder_)
    self.assertIsNotNone(audio_processor.background_data_placeholder_)
    self.assertIsNotNone(audio_processor.background_volume_placeholder_)
    self.assertIsNotNone(audio_processor.output_)

  @test_util.run_deprecated_v1
  def testGetDataAverage(self):
    self._runGetDataTest("average", 10)

  @test_util.run_deprecated_v1
  def testGetDataAverageLongWindow(self):
    self._runGetDataTest("average", 30)

  @test_util.run_deprecated_v1
  def testGetDataMfcc(self):
    self._runGetDataTest("mfcc", 30)

  @test_util.run_deprecated_v1
  def testGetDataMicro(self):
    self._runGetDataTest("micro", 20)

  @test_util.run_deprecated_v1
  def testGetUnprocessedData(self):
    tmp_dir = self.get_temp_dir()
    wav_dir = os.path.join(tmp_dir, "wavs")
    os.mkdir(wav_dir)
    self._saveWavFolders(wav_dir, ["a", "b", "c"], 100)
    model_settings = {
        "desired_samples": 160,
        "fingerprint_size": 40,
        "label_count": 4,
        "window_size_samples": 100,
        "window_stride_samples": 100,
        "fingerprint_width": 40,
        "preprocess": "mfcc",
    }
    audio_processor = input_data.AudioProcessor("", wav_dir, 10, 10, ["a", "b"],
                                                10, 10, model_settings, tmp_dir)
    result_data, result_labels = audio_processor.get_unprocessed_data(
        10, model_settings, "training")
    self.assertEqual(10, len(result_data))
    self.assertEqual(10, len(result_labels))

  @test_util.run_deprecated_v1
  def testGetFeaturesForWav(self):
    tmp_dir = self.get_temp_dir()
    wav_dir = os.path.join(tmp_dir, "wavs")
    os.mkdir(wav_dir)
    self._saveWavFolders(wav_dir, ["a", "b", "c"], 1)
    desired_samples = 1600
    model_settings = {
        "desired_samples": desired_samples,
        "fingerprint_size": 40,
        "label_count": 4,
        "window_size_samples": 100,
        "window_stride_samples": 100,
        "fingerprint_width": 40,
        "average_window_width": 6,
        "preprocess": "average",
    }
    with self.cached_session() as sess:
      audio_processor = input_data.AudioProcessor(
          "", wav_dir, 10, 10, ["a", "b"], 10, 10, model_settings, tmp_dir)
      sample_data = np.zeros([desired_samples, 1])
      for i in range(desired_samples):
        phase = i % 4
        if phase == 0:
          sample_data[i, 0] = 0
        elif phase == 1:
          sample_data[i, 0] = -1
        elif phase == 2:
          sample_data[i, 0] = 0
        elif phase == 3:
          sample_data[i, 0] = 1
      test_wav_path = os.path.join(tmp_dir, "test_wav.wav")
      input_data.save_wav_file(test_wav_path, sample_data, 16000)

      results = audio_processor.get_features_for_wav(test_wav_path,
                                                     model_settings, sess)
      spectrogram = results[0]
      self.assertEqual(1, spectrogram.shape[0])
      self.assertEqual(16, spectrogram.shape[1])
      self.assertEqual(11, spectrogram.shape[2])
      self.assertNear(0, spectrogram[0, 0, 0], 0.1)
      self.assertNear(200, spectrogram[0, 0, 5], 0.1)

  def testGetFeaturesRange(self):
    model_settings = {
        "preprocess": "average",
    }
    features_min, _ = input_data.get_features_range(model_settings)
    self.assertNear(0.0, features_min, 1e-5)

  def testGetMfccFeaturesRange(self):
    model_settings = {
        "preprocess": "mfcc",
    }
    features_min, features_max = input_data.get_features_range(model_settings)
    self.assertLess(features_min, features_max)


if __name__ == "__main__":
  test.main()

label_wav_test.py

Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for WAVE file labeling tool."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

import tensorflow as tf

from tensorflow.examples.speech_commands import label_wav
from tensorflow.python.platform import test


class LabelWavTest(test.TestCase):

  def _getWavData(self):
    with self.cached_session():
      sample_data = tf.zeros([1000, 2])
      wav_encoder = tf.audio.encode_wav(sample_data, 16000)
      wav_data = self.evaluate(wav_encoder)
    return wav_data

  def _saveTestWavFile(self, filename, wav_data):
    with open(filename, "wb") as f:
      f.write(wav_data)

  def testLabelWav(self):
    tmp_dir = self.get_temp_dir()
    wav_data = self._getWavData()
    wav_filename = os.path.join(tmp_dir, "wav_file.wav")
    self._saveTestWavFile(wav_filename, wav_data)
    input_name = "test_input"
    output_name = "test_output"
    graph_filename = os.path.join(tmp_dir, "test_graph.pb")
    with tf.compat.v1.Session() as sess:
      tf.compat.v1.placeholder(tf.string, name=input_name)
      tf.zeros([1, 3], name=output_name)
      with open(graph_filename, "wb") as f:
        f.write(sess.graph.as_graph_def().SerializeToString())
    labels_filename = os.path.join(tmp_dir, "test_labels.txt")
    with open(labels_filename, "w") as f:
      f.write("a\nb\nc\n")
    label_wav.label_wav(wav_filename, labels_filename, graph_filename,
                        input_name + ":0", output_name + ":0", 3)


if __name__ == "__main__":
  test.main()

models_test.py

Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for speech commands models."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from tensorflow.examples.speech_commands import models
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test


class ModelsTest(test.TestCase):

  def _modelSettings(self):
    return models.prepare_model_settings(
        label_count=10,
        sample_rate=16000,
        clip_duration_ms=1000,
        window_size_ms=20,
        window_stride_ms=10,
        feature_bin_count=40,
        preprocess="mfcc")

  def testPrepareModelSettings(self):
    self.assertIsNotNone(
        models.prepare_model_settings(
            label_count=10,
            sample_rate=16000,
            clip_duration_ms=1000,
            window_size_ms=20,
            window_stride_ms=10,
            feature_bin_count=40,
            preprocess="mfcc"))

  @test_util.run_deprecated_v1
  def testCreateModelConvTraining(self):
    model_settings = self._modelSettings()
    with self.cached_session() as sess:
      fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
      logits, dropout_rate = models.create_model(
          fingerprint_input, model_settings, "conv", True)
      self.assertIsNotNone(logits)
      self.assertIsNotNone(dropout_rate)
      self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
      self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_rate.name))

  @test_util.run_deprecated_v1
  def testCreateModelConvInference(self):
    model_settings = self._modelSettings()
    with self.cached_session() as sess:
      fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
      logits = models.create_model(fingerprint_input, model_settings, "conv",
                                   False)
      self.assertIsNotNone(logits)
      self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))

  @test_util.run_deprecated_v1
  def testCreateModelLowLatencyConvTraining(self):
    model_settings = self._modelSettings()
    with self.cached_session() as sess:
      fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
      logits, dropout_rate = models.create_model(
          fingerprint_input, model_settings, "low_latency_conv", True)
      self.assertIsNotNone(logits)
      self.assertIsNotNone(dropout_rate)
      self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
      self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_rate.name))

  @test_util.run_deprecated_v1
  def testCreateModelFullyConnectedTraining(self):
    model_settings = self._modelSettings()
    with self.cached_session() as sess:
      fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
      logits, dropout_rate = models.create_model(
          fingerprint_input, model_settings, "single_fc", True)
      self.assertIsNotNone(logits)
      self.assertIsNotNone(dropout_rate)
      self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
      self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_rate.name))

  def testCreateModelBadArchitecture(self):
    model_settings = self._modelSettings()
    with self.cached_session():
      fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
      with self.assertRaises(Exception) as e:
        models.create_model(fingerprint_input, model_settings,
                            "bad_architecture", True)
      self.assertIn("not recognized", str(e.exception))

  @test_util.run_deprecated_v1
  def testCreateModelTinyConvTraining(self):
    model_settings = self._modelSettings()
    with self.cached_session() as sess:
      fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
      logits, dropout_rate = models.create_model(
          fingerprint_input, model_settings, "tiny_conv", True)
      self.assertIsNotNone(logits)
      self.assertIsNotNone(dropout_rate)
      self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
      self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_rate.name))


if __name__ == "__main__":
  test.main()

recognize_commands.cc

C/C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/examples/speech_commands/recognize_commands.h"

namespace tensorflow {

RecognizeCommands::RecognizeCommands(const std::vector<string>& labels,
                                     int32 average_window_duration_ms,
                                     float detection_threshold,
                                     int32 suppression_ms, int32 minimum_count)
    : labels_(labels),
      average_window_duration_ms_(average_window_duration_ms),
      detection_threshold_(detection_threshold),
      suppression_ms_(suppression_ms),
      minimum_count_(minimum_count) {
  labels_count_ = labels.size();
  previous_top_label_ = "_silence_";
  previous_top_label_time_ = std::numeric_limits<int64>::min();
}

Status RecognizeCommands::ProcessLatestResults(const Tensor& latest_results,
                                               const int64 current_time_ms,
                                               string* found_command,
                                               float* score,
                                               bool* is_new_command) {
  if (latest_results.NumElements() != labels_count_) {
    return errors::InvalidArgument(
        "The results for recognition should contain ", labels_count_,
        " elements, but there are ", latest_results.NumElements());
  }

  if ((!previous_results_.empty()) &&
      (current_time_ms < previous_results_.front().first)) {
    return errors::InvalidArgument(
        "Results must be fed in increasing time order, but received a "
        "timestamp of ",
        current_time_ms, " that was earlier than the previous one of ",
        previous_results_.front().first);
  }

  // Add the latest results to the head of the queue.
  previous_results_.push_back({current_time_ms, latest_results});

  // Prune any earlier results that are too old for the averaging window.
  const int64 time_limit = current_time_ms - average_window_duration_ms_;
  while (previous_results_.front().first < time_limit) {
    previous_results_.pop_front();
  }

  // If there are too few results, assume the result will be unreliable and
  // bail.
  const int64 how_many_results = previous_results_.size();
  const int64 earliest_time = previous_results_.front().first;
  const int64 samples_duration = current_time_ms - earliest_time;
  if ((how_many_results < minimum_count_) ||
      (samples_duration < (average_window_duration_ms_ / 4))) {
    *found_command = previous_top_label_;
    *score = 0.0f;
    *is_new_command = false;
    return Status::OK();
  }

  // Calculate the average score across all the results in the window.
  std::vector<float> average_scores(labels_count_);
  for (const auto& previous_result : previous_results_) {
    const Tensor& scores_tensor = previous_result.second;
    auto scores_flat = scores_tensor.flat<float>();
    for (int i = 0; i < scores_flat.size(); ++i) {
      average_scores[i] += scores_flat(i) / how_many_results;
    }
  }

  // Sort the averaged results in descending score order.
  std::vector<std::pair<int, float>> sorted_average_scores;
  sorted_average_scores.reserve(labels_count_);
  for (int i = 0; i < labels_count_; ++i) {
    sorted_average_scores.push_back(
        std::pair<int, float>({i, average_scores[i]}));
  }
  std::sort(sorted_average_scores.begin(), sorted_average_scores.end(),
            [](const std::pair<int, float>& left,
               const std::pair<int, float>& right) {
              return left.second > right.second;
            });

  // See if the latest top score is enough to trigger a detection.
  const int current_top_index = sorted_average_scores[0].first;
  const string current_top_label = labels_[current_top_index];
  const float current_top_score = sorted_average_scores[0].second;
  // If we've recently had another label trigger, assume one that occurs too
  // soon afterwards is a bad result.
  int64 time_since_last_top;
  if ((previous_top_label_ == "_silence_") ||
      (previous_top_label_time_ == std::numeric_limits<int64>::min())) {
    time_since_last_top = std::numeric_limits<int64>::max();
  } else {
    time_since_last_top = current_time_ms - previous_top_label_time_;
  }
  if ((current_top_score > detection_threshold_) &&
      (current_top_label != previous_top_label_) &&
      (time_since_last_top > suppression_ms_)) {
    previous_top_label_ = current_top_label;
    previous_top_label_time_ = current_time_ms;
    *is_new_command = true;
  } else {
    *is_new_command = false;
  }
  *found_command = current_top_label;
  *score = current_top_score;

  return Status::OK();
}

}  // namespace tensorflow

recognize_commands.h

C/C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_RECOGNIZE_COMMANDS_H_
#define TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_RECOGNIZE_COMMANDS_H_

#include <deque>
#include <unordered_set>
#include <vector>

#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/types.h"

namespace tensorflow {

// This class is designed to apply a very primitive decoding model on top of the
// instantaneous results from running an audio recognition model on a single
// window of samples. It applies smoothing over time so that noisy individual
// label scores are averaged, increasing the confidence that apparent matches
// are real.
// To use it, you should create a class object with the configuration you
// want, and then feed results from running a TensorFlow model into the
// processing method. The timestamp for each subsequent call should be
// increasing from the previous, since the class is designed to process a stream
// of data over time.
class RecognizeCommands {
 public:
  // labels should be a list of the strings associated with each one-hot score.
  // The window duration controls the smoothing. Longer durations will give a
  // higher confidence that the results are correct, but may miss some commands.
  // The detection threshold has a similar effect, with high values increasing
  // the precision at the cost of recall. The minimum count controls how many
  // results need to be in the averaging window before it's seen as a reliable
  // average. This prevents erroneous results when the averaging window is
  // initially being populated for example. The suppression argument disables
  // further recognitions for a set time after one has been triggered, which can
  // help reduce spurious recognitions.
  explicit RecognizeCommands(const std::vector<string>& labels,
                             int32 average_window_duration_ms = 1000,
                             float detection_threshold = 0.2,
                             int32 suppression_ms = 500,
                             int32 minimum_count = 3);

  // Call this with the results of running a model on sample data.
  Status ProcessLatestResults(const Tensor& latest_results,
                              const int64 current_time_ms,
                              string* found_command, float* score,
                              bool* is_new_command);

 private:
  // Configuration
  std::vector<string> labels_;
  int32 average_window_duration_ms_;
  float detection_threshold_;
  int32 suppression_ms_;
  int32 minimum_count_;

  // Working variables
  std::deque<std::pair<int64, Tensor>> previous_results_;
  string previous_top_label_;
  int64 labels_count_;
  int64 previous_top_label_time_;
};

}  // namespace tensorflow

#endif  // TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_RECOGNIZE_COMMANDS_H_

recognize_commands_test.cc

C/C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/examples/speech_commands/recognize_commands.h"

#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"

namespace tensorflow {

TEST(RecognizeCommandsTest, Basic) {
  RecognizeCommands recognize_commands({"_silence_", "a", "b"});

  Tensor results(DT_FLOAT, {3});
  test::FillValues<float>(&results, {1.0f, 0.0f, 0.0f});

  string found_command;
  float score;
  bool is_new_command;
  TF_EXPECT_OK(recognize_commands.ProcessLatestResults(
      results, 0, &found_command, &score, &is_new_command));
}

TEST(RecognizeCommandsTest, FindCommands) {
  RecognizeCommands recognize_commands({"_silence_", "a", "b"}, 1000, 0.2f);

  Tensor results(DT_FLOAT, {3});

  test::FillValues<float>(&results, {0.0f, 1.0f, 0.0f});
  bool has_found_new_command = false;
  string new_command;
  for (int i = 0; i < 10; ++i) {
    string found_command;
    float score;
    bool is_new_command;
    int64 current_time_ms = 0 + (i * 100);
    TF_EXPECT_OK(recognize_commands.ProcessLatestResults(
        results, current_time_ms, &found_command, &score, &is_new_command));
    if (is_new_command) {
      EXPECT_FALSE(has_found_new_command);
      has_found_new_command = true;
      new_command = found_command;
    }
  }
  EXPECT_TRUE(has_found_new_command);
  EXPECT_EQ("a", new_command);

  test::FillValues<float>(&results, {0.0f, 0.0f, 1.0f});
  has_found_new_command = false;
  new_command = "";
  for (int i = 0; i < 10; ++i) {
    string found_command;
    float score;
    bool is_new_command;
    int64 current_time_ms = 1000 + (i * 100);
    TF_EXPECT_OK(recognize_commands.ProcessLatestResults(
        results, current_time_ms, &found_command, &score, &is_new_command));
    if (is_new_command) {
      EXPECT_FALSE(has_found_new_command);
      has_found_new_command = true;
      new_command = found_command;
    }
  }
  EXPECT_TRUE(has_found_new_command);
  EXPECT_EQ("b", new_command);
}

TEST(RecognizeCommandsTest, BadInputLength) {
  RecognizeCommands recognize_commands({"_silence_", "a", "b"}, 1000, 0.2f);

  Tensor bad_results(DT_FLOAT, {2});
  test::FillValues<float>(&bad_results, {1.0f, 0.0f});

  string found_command;
  float score;
  bool is_new_command;
  EXPECT_FALSE(recognize_commands
                   .ProcessLatestResults(bad_results, 0, &found_command, &score,
                                         &is_new_command)
                   .ok());
}

TEST(RecognizeCommandsTest, BadInputTimes) {
  RecognizeCommands recognize_commands({"_silence_", "a", "b"}, 1000, 0.2f);

  Tensor results(DT_FLOAT, {3});
  test::FillValues<float>(&results, {1.0f, 0.0f, 0.0f});

  string found_command;
  float score;
  bool is_new_command;
  TF_EXPECT_OK(recognize_commands.ProcessLatestResults(
      results, 100, &found_command, &score, &is_new_command));
  EXPECT_FALSE(recognize_commands
                   .ProcessLatestResults(results, 0, &found_command, &score,
                                         &is_new_command)
                   .ok());
}

}  // namespace tensorflow

test_streaming_accuracy.cc

C/C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
/*

Tool to create accuracy statistics from running an audio recognition model on a
continuous stream of samples.

This is designed to be an environment for running experiments on new models and
settings to understand the effects they will have in a real application. You
need to supply it with a long audio file containing sounds you want to recognize
and a text file listing the labels of each sound along with the time they occur.
With this information, and a frozen model, the tool will process the audio
stream, apply the model, and keep track of how many mistakes and successes the
model achieved.

The matched percentage is the number of sounds that were correctly classified,
as a percentage of the total number of sounds listed in the ground truth file.
A correct classification is when the right label is chosen within a short time
of the expected ground truth, where the time tolerance is controlled by the
'time_tolerance_ms' command line flag.

The wrong percentage is how many sounds triggered a detection (the classifier
figured out it wasn't silence or background noise), but the detected class was
wrong. This is also a percentage of the total number of ground truth sounds.

The false positive percentage is how many sounds were detected when there was
only silence or background noise. This is also expressed as a percentage of the
total number of ground truth sounds, though since it can be large it may go
above 100%.

The easiest way to get an audio file and labels to test with is by using the
'generate_streaming_test_wav' script. This will synthesize a test file with
randomly placed sounds and background noise, and output a text file with the
ground truth.

If you want to test natural data, you need to use a .wav with the same sample
rate as your model (often 16,000 samples per second), and note down where the
sounds occur in time. Save this information out as a comma-separated text file,
where the first column is the label and the second is the time in seconds from
the start of the file that it occurs.

Here's an example of how to run the tool:

bazel run tensorflow/examples/speech_commands:test_streaming_accuracy -- \
--wav=/tmp/streaming_test_bg.wav \
--graph=/tmp/conv_frozen.pb \
--labels=/tmp/speech_commands_train/conv_labels.txt \
--ground_truth=/tmp/streaming_test_labels.txt --verbose \
--clip_duration_ms=1000 --detection_threshold=0.70 --average_window_ms=500 \
--suppression_ms=500 --time_tolerance_ms=1500

 */

#include <fstream>
#include <iomanip>
#include <unordered_set>
#include <vector>

#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/wav/wav_io.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/examples/speech_commands/accuracy_utils.h"
#include "tensorflow/examples/speech_commands/recognize_commands.h"

// These are all common classes it's handy to reference with no namespace.
using tensorflow::Flag;
using tensorflow::Status;
using tensorflow::Tensor;
using tensorflow::int32;
using tensorflow::int64;
using tensorflow::string;
using tensorflow::uint16;
using tensorflow::uint32;

namespace {

// Reads a model graph definition from disk, and creates a session object you
// can use to run it.
Status LoadGraph(const string& graph_file_name,
                 std::unique_ptr<tensorflow::Session>* session) {
  tensorflow::GraphDef graph_def;
  Status load_graph_status =
      ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
  if (!load_graph_status.ok()) {
    return tensorflow::errors::NotFound("Failed to load compute graph at '",
                                        graph_file_name, "'");
  }
  session->reset(tensorflow::NewSession(tensorflow::SessionOptions()));
  Status session_create_status = (*session)->Create(graph_def);
  if (!session_create_status.ok()) {
    return session_create_status;
  }
  return Status::OK();
}

// Takes a file name, and loads a list of labels from it, one per line, and
// returns a vector of the strings.
Status ReadLabelsFile(const string& file_name, std::vector<string>* result) {
  std::ifstream file(file_name);
  if (!file) {
    return tensorflow::errors::NotFound("Labels file '", file_name,
                                        "' not found.");
  }
  result->clear();
  string line;
  while (std::getline(file, line)) {
    result->push_back(line);
  }
  return Status::OK();
}

}  // namespace

int main(int argc, char* argv[]) {
  string wav = "";
  string graph = "";
  string labels = "";
  string ground_truth = "";
  string input_data_name = "decoded_sample_data:0";
  string input_rate_name = "decoded_sample_data:1";
  string output_name = "labels_softmax";
  int32 clip_duration_ms = 1000;
  int32 clip_stride_ms = 30;
  int32 average_window_ms = 500;
  int32 time_tolerance_ms = 750;
  int32 suppression_ms = 1500;
  float detection_threshold = 0.7f;
  bool verbose = false;
  std::vector<Flag> flag_list = {
      Flag("wav", &wav, "audio file to be identified"),
      Flag("graph", &graph, "model to be executed"),
      Flag("labels", &labels, "path to file containing labels"),
      Flag("ground_truth", &ground_truth,
           "path to file containing correct times and labels of words in the "
           "audio as <word>,<timestamp in ms> lines"),
      Flag("input_data_name", &input_data_name,
           "name of input data node in model"),
      Flag("input_rate_name", &input_rate_name,
           "name of input sample rate node in model"),
      Flag("output_name", &output_name, "name of output node in model"),
      Flag("clip_duration_ms", &clip_duration_ms,
           "length of recognition window"),
      Flag("average_window_ms", &average_window_ms,
           "length of window to smooth results over"),
      Flag("time_tolerance_ms", &time_tolerance_ms,
           "maximum gap allowed between a recognition and ground truth"),
      Flag("suppression_ms", &suppression_ms,
           "how long to ignore others for after a recognition"),
      Flag("clip_stride_ms", &clip_stride_ms, "how often to run recognition"),
      Flag("detection_threshold", &detection_threshold,
           "what score is required to trigger detection of a word"),
      Flag("verbose", &verbose, "whether to log extra debugging information"),
  };
  string usage = tensorflow::Flags::Usage(argv[0], flag_list);
  const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
  if (!parse_result) {
    LOG(ERROR) << usage;
    return -1;
  }

  // We need to call this to set up global state for TensorFlow.
  tensorflow::port::InitMain(argv[0], &argc, &argv);
  if (argc > 1) {
    LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
    return -1;
  }

  // First we load and initialize the model.
  std::unique_ptr<tensorflow::Session> session;
  Status load_graph_status = LoadGraph(graph, &session);
  if (!load_graph_status.ok()) {
    LOG(ERROR) << load_graph_status;
    return -1;
  }

  std::vector<string> labels_list;
  Status read_labels_status = ReadLabelsFile(labels, &labels_list);
  if (!read_labels_status.ok()) {
    LOG(ERROR) << read_labels_status;
    return -1;
  }

  std::vector<std::pair<string, tensorflow::int64>> ground_truth_list;
  Status read_ground_truth_status =
      tensorflow::ReadGroundTruthFile(ground_truth, &ground_truth_list);
  if (!read_ground_truth_status.ok()) {
    LOG(ERROR) << read_ground_truth_status;
    return -1;
  }

  string wav_string;
  Status read_wav_status = tensorflow::ReadFileToString(
      tensorflow::Env::Default(), wav, &wav_string);
  if (!read_wav_status.ok()) {
    LOG(ERROR) << read_wav_status;
    return -1;
  }
  std::vector<float> audio_data;
  uint32 sample_count;
  uint16 channel_count;
  uint32 sample_rate;
  Status decode_wav_status = tensorflow::wav::DecodeLin16WaveAsFloatVector(
      wav_string, &audio_data, &sample_count, &channel_count, &sample_rate);
  if (!decode_wav_status.ok()) {
    LOG(ERROR) << decode_wav_status;
    return -1;
  }
  if (channel_count != 1) {
    LOG(ERROR) << "Only mono .wav files can be used, but input has "
               << channel_count << " channels.";
    return -1;
  }

  const int64 clip_duration_samples = (clip_duration_ms * sample_rate) / 1000;
  const int64 clip_stride_samples = (clip_stride_ms * sample_rate) / 1000;
  Tensor audio_data_tensor(tensorflow::DT_FLOAT,
                           tensorflow::TensorShape({clip_duration_samples, 1}));

  Tensor sample_rate_tensor(tensorflow::DT_INT32, tensorflow::TensorShape({}));
  sample_rate_tensor.scalar<int32>()() = sample_rate;

  tensorflow::RecognizeCommands recognize_commands(
      labels_list, average_window_ms, detection_threshold, suppression_ms);

  std::vector<std::pair<string, int64>> all_found_words;
  tensorflow::StreamingAccuracyStats previous_stats;

  const int64 audio_data_end = (sample_count - clip_duration_samples);
  for (int64 audio_data_offset = 0; audio_data_offset < audio_data_end;
       audio_data_offset += clip_stride_samples) {
    const float* input_start = &(audio_data[audio_data_offset]);
    const float* input_end = input_start + clip_duration_samples;
    std::copy(input_start, input_end, audio_data_tensor.flat<float>().data());

    // Actually run the audio through the model.
    std::vector<Tensor> outputs;
    Status run_status = session->Run({{input_data_name, audio_data_tensor},
                                      {input_rate_name, sample_rate_tensor}},
                                     {output_name}, {}, &outputs);
    if (!run_status.ok()) {
      LOG(ERROR) << "Running model failed: " << run_status;
      return -1;
    }

    const int64 current_time_ms = (audio_data_offset * 1000) / sample_rate;
    string found_command;
    float score;
    bool is_new_command;
    Status recognize_status = recognize_commands.ProcessLatestResults(
        outputs[0], current_time_ms, &found_command, &score, &is_new_command);
    if (!recognize_status.ok()) {
      LOG(ERROR) << "Recognition processing failed: " << recognize_status;
      return -1;
    }

    if (is_new_command && (found_command != "_silence_")) {
      all_found_words.push_back({found_command, current_time_ms});
      if (verbose) {
        tensorflow::StreamingAccuracyStats stats;
        tensorflow::CalculateAccuracyStats(ground_truth_list, all_found_words,
                                           current_time_ms, time_tolerance_ms,
                                           &stats);
        int32 false_positive_delta = stats.how_many_false_positives -
                                     previous_stats.how_many_false_positives;
        int32 correct_delta = stats.how_many_correct_words -
                              previous_stats.how_many_correct_words;
        int32 wrong_delta =
            stats.how_many_wrong_words - previous_stats.how_many_wrong_words;
        string recognition_state;
        if (false_positive_delta == 1) {
          recognition_state = " (False Positive)";
        } else if (correct_delta == 1) {
          recognition_state = " (Correct)";
        } else if (wrong_delta == 1) {
          recognition_state = " (Wrong)";
        } else {
          LOG(ERROR) << "Unexpected state in statistics";
        }
        LOG(INFO) << current_time_ms << "ms: " << found_command << ": " << score
                  << recognition_state;
        previous_stats = stats;
        tensorflow::PrintAccuracyStats(stats);
      }
    }
  }

  tensorflow::StreamingAccuracyStats stats;
  tensorflow::CalculateAccuracyStats(ground_truth_list, all_found_words, -1,
                                     time_tolerance_ms, &stats);
  tensorflow::PrintAccuracyStats(stats);

  return 0;
}

train_test.py

Python
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for data input for speech commands."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import unittest

import tensorflow as tf

from tensorflow.examples.speech_commands import train
from tensorflow.python.framework import test_util
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test


def requires_contrib(test_method):
  try:
    _ = tf.contrib
  except AttributeError:
    test_method = unittest.skip(
        'This test requires tf.contrib:\n    `pip install tensorflow<=1.15`')(
            test_method)

  return test_method


# Used to convert a dictionary into an object, for mocking parsed flags.
class DictStruct(object):

  def __init__(self, **entries):
    self.__dict__.update(entries)


class TrainTest(test.TestCase):

  def _getWavData(self):
    with self.cached_session():
      sample_data = tf.zeros([32000, 2])
      wav_encoder = tf.audio.encode_wav(sample_data, 16000)
      wav_data = self.evaluate(wav_encoder)
    return wav_data

  def _saveTestWavFile(self, filename, wav_data):
    with open(filename, 'wb') as f:
      f.write(wav_data)

  def _saveWavFolders(self, root_dir, labels, how_many):
    wav_data = self._getWavData()
    for label in labels:
      dir_name = os.path.join(root_dir, label)
      os.mkdir(dir_name)
      for i in range(how_many):
        file_path = os.path.join(dir_name, 'some_audio_%d.wav' % i)
        self._saveTestWavFile(file_path, wav_data)

  def _prepareDummyTrainingData(self):
    tmp_dir = self.get_temp_dir()
    wav_dir = os.path.join(tmp_dir, 'wavs')
    os.mkdir(wav_dir)
    self._saveWavFolders(wav_dir, ['a', 'b', 'c'], 100)
    background_dir = os.path.join(wav_dir, '_background_noise_')
    os.mkdir(background_dir)
    wav_data = self._getWavData()
    for i in range(10):
      file_path = os.path.join(background_dir, 'background_audio_%d.wav' % i)
      self._saveTestWavFile(file_path, wav_data)
    return wav_dir

  def _getDefaultFlags(self):
    flags = {
        'data_url': '',
        'data_dir': self._prepareDummyTrainingData(),
        'wanted_words': 'a,b,c',
        'sample_rate': 16000,
        'clip_duration_ms': 1000,
        'window_size_ms': 30,
        'window_stride_ms': 20,
        'feature_bin_count': 40,
        'preprocess': 'mfcc',
        'silence_percentage': 25,
        'unknown_percentage': 25,
        'validation_percentage': 10,
        'testing_percentage': 10,
        'summaries_dir': os.path.join(self.get_temp_dir(), 'summaries'),
        'train_dir': os.path.join(self.get_temp_dir(), 'train'),
        'time_shift_ms': 100,
        'how_many_training_steps': '2',
        'learning_rate': '0.01',
        'quantize': False,
        'model_architecture': 'conv',
        'check_nans': False,
        'start_checkpoint': '',
        'batch_size': 1,
        'background_volume': 0.25,
        'background_frequency': 0.8,
        'eval_step_interval': 1,
        'save_step_interval': 1,
        'verbosity': tf.compat.v1.logging.INFO,
        'optimizer': 'gradient_descent'
    }
    return DictStruct(**flags)

  @test_util.run_deprecated_v1
  def testTrain(self):
    train.FLAGS = self._getDefaultFlags()
    train.main('')
    self.assertTrue(
        gfile.Exists(
            os.path.join(train.FLAGS.train_dir,
                         train.FLAGS.model_architecture + '.pbtxt')))
    self.assertTrue(
        gfile.Exists(
            os.path.join(train.FLAGS.train_dir,
                         train.FLAGS.model_architecture + '_labels.txt')))
    self.assertTrue(
        gfile.Exists(
            os.path.join(train.FLAGS.train_dir,
                         train.FLAGS.model_architecture + '.ckpt-1.meta')))


if __name__ == '__main__':
  test.main()

wav_to_features_test.py

Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for data input for speech commands."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

import tensorflow as tf

from tensorflow.examples.speech_commands import wav_to_features
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test


class WavToFeaturesTest(test.TestCase):

  def _getWavData(self):
    with self.cached_session():
      sample_data = tf.zeros([32000, 2])
      wav_encoder = tf.audio.encode_wav(sample_data, 16000)
      wav_data = self.evaluate(wav_encoder)
    return wav_data

  def _saveTestWavFile(self, filename, wav_data):
    with open(filename, "wb") as f:
      f.write(wav_data)

  def _saveWavFolders(self, root_dir, labels, how_many):
    wav_data = self._getWavData()
    for label in labels:
      dir_name = os.path.join(root_dir, label)
      os.mkdir(dir_name)
      for i in range(how_many):
        file_path = os.path.join(dir_name, "some_audio_%d.wav" % i)
        self._saveTestWavFile(file_path, wav_data)

  @test_util.run_deprecated_v1
  def testWavToFeatures(self):
    tmp_dir = self.get_temp_dir()
    wav_dir = os.path.join(tmp_dir, "wavs")
    os.mkdir(wav_dir)
    self._saveWavFolders(wav_dir, ["a", "b", "c"], 100)
    input_file_path = os.path.join(tmp_dir, "input.wav")
    output_file_path = os.path.join(tmp_dir, "output.c")
    wav_data = self._getWavData()
    self._saveTestWavFile(input_file_path, wav_data)
    wav_to_features.wav_to_features(16000, 1000, 10, 10, 40, True, "average",
                                    input_file_path, output_file_path)
    with open(output_file_path, "rb") as f:
      content = f.read()
      self.assertIn(b"const unsigned char g_input_data", content)

  @test_util.run_deprecated_v1
  def testWavToFeaturesMicro(self):
    tmp_dir = self.get_temp_dir()
    wav_dir = os.path.join(tmp_dir, "wavs")
    os.mkdir(wav_dir)
    self._saveWavFolders(wav_dir, ["a", "b", "c"], 100)
    input_file_path = os.path.join(tmp_dir, "input.wav")
    output_file_path = os.path.join(tmp_dir, "output.c")
    wav_data = self._getWavData()
    self._saveTestWavFile(input_file_path, wav_data)
    wav_to_features.wav_to_features(16000, 1000, 10, 10, 40, True, "micro",
                                    input_file_path, output_file_path)
    with open(output_file_path, "rb") as f:
      content = f.read()
      self.assertIn(b"const unsigned char g_input_data", content)


if __name__ == "__main__":
  test.main()

models.py

Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model definitions for simple speech recognition.

"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math

import tensorflow as tf


def _next_power_of_two(x):
  """Calculates the smallest enclosing power of two for an input.

  Args:
    x: Positive float or integer number.

  Returns:
    Next largest power of two integer.
  """
  return 1 if x == 0 else 2**(int(x) - 1).bit_length()


def prepare_model_settings(label_count, sample_rate, clip_duration_ms,
                           window_size_ms, window_stride_ms, feature_bin_count,
                           preprocess):
  """Calculates common settings needed for all models.

  Args:
    label_count: How many classes are to be recognized.
    sample_rate: Number of audio samples per second.
    clip_duration_ms: Length of each audio clip to be analyzed.
    window_size_ms: Duration of frequency analysis window.
    window_stride_ms: How far to move in time between frequency windows.
    feature_bin_count: Number of frequency bins to use for analysis.
    preprocess: How the spectrogram is processed to produce features.

  Returns:
    Dictionary containing common settings.

  Raises:
    ValueError: If the preprocessing mode isn't recognized.
  """
  desired_samples = int(sample_rate * clip_duration_ms / 1000)
  window_size_samples = int(sample_rate * window_size_ms / 1000)
  window_stride_samples = int(sample_rate * window_stride_ms / 1000)
  length_minus_window = (desired_samples - window_size_samples)
  if length_minus_window < 0:
    spectrogram_length = 0
  else:
    spectrogram_length = 1 + int(length_minus_window / window_stride_samples)
  if preprocess == 'average':
    fft_bin_count = 1 + (_next_power_of_two(window_size_samples) / 2)
    average_window_width = int(math.floor(fft_bin_count / feature_bin_count))
    fingerprint_width = int(math.ceil(fft_bin_count / average_window_width))
  elif preprocess == 'mfcc':
    average_window_width = -1
    fingerprint_width = feature_bin_count
  elif preprocess == 'micro':
    average_window_width = -1
    fingerprint_width = feature_bin_count
  else:
    raise ValueError('Unknown preprocess mode "%s" (should be "mfcc",'
                     ' "average", or "micro")' % (preprocess))
  fingerprint_size = fingerprint_width * spectrogram_length
  return {
      'desired_samples': desired_samples,
      'window_size_samples': window_size_samples,
      'window_stride_samples': window_stride_samples,
      'spectrogram_length': spectrogram_length,
      'fingerprint_width': fingerprint_width,
      'fingerprint_size': fingerprint_size,
      'label_count': label_count,
      'sample_rate': sample_rate,
      'preprocess': preprocess,
      'average_window_width': average_window_width,
  }


def create_model(fingerprint_input, model_settings, model_architecture,
                 is_training, runtime_settings=None):
  """Builds a model of the requested architecture compatible with the settings.

  There are many possible ways of deriving predictions from a spectrogram
  input, so this function provides an abstract interface for creating different
  kinds of models in a black-box way. You need to pass in a TensorFlow node as
  the 'fingerprint' input, and this should output a batch of 1D features that
  describe the audio. Typically this will be derived from a spectrogram that's
  been run through an MFCC, but in theory it can be any feature vector of the
  size specified in model_settings['fingerprint_size'].

  The function will build the graph it needs in the current TensorFlow graph,
  and return the tensorflow output that will contain the 'logits' input to the
  softmax prediction process. If training flag is on, it will also return a
  placeholder node that can be used to control the dropout amount.

  See the implementations below for the possible model architectures that can be
  requested.

  Args:
    fingerprint_input: TensorFlow node that will output audio feature vectors.
    model_settings: Dictionary of information about the model.
    model_architecture: String specifying which kind of model to create.
    is_training: Whether the model is going to be used for training.
    runtime_settings: Dictionary of information about the runtime.

  Returns:
    TensorFlow node outputting logits results, and optionally a dropout
    placeholder.

  Raises:
    Exception: If the architecture type isn't recognized.
  """
  if model_architecture == 'single_fc':
    return create_single_fc_model(fingerprint_input, model_settings,
                                  is_training)
  elif model_architecture == 'conv':
    return create_conv_model(fingerprint_input, model_settings, is_training)
  elif model_architecture == 'low_latency_conv':
    return create_low_latency_conv_model(fingerprint_input, model_settings,
                                         is_training)
  elif model_architecture == 'low_latency_svdf':
    return create_low_latency_svdf_model(fingerprint_input, model_settings,
                                         is_training, runtime_settings)
  elif model_architecture == 'tiny_conv':
    return create_tiny_conv_model(fingerprint_input, model_settings,
                                  is_training)
  elif model_architecture == 'tiny_embedding_conv':
    return create_tiny_embedding_conv_model(fingerprint_input, model_settings,
                                            is_training)
  else:
    raise Exception('model_architecture argument "' + model_architecture +
                    '" not recognized, should be one of "single_fc", "conv",' +
                    ' "low_latency_conv, "low_latency_svdf",' +
                    ' "tiny_conv", or "tiny_embedding_conv"')


def load_variables_from_checkpoint(sess, start_checkpoint):
  """Utility function to centralize checkpoint restoration.

  Args:
    sess: TensorFlow session.
    start_checkpoint: Path to saved checkpoint on disk.
  """
  saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
  saver.restore(sess, start_checkpoint)


def create_single_fc_model(fingerprint_input, model_settings, is_training):
  """Builds a model with a single hidden fully-connected layer.

  This is a very simple model with just one matmul and bias layer. As you'd
  expect, it doesn't produce very accurate results, but it is very fast and
  simple, so it's useful for sanity testing.

  Here's the layout of the graph:

  (fingerprint_input)
          v
      [MatMul]<-(weights)
          v
      [BiasAdd]<-(bias)
          v

  Args:
    fingerprint_input: TensorFlow node that will output audio feature vectors.
    model_settings: Dictionary of information about the model.
    is_training: Whether the model is going to be used for training.

  Returns:
    TensorFlow node outputting logits results, and optionally a dropout
    placeholder.
  """
  if is_training:
    dropout_rate = tf.compat.v1.placeholder(tf.float32, name='dropout_rate')
  fingerprint_size = model_settings['fingerprint_size']
  label_count = model_settings['label_count']
  weights = tf.compat.v1.get_variable(
      name='weights',
      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.001),
      shape=[fingerprint_size, label_count])
  bias = tf.compat.v1.get_variable(name='bias',
                                   initializer=tf.compat.v1.zeros_initializer,
                                   shape=[label_count])
  logits = tf.matmul(fingerprint_input, weights) + bias
  if is_training:
    return logits, dropout_rate
  else:
    return logits


def create_conv_model(fingerprint_input, model_settings, is_training):
  """Builds a standard convolutional model.

  This is roughly the network labeled as 'cnn-trad-fpool3' in the
  'Convolutional Neural Networks for Small-footprint Keyword Spotting' paper:
  http://www.isca-speech.org/archive/interspeech_2015/papers/i15_1478.pdf

  Here's the layout of the graph:

  (fingerprint_input)
          v
      [Conv2D]<-(weights)
          v
      [BiasAdd]<-(bias)
          v
        [Relu]
          v
      [MaxPool]
          v
      [Conv2D]<-(weights)
          v
      [BiasAdd]<-(bias)
          v
        [Relu]
          v
      [MaxPool]
          v
      [MatMul]<-(weights)
          v
      [BiasAdd]<-(bias)
          v

  This produces fairly good quality results, but can involve a large number of
  weight parameters and computations. For a cheaper alternative from the same
  paper with slightly less accuracy, see 'low_latency_conv' below.

  During training, dropout nodes are introduced after each relu, controlled by a
  placeholder.

  Args:
    fingerprint_input: TensorFlow node that will output audio feature vectors.
    model_settings: Dictionary of information about the model.
    is_training: Whether the model is going to be used for training.

  Returns:
    TensorFlow node outputting logits results, and optionally a dropout
    placeholder.
  """
  if is_training:
    dropout_rate = tf.compat.v1.placeholder(tf.float32, name='dropout_rate')
  input_frequency_size = model_settings['fingerprint_width']
  input_time_size = model_settings['spectrogram_length']
  fingerprint_4d = tf.reshape(fingerprint_input,
                              [-1, input_time_size, input_frequency_size, 1])
  first_filter_width = 8
  first_filter_height = 20
  first_filter_count = 64
  first_weights = tf.compat.v1.get_variable(
      name='first_weights',
      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
      shape=[first_filter_height, first_filter_width, 1, first_filter_count])
  first_bias = tf.compat.v1.get_variable(
      name='first_bias',
      initializer=tf.compat.v1.zeros_initializer,
      shape=[first_filter_count])

  first_conv = tf.nn.conv2d(input=fingerprint_4d,
                            filters=first_weights,
                            strides=[1, 1, 1, 1],
                            padding='SAME') + first_bias
  first_relu = tf.nn.relu(first_conv)
  if is_training:
    first_dropout = tf.nn.dropout(first_relu, rate=dropout_rate)
  else:
    first_dropout = first_relu
  max_pool = tf.nn.max_pool2d(input=first_dropout,
                              ksize=[1, 2, 2, 1],
                              strides=[1, 2, 2, 1],
                              padding='SAME')
  second_filter_width = 4
  second_filter_height = 10
  second_filter_count = 64
  second_weights = tf.compat.v1.get_variable(
      name='second_weights',
      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
      shape=[
          second_filter_height, second_filter_width, first_filter_count,
          second_filter_count
      ])
  second_bias = tf.compat.v1.get_variable(
      name='second_bias',
      initializer=tf.compat.v1.zeros_initializer,
      shape=[second_filter_count])
  second_conv = tf.nn.conv2d(input=max_pool,
                             filters=second_weights,
                             strides=[1, 1, 1, 1],
                             padding='SAME') + second_bias
  second_relu = tf.nn.relu(second_conv)
  if is_training:
    second_dropout = tf.nn.dropout(second_relu, rate=dropout_rate)
  else:
    second_dropout = second_relu
  second_conv_shape = second_dropout.get_shape()
  second_conv_output_width = second_conv_shape[2]
  second_conv_output_height = second_conv_shape[1]
  second_conv_element_count = int(
      second_conv_output_width * second_conv_output_height *
      second_filter_count)
  flattened_second_conv = tf.reshape(second_dropout,
                                     [-1, second_conv_element_count])
  label_count = model_settings['label_count']
  final_fc_weights = tf.compat.v1.get_variable(
      name='final_fc_weights',
      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
      shape=[second_conv_element_count, label_count])
  final_fc_bias = tf.compat.v1.get_variable(
      name='final_fc_bias',
      initializer=tf.compat.v1.zeros_initializer,
      shape=[label_count])
  final_fc = tf.matmul(flattened_second_conv, final_fc_weights) + final_fc_bias
  if is_training:
    return final_fc, dropout_rate
  else:
    return final_fc


def create_low_latency_conv_model(fingerprint_input, model_settings,
                                  is_training):
  """Builds a convolutional model with low compute requirements.

  This is roughly the network labeled as 'cnn-one-fstride4' in the
  'Convolutional Neural Networks for Small-footprint Keyword Spotting' paper:
  http://www.isca-speech.org/archive/interspeech_2015/papers/i15_1478.pdf

  Here's the layout of the graph:

  (fingerprint_input)
          v
      [Conv2D]<-(weights)
          v
      [BiasAdd]<-(bias)
          v
        [Relu]
          v
      [MatMul]<-(weights)
          v
      [BiasAdd]<-(bias)
          v
      [MatMul]<-(weights)
          v
      [BiasAdd]<-(bias)
          v
      [MatMul]<-(weights)
          v
      [BiasAdd]<-(bias)
          v

  This produces slightly lower quality results than the 'conv' model, but needs
  fewer weight parameters and computations.

  During training, dropout nodes are introduced after the relu, controlled by a
  placeholder.

  Args:
    fingerprint_input: TensorFlow node that will output audio feature vectors.
    model_settings: Dictionary of information about the model.
    is_training: Whether the model is going to be used for training.

  Returns:
    TensorFlow node outputting logits results, and optionally a dropout
    placeholder.
  """
  if is_training:
    dropout_rate = tf.compat.v1.placeholder(tf.float32, name='dropout_rate')
  input_frequency_size = model_settings['fingerprint_width']
  input_time_size = model_settings['spectrogram_length']
  fingerprint_4d = tf.reshape(fingerprint_input,
                              [-1, input_time_size, input_frequency_size, 1])
  first_filter_width = 8
  first_filter_height = input_time_size
  first_filter_count = 186
  first_filter_stride_x = 1
  first_filter_stride_y = 1
  first_weights = tf.compat.v1.get_variable(
      name='first_weights',
      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
      shape=[first_filter_height, first_filter_width, 1, first_filter_count])
  first_bias = tf.compat.v1.get_variable(
      name='first_bias',
      initializer=tf.compat.v1.zeros_initializer,
      shape=[first_filter_count])
  first_conv = tf.nn.conv2d(
      input=fingerprint_4d,
      filters=first_weights,
      strides=[1, first_filter_stride_y, first_filter_stride_x, 1],
      padding='VALID') + first_bias
  first_relu = tf.nn.relu(first_conv)
  if is_training:
    first_dropout = tf.nn.dropout(first_relu, rate=dropout_rate)
  else:
    first_dropout = first_relu
  first_conv_output_width = math.floor(
      (input_frequency_size - first_filter_width + first_filter_stride_x) /
      first_filter_stride_x)
  first_conv_output_height = math.floor(
      (input_time_size - first_filter_height + first_filter_stride_y) /
      first_filter_stride_y)
  first_conv_element_count = int(
      first_conv_output_width * first_conv_output_height * first_filter_count)
  flattened_first_conv = tf.reshape(first_dropout,
                                    [-1, first_conv_element_count])
  first_fc_output_channels = 128
  first_fc_weights = tf.compat.v1.get_variable(
      name='first_fc_weights',
      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
      shape=[first_conv_element_count, first_fc_output_channels])
  first_fc_bias = tf.compat.v1.get_variable(
      name='first_fc_bias',
      initializer=tf.compat.v1.zeros_initializer,
      shape=[first_fc_output_channels])
  first_fc = tf.matmul(flattened_first_conv, first_fc_weights) + first_fc_bias
  if is_training:
    second_fc_input = tf.nn.dropout(first_fc, rate=dropout_rate)
  else:
    second_fc_input = first_fc
  second_fc_output_channels = 128
  second_fc_weights = tf.compat.v1.get_variable(
      name='second_fc_weights',
      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
      shape=[first_fc_output_channels, second_fc_output_channels])
  second_fc_bias = tf.compat.v1.get_variable(
      name='second_fc_bias',
      initializer=tf.compat.v1.zeros_initializer,
      shape=[second_fc_output_channels])
  second_fc = tf.matmul(second_fc_input, second_fc_weights) + second_fc_bias
  if is_training:
    final_fc_input = tf.nn.dropout(second_fc, rate=dropout_rate)
  else:
    final_fc_input = second_fc
  label_count = model_settings['label_count']
  final_fc_weights = tf.compat.v1.get_variable(
      name='final_fc_weights',
      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
      shape=[second_fc_output_channels, label_count])
  final_fc_bias = tf.compat.v1.get_variable(
      name='final_fc_bias',
      initializer=tf.compat.v1.zeros_initializer,
      shape=[label_count])
  final_fc = tf.matmul(final_fc_input, final_fc_weights) + final_fc_bias
  if is_training:
    return final_fc, dropout_rate
  else:
    return final_fc


def create_low_latency_svdf_model(fingerprint_input, model_settings,
                                  is_training, runtime_settings):
  """Builds an SVDF model with low compute requirements.

  This is based in the topology presented in the 'Compressing Deep Neural
  Networks using a Rank-Constrained Topology' paper:
  https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/43813.pdf

  Here's the layout of the graph:

  (fingerprint_input)
          v
        [SVDF]<-(weights)
          v
      [BiasAdd]<-(bias)
          v
        [Relu]
          v
      [MatMul]<-(weights)
          v
      [BiasAdd]<-(bias)
          v
      [MatMul]<-(weights)
          v
      [BiasAdd]<-(bias)
          v
      [MatMul]<-(weights)
          v
      [BiasAdd]<-(bias)
          v

  This model produces lower recognition accuracy than the 'conv' model above,
  but requires fewer weight parameters and, significantly fewer computations.

  During training, dropout nodes are introduced after the relu, controlled by a
  placeholder.

  Args:
    fingerprint_input: TensorFlow node that will output audio feature vectors.
    The node is expected to produce a 2D Tensor of shape:
      [batch, model_settings['fingerprint_width'] *
              model_settings['spectrogram_length']]
    with the features corresponding to the same time slot arranged contiguously,
    and the oldest slot at index [:, 0], and newest at [:, -1].
    model_settings: Dictionary of information about the model.
    is_training: Whether the model is going to be used for training.
    runtime_settings: Dictionary of information about the runtime.

  Returns:
    TensorFlow node outputting logits results, and optionally a dropout
    placeholder.

  Raises:
      ValueError: If the inputs tensor is incorrectly shaped.
  """
  if is_training:
    dropout_rate = tf.compat.v1.placeholder(tf.float32, name='dropout_rate')

  input_frequency_size = model_settings['fingerprint_width']
  input_time_size = model_settings['spectrogram_length']

  # Validation.
  input_shape = fingerprint_input.get_shape()
  if len(input_shape) != 2:
    raise ValueError('Inputs to `SVDF` should have rank == 2.')
  if input_shape[-1].value is None:
    raise ValueError('The last dimension of the input to `SVDF` '
                     'should be defined. Found `None`.')
  if input_shape[-1].value % input_frequency_size != 0:
    raise ValueError('The last dimension of the input to `SVDF` = {0} must be '
                     'a multiple of the frame size = {1}'.format(
                         input_shape.shape[-1].value, input_frequency_size))

  # Set number of units (i.e. nodes) and rank.
  rank = 2
  num_units = 1280
  # Number of filters: pairs of feature and time filters.
  num_filters = rank * num_units
  # Create the runtime memory: [num_filters, batch, input_time_size]
  batch = 1
  memory = tf.compat.v1.get_variable(
      initializer=tf.compat.v1.zeros_initializer,
      shape=[num_filters, batch, input_time_size],
      trainable=False,
      name='runtime-memory')
  first_time_flag = tf.compat.v1.get_variable(
      name='first_time_flag', dtype=tf.int32, initializer=1)
  # Determine the number of new frames in the input, such that we only operate
  # on those. For training we do not use the memory, and thus use all frames
  # provided in the input.
  # new_fingerprint_input: [batch, num_new_frames*input_frequency_size]
  if is_training:
    num_new_frames = input_time_size
  else:
    window_stride_ms = int(model_settings['window_stride_samples'] * 1000 /
                           model_settings['sample_rate'])
    num_new_frames = tf.cond(
        pred=tf.equal(first_time_flag, 1),
        true_fn=lambda: input_time_size,
        false_fn=lambda: int(runtime_settings['clip_stride_ms'] / window_stride_ms))  # pylint:disable=line-too-long
  first_time_flag = 0
  new_fingerprint_input = fingerprint_input[
      :, -num_new_frames*input_frequency_size:]
  # Expand to add input channels dimension.
  new_fingerprint_input = tf.expand_dims(new_fingerprint_input, 2)

  # Create the frequency filters.
  weights_frequency = tf.compat.v1.get_variable(
      name='weights_frequency',
      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
      shape=[input_frequency_size, num_filters])
  # Expand to add input channels dimensions.
  # weights_frequency: [input_frequency_size, 1, num_filters]
  weights_frequency = tf.expand_dims(weights_frequency, 1)
  # Convolve the 1D feature filters sliding over the time dimension.
  # activations_time: [batch, num_new_frames, num_filters]
  activations_time = tf.nn.conv1d(input=new_fingerprint_input,
                                  filters=weights_frequency,
                                  stride=input_frequency_size,
                                  padding='VALID')
  # Rearrange such that we can perform the batched matmul.
  # activations_time: [num_filters, batch, num_new_frames]
  activations_time = tf.transpose(a=activations_time, perm=[2, 0, 1])

  # Runtime memory optimization.
  if not is_training:
    # We need to drop the activations corresponding to the oldest frames, and
    # then add those corresponding to the new frames.
    new_memory = memory[:, :, num_new_frames:]
    new_memory = tf.concat([new_memory, activations_time], 2)
    tf.compat.v1.assign(memory, new_memory)
    activations_time = new_memory

  # Create the time filters.
  weights_time = tf.compat.v1.get_variable(
      name='weights_time',
      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
      shape=[num_filters, input_time_size])
  # Apply the time filter on the outputs of the feature filters.
  # weights_time: [num_filters, input_time_size, 1]
  # outputs: [num_filters, batch, 1]
  weights_time = tf.expand_dims(weights_time, 2)
  outputs = tf.matmul(activations_time, weights_time)
  # Split num_units and rank into separate dimensions (the remaining
  # dimension is the input_shape[0] -i.e. batch size). This also squeezes
  # the last dimension, since it's not used.
  # [num_filters, batch, 1] => [num_units, rank, batch]
  outputs = tf.reshape(outputs, [num_units, rank, -1])
  # Sum the rank outputs per unit => [num_units, batch].
  units_output = tf.reduce_sum(input_tensor=outputs, axis=1)
  # Transpose to shape [batch, num_units]
  units_output = tf.transpose(a=units_output)

  # Appy bias.
  bias = tf.compat.v1.get_variable(name='bias',
                                   initializer=tf.compat.v1.zeros_initializer,
                                   shape=[num_units])
  first_bias = tf.nn.bias_add(units_output, bias)

  # Relu.
  first_relu = tf.nn.relu(first_bias)

  if is_training:
    first_dropout = tf.nn.dropout(first_relu, rate=dropout_rate)
  else:
    first_dropout = first_relu

  first_fc_output_channels = 256
  first_fc_weights = tf.compat.v1.get_variable(
      name='first_fc_weights',
      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
      shape=[num_units, first_fc_output_channels])
  first_fc_bias = tf.compat.v1.get_variable(
      name='first_fc_bias',
      initializer=tf.compat.v1.zeros_initializer,
      shape=[first_fc_output_channels])
  first_fc = tf.matmul(first_dropout, first_fc_weights) + first_fc_bias
  if is_training:
    second_fc_input = tf.nn.dropout(first_fc, rate=dropout_rate)
  else:
    second_fc_input = first_fc
  second_fc_output_channels = 256
  second_fc_weights = tf.compat.v1.get_variable(
      name='second_fc_weights',
      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
      shape=[first_fc_output_channels, second_fc_output_channels])
  second_fc_bias = tf.compat.v1.get_variable(
      name='second_fc_bias',
      initializer=tf.compat.v1.zeros_initializer,
      shape=[second_fc_output_channels])
  second_fc = tf.matmul(second_fc_input, second_fc_weights) + second_fc_bias
  if is_training:
    final_fc_input = tf.nn.dropout(second_fc, rate=dropout_rate)
  else:
    final_fc_input = second_fc
  label_count = model_settings['label_count']
  final_fc_weights = tf.compat.v1.get_variable(
      name='final_fc_weights',
      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
      shape=[second_fc_output_channels, label_count])
  final_fc_bias = tf.compat.v1.get_variable(
      name='final_fc_bias',
      initializer=tf.compat.v1.zeros_initializer,
      shape=[label_count])
  final_fc = tf.matmul(final_fc_input, final_fc_weights) + final_fc_bias
  if is_training:
    return final_fc, dropout_rate
  else:
    return final_fc


def create_tiny_conv_model(fingerprint_input, model_settings, is_training):
  """Builds a convolutional model aimed at microcontrollers.

  Devices like DSPs and microcontrollers can have very small amounts of
  memory and limited processing power. This model is designed to use less
  than 20KB of working RAM, and fit within 32KB of read-only (flash) memory.

  Here's the layout of the graph:

  (fingerprint_input)
          v
      [Conv2D]<-(weights)
          v
      [BiasAdd]<-(bias)
          v
        [Relu]
          v
      [MatMul]<-(weights)
          v
      [BiasAdd]<-(bias)
          v

  This doesn't produce particularly accurate results, but it's designed to be
  used as the first stage of a pipeline, running on a low-energy piece of
  hardware that can always be on, and then wake higher-power chips when a
  possible utterance has been found, so that more accurate analysis can be done.

  During training, a dropout node is introduced after the relu, controlled by a
  placeholder.

  Args:
    fingerprint_input: TensorFlow node that will output audio feature vectors.
    model_settings: Dictionary of information about the model.
    is_training: Whether the model is going to be used for training.

  Returns:
    TensorFlow node outputting logits results, and optionally a dropout
    placeholder.
  """
  if is_training:
    dropout_rate = tf.compat.v1.placeholder(tf.float32, name='dropout_rate')
  input_frequency_size = model_settings['fingerprint_width']
  input_time_size = model_settings['spectrogram_length']
  fingerprint_4d = tf.reshape(fingerprint_input,
                              [-1, input_time_size, input_frequency_size, 1])
  first_filter_width = 8
  first_filter_height = 10
  first_filter_count = 8
  first_weights = tf.compat.v1.get_variable(
      name='first_weights',
      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
      shape=[first_filter_height, first_filter_width, 1, first_filter_count])
  first_bias = tf.compat.v1.get_variable(
      name='first_bias',
      initializer=tf.compat.v1.zeros_initializer,
      shape=[first_filter_count])
  first_conv_stride_x = 2
  first_conv_stride_y = 2
  first_conv = tf.nn.conv2d(
      input=fingerprint_4d, filters=first_weights,
      strides=[1, first_conv_stride_y, first_conv_stride_x, 1],
      padding='SAME') + first_bias
  first_relu = tf.nn.relu(first_conv)
  if is_training:
    first_dropout = tf.nn.dropout(first_relu, rate=dropout_rate)
  else:
    first_dropout = first_relu
  first_dropout_shape = first_dropout.get_shape()
  first_dropout_output_width = first_dropout_shape[2]
  first_dropout_output_height = first_dropout_shape[1]
  first_dropout_element_count = int(
      first_dropout_output_width * first_dropout_output_height *
      first_filter_count)
  flattened_first_dropout = tf.reshape(first_dropout,
                                       [-1, first_dropout_element_count])
  label_count = model_settings['label_count']
  final_fc_weights = tf.compat.v1.get_variable(
      name='final_fc_weights',
      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
      shape=[first_dropout_element_count, label_count])
  final_fc_bias = tf.compat.v1.get_variable(
      name='final_fc_bias',
      initializer=tf.compat.v1.zeros_initializer,
      shape=[label_count])
  final_fc = (
      tf.matmul(flattened_first_dropout, final_fc_weights) + final_fc_bias)
  if is_training:
    return final_fc, dropout_rate
  else:
    return final_fc


def create_tiny_embedding_conv_model(fingerprint_input, model_settings,
                                     is_training):
  """Builds a convolutional model aimed at microcontrollers.

  Devices like DSPs and microcontrollers can have very small amounts of
  memory and limited processing power. This model is designed to use less
  than 20KB of working RAM, and fit within 32KB of read-only (flash) memory.

  Here's the layout of the graph:

  (fingerprint_input)
          v
      [Conv2D]<-(weights)
          v
      [BiasAdd]<-(bias)
          v
        [Relu]
          v
      [Conv2D]<-(weights)
          v
      [BiasAdd]<-(bias)
          v
        [Relu]
          v
      [Conv2D]<-(weights)
          v
      [BiasAdd]<-(bias)
          v
        [Relu]
          v
      [MatMul]<-(weights)
          v
      [BiasAdd]<-(bias)
          v

  This doesn't produce particularly accurate results, but it's designed to be
  used as the first stage of a pipeline, running on a low-energy piece of
  hardware that can always be on, and then wake higher-power chips when a
  possible utterance has been found, so that more accurate analysis can be done.

  During training, a dropout node is introduced after the relu, controlled by a
  placeholder.

  Args:
    fingerprint_input: TensorFlow node that will output audio feature vectors.
    model_settings: Dictionary of information about the model.
    is_training: Whether the model is going to be used for training.

  Returns:
    TensorFlow node outputting logits results, and optionally a dropout
    placeholder.
  """
  if is_training:
    dropout_rate = tf.compat.v1.placeholder(tf.float32, name='dropout_rate')
  input_frequency_size = model_settings['fingerprint_width']
  input_time_size = model_settings['spectrogram_length']
  fingerprint_4d = tf.reshape(fingerprint_input,
                              [-1, input_time_size, input_frequency_size, 1])

  first_filter_width = 8
  first_filter_height = 10
  first_filter_count = 8
  first_weights = tf.compat.v1.get_variable(
      name='first_weights',
      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
      shape=[first_filter_height, first_filter_width, 1, first_filter_count])
  first_bias = tf.compat.v1.get_variable(
      name='first_bias',
      initializer=tf.compat.v1.zeros_initializer,
      shape=[first_filter_count])
  first_conv_stride_x = 2
  first_conv_stride_y = 2

  first_conv = tf.nn.conv2d(
      input=fingerprint_4d, filters=first_weights,
      strides=[1, first_conv_stride_y, first_conv_stride_x, 1],
      padding='SAME') + first_bias
  first_relu = tf.nn.relu(first_conv)
  if is_training:
    first_dropout = tf.nn.dropout(first_relu, rate=dropout_rate)

  else:
    first_dropout = first_relu

  second_filter_width = 8
  second_filter_height = 10
  second_filter_count = 8
  second_weights = tf.compat.v1.get_variable(
      name='second_weights',
      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
      shape=[
          second_filter_height, second_filter_width, first_filter_count,
          second_filter_count
      ])
  second_bias = tf.compat.v1.get_variable(
      name='second_bias',
      initializer=tf.compat.v1.zeros_initializer,
      shape=[second_filter_count])
  second_conv_stride_x = 8
  second_conv_stride_y = 8
  second_conv = tf.nn.conv2d(
      input=first_dropout, filters=second_weights,
      strides=[1, second_conv_stride_y, second_conv_stride_x, 1],
      padding='SAME') + second_bias
  second_relu = tf.nn.relu(second_conv)
  if is_training:
    second_dropout = tf.nn.dropout(second_relu, rate=dropout_rate)
  else:
    second_dropout = second_relu

  second_dropout_shape = second_dropout.get_shape()
  second_dropout_output_width = second_dropout_shape[2]
  second_dropout_output_height = second_dropout_shape[1]
  second_dropout_element_count = int(second_dropout_output_width *
                                     second_dropout_output_height *
                                     second_filter_count)
  flattened_second_dropout = tf.reshape(second_dropout,
                                        [-1, second_dropout_element_count])
  label_count = model_settings['label_count']
  final_fc_weights = tf.compat.v1.get_variable(
      name='final_fc_weights',
      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
      shape=[second_dropout_element_count, label_count])
  final_fc_bias = tf.compat.v1.get_variable(
      name='final_fc_bias',
      initializer=tf.compat.v1.zeros_initializer,
      shape=[label_count])
  final_fc = (
      tf.matmul(flattened_second_dropout, final_fc_weights) + final_fc_bias)
  if is_training:
    return final_fc, dropout_rate
  else:
    return final_fc

input_data.py

Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model definitions for simple speech recognition.

"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import hashlib
import math
import os.path
import random
import re
import sys
import tarfile

import numpy as np
from six.moves import urllib
from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf

from tensorflow.python.ops import gen_audio_ops as audio_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.platform import gfile
from tensorflow.python.util import compat

tf.compat.v1.disable_eager_execution()

# If it's available, load the specialized feature generator. If this doesn't
# work, try building with bazel instead of running the Python script directly.
try:
  from tensorflow.lite.experimental.microfrontend.python.ops import audio_microfrontend_op as frontend_op  # pylint:disable=g-import-not-at-top
except ImportError:
  frontend_op = None

MAX_NUM_WAVS_PER_CLASS = 2**27 - 1  # ~134M
SILENCE_LABEL = '_silence_'
SILENCE_INDEX = 0
UNKNOWN_WORD_LABEL = '_unknown_'
UNKNOWN_WORD_INDEX = 1
BACKGROUND_NOISE_DIR_NAME = '_background_noise_'
RANDOM_SEED = 59185


def prepare_words_list(wanted_words):
  """Prepends common tokens to the custom word list.

  Args:
    wanted_words: List of strings containing the custom words.

  Returns:
    List with the standard silence and unknown tokens added.
  """
  return [SILENCE_LABEL, UNKNOWN_WORD_LABEL] + wanted_words


def which_set(filename, validation_percentage, testing_percentage):
  """Determines which data partition the file should belong to.

  We want to keep files in the same training, validation, or testing sets even
  if new ones are added over time. This makes it less likely that testing
  samples will accidentally be reused in training when long runs are restarted
  for example. To keep this stability, a hash of the filename is taken and used
  to determine which set it should belong to. This determination only depends on
  the name and the set proportions, so it won't change as other files are added.

  It's also useful to associate particular files as related (for example words
  spoken by the same person), so anything after '_nohash_' in a filename is
  ignored for set determination. This ensures that 'bobby_nohash_0.wav' and
  'bobby_nohash_1.wav' are always in the same set, for example.

  Args:
    filename: File path of the data sample.
    validation_percentage: How much of the data set to use for validation.
    testing_percentage: How much of the data set to use for testing.

  Returns:
    String, one of 'training', 'validation', or 'testing'.
  """
  base_name = os.path.basename(filename)
  # We want to ignore anything after '_nohash_' in the file name when
  # deciding which set to put a wav in, so the data set creator has a way of
  # grouping wavs that are close variations of each other.
  hash_name = re.sub(r'_nohash_.*$', '', base_name)
  # This looks a bit magical, but we need to decide whether this file should
  # go into the training, testing, or validation sets, and we want to keep
  # existing files in the same set even if more files are subsequently
  # added.
  # To do that, we need a stable way of deciding based on just the file name
  # itself, so we do a hash of that and then use that to generate a
  # probability value that we use to assign it.
  hash_name_hashed = hashlib.sha1(compat.as_bytes(hash_name)).hexdigest()
  percentage_hash = ((int(hash_name_hashed, 16) %
                      (MAX_NUM_WAVS_PER_CLASS + 1)) *
                     (100.0 / MAX_NUM_WAVS_PER_CLASS))
  if percentage_hash < validation_percentage:
    result = 'validation'
  elif percentage_hash < (testing_percentage + validation_percentage):
    result = 'testing'
  else:
    result = 'training'
  return result


def load_wav_file(filename):
  """Loads an audio file and returns a float PCM-encoded array of samples.

  Args:
    filename: Path to the .wav file to load.

  Returns:
    Numpy array holding the sample data as floats between -1.0 and 1.0.
  """
  with tf.compat.v1.Session(graph=tf.Graph()) as sess:
    wav_filename_placeholder = tf.compat.v1.placeholder(tf.string, [])
    wav_loader = io_ops.read_file(wav_filename_placeholder)
    wav_decoder = tf.audio.decode_wav(wav_loader, desired_channels=1)
    return sess.run(
        wav_decoder,
        feed_dict={wav_filename_placeholder: filename}).audio.flatten()


def save_wav_file(filename, wav_data, sample_rate):
  """Saves audio sample data to a .wav audio file.

  Args:
    filename: Path to save the file to.
    wav_data: 2D array of float PCM-encoded audio data.
    sample_rate: Samples per second to encode in the file.
  """
  with tf.compat.v1.Session(graph=tf.Graph()) as sess:
    wav_filename_placeholder = tf.compat.v1.placeholder(tf.string, [])
    sample_rate_placeholder = tf.compat.v1.placeholder(tf.int32, [])
    wav_data_placeholder = tf.compat.v1.placeholder(tf.float32, [None, 1])
    wav_encoder = tf.audio.encode_wav(wav_data_placeholder,
                                      sample_rate_placeholder)
    wav_saver = io_ops.write_file(wav_filename_placeholder, wav_encoder)
    sess.run(
        wav_saver,
        feed_dict={
            wav_filename_placeholder: filename,
            sample_rate_placeholder: sample_rate,
            wav_data_placeholder: np.reshape(wav_data, (-1, 1))
        })


def get_features_range(model_settings):
  """Returns the expected min/max for generated features.

  Args:
    model_settings: Information about the current model being trained.

  Returns:
    Min/max float pair holding the range of features.

  Raises:
    Exception: If preprocessing mode isn't recognized.
  """
  # TODO(petewarden): These values have been derived from the observed ranges
  # of spectrogram and MFCC inputs. If the preprocessing pipeline changes,
  # they may need to be updated.
  if model_settings['preprocess'] == 'average':
    features_min = 0.0
    features_max = 127.5
  elif model_settings['preprocess'] == 'mfcc':
    features_min = -247.0
    features_max = 30.0
  elif model_settings['preprocess'] == 'micro':
    features_min = 0.0
    features_max = 26.0
  else:
    raise Exception('Unknown preprocess mode "%s" (should be "mfcc",'
                    ' "average", or "micro")' % (model_settings['preprocess']))
  return features_min, features_max


class AudioProcessor(object):
  """Handles loading, partitioning, and preparing audio training data."""

  def __init__(self, data_url, data_dir, silence_percentage, unknown_percentage,
               wanted_words, validation_percentage, testing_percentage,
               model_settings, summaries_dir):
    if data_dir:
      self.data_dir = data_dir
      self.maybe_download_and_extract_dataset(data_url, data_dir)
      self.prepare_data_index(silence_percentage, unknown_percentage,
                              wanted_words, validation_percentage,
                              testing_percentage)
      self.prepare_background_data()
    self.prepare_processing_graph(model_settings, summaries_dir)

  def maybe_download_and_extract_dataset(self, data_url, dest_directory):
    """Download and extract data set tar file.

    If the data set we're using doesn't already exist, this function
    downloads it from the TensorFlow.org website and unpacks it into a
    directory.
    If the data_url is none, don't download anything and expect the data
    directory to contain the correct files already.

    Args:
      data_url: Web location of the tar file containing the data set.
      dest_directory: File path to extract data to.
    """
    if not data_url:
      return
    if not gfile.Exists(dest_directory):
      os.makedirs(dest_directory)
    filename = data_url.split('/')[-1]
    filepath = os.path.join(dest_directory, filename)
    if not gfile.Exists(filepath):

      def _progress(count, block_size, total_size):
        sys.stdout.write(
            '\r>> Downloading %s %.1f%%' %
            (filename, float(count * block_size) / float(total_size) * 100.0))
        sys.stdout.flush()

      try:
        filepath, _ = urllib.request.urlretrieve(data_url, filepath, _progress)
      except:
        tf.compat.v1.logging.error(
            'Failed to download URL: {0} to folder: {1}. Please make sure you '
            'have enough free space and an internet connection'.format(
                data_url, filepath))
        raise
      print()
      statinfo = os.stat(filepath)
      tf.compat.v1.logging.info(
          'Successfully downloaded {0} ({1} bytes)'.format(
              filename, statinfo.st_size))
      tarfile.open(filepath, 'r:gz').extractall(dest_directory)

  def prepare_data_index(self, silence_percentage, unknown_percentage,
                         wanted_words, validation_percentage,
                         testing_percentage):
    """Prepares a list of the samples organized by set and label.

    The training loop needs a list of all the available data, organized by
    which partition it should belong to, and with ground truth labels attached.
    This function analyzes the folders below the `data_dir`, figures out the
    right
    labels for each file based on the name of the subdirectory it belongs to,
    and uses a stable hash to assign it to a data set partition.

    Args:
      silence_percentage: How much of the resulting data should be background.
      unknown_percentage: How much should be audio outside the wanted classes.
      wanted_words: Labels of the classes we want to be able to recognize.
      validation_percentage: How much of the data set to use for validation.
      testing_percentage: How much of the data set to use for testing.

    Returns:
      Dictionary containing a list of file information for each set partition,
      and a lookup map for each class to determine its numeric index.

    Raises:
      Exception: If expected files are not found.
    """
    # Make sure the shuffling and picking of unknowns is deterministic.
    random.seed(RANDOM_SEED)
    wanted_words_index = {}
    for index, wanted_word in enumerate(wanted_words):
      wanted_words_index[wanted_word] = index + 2
    self.data_index = {'validation': [], 'testing': [], 'training': []}
    unknown_index = {'validation': [], 'testing': [], 'training': []}
    all_words = {}
    # Look through all the subfolders to find audio samples
    search_path = os.path.join(self.data_dir, '*', '*.wav')
    for wav_path in gfile.Glob(search_path):
      _, word = os.path.split(os.path.dirname(wav_path))
      word = word.lower()
      # Treat the '_background_noise_' folder as a special case, since we expect
      # it to contain long audio samples we mix in to improve training.
      if word == BACKGROUND_NOISE_DIR_NAME:
        continue
      all_words[word] = True
      set_index = which_set(wav_path, validation_percentage, testing_percentage)
      # If it's a known class, store its detail, otherwise add it to the list
      # we'll use to train the unknown label.
      if word in wanted_words_index:
        self.data_index[set_index].append({'label': word, 'file': wav_path})
      else:
        unknown_index[set_index].append({'label': word, 'file': wav_path})
    if not all_words:
      raise Exception('No .wavs found at ' + search_path)
    for index, wanted_word in enumerate(wanted_words):
      if wanted_word not in all_words:
        raise Exception('Expected to find ' + wanted_word +
                        ' in labels but only found ' +
                        ', '.join(all_words.keys()))
    # We need an arbitrary file to load as the input for the silence samples.
    # It's multiplied by zero later, so the content doesn't matter.
    silence_wav_path = self.data_index['training'][0]['file']
    for set_index in ['validation', 'testing', 'training']:
      set_size = len(self.data_index[set_index])
      silence_size = int(math.ceil(set_size * silence_percentage / 100))
      for _ in range(silence_size):
        self.data_index[set_index].append({
            'label': SILENCE_LABEL,
            'file': silence_wav_path
        })
      # Pick some unknowns to add to each partition of the data set.
      random.shuffle(unknown_index[set_index])
      unknown_size = int(math.ceil(set_size * unknown_percentage / 100))
      self.data_index[set_index].extend(unknown_index[set_index][:unknown_size])
    # Make sure the ordering is random.
    for set_index in ['validation', 'testing', 'training']:
      random.shuffle(self.data_index[set_index])
    # Prepare the rest of the result data structure.
    self.words_list = prepare_words_list(wanted_words)
    self.word_to_index = {}
    for word in all_words:
      if word in wanted_words_index:
        self.word_to_index[word] = wanted_words_index[word]
      else:
        self.word_to_index[word] = UNKNOWN_WORD_INDEX
    self.word_to_index[SILENCE_LABEL] = SILENCE_INDEX

  def prepare_background_data(self):
    """Searches a folder for background noise audio, and loads it into memory.

    It's expected that the background audio samples will be in a subdirectory
    named '_background_noise_' inside the 'data_dir' folder, as .wavs that match
    the sample rate of the training data, but can be much longer in duration.

    If the '_background_noise_' folder doesn't exist at all, this isn't an
    error, it's just taken to mean that no background noise augmentation should
    be used. If the folder does exist, but it's empty, that's treated as an
    error.

    Returns:
      List of raw PCM-encoded audio samples of background noise.

    Raises:
      Exception: If files aren't found in the folder.
    """
    self.background_data = []
    background_dir = os.path.join(self.data_dir, BACKGROUND_NOISE_DIR_NAME)
    if not gfile.Exists(background_dir):
      return self.background_data
    with tf.compat.v1.Session(graph=tf.Graph()) as sess:
      wav_filename_placeholder = tf.compat.v1.placeholder(tf.string, [])
      wav_loader = io_ops.read_file(wav_filename_placeholder)
      wav_decoder = tf.audio.decode_wav(wav_loader, desired_channels=1)
      search_path = os.path.join(self.data_dir, BACKGROUND_NOISE_DIR_NAME,
                                 '*.wav')
      for wav_path in gfile.Glob(search_path):
        wav_data = sess.run(
            wav_decoder,
            feed_dict={wav_filename_placeholder: wav_path}).audio.flatten()
        self.background_data.append(wav_data)
      if not self.background_data:
        raise Exception('No background wav files were found in ' + search_path)

  def prepare_processing_graph(self, model_settings, summaries_dir):
    """Builds a TensorFlow graph to apply the input distortions.

    Creates a graph that loads a WAVE file, decodes it, scales the volume,
    shifts it in time, adds in background noise, calculates a spectrogram, and
    then builds an MFCC fingerprint from that.

    This must be called with an active TensorFlow session running, and it
    creates multiple placeholder inputs, and one output:

      - wav_filename_placeholder_: Filename of the WAV to load.
      - foreground_volume_placeholder_: How loud the main clip should be.
      - time_shift_padding_placeholder_: Where to pad the clip.
      - time_shift_offset_placeholder_: How much to move the clip in time.
      - background_data_placeholder_: PCM sample data for background noise.
      - background_volume_placeholder_: Loudness of mixed-in background.
      - output_: Output 2D fingerprint of processed audio.

    Args:
      model_settings: Information about the current model being trained.
      summaries_dir: Path to save training summary information to.

    Raises:
      ValueError: If the preprocessing mode isn't recognized.
      Exception: If the preprocessor wasn't compiled in.
    """
    with tf.compat.v1.get_default_graph().name_scope('data'):
      desired_samples = model_settings['desired_samples']
      self.wav_filename_placeholder_ = tf.compat.v1.placeholder(
          tf.string, [], name='wav_filename')
      wav_loader = io_ops.read_file(self.wav_filename_placeholder_)
      wav_decoder = tf.audio.decode_wav(
          wav_loader, desired_channels=1, desired_samples=desired_samples)
      # Allow the audio sample's volume to be adjusted.
      self.foreground_volume_placeholder_ = tf.compat.v1.placeholder(
          tf.float32, [], name='foreground_volume')
      scaled_foreground = tf.multiply(wav_decoder.audio,
                                      self.foreground_volume_placeholder_)
      # Shift the sample's start position, and pad any gaps with zeros.
      self.time_shift_padding_placeholder_ = tf.compat.v1.placeholder(
          tf.int32, [2, 2], name='time_shift_padding')
      self.time_shift_offset_placeholder_ = tf.compat.v1.placeholder(
          tf.int32, [2], name='time_shift_offset')
      padded_foreground = tf.pad(
          tensor=scaled_foreground,
          paddings=self.time_shift_padding_placeholder_,
          mode='CONSTANT')
      sliced_foreground = tf.slice(padded_foreground,
                                   self.time_shift_offset_placeholder_,
                                   [desired_samples, -1])
      # Mix in background noise.
      self.background_data_placeholder_ = tf.compat.v1.placeholder(
          tf.float32, [desired_samples, 1], name='background_data')
      self.background_volume_placeholder_ = tf.compat.v1.placeholder(
          tf.float32, [], name='background_volume')
      background_mul = tf.multiply(self.background_data_placeholder_,
                                   self.background_volume_placeholder_)
      background_add = tf.add(background_mul, sliced_foreground)
      background_clamp = tf.clip_by_value(background_add, -1.0, 1.0)
      # Run the spectrogram and MFCC ops to get a 2D 'fingerprint' of the audio.
      spectrogram = audio_ops.audio_spectrogram(
          background_clamp,
          window_size=model_settings['window_size_samples'],
          stride=model_settings['window_stride_samples'],
          magnitude_squared=True)
      tf.compat.v1.summary.image(
          'spectrogram', tf.expand_dims(spectrogram, -1), max_outputs=1)
      # The number of buckets in each FFT row in the spectrogram will depend on
      # how many input samples there are in each window. This can be quite
      # large, with a 160 sample window producing 127 buckets for example. We
      # don't need this level of detail for classification, so we often want to
      # shrink them down to produce a smaller result. That's what this section
      # implements. One method is to use average pooling to merge adjacent
      # buckets, but a more sophisticated approach is to apply the MFCC
      # algorithm to shrink the representation.
      if model_settings['preprocess'] == 'average':
        self.output_ = tf.nn.pool(
            input=tf.expand_dims(spectrogram, -1),
            window_shape=[1, model_settings['average_window_width']],
            strides=[1, model_settings['average_window_width']],
            pooling_type='AVG',
            padding='SAME')
        tf.compat.v1.summary.image('shrunk_spectrogram',
                                   self.output_,
                                   max_outputs=1)
      elif model_settings['preprocess'] == 'mfcc':
        self.output_ = audio_ops.mfcc(
            spectrogram,
            wav_decoder.sample_rate,
            dct_coefficient_count=model_settings['fingerprint_width'])
        tf.compat.v1.summary.image(
            'mfcc', tf.expand_dims(self.output_, -1), max_outputs=1)
      elif model_settings['preprocess'] == 'micro':
        if not frontend_op:
          raise Exception(
              'Micro frontend op is currently not available when running'
              ' TensorFlow directly from Python, you need to build and run'
              ' through Bazel')
        sample_rate = model_settings['sample_rate']
        window_size_ms = (model_settings['window_size_samples'] *
                          1000) / sample_rate
        window_step_ms = (model_settings['window_stride_samples'] *
                          1000) / sample_rate
        int16_input = tf.cast(tf.multiply(background_clamp, 32768), tf.int16)
        micro_frontend = frontend_op.audio_microfrontend(
            int16_input,
            sample_rate=sample_rate,
            window_size=window_size_ms,
            window_step=window_step_ms,
            num_channels=model_settings['fingerprint_width'],
            out_scale=1,
            out_type=tf.float32)
        self.output_ = tf.multiply(micro_frontend, (10.0 / 256.0))
        tf.compat.v1.summary.image(
            'micro',
            tf.expand_dims(tf.expand_dims(self.output_, -1), 0),
            max_outputs=1)
      else:
        raise ValueError('Unknown preprocess mode "%s" (should be "mfcc", '
                         ' "average", or "micro")' %
                         (model_settings['preprocess']))

      # Merge all the summaries and write them out to /tmp/retrain_logs (by
      # default)
      self.merged_summaries_ = tf.compat.v1.summary.merge_all(scope='data')
      if summaries_dir:
        self.summary_writer_ = tf.compat.v1.summary.FileWriter(
            summaries_dir + '/data', tf.compat.v1.get_default_graph())

  def set_size(self, mode):
    """Calculates the number of samples in the dataset partition.

    Args:
      mode: Which partition, must be 'training', 'validation', or 'testing'.

    Returns:
      Number of samples in the partition.
    """
    return len(self.data_index[mode])

  def get_data(self, how_many, offset, model_settings, background_frequency,
               background_volume_range, time_shift, mode, sess):
    """Gather samples from the data set, applying transformations as needed.

    When the mode is 'training', a random selection of samples will be returned,
    otherwise the first N clips in the partition will be used. This ensures that
    validation always uses the same samples, reducing noise in the metrics.

    Args:
      how_many: Desired number of samples to return. -1 means the entire
        contents of this partition.
      offset: Where to start when fetching deterministically.
      model_settings: Information about the current model being trained.
      background_frequency: How many clips will have background noise, 0.0 to
        1.0.
      background_volume_range: How loud the background noise will be.
      time_shift: How much to randomly shift the clips by in time.
      mode: Which partition to use, must be 'training', 'validation', or
        'testing'.
      sess: TensorFlow session that was active when processor was created.

    Returns:
      List of sample data for the transformed samples, and list of label indexes

    Raises:
      ValueError: If background samples are too short.
    """
    # Pick one of the partitions to choose samples from.
    candidates = self.data_index[mode]
    if how_many == -1:
      sample_count = len(candidates)
    else:
      sample_count = max(0, min(how_many, len(candidates) - offset))
    # Data and labels will be populated and returned.
    data = np.zeros((sample_count, model_settings['fingerprint_size']))
    labels = np.zeros(sample_count)
    desired_samples = model_settings['desired_samples']
    use_background = self.background_data and (mode == 'training')
    pick_deterministically = (mode != 'training')
    # Use the processing graph we created earlier to repeatedly to generate the
    # final output sample data we'll use in training.
    for i in xrange(offset, offset + sample_count):
      # Pick which audio sample to use.
      if how_many == -1 or pick_deterministically:
        sample_index = i
      else:
        sample_index = np.random.randint(len(candidates))
      sample = candidates[sample_index]
      # If we're time shifting, set up the offset for this sample.
      if time_shift > 0:
        time_shift_amount = np.random.randint(-time_shift, time_shift)
      else:
        time_shift_amount = 0
      if time_shift_amount > 0:
        time_shift_padding = [[time_shift_amount, 0], [0, 0]]
        time_shift_offset = [0, 0]
      else:
        time_shift_padding = [[0, -time_shift_amount], [0, 0]]
        time_shift_offset = [-time_shift_amount, 0]
      input_dict = {
          self.wav_filename_placeholder_: sample['file'],
          self.time_shift_padding_placeholder_: time_shift_padding,
          self.time_shift_offset_placeholder_: time_shift_offset,
      }
      # Choose a section of background noise to mix in.
      if use_background or sample['label'] == SILENCE_LABEL:
        background_index = np.random.randint(len(self.background_data))
        background_samples = self.background_data[background_index]
        if len(background_samples) <= model_settings['desired_samples']:
          raise ValueError(
              'Background sample is too short! Need more than %d'
              ' samples but only %d were found' %
              (model_settings['desired_samples'], len(background_samples)))
        background_offset = np.random.randint(
            0, len(background_samples) - model_settings['desired_samples'])
        background_clipped = background_samples[background_offset:(
            background_offset + desired_samples)]
        background_reshaped = background_clipped.reshape([desired_samples, 1])
        if sample['label'] == SILENCE_LABEL:
          background_volume = np.random.uniform(0, 1)
        elif np.random.uniform(0, 1) < background_frequency:
          background_volume = np.random.uniform(0, background_volume_range)
        else:
          background_volume = 0
      else:
        background_reshaped = np.zeros([desired_samples, 1])
        background_volume = 0
      input_dict[self.background_data_placeholder_] = background_reshaped
      input_dict[self.background_volume_placeholder_] = background_volume
      # If we want silence, mute out the main sample but leave the background.
      if sample['label'] == SILENCE_LABEL:
        input_dict[self.foreground_volume_placeholder_] = 0
      else:
        input_dict[self.foreground_volume_placeholder_] = 1
      # Run the graph to produce the output audio.
      summary, data_tensor = sess.run(
          [self.merged_summaries_, self.output_], feed_dict=input_dict)
      self.summary_writer_.add_summary(summary)
      data[i - offset, :] = data_tensor.flatten()
      label_index = self.word_to_index[sample['label']]
      labels[i - offset] = label_index
    return data, labels

  def get_features_for_wav(self, wav_filename, model_settings, sess):
    """Applies the feature transformation process to the input_wav.

    Runs the feature generation process (generally producing a spectrogram from
    the input samples) on the WAV file. This can be useful for testing and
    verifying implementations being run on other platforms.

    Args:
      wav_filename: The path to the input audio file.
      model_settings: Information about the current model being trained.
      sess: TensorFlow session that was active when processor was created.

    Returns:
      Numpy data array containing the generated features.
    """
    desired_samples = model_settings['desired_samples']
    input_dict = {
        self.wav_filename_placeholder_: wav_filename,
        self.time_shift_padding_placeholder_: [[0, 0], [0, 0]],
        self.time_shift_offset_placeholder_: [0, 0],
        self.background_data_placeholder_: np.zeros([desired_samples, 1]),
        self.background_volume_placeholder_: 0,
        self.foreground_volume_placeholder_: 1,
    }
    # Run the graph to produce the output audio.
    data_tensor = sess.run([self.output_], feed_dict=input_dict)
    return data_tensor

  def get_unprocessed_data(self, how_many, model_settings, mode):
    """Retrieve sample data for the given partition, with no transformations.

    Args:
      how_many: Desired number of samples to return. -1 means the entire
        contents of this partition.
      model_settings: Information about the current model being trained.
      mode: Which partition to use, must be 'training', 'validation', or
        'testing'.

    Returns:
      List of sample data for the samples, and list of labels in one-hot form.
    """
    candidates = self.data_index[mode]
    if how_many == -1:
      sample_count = len(candidates)
    else:
      sample_count = how_many
    desired_samples = model_settings['desired_samples']
    words_list = self.words_list
    data = np.zeros((sample_count, desired_samples))
    labels = []
    with tf.compat.v1.Session(graph=tf.Graph()) as sess:
      wav_filename_placeholder = tf.compat.v1.placeholder(tf.string, [])
      wav_loader = io_ops.read_file(wav_filename_placeholder)
      wav_decoder = tf.audio.decode_wav(
          wav_loader, desired_channels=1, desired_samples=desired_samples)
      foreground_volume_placeholder = tf.compat.v1.placeholder(tf.float32, [])
      scaled_foreground = tf.multiply(wav_decoder.audio,
                                      foreground_volume_placeholder)
      for i in range(sample_count):
        if how_many == -1:
          sample_index = i
        else:
          sample_index = np.random.randint(len(candidates))
        sample = candidates[sample_index]
        input_dict = {wav_filename_placeholder: sample['file']}
        if sample['label'] == SILENCE_LABEL:
          input_dict[foreground_volume_placeholder] = 0
        else:
          input_dict[foreground_volume_placeholder] = 1
        data[i, :] = sess.run(scaled_foreground, feed_dict=input_dict).flatten()
        label_index = self.word_to_index[sample['label']]
        labels.append(words_list[label_index])
    return data, labels

freeze.py

Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Converts a trained checkpoint into a frozen model for mobile inference.

Once you've trained a model using the `train.py` script, you can use this tool
to convert it into a binary GraphDef file that can be loaded into the Android,
iOS, or Raspberry Pi example code. Here's an example of how to run it:

bazel run tensorflow/examples/speech_commands/freeze -- \
--sample_rate=16000 --dct_coefficient_count=40 --window_size_ms=20 \
--window_stride_ms=10 --clip_duration_ms=1000 \
--model_architecture=conv \
--start_checkpoint=/tmp/speech_commands_train/conv.ckpt-1300 \
--output_file=/tmp/my_frozen_graph.pb

One thing to watch out for is that you need to pass in the same arguments for
`sample_rate` and other command line variables here as you did for the training
script.

The resulting graph has an input for WAV-encoded data named 'wav_data', one for
raw PCM data (as floats in the range -1.0 to 1.0) called 'decoded_sample_data',
and the output is called 'labels_softmax'.

"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os.path
import sys

import tensorflow as tf

import input_data
import models
from tensorflow.python.framework import graph_util
from tensorflow.python.ops import gen_audio_ops as audio_ops

# If it's available, load the specialized feature generator. If this doesn't
# work, try building with bazel instead of running the Python script directly.
# bazel run tensorflow/examples/speech_commands:freeze_graph
try:
  from tensorflow.lite.experimental.microfrontend.python.ops import audio_microfrontend_op as frontend_op  # pylint:disable=g-import-not-at-top
except ImportError:
  frontend_op = None

FLAGS = None


def create_inference_graph(wanted_words, sample_rate, clip_duration_ms,
                           clip_stride_ms, window_size_ms, window_stride_ms,
                           feature_bin_count, model_architecture, preprocess):
  """Creates an audio model with the nodes needed for inference.

  Uses the supplied arguments to create a model, and inserts the input and
  output nodes that are needed to use the graph for inference.

  Args:
    wanted_words: Comma-separated list of the words we're trying to recognize.
    sample_rate: How many samples per second are in the input audio files.
    clip_duration_ms: How many samples to analyze for the audio pattern.
    clip_stride_ms: How often to run recognition. Useful for models with cache.
    window_size_ms: Time slice duration to estimate frequencies from.
    window_stride_ms: How far apart time slices should be.
    feature_bin_count: Number of frequency bands to analyze.
    model_architecture: Name of the kind of model to generate.
    preprocess: How the spectrogram is processed to produce features, for
      example 'mfcc', 'average', or 'micro'.

  Returns:
    Input and output tensor objects.

  Raises:
    Exception: If the preprocessing mode isn't recognized.
  """

  words_list = input_data.prepare_words_list(wanted_words.split(','))
  model_settings = models.prepare_model_settings(
      len(words_list), sample_rate, clip_duration_ms, window_size_ms,
      window_stride_ms, feature_bin_count, preprocess)
  runtime_settings = {'clip_stride_ms': clip_stride_ms}

  wav_data_placeholder = tf.compat.v1.placeholder(tf.string, [],
                                                  name='wav_data')
  decoded_sample_data = tf.audio.decode_wav(
      wav_data_placeholder,
      desired_channels=1,
      desired_samples=model_settings['desired_samples'],
      name='decoded_sample_data')
  spectrogram = audio_ops.audio_spectrogram(
      decoded_sample_data.audio,
      window_size=model_settings['window_size_samples'],
      stride=model_settings['window_stride_samples'],
      magnitude_squared=True)

  if preprocess == 'average':
    fingerprint_input = tf.nn.pool(
        input=tf.expand_dims(spectrogram, -1),
        window_shape=[1, model_settings['average_window_width']],
        strides=[1, model_settings['average_window_width']],
        pooling_type='AVG',
        padding='SAME')
  elif preprocess == 'mfcc':
    fingerprint_input = audio_ops.mfcc(
        spectrogram,
        sample_rate,
        dct_coefficient_count=model_settings['fingerprint_width'])
  elif preprocess == 'micro':
    if not frontend_op:
      raise Exception(
          'Micro frontend op is currently not available when running TensorFlow'
          ' directly from Python, you need to build and run through Bazel, for'
          ' example'
          ' `bazel run tensorflow/examples/speech_commands:freeze_graph`')
    sample_rate = model_settings['sample_rate']
    window_size_ms = (model_settings['window_size_samples'] *
                      1000) / sample_rate
    window_step_ms = (model_settings['window_stride_samples'] *
                      1000) / sample_rate
    int16_input = tf.cast(
        tf.multiply(decoded_sample_data.audio, 32767), tf.int16)
    micro_frontend = frontend_op.audio_microfrontend(
        int16_input,
        sample_rate=sample_rate,
        window_size=window_size_ms,
        window_step=window_step_ms,
        num_channels=model_settings['fingerprint_width'],
        out_scale=1,
        out_type=tf.float32)
    fingerprint_input = tf.multiply(micro_frontend, (10.0 / 256.0))
  else:
    raise Exception('Unknown preprocess mode "%s" (should be "mfcc",'
                    ' "average", or "micro")' % (preprocess))

  fingerprint_size = model_settings['fingerprint_size']
  reshaped_input = tf.reshape(fingerprint_input, [-1, fingerprint_size])

  logits = models.create_model(
      reshaped_input, model_settings, model_architecture, is_training=False,
      runtime_settings=runtime_settings)

  # Create an output to use for inference.
  softmax = tf.nn.softmax(logits, name='labels_softmax')

  return reshaped_input, softmax


def save_graph_def(file_name, frozen_graph_def):
  """Writes a graph def file out to disk.

  Args:
    file_name: Where to save the file.
    frozen_graph_def: GraphDef proto object to save.
  """
  tf.io.write_graph(
      frozen_graph_def,
      os.path.dirname(file_name),
      os.path.basename(file_name),
      as_text=False)
  tf.compat.v1.logging.info('Saved frozen graph to %s', file_name)


def save_saved_model(file_name, sess, input_tensor, output_tensor):
  """Writes a SavedModel out to disk.

  Args:
    file_name: Where to save the file.
    sess: TensorFlow session containing the graph.
    input_tensor: Tensor object defining the input's properties.
    output_tensor: Tensor object defining the output's properties.
  """
  # Store the frozen graph as a SavedModel for v2 compatibility.
  builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(file_name)
  tensor_info_inputs = {
      'input': tf.compat.v1.saved_model.utils.build_tensor_info(input_tensor)
  }
  tensor_info_outputs = {
      'output': tf.compat.v1.saved_model.utils.build_tensor_info(output_tensor)
  }
  signature = (
      tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
          inputs=tensor_info_inputs,
          outputs=tensor_info_outputs,
          method_name=tf.compat.v1.saved_model.signature_constants
          .PREDICT_METHOD_NAME))
  builder.add_meta_graph_and_variables(
      sess,
      [tf.compat.v1.saved_model.tag_constants.SERVING],
      signature_def_map={
          tf.compat.v1.saved_model.signature_constants
          .DEFAULT_SERVING_SIGNATURE_DEF_KEY:
              signature,
      },
  )
  builder.save()


def main(_):
  if FLAGS.quantize:
    try:
      _ = tf.contrib
    except AttributeError as e:
      msg = e.args[0]
      msg += ('\n\n The --quantize option still requires contrib, which is not '
              'part of TensorFlow 2.0. Please install a previous version:'
              '\n    `pip install tensorflow<=1.15`')
      e.args = (msg,)
      raise e

  # Create the model and load its weights.
  sess = tf.compat.v1.InteractiveSession()
  input_tensor, output_tensor = create_inference_graph(
      FLAGS.wanted_words, FLAGS.sample_rate, FLAGS.clip_duration_ms,
      FLAGS.clip_stride_ms, FLAGS.window_size_ms, FLAGS.window_stride_ms,
      FLAGS.feature_bin_count, FLAGS.model_architecture, FLAGS.preprocess)
  if FLAGS.quantize:
    tf.contrib.quantize.create_eval_graph()
  models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint)

  # Turn all the variables into inline constants inside the graph and save it.
  frozen_graph_def = graph_util.convert_variables_to_constants(
      sess, sess.graph_def, ['labels_softmax'])

  if FLAGS.save_format == 'graph_def':
    save_graph_def(FLAGS.output_file, frozen_graph_def)
  elif FLAGS.save_format == 'saved_model':
    save_saved_model(FLAGS.output_file, sess, input_tensor, output_tensor)
  else:
    raise Exception('Unknown save format "%s" (should be "graph_def" or'
                    ' "saved_model")' % (FLAGS.save_format))


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--sample_rate',
      type=int,
      default=16000,
      help='Expected sample rate of the wavs',)
  parser.add_argument(
      '--clip_duration_ms',
      type=int,
      default=1000,
      help='Expected duration in milliseconds of the wavs',)
  parser.add_argument(
      '--clip_stride_ms',
      type=int,
      default=30,
      help='How often to run recognition. Useful for models with cache.',)
  parser.add_argument(
      '--window_size_ms',
      type=float,
      default=30.0,
      help='How long each spectrogram timeslice is',)
  parser.add_argument(
      '--window_stride_ms',
      type=float,
      default=10.0,
      help='How long the stride is between spectrogram timeslices',)
  parser.add_argument(
      '--feature_bin_count',
      type=int,
      default=40,
      help='How many bins to use for the MFCC fingerprint',
  )
  parser.add_argument(
      '--start_checkpoint',
      type=str,
      default='',
      help='If specified, restore this pretrained model before any training.')
  parser.add_argument(
      '--model_architecture',
      type=str,
      default='conv',
      help='What model architecture to use')
  parser.add_argument(
      '--wanted_words',
      type=str,
      default='yes,no,up,down,left,right,on,off,stop,go',
      help='Words to use (others will be added to an unknown label)',)
  parser.add_argument(
      '--output_file', type=str, help='Where to save the frozen graph.')
  parser.add_argument(
      '--quantize',
      type=bool,
      default=False,
      help='Whether to train the model for eight-bit deployment')
  parser.add_argument(
      '--preprocess',
      type=str,
      default='mfcc',
      help='Spectrogram processing mode. Can be "mfcc" or "average"')
  parser.add_argument(
      '--save_format',
      type=str,
      default='graph_def',
      help='How to save the result. Can be "graph_def" or "saved_model"')
  FLAGS, unparsed = parser.parse_known_args()
  tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)

train.py

Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Simple speech recognition to spot a limited number of keywords.

This is a self-contained example script that will train a very basic audio
recognition model in TensorFlow. It downloads the necessary training data and
runs with reasonable defaults to train within a few hours even only using a CPU.
For more information, please see
https://www.tensorflow.org/tutorials/audio_recognition.

It is intended as an introduction to using neural networks for audio
recognition, and is not a full speech recognition system. For more advanced
speech systems, I recommend looking into Kaldi. This network uses a keyword
detection style to spot discrete words from a small vocabulary, consisting of
"yes", "no", "up", "down", "left", "right", "on", "off", "stop", and "go".

To run the training process, use:

bazel run tensorflow/examples/speech_commands:train

This will write out checkpoints to /tmp/speech_commands_train/, and will
download over 1GB of open source training data, so you'll need enough free space
and a good internet connection. The default data is a collection of thousands of
one-second .wav files, each containing one spoken word. This data set is
collected from https://aiyprojects.withgoogle.com/open_speech_recording, please
consider contributing to help improve this and other models!

As training progresses, it will print out its accuracy metrics, which should
rise above 90% by the end. Once it's complete, you can run the freeze script to
get a binary GraphDef that you can easily deploy on mobile applications.

If you want to train on your own data, you'll need to create .wavs with your
recordings, all at a consistent length, and then arrange them into subfolders
organized by label. For example, here's a possible file structure:

my_wavs >
  up >
    audio_0.wav
    audio_1.wav
  down >
    audio_2.wav
    audio_3.wav
  other>
    audio_4.wav
    audio_5.wav

You'll also need to tell the script what labels to look for, using the
`--wanted_words` argument. In this case, 'up,down' might be what you want, and
the audio in the 'other' folder would be used to train an 'unknown' category.

To pull this all together, you'd run:

bazel run tensorflow/examples/speech_commands:train -- \
--data_dir=my_wavs --wanted_words=up,down

"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os.path
import sys

import numpy as np
from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf

import input_data
import models
from tensorflow.python.platform import gfile

FLAGS = None


def main(_):
  # Set the verbosity based on flags (default is INFO, so we see all messages)
  tf.compat.v1.logging.set_verbosity(FLAGS.verbosity)

  # Start a new TensorFlow session.
  sess = tf.compat.v1.InteractiveSession()

  # Begin by making sure we have the training data we need. If you already have
  # training data of your own, use `--data_url= ` on the command line to avoid
  # downloading.
  model_settings = models.prepare_model_settings(
      len(input_data.prepare_words_list(FLAGS.wanted_words.split(','))),
      FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms,
      FLAGS.window_stride_ms, FLAGS.feature_bin_count, FLAGS.preprocess)
  audio_processor = input_data.AudioProcessor(
      FLAGS.data_url, FLAGS.data_dir,
      FLAGS.silence_percentage, FLAGS.unknown_percentage,
      FLAGS.wanted_words.split(','), FLAGS.validation_percentage,
      FLAGS.testing_percentage, model_settings, FLAGS.summaries_dir)
  fingerprint_size = model_settings['fingerprint_size']
  label_count = model_settings['label_count']
  time_shift_samples = int((FLAGS.time_shift_ms * FLAGS.sample_rate) / 1000)
  # Figure out the learning rates for each training phase. Since it's often
  # effective to have high learning rates at the start of training, followed by
  # lower levels towards the end, the number of steps and learning rates can be
  # specified as comma-separated lists to define the rate at each stage. For
  # example --how_many_training_steps=10000,3000 --learning_rate=0.001,0.0001
  # will run 13,000 training loops in total, with a rate of 0.001 for the first
  # 10,000, and 0.0001 for the final 3,000.
  training_steps_list = list(map(int, FLAGS.how_many_training_steps.split(',')))
  learning_rates_list = list(map(float, FLAGS.learning_rate.split(',')))
  if len(training_steps_list) != len(learning_rates_list):
    raise Exception(
        '--how_many_training_steps and --learning_rate must be equal length '
        'lists, but are %d and %d long instead' % (len(training_steps_list),
                                                   len(learning_rates_list)))

  input_placeholder = tf.compat.v1.placeholder(
      tf.float32, [None, fingerprint_size], name='fingerprint_input')
  if FLAGS.quantize:
    fingerprint_min, fingerprint_max = input_data.get_features_range(
        model_settings)
    fingerprint_input = tf.quantization.fake_quant_with_min_max_args(
        input_placeholder, fingerprint_min, fingerprint_max)
  else:
    fingerprint_input = input_placeholder

  logits, dropout_rate = models.create_model(
      fingerprint_input,
      model_settings,
      FLAGS.model_architecture,
      is_training=True)

  # Define loss and optimizer
  ground_truth_input = tf.compat.v1.placeholder(
      tf.int64, [None], name='groundtruth_input')

  # Optionally we can add runtime checks to spot when NaNs or other symptoms of
  # numerical errors start occurring during training.
  control_dependencies = []
  if FLAGS.check_nans:
    checks = tf.compat.v1.add_check_numerics_ops()
    control_dependencies = [checks]

  # Create the back propagation and training evaluation machinery in the graph.
  with tf.compat.v1.name_scope('cross_entropy'):
    cross_entropy_mean = tf.compat.v1.losses.sparse_softmax_cross_entropy(
        labels=ground_truth_input, logits=logits)

  if FLAGS.quantize:
    try:
      tf.contrib.quantize.create_training_graph(quant_delay=0)
    except AttributeError as e:
      msg = e.args[0]
      msg += ('\n\n The --quantize option still requires contrib, which is not '
              'part of TensorFlow 2.0. Please install a previous version:'
              '\n    `pip install tensorflow<=1.15`')
      e.args = (msg,)
      raise e

  with tf.compat.v1.name_scope('train'), tf.control_dependencies(
      control_dependencies):
    learning_rate_input = tf.compat.v1.placeholder(
        tf.float32, [], name='learning_rate_input')
    if FLAGS.optimizer == 'gradient_descent':
      train_step = tf.compat.v1.train.GradientDescentOptimizer(
          learning_rate_input).minimize(cross_entropy_mean)
    elif FLAGS.optimizer == 'momentum':
      train_step = tf.compat.v1.train.MomentumOptimizer(
          learning_rate_input, .9,
          use_nesterov=True).minimize(cross_entropy_mean)
    else:
      raise Exception('Invalid Optimizer')
  predicted_indices = tf.argmax(input=logits, axis=1)
  correct_prediction = tf.equal(predicted_indices, ground_truth_input)
  confusion_matrix = tf.math.confusion_matrix(labels=ground_truth_input,
                                              predictions=predicted_indices,
                                              num_classes=label_count)
  evaluation_step = tf.reduce_mean(input_tensor=tf.cast(correct_prediction,
                                                        tf.float32))
  with tf.compat.v1.get_default_graph().name_scope('eval'):
    tf.compat.v1.summary.scalar('cross_entropy', cross_entropy_mean)
    tf.compat.v1.summary.scalar('accuracy', evaluation_step)

  global_step = tf.compat.v1.train.get_or_create_global_step()
  increment_global_step = tf.compat.v1.assign(global_step, global_step + 1)

  saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())

  # Merge all the summaries and write them out to /tmp/retrain_logs (by default)
  merged_summaries = tf.compat.v1.summary.merge_all(scope='eval')
  train_writer = tf.compat.v1.summary.FileWriter(FLAGS.summaries_dir + '/train',
                                                 sess.graph)
  validation_writer = tf.compat.v1.summary.FileWriter(
      FLAGS.summaries_dir + '/validation')

  tf.compat.v1.global_variables_initializer().run()

  start_step = 1

  if FLAGS.start_checkpoint:
    models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint)
    start_step = global_step.eval(session=sess)

  tf.compat.v1.logging.info('Training from step: %d ', start_step)

  # Save graph.pbtxt.
  tf.io.write_graph(sess.graph_def, FLAGS.train_dir,
                    FLAGS.model_architecture + '.pbtxt')

  # Save list of words.
  with gfile.GFile(
      os.path.join(FLAGS.train_dir, FLAGS.model_architecture + '_labels.txt'),
      'w') as f:
    f.write('\n'.join(audio_processor.words_list))

  # Training loop.
  training_steps_max = np.sum(training_steps_list)
  for training_step in xrange(start_step, training_steps_max + 1):
    # Figure out what the current learning rate is.
    training_steps_sum = 0
    for i in range(len(training_steps_list)):
      training_steps_sum += training_steps_list[i]
      if training_step <= training_steps_sum:
        learning_rate_value = learning_rates_list[i]
        break
    # Pull the audio samples we'll use for training.
    train_fingerprints, train_ground_truth = audio_processor.get_data(
        FLAGS.batch_size, 0, model_settings, FLAGS.background_frequency,
        FLAGS.background_volume, time_shift_samples, 'training', sess)
    # Run the graph with this batch of training data.
    train_summary, train_accuracy, cross_entropy_value, _, _ = sess.run(
        [
            merged_summaries,
            evaluation_step,
            cross_entropy_mean,
            train_step,
            increment_global_step,
        ],
        feed_dict={
            fingerprint_input: train_fingerprints,
            ground_truth_input: train_ground_truth,
            learning_rate_input: learning_rate_value,
            dropout_rate: 0.5
        })
    train_writer.add_summary(train_summary, training_step)
    tf.compat.v1.logging.debug(
        'Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' %
        (training_step, learning_rate_value, train_accuracy * 100,
         cross_entropy_value))
    is_last_step = (training_step == training_steps_max)
    if (training_step % FLAGS.eval_step_interval) == 0 or is_last_step:
      tf.compat.v1.logging.info(
          'Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' %
          (training_step, learning_rate_value, train_accuracy * 100,
           cross_entropy_value))
      set_size = audio_processor.set_size('validation')
      total_accuracy = 0
      total_conf_matrix = None
      for i in xrange(0, set_size, FLAGS.batch_size):
        validation_fingerprints, validation_ground_truth = (
            audio_processor.get_data(FLAGS.batch_size, i, model_settings, 0.0,
                                     0.0, 0, 'validation', sess))
        # Run a validation step and capture training summaries for TensorBoard
        # with the `merged` op.
        validation_summary, validation_accuracy, conf_matrix = sess.run(
            [merged_summaries, evaluation_step, confusion_matrix],
            feed_dict={
                fingerprint_input: validation_fingerprints,
                ground_truth_input: validation_ground_truth,
                dropout_rate: 0.0
            })
        validation_writer.add_summary(validation_summary, training_step)
        batch_size = min(FLAGS.batch_size, set_size - i)
        total_accuracy += (validation_accuracy * batch_size) / set_size
        if total_conf_matrix is None:
          total_conf_matrix = conf_matrix
        else:
          total_conf_matrix += conf_matrix
      tf.compat.v1.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
      tf.compat.v1.logging.info('Step %d: Validation accuracy = %.1f%% (N=%d)' %
                                (training_step, total_accuracy * 100, set_size))

    # Save the model checkpoint periodically.
    if (training_step % FLAGS.save_step_interval == 0 or
        training_step == training_steps_max):
      checkpoint_path = os.path.join(FLAGS.train_dir,
                                     FLAGS.model_architecture + '.ckpt')
      tf.compat.v1.logging.info('Saving to "%s-%d"', checkpoint_path,
                                training_step)
      saver.save(sess, checkpoint_path, global_step=training_step)

  set_size = audio_processor.set_size('testing')
  tf.compat.v1.logging.info('set_size=%d', set_size)
  total_accuracy = 0
  total_conf_matrix = None
  for i in xrange(0, set_size, FLAGS.batch_size):
    test_fingerprints, test_ground_truth = audio_processor.get_data(
        FLAGS.batch_size, i, model_settings, 0.0, 0.0, 0, 'testing', sess)
    test_accuracy, conf_matrix = sess.run(
        [evaluation_step, confusion_matrix],
        feed_dict={
            fingerprint_input: test_fingerprints,
            ground_truth_input: test_ground_truth,
            dropout_rate: 0.0
        })
    batch_size = min(FLAGS.batch_size, set_size - i)
    total_accuracy += (test_accuracy * batch_size) / set_size
    if total_conf_matrix is None:
      total_conf_matrix = conf_matrix
    else:
      total_conf_matrix += conf_matrix
  tf.compat.v1.logging.warn('Confusion Matrix:\n %s' % (total_conf_matrix))
  tf.compat.v1.logging.warn('Final test accuracy = %.1f%% (N=%d)' %
                            (total_accuracy * 100, set_size))


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--data_url',
      type=str,
      # pylint: disable=line-too-long
      default='https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz',
      # pylint: enable=line-too-long
      help='Location of speech training data archive on the web.')
  parser.add_argument(
      '--data_dir',
      type=str,
      default='/tmp/speech_dataset/',
      help="""\
      Where to download the speech training data to.
      """)
  parser.add_argument(
      '--background_volume',
      type=float,
      default=0.1,
      help="""\
      How loud the background noise should be, between 0 and 1.
      """)
  parser.add_argument(
      '--background_frequency',
      type=float,
      default=0.8,
      help="""\
      How many of the training samples have background noise mixed in.
      """)
  parser.add_argument(
      '--silence_percentage',
      type=float,
      default=10.0,
      help="""\
      How much of the training data should be silence.
      """)
  parser.add_argument(
      '--unknown_percentage',
      type=float,
      default=10.0,
      help="""\
      How much of the training data should be unknown words.
      """)
  parser.add_argument(
      '--time_shift_ms',
      type=float,
      default=100.0,
      help="""\
      Range to randomly shift the training audio by in time.
      """)
  parser.add_argument(
      '--testing_percentage',
      type=int,
      default=10,
      help='What percentage of wavs to use as a test set.')
  parser.add_argument(
      '--validation_percentage',
      type=int,
      default=10,
      help='What percentage of wavs to use as a validation set.')
  parser.add_argument(
      '--sample_rate',
      type=int,
      default=16000,
      help='Expected sample rate of the wavs',)
  parser.add_argument(
      '--clip_duration_ms',
      type=int,
      default=1000,
      help='Expected duration in milliseconds of the wavs',)
  parser.add_argument(
      '--window_size_ms',
      type=float,
      default=30.0,
      help='How long each spectrogram timeslice is.',)
  parser.add_argument(
      '--window_stride_ms',
      type=float,
      default=10.0,
      help='How far to move in time between spectrogram timeslices.',
  )
  parser.add_argument(
      '--feature_bin_count',
      type=int,
      default=40,
      help='How many bins to use for the MFCC fingerprint',
  )
  parser.add_argument(
      '--how_many_training_steps',
      type=str,
      default='15000,3000',
      help='How many training loops to run',)
  parser.add_argument(
      '--eval_step_interval',
      type=int,
      default=400,
      help='How often to evaluate the training results.')
  parser.add_argument(
      '--learning_rate',
      type=str,
      default='0.001,0.0001',
      help='How large a learning rate to use when training.')
  parser.add_argument(
      '--batch_size',
      type=int,
      default=100,
      help='How many items to train with at once',)
  parser.add_argument(
      '--summaries_dir',
      type=str,
      default='/tmp/retrain_logs',
      help='Where to save summary logs for TensorBoard.')
  parser.add_argument(
      '--wanted_words',
      type=str,
      default='yes,no,up,down,left,right,on,off,stop,go',
      help='Words to use (others will be added to an unknown label)',)
  parser.add_argument(
      '--train_dir',
      type=str,
      default='/tmp/speech_commands_train',
      help='Directory to write event logs and checkpoint.')
  parser.add_argument(
      '--save_step_interval',
      type=int,
      default=100,
      help='Save model checkpoint every save_steps.')
  parser.add_argument(
      '--start_checkpoint',
      type=str,
      default='',
      help='If specified, restore this pretrained model before any training.')
  parser.add_argument(
      '--model_architecture',
      type=str,
      default='conv',
      help='What model architecture to use')
  parser.add_argument(
      '--check_nans',
      type=bool,
      default=False,
      help='Whether to check for invalid numbers during processing')
  parser.add_argument(
      '--quantize',
      type=bool,
      default=False,
      help='Whether to train the model for eight-bit deployment')
  parser.add_argument(
      '--preprocess',
      type=str,
      default='mfcc',
      help='Spectrogram processing mode. Can be "mfcc", "average", or "micro"')

  # Function used to parse --verbosity argument
  def verbosity_arg(value):
    """Parses verbosity argument.

    Args:
      value: A member of tf.logging.
    Raises:
      ArgumentTypeError: Not an expected value.
    """
    value = value.upper()
    if value == 'DEBUG':
      return tf.compat.v1.logging.DEBUG
    elif value == 'INFO':
      return tf.compat.v1.logging.INFO
    elif value == 'WARN':
      return tf.compat.v1.logging.WARN
    elif value == 'ERROR':
      return tf.compat.v1.logging.ERROR
    elif value == 'FATAL':
      return tf.compat.v1.logging.FATAL
    else:
      raise argparse.ArgumentTypeError('Not an expected value')
  parser.add_argument(
      '--verbosity',
      type=verbosity_arg,
      default=tf.compat.v1.logging.INFO,
      help='Log verbosity. Can be "DEBUG", "INFO", "WARN", "ERROR", or "FATAL"')
  parser.add_argument(
      '--optimizer',
      type=str,
      default='gradient_descent',
      help='Optimizer (gradient_descent or momentum)')

  FLAGS, unparsed = parser.parse_known_args()
  tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)

wav_to_features.py

Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Converts WAV audio files into input features for neural networks.

The models used in this example take in two-dimensional spectrograms as the
input to their neural network portions. For testing and porting purposes it's
useful to be able to generate these spectrograms outside of the full model, so
that on-device implementations using their own FFT and streaming code can be
tested against the version used in training for example. The output is as a
C source file, so it can be easily linked into an embedded test application.

To use this, run:

bazel run tensorflow/examples/speech_commands:wav_to_features -- \
--input_wav=my.wav --output_c_file=my_wav_data.c

"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os.path
import sys

import tensorflow as tf

import input_data
import models
from tensorflow.python.platform import gfile

FLAGS = None


def wav_to_features(sample_rate, clip_duration_ms, window_size_ms,
                    window_stride_ms, feature_bin_count, quantize, preprocess,
                    input_wav, output_c_file):
  """Converts an audio file into its corresponding feature map.

  Args:
    sample_rate: Expected sample rate of the wavs.
    clip_duration_ms: Expected duration in milliseconds of the wavs.
    window_size_ms: How long each spectrogram timeslice is.
    window_stride_ms: How far to move in time between spectrogram timeslices.
    feature_bin_count: How many bins to use for the feature fingerprint.
    quantize: Whether to train the model for eight-bit deployment.
    preprocess: Spectrogram processing mode; "mfcc", "average" or "micro".
    input_wav: Path to the audio WAV file to read.
    output_c_file: Where to save the generated C source file.
  """

  # Start a new TensorFlow session.
  sess = tf.compat.v1.InteractiveSession()

  model_settings = models.prepare_model_settings(
      0, sample_rate, clip_duration_ms, window_size_ms, window_stride_ms,
      feature_bin_count, preprocess)
  audio_processor = input_data.AudioProcessor(None, None, 0, 0, '', 0, 0,
                                              model_settings, None)

  results = audio_processor.get_features_for_wav(input_wav, model_settings,
                                                 sess)
  features = results[0]

  variable_base = os.path.splitext(os.path.basename(input_wav).lower())[0]

  # Save a C source file containing the feature data as an array.
  with gfile.GFile(output_c_file, 'w') as f:
    f.write('/* File automatically created by\n')
    f.write(' * tensorflow/examples/speech_commands/wav_to_features.py \\\n')
    f.write(' * --sample_rate=%d \\\n' % sample_rate)
    f.write(' * --clip_duration_ms=%d \\\n' % clip_duration_ms)
    f.write(' * --window_size_ms=%d \\\n' % window_size_ms)
    f.write(' * --window_stride_ms=%d \\\n' % window_stride_ms)
    f.write(' * --feature_bin_count=%d \\\n' % feature_bin_count)
    if quantize:
      f.write(' * --quantize=1 \\\n')
    f.write(' * --preprocess="%s" \\\n' % preprocess)
    f.write(' * --input_wav="%s" \\\n' % input_wav)
    f.write(' * --output_c_file="%s" \\\n' % output_c_file)
    f.write(' */\n\n')
    f.write('const int g_%s_width = %d;\n' %
            (variable_base, model_settings['fingerprint_width']))
    f.write('const int g_%s_height = %d;\n' %
            (variable_base, model_settings['spectrogram_length']))
    if quantize:
      features_min, features_max = input_data.get_features_range(model_settings)
      f.write('const unsigned char g_%s_data[] = {' % variable_base)
      i = 0
      for value in features.flatten():
        quantized_value = int(
            round(
                (255 * (value - features_min)) / (features_max - features_min)))
        if quantized_value < 0:
          quantized_value = 0
        if quantized_value > 255:
          quantized_value = 255
        if i == 0:
          f.write('\n  ')
        f.write('%d, ' % (quantized_value))
        i = (i + 1) % 10
    else:
      f.write('const float g_%s_data[] = {\n' % variable_base)
      i = 0
      for value in features.flatten():
        if i == 0:
          f.write('\n  ')
        f.write('%f, ' % value)
        i = (i + 1) % 10
    f.write('\n};\n')


def main(_):
  # We want to see all the logging messages.
  tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
  wav_to_features(FLAGS.sample_rate, FLAGS.clip_duration_ms,
                  FLAGS.window_size_ms, FLAGS.window_stride_ms,
                  FLAGS.feature_bin_count, FLAGS.quantize, FLAGS.preprocess,
                  FLAGS.input_wav, FLAGS.output_c_file)
  tf.compat.v1.logging.info('Wrote to "%s"' % (FLAGS.output_c_file))


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--sample_rate',
      type=int,
      default=16000,
      help='Expected sample rate of the wavs',)
  parser.add_argument(
      '--clip_duration_ms',
      type=int,
      default=1000,
      help='Expected duration in milliseconds of the wavs',)
  parser.add_argument(
      '--window_size_ms',
      type=float,
      default=30.0,
      help='How long each spectrogram timeslice is.',)
  parser.add_argument(
      '--window_stride_ms',
      type=float,
      default=10.0,
      help='How far to move in time between spectrogram timeslices.',
  )
  parser.add_argument(
      '--feature_bin_count',
      type=int,
      default=40,
      help='How many bins to use for the MFCC fingerprint',
  )
  parser.add_argument(
      '--quantize',
      type=bool,
      default=False,
      help='Whether to train the model for eight-bit deployment')
  parser.add_argument(
      '--preprocess',
      type=str,
      default='mfcc',
      help='Spectrogram processing mode. Can be "mfcc", "average", or "micro"')
  parser.add_argument(
      '--input_wav',
      type=str,
      default=None,
      help='Path to the audio WAV file to read')
  parser.add_argument(
      '--output_c_file',
      type=str,
      default=None,
      help='Where to save the generated C source file containing the features')

  FLAGS, unparsed = parser.parse_known_args()
  tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)

test_streaming_accuracy.py

Python
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Tool to create accuracy statistics on a continuous stream of samples.

This is designed to be an environment for running experiments on new models and
settings to understand the effects they will have in a real application. You
need to supply it with a long audio file containing sounds you want to recognize
and a text file listing the labels of each sound along with the time they occur.
With this information, and a frozen model, the tool will process the audio
stream, apply the model, and keep track of how many mistakes and successes the
model achieved.

The matched percentage is the number of sounds that were correctly classified,
as a percentage of the total number of sounds listed in the ground truth file.
A correct classification is when the right label is chosen within a short time
of the expected ground truth, where the time tolerance is controlled by the
'time_tolerance_ms' command line flag.

The wrong percentage is how many sounds triggered a detection (the classifier
figured out it wasn't silence or background noise), but the detected class was
wrong. This is also a percentage of the total number of ground truth sounds.

The false positive percentage is how many sounds were detected when there was
only silence or background noise. This is also expressed as a percentage of the
total number of ground truth sounds, though since it can be large it may go
above 100%.

The easiest way to get an audio file and labels to test with is by using the
'generate_streaming_test_wav' script. This will synthesize a test file with
randomly placed sounds and background noise, and output a text file with the
ground truth.

If you want to test natural data, you need to use a .wav with the same sample
rate as your model (often 16,000 samples per second), and note down where the
sounds occur in time. Save this information out as a comma-separated text file,
where the first column is the label and the second is the time in seconds from
the start of the file that it occurs.

Here's an example of how to run the tool:

bazel run tensorflow/examples/speech_commands:test_streaming_accuracy_py -- \
--wav=/tmp/streaming_test_bg.wav \
--ground-truth=/tmp/streaming_test_labels.txt --verbose \
--model=/tmp/conv_frozen.pb \
--labels=/tmp/speech_commands_train/conv_labels.txt \
--clip_duration_ms=1000 --detection_threshold=0.70 --average_window_ms=500 \
--suppression_ms=500 --time_tolerance_ms=1500
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys

import numpy
import tensorflow as tf

from accuracy_utils import StreamingAccuracyStats
from recognize_commands import RecognizeCommands
from recognize_commands import RecognizeResult
from tensorflow.python.ops import io_ops

FLAGS = None


def load_graph(mode_file):
  """Read a tensorflow model, and creates a default graph object."""
  graph = tf.Graph()
  with graph.as_default():
    od_graph_def = tf.compat.v1.GraphDef()
    with tf.io.gfile.GFile(mode_file, 'rb') as fid:
      serialized_graph = fid.read()
      od_graph_def.ParseFromString(serialized_graph)
      tf.import_graph_def(od_graph_def, name='')
  return graph


def read_label_file(file_name):
  """Load a list of label."""
  label_list = []
  with open(file_name, 'r') as f:
    for line in f:
      label_list.append(line.strip())
  return label_list


def read_wav_file(filename):
  """Load a wav file and return sample_rate and numpy data of float64 type."""
  with tf.compat.v1.Session(graph=tf.Graph()) as sess:
    wav_filename_placeholder = tf.compat.v1.placeholder(tf.string, [])
    wav_loader = io_ops.read_file(wav_filename_placeholder)
    wav_decoder = tf.audio.decode_wav(wav_loader, desired_channels=1)
    res = sess.run(wav_decoder, feed_dict={wav_filename_placeholder: filename})
  return res.sample_rate, res.audio.flatten()


def main(_):
  label_list = read_label_file(FLAGS.labels)
  sample_rate, data = read_wav_file(FLAGS.wav)
  # Init instance of RecognizeCommands with given parameters.
  recognize_commands = RecognizeCommands(
      labels=label_list,
      average_window_duration_ms=FLAGS.average_window_duration_ms,
      detection_threshold=FLAGS.detection_threshold,
      suppression_ms=FLAGS.suppression_ms,
      minimum_count=4)

  # Init instance of StreamingAccuracyStats and load ground truth.
  stats = StreamingAccuracyStats()
  stats.read_ground_truth_file(FLAGS.ground_truth)
  recognize_element = RecognizeResult()
  all_found_words = []
  data_samples = data.shape[0]
  clip_duration_samples = int(FLAGS.clip_duration_ms * sample_rate / 1000)
  clip_stride_samples = int(FLAGS.clip_stride_ms * sample_rate / 1000)
  audio_data_end = data_samples - clip_duration_samples

  # Load model and create a tf session to process audio pieces
  recognize_graph = load_graph(FLAGS.model)
  with recognize_graph.as_default():
    with tf.compat.v1.Session() as sess:

      # Get input and output tensor
      data_tensor = sess.graph.get_tensor_by_name(FLAGS.input_names[0])
      sample_rate_tensor = sess.graph.get_tensor_by_name(FLAGS.input_names[1])
      output_softmax_tensor = sess.graph.get_tensor_by_name(FLAGS.output_name)

      # Inference along audio stream.
      for audio_data_offset in range(0, audio_data_end, clip_stride_samples):
        input_start = audio_data_offset
        input_end = audio_data_offset + clip_duration_samples
        outputs = sess.run(
            output_softmax_tensor,
            feed_dict={
                data_tensor:
                    numpy.expand_dims(data[input_start:input_end], axis=-1),
                sample_rate_tensor:
                    sample_rate
            })
        outputs = numpy.squeeze(outputs)
        current_time_ms = int(audio_data_offset * 1000 / sample_rate)
        try:
          recognize_commands.process_latest_result(outputs, current_time_ms,
                                                   recognize_element)
        except ValueError as e:
          tf.compat.v1.logging.error('Recognition processing failed: {}' % e)
          return
        if (recognize_element.is_new_command and
            recognize_element.founded_command != '_silence_'):
          all_found_words.append(
              [recognize_element.founded_command, current_time_ms])
          if FLAGS.verbose:
            stats.calculate_accuracy_stats(all_found_words, current_time_ms,
                                           FLAGS.time_tolerance_ms)
            try:
              recognition_state = stats.delta()
            except ValueError as e:
              tf.compat.v1.logging.error(
                  'Statistics delta computing failed: {}'.format(e))
            else:
              tf.compat.v1.logging.info('{}ms {}:{}{}'.format(
                  current_time_ms, recognize_element.founded_command,
                  recognize_element.score, recognition_state))
              stats.print_accuracy_stats()
  stats.calculate_accuracy_stats(all_found_words, -1, FLAGS.time_tolerance_ms)
  stats.print_accuracy_stats()


if __name__ == '__main__':
  parser = argparse.ArgumentParser(description='test_streaming_accuracy')
  parser.add_argument(
      '--wav', type=str, default='', help='The wave file path to evaluate.')
  parser.add_argument(
      '--ground-truth',
      type=str,
      default='',
      help='The ground truth file path corresponding to wav file.')
  parser.add_argument(
      '--labels',
      type=str,
      default='',
      help='The label file path containing all possible classes.')
  parser.add_argument(
      '--model', type=str, default='', help='The model used for inference')
  parser.add_argument(
      '--input-names',
      type=str,
      nargs='+',
      default=['decoded_sample_data:0', 'decoded_sample_data:1'],
      help='Input name list involved in model graph.')
  parser.add_argument(
      '--output-name',
      type=str,
      default='labels_softmax:0',
      help='Output name involved in model graph.')
  parser.add_argument(
      '--clip-duration-ms',
      type=int,
      default=1000,
      help='Length of each audio clip fed into model.')
  parser.add_argument(
      '--clip-stride-ms',
      type=int,
      default=30,
      help='Length of audio clip stride over main trap.')
  parser.add_argument(
      '--average_window_duration_ms',
      type=int,
      default=500,
      help='Length of average window used for smoothing results.')
  parser.add_argument(
      '--detection-threshold',
      type=float,
      default=0.7,
      help='The confidence for filtering unreliable commands')
  parser.add_argument(
      '--suppression_ms',
      type=int,
      default=500,
      help='The time interval between every two adjacent commands')
  parser.add_argument(
      '--time-tolerance-ms',
      type=int,
      default=1500,
      help='Time tolerance before and after the timestamp of this audio clip '
      'to match ground truth')
  parser.add_argument(
      '--verbose',
      action='store_true',
      default=False,
      help='Whether to print streaming accuracy on stdout.')

  FLAGS, unparsed = parser.parse_known_args()
  tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
  tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)

Credits

Yuyu Qian
1 project • 3 followers
YuetongYang
1 project • 2 followers
Xurui Zhang
1 project • 1 follower
Yifan Wang
0 projects • 2 followers
Thanks to Yuyu Qian, Yifan Wang, Xurui Zhang, and Yuetong Yang.

Comments