Python __matmul__ 方法
最后修改于 2025 年 4 月 8 日
本综合指南探讨了 Python 的 __matmul__
方法,这是一个实现矩阵乘法的特殊方法。我们将介绍基本用法、NumPy 集成、自定义实现和实际示例。
基本定义
__matmul__
方法实现了 Python 中的矩阵乘法运算 (@
)。它在 Python 3.5 中引入,为矩阵运算提供了一个专用的运算符,与逐元素乘法不同。
主要特征:它必须接受两个操作数(self 和 other),应返回矩阵乘法的结果,并在使用 @
运算符时调用。它通常用于像 NumPy 这样的数值计算库中。
基本 __matmul__ 实现
这是一个简单的实现,展示了如何在自定义类中使用 __matmul__
。这个例子创建了一个基本的 2x2 矩阵类,并支持矩阵乘法。
class Matrix: def __init__(self, data): self.data = data def __matmul__(self, other): if len(self.data[0]) != len(other.data): raise ValueError("Incompatible matrix dimensions") result = [[0 for _ in range(len(other.data[0]))] for _ in range(len(self.data))] for i in range(len(self.data)): for j in range(len(other.data[0])): for k in range(len(other.data)): result[i][j] += self.data[i][k] * other.data[k][j] return Matrix(result) def __repr__(self): return str(self.data) A = Matrix([[1, 2], [3, 4]]) B = Matrix([[5, 6], [7, 8]]) print(A @ B) # [[19, 22], [43, 50]]
这个例子展示了标准的矩阵乘法算法。__matmul__
方法检查维度兼容性,执行计算,并返回一个新的带有结果的 Matrix 实例。
该实现使用嵌套循环来计算行和列的点积。与调用像 multiply()
这样的方法相比,@
运算符提供了更简洁的语法。
NumPy 矩阵乘法
NumPy 的 ndarray
使用 __matmul__
进行矩阵乘法。这个例子演示了 NumPy 的实现,它针对性能进行了优化。
import numpy as np A = np.array([[1, 2], [3, 4]]) B = np.array([[5, 6], [7, 8]]) # Using @ operator (calls __matmul__) result = A @ B print(result) # Equivalent using matmul function result = np.matmul(A, B) print(result) # Note: * does element-wise multiplication print(A * B) # Different from @
NumPy 的实现使用 C 和 Fortran 库进行了高度优化。@
运算符为矩阵运算提供了一个清晰的语法,同时明确区分了逐元素乘法 (*
)。
对于大型矩阵,NumPy 的实现比纯 Python 快几个数量级。它还处理广播和高维数组。
向量乘法
__matmul__
方法也可以实现向量点积。这个例子展示了一个 Vector 类,它通过 @
支持点积。
class Vector: def __init__(self, components): self.components = components def __matmul__(self, other): if len(self.components) != len(other.components): raise ValueError("Vectors must have same length") return sum(a * b for a, b in zip(self.components, other.components)) def __repr__(self): return f"Vector({self.components})" v1 = Vector([1, 2, 3]) v2 = Vector([4, 5, 6]) print(v1 @ v2) # 32 (1*4 + 2*5 + 3*6)
此实现计算两个向量的点积。 __matmul__
方法检查向量长度是否匹配,然后计算分量乘积之和。
使用 @
进行点积提供了数学上的清晰性,尽管某些库专门将其用于矩阵-矩阵乘法。
链式矩阵运算
@
运算符可以像其他算术运算符一样链接。此示例演示了一个表达式中的多个矩阵乘法。
import numpy as np A = np.random.rand(3, 3) B = np.random.rand(3, 3) C = np.random.rand(3, 3) # Chained matrix multiplication result = A @ B @ C # Equivalent to: temp = A @ B result = temp @ C print(result.shape) # (3, 3)
矩阵乘法是结合律的,因此使用 @
链接操作会按预期工作。 这些操作从左到右执行,每个 @
在其左操作数上调用 __matmul__
。
NumPy 会在内部尽可能优化此类链,从而减少临时分配。无论操作计数如何,语法都保持清晰。
自定义线性变换
此示例演示了 __matmul__
如何实现线性变换,将变换矩阵应用于向量。
class Transform: def __init__(self, matrix): self.matrix = matrix def __matmul__(self, vector): if len(self.matrix[0]) != len(vector): raise ValueError("Incompatible dimensions") return [sum(m * v for m, v in zip(row, vector)) for row in self.matrix] def __repr__(self): return f"Transform({self.matrix})" # Rotation matrix (45 degrees) theta = 45 * (3.14159 / 180) rot = Transform([ [np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)] ]) point = [1, 0] # Point on x-axis transformed = rot @ point print(transformed) # [0.7071, 0.7071] (45° rotated)
此实现使用 @
运算符将变换矩阵应用于向量。 __matmul__
方法执行矩阵-向量乘法。
该示例显示了旋转变换,但任何线性变换都可以用这种方式表示。 简洁的语法使数学代码更具可读性。
最佳实践
- 遵循数学约定: 实现正确的矩阵乘法规则
- 检查维度: 验证输入形状是否符合乘法要求
- 返回适当的类型: 使用输入维护一致的返回类型
- 记录行为: 清楚地指定支持的操作和维度
- 考虑性能: 对于复杂的操作,优化或使用库
资料来源
作者
列出所有 Python 教程。