Python __matmul__ 方法
最后修改于 2025 年 4 月 8 日
本综合指南探讨了 Python 的 __matmul__ 方法,这是一个实现矩阵乘法的特殊方法。我们将介绍基本用法、NumPy 集成、自定义实现和实际示例。
基本定义
__matmul__ 方法实现了 Python 中的矩阵乘法运算 (@)。它在 Python 3.5 中引入,为矩阵运算提供了一个专用的运算符,与逐元素乘法不同。
主要特征:它必须接受两个操作数(self 和 other),应返回矩阵乘法的结果,并在使用 @ 运算符时调用。它通常用于像 NumPy 这样的数值计算库中。
基本 __matmul__ 实现
这是一个简单的实现,展示了如何在自定义类中使用 __matmul__。这个例子创建了一个基本的 2x2 矩阵类,并支持矩阵乘法。
class Matrix:
def __init__(self, data):
self.data = data
def __matmul__(self, other):
if len(self.data[0]) != len(other.data):
raise ValueError("Incompatible matrix dimensions")
result = [[0 for _ in range(len(other.data[0]))]
for _ in range(len(self.data))]
for i in range(len(self.data)):
for j in range(len(other.data[0])):
for k in range(len(other.data)):
result[i][j] += self.data[i][k] * other.data[k][j]
return Matrix(result)
def __repr__(self):
return str(self.data)
A = Matrix([[1, 2], [3, 4]])
B = Matrix([[5, 6], [7, 8]])
print(A @ B) # [[19, 22], [43, 50]]
这个例子展示了标准的矩阵乘法算法。__matmul__ 方法检查维度兼容性,执行计算,并返回一个新的带有结果的 Matrix 实例。
该实现使用嵌套循环来计算行和列的点积。与调用像 multiply() 这样的方法相比,@ 运算符提供了更简洁的语法。
NumPy 矩阵乘法
NumPy 的 ndarray 使用 __matmul__ 进行矩阵乘法。这个例子演示了 NumPy 的实现,它针对性能进行了优化。
import numpy as np A = np.array([[1, 2], [3, 4]]) B = np.array([[5, 6], [7, 8]]) # Using @ operator (calls __matmul__) result = A @ B print(result) # Equivalent using matmul function result = np.matmul(A, B) print(result) # Note: * does element-wise multiplication print(A * B) # Different from @
NumPy 的实现使用 C 和 Fortran 库进行了高度优化。@ 运算符为矩阵运算提供了一个清晰的语法,同时明确区分了逐元素乘法 (*)。
对于大型矩阵,NumPy 的实现比纯 Python 快几个数量级。它还处理广播和高维数组。
向量乘法
__matmul__ 方法也可以实现向量点积。这个例子展示了一个 Vector 类,它通过 @ 支持点积。
class Vector:
def __init__(self, components):
self.components = components
def __matmul__(self, other):
if len(self.components) != len(other.components):
raise ValueError("Vectors must have same length")
return sum(a * b for a, b in zip(self.components, other.components))
def __repr__(self):
return f"Vector({self.components})"
v1 = Vector([1, 2, 3])
v2 = Vector([4, 5, 6])
print(v1 @ v2) # 32 (1*4 + 2*5 + 3*6)
此实现计算两个向量的点积。 __matmul__ 方法检查向量长度是否匹配,然后计算分量乘积之和。
使用 @ 进行点积提供了数学上的清晰性,尽管某些库专门将其用于矩阵-矩阵乘法。
链式矩阵运算
@ 运算符可以像其他算术运算符一样链接。此示例演示了一个表达式中的多个矩阵乘法。
import numpy as np A = np.random.rand(3, 3) B = np.random.rand(3, 3) C = np.random.rand(3, 3) # Chained matrix multiplication result = A @ B @ C # Equivalent to: temp = A @ B result = temp @ C print(result.shape) # (3, 3)
矩阵乘法是结合律的,因此使用 @ 链接操作会按预期工作。 这些操作从左到右执行,每个 @ 在其左操作数上调用 __matmul__ 。
NumPy 会在内部尽可能优化此类链,从而减少临时分配。无论操作计数如何,语法都保持清晰。
自定义线性变换
此示例演示了 __matmul__ 如何实现线性变换,将变换矩阵应用于向量。
class Transform:
def __init__(self, matrix):
self.matrix = matrix
def __matmul__(self, vector):
if len(self.matrix[0]) != len(vector):
raise ValueError("Incompatible dimensions")
return [sum(m * v for m, v in zip(row, vector))
for row in self.matrix]
def __repr__(self):
return f"Transform({self.matrix})"
# Rotation matrix (45 degrees)
theta = 45 * (3.14159 / 180)
rot = Transform([
[np.cos(theta), -np.sin(theta)],
[np.sin(theta), np.cos(theta)]
])
point = [1, 0] # Point on x-axis
transformed = rot @ point
print(transformed) # [0.7071, 0.7071] (45° rotated)
此实现使用 @ 运算符将变换矩阵应用于向量。 __matmul__ 方法执行矩阵-向量乘法。
该示例显示了旋转变换,但任何线性变换都可以用这种方式表示。 简洁的语法使数学代码更具可读性。
最佳实践
- 遵循数学约定: 实现正确的矩阵乘法规则
- 检查维度: 验证输入形状是否符合乘法要求
- 返回适当的类型: 使用输入维护一致的返回类型
- 记录行为: 清楚地指定支持的操作和维度
- 考虑性能: 对于复杂的操作,优化或使用库
资料来源
作者
列出所有 Python 教程。