Files
quant/bt.py
2025-10-31 09:32:53 +08:00

294 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import vectorbt as vbt
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
# 设置中文字体和显示格式
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
pd.set_option('display.float_format', '{:.4f}'.format)
stock_00981 = ak.stock_hk_daily(symbol="00981")
stock_01347 = ak.stock_hk_daily(symbol="01347")
class PairsTradingStrategy:
def __init__(self, price_a, price_b, lookback_window=60, k=1.5):
"""
初始化配对交易策略
"""
self.price_a = price_a
self.price_b = price_b
self.lookback_window = lookback_window
self.k = k
def calculate_ratio(self):
"""计算价格比率 RS = 华虹/中芯"""
return self.price_b / self.price_a
def generate_signals(self):
"""生成交易信号"""
ratio = self.calculate_ratio()
# 计算滚动均值和标准差
ratio_mean = ratio.rolling(window=self.lookback_window).mean()
ratio_std = ratio.rolling(window=self.lookback_window).std()
# 计算上下轨道
upper_band = ratio_mean + self.k * ratio_std
lower_band = ratio_mean - self.k * ratio_std
# 生成信号
short_b_long_a = ratio > upper_band # 做空华虹,做多中芯
long_b_short_a = ratio < lower_band # 做多华虹,做空中芯
exit_signal = (ratio <= ratio_mean) & (ratio >= ratio_mean) # 回归均值平仓
return {
'ratio': ratio,
'ratio_mean': ratio_mean,
'upper_band': upper_band,
'lower_band': lower_band,
'short_b_long_a': short_b_long_a,
'long_b_short_a': long_b_short_a,
'exit_signal': exit_signal
}
def backtest(self, initial_cash=100000, transaction_cost=0.001):
"""执行回测 - 修复后的版本"""
signals = self.generate_signals()
# 创建价格DataFrame
prices = pd.DataFrame({
'SMIC': self.price_a, # 中芯国际
'HuaHong': self.price_b # 华虹半导体
})
# 生成仓位信号
positions = self._generate_positions(signals, prices)
# 使用vectorbt进行回测
portfolio = vbt.Portfolio.from_holdings(
prices,
size=positions, # 仓位大小
init_cash=initial_cash,
fees=transaction_cost,
freq='1D'
)
return portfolio, signals
def _generate_positions(self, signals, prices):
"""生成仓位序列"""
positions = pd.DataFrame(0, index=prices.index, columns=prices.columns)
cash_per_trade = 0.2 # 每次交易使用20%资金
current_position = 0 # 0: 无仓位, 1: 做空华虹做多中芯, -1: 做多华虹做空中芯
for i in range(len(prices)):
if i < self.lookback_window:
continue
current_date = prices.index[i]
# 退出信号
if current_position != 0 and signals['exit_signal'].iloc[i]:
positions.loc[current_date, :] = 0
current_position = 0
# 开仓信号 - 只在没有持仓时开仓
elif current_position == 0:
if signals['short_b_long_a'].iloc[i]:
# 做空华虹,做多中芯
positions.loc[current_date, 'HuaHong'] = -cash_per_trade # 做空
positions.loc[current_date, 'SMIC'] = cash_per_trade # 做多
current_position = 1
elif signals['long_b_short_a'].iloc[i]:
# 做多华虹,做空中芯
positions.loc[current_date, 'HuaHong'] = cash_per_trade # 做多
positions.loc[current_date, 'SMIC'] = -cash_per_trade # 做空
current_position = -1
# 前向填充仓位,直到下一个信号
positions = positions.replace(0, np.nan).ffill().fillna(0)
return positions
# 数据预处理函数
def prepare_data(stock_00981, stock_01347):
"""准备回测数据"""
# 复制数据避免修改原数据
smic_data = stock_00981.copy()
huahong_data = stock_01347.copy()
# 设置日期索引
smic_data['date'] = pd.to_datetime(smic_data['date'])
huahong_data['date'] = pd.to_datetime(huahong_data['date'])
smic_data = smic_data.set_index('date')
huahong_data = huahong_data.set_index('date')
# 对齐数据 - 只保留两个股票都有的交易日
common_dates = smic_data.index.intersection(huahong_data.index)
smic_aligned = smic_data.loc[common_dates]
huahong_aligned = huahong_data.loc[common_dates]
return smic_aligned['close'], huahong_aligned['close']
# 执行回测
print("准备数据...")
smic_close, huahong_close = prepare_data(stock_00981, stock_01347)
print(f"数据时间范围: {smic_close.index.min()}{smic_close.index.max()}")
print(f"总交易日数: {len(smic_close)}")
print(f"中芯国际价格范围: {smic_close.min():.2f} - {smic_close.max():.2f}")
print(f"华虹半导体价格范围: {huahong_close.min():.2f} - {huahong_close.max():.2f}")
# 创建策略实例
strategy = PairsTradingStrategy(
price_a=smic_close, # 中芯国际
price_b=huahong_close, # 华虹半导体
lookback_window=60, # 60日滚动窗口
k=1.5 # 1.5倍标准差
)
print("执行回测...")
portfolio, signals = strategy.backtest(
initial_cash=100000, # 初始资金10万
transaction_cost=0.001 # 交易成本0.1%
)
# 分析结果
print("\n" + "="*50)
print("回测结果分析")
print("="*50)
# 基本统计
try:
print(f"总收益率: {portfolio.total_return():.2%}")
print(f"年化收益率: {portfolio.annualized_return():.2%}")
print(f"最大回撤: {portfolio.max_drawdown():.2%}")
print(f"夏普比率: {portfolio.sharpe_ratio():.2f}")
except:
print("部分指标计算失败,继续显示其他结果...")
# 交易统计
try:
stats = portfolio.stats()
print(f"总交易次数: {stats['Total Trades']}")
print(f"胜率: {stats.get('Win Rate [%]', 'N/A')}%")
print(f"盈亏比: {stats.get('Profit Factor', 'N/A')}")
except:
print("交易统计获取失败")
# 可视化结果
fig = plt.figure(figsize=(15, 12))
# 1. 价格比率和交易信号
ax1 = plt.subplot(3, 1, 1)
plt.plot(signals['ratio'].index, signals['ratio'].values, label='价格比率(华虹/中芯)', linewidth=1)
plt.plot(signals['ratio_mean'].index, signals['ratio_mean'].values, label='滚动均值', linestyle='--', alpha=0.7)
plt.plot(signals['upper_band'].index, signals['upper_band'].values, label=f'上轨(μ+{strategy.k}σ)', linestyle='--', color='red', alpha=0.7)
plt.plot(signals['lower_band'].index, signals['lower_band'].values, label=f'下轨(μ-{strategy.k}σ)', linestyle='--', color='green', alpha=0.7)
# 标记交易信号
short_signals = signals['ratio'][signals['short_b_long_a'] & (signals['ratio'].notna())]
long_signals = signals['ratio'][signals['long_b_short_a'] & (signals['ratio'].notna())]
if len(short_signals) > 0:
plt.scatter(short_signals.index, short_signals.values, color='red', marker='v', s=50, label='做空华虹/做多中芯')
if len(long_signals) > 0:
plt.scatter(long_signals.index, long_signals.values, color='green', marker='^', s=50, label='做多华虹/做空中芯')
plt.title('价格比率与交易信号')
plt.legend()
plt.grid(True, alpha=0.3)
# 2. 累积收益
ax2 = plt.subplot(3, 1, 2)
try:
portfolio.value.vbt.plot(ax=ax2, label='策略价值')
(portfolio.init_cash * (1 + portfolio.returns).cumprod()).vbt.plot(ax=ax2, label='买入持有')
plt.title('策略价值 vs 买入持有')
plt.legend()
plt.grid(True, alpha=0.3)
except:
plt.text(0.5, 0.5, '收益数据无法显示', ha='center', va='center', transform=ax2.transAxes)
# 3. 仓位变化
ax3 = plt.subplot(3, 1, 3)
try:
portfolio.positions.records_readable['Size'].groupby(portfolio.positions.records_readable['Timestamp']).sum().vbt.plot(ax=ax3)
plt.title('仓位变化')
plt.grid(True, alpha=0.3)
except:
# 如果上面的方法失败,使用备选方法
try:
portfolio.holdings.vbt.plot(ax=ax3)
plt.title('持仓价值')
plt.grid(True, alpha=0.3)
except:
plt.text(0.5, 0.5, '仓位数据无法显示', ha='center', va='center', transform=ax3.transAxes)
plt.tight_layout()
plt.show()
# 显示详细的交易记录
print("\n交易记录详情:")
try:
trades = portfolio.trades.records_readable
if len(trades) > 0:
print(trades[['Entry Index', 'Column', 'Size', 'Entry Price', 'Exit Price', 'PnL']].tail(10))
else:
print("没有交易记录")
except:
print("无法获取交易记录")
# 参数优化分析
print("\n" + "="*50)
print("参数优化分析")
print("="*50)
# 测试不同的k值
k_values = [1.0, 1.3, 1.5, 1.7, 2.0]
results = []
for k in k_values:
try:
test_strategy = PairsTradingStrategy(
price_a=smic_close,
price_b=huahong_close,
lookback_window=60,
k=k
)
test_portfolio, _ = test_strategy.backtest(initial_cash=100000)
results.append({
'k': k,
'总收益率': test_portfolio.total_return(),
'年化收益率': test_portfolio.annualized_return(),
'最大回撤': test_portfolio.max_drawdown(),
'夏普比率': test_portfolio.sharpe_ratio(),
'总交易次数': getattr(test_portfolio.stats(), 'get', lambda x: 'N/A')('Total Trades')
})
except Exception as e:
print(f"参数k={k}测试失败: {e}")
continue
if results:
results_df = pd.DataFrame(results)
print(results_df.round(4))
else:
print("所有参数测试都失败了")
# 显示价格比率的基本统计信息
print("\n价格比率统计信息:")
ratio = signals['ratio'].dropna()
print(f"均值: {ratio.mean():.4f}")
print(f"标准差: {ratio.std():.4f}")
print(f"最小值: {ratio.min():.4f}")
print(f"最大值: {ratio.max():.4f}")
print(f"当前值: {ratio.iloc[-1]:.4f}")
# 显示策略信号统计
print(f"\n策略信号统计:")
print(f"做空华虹/做多中芯信号次数: {signals['short_b_long_a'].sum()}")
print(f"做多华虹/做空中芯信号次数: {signals['long_b_short_a'].sum()}")