Python 策略模式
策略模式是一种行为设计模式,它允许你在运行时选择算法或行为。简单来说,策略模式就是将不同的算法封装成独立的类,让它们可以相互替换。
核心思想
想象一下你去餐厅点餐:你可以选择不同的支付方式 - 现金、信用卡、移动支付等。无论你选择哪种支付方式,最终都能完成支付,只是具体的支付流程不同。这就是策略模式的现实例子。
设计原则
策略模式遵循以下重要原则:
- 开闭原则:对扩展开放,对修改关闭
- 单一职责原则:每个策略类只负责一种算法
- 依赖倒置原则:依赖于抽象,而不是具体实现
策略模式的结构
让我们通过 UML 类图来理解策略模式的组成结构:

组件说明
Context(上下文)
- 维护一个策略对象的引用
- 可以动态切换策略
- 将工作委托给策略对象
Strategy(策略接口)
- 定义所有支持的算法的公共接口
- 声明算法执行的方法
ConcreteStrategy(具体策略)
- 实现策略接口
- 提供具体的算法实现
基础语法和实现
策略接口定义
在 Python 中,我们可以使用抽象基类(ABC)来定义策略接口:
实例
from abc import ABC, abstractmethod
class PaymentStrategy(ABC):
"""支付策略抽象基类"""
@abstractmethod
def pay(self, amount: float) -> bool:
"""支付方法"""
pass
class PaymentStrategy(ABC):
"""支付策略抽象基类"""
@abstractmethod
def pay(self, amount: float) -> bool:
"""支付方法"""
pass
具体策略实现
实例
class CreditCardPayment(PaymentStrategy):
"""信用卡支付策略"""
def __init__(self, card_number: str, expiry_date: str, cvv: str):
self.card_number = card_number
self.expiry_date = expiry_date
self.cvv = cvv
def pay(self, amount: float) -> bool:
print(f"使用信用卡支付 {amount} 元")
print(f"卡号: {self.card_number}")
# 这里应该有实际的支付逻辑
return True
class AlipayPayment(PaymentStrategy):
"""支付宝支付策略"""
def __init__(self, account: str):
self.account = account
def pay(self, amount: float) -> bool:
print(f"使用支付宝支付 {amount} 元")
print(f"支付宝账号: {self.account}")
# 这里应该有实际的支付逻辑
return True
class WechatPayment(PaymentStrategy):
"""微信支付策略"""
def __init__(self, openid: str):
self.openid = openid
def pay(self, amount: float) -> bool:
print(f"使用微信支付 {amount} 元")
print(f"微信 OpenID: {self.openid}")
# 这里应该有实际的支付逻辑
return True
"""信用卡支付策略"""
def __init__(self, card_number: str, expiry_date: str, cvv: str):
self.card_number = card_number
self.expiry_date = expiry_date
self.cvv = cvv
def pay(self, amount: float) -> bool:
print(f"使用信用卡支付 {amount} 元")
print(f"卡号: {self.card_number}")
# 这里应该有实际的支付逻辑
return True
class AlipayPayment(PaymentStrategy):
"""支付宝支付策略"""
def __init__(self, account: str):
self.account = account
def pay(self, amount: float) -> bool:
print(f"使用支付宝支付 {amount} 元")
print(f"支付宝账号: {self.account}")
# 这里应该有实际的支付逻辑
return True
class WechatPayment(PaymentStrategy):
"""微信支付策略"""
def __init__(self, openid: str):
self.openid = openid
def pay(self, amount: float) -> bool:
print(f"使用微信支付 {amount} 元")
print(f"微信 OpenID: {self.openid}")
# 这里应该有实际的支付逻辑
return True
上下文类实现
实例
class PaymentContext:
"""支付上下文类"""
def __init__(self, strategy: PaymentStrategy = None):
self._strategy = strategy
def set_strategy(self, strategy: PaymentStrategy):
"""设置支付策略"""
self._strategy = strategy
def execute_payment(self, amount: float) -> bool:
"""执行支付"""
if not self._strategy:
raise ValueError("未设置支付策略")
return self._strategy.pay(amount)
"""支付上下文类"""
def __init__(self, strategy: PaymentStrategy = None):
self._strategy = strategy
def set_strategy(self, strategy: PaymentStrategy):
"""设置支付策略"""
self._strategy = strategy
def execute_payment(self, amount: float) -> bool:
"""执行支付"""
if not self._strategy:
raise ValueError("未设置支付策略")
return self._strategy.pay(amount)
完整示例:电商支付系统
让我们通过一个完整的电商支付系统来演示策略模式的实际应用:
实例
from abc import ABC, abstractmethod
from typing import Dict, Any
# 策略接口
class DiscountStrategy(ABC):
"""折扣策略接口"""
@abstractmethod
def calculate_discount(self, original_price: float) -> float:
"""计算折扣后的价格"""
pass
# 具体策略实现
class NoDiscountStrategy(DiscountStrategy):
"""无折扣策略"""
def calculate_discount(self, original_price: float) -> float:
return original_price
class PercentageDiscountStrategy(DiscountStrategy):
"""百分比折扣策略"""
def __init__(self, percentage: float):
if not 0 <= percentage <= 100:
raise ValueError("折扣百分比必须在 0-100 之间")
self.percentage = percentage
def calculate_discount(self, original_price: float) -> float:
discount_amount = original_price * (self.percentage / 100)
return original_price - discount_amount
class FixedAmountDiscountStrategy(DiscountStrategy):
"""固定金额折扣策略"""
def __init__(self, discount_amount: float):
if discount_amount < 0:
raise ValueError("折扣金额不能为负数")
self.discount_amount = discount_amount
def calculate_discount(self, original_price: float) -> float:
return max(0, original_price - self.discount_amount)
class SeasonalDiscountStrategy(DiscountStrategy):
"""季节性折扣策略"""
def __init__(self, base_discount: float, seasonal_multiplier: float):
self.base_discount = base_discount
self.seasonal_multiplier = seasonal_multiplier
def calculate_discount(self, original_price: float) -> float:
total_discount = self.base_discount * self.seasonal_multiplier
return max(0, original_price - total_discount)
# 上下文类
class ShoppingCart:
"""购物车类"""
def __init__(self):
self.items = []
self._discount_strategy = NoDiscountStrategy()
def add_item(self, item: str, price: float):
"""添加商品"""
self.items.append({"item": item, "price": price})
def set_discount_strategy(self, strategy: DiscountStrategy):
"""设置折扣策略"""
self._discount_strategy = strategy
def calculate_total(self) -> float:
"""计算总价"""
total = sum(item["price"] for item in self.items)
return self._discount_strategy.calculate_discount(total)
def display_cart(self):
"""显示购物车内容"""
print("购物车内容:")
for item in self.items:
print(f" - {item['item']}: {item['price']}元")
original_total = sum(item["price"] for item in self.items)
final_total = self.calculate_total()
print(f"原价: {original_total}元")
print(f"折后价: {final_total}元")
if original_total != final_total:
discount = original_total - final_total
print(f"节省: {discount}元")
# 使用示例
def main():
# 创建购物车
cart = ShoppingCart()
# 添加商品
cart.add_item("Python编程书", 89.0)
cart.add_item("无线鼠标", 129.0)
cart.add_item("机械键盘", 399.0)
print("=== 无折扣 ===")
cart.set_discount_strategy(NoDiscountStrategy())
cart.display_cart()
print("\n=== 8折优惠 ===")
cart.set_discount_strategy(PercentageDiscountStrategy(20)) # 8折
cart.display_cart()
print("\n=== 满减优惠(减50元)===")
cart.set_discount_strategy(FixedAmountDiscountStrategy(50))
cart.display_cart()
print("\n=== 季节性优惠 ===")
cart.set_discount_strategy(SeasonalDiscountStrategy(30, 1.5)) # 基础折扣30,季节性系数1.5
cart.display_cart()
if __name__ == "__main__":
main()
from typing import Dict, Any
# 策略接口
class DiscountStrategy(ABC):
"""折扣策略接口"""
@abstractmethod
def calculate_discount(self, original_price: float) -> float:
"""计算折扣后的价格"""
pass
# 具体策略实现
class NoDiscountStrategy(DiscountStrategy):
"""无折扣策略"""
def calculate_discount(self, original_price: float) -> float:
return original_price
class PercentageDiscountStrategy(DiscountStrategy):
"""百分比折扣策略"""
def __init__(self, percentage: float):
if not 0 <= percentage <= 100:
raise ValueError("折扣百分比必须在 0-100 之间")
self.percentage = percentage
def calculate_discount(self, original_price: float) -> float:
discount_amount = original_price * (self.percentage / 100)
return original_price - discount_amount
class FixedAmountDiscountStrategy(DiscountStrategy):
"""固定金额折扣策略"""
def __init__(self, discount_amount: float):
if discount_amount < 0:
raise ValueError("折扣金额不能为负数")
self.discount_amount = discount_amount
def calculate_discount(self, original_price: float) -> float:
return max(0, original_price - self.discount_amount)
class SeasonalDiscountStrategy(DiscountStrategy):
"""季节性折扣策略"""
def __init__(self, base_discount: float, seasonal_multiplier: float):
self.base_discount = base_discount
self.seasonal_multiplier = seasonal_multiplier
def calculate_discount(self, original_price: float) -> float:
total_discount = self.base_discount * self.seasonal_multiplier
return max(0, original_price - total_discount)
# 上下文类
class ShoppingCart:
"""购物车类"""
def __init__(self):
self.items = []
self._discount_strategy = NoDiscountStrategy()
def add_item(self, item: str, price: float):
"""添加商品"""
self.items.append({"item": item, "price": price})
def set_discount_strategy(self, strategy: DiscountStrategy):
"""设置折扣策略"""
self._discount_strategy = strategy
def calculate_total(self) -> float:
"""计算总价"""
total = sum(item["price"] for item in self.items)
return self._discount_strategy.calculate_discount(total)
def display_cart(self):
"""显示购物车内容"""
print("购物车内容:")
for item in self.items:
print(f" - {item['item']}: {item['price']}元")
original_total = sum(item["price"] for item in self.items)
final_total = self.calculate_total()
print(f"原价: {original_total}元")
print(f"折后价: {final_total}元")
if original_total != final_total:
discount = original_total - final_total
print(f"节省: {discount}元")
# 使用示例
def main():
# 创建购物车
cart = ShoppingCart()
# 添加商品
cart.add_item("Python编程书", 89.0)
cart.add_item("无线鼠标", 129.0)
cart.add_item("机械键盘", 399.0)
print("=== 无折扣 ===")
cart.set_discount_strategy(NoDiscountStrategy())
cart.display_cart()
print("\n=== 8折优惠 ===")
cart.set_discount_strategy(PercentageDiscountStrategy(20)) # 8折
cart.display_cart()
print("\n=== 满减优惠(减50元)===")
cart.set_discount_strategy(FixedAmountDiscountStrategy(50))
cart.display_cart()
print("\n=== 季节性优惠 ===")
cart.set_discount_strategy(SeasonalDiscountStrategy(30, 1.5)) # 基础折扣30,季节性系数1.5
cart.display_cart()
if __name__ == "__main__":
main()
运行上述代码,你将看到以下输出:
=== 无折扣 === 购物车内容: - Python编程书: 89.0元 - 无线鼠标: 129.0元 - 机械键盘: 399.0元 原价: 617.0元 折后价: 617.0元 === 8折优惠 === 购物车内容: - Python编程书: 89.0元 - 无线鼠标: 129.0元 - 机械键盘: 399.0元 原价: 617.0元 折后价: 493.6元 节省: 123.4元 === 满减优惠(减50元)=== 购物车内容: - Python编程书: 89.0元 - 无线鼠标: 129.0元 - 机械键盘: 399.0元 原价: 617.0元 折后价: 567.0元 节省: 50.0元 === 季节性优惠 === 购物车内容: - Python编程书: 89.0元 - 无线鼠标: 129.0元 - 机械键盘: 399.0元 原价: 617.0元 折后价: 572.0元 节省: 45.0元
策略模式的进阶用法
1. 策略工厂模式
结合工厂模式来管理策略的创建:
实例
class DiscountStrategyFactory:
"""折扣策略工厂"""
@staticmethod
def create_strategy(strategy_type: str, **kwargs) -> DiscountStrategy:
"""创建折扣策略"""
strategies = {
"no_discount": NoDiscountStrategy,
"percentage": PercentageDiscountStrategy,
"fixed_amount": FixedAmountDiscountStrategy,
"seasonal": SeasonalDiscountStrategy
}
if strategy_type not in strategies:
raise ValueError(f"不支持的策略类型: {strategy_type}")
return strategies[strategy_type](**kwargs)
# 使用工厂模式
factory = DiscountStrategyFactory()
# 创建不同的策略
strategy1 = factory.create_strategy("percentage", percentage=15) # 85折
strategy2 = factory.create_strategy("fixed_amount", discount_amount=100) # 减100元
"""折扣策略工厂"""
@staticmethod
def create_strategy(strategy_type: str, **kwargs) -> DiscountStrategy:
"""创建折扣策略"""
strategies = {
"no_discount": NoDiscountStrategy,
"percentage": PercentageDiscountStrategy,
"fixed_amount": FixedAmountDiscountStrategy,
"seasonal": SeasonalDiscountStrategy
}
if strategy_type not in strategies:
raise ValueError(f"不支持的策略类型: {strategy_type}")
return strategies[strategy_type](**kwargs)
# 使用工厂模式
factory = DiscountStrategyFactory()
# 创建不同的策略
strategy1 = factory.create_strategy("percentage", percentage=15) # 85折
strategy2 = factory.create_strategy("fixed_amount", discount_amount=100) # 减100元
2. 动态策略选择
根据条件动态选择策略:
实例
class DynamicDiscountSelector:
"""动态折扣选择器"""
@staticmethod
def select_strategy(user_type: str, total_amount: float) -> DiscountStrategy:
"""根据用户类型和总金额选择策略"""
if user_type == "vip":
if total_amount > 500:
return PercentageDiscountStrategy(25) # VIP 满500打75折
else:
return PercentageDiscountStrategy(15) # VIP 打85折
elif user_type == "normal":
if total_amount > 300:
return FixedAmountDiscountStrategy(30) # 普通用户满300减30
else:
return NoDiscountStrategy()
else:
return NoDiscountStrategy()
# 使用动态选择
cart = ShoppingCart()
cart.add_item("商品A", 200)
cart.add_item("商品B", 150)
selector = DynamicDiscountSelector()
strategy = selector.select_strategy("vip", cart.calculate_total())
cart.set_discount_strategy(strategy)
cart.display_cart()
"""动态折扣选择器"""
@staticmethod
def select_strategy(user_type: str, total_amount: float) -> DiscountStrategy:
"""根据用户类型和总金额选择策略"""
if user_type == "vip":
if total_amount > 500:
return PercentageDiscountStrategy(25) # VIP 满500打75折
else:
return PercentageDiscountStrategy(15) # VIP 打85折
elif user_type == "normal":
if total_amount > 300:
return FixedAmountDiscountStrategy(30) # 普通用户满300减30
else:
return NoDiscountStrategy()
else:
return NoDiscountStrategy()
# 使用动态选择
cart = ShoppingCart()
cart.add_item("商品A", 200)
cart.add_item("商品B", 150)
selector = DynamicDiscountSelector()
strategy = selector.select_strategy("vip", cart.calculate_total())
cart.set_discount_strategy(strategy)
cart.display_cart()
策略模式的优势和适用场景
优势对比
| 特性 | 传统方式 | 策略模式 |
|---|---|---|
| 扩展性 | 需要修改原有代码 | 新增策略类即可 |
| 维护性 | 代码耦合度高 | 职责分离,易于维护 |
| 灵活性 | 运行时难以切换算法 | 可动态切换策略 |
| 测试性 | 难以单独测试算法 | 每个策略可独立测试 |
适用场景
- 多种算法变体:当你有多个相似的类,只在某些行为上不同时
- 避免条件语句:当想要避免使用大量的条件语句(if-else 或 switch)时
- 运行时算法选择:需要在运行时选择不同算法时
- 算法封装:希望将算法细节与使用算法的客户端隔离开时
不适用场景
- 简单算法:如果只有一两个很少变化的算法,可能过度设计
- 客户端需要了解策略细节:如果客户端必须知道策略的具体实现
- 策略数量过多:当策略类数量爆炸时,考虑其他模式
最佳实践和注意事项
代码组织建议
实例
# 推荐的文件结构
project/
├── strategies/
│ ├── __init__.py
│ ├── base_strategy.py # 基础策略接口
│ ├── discount_strategies.py # 折扣相关策略
│ └── payment_strategies.py # 支付相关策略
├── contexts/
│ ├── __init__.py
│ └── shopping_cart.py # 上下文类
└── main.py
project/
├── strategies/
│ ├── __init__.py
│ ├── base_strategy.py # 基础策略接口
│ ├── discount_strategies.py # 折扣相关策略
│ └── payment_strategies.py # 支付相关策略
├── contexts/
│ ├── __init__.py
│ └── shopping_cart.py # 上下文类
└── main.py
错误处理
实例
class SafeDiscountStrategy(DiscountStrategy):
"""带错误处理的折扣策略"""
def __init__(self, base_strategy: DiscountStrategy, fallback_strategy: DiscountStrategy = None):
self.base_strategy = base_strategy
self.fallback_strategy = fallback_strategy or NoDiscountStrategy()
def calculate_discount(self, original_price: float) -> float:
try:
return self.base_strategy.calculate_discount(original_price)
except Exception as e:
print(f"折扣计算错误: {e},使用备用策略")
return self.fallback_strategy.calculate_discount(original_price)
"""带错误处理的折扣策略"""
def __init__(self, base_strategy: DiscountStrategy, fallback_strategy: DiscountStrategy = None):
self.base_strategy = base_strategy
self.fallback_strategy = fallback_strategy or NoDiscountStrategy()
def calculate_discount(self, original_price: float) -> float:
try:
return self.base_strategy.calculate_discount(original_price)
except Exception as e:
print(f"折扣计算错误: {e},使用备用策略")
return self.fallback_strategy.calculate_discount(original_price)
性能考虑
对于性能敏感的场景,可以考虑以下优化:
实例
from functools import lru_cache
class CachedDiscountStrategy(DiscountStrategy):
"""带缓存的折扣策略"""
def __init__(self, base_strategy: DiscountStrategy):
self.base_strategy = base_strategy
@lru_cache(maxsize=128)
def calculate_discount(self, original_price: float) -> float:
return self.base_strategy.calculate_discount(original_price)
class CachedDiscountStrategy(DiscountStrategy):
"""带缓存的折扣策略"""
def __init__(self, base_strategy: DiscountStrategy):
self.base_strategy = base_strategy
@lru_cache(maxsize=128)
def calculate_discount(self, original_price: float) -> float:
return self.base_strategy.calculate_discount(original_price)
实践练习
练习 1:排序策略实现
实现一个支持多种排序算法的排序器:
实例
from abc import ABC, abstractmethod
from typing import List
class SortStrategy(ABC):
@abstractmethod
def sort(self, data: List) -> List:
pass
# TODO: 实现冒泡排序策略
class BubbleSortStrategy(SortStrategy):
def sort(self, data: List) -> List:
# 你的实现代码
pass
# TODO: 实现快速排序策略
class QuickSortStrategy(SortStrategy):
def sort(self, data: List) -> List:
# 你的实现代码
pass
# TODO: 实现归并排序策略
class MergeSortStrategy(SortStrategy):
def sort(self, data: List) -> List:
# 你的实现代码
pass
class Sorter:
def __init__(self, strategy: SortStrategy = None):
self._strategy = strategy
def set_strategy(self, strategy: SortStrategy):
self._strategy = strategy
def sort_data(self, data: List) -> List:
if not self._strategy:
raise ValueError("未设置排序策略")
return self._strategy.sort(data)
# 测试你的实现
data = [64, 34, 25, 12, 22, 11, 90]
sorter = Sorter(BubbleSortStrategy())
result = sorter.sort_data(data)
print(f"排序结果: {result}")
from typing import List
class SortStrategy(ABC):
@abstractmethod
def sort(self, data: List) -> List:
pass
# TODO: 实现冒泡排序策略
class BubbleSortStrategy(SortStrategy):
def sort(self, data: List) -> List:
# 你的实现代码
pass
# TODO: 实现快速排序策略
class QuickSortStrategy(SortStrategy):
def sort(self, data: List) -> List:
# 你的实现代码
pass
# TODO: 实现归并排序策略
class MergeSortStrategy(SortStrategy):
def sort(self, data: List) -> List:
# 你的实现代码
pass
class Sorter:
def __init__(self, strategy: SortStrategy = None):
self._strategy = strategy
def set_strategy(self, strategy: SortStrategy):
self._strategy = strategy
def sort_data(self, data: List) -> List:
if not self._strategy:
raise ValueError("未设置排序策略")
return self._strategy.sort(data)
# 测试你的实现
data = [64, 34, 25, 12, 22, 11, 90]
sorter = Sorter(BubbleSortStrategy())
result = sorter.sort_data(data)
print(f"排序结果: {result}")
练习 2:文件压缩策略
设计一个支持多种压缩格式的文件压缩器:
实例
from abc import ABC, abstractmethod
class CompressionStrategy(ABC):
@abstractmethod
def compress(self, file_path: str) -> str:
pass
@abstractmethod
def decompress(self, file_path: str) -> str:
pass
# TODO: 实现 ZIP 压缩策略
class ZipCompressionStrategy(CompressionStrategy):
def compress(self, file_path: str) -> str:
# 你的实现代码
pass
def decompress(self, file_path: str) -> str:
# 你的实现代码
pass
# TODO: 实现 GZIP 压缩策略
class GzipCompressionStrategy(CompressionStrategy):
def compress(self, file_path: str) -> str:
# 你的实现代码
pass
def decompress(self, file_path: str) -> str:
# 你的实现代码
pass
class FileCompressor:
def __init__(self, strategy: CompressionStrategy = None):
self._strategy = strategy
def set_strategy(self, strategy: CompressionStrategy):
self._strategy = strategy
def compress_file(self, file_path: str) -> str:
if not self._strategy:
raise ValueError("未设置压缩策略")
return self._strategy.compress(file_path)
def decompress_file(self, file_path: str) -> str:
if not self._strategy:
raise ValueError("未设置压缩策略")
return self._strategy.decompress(file_path)
class CompressionStrategy(ABC):
@abstractmethod
def compress(self, file_path: str) -> str:
pass
@abstractmethod
def decompress(self, file_path: str) -> str:
pass
# TODO: 实现 ZIP 压缩策略
class ZipCompressionStrategy(CompressionStrategy):
def compress(self, file_path: str) -> str:
# 你的实现代码
pass
def decompress(self, file_path: str) -> str:
# 你的实现代码
pass
# TODO: 实现 GZIP 压缩策略
class GzipCompressionStrategy(CompressionStrategy):
def compress(self, file_path: str) -> str:
# 你的实现代码
pass
def decompress(self, file_path: str) -> str:
# 你的实现代码
pass
class FileCompressor:
def __init__(self, strategy: CompressionStrategy = None):
self._strategy = strategy
def set_strategy(self, strategy: CompressionStrategy):
self._strategy = strategy
def compress_file(self, file_path: str) -> str:
if not self._strategy:
raise ValueError("未设置压缩策略")
return self._strategy.compress(file_path)
def decompress_file(self, file_path: str) -> str:
if not self._strategy:
raise ValueError("未设置压缩策略")
return self._strategy.decompress(file_path)
总结
策略模式是 Python 中非常实用的设计模式,它通过将算法封装成独立的策略类,提供了良好的扩展性和灵活性。通过本文的学习,你应该能够:
- 理解策略模式的核心概念和适用场景
- 掌握策略模式的基本实现方法
- 在实际项目中合理应用策略模式
- 避免策略模式的常见误用
点我分享笔记