ZetCode

Python __matmul__ 方法

最后修改于 2025 年 4 月 8 日

本综合指南探讨了 Python 的 __matmul__ 方法,这是一个实现矩阵乘法的特殊方法。我们将介绍基本用法、NumPy 集成、自定义实现和实际示例。

基本定义

__matmul__ 方法实现了 Python 中的矩阵乘法运算 (@)。它在 Python 3.5 中引入,为矩阵运算提供了一个专用的运算符,与逐元素乘法不同。

主要特征:它必须接受两个操作数(self 和 other),应返回矩阵乘法的结果,并在使用 @ 运算符时调用。它通常用于像 NumPy 这样的数值计算库中。

基本 __matmul__ 实现

这是一个简单的实现,展示了如何在自定义类中使用 __matmul__。这个例子创建了一个基本的 2x2 矩阵类,并支持矩阵乘法。

basic_matmul.py
class Matrix:
    def __init__(self, data):
        self.data = data
        
    def __matmul__(self, other):
        if len(self.data[0]) != len(other.data):
            raise ValueError("Incompatible matrix dimensions")
            
        result = [[0 for _ in range(len(other.data[0]))] 
                 for _ in range(len(self.data))]
        
        for i in range(len(self.data)):
            for j in range(len(other.data[0])):
                for k in range(len(other.data)):
                    result[i][j] += self.data[i][k] * other.data[k][j]
        return Matrix(result)
    
    def __repr__(self):
        return str(self.data)

A = Matrix([[1, 2], [3, 4]])
B = Matrix([[5, 6], [7, 8]])
print(A @ B)  # [[19, 22], [43, 50]]

这个例子展示了标准的矩阵乘法算法。__matmul__ 方法检查维度兼容性,执行计算,并返回一个新的带有结果的 Matrix 实例。

该实现使用嵌套循环来计算行和列的点积。与调用像 multiply() 这样的方法相比,@ 运算符提供了更简洁的语法。

NumPy 矩阵乘法

NumPy 的 ndarray 使用 __matmul__ 进行矩阵乘法。这个例子演示了 NumPy 的实现,它针对性能进行了优化。

numpy_matmul.py
import numpy as np

A = np.array([[1, 2], [3, 4]])
B = np.array([[5, 6], [7, 8]])

# Using @ operator (calls __matmul__)
result = A @ B
print(result)

# Equivalent using matmul function
result = np.matmul(A, B)
print(result)

# Note: * does element-wise multiplication
print(A * B)  # Different from @

NumPy 的实现使用 C 和 Fortran 库进行了高度优化。@ 运算符为矩阵运算提供了一个清晰的语法,同时明确区分了逐元素乘法 (*)。

对于大型矩阵,NumPy 的实现比纯 Python 快几个数量级。它还处理广播和高维数组。

向量乘法

__matmul__ 方法也可以实现向量点积。这个例子展示了一个 Vector 类,它通过 @ 支持点积。

vector_dot.py
class Vector:
    def __init__(self, components):
        self.components = components
        
    def __matmul__(self, other):
        if len(self.components) != len(other.components):
            raise ValueError("Vectors must have same length")
        return sum(a * b for a, b in zip(self.components, other.components))
    
    def __repr__(self):
        return f"Vector({self.components})"

v1 = Vector([1, 2, 3])
v2 = Vector([4, 5, 6])
print(v1 @ v2)  # 32 (1*4 + 2*5 + 3*6)

此实现计算两个向量的点积。 __matmul__ 方法检查向量长度是否匹配,然后计算分量乘积之和。

使用 @ 进行点积提供了数学上的清晰性,尽管某些库专门将其用于矩阵-矩阵乘法。

链式矩阵运算

@ 运算符可以像其他算术运算符一样链接。此示例演示了一个表达式中的多个矩阵乘法。

chained_matmul.py
import numpy as np

A = np.random.rand(3, 3)
B = np.random.rand(3, 3)
C = np.random.rand(3, 3)

# Chained matrix multiplication
result = A @ B @ C

# Equivalent to:
temp = A @ B
result = temp @ C

print(result.shape)  # (3, 3)

矩阵乘法是结合律的,因此使用 @ 链接操作会按预期工作。 这些操作从左到右执行,每个 @ 在其左操作数上调用 __matmul__

NumPy 会在内部尽可能优化此类链,从而减少临时分配。无论操作计数如何,语法都保持清晰。

自定义线性变换

此示例演示了 __matmul__ 如何实现线性变换,将变换矩阵应用于向量。

linear_transform.py
class Transform:
    def __init__(self, matrix):
        self.matrix = matrix
        
    def __matmul__(self, vector):
        if len(self.matrix[0]) != len(vector):
            raise ValueError("Incompatible dimensions")
        return [sum(m * v for m, v in zip(row, vector))
                for row in self.matrix]

    def __repr__(self):
        return f"Transform({self.matrix})"

# Rotation matrix (45 degrees)
theta = 45 * (3.14159 / 180)
rot = Transform([
    [np.cos(theta), -np.sin(theta)],
    [np.sin(theta), np.cos(theta)]
])

point = [1, 0]  # Point on x-axis
transformed = rot @ point
print(transformed)  # [0.7071, 0.7071] (45° rotated)

此实现使用 @ 运算符将变换矩阵应用于向量。 __matmul__ 方法执行矩阵-向量乘法。

该示例显示了旋转变换,但任何线性变换都可以用这种方式表示。 简洁的语法使数学代码更具可读性。

最佳实践

资料来源

作者

我叫 Jan Bodnar,是一位充满热情的程序员,拥有丰富的编程经验。自 2007 年以来,我一直在撰写编程文章。到目前为止,我已经撰写了超过 1,400 篇文章和 8 本电子书。我拥有超过十年的编程教学经验。

列出所有 Python 教程