ZetCode

Python unittest

最后修改于 2025 年 2 月 25 日

unittest 模块是 Python 内置的用于编写和执行单元测试的框架,其灵感来源于 Java 生态系统中的 JUnit。它使开发人员能够验证其代码的各个组件(函数、方法或类)是否按预期工作。单元测试是可靠软件开发的基石,有助于您及早发现错误、验证功能并确保可维护性。

unittest 中的关键概念

设置 unittest

要开始使用,请导入 unittest 并通过子类化 unittest.TestCase 创建一个测试类。测试方法必须以 test_ 开头,才能被测试运行器识别。

BasicTest.py
import unittest

class MyTestCase(unittest.TestCase):

    def test_basic_arithmetic(self):
        self.assertEqual(1 + 1, 2)

if __name__ == '__main__':
    unittest.main()

此示例演示了使用 unittest 模块的基本单元测试。MyTestCase 类继承自 unittest.TestCase,定义了一个名为 test_basic_arithmetic 的测试方法。该方法使用 assertEqual 来验证 1 + 1 是否等于 2。通过 unittest.main() 执行时,测试运行器会检查此断言并报告结果。

从命令行运行测试

python -m unittest test_filename.py

要获得详细输出(显示测试名称和结果)

python -m unittest test_filename.py -v

unittest 实用示例

以下是在各种场景中展示 unittest 的实用示例,从基本函数到高级功能。

测试一个简单的函数

用多个用例测试一个基本的加法函数。

TestAddFunction.py
import unittest

def add(a, b):
    return a + b

class TestAddFunction(unittest.TestCase):

    def test_add_positive_numbers(self):
        self.assertEqual(add(2, 3), 5)

    def test_add_negative_numbers(self):
        self.assertEqual(add(-1, -1), -2)

    def test_add_mixed_numbers(self):
        self.assertEqual(add(-1, 1), 0)

if __name__ == '__main__':
    unittest.main()

此代码通过三个测试用例测试了一个简单的 add 函数。TestAddFunction 子类化了 unittest.TestCase,并包含了检查正数相加 (2 + 3 = 5)、负数相加 (-1 + -1 = -2) 以及混合数相加 (-1 + 1 = 0) 的方法。每个测试都使用 assertEqual 来确保函数返回预期的输出。

测试字符串方法

通过正面和负面测试来验证内置的字符串方法。

TestStringMethods.py
import unittest

class TestStringMethods(unittest.TestCase):

    def test_upper(self):
        self.assertEqual('hello'.upper(), 'HELLO')

    def test_isupper(self):
        self.assertTrue('HELLO'.isupper())
        self.assertFalse('Hello'.isupper())

    def test_split(self):
        s = 'hello there'
        self.assertEqual(s.split(), ['hello', 'there'])
        with self.assertRaises(TypeError):
            s.split(2)  # split() expects a string, not an integer

if __name__ == '__main__':
    unittest.main()

此示例测试 Python 的内置字符串方法。TestStringMethods 包含三个测试:test_upper 检查 'hello'.upper() 是否返回 'HELLO',test_isupper 验证 'HELLO'.isupper() 为 True 且 'Hello'.isupper() 为 False,而 test_split 确认 'hello there'.split() 产生 ['hello', 'there']。它还使用 assertRaises 来确保 split(2) 会引发一个 TypeError

测试列表方法

测试列表操作方法的正确性。

TestListMethods.py
import unittest

class TestListMethods(unittest.TestCase):

    def test_append(self):
        my_list = [1, 2, 3]
        my_list.append(4)
        self.assertEqual(my_list, [1, 2, 3, 4])

    def test_pop(self):
        my_list = [1, 2, 3]
        popped_value = my_list.pop()
        self.assertEqual(popped_value, 3)
        self.assertEqual(my_list, [1, 2])

if __name__ == '__main__':
    unittest.main()

此代码验证列表操作。TestListMethods 通过向 [1, 2, 3] 添加 4 来测试 append,并用 assertEqual 检查结果是否为 [1, 2, 3, 4]。test_pop 方法从 [1, 2, 3] 中移除最后一个元素,验证弹出的值是 3,剩余列表是 [1, 2],并对这两项检查都使用了 assertEqual

测试异常

确保函数按预期引发异常。

TestDivideFunction.py
import unittest

def divide(a, b):

    if b == 0:
        raise ValueError("Cannot divide by zero")
    return a / b

class TestDivideFunction(unittest.TestCase):

    def test_divide_valid(self):
        self.assertEqual(divide(10, 2), 5)

    def test_divide_by_zero(self):
        with self.assertRaises(ValueError):
            divide(10, 0)

if __name__ == '__main__':
    unittest.main()

此示例测试一个在除以零时会引发异常的 divide 函数。TestDivideFunction 包含 test_divide_valid,它使用 assertEqual 检查 divide(10, 2) 是否等于 5。test_divide_by_zero 方法使用 assertRaises 来确认 divide(10, 0) 会引发 ValueError,从而确保正确的异常处理。

使用 setUp 和 tearDown

通过 setUp 和 tearDown 模拟资源(例如数据库)。

TestDatabase.py
import unittest

class TestDatabase(unittest.TestCase):

    def setUp(self):
        # Simulate opening a database connection
        self.database = []

    def tearDown(self):
        # Simulate closing the connection
        self.database = None

    def test_insert(self):
        self.database.append('data')
        self.assertIn('data', self.database)

    def test_delete(self):
        self.database.append('data')
        self.database.remove('data')
        self.assertNotIn('data', self.database)

if __name__ == '__main__':
    unittest.main()

此代码演示了使用 setUptearDown 的测试固件。TestDatabase 使用 setUp 初始化一个空列表作为模拟数据库,并使用 tearDown 将其重置为 None。test_insert 添加了 'data' 并用 assertIn 检查其是否存在,而 test_delete 先添加再移除 'data',并用 assertNotIn 验证其不存在,从而模拟资源管理。

测试一个类

测试一个简单的 Calculator 类的方法。

TestCalculator.py
import unittest

class Calculator:
    def add(self, a, b):
        return a + b

    def subtract(self, a, b):
        return a - b

class TestCalculator(unittest.TestCase):

    def setUp(self):
        self.calc = Calculator()

    def test_add(self):
        self.assertEqual(self.calc.add(2, 3), 5)

    def test_subtract(self):
        self.assertEqual(self.calc.subtract(5, 3), 2)

if __name__ == '__main__':
    unittest.main()

此示例测试一个带有 add 和 subtract 方法的 Calculator 类。TestCalculator 使用 setUp 创建一个 Calculator 实例。test_add 方法检查 calc.add(2, 3) 是否返回 5,test_subtract 验证 calc.subtract(5, 3) 是否等于 2,两者都使用 assertEqual 来确保类方法功能正确。

跳过测试

演示如何有条件或无条件地跳过测试。

TestSkipExample.py
import unittest

class TestSkipExample(unittest.TestCase):

    @unittest.skip("Skipping this test for demonstration")
    def test_skip(self):
        self.fail("This test should be skipped")

    @unittest.skipIf(2 > 1, "Skipping because condition is true")
    def test_skip_if(self):
        self.assertEqual(1, 2)  # Would fail if not skipped

    def test_normal(self):
        self.assertEqual(1 + 1, 2)

if __name__ == '__main__':
    unittest.main()

此代码演示了如何跳过测试。TestSkipExample 使用 @unittest.skip 无条件地跳过 test_skip,该测试原本会因 self.fail 而失败。@unittest.skipIf 装饰器在 2 > 1 为真时跳过 test_skip_if,从而避免了一个失败的断言。test_normal 则正常运行,用 assertEqual 检查 1 + 1 = 2。

使用 assertAlmostEqual 进行测试

处理浮点数算术的精度问题。

TestFloatingPoint.py
import unittest

class TestFloatingPoint(unittest.TestCase):

    def test_almost_equal(self):
        self.assertAlmostEqual(0.1 + 0.2, 0.3, places=7)  # Accounts for float precision

    def test_not_almost_equal(self):
        self.assertNotAlmostEqual(0.1 + 0.2, 0.4, places=7)

if __name__ == '__main__':
    unittest.main()

此示例解决了浮点数精度问题。TestFloatingPointtest_almost_equal 中使用 assertAlmostEqual 来验证 0.1 + 0.2 ≈ 0.3,精度到小数点后 7 位,考虑了浮点数的不精确性。Test_not_almost_equal 使用 assertNotAlmostEqual 来确保 0.1 + 0.2 ≠ 0.4,展示了考虑精度的测试方法。

使用 assertRaises 进行测试

验证异常是否按预期被引发。

TestException.py
import unittest

def raise_exception():
    raise ValueError("An error occurred")

class TestException(unittest.TestCase):

    def test_raise_exception(self):
        with self.assertRaises(ValueError) as context:
            raise_exception()
        self.assertEqual(str(context.exception), "An error occurred")

if __name__ == '__main__':
    unittest.main()

此代码测试异常处理。TestException 定义了 raise_exception,它会引发一个 ValueErrortest_raise_exception 方法使用 assertRaises 来确认异常发生,并使用上下文管理器捕获它,然后用 assertEqual 检查异常消息是否与 'An error occurred' 匹配。

使用测试套件

手动创建并运行一个自定义测试套件。

TestSuiteExample.py
import unittest

class TestSuiteExample1(unittest.TestCase):
    def test_case1(self):
        self.assertEqual(1, 1)

class TestSuiteExample2(unittest.TestCase):
    def test_case2(self):
        self.assertEqual(2, 2)

def suite():

    suite = unittest.TestSuite()
    suite.addTest(TestSuiteExample1('test_case1'))
    suite.addTest(TestSuiteExample2('test_case2'))
    return suite

if __name__ == '__main__':

    runner = unittest.TextTestRunner()
    test_suite = suite()
    runner.run(test_suite)

此示例创建了一个自定义测试套件。TestSuiteExample1TestSuiteExample2 各包含一个检查相等性(1 = 1 和 2 = 2)的测试。suite 函数构建了一个 TestSuite,添加了特定的测试,然后由 TextTestRunner 执行,从而允许手动分组和运行测试。

测试文件操作

使用临时文件模拟文件操作。

TestFileOperations.py
import unittest
import os

class TestFileOperations(unittest.TestCase):

    def setUp(self):
        self.filename = "temp_test.txt"
        with open(self.filename, 'w') as f:
            f.write("Hello, there!")

    def tearDown(self):
        if os.path.exists(self.filename):
            os.remove(self.filename)

    def test_read_file(self):
        with open(self.filename, 'r') as f:
            content = f.read()
        self.assertEqual(content, "Hello, there!")

if __name__ == '__main__':
    unittest.main()

此示例使用固件测试文件操作。TestFileOperations 使用 setUp 创建一个包含“Hello, there!”内容的文件,并使用 tearDown 删除它。test_read_file 读取文件内容,并使用 assertEqual 确保内容与写入的字符串匹配,从而模拟文件 I/O 测试。

使用 unittest.mock 进行模拟测试

使用模拟来隔离依赖项。

TestMocking.py
import unittest
from unittest.mock import Mock

def fetch_data(api):
    return api.get_data()

class TestMocking(unittest.TestCase):

    def test_fetch_data(self):

        # Create a mock API object
        mock_api = Mock()
        mock_api.get_data.return_value = "mocked data"
        
        result = fetch_data(mock_api)
        self.assertEqual(result, "mocked data")
        mock_api.get_data.assert_called_once()

if __name__ == '__main__':
    unittest.main()

此代码演示了使用 unittest.mock 进行模拟。TestMockingtest_fetch_data 中模拟了一个 API,将 get_data 的返回值设置为 'mocked data'。它使用模拟对象调用 fetch_data,用 assertEqual 验证结果,并用 assert_called_once 检查该方法被调用了一次,从而隔离了依赖项。

测试类型检查

确保函数能正确处理输入类型。

TestTypeChecking.py
import unittest

def multiply(a, b):
    if not isinstance(a, (int, float)) or not isinstance(b, (int, float)):
        raise TypeError("Inputs must be numbers")
    return a * b

class TestTypeChecking(unittest.TestCase):
    def test_valid_input(self):
        self.assertEqual(multiply(2, 3), 6)

    def test_invalid_input(self):
        with self.assertRaises(TypeError):
            multiply("2", 3)

if __name__ == '__main__':
    unittest.main()

此示例测试 multiply 函数中的类型验证。TestTypeChecking 使用 test_valid_input 检查 multiply(2, 3) = 6,使用了 assertEqualtest_invalid_input 使用 assertRaises 确保 multiply('2', 3) 会引发 TypeError,从而确认该函数强制要求数字输入。

测试边界情况

用边界条件测试一个函数。

TestClampFunction.py
import unittest

def clamp(value, min_val, max_val):
    """Clamp value between min_val and max_val."""
    return max(min_val, min(max_val, value))

class TestClampFunction(unittest.TestCase):

    def test_within_range(self):
        self.assertEqual(clamp(5, 0, 10), 5)

    def test_below_range(self):
        self.assertEqual(clamp(-1, 0, 10), 0)

    def test_above_range(self):
        self.assertEqual(clamp(15, 0, 10), 10)

if __name__ == '__main__':
    unittest.main()

此代码测试一个用于边界处理的 clamp 函数。TestClampFunction 检查了 test_within_range (clamp(5, 0, 10) = 5)、test_below_range (clamp(-1, 0, 10) = 0) 和 test_above_range (clamp(15, 0, 10) = 10),所有测试都使用 assertEqual 来验证值是否保持在指定范围内。

测试排序算法

sorting_algos.py
def bubble_sort(arr):

    n = len(arr)
    for i in range(n):
        for j in range(0, n-i-1):
            if arr[j] > arr[j+1]:
                arr[j], arr[j+1] = arr[j+1], arr[j]

def selection_sort(arr, ascending=True):

    n = len(arr)
    for i in range(n):
        idx = i
        for j in range(i + 1, n):
            if (ascending and arr[j] < arr[idx]) or (not ascending and arr[j] > arr[idx]):
                idx = j
        arr[i], arr[idx] = arr[idx], arr[i]
    return arr

此文件定义了两种排序算法。bubble_sort 实现冒泡排序,如果相邻元素顺序错误则进行交换,原地修改数组。selection_sort 在每次迭代中找到最小(或最大)元素,根据 ascending 参数按升序或降序排序,并返回排序后的数组。

test_algos.py
import unittest
from sorting_algos import bubble_sort, selection_sort

class TestSortingAlgorithms(unittest.TestCase):

    def setUp(self):
        """Set up test fixtures with various input arrays."""
        self.unsorted_list = [64, 34, 25, 12, 22, 11, 90]
        self.sorted_asc_list = [11, 12, 22, 25, 34, 64, 90]
        self.sorted_desc_list = [90, 64, 34, 25, 22, 12, 11]
        self.empty_list = []
        self.single_element_list = [42]
        self.duplicates_list = [5, 2, 8, 5, 1, 9, 2]

    def test_bubble_sort_unsorted(self):
        """Test bubble_sort with an unsorted list."""
        arr = self.unsorted_list.copy()  # Use copy to preserve original fixture
        bubble_sort(arr)
        self.assertEqual(arr, self.sorted_asc_list)

    def test_bubble_sort_already_sorted(self):
        """Test bubble_sort with an already sorted list."""
        arr = self.sorted_asc_list.copy()
        bubble_sort(arr)
        self.assertEqual(arr, self.sorted_asc_list)

    def test_bubble_sort_empty(self):
        """Test bubble_sort with an empty list."""
        arr = self.empty_list.copy()
        bubble_sort(arr)
        self.assertEqual(arr, self.empty_list)

    def test_bubble_sort_single_element(self):
        """Test bubble_sort with a single-element list."""
        arr = self.single_element_list.copy()
        bubble_sort(arr)
        self.assertEqual(arr, self.single_element_list)

    def test_bubble_sort_duplicates(self):
        """Test bubble_sort with a list containing duplicates."""
        arr = self.duplicates_list.copy()
        bubble_sort(arr)
        self.assertEqual(arr, [1, 2, 2, 5, 5, 8, 9])

    def test_selection_sort_ascending_unsorted(self):
        """Test selection_sort with an unsorted list in ascending order."""
        arr = self.unsorted_list.copy()
        result = selection_sort(arr, ascending=True)
        self.assertEqual(result, self.sorted_asc_list)

    def test_selection_sort_descending_unsorted(self):
        """Test selection_sort with an unsorted list in descending order."""
        arr = self.unsorted_list.copy()
        result = selection_sort(arr, ascending=False)
        self.assertEqual(result, self.sorted_desc_list)

    def test_selection_sort_ascending_sorted(self):
        """Test selection_sort with an already sorted list in ascending order."""
        arr = self.sorted_asc_list.copy()
        result = selection_sort(arr, ascending=True)
        self.assertEqual(result, self.sorted_asc_list)

    def test_selection_sort_descending_sorted(self):
        """Test selection_sort with an already sorted list in descending order."""
        arr = self.sorted_desc_list.copy()
        result = selection_sort(arr, ascending=False)
        self.assertEqual(result, self.sorted_desc_list)

    def test_selection_sort_empty(self):
        """Test selection_sort with an empty list."""
        arr = self.empty_list.copy()
        result = selection_sort(arr, ascending=True)
        self.assertEqual(result, self.empty_list)

    def test_selection_sort_single_element(self):
        """Test selection_sort with a single-element list."""
        arr = self.single_element_list.copy()
        result = selection_sort(arr, ascending=True)
        self.assertEqual(result, self.single_element_list)

    def test_selection_sort_duplicates_ascending(self):
        """Test selection_sort with duplicates in ascending order."""
        arr = self.duplicates_list.copy()
        result = selection_sort(arr, ascending=True)
        self.assertEqual(result, [1, 2, 2, 5, 5, 8, 9])

    def test_selection_sort_duplicates_descending(self):
        """Test selection_sort with duplicates in descending order."""
        arr = self.duplicates_list.copy()
        result = selection_sort(arr, ascending=False)
        self.assertEqual(result, [9, 8, 5, 5, 2, 2, 1])

if __name__ == '__main__':
    unittest.main()

此文件测试各种情况下的排序算法。TestSortingAlgorithms 使用 setUp 定义了测试固件,如未排序和已排序的列表。像 test_bubble_sort_unsortedtest_selection_sort_ascending_unsorted 这样的测试使用 assertEqual 来验证未排序、已排序、空、单元素以及包含重复元素的列表的排序正确性。

测试扑克牌型

使用 python -m unittest test_rank_hands.py -v 运行

rank_hands.py
from itertools import combinations
from collections import Counter

def create_deck():

    signs = [2, 3, 4, 5, 6, 7, 8, 9, 10, 'J', 'Q', 'K', 'A']
    symbols = ['♠', '♥', '♦', '♣']  # spades, hearts, diamonds, clubs
    deck = [f'{si}{sy}' for si in signs for sy in symbols]
    return deck

def by_poker_order(card):

    poker_order = ["2", "3", "4", "5", "6",
                   "7", "8", "9", "10", "J", "Q", "K", "A"]
    return poker_order.index(card[:-1])

def calculate_combinations(hole: list, ccards: list):

    hands = hole + ccards
    hands.sort(key=by_poker_order)
    combs = combinations(hands, 5)
    return tuple(combs)

def check_rank(hole: list, ccards: list):

    combs = calculate_combinations(hole, ccards)

    match is_royal(combs):
        case (True, form):
            print(f'{form} is a royal flush')
            return

    match is_4_kind(combs):
        case (True, form):
            print(f'{form} is four of a kind')
            return

    match is_full_house(combs):
        case (True, form):
            print(f'{form} is a full house')
            return

    match is_flush(combs):
        case (True, form):
            print(f'{form} is a flush')
            return

    match is_3_kind(combs):
        case (True, form):
            print(f'{form} is three of a kind')
            return

    match is_straight(combs):
        case (True, form):
            print(f'{form} is a straight')
            return

    match is_two_pairs(combs):
        case (True, form):
            print(f'{form} is two pairs')
            return

    match is_pair(combs):
        case (True, form):
            print(f'{form} is a pair')
            return

    match is_high_card(combs):
        case (True, form):
            print(f'{form} is a high card')
            return

def is_royal(combs: list):

    royals = ["10♠ J♠ Q♠ K♠ A♠", "10♣ J♣ Q♣ K♣ A♣",
              "10♥ J♥ Q♥ K♥ A♥", "10♦ J♦ Q♦ K♦ A♦"]

    for comb in combs:
        form = ' '.join(comb)
        if form in royals:
            return True, form

def is_4_kind(combs: list):

    four_kinds = []

    for comb in combs:
        c = Counter([e[:-1] for e in comb])
        vals = c.values()

        if 4 in vals:
            form = ' '.join(comb)
            four_kinds.append(form)

    if len(four_kinds) > 0:
        return True, four_kinds[-1]

def is_full_house(combs: list):

    for comb in combs:
        c = Counter([e[:-1] for e in comb])
        vals = c.values()
        form = ' '.join(comb)

        if 2 in vals and 3 in vals:
            return True, form

def is_flush(combs: list):

    matches = ['♣ ♣ ♣ ♣ ♣', '♦ ♦ ♦ ♦ ♦', '♥ ♥ ♥ ♥ ♥', '♠ ♠ ♠ ♠ ♠']
    flushes = []  # there may be more flush combinations, we pick the strongest

    for comb in combs:
        psuits = ' '.join(e[-1] for e in comb)

        if psuits in matches:
            form = ' '.join(comb)
            flushes.append(form)

    if len(flushes) > 0:
        return True, flushes[-1]

def is_straight(combs: list):

    order = "2 3 4 5 6 7 8 9 10 J Q K A"
    strainghts = []

    for comb in combs:
        seq = [e[:-1] for e in comb]
        unit = ' '.join(seq)

        if unit in order:
            form = ' '.join(comb)
            strainghts.append(form)

    if len(strainghts) > 0:
        return True, strainghts[-1]

def is_3_kind(combs: list):

    three_kinds = []

    for comb in combs:
        c = Counter([e[:-1] for e in comb])
        vals = c.values()

        if 3 in vals:
            form = ' '.join(comb)
            three_kinds.append(form)

    if len(three_kinds) > 0:
        return True, three_kinds[-1]

def is_two_pairs(combs: list):

    two_pairs = []

    for comb in combs:
        c = Counter([e[:-1] for e in comb])
        vals = list(c.values())

        if vals.count(2) == 2:
            form = ' '.join(comb)
            two_pairs.append(form)

    if len(two_pairs) > 0:
        return True, two_pairs[-1]

def is_pair(combs: list):

    pairs = []

    for comb in combs:
        c = Counter([e[:-1] for e in comb])
        vals = list(c.values())

        if vals.count(2) == 1:
            form = ' '.join(comb)
            pairs.append(form)

    if len(pairs) > 0:
        return True, pairs[-1]

def is_high_card(combs: list):

    high_cards = []

    for comb in combs:
        form = ' '.join(comb)
        high_cards.append(form)

    if len(high_cards) > 0:
        return True, high_cards[-1]

holes = (['K♥', 'A♣'], ['6♥', '4♠'], ['Q♠', 'Q♣'], ['2♠', '4♣'], ['5♠', '3♠'],
         ['J♣', 'Q♣'], ['Q♦', 'K♦'], ['K♠', 'A♠'], ['6♣', '7♣'], ['2♠', '7♦'])

ccards = (['3♦', '6♠', '10♦', 'J♠', '2♣'],
          ['10♠', 'J♠', 'Q♠', '8♣', '6♠'],
          ['9♠', '10♠', 'J♠', '6♦', '4♥'],
          ['9♠', '3♠', '4♦', '5♦', '6♥'],
          ['9♠', '5♦', '6♦', 'J♠', '3♣'],
          ['9♣', '10♣', '2♣', '3♣', '4♥'],
          ['4♦', '7♥', '7♦', 'A♣', '6♠'],
          ['5♦', '6♦', '10♣', '2♦', '2♣'],
          ['5♦', '5♣', '5♥', '6♦', '2♣'],
          ['5♣', '5♥', '6♦', '2♣', '4♦'],
          ['A♣', '10♣', '6♦', '3♦', 'K♣'])

for hole in holes:
    for ccard in ccards:
        check_rank(hole, ccard)

此文件对扑克牌型进行排名。create_deck 生成一副 52 张牌的牌组,calculate_combinations 从底牌和公共牌中计算出 5 张牌的组合。像 is_royalis_pair 这样的函数检查特定的牌型等级,并返回最强的匹配项。该脚本遍历底牌和公共牌的组合,并打印出牌型等级。

test_poker_rank_hands.py
import unittest

from rank_hands import (create_deck, by_poker_order, calculate_combinations,
                       is_royal, is_4_kind, is_full_house, is_flush, is_straight,
                       is_3_kind, is_two_pairs, is_pair, is_high_card)
from collections import Counter

class TestPokerHandRanking(unittest.TestCase):

    def setUp(self):
        """Set up test fixtures with sample hole and community cards."""
        self.hole_royal = ['K♠', 'A♠']
        self.hole_pair = ['Q♠', 'Q♣']
        self.hole_high = ['K♥', 'A♣']
        
        self.ccards_royal = ['10♠', 'J♠', 'Q♠', '2♣', '3♦']
        self.ccards_4kind = ['Q♥', 'Q♦', 'Q♠', '5♣', '6♦']
        self.ccards_full = ['Q♥', 'Q♦', '5♠', '5♣', '6♦']
        self.ccards_flush = ['2♠', '5♠', '7♠', '9♠', 'J♠']
        self.ccards_straight = ['9♠', '10♦', 'J♣', 'Q♥', 'K♦']
        self.ccards_3kind = ['Q♥', 'Q♦', '5♠', '6♣', '7♦']
        self.ccards_twopair = ['Q♥', '5♠', '5♣', 'K♦', '6♠']
        self.ccards_pair = ['5♠', '6♣', '7♦', '8♠', '9♣']
        self.ccards_high = ['2♣', '5♦', '7♥', '9♠', 'J♦']

    def test_create_deck(self):
        """Test that create_deck generates a standard 52-card deck."""
        deck = create_deck()
        self.assertEqual(len(deck), 52)
        self.assertIn('A♠', deck)
        self.assertIn('2♣', deck)
        self.assertEqual(len(set(deck)), 52)

    def test_by_poker_order(self):
        """Test that by_poker_order ranks cards correctly."""
        self.assertEqual(by_poker_order('2♠'), 0)
        self.assertEqual(by_poker_order('10♦'), 8)
        self.assertEqual(by_poker_order('J♣'), 9)
        self.assertEqual(by_poker_order('A♥'), 12)
        self.assertTrue(by_poker_order('2♠') < by_poker_order('A♠'))

    def test_calculate_combinations(self):
        """Test that calculate_combinations generates correct 5-card combinations."""
        hole = ['K♥', 'A♣']  # 2 hole cards
        ccards = ['2♠', '3♦', '4♣', '5♠', '6♦']  # 5 community cards
        combs = calculate_combinations(hole, ccards)
        self.assertEqual(len(combs), 21)  # 7 choose 5 = 21
        sample_comb = combs[0]
        self.assertEqual(len(sample_comb), 5)
        self.assertTrue(all(card in hole + ccards for card in sample_comb))

    def test_is_royal(self):
        """Test detection of a royal flush."""
        combs = calculate_combinations(self.hole_royal, self.ccards_royal)
        result = is_royal(combs)
        self.assertTrue(result[0])
        self.assertEqual(result[1], '10♠ J♠ Q♠ K♠ A♠')

        combs = calculate_combinations(self.hole_high, self.ccards_flush)
        result = is_royal(combs)
        self.assertIsNone(result)

    def test_is_4_kind(self):
        """Test detection of four of a kind."""
        combs = calculate_combinations(self.hole_pair, self.ccards_4kind)
        result = is_4_kind(combs)
        self.assertTrue(result[0])
        ranks = [card[:-1] for card in result[1].split()]
        self.assertEqual(Counter(ranks)['Q'], 4)

        combs = calculate_combinations(self.hole_high, self.ccards_high)
        result = is_4_kind(combs)
        self.assertIsNone(result)

    def test_is_full_house(self):
        """Test detection of a full house."""
        combs = calculate_combinations(self.hole_pair, self.ccards_full)
        result = is_full_house(combs)
        self.assertTrue(result[0])
        ranks = [card[:-1] for card in result[1].split()]
        c = Counter(ranks)
        self.assertTrue(3 in c.values() and 2 in c.values())

        combs = calculate_combinations(self.hole_high, self.ccards_high)
        result = is_full_house(combs)
        self.assertIsNone(result)

    def test_is_flush(self):
        """Test detection of a flush."""
        combs = calculate_combinations(self.hole_high, self.ccards_flush)
        result = is_flush(combs)
        self.assertTrue(result[0])
        suits = [card[-1] for card in result[1].split()]
        self.assertEqual(len(set(suits)), 1)

        combs = calculate_combinations(self.hole_high, self.ccards_high)
        result = is_flush(combs)
        self.assertIsNone(result)

    def test_is_straight(self):
        """Test detection of a straight."""
        combs = calculate_combinations(self.hole_high, self.ccards_straight)
        result = is_straight(combs)
        self.assertTrue(result[0])
        ranks = [card[:-1] for card in result[1].split()]
        self.assertEqual(len(set(ranks)), 5)

        combs = calculate_combinations(self.hole_high, self.ccards_high)
        result = is_straight(combs)
        self.assertIsNone(result)

    def test_is_3_kind(self):
        """Test detection of three of a kind."""
        combs = calculate_combinations(self.hole_pair, self.ccards_3kind)
        result = is_3_kind(combs)
        self.assertTrue(result[0])
        ranks = [card[:-1] for card in result[1].split()]
        self.assertEqual(Counter(ranks)['Q'], 3)

        combs = calculate_combinations(self.hole_high, self.ccards_high)
        result = is_3_kind(combs)
        self.assertIsNone(result)

    def test_is_two_pairs(self):
        """Test detection of two pairs."""
        combs = calculate_combinations(self.hole_pair, self.ccards_twopair)
        result = is_two_pairs(combs)
        self.assertTrue(result[0])
        ranks = [card[:-1] for card in result[1].split()]
        c = Counter(ranks)
        self.assertEqual(list(c.values()).count(2), 2)

        combs = calculate_combinations(self.hole_high, self.ccards_high)
        result = is_two_pairs(combs)
        self.assertIsNone(result)

    def test_is_pair(self):
        """Test detection of a pair."""
        combs = calculate_combinations(self.hole_pair, self.ccards_pair)
        result = is_pair(combs)
        self.assertTrue(result[0])
        ranks = [card[:-1] for card in result[1].split()]
        c = Counter(ranks)
        self.assertEqual(list(c.values()).count(2), 1)

        combs = calculate_combinations(self.hole_high, self.ccards_high)
        result = is_pair(combs)
        self.assertIsNone(result)

    def test_is_high_card(self):
        """Test detection of a high card (always true if no other hand)."""
        combs = calculate_combinations(self.hole_high, self.ccards_high)
        result = is_high_card(combs)
        self.assertTrue(result[0])
        self.assertTrue(isinstance(result[1], str))

if __name__ == '__main__':
    unittest.main()

此文件测试扑克牌型排名函数。TestPokerHandRanking 使用 setUp 定义了示例手牌。像 test_is_royal 这样的测试验证皇家同花顺的检测,test_is_4_kind 检查四条,test_is_high_card 确保高牌检测,均使用断言来验证手牌识别和排名逻辑。

有效单元测试的技巧

在本文中,我们介绍了 unittest 模块,它为测试 Python 代码提供了一个通用而强大的框架。通过以上示例,您已经了解了如何测试函数、类、异常等,包括模拟和测试套件等高级功能。开始将单元测试纳入您的项目,以提高代码质量和信心。

作者

我的名字是 Jan Bodnar,我是一名充满热情的程序员,拥有多年的编程经验。我从 2007 年开始撰写编程文章。到目前为止,我已经写了超过 1400 篇文章和 8本电子书。我有超过八年的编程教学经验。

列出所有 Python 教程