Python unittest
最后修改于 2025 年 2 月 25 日
unittest 模块是 Python 内置的用于编写和执行单元测试的框架,其灵感来源于 Java 生态系统中的 JUnit。它使开发人员能够验证其代码的各个组件(函数、方法或类)是否按预期工作。单元测试是可靠软件开发的基石,有助于您及早发现错误、验证功能并确保可维护性。
unittest 中的关键概念
- 测试用例 (Test Case):测试的最小单元,通常用于验证给定输入的特定行为或输出。测试用例通过子类化
unittest.TestCase来定义。 - 测试固件 (Test Fixture):测试所需的设置和拆卸逻辑,例如在测试运行前初始化资源(如文件、数据库)并在之后进行清理。
- 测试套件 (Test Suite):测试用例或其他测试套件的集合,允许您将相关测试分组执行。
- 测试运行器 (Test Runner):执行测试并报告结果的机制,例如通过/失败计数和错误详情。
- 断言 (Assertions):像 assertEqual、assertTrue、assertRaises 等方法,用于检查代码的行为是否符合预期。失败的断言会将测试标记为失败。
设置 unittest
要开始使用,请导入 unittest 并通过子类化 unittest.TestCase 创建一个测试类。测试方法必须以 test_ 开头,才能被测试运行器识别。
import unittest
class MyTestCase(unittest.TestCase):
def test_basic_arithmetic(self):
self.assertEqual(1 + 1, 2)
if __name__ == '__main__':
unittest.main()
此示例演示了使用 unittest 模块的基本单元测试。MyTestCase 类继承自 unittest.TestCase,定义了一个名为 test_basic_arithmetic 的测试方法。该方法使用 assertEqual 来验证 1 + 1 是否等于 2。通过 unittest.main() 执行时,测试运行器会检查此断言并报告结果。
从命令行运行测试
python -m unittest test_filename.py
要获得详细输出(显示测试名称和结果)
python -m unittest test_filename.py -v
unittest 实用示例
以下是在各种场景中展示 unittest 的实用示例,从基本函数到高级功能。
测试一个简单的函数
用多个用例测试一个基本的加法函数。
import unittest
def add(a, b):
return a + b
class TestAddFunction(unittest.TestCase):
def test_add_positive_numbers(self):
self.assertEqual(add(2, 3), 5)
def test_add_negative_numbers(self):
self.assertEqual(add(-1, -1), -2)
def test_add_mixed_numbers(self):
self.assertEqual(add(-1, 1), 0)
if __name__ == '__main__':
unittest.main()
此代码通过三个测试用例测试了一个简单的 add 函数。TestAddFunction 子类化了 unittest.TestCase,并包含了检查正数相加 (2 + 3 = 5)、负数相加 (-1 + -1 = -2) 以及混合数相加 (-1 + 1 = 0) 的方法。每个测试都使用 assertEqual 来确保函数返回预期的输出。
测试字符串方法
通过正面和负面测试来验证内置的字符串方法。
import unittest
class TestStringMethods(unittest.TestCase):
def test_upper(self):
self.assertEqual('hello'.upper(), 'HELLO')
def test_isupper(self):
self.assertTrue('HELLO'.isupper())
self.assertFalse('Hello'.isupper())
def test_split(self):
s = 'hello there'
self.assertEqual(s.split(), ['hello', 'there'])
with self.assertRaises(TypeError):
s.split(2) # split() expects a string, not an integer
if __name__ == '__main__':
unittest.main()
此示例测试 Python 的内置字符串方法。TestStringMethods 包含三个测试:test_upper 检查 'hello'.upper() 是否返回 'HELLO',test_isupper 验证 'HELLO'.isupper() 为 True 且 'Hello'.isupper() 为 False,而 test_split 确认 'hello there'.split() 产生 ['hello', 'there']。它还使用 assertRaises 来确保 split(2) 会引发一个 TypeError。
测试列表方法
测试列表操作方法的正确性。
import unittest
class TestListMethods(unittest.TestCase):
def test_append(self):
my_list = [1, 2, 3]
my_list.append(4)
self.assertEqual(my_list, [1, 2, 3, 4])
def test_pop(self):
my_list = [1, 2, 3]
popped_value = my_list.pop()
self.assertEqual(popped_value, 3)
self.assertEqual(my_list, [1, 2])
if __name__ == '__main__':
unittest.main()
此代码验证列表操作。TestListMethods 通过向 [1, 2, 3] 添加 4 来测试 append,并用 assertEqual 检查结果是否为 [1, 2, 3, 4]。test_pop 方法从 [1, 2, 3] 中移除最后一个元素,验证弹出的值是 3,剩余列表是 [1, 2],并对这两项检查都使用了 assertEqual。
测试异常
确保函数按预期引发异常。
import unittest
def divide(a, b):
if b == 0:
raise ValueError("Cannot divide by zero")
return a / b
class TestDivideFunction(unittest.TestCase):
def test_divide_valid(self):
self.assertEqual(divide(10, 2), 5)
def test_divide_by_zero(self):
with self.assertRaises(ValueError):
divide(10, 0)
if __name__ == '__main__':
unittest.main()
此示例测试一个在除以零时会引发异常的 divide 函数。TestDivideFunction 包含 test_divide_valid,它使用 assertEqual 检查 divide(10, 2) 是否等于 5。test_divide_by_zero 方法使用 assertRaises 来确认 divide(10, 0) 会引发 ValueError,从而确保正确的异常处理。
使用 setUp 和 tearDown
通过 setUp 和 tearDown 模拟资源(例如数据库)。
import unittest
class TestDatabase(unittest.TestCase):
def setUp(self):
# Simulate opening a database connection
self.database = []
def tearDown(self):
# Simulate closing the connection
self.database = None
def test_insert(self):
self.database.append('data')
self.assertIn('data', self.database)
def test_delete(self):
self.database.append('data')
self.database.remove('data')
self.assertNotIn('data', self.database)
if __name__ == '__main__':
unittest.main()
此代码演示了使用 setUp 和 tearDown 的测试固件。TestDatabase 使用 setUp 初始化一个空列表作为模拟数据库,并使用 tearDown 将其重置为 None。test_insert 添加了 'data' 并用 assertIn 检查其是否存在,而 test_delete 先添加再移除 'data',并用 assertNotIn 验证其不存在,从而模拟资源管理。
测试一个类
测试一个简单的 Calculator 类的方法。
import unittest
class Calculator:
def add(self, a, b):
return a + b
def subtract(self, a, b):
return a - b
class TestCalculator(unittest.TestCase):
def setUp(self):
self.calc = Calculator()
def test_add(self):
self.assertEqual(self.calc.add(2, 3), 5)
def test_subtract(self):
self.assertEqual(self.calc.subtract(5, 3), 2)
if __name__ == '__main__':
unittest.main()
此示例测试一个带有 add 和 subtract 方法的 Calculator 类。TestCalculator 使用 setUp 创建一个 Calculator 实例。test_add 方法检查 calc.add(2, 3) 是否返回 5,test_subtract 验证 calc.subtract(5, 3) 是否等于 2,两者都使用 assertEqual 来确保类方法功能正确。
跳过测试
演示如何有条件或无条件地跳过测试。
import unittest
class TestSkipExample(unittest.TestCase):
@unittest.skip("Skipping this test for demonstration")
def test_skip(self):
self.fail("This test should be skipped")
@unittest.skipIf(2 > 1, "Skipping because condition is true")
def test_skip_if(self):
self.assertEqual(1, 2) # Would fail if not skipped
def test_normal(self):
self.assertEqual(1 + 1, 2)
if __name__ == '__main__':
unittest.main()
此代码演示了如何跳过测试。TestSkipExample 使用 @unittest.skip 无条件地跳过 test_skip,该测试原本会因 self.fail 而失败。@unittest.skipIf 装饰器在 2 > 1 为真时跳过 test_skip_if,从而避免了一个失败的断言。test_normal 则正常运行,用 assertEqual 检查 1 + 1 = 2。
使用 assertAlmostEqual 进行测试
处理浮点数算术的精度问题。
import unittest
class TestFloatingPoint(unittest.TestCase):
def test_almost_equal(self):
self.assertAlmostEqual(0.1 + 0.2, 0.3, places=7) # Accounts for float precision
def test_not_almost_equal(self):
self.assertNotAlmostEqual(0.1 + 0.2, 0.4, places=7)
if __name__ == '__main__':
unittest.main()
此示例解决了浮点数精度问题。TestFloatingPoint 在 test_almost_equal 中使用 assertAlmostEqual 来验证 0.1 + 0.2 ≈ 0.3,精度到小数点后 7 位,考虑了浮点数的不精确性。Test_not_almost_equal 使用 assertNotAlmostEqual 来确保 0.1 + 0.2 ≠ 0.4,展示了考虑精度的测试方法。
使用 assertRaises 进行测试
验证异常是否按预期被引发。
import unittest
def raise_exception():
raise ValueError("An error occurred")
class TestException(unittest.TestCase):
def test_raise_exception(self):
with self.assertRaises(ValueError) as context:
raise_exception()
self.assertEqual(str(context.exception), "An error occurred")
if __name__ == '__main__':
unittest.main()
此代码测试异常处理。TestException 定义了 raise_exception,它会引发一个 ValueError。test_raise_exception 方法使用 assertRaises 来确认异常发生,并使用上下文管理器捕获它,然后用 assertEqual 检查异常消息是否与 'An error occurred' 匹配。
使用测试套件
手动创建并运行一个自定义测试套件。
import unittest
class TestSuiteExample1(unittest.TestCase):
def test_case1(self):
self.assertEqual(1, 1)
class TestSuiteExample2(unittest.TestCase):
def test_case2(self):
self.assertEqual(2, 2)
def suite():
suite = unittest.TestSuite()
suite.addTest(TestSuiteExample1('test_case1'))
suite.addTest(TestSuiteExample2('test_case2'))
return suite
if __name__ == '__main__':
runner = unittest.TextTestRunner()
test_suite = suite()
runner.run(test_suite)
此示例创建了一个自定义测试套件。TestSuiteExample1 和 TestSuiteExample2 各包含一个检查相等性(1 = 1 和 2 = 2)的测试。suite 函数构建了一个 TestSuite,添加了特定的测试,然后由 TextTestRunner 执行,从而允许手动分组和运行测试。
测试文件操作
使用临时文件模拟文件操作。
import unittest
import os
class TestFileOperations(unittest.TestCase):
def setUp(self):
self.filename = "temp_test.txt"
with open(self.filename, 'w') as f:
f.write("Hello, there!")
def tearDown(self):
if os.path.exists(self.filename):
os.remove(self.filename)
def test_read_file(self):
with open(self.filename, 'r') as f:
content = f.read()
self.assertEqual(content, "Hello, there!")
if __name__ == '__main__':
unittest.main()
此示例使用固件测试文件操作。TestFileOperations 使用 setUp 创建一个包含“Hello, there!”内容的文件,并使用 tearDown 删除它。test_read_file 读取文件内容,并使用 assertEqual 确保内容与写入的字符串匹配,从而模拟文件 I/O 测试。
使用 unittest.mock 进行模拟测试
使用模拟来隔离依赖项。
import unittest
from unittest.mock import Mock
def fetch_data(api):
return api.get_data()
class TestMocking(unittest.TestCase):
def test_fetch_data(self):
# Create a mock API object
mock_api = Mock()
mock_api.get_data.return_value = "mocked data"
result = fetch_data(mock_api)
self.assertEqual(result, "mocked data")
mock_api.get_data.assert_called_once()
if __name__ == '__main__':
unittest.main()
此代码演示了使用 unittest.mock 进行模拟。TestMocking 在 test_fetch_data 中模拟了一个 API,将 get_data 的返回值设置为 'mocked data'。它使用模拟对象调用 fetch_data,用 assertEqual 验证结果,并用 assert_called_once 检查该方法被调用了一次,从而隔离了依赖项。
测试类型检查
确保函数能正确处理输入类型。
import unittest
def multiply(a, b):
if not isinstance(a, (int, float)) or not isinstance(b, (int, float)):
raise TypeError("Inputs must be numbers")
return a * b
class TestTypeChecking(unittest.TestCase):
def test_valid_input(self):
self.assertEqual(multiply(2, 3), 6)
def test_invalid_input(self):
with self.assertRaises(TypeError):
multiply("2", 3)
if __name__ == '__main__':
unittest.main()
此示例测试 multiply 函数中的类型验证。TestTypeChecking 使用 test_valid_input 检查 multiply(2, 3) = 6,使用了 assertEqual。test_invalid_input 使用 assertRaises 确保 multiply('2', 3) 会引发 TypeError,从而确认该函数强制要求数字输入。
测试边界情况
用边界条件测试一个函数。
import unittest
def clamp(value, min_val, max_val):
"""Clamp value between min_val and max_val."""
return max(min_val, min(max_val, value))
class TestClampFunction(unittest.TestCase):
def test_within_range(self):
self.assertEqual(clamp(5, 0, 10), 5)
def test_below_range(self):
self.assertEqual(clamp(-1, 0, 10), 0)
def test_above_range(self):
self.assertEqual(clamp(15, 0, 10), 10)
if __name__ == '__main__':
unittest.main()
此代码测试一个用于边界处理的 clamp 函数。TestClampFunction 检查了 test_within_range (clamp(5, 0, 10) = 5)、test_below_range (clamp(-1, 0, 10) = 0) 和 test_above_range (clamp(15, 0, 10) = 10),所有测试都使用 assertEqual 来验证值是否保持在指定范围内。
测试排序算法
def bubble_sort(arr):
n = len(arr)
for i in range(n):
for j in range(0, n-i-1):
if arr[j] > arr[j+1]:
arr[j], arr[j+1] = arr[j+1], arr[j]
def selection_sort(arr, ascending=True):
n = len(arr)
for i in range(n):
idx = i
for j in range(i + 1, n):
if (ascending and arr[j] < arr[idx]) or (not ascending and arr[j] > arr[idx]):
idx = j
arr[i], arr[idx] = arr[idx], arr[i]
return arr
此文件定义了两种排序算法。bubble_sort 实现冒泡排序,如果相邻元素顺序错误则进行交换,原地修改数组。selection_sort 在每次迭代中找到最小(或最大)元素,根据 ascending 参数按升序或降序排序,并返回排序后的数组。
import unittest
from sorting_algos import bubble_sort, selection_sort
class TestSortingAlgorithms(unittest.TestCase):
def setUp(self):
"""Set up test fixtures with various input arrays."""
self.unsorted_list = [64, 34, 25, 12, 22, 11, 90]
self.sorted_asc_list = [11, 12, 22, 25, 34, 64, 90]
self.sorted_desc_list = [90, 64, 34, 25, 22, 12, 11]
self.empty_list = []
self.single_element_list = [42]
self.duplicates_list = [5, 2, 8, 5, 1, 9, 2]
def test_bubble_sort_unsorted(self):
"""Test bubble_sort with an unsorted list."""
arr = self.unsorted_list.copy() # Use copy to preserve original fixture
bubble_sort(arr)
self.assertEqual(arr, self.sorted_asc_list)
def test_bubble_sort_already_sorted(self):
"""Test bubble_sort with an already sorted list."""
arr = self.sorted_asc_list.copy()
bubble_sort(arr)
self.assertEqual(arr, self.sorted_asc_list)
def test_bubble_sort_empty(self):
"""Test bubble_sort with an empty list."""
arr = self.empty_list.copy()
bubble_sort(arr)
self.assertEqual(arr, self.empty_list)
def test_bubble_sort_single_element(self):
"""Test bubble_sort with a single-element list."""
arr = self.single_element_list.copy()
bubble_sort(arr)
self.assertEqual(arr, self.single_element_list)
def test_bubble_sort_duplicates(self):
"""Test bubble_sort with a list containing duplicates."""
arr = self.duplicates_list.copy()
bubble_sort(arr)
self.assertEqual(arr, [1, 2, 2, 5, 5, 8, 9])
def test_selection_sort_ascending_unsorted(self):
"""Test selection_sort with an unsorted list in ascending order."""
arr = self.unsorted_list.copy()
result = selection_sort(arr, ascending=True)
self.assertEqual(result, self.sorted_asc_list)
def test_selection_sort_descending_unsorted(self):
"""Test selection_sort with an unsorted list in descending order."""
arr = self.unsorted_list.copy()
result = selection_sort(arr, ascending=False)
self.assertEqual(result, self.sorted_desc_list)
def test_selection_sort_ascending_sorted(self):
"""Test selection_sort with an already sorted list in ascending order."""
arr = self.sorted_asc_list.copy()
result = selection_sort(arr, ascending=True)
self.assertEqual(result, self.sorted_asc_list)
def test_selection_sort_descending_sorted(self):
"""Test selection_sort with an already sorted list in descending order."""
arr = self.sorted_desc_list.copy()
result = selection_sort(arr, ascending=False)
self.assertEqual(result, self.sorted_desc_list)
def test_selection_sort_empty(self):
"""Test selection_sort with an empty list."""
arr = self.empty_list.copy()
result = selection_sort(arr, ascending=True)
self.assertEqual(result, self.empty_list)
def test_selection_sort_single_element(self):
"""Test selection_sort with a single-element list."""
arr = self.single_element_list.copy()
result = selection_sort(arr, ascending=True)
self.assertEqual(result, self.single_element_list)
def test_selection_sort_duplicates_ascending(self):
"""Test selection_sort with duplicates in ascending order."""
arr = self.duplicates_list.copy()
result = selection_sort(arr, ascending=True)
self.assertEqual(result, [1, 2, 2, 5, 5, 8, 9])
def test_selection_sort_duplicates_descending(self):
"""Test selection_sort with duplicates in descending order."""
arr = self.duplicates_list.copy()
result = selection_sort(arr, ascending=False)
self.assertEqual(result, [9, 8, 5, 5, 2, 2, 1])
if __name__ == '__main__':
unittest.main()
此文件测试各种情况下的排序算法。TestSortingAlgorithms 使用 setUp 定义了测试固件,如未排序和已排序的列表。像 test_bubble_sort_unsorted 和 test_selection_sort_ascending_unsorted 这样的测试使用 assertEqual 来验证未排序、已排序、空、单元素以及包含重复元素的列表的排序正确性。
测试扑克牌型
使用 python -m unittest test_rank_hands.py -v 运行
from itertools import combinations
from collections import Counter
def create_deck():
signs = [2, 3, 4, 5, 6, 7, 8, 9, 10, 'J', 'Q', 'K', 'A']
symbols = ['♠', '♥', '♦', '♣'] # spades, hearts, diamonds, clubs
deck = [f'{si}{sy}' for si in signs for sy in symbols]
return deck
def by_poker_order(card):
poker_order = ["2", "3", "4", "5", "6",
"7", "8", "9", "10", "J", "Q", "K", "A"]
return poker_order.index(card[:-1])
def calculate_combinations(hole: list, ccards: list):
hands = hole + ccards
hands.sort(key=by_poker_order)
combs = combinations(hands, 5)
return tuple(combs)
def check_rank(hole: list, ccards: list):
combs = calculate_combinations(hole, ccards)
match is_royal(combs):
case (True, form):
print(f'{form} is a royal flush')
return
match is_4_kind(combs):
case (True, form):
print(f'{form} is four of a kind')
return
match is_full_house(combs):
case (True, form):
print(f'{form} is a full house')
return
match is_flush(combs):
case (True, form):
print(f'{form} is a flush')
return
match is_3_kind(combs):
case (True, form):
print(f'{form} is three of a kind')
return
match is_straight(combs):
case (True, form):
print(f'{form} is a straight')
return
match is_two_pairs(combs):
case (True, form):
print(f'{form} is two pairs')
return
match is_pair(combs):
case (True, form):
print(f'{form} is a pair')
return
match is_high_card(combs):
case (True, form):
print(f'{form} is a high card')
return
def is_royal(combs: list):
royals = ["10♠ J♠ Q♠ K♠ A♠", "10♣ J♣ Q♣ K♣ A♣",
"10♥ J♥ Q♥ K♥ A♥", "10♦ J♦ Q♦ K♦ A♦"]
for comb in combs:
form = ' '.join(comb)
if form in royals:
return True, form
def is_4_kind(combs: list):
four_kinds = []
for comb in combs:
c = Counter([e[:-1] for e in comb])
vals = c.values()
if 4 in vals:
form = ' '.join(comb)
four_kinds.append(form)
if len(four_kinds) > 0:
return True, four_kinds[-1]
def is_full_house(combs: list):
for comb in combs:
c = Counter([e[:-1] for e in comb])
vals = c.values()
form = ' '.join(comb)
if 2 in vals and 3 in vals:
return True, form
def is_flush(combs: list):
matches = ['♣ ♣ ♣ ♣ ♣', '♦ ♦ ♦ ♦ ♦', '♥ ♥ ♥ ♥ ♥', '♠ ♠ ♠ ♠ ♠']
flushes = [] # there may be more flush combinations, we pick the strongest
for comb in combs:
psuits = ' '.join(e[-1] for e in comb)
if psuits in matches:
form = ' '.join(comb)
flushes.append(form)
if len(flushes) > 0:
return True, flushes[-1]
def is_straight(combs: list):
order = "2 3 4 5 6 7 8 9 10 J Q K A"
strainghts = []
for comb in combs:
seq = [e[:-1] for e in comb]
unit = ' '.join(seq)
if unit in order:
form = ' '.join(comb)
strainghts.append(form)
if len(strainghts) > 0:
return True, strainghts[-1]
def is_3_kind(combs: list):
three_kinds = []
for comb in combs:
c = Counter([e[:-1] for e in comb])
vals = c.values()
if 3 in vals:
form = ' '.join(comb)
three_kinds.append(form)
if len(three_kinds) > 0:
return True, three_kinds[-1]
def is_two_pairs(combs: list):
two_pairs = []
for comb in combs:
c = Counter([e[:-1] for e in comb])
vals = list(c.values())
if vals.count(2) == 2:
form = ' '.join(comb)
two_pairs.append(form)
if len(two_pairs) > 0:
return True, two_pairs[-1]
def is_pair(combs: list):
pairs = []
for comb in combs:
c = Counter([e[:-1] for e in comb])
vals = list(c.values())
if vals.count(2) == 1:
form = ' '.join(comb)
pairs.append(form)
if len(pairs) > 0:
return True, pairs[-1]
def is_high_card(combs: list):
high_cards = []
for comb in combs:
form = ' '.join(comb)
high_cards.append(form)
if len(high_cards) > 0:
return True, high_cards[-1]
holes = (['K♥', 'A♣'], ['6♥', '4♠'], ['Q♠', 'Q♣'], ['2♠', '4♣'], ['5♠', '3♠'],
['J♣', 'Q♣'], ['Q♦', 'K♦'], ['K♠', 'A♠'], ['6♣', '7♣'], ['2♠', '7♦'])
ccards = (['3♦', '6♠', '10♦', 'J♠', '2♣'],
['10♠', 'J♠', 'Q♠', '8♣', '6♠'],
['9♠', '10♠', 'J♠', '6♦', '4♥'],
['9♠', '3♠', '4♦', '5♦', '6♥'],
['9♠', '5♦', '6♦', 'J♠', '3♣'],
['9♣', '10♣', '2♣', '3♣', '4♥'],
['4♦', '7♥', '7♦', 'A♣', '6♠'],
['5♦', '6♦', '10♣', '2♦', '2♣'],
['5♦', '5♣', '5♥', '6♦', '2♣'],
['5♣', '5♥', '6♦', '2♣', '4♦'],
['A♣', '10♣', '6♦', '3♦', 'K♣'])
for hole in holes:
for ccard in ccards:
check_rank(hole, ccard)
此文件对扑克牌型进行排名。create_deck 生成一副 52 张牌的牌组,calculate_combinations 从底牌和公共牌中计算出 5 张牌的组合。像 is_royal 和 is_pair 这样的函数检查特定的牌型等级,并返回最强的匹配项。该脚本遍历底牌和公共牌的组合,并打印出牌型等级。
import unittest
from rank_hands import (create_deck, by_poker_order, calculate_combinations,
is_royal, is_4_kind, is_full_house, is_flush, is_straight,
is_3_kind, is_two_pairs, is_pair, is_high_card)
from collections import Counter
class TestPokerHandRanking(unittest.TestCase):
def setUp(self):
"""Set up test fixtures with sample hole and community cards."""
self.hole_royal = ['K♠', 'A♠']
self.hole_pair = ['Q♠', 'Q♣']
self.hole_high = ['K♥', 'A♣']
self.ccards_royal = ['10♠', 'J♠', 'Q♠', '2♣', '3♦']
self.ccards_4kind = ['Q♥', 'Q♦', 'Q♠', '5♣', '6♦']
self.ccards_full = ['Q♥', 'Q♦', '5♠', '5♣', '6♦']
self.ccards_flush = ['2♠', '5♠', '7♠', '9♠', 'J♠']
self.ccards_straight = ['9♠', '10♦', 'J♣', 'Q♥', 'K♦']
self.ccards_3kind = ['Q♥', 'Q♦', '5♠', '6♣', '7♦']
self.ccards_twopair = ['Q♥', '5♠', '5♣', 'K♦', '6♠']
self.ccards_pair = ['5♠', '6♣', '7♦', '8♠', '9♣']
self.ccards_high = ['2♣', '5♦', '7♥', '9♠', 'J♦']
def test_create_deck(self):
"""Test that create_deck generates a standard 52-card deck."""
deck = create_deck()
self.assertEqual(len(deck), 52)
self.assertIn('A♠', deck)
self.assertIn('2♣', deck)
self.assertEqual(len(set(deck)), 52)
def test_by_poker_order(self):
"""Test that by_poker_order ranks cards correctly."""
self.assertEqual(by_poker_order('2♠'), 0)
self.assertEqual(by_poker_order('10♦'), 8)
self.assertEqual(by_poker_order('J♣'), 9)
self.assertEqual(by_poker_order('A♥'), 12)
self.assertTrue(by_poker_order('2♠') < by_poker_order('A♠'))
def test_calculate_combinations(self):
"""Test that calculate_combinations generates correct 5-card combinations."""
hole = ['K♥', 'A♣'] # 2 hole cards
ccards = ['2♠', '3♦', '4♣', '5♠', '6♦'] # 5 community cards
combs = calculate_combinations(hole, ccards)
self.assertEqual(len(combs), 21) # 7 choose 5 = 21
sample_comb = combs[0]
self.assertEqual(len(sample_comb), 5)
self.assertTrue(all(card in hole + ccards for card in sample_comb))
def test_is_royal(self):
"""Test detection of a royal flush."""
combs = calculate_combinations(self.hole_royal, self.ccards_royal)
result = is_royal(combs)
self.assertTrue(result[0])
self.assertEqual(result[1], '10♠ J♠ Q♠ K♠ A♠')
combs = calculate_combinations(self.hole_high, self.ccards_flush)
result = is_royal(combs)
self.assertIsNone(result)
def test_is_4_kind(self):
"""Test detection of four of a kind."""
combs = calculate_combinations(self.hole_pair, self.ccards_4kind)
result = is_4_kind(combs)
self.assertTrue(result[0])
ranks = [card[:-1] for card in result[1].split()]
self.assertEqual(Counter(ranks)['Q'], 4)
combs = calculate_combinations(self.hole_high, self.ccards_high)
result = is_4_kind(combs)
self.assertIsNone(result)
def test_is_full_house(self):
"""Test detection of a full house."""
combs = calculate_combinations(self.hole_pair, self.ccards_full)
result = is_full_house(combs)
self.assertTrue(result[0])
ranks = [card[:-1] for card in result[1].split()]
c = Counter(ranks)
self.assertTrue(3 in c.values() and 2 in c.values())
combs = calculate_combinations(self.hole_high, self.ccards_high)
result = is_full_house(combs)
self.assertIsNone(result)
def test_is_flush(self):
"""Test detection of a flush."""
combs = calculate_combinations(self.hole_high, self.ccards_flush)
result = is_flush(combs)
self.assertTrue(result[0])
suits = [card[-1] for card in result[1].split()]
self.assertEqual(len(set(suits)), 1)
combs = calculate_combinations(self.hole_high, self.ccards_high)
result = is_flush(combs)
self.assertIsNone(result)
def test_is_straight(self):
"""Test detection of a straight."""
combs = calculate_combinations(self.hole_high, self.ccards_straight)
result = is_straight(combs)
self.assertTrue(result[0])
ranks = [card[:-1] for card in result[1].split()]
self.assertEqual(len(set(ranks)), 5)
combs = calculate_combinations(self.hole_high, self.ccards_high)
result = is_straight(combs)
self.assertIsNone(result)
def test_is_3_kind(self):
"""Test detection of three of a kind."""
combs = calculate_combinations(self.hole_pair, self.ccards_3kind)
result = is_3_kind(combs)
self.assertTrue(result[0])
ranks = [card[:-1] for card in result[1].split()]
self.assertEqual(Counter(ranks)['Q'], 3)
combs = calculate_combinations(self.hole_high, self.ccards_high)
result = is_3_kind(combs)
self.assertIsNone(result)
def test_is_two_pairs(self):
"""Test detection of two pairs."""
combs = calculate_combinations(self.hole_pair, self.ccards_twopair)
result = is_two_pairs(combs)
self.assertTrue(result[0])
ranks = [card[:-1] for card in result[1].split()]
c = Counter(ranks)
self.assertEqual(list(c.values()).count(2), 2)
combs = calculate_combinations(self.hole_high, self.ccards_high)
result = is_two_pairs(combs)
self.assertIsNone(result)
def test_is_pair(self):
"""Test detection of a pair."""
combs = calculate_combinations(self.hole_pair, self.ccards_pair)
result = is_pair(combs)
self.assertTrue(result[0])
ranks = [card[:-1] for card in result[1].split()]
c = Counter(ranks)
self.assertEqual(list(c.values()).count(2), 1)
combs = calculate_combinations(self.hole_high, self.ccards_high)
result = is_pair(combs)
self.assertIsNone(result)
def test_is_high_card(self):
"""Test detection of a high card (always true if no other hand)."""
combs = calculate_combinations(self.hole_high, self.ccards_high)
result = is_high_card(combs)
self.assertTrue(result[0])
self.assertTrue(isinstance(result[1], str))
if __name__ == '__main__':
unittest.main()
此文件测试扑克牌型排名函数。TestPokerHandRanking 使用 setUp 定义了示例手牌。像 test_is_royal 这样的测试验证皇家同花顺的检测,test_is_4_kind 检查四条,test_is_high_card 确保高牌检测,均使用断言来验证手牌识别和排名逻辑。
有效单元测试的技巧
- 一次只测试一件事:每个测试都应专注于单一的行为或条件。
- 使用描述性的名称:像 test_add_positive_numbers 这样的方法名使失败更容易诊断。
- 覆盖边界情况:测试边界、异常和不寻常的输入。
- 保持测试独立:使用 setUp 和 tearDown 避免测试之间的依赖关系。
- 频繁运行测试:将测试集成到您的工作流程中,以及早发现问题。
在本文中,我们介绍了 unittest 模块,它为测试 Python 代码提供了一个通用而强大的框架。通过以上示例,您已经了解了如何测试函数、类、异常等,包括模拟和测试套件等高级功能。开始将单元测试纳入您的项目,以提高代码质量和信心。
作者
列出所有 Python 教程。