einsum字符串需确保输入维度标签与输出标签严格匹配,字母顺序须与张量ndim一致,重复字母表求和或对角线,跨输入重复触发求和,空输出表示标量,省略号要求前缀维度对齐。
核心是让输入张量的维度标签和输出标签严格匹配,einsum 不会自动广播或对齐轴,写错一个字母就直接抛 ValueError: operands could not be broadcast together 或 IndexError。
"ij"、"jk")表示其形状,字母顺序必须和 ndim 一致"ii" 表示对角线),跨输入重复则触
"ij,jk" 中的 j)"ii->")表示标量结果,不能漏掉箭头后的空字符串多数矩阵乘、转置、迹运算都能用 einsum 更清晰地表达,且避免临时数组分配。
A @ B → np.einsum("ij,jk->ik", A, B)(比 dot 更显式控制哪维参与计算)B @ C,其中 B.shape = (b, i, j), C.shape = (b, j, k) → np.einsum("bij,bjk->bik", B, C)
np.diag(A) → np.einsum("ii->i", A);求迹 → np.einsum("ii->", A)
np.outer(u, v) → np.einsum("i,j->ij", u, v)
einsum 默认走通用路径,对简单操作(如二维矩阵乘)不如高度优化的 BLAS 后端快;是否加速取决于操作复杂度和数据规模。
):通常 matmul 或 dot 更快,einsum 有解析字符串开销
"ab,cd,be->acde"):einsum 可能显著胜出,因避免多个中间数组optimize=True(如 np.einsum("...,...->...", A, B, optimize=True))可启用路径优化,对三阶及以上张量尤其重要optimize="greedy" 或 "optimal" 会增加预处理时间,仅当反复调用同结构时值得开启einsum 默认按输入中最高精度 dtype 输出,但不会自动提升整数精度;同时,它不共享内存,结果总是新分配数组。
int32 矩阵相乘,结果仍是 int32,可能溢出;需显式转成 float64 或用 dtype=np.float64 参数指定out= 参数复用内存(einsum 不支持 out 参数)... 时(如 "...ij,...jk->...ik"),要确保前缀维度完全对齐,否则运行时报错而非静默截断einsum 版本验证逻辑,再测性能;别为了“看起来高级”硬套,尤其是二维场景下 @ 还是最稳的。