正点原子学习小组
直播中

jf_07365693

2年用户 380经验值
擅长:嵌入式技术 控制/MCU
私信 关注

【正点原子STM32MP257开发板试用】MNIST 手写数字识别

AI_monist_local

【正点原子STM32MP257开发板试用】MNIST 手写数字识别

本文介绍了正点原子 STM32MP257 开发板基于 MNIST 数据集实现手写数字识别的项目设计,包括 USB 摄像头驱动、模型训练与部署、板端推理、本地识别以及远程数字识别等。

项目介绍

  • 准备工作:包括 USB 摄像头的驱动显示;
  • 模型部署:模型训练、ONNX 转换;
  • 板端推理:开发板本地运行图片实现推理测试;
  • 远程识别:结合 http 协议实现识别画面的网络推送;
  • Home Assistant 连接:添加摄像头集成,实现网页摄像头画面的远程调用。

准备工作

包括 USB 摄像头的本地驱动显示和网页远程显示。

USB 摄像头

这里使用 USB 摄像头进行图像采集,型号为罗技 C270 (标准 UVC 设备,便于驱动)。

更新软件源并安装 OpenCV 库(默认已安装)

sudo apt-get update
sudo apt-get install python3-opencv

安装完成后,检查版本

python3
import cv2
print(cv2.__version__)

opencv_version.jpg

通过 v4l2-ctl --list-devices 指令获取当前 USB 设备列表

camera_usb_list.jpg

流程图

flowchart_camera.jpg

代码

终端执行指令 touch camera_test.py 新建 python 执行文件,添加如下代码

import cv2

# Create a VideoCapture object
# Parameter is the camera index (0 for first/default camera)
cap = cv2.VideoCapture(7)

# Check if camera opened successfully
if not cap.isOpened():
    print("Error: Could not open camera")
    exit()

# Continuous frame capture loop
while True:
    # Capture frame-by-frame
    ret, frame = cap.read()
    
    # If frame reading fails, break the loop
    if not ret:
        print("Error: Could not read frame")
        break
    
    # Display the resulting frame
    cv2.imshow('USB Camera Feed', frame)
    
    # Break the loop when 'q' key is pressed
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

# Release the capture and destroy all windows
cap.release()
cv2.destroyAllWindows()

终端执行 python3 camera_test.py 指令运行程序,可在屏幕获取 USB 摄像头采集的实时动态画面。

效果

usb_camera_test.jpg

通过如下指令查询系统信息

cat /etc/os-release

os_version_view.jpg

安装 nano 文本编辑器

sudo apt install nano

更新软件源可能遇到报错,可添加镜像实现加速下载

执行如下指令,编辑软件源列表

sudo vi /etc/apt/sources.list

添加如下软件源

deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal main restricted universe multiverse
deb-src https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal main restricted universe multiverse

保存并运行 sudo apt update .

开启时钟同步

更新或安装部分软件或 Python 库时,可能会要求系统时钟同步,下面介绍通过修改系统时钟配置实现同步的方案。

终端执行

sudo nano /etc/systemd/timesyncd.conf

启用 NTP 服务,将配置信息修改为

[Time]
NTP=pool.ntp.org
FallbackNTP=ntp.ubuntu.com

执行以下指令,应用上述修改

sudo systemctl restart systemd-timesyncd
timedatectl list-timezones # 列出所有时区
sudo timedatectl set-timezone Asia/Shanghai # 设置为上海时区

终端输入指令 timedatectl 验证时钟配置信息

NTP_time.jpg

系统时钟同步已激活,NTP 服务已开启。

网页摄像头

为了便于调试和验证摄像头画面采集效果,结合 flask 和 opencv 库实现摄像头画面的网页端显示。

流程图

flowchart_http_camera.jpg

代码

终端执行指令 touch camera_server.py 新建 python 执行文件,添加如下代码

# camera_server.py
from flask import Flask, Response
import cv2

app = Flask(__name__)

def get_frame():
    camera = cv2.VideoCapture(7, cv2.CAP_V4L2)
    camera.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
    camera.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
    while True:
        ret, frame = camera.read()
        if not ret:
            break
        ret, jpeg = cv2.imencode('.jpg', frame, [
            int(cv2.IMWRITE_JPEG_QUALITY), 70
        ])
        yield (b'--frame\\\\r\\\\n'
               b'Content-Type: image/jpeg\\\\r\\\\n\\\\r\\\\n' + jpeg.tobytes() + b'\\\\r\\\\n\\\\r\\\\n')

@app.route('/video_feed')
def video_feed():
    return Response(get_frame(),
                   mimetype='multipart/x-mixed-replace; boundary=frame')

@app.route('/')
def index():
    return """
    <html>
      <head>
        <title>STM32MP257 Camera</title>
        <link rel="icon" href="data:,">
      </head>
      <body>
        <h1>Live Camera</h1>
        <img src="/video_feed" width="640" height="480">
      </body>
    </html>
    """

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000, threaded=True)
  • 终端执行 python3 camera_server.py 指令运行程序;
  • 打开浏览器并根据终端提示,输入对应的网址 192.168.1.119:5000 ,即可实时显示 USB 摄像头画面。

效果

camera_http.jpg

模型部署

ST Edge AI

ST Edge AI 开发者云是在 STM32 产品部署边缘 AI 的在线解决方案,通过云服务直接在 STM32 目标上对神经网络模型进行基准测试。当基准测试在 STM32MP2x 板上运行时,会自动生成 NBG 模型,并可下载。

ST Edge AI 开发者云支持 TensorFlow Lite,Keras 和 ONNX 模型。使用该在线工具进行转换模型,方便、快捷且高效。

量化类型非常重要。为了获得 GPU/NPU 的最佳性能,应该将模型量化为每张量 8 位。

打开 ST Edge AI 网站,

st-edge-ai_page.jpg

登录 ST 账号

stedgeai_login.jpg

自定义模型部署

AI 模型在 STM32MP257 开发板的常规部署流程:

  • 根据需求采集获得数据集
  • 选取合适的框架和模型
  • 训练模型
  • 模型量化
  • 使用云转换工具将模型转换为 nb模型
  • 编写程序,推理得到最终结果。

模型训练

介绍了手写数字识别的模型训练的主要流程,包括测试准备、代码及效果。

准备工作

安装必要的库,PC 终端执行指令

pip install torch torchvision onnx onnxruntime

训练代码

流程图

flowchart_train_model.jpg

新建 HWNR_train.py 文件,添加如下代码

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import onnx
import onnxruntime

# 定义神经网络模型
class MNISTNet(nn.Module):
    def __init__(self):
        super(MNISTNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = torch.relu(x)
        x = self.conv2(x)
        x = torch.relu(x)
        x = torch.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return torch.log_softmax(x, dim=1)

# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# 加载数据集
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

# 初始化模型、损失函数和优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MNISTNet().to(device)
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters())

# 训练函数
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\\\\tLoss: {loss.item():.6f}')

# 测试函数
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print(f'\\\\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} '
          f'({100. * correct / len(test_loader.dataset):.0f}%)\\\\n')

# 训练模型
epochs = 5
for epoch in range(1, epochs + 1):
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)

# 导出为ONNX模型
dummy_input = torch.randn(1, 1, 28, 28).to(device)
onnx_path = "mnist_model.onnx"

torch.onnx.export(
    model,
    dummy_input,
    onnx_path,
    export_params=True,
    opset_version=11,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)

print(f"Model exported to {onnx_path}")

# 验证ONNX模型
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
print("ONNX model check passed!")

# 测试ONNX模型推理
ort_session = onnxruntime.InferenceSession(onnx_path)

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# 使用测试数据验证
test_data, test_target = next(iter(test_loader))
test_data = test_data[0].unsqueeze(0).to(device)

# PyTorch推理
model.eval()
with torch.no_grad():
    torch_out = model(test_data)

# ONNX Runtime推理
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(test_data)}
ort_outs = ort_session.run(None, ort_inputs)

# 比较结果
print("PyTorch和ONNX Runtime输出是否接近:", 
      torch.allclose(torch_out, torch.tensor(ort_outs[0]), atol=1e-3))

执行代码,打印训练过程,输出 *.onnx 模型文件。

训练效果

HWNR_train_output.jpg

经过 5 轮训练,模型精度已达到 99% 满足识别要求。

本地识别测试

介绍了 PC 端对生成的 ONNX 模型的数字识别测试流程,包括测试代码和结果展示。

流程图

flowchart_model_test_local.jpg

测试代码

新建 HWNR_test.py 文件,添加如下代码

import onnxruntime
import numpy as np
from PIL import Image
import torchvision.transforms as transforms

ort_session = onnxruntime.InferenceSession("mnist_model.onnx") # 加载ONNX模型

# 预处理函数
def preprocess_image(image_path):
    image = Image.open(image_path).convert('L')  # 转换为灰度图像
    transform = transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    return transform(image).unsqueeze(0).numpy()

# 推理函数
def predict(image_path):
    # 预处理图像
    input_data = preprocess_image(image_path)
    
    # 运行推理
    ort_inputs = {ort_session.get_inputs()[0].name: input_data}
    ort_outs = ort_session.run(None, ort_inputs)
    
    # 获取预测结果
    pred = np.argmax(ort_outs[0])
    confidence = np.max(np.exp(ort_outs[0]))  # 转换为概率
    
    return pred, confidence

# 图片测试
image_path = "number5.png"  # 手写数字图片路径
prediction, confidence = predict(image_path)
print(f"Predicted digit: {prediction} with confidence: {confidence:.2f}")

执行代码,输出模型测试结果。

测试结果

recog_result.jpg

模型转换

使用 ST Edge AI 工具转换 ONNX 模型,实现 STM32MP257 资源的高效利用。

  • 进入 ST Edge AI 官网,点击 START,登录 ST 账号;
  • 选择 ST Edge AI Core 版本 2.0.0 以适应 STM32MPU 开发板;

stedgeai_platform.jpg

  • 点击 Launch quantization 执行模型量化,完成后点击下一步

stedgeai_quantize.jpg

  • 点击 Optimize 按钮,执行模型优化

stedgeai_optimize.jpg

  • 点击 Start Benchmark 执行模型基准测试

stedgeai_benchmark.jpg

  • 完成后生成模型基准测试历史记录,点击下一步

stedgeai_results.jpg

  • 云端自动生成 nb 模型,点击下载

stedgeai_generate.jpg

得到目标 mnist_model.nb 板端部署模型文件。

板端推理

结合 stai_mpu 库以及生成的 nb 模型文件,实现手写数字识别的板端推理。

流程图

flowchart_MNIST_board.jpg

代码

终端执行 touch HWNR_inference.py 指令,新建 python 文件,添加如下代码

import cv2
import numpy as np
import time
from stai_mpu import stai_mpu_network  # STM32MPU专用AI推理库

class MNISTInference:
    def __init__(self, model_path):
        """初始化MNIST数字识别推理引擎"""
        print("正在加载模型...")
        self.model = stai_mpu_network(model_path=model_path)
                
        self.input_shape = self._get_input_shape()
        print(f"模型加载成功,输入尺寸: {self.input_shape}")

    def _get_input_shape(self):
        """获取模型输入张量形状"""
        input_info = self.model.get_input_infos()[0]
        
        shape = input_info.get_shape() # 使用get_shape()方法获取形状
        print(f"输入形状信息: {shape}")
        
        # 假设形状格式为[1, height, width, 1] (NHWC)
        if len(shape) == 4:
            return (shape[1], shape[2])  # (height, width)
        elif len(shape) == 2:
            return (shape[0], shape[1])  # (height, width)
        else:
            print("警告: 未知输入形状格式,使用默认尺寸28x28")
            return (28, 28)
            
    def preprocess_image(self, image):
        """
        图像预处理
        :param image: 输入图像(BGR格式)
        :return: 预处理后的张量(NHWC格式)
        """
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # 转换为灰度图
        
        # 自动检测并反转白底黑字
        if np.mean(gray) > 127:
            gray = 255 - gray
        
        resized = cv2.resize(gray, self.input_shape) # 调整尺寸匹配模型输入
        
        normalized = (resized.astype(np.float32) / 255.0 - 0.1307) / 0.3081 # 归一化并应用MNIST统计参数
        
        return np.expand_dims(np.expand_dims(normalized, 0), -1) # 添加batch和channel维度 (NHWC)

    def infer(self, image):
        """
        执行推理
        :param image: 输入图像
        :return: (预测结果, 推理时间ms)
        """
        input_data = self.preprocess_image(image) # 预处理
        self.model.set_input(0, input_data) # 设置输入
        # 推理
        start_time = time.perf_counter()
        self.model.run()
        inference_time = (time.perf_counter() - start_time) * 1000
        
        # 获取输出
        output = self.model.get_output(0)
        return output, inference_time

    @staticmethod
    def postprocess(output):
        """
        后处理输出结果
        :param output: 模型原始输出
        :return: (预测数字, 置信度)
        """
        probabilities = np.exp(output) / np.sum(np.exp(output))  # softmax
        predicted = np.argmax(probabilities)
        confidence = probabilities[0][predicted]
        return predicted, confidence

    def show_top5(self, output):
        """显示Top5预测结果"""
        probs = np.exp(output[0]) / np.sum(np.exp(output[0]))
        top5_idx = np.argsort(probs)[::-1][:5]
        
        print("-----TOP 5预测结果-----")
        for i, idx in enumerate(top5_idx):
            print(f"{i+1}. 数字 {idx}: {probs[idx]*100:.2f}%")

if __name__ == '__main__':
    # 配置参数
    MODEL_PATH = 'model/mnist_model_1.nb'
    TEST_IMAGE = 'model/number5.png'
    
    mnist = MNISTInference(MODEL_PATH) # 初始化推理引擎
    
    # 加载测试图像
    image = cv2.imread(TEST_IMAGE)
    if image is None:
        raise FileNotFoundError(f"无法加载图像: {TEST_IMAGE}")
	
    output, inference_time = mnist.infer(image) # 执行推理
        
    digit, confidence = mnist.postprocess(output) # 后处理结果
    
    # 打印结果
    print(f"\\\\n推理时间: {inference_time:.2f}ms")
    print(f"预测结果: 数字 {digit}, 置信度: {confidence*100:.2f}%")
    mnist.show_top5(output)
    
    # 推理结果可视化
    display_img = cv2.resize(image, (280, 280))
    cv2.putText(display_img, f"Prediction: {digit} ({confidence*100:.1f}%)", 
                (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
    cv2.imshow("MNIST Digit Recognition", display_img)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

终端执行 python3 HWNR_inference.py 输出推理结果

效果

HWNR_inference_result.png

同时 LCD 屏显示推理结果图片

HWNR_inference_result_display.jpg

可获得较高的识别准确率。见顶部视频。

参考:正点原子官方 AI 例程,路径为 01、程序源码\\\\06、AI 例程源码\\\\01、例程源 码\\\\01、LENET 执行代码可获得识别结果的打印信息

lenet_inference_print.jpg

官方例程采用LeNet模型转换得到nb模型,输入采样图片的尺寸须为 28x28 才能正确识别。

远程数字识别

结合 OpenCV 自带的 http.server 网页服务器函数,结合 USB 摄像头实现数字识别的远程传递。

流程图

flowchart_http_camera_inference.jpg

代码

from http.server import BaseHTTPRequestHandler, HTTPServer
import cv2
import numpy as np
from stai_mpu import stai_mpu_network
import time
import threading

class CameraHandler:
    def __init__(self, model_path):
        self.model = stai_mpu_network(model_path=model_path)
        self.input_shape = (28, 28)
        self.cap = cv2.VideoCapture(7)
        self.latest_frame = None
        self.latest_result = "等待识别..."
        self.running = True
        self.thread = threading.Thread(target=self.process_frames)
        self.thread.start()

    def process_frames(self):
        while self.running:
            ret, frame = self.cap.read()
            if not ret:
                continue
                
            # 预处理和推理
            gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
            if np.mean(gray) > 127:
                gray = 255 - gray
            resized = cv2.resize(gray, self.input_shape, interpolation=cv2.INTER_AREA)
            normalized = (resized.astype(np.float32) / 255.0 - 0.1307) / 0.3081
            input_data = np.expand_dims(np.expand_dims(normalized, 0), -1)
            
            self.model.set_input(0, input_data)
            self.model.run()
            output = self.model.get_output(0)
            
            digit = np.argmax(output)
            self.latest_frame = frame
            self.latest_result = f"识别结果: {digit}"

    def stop(self):
        self.running = False
        self.thread.join()
        self.cap.release()

class HTTPRequestHandler(BaseHTTPRequestHandler):
    def do_GET(self):
        if self.path == '/':
            self.send_response(200)
            self.send_header('Content-type', 'text/html; charset=utf-8')
            self.end_headers()
            
            # 分开处理HTML内容和变量
            html_content = f"""
            <html>
            <head>
                <meta http-equiv="refresh" content="1">
                <title>MNIST数字识别</title>
            </head>
            <body>
                <h1>MNIST数字识别</h1>
                <img src="/video" width="640">
                <p>{camera.latest_result}</p>
            </body>
            </html>
            """
            self.wfile.write(html_content.encode('utf-8'))
            
        elif self.path == '/video':
            self.send_response(200)
            self.send_header('Content-type', 'image/jpeg')
            self.end_headers()
            if camera.latest_frame is not None:
                ret, jpeg = cv2.imencode('.jpg', camera.latest_frame)
                if ret:
                    self.wfile.write(jpeg.tobytes())

if __name__ == '__main__':
    camera = CameraHandler('model/LeNet5_mnist_model_1.nb')
    try:
        server = HTTPServer(('0.0.0.0', 8000), HTTPRequestHandler)
        print("服务器已启动: http://<开发板IP>:8000")
        server.serve_forever()
    except KeyboardInterrupt:
        print("\\\\n正在关闭服务器...")
        camera.stop()
        server.server_close()

效果

网页端

HWNR_cam_http.png

其他数字的识别效果

HWNR_http_Cam_numbers.jpg

手机端

使用连接同一局域网的手机浏览器访问网页摄像头并获取数字识别结果

HWNR_Cam_htpp_Phone.jpg

效果见底部视频。

Home Assistant

在完成网页端访问摄像头实时数字识别画面的基础上,进一步实现 Home Assistant (HA) 平台的远程数字识别画面显示的项目流程。

设备添加

  • Docker 启动 HA,打开 HA 网页界面;
  • 设置 - 设备和服务 - 添加集成(右下角);
  • 搜索 Camera - 选择 MJPEG IP Camera

HA_camera_device.jpg

  • 填写名称和摄像头设备 ip 地址 http://192.168.1.118:8000 - 确认即可
  • 回到 概览 页面,可见摄像头卡片。

总结

本文介绍了正点原子 STM32MP257 开发板实现基于MNIST数据集的手写数字识别的项目设计,包括 USB 摄像头驱动、模型训练与部署、板端推理、本地识别以及远程数字识别等。板端硬件资源充足,完全满足数字识别所需的硬件支持,识别速度极快,取得了令人满意的识别效果,该项目为人工智能和图像识别相关领域的开发提供了经验和参考。

AI_monist_http

回帖(2)

jf_07365693

2025-6-19 19:31:23
2025-06-19更新:增加代码流程图
举报

无垠的广袤

2025-6-20 10:20:01
内容真丰富啊
还有什么办法能提升识别准确率呢?
1 举报
  • jf_07365693: 感谢关注~若要提高数字识别的准确率,可以尝试优化模型,文中提到的方案包括LeNet、ONNX模型,其他模型如 Yolo;此外还可以创建更广泛的数据集、增加样本数量和迭代次数、优化python运行代码、增加不同场景的灰度处理和像素压缩等方案~

更多回帖

发帖
×
20
完善资料,
赚取积分