ZetCode

Python __imatmul__ 方法

最后修改于 2025 年 4 月 8 日

本综合指南探讨了 Python 的 __imatmul__ 方法,这是用于原地(in-place)矩阵乘法的特殊方法。我们将涵盖基本用法、运算符重载、NumPy 集成和实际示例。

基本定义

__imatmul__ 方法使用 @= 运算符实现原地矩阵乘法。它会就地修改左操作数,而不是创建新对象。

关键特征:它必须返回修改后的对象,通常执行矩阵乘法,并在应用 @= 运算符时使用。它是 __matmul__ 的原地版本。

基本 __imatmul__ 实现

这是一个简单的实现,展示了 __imatmul__ 如何与自定义矩阵类一起工作。它演示了基本的原地矩阵乘法。

basic_imatmul.py
class Matrix:
    def __init__(self, data):
        self.data = data
    
    def __imatmul__(self, other):
        if len(self.data[0]) != len(other.data):
            raise ValueError("Incompatible matrix dimensions")
        
        result = [
            [sum(a*b for a,b in zip(row, col)) 
             for col in zip(*other.data)]
            for row in self.data
        ]
        self.data = result
        return self

m1 = Matrix([[1, 2], [3, 4]])
m2 = Matrix([[5, 6], [7, 8]])
m1 @= m2
print(m1.data)  # [[19, 22], [43, 50]]

此示例展示了原地执行的矩阵乘法。@= 运算符调用 __imatmul__,该方法会修改左操作数的数据。

该方法检查维度兼容性,计算乘积,更新 self.data,并返回 self 以维持操作的原地性质。

回退到 __matmul__

如果未实现 __imatmul__,Python 会回退到 __matmul__,然后进行赋值。此示例演示了该行为。

fallback.py
class Matrix:
    def __init__(self, data):
        self.data = data
    
    def __matmul__(self, other):
        print("__matmul__ called")
        result = [
            [sum(a*b for a,b in zip(row, col)) 
             for col in zip(*other.data)]
            for row in self.data
        ]
        return Matrix(result)

m1 = Matrix([[1, 2], [3, 4]])
m2 = Matrix([[5, 6], [7, 8]])
m1 @= m2  # Falls back to __matmul__ + assignment
print(m1.data)  # [[19, 22], [43, 50]]

__imatmul__ 缺失时,Python 会调用 __matmul__ 并将结果分配给左操作数。这会创建一个新对象,而不是原地修改。

输出显示 __matmul__ called,证明了回退行为。对于大型矩阵而言,这不如真正的原地操作高效。

NumPy 数组集成

NumPy 数组实现了 __imatmul__ 以进行高效的原地矩阵运算。此示例展示了其在 NumPy 中的用法。

numpy_imatmul.py
import numpy as np

a = np.array([[1, 2], [3, 4]])
b = np.array([[5, 6], [7, 8]])

print("Before @=:", id(a))
a @= b
print("After @=:", id(a))  # Same ID
print(a)
# Output:
# [[19 22]
# [43 50]]

NumPy 的实现会就地修改数组,而不创建新对象。操作后内存地址 (id) 保持不变。

这对于创建新对象会占用大量内存的大型矩阵尤其重要。NumPy 会优化这些操作以提高性能。

同时实现这两种方法的自定义类

此示例展示了一个同时实现 __matmul____imatmul__ 的类,以演示它们的不同行为。

both_methods.py
class Matrix:
    def __init__(self, data):
        self.data = data
    
    def __matmul__(self, other):
        print("__matmul__ called")
        result = [
            [sum(a*b for a,b in zip(row, col)) 
             for col in zip(*other.data)]
            for row in self.data
        ]
        return Matrix(result)
    
    def __imatmul__(self, other):
        print("__imatmul__ called")
        if len(self.data[0]) != len(other.data):
            raise ValueError("Incompatible dimensions")
        
        self.data = [
            [sum(a*b for a,b in zip(row, col)) 
             for col in zip(*other.data)]
            for row in self.data
        ]
        return self

m1 = Matrix([[1, 2], [3, 4]])
m2 = Matrix([[5, 6], [7, 8]])

m3 = m1 @ m2  # Calls __matmul__
print("m3 is new object:", m3 is not m1)

m1 @= m2  # Calls __imatmul__
print("m1 modified in place:", m1.data)

输出显示了为每个操作调用了哪个方法。@ 创建一个新对象,而 @= 则原地修改。

这表明 Python 根据操作是否为原地操作来选择适当的方法。这两种方法可以在同一个类中共存。

不可变对象和 __imatmul__

不可变对象无法实现真正的原地操作。此示例展示了它们如何通过返回新对象来处理 @=

immutable.py
class ImmutableMatrix:
    def __init__(self, data):
        self._data = tuple(tuple(row) for row in data)
    
    @property
    def data(self):
        return self._data
    
    def __imatmul__(self, other):
        print("Cannot modify immutable object, returning new instance")
        result = [
            [sum(a*b for a,b in zip(row, col)) 
             for col in zip(*other.data)]
            for row in self.data
        ]
        return ImmutableMatrix(result)

m1 = ImmutableMatrix([[1, 2], [3, 4]])
m2 = ImmutableMatrix([[5, 6], [7, 8]])

m1 @= m2  # Actually creates new object
print(m1.data)  # Shows new matrix data

尽管使用了 @=,但由于无法修改原始对象,此操作会创建一个新对象。该方法会对此行为发出警告。

当您想保持不可变性但仍支持原地运算符语法时,此模式很有用。该实现有效地使 @= 的行为类似于 @

最佳实践

资料来源

作者

我叫 Jan Bodnar,我是一名充满激情的程序员,拥有丰富的编程经验。我从 2007 年开始撰写编程文章。至今,我已撰写了 1400 多篇文章和 8 本电子书。我在编程教学方面拥有十多年的经验。

列出所有 Python 教程