介绍

随着物联网和嵌入式系统的普及,如何在资源受限的硬件上实现高效的机器学习推理成为一个重要课题。行空板 K10 AI 硬件结合 TensorFlow Lite Micro (TFLM),为嵌入式设备上的 TinyML 开发 提供了轻量级的解决方案。从模型训练到最终部署,本文将详细讲解如何通过 K10 硬件 和 Arduino TensorFlow 开发环境,快速实现 嵌入式机器学习 应用。同时,这一教程也适用于类似的硬件平台,例如 ESP32 S3 AI 设备。

本文将结合 DFRobot 行空板K10 硬件 和 TensorFlow-Lite-Mirco 库,介绍如何从模型训练到最终部署,实现一个完整的 TinyML 开发过程,帮助开发者快速上手资源受限设备上的机器学习推理。代码下载链接:https://github.com/polamaxu/TFLM

行空板K10

行空板 K10是一款专为信息科技课程中的编程学习、物联网及AI项目教学需求而开发的学习板。

核心

  • MCU:ESP32-S3,32位双核处理器,主频为240MHz.

无线通信

  • Wi-Fi:支持2.4G Wi-Fi
  • 蓝牙:支持蓝牙5.0

显示模块

  • 显示屏:2.8英寸彩色LCD屏

板载元件

  • 摄像头:集成摄像头,支持图像采集和检测,可以用于人脸识别、物体检测等人工智能应用.
  • 麦克风:配备两个麦克风,能够捕捉声音信号,支持语音识别和语音交互功能.
  • 扬声器:内置扬声器,可以播放音频,用于语音提示和音乐播放等.
  • 传感器:包括数字环境光传感器、温湿度传感器和加速度传感器,能够检测环境参数和运动状态,为物联网应用提供数据支持.
  • RGB灯:有三个RGB灯,可以显示多种颜色
  • 实体按键:有两个实体按键,可以用于输入和控制操作.

Tensorflow vs TensorFlow-Lite-Mirco

TensorFlow 是一个开源机器学习框架。TensorFlow 提供了丰富的 API,使得开发者可以轻松地构建、训练和部署机器学习模型。
TensorFlow Lite Micro 是 TensorFlow 的一个子集,专门为微控制器等资源受限的设备设计。它允许开发者将机器学习模型部署到这些设备上,使得设备能够直接在本地进行数据处理和推理,而不需要依赖云端计算。

  • TensorFlow:适合在桌面或服务器环境中进行复杂的模型训练与开发,支持高度定制化和复杂操作。
  • TensorFlow Lite for Microcontrollers (TFLM):为嵌入式设备设计,优化了存储与计算效率,适用于资源受限环境中快速运行机器学习模型。

    如何通过Arduino IDE在行空板K10 (ESP32 S3) 上使用TensorFlow Li...图1

     

    (图源:https://github.com/johnosbb/MicroTFLite)[/align]

TinyML开发流程(软件Arduino IDE、硬件行空板K10)

数据采集:
  • 创建一个 Arduino Sketch 用于从传感器或其他输入设备收集数据,以形成模型训练所需的数据集。
定义与训练模型:
  • 在 TensorFlow 开发环境中(如 Google Colab)定义一个深度神经网络模型(DNN),并使用采集到的数据对模型进行训练。
模型转换与保存:
  • 将训练好的模型转换为 TensorFlow Lite 格式,并保存为 model.h 文件,其中包含模型的二进制表示(Flat Buffer 格式)。
部署推理代码:
  • 在 Arduino IDE 中准备推理代码,包含以下步骤:
  • 引入必要头文件(如 TensorFlow Lite Micro 和模型文件 model.h)。
  • 定义 TensorArena(内存缓冲区)。
  • 初始化模型。
  • 设置输入数据并运行推理。
  • 读取推理输出。
测试与优化:
  • 通过串口调试工具查看推理结果,并根据需要优化模型性能。

MicroTFLite介绍

MicroTFLite 是一个为 Arduino 设计的 TensorFlow Lite Micro 库,旨在简化在 Arduino 板上使用 TensorFlow Lite Micro 的过程。MicroTFLite库适用于各种任务,如分类、回归、预测。

1. 特点

  • Arduino 风格的 API:MicroTFLite 提供了典型的 Arduino 风格的 API,避免了在 Arduino 代码中使用指针或其他 C++ 语法构造,这使得它更适合 Arduino 开发者使用。
  • 支持量化数据和浮点数据:该库能够处理量化数据和原始浮点值,根据模型的元数据自动检测适当的处理方式。
  • 调试功能:提供了多个函数来帮助开发者了解模型部署的过程,并在调试模型问题时提供帮助。

2. 目标平台

MicroTFLite 适用于多种嵌入式设备,包括:

  • Arduino Nano 系列(如 Nano 33 BLE、Nano ESP32)
  • Arduino Nicla、Portenta 系列
  • ESP32 和 Arduino Giga R1 WiFi

3. MicroTFLite支持的算子

MicroTFLite 支持多种常见的机器学习算子:

  • 卷积与池化:
    • 二维卷积(CONV_2D)、深度可分离卷积(DEPTHWISE_CONV_2D)
    • 最大池化(MAX_POOL_2D)、平均池化(AVERAGE_POOL_2D)
  • 全连接层:FULLY_CONNECTED
  • 激活函数:
    • ReLU、ReLU6、TANH、SIGMOID

4. MicroTFLite API一览

  • 模型初始化和部署
    • 初始化模型:通过 ModelInit() 函数,可以初始化 TensorFlow Lite 模型和解释器。这包括加载模型数据和分配必要的内存空间(如 tensorArena).
    • 加载模型文件:支持从二进制文件(如model.h)中加载预训练的 TensorFlow Lite 模型,使其能够在 Arduino 设备上运行.
  • 输入和输出处理
    • 设置输入数据:使用 ModelSetInput() 函数可以将输入数据设置到模型的输入张量中。该函数还支持量化处理,能够根据模型的量化参数自动调整输入值.
    • 获取输出结果:通过 ModelGetOutput() 函数可以读取模型推理后的输出结果。这使得开发者能够获取模型的预测值或分类结果.
  • 推理和调试
    • 运行推理:调用 ModelRunInference() 函数可以启动模型的推理过程。该函数执行模型的前向传播,生成输出结果.
    • 打印张量信息:提供了多个函数来打印模型的张量信息,如 ModelPrintInputTensorDimensions()和 ModelPrintOutputTensorDimensions(),这些函数可以帮助开发者了解输入和输出张量的维度.
    • 调试和调试信息:ModelPrintTensorQuantizationParams() 函数可以打印输入和输出张量的量化参数,这对于调试量化模型非常有用。ModelPrintMetadata()函数可以打印模型的元数据信息,如描述和版本,帮助开发者了解模型的基本信息.

MicroTFLite安装

1. arduinoIDE 1.8.19 preferences安装行空板K10环境
https://downloadcd.dfrobot.com.cn/UNIHIKER/package_unihiker_index.json

如何通过Arduino IDE在行空板K10 (ESP32 S3) 上使用TensorFlow Li...图2


2. Library Manager搜索MicroTFLite安装

如何通过Arduino IDE在行空板K10 (ESP32 S3) 上使用TensorFlow Li...图3

3. 示例测试:通过传感器数据实时预测工业设备故障
ArduinoLite_preventive_maintenance项目使用的是一个前馈神经网络架构,包含输入层、两个 Dropout 层、一个隐藏层和输出层,处理二元分类任务。使用 binary_crossentropy 作为损失函数和 adam 优化器进行训练,支持类别权重调整和早停回调,确保在类别不平衡时仍能有效训练。示例中使用了随机数模拟的传感器数据,比如转速、温度、振动、电流。为了减少模型大小,此模型量化成int8。

这个程序的主要功能是:

  • 模拟采集设备的各项参数(转速、温度、振动、电流)
  • 使用这些参数进行故障预测
  • 同时通过预设的阈值进行实际故障判断
  • 比较预测结果和实际故障情况,统计预测准确度提升

量化作用和影响:TensorFlow Lite for Microcontrollers (TFLM) 的量化是将模型中的数据从浮点数(float32)转换为整数(int8)的过程。

  • 减少模型大小:量化将模型中的浮点数转换为整数,显著减少了模型的存储需求。
  • 提高计算效率:在资源受限的设备上,整数计算通常比浮点计算更高效。整数运算器(如 ARM Cortex-M 的 DSP 扩展)可以更快地执行整数乘法和加法等操作,从而提高模型的推理速度.
  • 精度损失:量化会引入一定的精度损失,因为整数表示无法完全精确地表示浮点数。然而,通过合理的量化策略和模型优化,可以在保持较高精度的同时实现显著的性能
    /* Copyright 2024 John O'Sullivan, TensorFlow Authors. All Rights Reserved.
    
    这是一个使用 MicroTFLite 库运行 TensorFlow Lite 模型的示例程序
    主要用于设备的预防性维护,通过传感器数据预测可能的故障
    
    更多信息请参考库文档:
    https://github.com/johnosbb/MicroTFLite
    
    Licensed under the Apache License, Version 2.0 (the "License");
    ... 许可证信息 ...
    ==============================================================================*/
    
    #include <MicroTFLite.h>
    #include "model.h"
    
    // 特征数据的统计信息(缩放和平衡前):
    // 转速(RPM) - 平均值: 1603.866, 标准差: 195.843
    // 温度(°C) - 平均值: 24.354, 标准差: 4.987
    // 振动(g) - 平均值: 0.120, 标准差: 0.020
    // 电流(A) - 平均值: 3.494, 标准差: 0.308
    constexpr float tMean = 24.354f;      // 温度平均值
    constexpr float rpmMean = 1603.866f;  // 转速平均值
    constexpr float vMean = 0.120f;       // 振动平均值
    constexpr float cMean = 3.494f;       // 电流平均值
    constexpr float tStd = 4.987f;        // 温度标准差
    constexpr float rpmStd = 195.843f;    // 转速标准差
    constexpr float vStd = 0.020f;        // 振动标准差
    constexpr float cStd = 0.308f;        // 电流标准差
    
    // 定义故障条件的阈值
    const float highTempThreshold = 30.0f;            // 温度过高阈值(摄氏度)
    const float lowRpmThreshold = 1500.0f;            // 转速过低阈值
    const float highVibrationThreshold = 0.60f;       // 振动过高阈值(g)
    const float abnormalCurrentLowThreshold = 0.2f;   // 电流过低阈值(安培)
    const float abnormalCurrentHighThreshold = 10.8f; // 电流过高阈值(安培)
    
    // 预测统计计数器
    int totalPredictions = 0;    
    int truePositives = 0;       
    int falsePositives = 0;      
    int trueNegatives = 0;       
    int falseNegatives = 0;      
    float rollingAccuracy = 0.0f;// 滚动计算的准确率
    bool showStatistics = false; // 是否显示统计信息
    
    // 为TensorFlow Lite分配内存
    constexpr int kTensorArenaSize = 4 * 1024;
    alignas(16) uint8_t tensorArena[kTensorArenaSize];
    
    void setup() {
        // 初始化串口通信并等待串口监视器打开
        Serial.begin(115200);
        while (!Serial);
        delay(5000);
        Serial.println("Preventative Maintenance Example.");
        Serial.println("Initializing TensorFlow Lite Micro Interpreter...");
    
        // 初始化TensorFlow Lite模型
        if (!ModelInit(model, tensorArena, kTensorArenaSize)) {
            Serial.println("Model initialization failed!");
            while (true);
        }
        Serial.println("Model initialization done.");
        ModelPrintMetadata();
        ModelPrintTensorQuantizationParams();
        ModelPrintTensorInfo();
    }
    
    // 使用Box-Muller变换生成正态分布的随机值
    float GenerateRandomValue(float mean, float stddev) {
        float u1 = random(0, 10000) / 10000.0f;
        float u2 = random(0, 10000) / 10000.0f;
        // Box-Muller变换生成正态分布值
        float z0 = sqrt(-2.0f * log(u1)) * cos(2.0f * PI * u2);
        float value = mean + z0 * stddev;
        return value;
    }
    
    // 模拟读取传感器数据的函数
    float ReadRpm() {
        return GenerateRandomValue(rpmMean, rpmStd);
    }
    
    float ReadVibration() {
        return GenerateRandomValue(vMean, vStd);
    }
    
    float ReadTemperature() {
        return GenerateRandomValue(tMean, tStd);
    }
    
    float ReadCurrent() {
        return GenerateRandomValue(cMean, cStd);
    }
    
    // 检查是否存在故障条件
    bool CheckFailureConditions(float temperature, float rpm, float vibration, float current) {
        bool conditionMet = false;
        String failureReason = "";
    
        // 检查各项参数是否超过阈值
        if (temperature > highTempThreshold) {
            conditionMet = true;
            failureReason += "温度过高; ";
        }
        if (rpm < lowRpmThreshold) {
            conditionMet = true;
            failureReason += "转速过低; ";
        }
        if (vibration > highVibrationThreshold) {
            conditionMet = true;
            failureReason += "振动过大; ";
        }
        if (current < abnormalCurrentLowThreshold) {
            conditionMet = true;
            failureReason += "电流过低; ";
        }
        if (current > abnormalCurrentHighThreshold) {
            conditionMet = true;
            failureReason += "电流过高; ";
        }
    
        if (conditionMet) {
            Serial.print("注意:传感器读数表明可能存在故障。原因: ");
            Serial.println(failureReason);
        }
    
        return conditionMet;
    }
    
    void loop() {
        // 读取传感器数据(这里使用模拟数据)
        float rpm = ReadRpm();
        float temperature = ReadTemperature();
        float current = ReadCurrent();
        float vibration = ReadVibration();
        bool inputSetFailed = false;
    
        // 检查实际故障条件
        bool actualFailure = CheckFailureConditions(temperature, rpm, vibration, current);
    
        // 数据标准化处理
        float temperatureN = (temperature - tMean) / tStd;
        float rpmN = (rpm - rpmMean) / rpmStd;
        float currentN = (current - cMean) / cStd;
        float vibrationN = (vibration - vMean) / vStd;
    
        // 将数据输入到模型中
        if (!ModelSetInput(temperatureN, 0))
            inputSetFailed = true;
        if (!ModelSetInput(rpmN, 1))
            inputSetFailed = true;
        if (!ModelSetInput(currentN, 2))
            inputSetFailed = true;
        if (!ModelSetInput(vibrationN, 3))
            inputSetFailed = true;
    
        // 运行模型推理
        if (!ModelRunInference()) {
            Serial.println("模型推理失败!");
            return;
        }
    
        // 获取模型输出结果
        float prediction = ModelGetOutput(0);
        bool predictedFailure = (prediction > 0.50f); // 预测值>0.5表示可能发生故障
    
        // 更新预测统计信息
        if (predictedFailure && actualFailure) {
            truePositives++;
        } else if (predictedFailure && !actualFailure) {
            falsePositives++;
            showStatistics = true;
        } else if (!predictedFailure && actualFailure) {
            falseNegatives++;
            showStatistics = true;
        } else if (!predictedFailure && !actualFailure) {
            trueNegatives++;
        }
    
        totalPredictions++;
    
        // 计算滚动准确率
        rollingAccuracy = (float)(truePositives + trueNegatives) / totalPredictions * 100.0f;
    
        // 当出现假阳性或假阴性时显示统计信息
        if (showStatistics) {
            Serial.println("-----------------------------");
            Serial.print("预测置信度: ");
            Serial.println(prediction);
            Serial.print("预测是否故障: ");
            Serial.println(predictedFailure);
            Serial.print("实际是否故障: ");
            Serial.println(actualFailure);
            Serial.print("转速: ");
            Serial.print(rpm);
            Serial.print(" | 温度: ");
            Serial.print(temperature);
            Serial.print(" °C");
            Serial.print(" | 电流: ");
            Serial.print(current);
            Serial.print(" A");
            Serial.print(" | 振动: ");
            Serial.print(vibration);
            Serial.print(" m/s^2\n");
            Serial.println("总预测次数: " + String(totalPredictions) + 
                          ", 真阳性: " + String(truePositives) + 
                          ", 假阳性: " + String(falsePositives) + 
                          ", 真阴性: " + String(trueNegatives) + 
                          ", 假阴性: " + String(falseNegatives) + 
                          ", 准确率(%): " + String(rollingAccuracy));
    
            showStatistics = false;
        }
    
        delay(10000); // 等待10秒后进行下一次检测
    }

 

 4. 电脑连接k10,上传程序:

如何通过Arduino IDE在行空板K10 (ESP32 S3) 上使用TensorFlow Li...图5

5. k10编译结果:程序占了很少的存储空间和动态内存
Sketch uses 539913 bytes (10%) of program storage space. Maximum is 5242880 bytes.
Global variables use 29740 bytes (9%) of dynamic memory, leaving 297940 bytes for local variables.

6. 串口监视器输出:

  • 转速(RPM):1493.15
  • 温度(Temperature):32.01°C
  • 电流(Current):3.52 A
  • 振动(Vibration):0.09 m/s^2
    在 7:59:07.295,串口输出指出传感器读数显示潜在的故障条件,原因是高温(High Temperature)。在 7:59:37.328,串口输出再次指出潜在的故障条件,原因包括高温和低转速。

    如何通过Arduino IDE在行空板K10 (ESP32 S3) 上使用TensorFlow Li...图6

自己训练模型并部署

1. PC上安装python环境
从 Python 官网下载并安装 Python 3.6 或更高版本:Python 官网

2. PC上安装TensorFlow和NumPy库
打开终端(Windows 上为命令提示符,macOS 或 Linux 上为终端),然后运行以下命令来安装所需的库:

# 安装 TensorFlow 2.x(默认为最新版本) 
pip install tensorflow  
# 安装 NumPy 
pip install numpy

3.  定义简单的回归网络,生成模型后再转换成C文件
目标:输入为1到5的数字,输出为其平方值
终端上运行test.py文件:

python test.py

生成的模型文件下载链接:https://github.com/polamaxu/TFLM
代码详情,请至原文查看

4. 将生成的模型文件上传至K10硬件,通过MicroTFLite库进行推理操作。
将训练生成的model.h与ino文件放在同一个文件夹中,在arduinoIDE软件中上传程序

 代码详情,请至原文查看

5. k10编译结果:程序占了很少的存储空间和动态内存
Sketch uses 530373 bytes (10%) of program storage space. Maximum is 5242880 bytes.
Global variables use 28492 bytes (8%) of dynamic memory, leaving 299188 bytes for local variables. Maximum is 327680 bytes.

6. k10输出:

如何通过Arduino IDE在行空板K10 (ESP32 S3) 上使用TensorFlow Li...图7

结论

MicroTFLite 库为嵌入式开发者提供了高效的工具,使得在资源受限平台(如 K10)上运行机器学习模型变得简单易行。从模型训练到部署推理,开发者可以快速完成 TinyML 项目,推动智能嵌入式设备的发展。
如果你有想要了解的,也请告诉我们,我们会推出更多 tinyml 项目教程。

参考

  1. https://github.com/tensorflow/tflite-micro
  2. https://github.com/johnosbb/MicroTFLite
  3. https://github.com/polamaxu/TFLM
Logo

脑启社区是一个专注类脑智能领域的开发者社区。欢迎加入社区,共建类脑智能生态。社区为开发者提供了丰富的开源类脑工具软件、类脑算法模型及数据集、类脑知识库、类脑技术培训课程以及类脑应用案例等资源。

更多推荐