3.2.2
人工智能学习
人工智能考试
📅 2025-11-21 16:15
🔄 2025-11-21 16:28
👤 admin
代码:
import onnxruntime
import numpy as np
from PIL import Image
# 加载ONNX模型 2分
ort_session = onnxruntime.InferenceSession("mnist.onnx")
# 加载图像 2分
image = Image.open("img_test.png").convert('L') # 转为灰度图
# 图像预处理 4分
image = image.resize((28, 28)) # 调整大小为MNIST模型的输入尺寸1分
image_array = np.array(image, dtype=np.float32) # 转为numpy数组1分
image_array = np.expand_dims(image_array, axis=0) # 添加batch维度1分
image_array = np.expand_dims(image_array, axis=0) # 添加通道维度1分
# 使用模型对图片进行识别 2分
ort_inputs = {ort_session.get_inputs()[0].name: image_array}
# 执行预测 2分
ort_outs = ort_session.run(None, ort_inputs)
# 获取预测结果 2分
predicted_class = np.argmax(ort_outs[0])
# 输出预测结果
print(f"Predicted class: {predicted_class}")
相关笔记