3.2.2

人工智能学习 人工智能考试
📅 2025-11-21 16:15 🔄 2025-11-21 16:28 👤 admin

代码:

import onnxruntime

import numpy as np

from PIL import Image

# 加载ONNX模型  2

ort_session = onnxruntime.InferenceSession("mnist.onnx")

# 加载图像 2

image = Image.open("img_test.png").convert('L')  # 转为灰度图

# 图像预处理 4

image = image.resize((28, 28))  # 调整大小为MNIST模型的输入尺寸1

image_array = np.array(image, dtype=np.float32)  # 转为numpy数组1

image_array = np.expand_dims(image_array, axis=0)  # 添加batch维度1

image_array = np.expand_dims(image_array, axis=0)  # 添加通道维度1

# 使用模型对图片进行识别 2

ort_inputs = {ort_session.get_inputs()[0].name: image_array}

# 执行预测 2

ort_outs = ort_session.run(None, ort_inputs)

# 获取预测结果 2

predicted_class = np.argmax(ort_outs[0])

# 输出预测结果
print(f"Predicted class: {predicted_class}")

相关笔记