一、导入依赖模块
import torch
import torch.utils.data as Data
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from model import LeNet| 代码行 | 逻辑与作用 | 关键补充 |
|---|---|---|
import torch | 导入PyTorch核心库 | 提供张量计算、设备管理(CPU/GPU)、模型前向传播等基础能力 |
import torch.utils.data as Data | 导入PyTorch数据加载工具 | 封装数据集为可迭代的DataLoader,简化批次读取、打乱、多进程加载等逻辑 |
from torchvision import transforms | 导入图像预处理工具 | 用于对FashionMNIST的图像做标准化、尺寸调整、格式转换(如PIL→Tensor) |
from torchvision.datasets import FashionMNIST | 导入FashionMNIST数据集 | 服装分类数据集(10类,如T恤、裤子),替代MNIST的手写数字分类,更贴近实际场景 |
from model import LeNet | 导入自定义的LeNet模型 | model.py中定义的LeNet卷积神经网络,是本次要测试的模型 |
二、测试数据处理函数 test_data_process
作用:加载并预处理测试集,返回可迭代的DataLoader(数据加载器)
def test_data_process():
test_data = FashionMNIST(root='./data',
train=False,
transform=transforms.Compose([transforms.Resize(size=28), transforms.ToTensor()]),
download=True)
test_dataloader = Data.DataLoader(dataset=test_data,
batch_size=1,
shuffle=True,
num_workers=0)
return test_dataloader| 代码行 | 逻辑与作用 | 关键补充 |
|---|---|---|
def test_data_process(): | 定义测试数据处理函数(无入参) | 封装数据加载逻辑,提高代码复用性 |
test_data = FashionMNIST(...) | 加载FashionMNIST测试集 | + root='./data':数据集保存到当前目录的data文件夹<br/>+ train=False:指定加载测试集(train=True为训练集)<br/>+ transform=...:图像预处理管道(组合多个变换): - transforms.Resize(size=28):强制将图像缩放到28×28(确保输入模型的尺寸统一) - transforms.ToTensor():将PIL图像转为PyTorch张量,同时把像素值从0-255归一化到0-1<br/>+ download=True:如果data文件夹无数据集,自动从官网下载 |
test_dataloader = Data.DataLoader(...) | 封装测试集为DataLoader | + dataset=test_data:指定要加载的数据集<br/>+ batch_size=1:批次大小为1(测试时批次不影响准确率,设1方便逐样本验证)<br/>+ shuffle=True:打乱测试集顺序(测试时shuffle不影响最终准确率,仅改变遍历顺序)<br/>+ num_workers=0:数据加载的子进程数(Windows系统推荐设0,避免多进程冲突) |
return test_dataloader | 返回构建好的测试数据加载器 | 供后续模型测试函数调用 |
三、模型测试函数 test_model_process
作用:用测试集评估模型的预测准确率,核心是仅前向传播、不计算梯度
def test_model_process(model, test_dataloader):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
#初始化参数
test_correct = 0.0
test_num = 0
#只进行前向传播计算,不计算梯度,从而节省内存,加快运行速度
with torch.no_grad():
for test_data_x, test_data_y in test_dataloader:
# 将特征放入到测试设备中
test_data_x = test_data_x.to(device)
# 将标签放入到测试设备中
test_data_y = test_data_y.to(device)
# 设置模型为评估模式
model.eval()
# 前向传播过程,输入为测试数据集,输出为对每个样本的预测值
output = model(test_data_x)
# 查找每一行中最大值对应的行标
pre_lab = torch.argmax(output, dim=1) #沿着第一维度 (列)找最大值
# 如果预测正确,则准确度test_corrects加1
test_correct += torch.sum(pre_lab == test_data_y.data)
# 将所有的测试样本进行累加
test_num += test_data_x.size(0)
# 计算测试准确率
test_acc = test_correct.double().item() / test_num
print("测试的准确率为:", test_acc)| 代码行 | 逻辑与作用 | 关键补充 |
|---|---|---|
def test_model_process(model, test_dataloader): | 定义模型测试函数 | 入参:model(待测试的LeNet模型)、test_dataloader(测试数据加载器) |
device = torch.device(...) | 定义计算设备 | 优先使用GPU(cuda:0),无GPU则用CPU;保证模型和数据在同一设备上计算(否则报错) |
model = model.to(device) | 将模型移动到指定设备 | 模型的所有参数/计算都会在GPU/CPU上执行 |
test_correct = 0.0 | 初始化“正确预测数” | 浮点型避免整数溢出,方便后续除法计算 |
test_num = 0 | 初始化“测试样本总数” | 累加所有遍历过的样本数 |
with torch.no_grad(): | 禁用梯度计算的上下文管理器 | 测试阶段不需要反向传播更新参数,禁用梯度可大幅节省内存、提升计算速度 |
for test_data_x, test_data_y in test_dataloader: | 遍历测试集的每个批次 | test_data_x:图像特征张量(shape=[1,1,28,28],对应batch_size=1、单通道、28×28);test_data_y:真实标签(shape=[1]) |
test_data_x = test_data_x.to(device) | 将特征张量移到指定设备 | 必须和模型同设备,否则无法计算 |
test_data_y = test_data_y.to(device) | 将标签张量移到指定设备 | 保证和预测标签(pre_lab)同设备,才能比较 |
model.eval() | 将模型设为评估模式 | 关键:关闭训练时的特殊层行为(如Dropout随机失活、BatchNorm的均值/方差更新),确保测试时模型行为一致 |
output = model(test_data_x) | 模型前向传播 | 输入测试特征,输出output(shape=[1,10]):10个类别的预测得分(logits,未归一化) |
pre_lab = torch.argmax(output, dim=1) | 取预测得分的最大值索引 | dim=1:沿“类别维度”(第1维,列方向)找最大值,结果是预测的类别标签(如0=T恤、1=裤子) |
test_correct += torch.sum(pre_lab == test_data_y.data) | 累加本批次正确数 | pre_lab == test_data_y.data:逐元素比较(True=1,False=0);torch.sum()求和得到本批次正确数 |
test_num += test_data_x.size(0) | 累加本批次样本数 | test_data_x.size(0)=batch_size(这里是1),最终累加为测试集总样本数 |
test_acc = test_correct.double().item() / test_num | 计算整体准确率 | + test_correct.double():转为双精度浮点型<br/>+ .item():将张量转为Python数值(避免张量运算)<br/>+ 除以总样本数得到准确率(0~1之间) |
print("测试的准确率为:", test_acc) | 打印准确率 | 直观展示模型测试效果 |
四、主程序入口
if __name__=="__main__":
# 加载模型
model = LeNet()
model.load_state_dict(torch.load('model.pth'))
# 加载测试数据
test_dataloader = test_data_process()
# 加载模型测试的函数
test_model_process(model, test_dataloader)| 代码行 | 逻辑与作用 | 关键补充 |
|---|---|---|
if __name__=="__main__": | 脚本入口判断 | 仅当脚本被直接运行时执行(被导入为模块时不执行),避免代码被复用时报错 |
model = LeNet() | 实例化LeNet模型 | 初始化模型的网络结构(此时参数是随机的) |
model.load_state_dict(torch.load('model.pth')) | 加载训练好的模型权重 | model.pth是训练阶段保存的模型参数文件,加载后模型才有“训练好的能力” |
test_dataloader = test_data_process() | 调用函数获取测试数据加载器 | 准备测试数据 |
test_model_process(model, test_dataloader) | 调用测试函数 | 执行模型评估,最终打印测试准确率 |
总结
- 测试阶段的核心优化:
torch.no_grad()+model.eval()是测试的“黄金组合”——禁用梯度+关闭训练层行为,保证效率和结果准确; - 设备一致性:模型、特征、标签必须在同一设备(CPU/GPU),否则会报张量设备不匹配错误;
- 准确率计算逻辑:通过
torch.argmax取预测类别,和真实标签比较后累加正确数,最终除以总样本数。
完整代码如下
import torch
import torch.utils.data as Data
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from model import LeNet
def test_data_process():
test_data = FashionMNIST(root='./data',
train=False,
transform=transforms.Compose([transforms.Resize(size=28), transforms.ToTensor()]),
download=True)
test_dataloader = Data.DataLoader(dataset=test_data,
batch_size=1,
shuffle=True,
num_workers=0)
return test_dataloader
def test_model_process(model, test_dataloader):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
#初始化参数
test_correct = 0.0
test_num = 0
#只进行前向传播计算,不计算梯度,从而节省内存,加快运行速度
with torch.no_grad():
for test_data_x, test_data_y in test_dataloader:
# 将特征放入到测试设备中
test_data_x = test_data_x.to(device)
# 将标签放入到测试设备中
test_data_y = test_data_y.to(device)
# 设置模型为评估模式
model.eval()
# 前向传播过程,输入为测试数据集,输出为对每个样本的预测值
output = model(test_data_x)
# 查找每一行中最大值对应的行标
pre_lab = torch.argmax(output, dim=1) #沿着第一维度 (列)找最大值
# 如果预测正确,则准确度test_corrects加1
test_correct += torch.sum(pre_lab == test_data_y.data)
# 将所有的测试样本进行累加
test_num += test_data_x.size(0)
# 计算测试准确率
test_acc = test_correct.double().item() / test_num
print("测试的准确率为:", test_acc)
if __name__=="__main__":
# 加载模型
model = LeNet()
model.load_state_dict(torch.load('model.pth'))
# 加载测试数据
test_dataloader = test_data_process()
# 加载模型测试的函数
test_model_process(model, test_dataloader)