3.2.3

人工智能学习 人工智能考试
📅 2025-11-21 16:29 🔄 2025-11-21 16:41 👤 admin

import onnx
import numpy as np
from PIL import Image
import onnxruntime as ort
#
定义预处理函数,用于将图片转换为模型所需的输入格式
def preprocess(image_path):
    input_shape = (1, 1, 64, 64)    #
模型输入期望的形状,这里是 (N, C, H, W)N=batch size, C=channels, H=height, W=width
    img = Image.open(image_path).convert('L')    #
打开图像文件并将其转换为灰度图  1
    img = img.resize((64, 64), Image.Resampling.LANCZOS)    #
调整图像大小到模型输入所需的尺寸
    img_data = np.array(img, dtype=np.float32)    #
PIL图像对象转换为numpy数组,并确保数据类型是float32
    #
调整数组的形状以匹配模型输入的形状
    img_data = np.expand_dims(img_data, axis=0)  #
添加 batch 维度
    img_data = np.expand_dims(img_data, axis=1)  #
添加 channel 维度
    assert img_data.shape == input_shape, f"Expected shape {input_shape}, but got {img_data.shape}"    #
确保最终的形状与模型输入要求的形状一致
    return img_data    #
返回预处理后的图像数据
#
定义情感类别与数字标签的映射表 2
emotion_table = {'neutral':0, 'happiness':1, 'surprise':2, 'sadness':3, 'anger':4, 'disgust':5, 'fear':6, 'contempt':7}
#
加载模型 2
ort_session = ort.InferenceSession('emotion-ferplus.onnx')    #
使用onnxruntime创建一个会话,用于加载并运行模型
#
加载本地图片并进行预处理 2
input_data = preprocess('img_test.png')
#
准备输入数据,确保其符合模型输入的要求
ort_inputs = {ort_session.get_inputs()[0].name: input_data}    # ort_session.get_inputs()[0].name
是获取模型的第一个输入的名字
#
运行模型,进行预测 2
ort_outs = ort_session.run(None, ort_inputs)
#
解码模型输出,找到预测概率最高的情感类别 2
predicted_label = np.argmax(ort_outs[0])
print(predicted_label)
#
根据预测的标签找到对应的情感名称 2
predicted_emotion = list(emotion_table.keys())[predicted_label]
#
输出预测的情感
print(f"Predicted emotion: {predicted_emotion}")

相关笔记