Python 模板方法模式

模板方法模式是一种行为设计模式,它定义了一个操作中的算法骨架,而将一些步骤延迟到子类中实现。模板方法使得子类可以不改变算法结构的情况下,重新定义算法中的某些特定步骤。

简单来说,就像做菜的食谱一样:食谱规定了做菜的步骤顺序(洗菜、切菜、炒菜、装盘),但具体怎么洗、怎么切、怎么炒,可以由不同的厨师根据自己的风格来实现。


为什么需要模板方法模式?

代码复用

模板方法模式将不变的行为移到超类中,避免了子类中的代码重复。多个子类可以共享同一个模板方法定义的算法结构。

扩展性

子类可以通过重写钩子方法或具体步骤方法,来扩展或修改算法的部分行为,而不需要改变算法的整体结构。

控制反转

父类控制着算法的执行流程,子类只需要关注自己需要实现的特定步骤,实现了"好莱坞原则"——"不要打电话给我们,我们会打电话给你"。


模板方法模式的结构

让我们通过一个类图来理解模板方法模式的结构:

核心组件说明

组件 职责 说明
AbstractClass 定义算法骨架 包含模板方法和抽象步骤方法
ConcreteClass 实现具体步骤 实现父类定义的抽象方法
template_method 模板方法 定义算法的不变部分
step1, step2 抽象方法 需要子类实现的具体步骤
hook 钩子方法 可选步骤,子类可选择是否重写

基本语法和实现

抽象基类的定义

实例

from abc import ABC, abstractmethod

class AbstractClass(ABC):
    """模板方法模式的抽象基类"""
   
    def template_method(self):
        """模板方法 - 定义算法骨架"""
        self.step1()
        self.step2()
        self.hook()
        self.step3()
   
    @abstractmethod
    def step1(self):
        """抽象方法1 - 必须由子类实现"""
        pass
   
    @abstractmethod
    def step2(self):
        """抽象方法2 - 必须由子类实现"""
        pass
   
    def step3(self):
        """具体方法 - 已有默认实现"""
        print("执行步骤3 - 默认实现")
   
    def hook(self):
        """钩子方法 - 可选步骤,子类可选择重写"""
        print("执行钩子方法 - 默认什么都不做")

具体子类的实现

实例

class ConcreteClassA(AbstractClass):
    """具体实现类A"""
   
    def step1(self):
        print("ConcreteClassA - 执行步骤1")
   
    def step2(self):
        print("ConcreteClassA - 执行步骤2")
   
    def hook(self):
        print("ConcreteClassA - 重写钩子方法,添加额外功能")

class ConcreteClassB(AbstractClass):
    """具体实现类B"""
   
    def step1(self):
        print("ConcreteClassB - 执行步骤1")
   
    def step2(self):
        print("ConcreteClassB - 执行步骤2")
   
    # 不重写hook方法,使用默认实现

实际应用示例

让我们通过几个实际的例子来深入理解模板方法模式的应用。

示例1:数据处理的模板

实例

from abc import ABC, abstractmethod
import json
import csv

class DataProcessor(ABC):
    """数据处理模板"""
   
    def process_data(self, input_file, output_file):
        """数据处理模板方法"""
        print(f"开始处理数据: {input_file} -> {output_file}")
       
        # 读取数据
        data = self.read_data(input_file)
        print(f"读取到 {len(data)} 条数据")
       
        # 转换数据
        transformed_data = self.transform_data(data)
        print("数据转换完成")
       
        # 保存数据
        self.save_data(transformed_data, output_file)
        print("数据保存完成")
       
        # 清理工作
        self.cleanup()
   
    @abstractmethod
    def read_data(self, file_path):
        """读取数据 - 抽象方法"""
        pass
   
    @abstractmethod
    def transform_data(self, data):
        """转换数据 - 抽象方法"""
        pass
   
    def save_data(self, data, file_path):
        """保存数据 - 具体方法"""
        with open(file_path, 'w', encoding='utf-8') as f:
            if isinstance(data, list):
                for item in data:
                    f.write(str(item) + '\n')
            else:
                f.write(str(data))
   
    def cleanup(self):
        """清理工作 - 钩子方法"""
        print("清理工作完成")

class JSONProcessor(DataProcessor):
    """JSON 数据处理器"""
   
    def read_data(self, file_path):
        with open(file_path, 'r', encoding='utf-8') as f:
            return json.load(f)
   
    def transform_data(self, data):
        # 简单的转换:将所有字符串值转为大写
        if isinstance(data, dict):
            return {k: v.upper() if isinstance(v, str) else v
                   for k, v in data.items()}
        elif isinstance(data, list):
            return [item.upper() if isinstance(item, str) else item
                   for item in data]
        return data

class CSVProcessor(DataProcessor):
    """CSV 数据处理器"""
   
    def read_data(self, file_path):
        data = []
        with open(file_path, 'r', encoding='utf-8') as f:
            reader = csv.DictReader(f)
            for row in reader:
                data.append(row)
        return data
   
    def transform_data(self, data):
        # 为每条记录添加处理时间戳
        import datetime
        timestamp = datetime.datetime.now().isoformat()
        for record in data:
            record['processed_at'] = timestamp
        return data
   
    def cleanup(self):
        """重写清理方法,添加额外清理逻辑"""
        print("CSV 处理器清理完成")
        print("释放 CSV 解析器资源")

# 使用示例
if __name__ == "__main__":
    # 创建测试数据
    import os
   
    # 测试 JSON 处理器
    json_data = {'name': 'john', 'age': 30, 'city': 'new york'}
    with open('test.json', 'w') as f:
        json.dump(json_data, f)
   
    json_processor = JSONProcessor()
    json_processor.process_data('test.json', 'output_json.txt')
   
    print("\n" + "="*50 + "\n")
   
    # 测试 CSV 处理器
    csv_data = "name,age,city\njohn,30,new york\njane,25,los angeles"
    with open('test.csv', 'w') as f:
        f.write(csv_data)
   
    csv_processor = CSVProcessor()
    csv_processor.process_data('test.csv', 'output_csv.txt')
   
    # 清理测试文件
    for file in ['test.json', 'test.csv', 'output_json.txt', 'output_csv.txt']:
        if os.path.exists(file):
            os.remove(file)

示例2:饮料制作模板

实例

from abc import ABC, abstractmethod

class BeverageMaker(ABC):
    """饮料制作模板"""
   
    def make_beverage(self):
        """制作饮料的模板方法"""
        self.boil_water()
        self.brew()
        self.pour_in_cup()
        if self.customer_wants_condiments():
            self.add_condiments()
        self.serve()
   
    def boil_water(self):
        """烧水 - 具体方法"""
        print("烧开水")
   
    @abstractmethod
    def brew(self):
        """冲泡 - 抽象方法"""
        pass
   
    def pour_in_cup(self):
        """倒入杯子 - 具体方法"""
        print("倒入杯子中")
   
    @abstractmethod
    def add_condiments(self):
        """添加调料 - 抽象方法"""
        pass
   
    def customer_wants_condiments(self):
        """钩子方法 - 客户是否要调料"""
        return True
   
    def serve(self):
        """上饮料 - 具体方法"""
        print("饮料制作完成,请享用!")

class CoffeeMaker(BeverageMaker):
    """咖啡制作"""
   
    def brew(self):
        print("用沸水冲泡咖啡粉")
   
    def add_condiments(self):
        print("加入糖和牛奶")
   
    def customer_wants_condiments(self):
        answer = input("咖啡要加糖和牛奶吗?(y/n): ")
        return answer.lower() == 'y'

class TeaMaker(BeverageMaker):
    """茶制作"""
   
    def brew(self):
        print("用沸水浸泡茶叶")
   
    def add_condiments(self):
        print("加入柠檬")

# 使用示例
if __name__ == "__main__":
    print("制作咖啡:")
    coffee = CoffeeMaker()
    coffee.make_beverage()
   
    print("\n" + "="*30 + "\n")
   
    print("制作茶:")
    tea = TeaMaker()
    tea.make_beverage()

模板方法模式的变体

1. 带参数的模板方法

实例

class ConfigurableProcessor(ABC):
    """可配置的数据处理器"""
   
    def process_with_config(self, input_file, output_file, config):
        """带配置参数的模板方法"""
        self.validate_config(config)
        data = self.read_data(input_file)
        processed_data = self.process_with_config_impl(data, config)
        self.save_data(processed_data, output_file)
        self.post_process(config)
   
    def validate_config(self, config):
        """验证配置 - 具体方法"""
        required_keys = ['format', 'encoding']
        for key in required_keys:
            if key not in config:
                raise ValueError(f"配置中缺少必需的键: {key}")
   
    @abstractmethod
    def process_with_config_impl(self, data, config):
        """使用配置处理数据的实现"""
        pass
   
    def post_process(self, config):
        """后处理 - 钩子方法"""
        if config.get('cleanup', False):
            print("执行清理操作")

2. 多步骤模板方法

实例

class MultiStepProcessor(ABC):
    """多步骤处理器"""
   
    def complex_processing(self):
        """复杂处理流程"""
        self.initialize()
        self.pre_process()
        self.main_process()
        self.post_process()
        self.finalize()
   
    def initialize(self):
        print("初始化处理器")
   
    def pre_process(self):
        """预处理 - 钩子方法"""
        pass
   
    @abstractmethod
    def main_process(self):
        """主要处理逻辑"""
        pass
   
    def post_process(self):
        """后处理 - 钩子方法"""
        pass
   
    def finalize(self):
        print("处理完成")

最佳实践和注意事项

1. 合理使用抽象方法

  • 只将真正需要子类实现的方法声明为抽象方法
  • 为可选步骤提供默认实现的钩子方法

2. 模板方法的访问控制

  • 模板方法通常应该声明为 final(在 Python 中可以通过命名约定)
  • 步骤方法应该被保护,避免外部直接调用

3. 错误处理

实例

class RobustTemplate(ABC):
    """健壮的模板类"""
   
    def template_method(self):
        try:
            self.setup()
            self.execute_steps()
        except Exception as e:
            self.handle_error(e)
        finally:
            self.cleanup()
   
    def execute_steps(self):
        """执行步骤序列"""
        self.step1()
        self.step2()
        self.step3()
   
    def handle_error(self, error):
        """错误处理 - 钩子方法"""
        print(f"处理过程中发生错误: {error}")
   
    @abstractmethod
    def setup(self):
        pass
   
    @abstractmethod
    def step1(self):
        pass
   
    @abstractmethod
    def step2(self):
        pass
   
    @abstractmethod
    def step3(self):
        pass
   
    def cleanup(self):
        print("清理资源")

实践练习

现在轮到你来实践了!请完成以下练习来巩固对模板方法模式的理解:

练习1:实现文件导出模板

创建一个文件导出模板,支持导出为不同格式(TXT、HTML、JSON)。

要求:

  • 定义抽象基类 FileExporter
  • 实现具体子类 TXTExporterHTMLExporterJSONExporter
  • 模板方法应该包含:准备数据、格式化数据、保存文件、后处理等步骤

练习2:改进饮料制作模板

扩展之前的饮料制作模板,添加以下功能:

  • 支持选择饮料大小(小杯、中杯、大杯)
  • 添加价格计算功能
  • 支持自定义调料

总结

模板方法模式是一个强大而实用的设计模式,它通过以下方式帮助我们构建更好的代码:

主要优点

  1. 提高代码复用性:将公共代码放在父类中
  2. 提高扩展性:通过子类扩展特定步骤
  3. 符合开闭原则:对扩展开放,对修改关闭
  4. 反向控制:父类控制流程,子类实现细节

适用场景

  • 多个类有相同的方法,但具体实现不同
  • 需要控制子类的扩展点
  • 需要定义算法骨架,但允许某些步骤变化

核心要点

  • 模板方法定义算法的不变部分
  • 抽象方法定义算法的可变部分
  • 钩子方法提供可选的扩展点