Sudhir KshirsagarDeven MistryJacob Specht
Created December 15, 2023 © GPL3+

Improving Situational Awareness in Disasters with Edge AI

Extract critical information from the "techno-speak" in IoT data using a fast Small Language Model (SLM) running on NVIDIA Jetson Orin Nano.

AdvancedOver 1 day22
Improving Situational Awareness in Disasters with Edge AI

Things used in this project

Hardware components

Jetson Orin Nano Developer Kit
NVIDIA Jetson Orin Nano Developer Kit
Used one with 128GB SD Card (Jetpack 5) and the second with 1TB NVME (Jetpack 6). Got one as a prize and bought the second one.
×2

Software apps and online services

NVIDIA SDK Manager
VS Code
Microsoft VS Code

Story

Read more

Code

MQTT Subscribers and Model Inferencing

Python
This code is capable spawning multiple MQTT subscribers that receive IoT JSON data and send that data along with a leading prompt to a Docker container that is running a Small Language Model (SLM)
import click
import fcntl
import json
import pandas as pd
import termios
import threading
import time
import torch
import yaml

from copy import deepcopy
from datetime import datetime
from fastcore.style import Style
from os.path import join
from paho.mqtt import client as mqtt_client
from peft import AutoPeftModelForCausalLM
import psutil
from queue import Queue, Empty
from transformers import AutoTokenizer, GenerationConfig

from iotgpt.db_utils import SqliteConnector, PgConnector, SENSOR_ID, DESC, \
  LON, LAT, UNIT, NAME, SENSOR_TYPE, MANUF, MODEL

from iotgpt.llms.bard import ask_bard
from iotgpt.llms.openai import ask_openai

from iotgpt.mqtt.models import Message, SensorData
from iotgpt.mqtt.utils import connect_mqtt, string_is_date_time, \
  pd_str_to_unixtime

from iotgpt.tts import text_to_speech
from iotgpt.utils import (generate_sentences_from_json,
                          standardize_keys)

# highlighter in the terminal
GB = Style().green.bold
YU = Style().yellow.underline
Y = Style().yellow
U = Style().underline
G = Style().green
B = Style().blue
R = Style().red

STANDARDIZE = 'STANDARDIZE'
SENTENCIFY = 'SENTENCIFY'
USE_BARD = 'USE_BARD'
USE_OPENAI = 'USE_OPENAI'
USE_LOCAL_LLAMA2 = 'USE_LOCAL_LLAMA2'
USE_LOCAL_LLAMA_DOCKER = 'USE_LOCAL_LLAMA_DOCKER'
TTS = 'TTS'
LLM_QUESTION = 'LLM_QUESTION'
USE_LOOKUP = 'LOOKUP'

LLM_PROCESS_NAME = 'python3'
LLM_PROCESS_DESC = 'local_llm.chat'
LLM_PROMPT = 'What are the highlights from the following JSON message? Are there any anomalies?'

all_subscribers = {}


class MessageHandler:

  message_queue = None
  model = None
  config_dict = None
  lookup_db_data = None
  actions = None
  tokenizer = None
  sqlite_conn = None
  psql_conn_str = 'postgresql://postgres:gq010102@localhost:5432/ai3'
  engine = None
  conn = None
  sqlite_connector: SqliteConnector = None
  pg_connector: PgConnector = None
  docker_llm_pid: int = None

  def __init__(self, config_dict):
    self.config_dict = config_dict
    self.actions = self.config_dict['actions']
    self.message_queue = Queue()
    # self.create_db_engine()
    self.sqlite_connector = SqliteConnector(
      db_path=config_dict['LOOKUP_DB_PATH'])
    self.pg_connector = PgConnector(db_url=self.psql_conn_str)

  def load_model(self, model_location):
    # Load the model only if it's not already loaded
    if self.model is None:
      self.model = AutoPeftModelForCausalLM.from_pretrained(
        model_location,
        low_cpu_mem_usage=True,
        torch_dtype=torch.float16,
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,
      )
    self.tokenizer = AutoTokenizer.from_pretrained(model_location)
    self.gen_config = GenerationConfig.from_pretrained(
      model_location, 'generation_config.json')

  def standardize(self, mqtt_dict):
    """Standardizes a given MQTT Dictionary (JSON format).

    Args:
        mqtt_dict (_type_): _description_

    Returns:
        dict: std_json_msg, a standardized JSON message
    """
    print(G('Original'), mqtt_dict)
    std_json_msg = standardize_keys(mqtt_dict,
                                    self.sqlite_connector,
                                    self.config_dict['NK_KEY'],
                                    R)
    print(Y('Standardized'), std_json_msg)
    print('-' * 30)
    return std_json_msg

  def sentencify(self, std_json_msg):
    """Returns sentences, paragraph, to_ask."""
    # sentencified JSON created from the payload AND
    # joined where there's a whitespace, for Bard and the in-house model
    # generate sentences from std JSON
    try:
      std_json_msg.pop('changedKey')
    except Exception:
      pass

    sentences = generate_sentences_from_json(std_json_msg)
    paragraph = ' '.join(sentences)
    to_ask = f'{self.config_dict[LLM_QUESTION]} {paragraph}'

    # simple header
    print(GB('Sentencified Output: '))
    print(f'{Y("Sentence:")}', end=' ')
    for sentence in sentences:
      print(sentence, end=' ')

    print('-' * 30)

    return sentences, paragraph, to_ask

  def use_bard(self, questions, to_ask, bard_responses):
    """Returns bard_responses dictionary."""
    # This doesn't work anymore. If google releases an API officially
    # we can modify the function to work
    questions.append(to_ask)
    print(GB('Bard says: '))
    bard_responses[to_ask] = ask_bard(self.config_dict['BARD_TOKEN'], to_ask)
    print(bard_responses[to_ask])
    print('-' * 30)
    return bard_responses.get(to_ask, '')

  def use_openai(self, questions, to_ask, gpt_responses):
    """Returns gpt_responses dictionary."""
    # Using GPT API for summarization of sentencified JSON
    questions.append(to_ask)
    # print(GB('OpenAI says: '))
    resp = ask_openai(self.config_dict['OPENAI_TOKEN'], to_ask)
    # Remove the phrase from the beginning
    if not resp:
      return ''
    if resp.startswith(self.config_dict['REMOVE_PHRASE']):
      resp = resp[
        len(self.config_dict['REMOVE_PHRASE']):].lstrip()
      # Capitalizing 1st letter
      gpt_responses[to_ask] = resp[0].upper() + resp[1:]
    return gpt_responses.get(to_ask, '')

  def use_local_llm(self, to_ask, in_house_responses):
    """Returns in_house_responses dictionary."""
    # Using in-house model for no-internet case.
    # Currently using LLAMA2 for summarization of JSON sentences
    # initializing in-house model
    # model_name = self.config_dict['MODEL_NAME']
    # print(GB(f'At {datetime.now()} In house trained {model_name} says: '))
    batch = self.tokenizer(to_ask, return_tensors='pt').to('cuda')
    with torch.autocast("cuda"):
      output_tokens = self.model.generate(
        **batch, max_new_tokens=250, generation_config=self.gen_config)

    model_resp = self.tokenizer.decode(
      output_tokens[0], skip_special_tokens=True)
    # Removing question from the response
    output = model_resp.split(to_ask, 1)[-1]
    # Remove the phrase from the beginning
    if output.startswith(self.config_dict['REMOVE_PHRASE']):
      output = output[
        len(self.config_dict['REMOVE_PHRASE']):].lstrip()
      output = output[0].upper() + output[1:]    # Capitalizing 1st letter

    in_house_responses[to_ask] = output
    del batch
    # del model
    torch.cuda.empty_cache()
    return in_house_responses.get(to_ask, '')

  def populate_docker_llm_pid(self):
    for process in psutil.process_iter():
      if process.name() == LLM_PROCESS_NAME:
        for cmd in process.cmdline():
          if LLM_PROCESS_DESC in cmd:
            self.docker_llm_pid = process.pid
            return

  def msg_to_docker_llm(self, to_ask):
    if not self.docker_llm_pid:
      self.populate_docker_llm_pid()
    try:
      # Process ID will always be 1 (stated such in code)
      # docker run <container_id> "/bin/bash "
      msg = to_ask + '\n'
      # ###########################################################################
      # We need to figure out how to pipe this through to docker:
      with open(join('/proc', str(self.docker_llm_pid),
                     'fd', '1'), 'wb') as tty_fd:
        for byte in bytes(msg, 'utf-8'):
          fcntl.ioctl(tty_fd, termios.TIOCSTI, bytes([byte]))
      # ###########################################################################

    except Exception as err:
      print(err)

  def tts(self):
    """Void method that triggers text-to-speech."""
    # The summary from one of the models is converted to speech here
    # text to speech flag
    print(GB(f'At {datetime.now()} Text-to-speech output'))
    if self.actions['TTS_INPUT'] == 'OPENAI':
      for k, v in self.gpt_responses.items():
        text_to_speech(v)
    else:
      for v in self.in_house_responses.values():
        text_to_speech(v)

  def checkpoint(self, count, method):
    current_time = datetime.now()
    elapsed_time = current_time - self.prev_time
    print(R(f'{count}: Finished {method} in {elapsed_time} seconds'))
    self.prev_time = current_time

  def consolidate_original_std_dicts(self, mqtt_dict, std_dict):
    """Method to consolidate the original and standardized dictionaries."""
    # {"time": "2023-04-06 10:03:21.274943488", "nk": 1.410912459410251}
    # {'The time at which reading was collected': '2023-04-06 10:03:21.274943488', 'depth value': 1.410912459410251}  # noqa: E501
    consolidated_dict = {'original_keys': [],
                         'std_keys': [],
                         'values': []}
    for k, v in mqtt_dict.items():
      # We expect two types of values in this dictionary, a time value and a
      # measurement value.
      if isinstance(v, str) and string_is_date_time(v):
        consolidated_dict['timestamp'] = pd_str_to_unixtime(v)
        consolidated_dict['original_time_key'] = k
      else:
        consolidated_dict['original_keys'].append(k)
        consolidated_dict['values'].append(v)
        # Keep track of how many keys we'll need for std_keys array:
        consolidated_dict['std_keys'].append('')

    for k, v in std_dict.items():
      if isinstance(v, str) and string_is_date_time(v):
        consolidated_dict['std_time_key'] = k
      else:
        # Iterate over values to see which key this corresponds to:
        for idx, val in enumerate(consolidated_dict['values']):
          if val == v:
            # If matched, we know which standardized key matches which original
            consolidated_dict['std_keys'][idx] = k
            break
    return consolidated_dict

  def insert_values(self, mqtt_dict, std_dict, sensor_id, received_unixtime):
    consolidated_dict = self.consolidate_original_std_dicts(mqtt_dict, std_dict)
    print(G('Inserting consolidated dict:'))
    print(G(consolidated_dict))
    for idx, key in enumerate(consolidated_dict['original_keys']):
      datum = SensorData()
      datum.sensor = sensor_id
      datum.key = key
      datum.value = consolidated_dict['values'][idx]
      # If we're in testing mode, i.e. not live mode, then we should insert
      # data with the received unix time instead of the measurement time from
      # the testing data. This is to avoid duplicate keys in the DB inserts.
      if self.config_dict['TESTER_MODE']:
        datum.timestamp = received_unixtime
      else:
        datum.timestamp = consolidated_dict['timestamp']
      datum.nk_guess = consolidated_dict['std_keys'][idx]
      datum.insert(self.pg_connector)

  def handle(self, message):
    # variables to store models' responses
    bard_responses = {}
    gpt_responses = {}
    in_house_responses = {}

    # variable to store questions
    questions = []

    # Received unix time for inserting
    received_unixtime = int(time.time())

    key = message['key']
    msg = message['msg']
    msg_decoded = msg.payload.decode().replace('Value', key)
    print('-' * 60)
    print(f"Received `{msg_decoded}` from `{msg.topic}` topic")
    print('-' * 30)

    # Receive the payload, decode it and convert it to a dict object
    mqtt_dict = json.loads(msg_decoded)
    # Loading phrase_synonym table from lookup db into a dataframe
    self.lookup_db_data = self.sqlite_connector.select_phrase_synonym()

    db_message = Message(connector=self.pg_connector,
                         topic=msg.topic,
                         payload=msg_decoded,
                         received_unixtime=received_unixtime)

    start_time = datetime.now()
    self.prev_time = start_time

    if self.actions[USE_LOCAL_LLAMA_DOCKER]:
      print(R('sending message to docker LLM'))
      to_ask = f'{LLM_PROMPT} {json.dumps(msg_decoded)}'
      self.msg_to_docker_llm(to_ask)
      self.checkpoint(message['count'], 'USE_LOCAL_LLAMA_DOCKER')

    if self.actions[USE_LOOKUP]:
      std_json_msg = self.standardize(mqtt_dict)
      self.checkpoint(message['count'], 'standardize')
      self.insert_values(mqtt_dict, std_json_msg, msg.topic, received_unixtime)

      sentences, paragraph, to_ask = self.sentencify(std_json_msg)
      self.checkpoint(message['count'], 'sentencify')

      if self.actions[USE_BARD]:
        db_message.bard_output = self.use_bard(
          questions, to_ask, bard_responses).replace("'", "''")
        self.checkpoint(message['count'], 'use_bard')

      if self.actions[USE_OPENAI]:
        db_message.openai_output = self.use_openai(
          questions, to_ask, gpt_responses).replace("'", "''")
        self.checkpoint(message['count'], 'use_openai')

      if self.actions[USE_LOCAL_LLAMA2]:
        db_message.local_llama_output = self.use_local_llm(
          to_ask, in_house_responses).replace("'", "''")
        print(R(db_message.local_llama_output))
        self.checkpoint(message['count'], 'use_local_llm')

      if self.actions[USE_LOCAL_LLAMA_DOCKER]:
        print(R('sending message to docker LLM'))
        self.msg_to_docker_llm(to_ask)
        self.checkpoint(message['count'], 'USE_LOCAL_LLAMA_DOCKER')

      if self.actions[TTS]:
        self.tts()
        self.checkpoint(message['count'], 'TTS')

    # self.update_message(db_message)
    db_message.update(connector=self.pg_connector)

    total_elapsed = datetime.now() - start_time
    print(R(f'{message["count"]}: Finished flow in {total_elapsed} seconds'))

  def process_messages(self):
    while True:
      try:
        # Try to get a message without blocking
        message = self.message_queue.get_nowait()
      except Empty:
        continue  # If the queue is empty, continue looping
      # Process the message here
      print(B(f"Processing message: {message['count']}"))
      self.handle(message)

  def start(self):
    # Create a thread to continuously process messages
    self.message_processing_thread = threading.Thread(
      target=self.process_messages)
    # Make the thread a daemon so it exits when the main program ends
    self.message_processing_thread.daemon = True
    self.message_processing_thread.start()

  def push(self, msg):
    # new_message = RawMessage(**msg)
    # new_message.insert()
    self.message_queue.put(msg)


class Subscriber:

  msg_count = 0
  config_dict = None
  handler = None
  key = ''

  def __init__(self, config_dict, handler) -> None:
    self.config_dict = config_dict
    self.actions = self.config_dict['actions']
    self.key = self.config_dict['key']
    self.handler = handler

  def process_message(self, client, userdata, msg):
    """_summary_

    Args:
        client (_type_): _description_
        userdata (_type_): _description_
        msg (_type_): _description_
    """
    self.msg_count += 1
    self.handler.push({'client': client, 'userdata': userdata,
                       'msg': msg, 'count': self.msg_count, 'key': self.key})


def publish(client, message, topic):
  result = client.publish(topic, message)
  status = result[0]
  if status == 0:
    print(f"Send `{message}` to topic `{topic}`")
  else:
    print(f"Failed to send message to topic {topic}")


def generic_process_message(client, userdata, msg):
    all_subscribers[msg.topic].process_message(client, userdata, msg)


def subscribe(client: mqtt_client, config_dict: dict, handler: MessageHandler):
  """_summary_

  Args:
      client (mqtt_client): _description_
      config_dict (dict): _description_
  """
  client.subscribe(config_dict['TOPIC'])
  all_subscribers[config_dict['TOPIC']] = Subscriber(config_dict, handler)
  client.on_message = generic_process_message


@click.command()
@click.option('--config', '-c', help='path to the yml config file',
              required=True)
def main(config: str):
  """
  Entry-point for the multi-subscriber script.
  1. Loads the provided config.yml
  2. Creates a paho-mqtt client
  3. Creates an iotgpt MessageHandler
  4. Reads the topics CSV which includes TOPIC IDs as well as sensor metadata
  5. Upserts sensors in the PG DB
  6. If local LLM is enabled, loads the model
  7. Starts the MessageHandler
  8. Iterates over topics and subscribes to them
  9. Finally calls client.loop_forever

  Args:
      config (str): _description_
  """
  config_dict = yaml.safe_load(open(config))
  client = connect_mqtt(
    config_dict['BROKER'], config_dict['PORT'], config_dict['CLIENT_ID'],
    config_dict.get('USERNAME', None), config_dict.get('PASSWORD', None))

  # Create the handler and load the in-house model
  handler = MessageHandler(config_dict)

  topics_path = config_dict.get('TOPICS_CSV', None)
  topics = pd.read_csv(topics_path, header=0)
  handler.pg_connector.upsert_sensors(topics)

  if config_dict['actions']['USE_LOCAL_LLAMA2']:
    model_location = config_dict['SAVED_MODEL_LOCATION']
    handler.load_model(model_location)

  handler.start()

  for _, topic in topics.iterrows():
    new_config_dict = deepcopy(config_dict)
    new_config_dict['TOPIC'] = topic[SENSOR_ID]
    new_config_dict['key'] = topic[DESC]
    new_config_dict['longitude'] = topic[LON]
    new_config_dict['latitude'] = topic[LAT]
    # New fields from CSV
    new_config_dict['unit'] = topic[UNIT]
    new_config_dict['name'] = topic[NAME]
    new_config_dict['sensor_type'] = topic[SENSOR_TYPE]
    new_config_dict['manuf'] = topic[MANUF]
    new_config_dict['model'] = topic[MODEL]
    subscribe(client, new_config_dict, handler)

  client.loop_forever()


if __name__ == '__main__':
  main()

Credits

Sudhir Kshirsagar
6 projects • 2 followers
Deven Mistry
0 projects • 0 followers
Jacob Specht
1 project • 1 follower
Thanks to Dustin Franklin.

Comments