一、导入依赖模块

import torch
import torch.utils.data as Data
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from model import LeNet
代码行逻辑与作用关键补充
import torch导入PyTorch核心库提供张量计算、设备管理(CPU/GPU)、模型前向传播等基础能力
import torch.utils.data as Data导入PyTorch数据加载工具封装数据集为可迭代的DataLoader,简化批次读取、打乱、多进程加载等逻辑
from torchvision import transforms导入图像预处理工具用于对FashionMNIST的图像做标准化、尺寸调整、格式转换(如PIL→Tensor)
from torchvision.datasets import FashionMNIST导入FashionMNIST数据集服装分类数据集(10类,如T恤、裤子),替代MNIST的手写数字分类,更贴近实际场景
from model import LeNet导入自定义的LeNet模型model.py中定义的LeNet卷积神经网络,是本次要测试的模型

二、测试数据处理函数 test_data_process

作用:加载并预处理测试集,返回可迭代的DataLoader(数据加载器)

def test_data_process():
    test_data = FashionMNIST(root='./data',
                              train=False,
                              transform=transforms.Compose([transforms.Resize(size=28), transforms.ToTensor()]),
                              download=True)

    test_dataloader = Data.DataLoader(dataset=test_data,
                                       batch_size=1,
                                       shuffle=True,
                                       num_workers=0)
    return test_dataloader
代码行逻辑与作用关键补充
def test_data_process():定义测试数据处理函数(无入参)封装数据加载逻辑,提高代码复用性
test_data = FashionMNIST(...)加载FashionMNIST测试集+ root='./data':数据集保存到当前目录的data文件夹<br/>+ train=False:指定加载测试集train=True为训练集)<br/>+ transform=...:图像预处理管道(组合多个变换): - transforms.Resize(size=28):强制将图像缩放到28×28(确保输入模型的尺寸统一) - transforms.ToTensor():将PIL图像转为PyTorch张量,同时把像素值从0-255归一化到0-1<br/>+ download=True:如果data文件夹无数据集,自动从官网下载
test_dataloader = Data.DataLoader(...)封装测试集为DataLoader+ dataset=test_data:指定要加载的数据集<br/>+ batch_size=1:批次大小为1(测试时批次不影响准确率,设1方便逐样本验证)<br/>+ shuffle=True:打乱测试集顺序(测试时shuffle不影响最终准确率,仅改变遍历顺序)<br/>+ num_workers=0:数据加载的子进程数(Windows系统推荐设0,避免多进程冲突)
return test_dataloader返回构建好的测试数据加载器供后续模型测试函数调用

三、模型测试函数 test_model_process

作用:用测试集评估模型的预测准确率,核心是仅前向传播、不计算梯度

def test_model_process(model, test_dataloader):

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    #初始化参数
    test_correct = 0.0
    test_num = 0

    #只进行前向传播计算,不计算梯度,从而节省内存,加快运行速度
    with torch.no_grad():
        for test_data_x, test_data_y in test_dataloader:
            # 将特征放入到测试设备中
            test_data_x = test_data_x.to(device)
            # 将标签放入到测试设备中
            test_data_y = test_data_y.to(device)
            # 设置模型为评估模式
            model.eval()
            # 前向传播过程,输入为测试数据集,输出为对每个样本的预测值
            output = model(test_data_x)
            # 查找每一行中最大值对应的行标
            pre_lab = torch.argmax(output, dim=1)  #沿着第一维度 (列)找最大值
            # 如果预测正确,则准确度test_corrects加1
            test_correct += torch.sum(pre_lab == test_data_y.data)
            # 将所有的测试样本进行累加
            test_num += test_data_x.size(0)

        # 计算测试准确率
    test_acc = test_correct.double().item() / test_num
    print("测试的准确率为:", test_acc)
代码行逻辑与作用关键补充
def test_model_process(model, test_dataloader):定义模型测试函数入参:model(待测试的LeNet模型)、test_dataloader(测试数据加载器)
device = torch.device(...)定义计算设备优先使用GPU(cuda:0),无GPU则用CPU;保证模型和数据在同一设备上计算(否则报错)
model = model.to(device)将模型移动到指定设备模型的所有参数/计算都会在GPU/CPU上执行
test_correct = 0.0初始化“正确预测数”浮点型避免整数溢出,方便后续除法计算
test_num = 0初始化“测试样本总数”累加所有遍历过的样本数
with torch.no_grad():禁用梯度计算的上下文管理器测试阶段不需要反向传播更新参数,禁用梯度可大幅节省内存、提升计算速度
for test_data_x, test_data_y in test_dataloader:遍历测试集的每个批次test_data_x:图像特征张量(shape=[1,1,28,28],对应batch_size=1、单通道、28×28);test_data_y:真实标签(shape=[1])
test_data_x = test_data_x.to(device)将特征张量移到指定设备必须和模型同设备,否则无法计算
test_data_y = test_data_y.to(device)将标签张量移到指定设备保证和预测标签(pre_lab)同设备,才能比较
model.eval()将模型设为评估模式关键:关闭训练时的特殊层行为(如Dropout随机失活、BatchNorm的均值/方差更新),确保测试时模型行为一致
output = model(test_data_x)模型前向传播输入测试特征,输出output(shape=[1,10]):10个类别的预测得分(logits,未归一化)
pre_lab = torch.argmax(output, dim=1)取预测得分的最大值索引dim=1:沿“类别维度”(第1维,列方向)找最大值,结果是预测的类别标签(如0=T恤、1=裤子)
test_correct += torch.sum(pre_lab == test_data_y.data)累加本批次正确数pre_lab == test_data_y.data:逐元素比较(True=1,False=0);torch.sum()求和得到本批次正确数
test_num += test_data_x.size(0)累加本批次样本数test_data_x.size(0)=batch_size(这里是1),最终累加为测试集总样本数
test_acc = test_correct.double().item() / test_num计算整体准确率+ test_correct.double():转为双精度浮点型<br/>+ .item():将张量转为Python数值(避免张量运算)<br/>+ 除以总样本数得到准确率(0~1之间)
print("测试的准确率为:", test_acc)打印准确率直观展示模型测试效果

四、主程序入口

if __name__=="__main__":
    # 加载模型
    model = LeNet()
    model.load_state_dict(torch.load('model.pth'))
    # 加载测试数据
    test_dataloader = test_data_process()
    # 加载模型测试的函数
    test_model_process(model, test_dataloader)
代码行逻辑与作用关键补充
if __name__=="__main__":脚本入口判断仅当脚本被直接运行时执行(被导入为模块时不执行),避免代码被复用时报错
model = LeNet()实例化LeNet模型初始化模型的网络结构(此时参数是随机的)
model.load_state_dict(torch.load('model.pth'))加载训练好的模型权重model.pth是训练阶段保存的模型参数文件,加载后模型才有“训练好的能力”
test_dataloader = test_data_process()调用函数获取测试数据加载器准备测试数据
test_model_process(model, test_dataloader)调用测试函数执行模型评估,最终打印测试准确率

总结

  1. 测试阶段的核心优化torch.no_grad() + model.eval() 是测试的“黄金组合”——禁用梯度+关闭训练层行为,保证效率和结果准确;
  2. 设备一致性:模型、特征、标签必须在同一设备(CPU/GPU),否则会报张量设备不匹配错误;
  3. 准确率计算逻辑:通过torch.argmax取预测类别,和真实标签比较后累加正确数,最终除以总样本数。

完整代码如下

import torch
import torch.utils.data as Data
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from model import LeNet

def test_data_process():
    test_data = FashionMNIST(root='./data',
                              train=False,
                              transform=transforms.Compose([transforms.Resize(size=28), transforms.ToTensor()]),
                              download=True)

    test_dataloader = Data.DataLoader(dataset=test_data,
                                       batch_size=1,
                                       shuffle=True,
                                       num_workers=0)
    return test_dataloader

def test_model_process(model, test_dataloader):

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    #初始化参数
    test_correct = 0.0
    test_num = 0

    #只进行前向传播计算,不计算梯度,从而节省内存,加快运行速度
    with torch.no_grad():
        for test_data_x, test_data_y in test_dataloader:
            # 将特征放入到测试设备中
            test_data_x = test_data_x.to(device)
            # 将标签放入到测试设备中
            test_data_y = test_data_y.to(device)
            # 设置模型为评估模式
            model.eval()
            # 前向传播过程,输入为测试数据集,输出为对每个样本的预测值
            output = model(test_data_x)
            # 查找每一行中最大值对应的行标
            pre_lab = torch.argmax(output, dim=1)  #沿着第一维度 (列)找最大值
            # 如果预测正确,则准确度test_corrects加1
            test_correct += torch.sum(pre_lab == test_data_y.data)
            # 将所有的测试样本进行累加
            test_num += test_data_x.size(0)

        # 计算测试准确率
    test_acc = test_correct.double().item() / test_num
    print("测试的准确率为:", test_acc)


if __name__=="__main__":
    # 加载模型
    model = LeNet()
    model.load_state_dict(torch.load('model.pth'))
    # 加载测试数据
    test_dataloader = test_data_process()
    # 加载模型测试的函数
    test_model_process(model, test_dataloader)
最后修改:2026 年 01 月 31 日
如果觉得我的文章对你有用,请随意赞赏