Python 策略模式

策略模式是一种行为设计模式,它允许你在运行时选择算法或行为。简单来说,策略模式就是将不同的算法封装成独立的类,让它们可以相互替换。

核心思想

想象一下你去餐厅点餐:你可以选择不同的支付方式 - 现金、信用卡、移动支付等。无论你选择哪种支付方式,最终都能完成支付,只是具体的支付流程不同。这就是策略模式的现实例子。

设计原则

策略模式遵循以下重要原则:

  • 开闭原则:对扩展开放,对修改关闭
  • 单一职责原则:每个策略类只负责一种算法
  • 依赖倒置原则:依赖于抽象,而不是具体实现

策略模式的结构

让我们通过 UML 类图来理解策略模式的组成结构:

组件说明

Context(上下文)

  • 维护一个策略对象的引用
  • 可以动态切换策略
  • 将工作委托给策略对象

Strategy(策略接口)

  • 定义所有支持的算法的公共接口
  • 声明算法执行的方法

ConcreteStrategy(具体策略)

  • 实现策略接口
  • 提供具体的算法实现

基础语法和实现

策略接口定义

在 Python 中,我们可以使用抽象基类(ABC)来定义策略接口:

实例

from abc import ABC, abstractmethod

class PaymentStrategy(ABC):
    """支付策略抽象基类"""
   
    @abstractmethod
    def pay(self, amount: float) -> bool:
        """支付方法"""
        pass

具体策略实现

实例

class CreditCardPayment(PaymentStrategy):
    """信用卡支付策略"""
   
    def __init__(self, card_number: str, expiry_date: str, cvv: str):
        self.card_number = card_number
        self.expiry_date = expiry_date
        self.cvv = cvv
   
    def pay(self, amount: float) -> bool:
        print(f"使用信用卡支付 {amount} 元")
        print(f"卡号: {self.card_number}")
        # 这里应该有实际的支付逻辑
        return True

class AlipayPayment(PaymentStrategy):
    """支付宝支付策略"""
   
    def __init__(self, account: str):
        self.account = account
   
    def pay(self, amount: float) -> bool:
        print(f"使用支付宝支付 {amount} 元")
        print(f"支付宝账号: {self.account}")
        # 这里应该有实际的支付逻辑
        return True

class WechatPayment(PaymentStrategy):
    """微信支付策略"""
   
    def __init__(self, openid: str):
        self.openid = openid
   
    def pay(self, amount: float) -> bool:
        print(f"使用微信支付 {amount} 元")
        print(f"微信 OpenID: {self.openid}")
        # 这里应该有实际的支付逻辑
        return True

上下文类实现

实例

class PaymentContext:
    """支付上下文类"""
   
    def __init__(self, strategy: PaymentStrategy = None):
        self._strategy = strategy
   
    def set_strategy(self, strategy: PaymentStrategy):
        """设置支付策略"""
        self._strategy = strategy
   
    def execute_payment(self, amount: float) -> bool:
        """执行支付"""
        if not self._strategy:
            raise ValueError("未设置支付策略")
       
        return self._strategy.pay(amount)

完整示例:电商支付系统

让我们通过一个完整的电商支付系统来演示策略模式的实际应用:

实例

from abc import ABC, abstractmethod
from typing import Dict, Any

# 策略接口
class DiscountStrategy(ABC):
    """折扣策略接口"""
   
    @abstractmethod
    def calculate_discount(self, original_price: float) -> float:
        """计算折扣后的价格"""
        pass

# 具体策略实现
class NoDiscountStrategy(DiscountStrategy):
    """无折扣策略"""
   
    def calculate_discount(self, original_price: float) -> float:
        return original_price

class PercentageDiscountStrategy(DiscountStrategy):
    """百分比折扣策略"""
   
    def __init__(self, percentage: float):
        if not 0 <= percentage <= 100:
            raise ValueError("折扣百分比必须在 0-100 之间")
        self.percentage = percentage
   
    def calculate_discount(self, original_price: float) -> float:
        discount_amount = original_price * (self.percentage / 100)
        return original_price - discount_amount

class FixedAmountDiscountStrategy(DiscountStrategy):
    """固定金额折扣策略"""
   
    def __init__(self, discount_amount: float):
        if discount_amount < 0:
            raise ValueError("折扣金额不能为负数")
        self.discount_amount = discount_amount
   
    def calculate_discount(self, original_price: float) -> float:
        return max(0, original_price - self.discount_amount)

class SeasonalDiscountStrategy(DiscountStrategy):
    """季节性折扣策略"""
   
    def __init__(self, base_discount: float, seasonal_multiplier: float):
        self.base_discount = base_discount
        self.seasonal_multiplier = seasonal_multiplier
   
    def calculate_discount(self, original_price: float) -> float:
        total_discount = self.base_discount * self.seasonal_multiplier
        return max(0, original_price - total_discount)

# 上下文类
class ShoppingCart:
    """购物车类"""
   
    def __init__(self):
        self.items = []
        self._discount_strategy = NoDiscountStrategy()
   
    def add_item(self, item: str, price: float):
        """添加商品"""
        self.items.append({"item": item, "price": price})
   
    def set_discount_strategy(self, strategy: DiscountStrategy):
        """设置折扣策略"""
        self._discount_strategy = strategy
   
    def calculate_total(self) -> float:
        """计算总价"""
        total = sum(item["price"] for item in self.items)
        return self._discount_strategy.calculate_discount(total)
   
    def display_cart(self):
        """显示购物车内容"""
        print("购物车内容:")
        for item in self.items:
            print(f"  - {item['item']}: {item['price']}元")
       
        original_total = sum(item["price"] for item in self.items)
        final_total = self.calculate_total()
       
        print(f"原价: {original_total}元")
        print(f"折后价: {final_total}元")
       
        if original_total != final_total:
            discount = original_total - final_total
            print(f"节省: {discount}元")

# 使用示例
def main():
    # 创建购物车
    cart = ShoppingCart()
   
    # 添加商品
    cart.add_item("Python编程书", 89.0)
    cart.add_item("无线鼠标", 129.0)
    cart.add_item("机械键盘", 399.0)
   
    print("=== 无折扣 ===")
    cart.set_discount_strategy(NoDiscountStrategy())
    cart.display_cart()
   
    print("\n=== 8折优惠 ===")
    cart.set_discount_strategy(PercentageDiscountStrategy(20))  # 8折
    cart.display_cart()
   
    print("\n=== 满减优惠(减50元)===")
    cart.set_discount_strategy(FixedAmountDiscountStrategy(50))
    cart.display_cart()
   
    print("\n=== 季节性优惠 ===")
    cart.set_discount_strategy(SeasonalDiscountStrategy(30, 1.5))  # 基础折扣30,季节性系数1.5
    cart.display_cart()

if __name__ == "__main__":
    main()

运行上述代码,你将看到以下输出:

=== 无折扣 ===
购物车内容:
  - Python编程书: 89.0元
  - 无线鼠标: 129.0元
  - 机械键盘: 399.0元
原价: 617.0元
折后价: 617.0元

=== 8折优惠 ===
购物车内容:
  - Python编程书: 89.0元
  - 无线鼠标: 129.0元
  - 机械键盘: 399.0元
原价: 617.0元
折后价: 493.6元
节省: 123.4元

=== 满减优惠(减50元)===
购物车内容:
  - Python编程书: 89.0元
  - 无线鼠标: 129.0元
  - 机械键盘: 399.0元
原价: 617.0元
折后价: 567.0元
节省: 50.0元

=== 季节性优惠 ===
购物车内容:
  - Python编程书: 89.0元
  - 无线鼠标: 129.0元
  - 机械键盘: 399.0元
原价: 617.0元
折后价: 572.0元
节省: 45.0元

策略模式的进阶用法

1. 策略工厂模式

结合工厂模式来管理策略的创建:

实例

class DiscountStrategyFactory:
    """折扣策略工厂"""
   
    @staticmethod
    def create_strategy(strategy_type: str, **kwargs) -> DiscountStrategy:
        """创建折扣策略"""
        strategies = {
            "no_discount": NoDiscountStrategy,
            "percentage": PercentageDiscountStrategy,
            "fixed_amount": FixedAmountDiscountStrategy,
            "seasonal": SeasonalDiscountStrategy
        }
       
        if strategy_type not in strategies:
            raise ValueError(f"不支持的策略类型: {strategy_type}")
       
        return strategies[strategy_type](**kwargs)

# 使用工厂模式
factory = DiscountStrategyFactory()

# 创建不同的策略
strategy1 = factory.create_strategy("percentage", percentage=15)  # 85折
strategy2 = factory.create_strategy("fixed_amount", discount_amount=100)  # 减100元

2. 动态策略选择

根据条件动态选择策略:

实例

class DynamicDiscountSelector:
    """动态折扣选择器"""
   
    @staticmethod
    def select_strategy(user_type: str, total_amount: float) -> DiscountStrategy:
        """根据用户类型和总金额选择策略"""
        if user_type == "vip":
            if total_amount > 500:
                return PercentageDiscountStrategy(25)  # VIP 满500打75折
            else:
                return PercentageDiscountStrategy(15)  # VIP 打85折
        elif user_type == "normal":
            if total_amount > 300:
                return FixedAmountDiscountStrategy(30)  # 普通用户满300减30
            else:
                return NoDiscountStrategy()
        else:
            return NoDiscountStrategy()

# 使用动态选择
cart = ShoppingCart()
cart.add_item("商品A", 200)
cart.add_item("商品B", 150)

selector = DynamicDiscountSelector()
strategy = selector.select_strategy("vip", cart.calculate_total())
cart.set_discount_strategy(strategy)
cart.display_cart()

策略模式的优势和适用场景

优势对比

特性 传统方式 策略模式
扩展性 需要修改原有代码 新增策略类即可
维护性 代码耦合度高 职责分离,易于维护
灵活性 运行时难以切换算法 可动态切换策略
测试性 难以单独测试算法 每个策略可独立测试

适用场景

  1. 多种算法变体:当你有多个相似的类,只在某些行为上不同时
  2. 避免条件语句:当想要避免使用大量的条件语句(if-else 或 switch)时
  3. 运行时算法选择:需要在运行时选择不同算法时
  4. 算法封装:希望将算法细节与使用算法的客户端隔离开时

不适用场景

  1. 简单算法:如果只有一两个很少变化的算法,可能过度设计
  2. 客户端需要了解策略细节:如果客户端必须知道策略的具体实现
  3. 策略数量过多:当策略类数量爆炸时,考虑其他模式

最佳实践和注意事项

代码组织建议

实例

# 推荐的文件结构
project/
├── strategies/
│   ├── __init__.py
│   ├── base_strategy.py      # 基础策略接口
│   ├── discount_strategies.py # 折扣相关策略
│   └── payment_strategies.py # 支付相关策略
├── contexts/
│   ├── __init__.py
│   └── shopping_cart.py      # 上下文类
└── main.py

错误处理

实例

class SafeDiscountStrategy(DiscountStrategy):
    """带错误处理的折扣策略"""
   
    def __init__(self, base_strategy: DiscountStrategy, fallback_strategy: DiscountStrategy = None):
        self.base_strategy = base_strategy
        self.fallback_strategy = fallback_strategy or NoDiscountStrategy()
   
    def calculate_discount(self, original_price: float) -> float:
        try:
            return self.base_strategy.calculate_discount(original_price)
        except Exception as e:
            print(f"折扣计算错误: {e},使用备用策略")
            return self.fallback_strategy.calculate_discount(original_price)

性能考虑

对于性能敏感的场景,可以考虑以下优化:

实例

from functools import lru_cache

class CachedDiscountStrategy(DiscountStrategy):
    """带缓存的折扣策略"""
   
    def __init__(self, base_strategy: DiscountStrategy):
        self.base_strategy = base_strategy
   
    @lru_cache(maxsize=128)
    def calculate_discount(self, original_price: float) -> float:
        return self.base_strategy.calculate_discount(original_price)

实践练习

练习 1:排序策略实现

实现一个支持多种排序算法的排序器:

实例

from abc import ABC, abstractmethod
from typing import List

class SortStrategy(ABC):
    @abstractmethod
    def sort(self, data: List) -> List:
        pass

# TODO: 实现冒泡排序策略
class BubbleSortStrategy(SortStrategy):
    def sort(self, data: List) -> List:
        # 你的实现代码
        pass

# TODO: 实现快速排序策略  
class QuickSortStrategy(SortStrategy):
    def sort(self, data: List) -> List:
        # 你的实现代码
        pass

# TODO: 实现归并排序策略
class MergeSortStrategy(SortStrategy):
    def sort(self, data: List) -> List:
        # 你的实现代码
        pass

class Sorter:
    def __init__(self, strategy: SortStrategy = None):
        self._strategy = strategy
   
    def set_strategy(self, strategy: SortStrategy):
        self._strategy = strategy
   
    def sort_data(self, data: List) -> List:
        if not self._strategy:
            raise ValueError("未设置排序策略")
        return self._strategy.sort(data)

# 测试你的实现
data = [64, 34, 25, 12, 22, 11, 90]
sorter = Sorter(BubbleSortStrategy())
result = sorter.sort_data(data)
print(f"排序结果: {result}")

练习 2:文件压缩策略

设计一个支持多种压缩格式的文件压缩器:

实例

from abc import ABC, abstractmethod

class CompressionStrategy(ABC):
    @abstractmethod
    def compress(self, file_path: str) -> str:
        pass
   
    @abstractmethod
    def decompress(self, file_path: str) -> str:
        pass

# TODO: 实现 ZIP 压缩策略
class ZipCompressionStrategy(CompressionStrategy):
    def compress(self, file_path: str) -> str:
        # 你的实现代码
        pass
   
    def decompress(self, file_path: str) -> str:
        # 你的实现代码
        pass

# TODO: 实现 GZIP 压缩策略
class GzipCompressionStrategy(CompressionStrategy):
    def compress(self, file_path: str) -> str:
        # 你的实现代码
        pass
   
    def decompress(self, file_path: str) -> str:
        # 你的实现代码
        pass

class FileCompressor:
    def __init__(self, strategy: CompressionStrategy = None):
        self._strategy = strategy
   
    def set_strategy(self, strategy: CompressionStrategy):
        self._strategy = strategy
   
    def compress_file(self, file_path: str) -> str:
        if not self._strategy:
            raise ValueError("未设置压缩策略")
        return self._strategy.compress(file_path)
   
    def decompress_file(self, file_path: str) -> str:
        if not self._strategy:
            raise ValueError("未设置压缩策略")
        return self._strategy.decompress(file_path)

总结

策略模式是 Python 中非常实用的设计模式,它通过将算法封装成独立的策略类,提供了良好的扩展性和灵活性。通过本文的学习,你应该能够:

  1. 理解策略模式的核心概念和适用场景
  2. 掌握策略模式的基本实现方法
  3. 在实际项目中合理应用策略模式
  4. 避免策略模式的常见误用