在深度学习代码中出现了张量如何转化为二维矩阵的例子,用AI整理过后形成了本文。

一、先理解:为什么要把张量展平成二维矩阵?

全连接层(仿射层)的核心运算是矩阵乘法y = X · W + b,而矩阵乘法有严格的维度要求:

  • 输入X必须是二维矩阵(形状:[样本数, 输入特征数]
  • 权重W是二维矩阵(形状:[输入特征数, 输出特征数]

但实际场景中,输入往往不是二维的:

  • 比如批量图片:形状是 [批量数, 通道数, 高, 宽](例:(100, 3, 28, 28) → 100 张 3 通道 28×28 的图片)
  • 比如批量序列数据:形状是 [批量数, 序列长度, 特征维度](例:(50, 10, 128)

这些高维张量无法直接和权重矩阵做乘法,必须先展平成二维——固定 “样本数” 为第一维,把剩下的所有维度合并成 “特征数” 一维

二、核心操作:X.reshape(X.shape[0], -1) 详解

reshape是 NumPy/PyTorch 中重塑数组 / 张量形状的函数,这里的关键是X.shape[0]-1的组合:

部分作用
X.shape[0]固定第一维为样本数(批量大小),保证每个样本的特征独立,不混淆。
-1NumPy 的 “自动计算维度” 标记:让程序自动计算这一维的大小,等于「总元素数 ÷ 样本数」。

例子

假设输入是 100 张 3 通道 28×28 的图片,张量形状为 X.shape = (100, 3, 28, 28)

import numpy as np

# 模拟输入:100张3通道28×28的图片
X = np.random.randn(100, 3, 28, 28)
print("原始形状:", X.shape)  # 输出:(100, 3, 28, 28)

# 展平成二维矩阵
X_flat = X.reshape(X.shape[0], -1)
print("展平后形状:", X_flat.shape)  # 输出:(100, 2352)
# 2352 = 3 × 28 × 28,刚好是单张图片的所有像素数(特征数)

再比如输入是 50 条序列数据(形状(50, 10, 128)):

X = np.random.randn(50, 10, 128)
X_flat = X.reshape(X.shape[0], -1)
print(X_flat.shape)  # 输出:(50, 1280) → 1280 = 10 × 128

三、original_x_shape 的关键作用:恢复梯度形状

反向传播时,我们计算出的输入梯度dX一开始是展平后的二维形状(比如(100, 2352)),但前一层的输出是高维的(比如卷积层输出是(100, 3, 28, 28)),如果直接把二维的梯度传给前一层,维度会不匹配,导致计算错误。

所以original_x_shape的作用就是保存输入的原始形状,反向传播时用它把梯度恢复成原始维度:

python

运行

# 反向传播时恢复梯度形状
dX = np.dot(dy, self.W.T)  # 计算出的梯度是二维:(100, 2352)
dX = dX.reshape(*self.original_x_shape)  # 恢复成(100, 3, 28, 28)

这里的*self.original_x_shape是 Python 的 “解包” 操作,把保存的形状元组(100, 3, 28, 28)拆解成参数传给reshape,等价于dX.reshape(100, 3, 28, 28)

总结

  1. X.reshape(X.shape[0], -1) 是高维张量适配全连接层的核心:固定样本数维度,自动展平剩余维度为一维,得到 “样本数 × 总特征数” 的二维矩阵。
  2. original_x_shape 是反向传播的 “维度桥梁”:保存原始形状,确保梯度能恢复成前一层需要的维度,避免维度不匹配错误。
  3. -1 在 reshape 中是 “懒人神器”:无需手动计算展平后的特征数,由程序自动推导,适配任意高维输入。

最后修改:2026 年 01 月 08 日
如果觉得我的文章对你有用,请随意赞赏