Files
quant/bt.py

294 lines
10 KiB
Python
Raw Permalink Normal View History

2025-10-31 09:32:53 +08:00
import vectorbt as vbt
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
# 设置中文字体和显示格式
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
pd.set_option('display.float_format', '{:.4f}'.format)
stock_00981 = ak.stock_hk_daily(symbol="00981")
stock_01347 = ak.stock_hk_daily(symbol="01347")
class PairsTradingStrategy:
def __init__(self, price_a, price_b, lookback_window=60, k=1.5):
"""
初始化配对交易策略
"""
self.price_a = price_a
self.price_b = price_b
self.lookback_window = lookback_window
self.k = k
def calculate_ratio(self):
"""计算价格比率 RS = 华虹/中芯"""
return self.price_b / self.price_a
def generate_signals(self):
"""生成交易信号"""
ratio = self.calculate_ratio()
# 计算滚动均值和标准差
ratio_mean = ratio.rolling(window=self.lookback_window).mean()
ratio_std = ratio.rolling(window=self.lookback_window).std()
# 计算上下轨道
upper_band = ratio_mean + self.k * ratio_std
lower_band = ratio_mean - self.k * ratio_std
# 生成信号
short_b_long_a = ratio > upper_band # 做空华虹,做多中芯
long_b_short_a = ratio < lower_band # 做多华虹,做空中芯
exit_signal = (ratio <= ratio_mean) & (ratio >= ratio_mean) # 回归均值平仓
return {
'ratio': ratio,
'ratio_mean': ratio_mean,
'upper_band': upper_band,
'lower_band': lower_band,
'short_b_long_a': short_b_long_a,
'long_b_short_a': long_b_short_a,
'exit_signal': exit_signal
}
def backtest(self, initial_cash=100000, transaction_cost=0.001):
"""执行回测 - 修复后的版本"""
signals = self.generate_signals()
# 创建价格DataFrame
prices = pd.DataFrame({
'SMIC': self.price_a, # 中芯国际
'HuaHong': self.price_b # 华虹半导体
})
# 生成仓位信号
positions = self._generate_positions(signals, prices)
# 使用vectorbt进行回测
portfolio = vbt.Portfolio.from_holdings(
prices,
size=positions, # 仓位大小
init_cash=initial_cash,
fees=transaction_cost,
freq='1D'
)
return portfolio, signals
def _generate_positions(self, signals, prices):
"""生成仓位序列"""
positions = pd.DataFrame(0, index=prices.index, columns=prices.columns)
cash_per_trade = 0.2 # 每次交易使用20%资金
current_position = 0 # 0: 无仓位, 1: 做空华虹做多中芯, -1: 做多华虹做空中芯
for i in range(len(prices)):
if i < self.lookback_window:
continue
current_date = prices.index[i]
# 退出信号
if current_position != 0 and signals['exit_signal'].iloc[i]:
positions.loc[current_date, :] = 0
current_position = 0
# 开仓信号 - 只在没有持仓时开仓
elif current_position == 0:
if signals['short_b_long_a'].iloc[i]:
# 做空华虹,做多中芯
positions.loc[current_date, 'HuaHong'] = -cash_per_trade # 做空
positions.loc[current_date, 'SMIC'] = cash_per_trade # 做多
current_position = 1
elif signals['long_b_short_a'].iloc[i]:
# 做多华虹,做空中芯
positions.loc[current_date, 'HuaHong'] = cash_per_trade # 做多
positions.loc[current_date, 'SMIC'] = -cash_per_trade # 做空
current_position = -1
# 前向填充仓位,直到下一个信号
positions = positions.replace(0, np.nan).ffill().fillna(0)
return positions
# 数据预处理函数
def prepare_data(stock_00981, stock_01347):
"""准备回测数据"""
# 复制数据避免修改原数据
smic_data = stock_00981.copy()
huahong_data = stock_01347.copy()
# 设置日期索引
smic_data['date'] = pd.to_datetime(smic_data['date'])
huahong_data['date'] = pd.to_datetime(huahong_data['date'])
smic_data = smic_data.set_index('date')
huahong_data = huahong_data.set_index('date')
# 对齐数据 - 只保留两个股票都有的交易日
common_dates = smic_data.index.intersection(huahong_data.index)
smic_aligned = smic_data.loc[common_dates]
huahong_aligned = huahong_data.loc[common_dates]
return smic_aligned['close'], huahong_aligned['close']
# 执行回测
print("准备数据...")
smic_close, huahong_close = prepare_data(stock_00981, stock_01347)
print(f"数据时间范围: {smic_close.index.min()}{smic_close.index.max()}")
print(f"总交易日数: {len(smic_close)}")
print(f"中芯国际价格范围: {smic_close.min():.2f} - {smic_close.max():.2f}")
print(f"华虹半导体价格范围: {huahong_close.min():.2f} - {huahong_close.max():.2f}")
# 创建策略实例
strategy = PairsTradingStrategy(
price_a=smic_close, # 中芯国际
price_b=huahong_close, # 华虹半导体
lookback_window=60, # 60日滚动窗口
k=1.5 # 1.5倍标准差
)
print("执行回测...")
portfolio, signals = strategy.backtest(
initial_cash=100000, # 初始资金10万
transaction_cost=0.001 # 交易成本0.1%
)
# 分析结果
print("\n" + "="*50)
print("回测结果分析")
print("="*50)
# 基本统计
try:
print(f"总收益率: {portfolio.total_return():.2%}")
print(f"年化收益率: {portfolio.annualized_return():.2%}")
print(f"最大回撤: {portfolio.max_drawdown():.2%}")
print(f"夏普比率: {portfolio.sharpe_ratio():.2f}")
except:
print("部分指标计算失败,继续显示其他结果...")
# 交易统计
try:
stats = portfolio.stats()
print(f"总交易次数: {stats['Total Trades']}")
print(f"胜率: {stats.get('Win Rate [%]', 'N/A')}%")
print(f"盈亏比: {stats.get('Profit Factor', 'N/A')}")
except:
print("交易统计获取失败")
# 可视化结果
fig = plt.figure(figsize=(15, 12))
# 1. 价格比率和交易信号
ax1 = plt.subplot(3, 1, 1)
plt.plot(signals['ratio'].index, signals['ratio'].values, label='价格比率(华虹/中芯)', linewidth=1)
plt.plot(signals['ratio_mean'].index, signals['ratio_mean'].values, label='滚动均值', linestyle='--', alpha=0.7)
plt.plot(signals['upper_band'].index, signals['upper_band'].values, label=f'上轨(μ+{strategy.k}σ)', linestyle='--', color='red', alpha=0.7)
plt.plot(signals['lower_band'].index, signals['lower_band'].values, label=f'下轨(μ-{strategy.k}σ)', linestyle='--', color='green', alpha=0.7)
# 标记交易信号
short_signals = signals['ratio'][signals['short_b_long_a'] & (signals['ratio'].notna())]
long_signals = signals['ratio'][signals['long_b_short_a'] & (signals['ratio'].notna())]
if len(short_signals) > 0:
plt.scatter(short_signals.index, short_signals.values, color='red', marker='v', s=50, label='做空华虹/做多中芯')
if len(long_signals) > 0:
plt.scatter(long_signals.index, long_signals.values, color='green', marker='^', s=50, label='做多华虹/做空中芯')
plt.title('价格比率与交易信号')
plt.legend()
plt.grid(True, alpha=0.3)
# 2. 累积收益
ax2 = plt.subplot(3, 1, 2)
try:
portfolio.value.vbt.plot(ax=ax2, label='策略价值')
(portfolio.init_cash * (1 + portfolio.returns).cumprod()).vbt.plot(ax=ax2, label='买入持有')
plt.title('策略价值 vs 买入持有')
plt.legend()
plt.grid(True, alpha=0.3)
except:
plt.text(0.5, 0.5, '收益数据无法显示', ha='center', va='center', transform=ax2.transAxes)
# 3. 仓位变化
ax3 = plt.subplot(3, 1, 3)
try:
portfolio.positions.records_readable['Size'].groupby(portfolio.positions.records_readable['Timestamp']).sum().vbt.plot(ax=ax3)
plt.title('仓位变化')
plt.grid(True, alpha=0.3)
except:
# 如果上面的方法失败,使用备选方法
try:
portfolio.holdings.vbt.plot(ax=ax3)
plt.title('持仓价值')
plt.grid(True, alpha=0.3)
except:
plt.text(0.5, 0.5, '仓位数据无法显示', ha='center', va='center', transform=ax3.transAxes)
plt.tight_layout()
plt.show()
# 显示详细的交易记录
print("\n交易记录详情:")
try:
trades = portfolio.trades.records_readable
if len(trades) > 0:
print(trades[['Entry Index', 'Column', 'Size', 'Entry Price', 'Exit Price', 'PnL']].tail(10))
else:
print("没有交易记录")
except:
print("无法获取交易记录")
# 参数优化分析
print("\n" + "="*50)
print("参数优化分析")
print("="*50)
# 测试不同的k值
k_values = [1.0, 1.3, 1.5, 1.7, 2.0]
results = []
for k in k_values:
try:
test_strategy = PairsTradingStrategy(
price_a=smic_close,
price_b=huahong_close,
lookback_window=60,
k=k
)
test_portfolio, _ = test_strategy.backtest(initial_cash=100000)
results.append({
'k': k,
'总收益率': test_portfolio.total_return(),
'年化收益率': test_portfolio.annualized_return(),
'最大回撤': test_portfolio.max_drawdown(),
'夏普比率': test_portfolio.sharpe_ratio(),
'总交易次数': getattr(test_portfolio.stats(), 'get', lambda x: 'N/A')('Total Trades')
})
except Exception as e:
print(f"参数k={k}测试失败: {e}")
continue
if results:
results_df = pd.DataFrame(results)
print(results_df.round(4))
else:
print("所有参数测试都失败了")
# 显示价格比率的基本统计信息
print("\n价格比率统计信息:")
ratio = signals['ratio'].dropna()
print(f"均值: {ratio.mean():.4f}")
print(f"标准差: {ratio.std():.4f}")
print(f"最小值: {ratio.min():.4f}")
print(f"最大值: {ratio.max():.4f}")
print(f"当前值: {ratio.iloc[-1]:.4f}")
# 显示策略信号统计
print(f"\n策略信号统计:")
print(f"做空华虹/做多中芯信号次数: {signals['short_b_long_a'].sum()}")
print(f"做多华虹/做空中芯信号次数: {signals['long_b_short_a'].sum()}")