本文讲解 pytorch 中张量切片的核心原理,重点解决因误用索引维度导致的形状错误问题——如将 shape 为 `[2, 11938]` 的张量错误切分为 `[2, 64]` 所需的正确语法是 `tensor[:, start:end]`,而非 `tensor[0:2][start:end]`。
在 PyTorch 中,张量(torch.Tensor)的切片遵循与 NumPy 高度一致的多维索引规则:每个维度需显式指定索引范围,使用冒号 : 表示“该维度全部保留”。你遇到的问题根源在于对二维张量索引逻辑的理解偏差。
你的原始张量 X_train 形状为 torch.Size([2, 11938]),即:
你希望每次取 64 个连续样本,形成 shape 为 [2, 64] 的 batch,这本质上是对第 1 维(列方向)进行切片,而第 0 维应保持完整。
❌ 错误写法分析:
y_pred = model(X_train[0:2][batch:batch+BATCH_SIZE])
✅ 正确写法:使用逗号分隔各维索引,明确指定切片维度
BATCH_SIZE = 64
for start_idx in range(0, X_train.size(1), BATCH_SIZE): # 注意:遍历的是第1维长度!
end_idx = min(start_idx + BATCH_SIZE, X_train.size(1))
X_batch = X_train[:, start_idx:end_idx] # ✅ 保留所有行(:),切第1维 [start:end]
y_batch = y_train[:, start_idx:end_idx] # 同理处理标签(若 y_train 也是 [2, 11938])
# 训练流程
model.train()
y_pred = model(X_batch) # 输入 shape: [2, 64] —— 符合模型权重兼容性
# ... loss计算、反向传播等? 关
键要点总结:
掌握这一维度意识,是写出高效、无错 PyTorch 数据管道的基础。