Skip to content

集成测试套件 #10

@realm520

Description

@realm520

## 概述

为Strategy-21创建完整的集成测试套件,包含策略逻辑测试、回测验证、风险管理测试、性能基准测试等,确保策略在各种市场条件下的稳定性和可靠性,建立持续集成测试流程。

## 目标

- 创建全面的集成测试框架
- 验证策略在不同市场条件下的表现
- 建立自动化测试流程
- 确保代码质量和策略稳定性

## 技术要求

### 核心功能

1. **策略逻辑测试**
   - VATSM指标计算正确性
   - 信号生成逻辑验证
   - 仓位管理规则测试
   - 风险控制机制验证

2. **市场场景测试**
   - 牛市场景测试
   - 熊市场景测试
   - 震荡市场景测试
   - 极端波动场景测试

3. **回测验证测试**
   - 历史数据回测一致性
   - 不同时间段表现验证
   - 参数敏感性测试
   - 样本外测试验证

4. **性能基准测试**
   - 策略执行性能测试
   - 内存使用率测试
   - 数据处理速度测试
   - 并发处理能力测试

### 实现细节

#### 测试框架基础 (tests/conftest.py)

```python
import pytest
import asyncio
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from typing import Dict, List, Any
from unittest.mock import Mock, MagicMock, AsyncMock
from pathlib import Path
import json
import logging

# 测试数据生成器
class MarketDataGenerator:
    """市场数据生成器"""
    
    @staticmethod
    def generate_ohlcv_data(
        start_date: datetime,
        end_date: datetime,
        timeframe: str = '5m',
        trend: str = 'sideways',
        volatility: float = 0.02,
        seed: int = 42
    ) -> pd.DataFrame:
        """生成OHLCV测试数据"""
        np.random.seed(seed)
        
        # 计算时间间隔
        if timeframe == '1m':
            freq = '1min'
        elif timeframe == '5m':
            freq = '5min'
        elif timeframe == '1h':
            freq = '1H'
        else:
            freq = '1D'
        
        # 生成时间序列
        timestamps = pd.date_range(start_date, end_date, freq=freq)
        n_periods = len(timestamps)
        
        # 生成价格数据
        if trend == 'bullish':
            trend_component = np.linspace(0, 0.5, n_periods)
        elif trend == 'bearish':
            trend_component = np.linspace(0, -0.3, n_periods)
        else:
            trend_component = np.sin(np.linspace(0, 4*np.pi, n_periods)) * 0.1
        
        # 随机波动
        random_returns = np.random.normal(0, volatility, n_periods)
        cumulative_returns = np.cumsum(trend_component + random_returns)
        
        # 生成价格序列
        base_price = 50000  # BTC基准价格
        prices = base_price * (1 + cumulative_returns)
        
        # 生成OHLCV数据
        data = []
        for i, (timestamp, price) in enumerate(zip(timestamps, prices)):
            daily_volatility = volatility * np.random.uniform(0.5, 1.5)
            
            high = price * (1 + daily_volatility * np.random.uniform(0, 1))
            low = price * (1 - daily_volatility * np.random.uniform(0, 1))
            open_price = prices[i-1] if i > 0 else price
            close_price = price
            
            # 确保OHLC逻辑正确
            high = max(high, open_price, close_price)
            low = min(low, open_price, close_price)
            
            volume = np.random.lognormal(15, 1)  # 对数正态分布的成交量
            
            data.append({
                'timestamp': timestamp,
                'open': open_price,
                'high': high,
                'low': low,
                'close': close_price,
                'volume': volume
            })
        
        df = pd.DataFrame(data)
        df.set_index('timestamp', inplace=True)
        return df

class MockFreqtradeBot:
    """模拟Freqtrade机器人"""
    
    def __init__(self, config: Dict[str, Any]):
        self.config = config
        self.wallets = Mock()
        self.wallets.get_total_stake_amount.return_value = 10000
        self.strategy = None
        self.trades = []
        self.active_trades = []
        
    def get_pair_dataframe(self, pair: str, timeframe: str) -> pd.DataFrame:
        """获取交易对数据"""
        # 返回预生成的测试数据
        return MarketDataGenerator.generate_ohlcv_data(
            datetime.now() - timedelta(days=30),
            datetime.now(),
            timeframe
        )

# 测试配置
@pytest.fixture(scope="session")
def test_config():
    """测试配置"""
    return {
        "strategy": "Strategy21VATSM",
        "timeframe": "5m",
        "startup_candle_count": 500,
        "stake_currency": "USDT",
        "stake_amount": 100,
        "dry_run": True,
        "exchange": {
            "name": "binance",
            "pair_whitelist": ["BTC/USDT", "ETH/USDT"]
        }
    }

@pytest.fixture
def mock_bot(test_config):
    """模拟机器人实例"""
    return MockFreqtradeBot(test_config)

@pytest.fixture
def market_data_generator():
    """市场数据生成器实例"""
    return MarketDataGenerator()

@pytest.fixture
def strategy_params():
    """策略参数配置"""
    return {
        "vatsm_indicators": {
            "volume_sma_period": 20,
            "atr_period": 14,
            "trend_sma_period": 50,
            "strength_rsi_period": 14,
            "momentum_period": 14
        },
        "signal_system": {
            "entry_signals": {
                "volume_threshold": 1.5,
                "trend_alignment_required": True,
                "strength_min_score": 0.6
            },
            "signal_weights": {
                "volume_weight": 0.20,
                "atr_weight": 0.15,
                "trend_weight": 0.25,
                "strength_weight": 0.25,
                "momentum_weight": 0.15
            }
        },
        "risk_management": {
            "three_layer_stop_loss": {
                "fixed_stop_pct": 0.02,
                "trailing_stop_pct": 0.015,
                "atr_multiplier": 2.0
            },
            "drawdown_protection": {
                "max_drawdown_pct": 0.15,
                "position_scaling_enabled": True
            }
        }
    }

# 异步测试支持
@pytest.fixture
def event_loop():
    """异步测试事件循环"""
    loop = asyncio.get_event_loop_policy().new_event_loop()
    yield loop
    loop.close()

# 日志配置
@pytest.fixture(autouse=True)
def configure_logging():
    """配置测试日志"""
    logging.basicConfig(
        level=logging.DEBUG,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
    )

策略逻辑测试 (tests/test_strategy_logic.py)

import pytest
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from unittest.mock import Mock, patch

from strategies.strategy21_vatsm import Strategy21VATSM
from strategies.vatsm_indicators import VATSMIndicators
from strategies.signal_system import VATSMSignalSystem

class TestVATSMIndicators:
    """VATSM指标测试"""
    
    def test_volume_indicator_calculation(self, market_data_generator):
        """测试成交量指标计算"""
        # 生成测试数据
        data = market_data_generator.generate_ohlcv_data(
            datetime(2023, 1, 1),
            datetime(2023, 1, 31),
            trend='bullish'
        )
        
        indicators = VATSMIndicators()
        volume_data = indicators.calculate_volume_indicators(data, period=20)
        
        # 验证数据完整性
        assert len(volume_data) == len(data)
        assert 'volume_sma' in volume_data.columns
        assert 'volume_ratio' in volume_data.columns
        assert 'volume_score' in volume_data.columns
        
        # 验证计算逻辑
        assert volume_data['volume_sma'].iloc[-1] > 0
        assert 0 <= volume_data['volume_score'].iloc[-1] <= 1
        
        # 验证移动平均计算正确性
        manual_sma = data['volume'].rolling(20).mean().iloc[-1]
        assert abs(volume_data['volume_sma'].iloc[-1] - manual_sma) < 1e-6
    
    def test_atr_indicator_calculation(self, market_data_generator):
        """测试ATR指标计算"""
        data = market_data_generator.generate_ohlcv_data(
            datetime(2023, 1, 1),
            datetime(2023, 1, 31),
            volatility=0.05
        )
        
        indicators = VATSMIndicators()
        atr_data = indicators.calculate_atr_indicators(data, period=14)
        
        # 验证ATR计算
        assert 'atr' in atr_data.columns
        assert 'atr_pct' in atr_data.columns
        assert 'atr_score' in atr_data.columns
        
        # 验证ATR值合理性
        assert atr_data['atr'].iloc[-1] > 0
        assert 0 <= atr_data['atr_score'].iloc[-1] <= 1
        
        # 验证高波动率数据产生更高的ATR
        high_vol_data = market_data_generator.generate_ohlcv_data(
            datetime(2023, 1, 1),
            datetime(2023, 1, 31),
            volatility=0.1
        )
        high_vol_atr = indicators.calculate_atr_indicators(high_vol_data, period=14)
        
        assert high_vol_atr['atr'].iloc[-1] > atr_data['atr'].iloc[-1]
    
    def test_trend_indicator_calculation(self, market_data_generator):
        """测试趋势指标计算"""
        # 测试牛市趋势
        bullish_data = market_data_generator.generate_ohlcv_data(
            datetime(2023, 1, 1),
            datetime(2023, 2, 28),
            trend='bullish'
        )
        
        indicators = VATSMIndicators()
        trend_data = indicators.calculate_trend_indicators(bullish_data)
        
        assert 'sma_50' in trend_data.columns
        assert 'ema_21' in trend_data.columns
        assert 'trend_direction' in trend_data.columns
        assert 'trend_strength' in trend_data.columns
        assert 'trend_score' in trend_data.columns
        
        # 验证牛市趋势识别
        assert trend_data['trend_direction'].iloc[-1] > 0  # 上升趋势
        assert trend_data['trend_score'].iloc[-1] > 0.5   # 强趋势
        
        # 测试熊市趋势
        bearish_data = market_data_generator.generate_ohlcv_data(
            datetime(2023, 1, 1),
            datetime(2023, 2, 28),
            trend='bearish'
        )
        
        bearish_trend = indicators.calculate_trend_indicators(bearish_data)
        assert bearish_trend['trend_direction'].iloc[-1] < 0  # 下降趋势
    
    def test_strength_indicator_calculation(self, market_data_generator):
        """测试强度指标计算"""
        data = market_data_generator.generate_ohlcv_data(
            datetime(2023, 1, 1),
            datetime(2023, 2, 28)
        )
        
        indicators = VATSMIndicators()
        strength_data = indicators.calculate_strength_indicators(data)
        
        assert 'rsi' in strength_data.columns
        assert 'macd' in strength_data.columns
        assert 'macd_signal' in strength_data.columns
        assert 'macd_histogram' in strength_data.columns
        assert 'strength_score' in strength_data.columns
        
        # 验证RSI范围
        assert 0 <= strength_data['rsi'].iloc[-1] <= 100
        
        # 验证MACD计算
        assert strength_data['macd_histogram'].iloc[-1] == (
            strength_data['macd'].iloc[-1] - strength_data['macd_signal'].iloc[-1]
        )
    
    def test_momentum_indicator_calculation(self, market_data_generator):
        """测试动量指标计算"""
        data = market_data_generator.generate_ohlcv_data(
            datetime(2023, 1, 1),
            datetime(2023, 2, 28),
            trend='bullish'
        )
        
        indicators = VATSMIndicators()
        momentum_data = indicators.calculate_momentum_indicators(data)
        
        assert 'momentum' in momentum_data.columns
        assert 'momentum_sma' in momentum_data.columns
        assert 'momentum_score' in momentum_data.columns
        
        # 验证动量计算
        assert momentum_data['momentum_score'].iloc[-1] >= 0

class TestSignalSystem:
    """信号系统测试"""
    
    def test_signal_generation(self, market_data_generator, strategy_params):
        """测试信号生成"""
        # 生成强趋势数据
        data = market_data_generator.generate_ohlcv_data(
            datetime(2023, 1, 1),
            datetime(2023, 2, 28),
            trend='bullish',
            volatility=0.03
        )
        
        # 添加高成交量
        data.loc[data.index[-5:], 'volume'] *= 3
        
        signal_system = VATSMSignalSystem(strategy_params['signal_system'])
        signals = signal_system.generate_signals(data)
        
        assert 'entry_signal' in signals.columns
        assert 'exit_signal' in signals.columns
        assert 'signal_strength' in signals.columns
        assert 'confidence' in signals.columns
        
        # 验证信号值域
        assert signals['entry_signal'].dtype == bool
        assert signals['exit_signal'].dtype == bool
        assert (signals['signal_strength'] >= 0).all()
        assert (signals['signal_strength'] <= 1).all()
    
    def test_entry_signal_conditions(self, market_data_generator, strategy_params):
        """测试入场信号条件"""
        # 创建满足入场条件的数据
        data = market_data_generator.generate_ohlcv_data(
            datetime(2023, 1, 1),
            datetime(2023, 2, 28),
            trend='bullish'
        )
        
        signal_system = VATSMSignalSystem(strategy_params['signal_system'])
        
        # 人工设置满足条件的指标
        with patch.object(signal_system, '_calculate_vatsm_scores') as mock_vatsm:
            mock_vatsm.return_value = pd.DataFrame({
                'volume_score': [0.8] * len(data),
                'atr_score': [0.7] * len(data),
                'trend_score': [0.9] * len(data),
                'strength_score': [0.8] * len(data),
                'momentum_score': [0.7] * len(data),
                'vatsm_total_score': [0.78] * len(data)
            }, index=data.index)
            
            signals = signal_system.generate_signals(data)
            
            # 验证强信号产生入场信号
            assert signals['entry_signal'].iloc[-1] == True
            assert signals['signal_strength'].iloc[-1] > 0.7
    
    def test_exit_signal_conditions(self, market_data_generator, strategy_params):
        """测试出场信号条件"""
        data = market_data_generator.generate_ohlcv_data(
            datetime(2023, 1, 1),
            datetime(2023, 2, 28),
            trend='bearish'  # 趋势反转
        )
        
        signal_system = VATSMSignalSystem(strategy_params['signal_system'])
        
        # 模拟趋势反转和强度减弱
        with patch.object(signal_system, '_calculate_vatsm_scores') as mock_vatsm:
            mock_vatsm.return_value = pd.DataFrame({
                'volume_score': [0.3] * len(data),
                'atr_score': [0.4] * len(data),
                'trend_score': [0.2] * len(data),  # 趋势转弱
                'strength_score': [0.2] * len(data),  # 强度减弱
                'momentum_score': [0.1] * len(data),  # 动量衰减
                'vatsm_total_score': [0.24] * len(data)
            }, index=data.index)
            
            signals = signal_system.generate_signals(data)
            
            # 验证弱信号产生出场信号
            assert signals['exit_signal'].iloc[-1] == True

class TestStrategy21VATSM:
    """Strategy21主策略测试"""
    
    def test_strategy_initialization(self, test_config, strategy_params):
        """测试策略初始化"""
        strategy = Strategy21VATSM(test_config)
        strategy.load_params(strategy_params)
        
        assert strategy.timeframe == test_config['timeframe']
        assert strategy.startup_candle_count == test_config['startup_candle_count']
        assert strategy.vatsm_indicators is not None
        assert strategy.signal_system is not None
        assert strategy.position_manager is not None
        assert strategy.risk_manager is not None
    
    def test_populate_indicators(self, mock_bot, test_config, strategy_params):
        """测试指标计算"""
        strategy = Strategy21VATSM(test_config)
        strategy.load_params(strategy_params)
        strategy.dp = mock_bot  # 设置数据提供者
        
        # 生成测试数据
        dataframe = MarketDataGenerator.generate_ohlcv_data(
            datetime(2023, 1, 1),
            datetime(2023, 2, 28)
        )
        
        result = strategy.populate_indicators(dataframe, {})
        
        # 验证所有指标都已计算
        expected_columns = [
            'volume_sma', 'volume_ratio', 'volume_score',
            'atr', 'atr_score',
            'sma_50', 'ema_21', 'trend_score',
            'rsi', 'macd', 'strength_score',
            'momentum', 'momentum_score',
            'vatsm_total_score'
        ]
        
        for col in expected_columns:
            assert col in result.columns, f"Missing indicator column: {col}"
    
    def test_populate_entry_trend(self, mock_bot, test_config, strategy_params):
        """测试入场条件"""
        strategy = Strategy21VATSM(test_config)
        strategy.load_params(strategy_params)
        strategy.dp = mock_bot
        
        # 创建包含指标的数据
        dataframe = MarketDataGenerator.generate_ohlcv_data(
            datetime(2023, 1, 1),
            datetime(2023, 2, 28)
        )
        
        # 添加指标数据
        dataframe = strategy.populate_indicators(dataframe, {})
        
        # 人工设置满足入场条件的指标
        dataframe.loc[dataframe.index[-10:], 'vatsm_total_score'] = 0.75
        dataframe.loc[dataframe.index[-10:], 'volume_score'] = 0.8
        dataframe.loc[dataframe.index[-10:], 'trend_score'] = 0.9
        
        result = strategy.populate_entry_trend(dataframe, {})
        
        assert 'enter_long' in result.columns
        assert result['enter_long'].any()  # 应该有入场信号
    
    def test_populate_exit_trend(self, mock_bot, test_config, strategy_params):
        """测试出场条件"""
        strategy = Strategy21VATSM(test_config)
        strategy.load_params(strategy_params)
        strategy.dp = mock_bot
        
        dataframe = MarketDataGenerator.generate_ohlcv_data(
            datetime(2023, 1, 1),
            datetime(2023, 2, 28)
        )
        
        dataframe = strategy.populate_indicators(dataframe, {})
        
        # 人工设置出场条件
        dataframe.loc[dataframe.index[-5:], 'vatsm_total_score'] = 0.2  # 信号转弱
        dataframe.loc[dataframe.index[-5:], 'trend_score'] = 0.3  # 趋势转弱
        
        result = strategy.populate_exit_trend(dataframe, {})
        
        assert 'exit_long' in result.columns
        assert result['exit_long'].any()  # 应该有出场信号

市场场景测试 (tests/test_market_scenarios.py)

import pytest
import pandas as pd
import numpy as np
from datetime import datetime, timedelta

class TestMarketScenarios:
    """市场场景测试"""
    
    def test_bullish_market_performance(self, mock_bot, test_config, strategy_params, market_data_generator):
        """测试牛市表现"""
        # 生成牛市数据
        bullish_data = market_data_generator.generate_ohlcv_data(
            datetime(2023, 1, 1),
            datetime(2023, 6, 30),
            trend='bullish',
            volatility=0.02
        )
        
        strategy = Strategy21VATSM(test_config)
        strategy.load_params(strategy_params)
        strategy.dp = mock_bot
        
        # 运行策略
        result = self._run_strategy_simulation(strategy, bullish_data)
        
        # 验证牛市表现
        assert result['total_trades'] > 0
        assert result['win_rate'] > 0.4  # 牛市胜率应该不低于40%
        assert result['total_profit'] > 0  # 牛市应该盈利
        assert result['max_drawdown'] < 0.2  # 最大回撤不超过20%
    
    def test_bearish_market_performance(self, mock_bot, test_config, strategy_params, market_data_generator):
        """测试熊市表现"""
        bearish_data = market_data_generator.generate_ohlcv_data(
            datetime(2023, 1, 1),
            datetime(2023, 6, 30),
            trend='bearish',
            volatility=0.03
        )
        
        strategy = Strategy21VATSM(test_config)
        strategy.load_params(strategy_params)
        strategy.dp = mock_bot
        
        result = self._run_strategy_simulation(strategy, bearish_data)
        
        # 验证熊市表现
        assert result['max_drawdown'] < 0.25  # 熊市回撤控制在25%以内
        assert result['total_trades'] >= 0  # 熊市可能交易较少
        # 熊市可以允许小幅亏损,但不能巨亏
        assert result['total_profit'] > -0.15  # 亏损不超过15%
    
    def test_sideways_market_performance(self, mock_bot, test_config, strategy_params, market_data_generator):
        """测试震荡市表现"""
        sideways_data = market_data_generator.generate_ohlcv_data(
            datetime(2023, 1, 1),
            datetime(2023, 6, 30),
            trend='sideways',
            volatility=0.025
        )
        
        strategy = Strategy21VATSM(test_config)
        strategy.load_params(strategy_params)
        strategy.dp = mock_bot
        
        result = self._run_strategy_simulation(strategy, sideways_data)
        
        # 验证震荡市表现
        assert result['total_trades'] > 0  # 震荡市应该有较多交易
        assert result['max_drawdown'] < 0.15  # 震荡市回撤控制更严
        # 震荡市目标是小幅盈利或保本
        assert result['total_profit'] > -0.05  # 亏损不超过5%
    
    def test_high_volatility_scenario(self, mock_bot, test_config, strategy_params, market_data_generator):
        """测试高波动率场景"""
        high_vol_data = market_data_generator.generate_ohlcv_data(
            datetime(2023, 1, 1),
            datetime(2023, 3, 31),
            trend='sideways',
            volatility=0.08  # 高波动率
        )
        
        strategy = Strategy21VATSM(test_config)
        strategy.load_params(strategy_params)
        strategy.dp = mock_bot
        
        result = self._run_strategy_simulation(strategy, high_vol_data)
        
        # 验证高波动率处理
        assert result['avg_trade_duration'] > 0  # 交易持续时间应该合理
        assert result['max_position_size'] <= strategy_params['risk_management']['drawdown_protection']['max_drawdown_pct']
    
    def test_market_crash_scenario(self, mock_bot, test_config, strategy_params, market_data_generator):
        """测试市场崩盘场景"""
        # 生成正常数据
        normal_data = market_data_generator.generate_ohlcv_data(
            datetime(2023, 1, 1),
            datetime(2023, 2, 15),
            trend='sideways'
        )
        
        # 生成崩盘数据
        crash_data = market_data_generator.generate_ohlcv_data(
            datetime(2023, 2, 16),
            datetime(2023, 2, 20),
            trend='bearish',
            volatility=0.15  # 极高波动率
        )
        
        # 合并数据
        combined_data = pd.concat([normal_data, crash_data])
        
        strategy = Strategy21VATSM(test_config)
        strategy.load_params(strategy_params)
        strategy.dp = mock_bot
        
        result = self._run_strategy_simulation(strategy, combined_data)
        
        # 验证崩盘保护
        assert result['max_single_day_loss'] > -0.1  # 单日最大损失不超过10%
        assert result['total_profit'] > -0.2  # 总损失不超过20%
    
    def _run_strategy_simulation(self, strategy, dataframe) -> Dict[str, Any]:
        """运行策略模拟"""
        # 计算指标
        dataframe = strategy.populate_indicators(dataframe, {})
        dataframe = strategy.populate_entry_trend(dataframe, {})
        dataframe = strategy.populate_exit_trend(dataframe, {})
        
        # 模拟交易执行
        trades = []
        position = None
        portfolio_value = 10000  # 初始资金
        peak_value = portfolio_value
        max_drawdown = 0
        
        for i, row in dataframe.iterrows():
            current_value = portfolio_value
            
            # 处理入场信号
            if row['enter_long'] and position is None:
                position = {
                    'entry_price': row['close'],
                    'entry_time': i,
                    'amount': portfolio_value * 0.1  # 10%仓位
                }
            
            # 处理出场信号
            if row['exit_long'] and position is not None:
                profit = (row['close'] - position['entry_price']) * position['amount'] / position['entry_price']
                portfolio_value += profit
                
                trades.append({
                    'entry_time': position['entry_time'],
                    'exit_time': i,
                    'entry_price': position['entry_price'],
                    'exit_price': row['close'],
                    'profit': profit,
                    'duration': (i - position['entry_time']).total_seconds() / 3600  # 小时
                })
                
                position = None
            
            # 更新最大回撤
            if current_value > peak_value:
                peak_value = current_value
            
            drawdown = (peak_value - current_value) / peak_value
            if drawdown > max_drawdown:
                max_drawdown = drawdown
        
        # 计算统计指标
        if trades:
            profitable_trades = [t for t in trades if t['profit'] > 0]
            win_rate = len(profitable_trades) / len(trades)
            avg_profit = sum([t['profit'] for t in profitable_trades]) / len(profitable_trades) if profitable_trades else 0
            avg_loss = sum([t['profit'] for t in trades if t['profit'] < 0]) / len([t for t in trades if t['profit'] < 0]) if [t for t in trades if t['profit'] < 0] else 0
            avg_duration = sum([t['duration'] for t in trades]) / len(trades)
        else:
            win_rate = 0
            avg_profit = 0
            avg_loss = 0
            avg_duration = 0
        
        return {
            'total_trades': len(trades),
            'win_rate': win_rate,
            'total_profit': (portfolio_value - 10000) / 10000,
            'max_drawdown': max_drawdown,
            'avg_profit': avg_profit,
            'avg_loss': avg_loss,
            'avg_trade_duration': avg_duration,
            'max_single_day_loss': min([t['profit'] for t in trades]) if trades else 0,
            'max_position_size': 0.1  # 固定10%仓位
        }

回测验证测试 (tests/test_backtest_validation.py)

import pytest
from freqtrade.optimize.backtesting import Backtesting
from freqtrade.configuration import Configuration

class TestBacktestValidation:
    """回测验证测试"""
    
    def test_historical_data_consistency(self, test_config):
        """测试历史数据一致性"""
        # 配置回测参数
        config = test_config.copy()
        config.update({
            'strategy_list': ['Strategy21VATSM'],
            'timerange': '20230101-20230630',
            'max_open_trades': 5,
            'stake_amount': 100
        })
        
        # 创建回测实例
        backtesting = Backtesting(config)
        
        # 运行回测
        results = backtesting.start()
        
        # 验证回测结果一致性
        assert results is not None
        assert 'results_metrics' in results
        assert 'strategy_comparison' in results
        
        # 验证关键指标存在
        metrics = results['results_metrics']['Strategy21VATSM']
        assert 'total_trades' in metrics
        assert 'profit_mean' in metrics
        assert 'profit_total' in metrics
        assert 'max_drawdown' in metrics
    
    def test_parameter_sensitivity(self, test_config, strategy_params):
        """测试参数敏感性"""
        base_config = test_config.copy()
        base_results = self._run_backtest_with_params(base_config, strategy_params)
        
        # 测试成交量阈值敏感性
        modified_params = strategy_params.copy()
        modified_params['signal_system']['entry_signals']['volume_threshold'] = 2.0
        modified_results = self._run_backtest_with_params(base_config, modified_params)
        
        # 验证参数变化对结果的影响
        assert modified_results['total_trades'] != base_results['total_trades']
        
        # 测试止损参数敏感性
        modified_params2 = strategy_params.copy()
        modified_params2['risk_management']['three_layer_stop_loss']['fixed_stop_pct'] = 0.03
        modified_results2 = self._run_backtest_with_params(base_config, modified_params2)
        
        assert modified_results2['max_drawdown'] <= base_results['max_drawdown'] * 1.1  # 允许10%差异
    
    def test_out_of_sample_validation(self, test_config, strategy_params):
        """测试样本外验证"""
        # 训练期回测(2023年1-4月)
        train_config = test_config.copy()
        train_config['timerange'] = '20230101-20230430'
        train_results = self._run_backtest_with_params(train_config, strategy_params)
        
        # 测试期回测(2023年5-8月)
        test_config_oos = test_config.copy()
        test_config_oos['timerange'] = '20230501-20230831'
        test_results = self._run_backtest_with_params(test_config_oos, strategy_params)
        
        # 验证样本外表现不应该显著恶化
        performance_degradation = (train_results['profit_total'] - test_results['profit_total']) / abs(train_results['profit_total'])
        assert performance_degradation < 0.5  # 性能下降不超过50%
        
        # 风险特征应该相似
        drawdown_ratio = test_results['max_drawdown'] / train_results['max_drawdown']
        assert 0.5 < drawdown_ratio < 2.0  # 回撤比例在合理范围
    
    def test_different_timeframes(self, test_config, strategy_params):
        """测试不同时间框架"""
        timeframes = ['1m', '5m', '15m', '1h']
        results = {}
        
        for tf in timeframes:
            config = test_config.copy()
            config['timeframe'] = tf
            config['startup_candle_count'] = 500
            
            try:
                results[tf] = self._run_backtest_with_params(config, strategy_params)
                
                # 验证基本合理性
                assert results[tf]['total_trades'] >= 0
                assert results[tf]['max_drawdown'] <= 1.0
                
            except Exception as e:
                pytest.skip(f"Timeframe {tf} test failed: {e}")
        
        # 验证不同时间框架的逻辑一致性
        if len(results) >= 2:
            # 短周期通常交易更频繁
            if '1m' in results and '1h' in results:
                assert results['1m']['total_trades'] >= results['1h']['total_trades']
    
    def test_cross_validation(self, test_config, strategy_params):
        """测试交叉验证"""
        # 定义多个时间段
        periods = [
            ('20230101', '20230228'),  # 2023年1-2月
            ('20230301', '20230430'),  # 2023年3-4月
            ('20230501', '20230630'),  # 2023年5-6月
            ('20230701', '20230831'),  # 2023年7-8月
        ]
        
        period_results = []
        
        for start, end in periods:
            config = test_config.copy()
            config['timerange'] = f'{start}-{end}'
            
            try:
                result = self._run_backtest_with_params(config, strategy_params)
                period_results.append(result)
            except Exception as e:
                print(f"Period {start}-{end} failed: {e}")
                continue
        
        # 验证一致性
        if len(period_results) >= 3:
            profit_ratios = [r['profit_total'] for r in period_results]
            drawdowns = [r['max_drawdown'] for r in period_results]
            
            # 利润率标准差不应过大
            profit_std = np.std(profit_ratios)
            profit_mean = np.mean(profit_ratios)
            
            if profit_mean != 0:
                cv = abs(profit_std / profit_mean)  # 变异系数
                assert cv < 2.0  # 变异系数小于200%
            
            # 最大回撤应该相对稳定
            drawdown_max = max(drawdowns)
            drawdown_min = min(drawdowns)
            if drawdown_min > 0:
                assert drawdown_max / drawdown_min < 3.0  # 倍数差异不超过3倍
    
    def _run_backtest_with_params(self, config: Dict, params: Dict) -> Dict[str, float]:
        """使用指定参数运行回测"""
        # 这里应该集成实际的回测逻辑
        # 为了测试目的,返回模拟结果
        return {
            'total_trades': np.random.randint(50, 200),
            'profit_total': np.random.uniform(-0.1, 0.3),
            'profit_mean': np.random.uniform(-0.01, 0.02),
            'max_drawdown': np.random.uniform(0.05, 0.25),
            'win_rate': np.random.uniform(0.3, 0.7),
            'sharpe_ratio': np.random.uniform(0.5, 2.0)
        }

性能基准测试 (tests/test_performance_benchmarks.py)

import pytest
import time
import memory_profiler
import asyncio
from concurrent.futures import ThreadPoolExecutor

class TestPerformanceBenchmarks:
    """性能基准测试"""
    
    def test_strategy_execution_speed(self, mock_bot, test_config, strategy_params, market_data_generator):
        """测试策略执行速度"""
        strategy = Strategy21VATSM(test_config)
        strategy.load_params(strategy_params)
        strategy.dp = mock_bot
        
        # 生成大量测试数据
        large_dataset = market_data_generator.generate_ohlcv_data(
            datetime(2023, 1, 1),
            datetime(2023, 12, 31),
            timeframe='5m'
        )
        
        # 测量执行时间
        start_time = time.time()
        
        result = strategy.populate_indicators(large_dataset, {})
        result = strategy.populate_entry_trend(result, {})
        result = strategy.populate_exit_trend(result, {})
        
        execution_time = time.time() - start_time
        
        # 性能要求
        data_points = len(large_dataset)
        time_per_point = execution_time / data_points
        
        assert execution_time < 30.0  # 总执行时间不超过30秒
        assert time_per_point < 0.01  # 每个数据点处理时间不超过10ms
        
        print(f"Processing {data_points} data points in {execution_time:.2f}s")
        print(f"Time per data point: {time_per_point*1000:.2f}ms")
    
    @memory_profiler.profile
    def test_memory_usage(self, mock_bot, test_config, strategy_params, market_data_generator):
        """测试内存使用"""
        strategy = Strategy21VATSM(test_config)
        strategy.load_params(strategy_params)
        strategy.dp = mock_bot
        
        # 初始内存使用
        initial_memory = memory_profiler.memory_usage()[0]
        
        # 处理大量数据
        datasets = []
        for i in range(10):  # 创建10个数据集
            data = market_data_generator.generate_ohlcv_data(
                datetime(2023, 1, 1) + timedelta(days=i*30),
                datetime(2023, 1, 31) + timedelta(days=i*30),
                timeframe='5m'
            )
            
            result = strategy.populate_indicators(data, {})
            datasets.append(result)
        
        # 峰值内存使用
        peak_memory = max(memory_profiler.memory_usage())
        memory_increase = peak_memory - initial_memory
        
        # 内存使用要求
        assert memory_increase < 500  # 内存增长不超过500MB
        
        print(f"Initial memory: {initial_memory:.1f}MB")
        print(f"Peak memory: {peak_memory:.1f}MB")
        print(f"Memory increase: {memory_increase:.1f}MB")
        
        # 清理内存
        del datasets
    
    def test_concurrent_processing(self, mock_bot, test_config, strategy_params, market_data_generator):
        """测试并发处理能力"""
        strategy = Strategy21VATSM(test_config)
        strategy.load_params(strategy_params)
        strategy.dp = mock_bot
        
        pairs = ["BTC/USDT", "ETH/USDT", "BNB/USDT", "ADA/USDT", "SOL/USDT"]
        
        def process_pair(pair):
            """处理单个交易对"""
            data = market_data_generator.generate_ohlcv_data(
                datetime(2023, 1, 1),
                datetime(2023, 6, 30),
                timeframe='5m'
            )
            
            result = strategy.populate_indicators(data, {'pair': pair})
            return len(result)
        
        # 串行处理
        start_time = time.time()
        serial_results = []
        for pair in pairs:
            result = process_pair(pair)
            serial_results.append(result)
        serial_time = time.time() - start_time
        
        # 并行处理
        start_time = time.time()
        with ThreadPoolExecutor(max_workers=len(pairs)) as executor:
            parallel_results = list(executor.map(process_pair, pairs))
        parallel_time = time.time() - start_time
        
        # 验证结果一致性
        assert serial_results == parallel_results
        
        # 验证性能提升
        speedup = serial_time / parallel_time
        assert speedup > 1.5  # 至少1.5倍提速
        
        print(f"Serial processing: {serial_time:.2f}s")
        print(f"Parallel processing: {parallel_time:.2f}s")
        print(f"Speedup: {speedup:.2f}x")
    
    @pytest.mark.asyncio
    async def test_async_performance(self, mock_bot, test_config, strategy_params, market_data_generator):
        """测试异步处理性能"""
        strategy = Strategy21VATSM(test_config)
        strategy.load_params(strategy_params)
        
        async def async_process_data(data):
            """异步处理数据"""
            # 模拟异步IO操作
            await asyncio.sleep(0.1)
            
            result = strategy.populate_indicators(data, {})
            return result
        
        # 创建多个数据集
        datasets = []
        for i in range(5):
            data = market_data_generator.generate_ohlcv_data(
                datetime(2023, 1, 1) + timedelta(days=i*30),
                datetime(2023, 1, 31) + timedelta(days=i*30),
                timeframe='5m'
            )
            datasets.append(data)
        
        # 异步并发处理
        start_time = time.time()
        tasks = [async_process_data(data) for data in datasets]
        results = await asyncio.gather(*tasks)
        async_time = time.time() - start_time
        
        # 验证结果
        assert len(results) == len(datasets)
        assert all(len(result) > 0 for result in results)
        
        # 异步处理应该比串行IO快
        expected_serial_time = len(datasets) * 0.1  # 每个0.1秒
        assert async_time < expected_serial_time * 0.8  # 至少快20%
        
        print(f"Async processing time: {async_time:.2f}s")
        print(f"Expected serial time: {expected_serial_time:.2f}s")
    
    def test_large_scale_simulation(self, mock_bot, test_config, strategy_params):
        """测试大规模模拟"""
        strategy = Strategy21VATSM(test_config)
        strategy.load_params(strategy_params)
        strategy.dp = mock_bot
        
        # 模拟一年的数据处理
        start_time = time.time()
        
        total_processed = 0
        for month in range(1, 13):  # 12个月
            monthly_data = MarketDataGenerator.generate_ohlcv_data(
                datetime(2023, month, 1),
                datetime(2023, month, 28),
                timeframe='1m'  # 1分钟数据量更大
            )
            
            result = strategy.populate_indicators(monthly_data, {})
            total_processed += len(result)
        
        total_time = time.time() - start_time
        
        # 性能基准
        throughput = total_processed / total_time  # 每秒处理的数据点
        
        assert throughput > 1000  # 每秒至少处理1000个数据点
        assert total_time < 300  # 总时间不超过5分钟
        
        print(f"Processed {total_processed} data points in {total_time:.2f}s")
        print(f"Throughput: {throughput:.0f} points/second")

测试执行脚本 (run_tests.py)

#!/usr/bin/env python3

import subprocess
import sys
import os
from datetime import datetime

def run_test_suite():
    """运行完整测试套件"""
    print("=" * 60)
    print("Strategy21 集成测试套件")
    print("=" * 60)
    
    test_commands = [
        {
            'name': '策略逻辑测试',
            'command': 'pytest tests/test_strategy_logic.py -v --tb=short',
            'critical': True
        },
        {
            'name': '市场场景测试',
            'command': 'pytest tests/test_market_scenarios.py -v --tb=short',
            'critical': True
        },
        {
            'name': '回测验证测试',
            'command': 'pytest tests/test_backtest_validation.py -v --tb=short',
            'critical': False
        },
        {
            'name': '性能基准测试',
            'command': 'pytest tests/test_performance_benchmarks.py -v --tb=short',
            'critical': False
        },
        {
            'name': '代码覆盖率测试',
            'command': 'pytest --cov=strategies --cov-report=html --cov-report=term',
            'critical': False
        }
    ]
    
    results = []
    
    for test in test_commands:
        print(f"\n运行 {test['name']}...")
        print("-" * 40)
        
        try:
            result = subprocess.run(
                test['command'].split(),
                capture_output=True,
                text=True,
                timeout=600  # 10分钟超时
            )
            
            if result.returncode == 0:
                print(f"✓ {test['name']} 通过")
                results.append({'name': test['name'], 'status': 'PASS', 'critical': test['critical']})
            else:
                print(f"✗ {test['name']} 失败")
                print("STDOUT:", result.stdout)
                print("STDERR:", result.stderr)
                results.append({'name': test['name'], 'status': 'FAIL', 'critical': test['critical']})
                
        except subprocess.TimeoutExpired:
            print(f"✗ {test['name']} 超时")
            results.append({'name': test['name'], 'status': 'TIMEOUT', 'critical': test['critical']})
        except Exception as e:
            print(f"✗ {test['name']} 错误: {e}")
            results.append({'name': test['name'], 'status': 'ERROR', 'critical': test['critical']})
    
    # 输出测试总结
    print("\n" + "=" * 60)
    print("测试结果总结")
    print("=" * 60)
    
    total_tests = len(results)
    passed_tests = len([r for r in results if r['status'] == 'PASS'])
    critical_failures = len([r for r in results if r['status'] != 'PASS' and r['critical']])
    
    for result in results:
        status_symbol = "✓" if result['status'] == 'PASS' else "✗"
        critical_mark = " (关键)" if result['critical'] else ""
        print(f"{status_symbol} {result['name']}: {result['status']}{critical_mark}")
    
    print(f"\n总测试数: {total_tests}")
    print(f"通过测试: {passed_tests}")
    print(f"失败测试: {total_tests - passed_tests}")
    print(f"关键失败: {critical_failures}")
    
    # 生成测试报告
    generate_test_report(results)
    
    # 返回退出码
    if critical_failures > 0:
        print("\n❌ 关键测试失败,构建不通过")
        return 1
    elif passed_tests == total_tests:
        print("\n✅ 所有测试通过")
        return 0
    else:
        print("\n⚠️  部分非关键测试失败")
        return 0

def generate_test_report(results):
    """生成测试报告"""
    report_path = f"test_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.html"
    
    html_content = f"""
    <!DOCTYPE html>
    <html>
    <head>
        <title>Strategy21 测试报告</title>
        <style>
            body {{ font-family: Arial, sans-serif; margin: 40px; }}
            .header {{ background: #f0f0f0; padding: 20px; border-radius: 5px; }}
            .test-result {{ margin: 10px 0; padding: 10px; border-radius: 5px; }}
            .pass {{ background: #d4edda; color: #155724; }}
            .fail {{ background: #f8d7da; color: #721c24; }}
            .critical {{ border-left: 5px solid #ff6b6b; }}
        </style>
    </head>
    <body>
        <div class="header">
            <h1>Strategy21 集成测试报告</h1>
            <p>生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
        </div>
        
        <h2>测试结果</h2>
    """
    
    for result in results:
        css_class = "pass" if result['status'] == 'PASS' else "fail"
        if result['critical']:
            css_class += " critical"
        
        html_content += f"""
        <div class="test-result {css_class}">
            <strong>{result['name']}</strong>: {result['status']}
            {' (关键测试)' if result['critical'] else ''}
        </div>
        """
    
    html_content += """
    </body>
    </html>
    """
    
    with open(report_path, 'w', encoding='utf-8') as f:
        f.write(html_content)
    
    print(f"\n测试报告已生成: {report_path}")

if __name__ == "__main__":
    exit_code = run_test_suite()
    sys.exit(exit_code)

验收标准

功能测试

  1. 测试覆盖率

    • 策略逻辑测试覆盖率 > 90%
    • 核心模块单元测试覆盖率 > 85%
    • 集成测试覆盖主要场景
    • 边界条件和异常情况测试完整
  2. 测试质量

    • 所有关键测试用例通过
    • 测试数据生成器产生合理数据
    • 测试结果可重现
    • 测试执行时间合理(总时间 < 30分钟)
  3. 场景覆盖

    • 牛市/熊市/震荡市场景测试通过
    • 极端市场条件处理正确
    • 不同时间框架测试一致
    • 参数敏感性测试合理

性能测试

  1. 执行性能
    • 策略执行速度满足要求
    • 内存使用控制在合理范围
    • 并发处理能力达标
    • 大规模数据处理稳定

实现计划

阶段1:测试框架搭建 (1.5小时)

  • 创建测试基础设施
  • 实现数据生成器
  • 配置测试环境
  • 建立Mock对象

阶段2:策略逻辑测试 (2小时)

  • VATSM指标计算测试
  • 信号系统测试
  • 策略主逻辑测试
  • 边界条件测试

阶段3:市场场景测试 (1.5小时)

  • 不同市场环境测试
  • 极端场景处理测试
  • 风险控制测试
  • 交叉验证测试

阶段4:性能和集成测试 (1小时)

  • 性能基准测试
  • 回测验证测试
  • 自动化测试流程
  • 测试报告生成

依赖关系

  • 前置依赖: 007 (风险管理系统) - 需要完整的策略实现
  • 并行执行: 不可与其他任务并行执行
  • 后续集成: 为任务010提供质量保证

风险与缓解

技术风险

  1. 测试数据质量

    • 风险:生成的测试数据不够真实
    • 缓解:基于历史数据模式生成,添加真实市场特征
  2. 测试执行时间

    • 风险:测试套件执行时间过长
    • 缓解:优化测试数据量,使用并行测试

业务风险

  1. 测试覆盖不足
    • 风险:重要场景未覆盖导致线上问题
    • 缓解:基于实际交易场景设计测试用例

成功指标

  1. 测试覆盖率 - 代码覆盖率 > 85%,场景覆盖率 > 90%
  2. 测试通过率 - 所有关键测试 100% 通过
  3. 性能达标率 - 性能基准测试 100% 达标
  4. 自动化程度 - 测试流程完全自动化
  5. 问题发现率 - 测试能发现 95% 以上的已知问题

交付物

  1. 测试框架

    • conftest.py (测试配置和工具)
    • market_data_generator.py (测试数据生成器)
    • mock_objects.py (模拟对象)
  2. 测试套件

    • test_strategy_logic.py (策略逻辑测试)
    • test_market_scenarios.py (市场场景测试)
    • test_backtest_validation.py (回测验证测试)
    • test_performance_benchmarks.py (性能基准测试)
  3. 自动化工具

    • run_tests.py (测试执行脚本)
    • pytest.ini (pytest配置)
    • requirements-test.txt (测试依赖)
  4. 文档

    • 测试套件使用指南
    • 测试用例设计文档
    • 性能基准说明
    • 故障排查指南

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions