Files
quant/PairsTradingSimple.py

549 lines
19 KiB
Python
Raw Normal View History

2025-11-01 15:34:46 +08:00
import vectorbt as vbt
import akshare as ak
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import datetime
from numba import njit
from collections import namedtuple
# 设置中文显示
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# 获取股票数据
print("正在获取股票数据...")
stock_00981 = ak.stock_hk_daily(symbol="00981")
stock_01347 = ak.stock_hk_daily(symbol="01347")
print("中芯国际数据列名:", stock_00981.columns.tolist())
print("华虹半导体数据列名:", stock_01347.columns.tolist())
# 数据预处理
def preprocess_data(df, symbol):
"""预处理股票数据"""
df = df.copy()
# 检查列名并重命名(如果需要)
if 'date' in df.columns:
df['date'] = pd.to_datetime(df['date'])
df.set_index('date', inplace=True)
elif '日期' in df.columns:
df['date'] = pd.to_datetime(df['日期'])
df.set_index('date', inplace=True)
# 重命名中文列为英文
rename_dict = {
'开盘': 'open',
'最高': 'high',
'最低': 'low',
'收盘': 'close',
'成交量': 'volume'
}
df = df.rename(columns=rename_dict)
else:
# 如果已经有英文列名,直接使用
df.index = pd.to_datetime(df.index)
df = df.sort_index()
return df[['open', 'high', 'low', 'close', 'volume']]
# 预处理数据
smic_data = preprocess_data(stock_00981, "00981")
hhic_data = preprocess_data(stock_01347, "01347")
print(f"中芯国际原始数据时间范围: {smic_data.index.min()}{smic_data.index.max()}")
print(f"华虹半导体原始数据时间范围: {hhic_data.index.min()}{hhic_data.index.max()}")
# 限制为最近三年数据
end_date = smic_data.index.max()
start_date = end_date - pd.Timedelta(days=3*365) # 最近三年
print(f"\n限制回测时间范围: {start_date}{end_date}")
smic_data = smic_data.loc[start_date:end_date]
hhic_data = hhic_data.loc[start_date:end_date]
print(f"限制后中芯国际数据形状: {smic_data.shape}")
print(f"限制后华虹半导体数据形状: {hhic_data.shape}")
# 对齐数据时间范围
common_index = smic_data.index.intersection(hhic_data.index)
smic_data = smic_data.loc[common_index]
hhic_data = hhic_data.loc[common_index]
print(f"\n对齐后数据时间范围: {smic_data.index.min()}{smic_data.index.max()}")
print(f"对齐后数据点数: {len(smic_data)}")
# 设置交易参数
initial_cash = 100000
commission = 0.001 # 0.1% 交易佣金
position_size = 0.1 # 每次交易仓位比例
# 计算配对交易信号
def calculate_pair_signals(price1, price2, window=20, num_std=2):
"""
计算配对交易信号
"""
# 计算价格比率
ratio = price1 / price2
# 计算布林带
ratio_ma = ratio.rolling(window=window).mean()
ratio_std = ratio.rolling(window=window).std()
upper_band = ratio_ma + num_std * ratio_std
lower_band = ratio_ma - num_std * ratio_std
# 生成交易信号
# 1: 做多价差 (买中芯/卖华虹)
# -1: 做空价差 (卖中芯/买华虹)
# 0: 平仓
signals = pd.Series(0, index=ratio.index, name='signal')
# 当比率突破下轨时做多价差
long_condition = (ratio < lower_band) & (ratio_ma.notna())
signals[long_condition] = 1
# 当比率突破上轨时做空价差
short_condition = (ratio > upper_band) & (ratio_ma.notna())
signals[short_condition] = -1
# 当比率回归均值时平仓
close_condition = (ratio.between(lower_band, upper_band)) & (signals.shift(1) != 0)
signals[close_condition] = 0
return signals, ratio, ratio_ma, upper_band, lower_band
# 计算信号
close_smic = smic_data['close']
close_hhic = hhic_data['close']
signals, ratio, ratio_ma, upper_band, lower_band = calculate_pair_signals(
close_smic, close_hhic, window=20, num_std=2
)
print(f"信号计算完成,有效信号数量: {(signals != 0).sum()}")
# 创建组合数据
close = pd.DataFrame({
'00981': close_smic,
'01347': close_hhic
})
2025-11-01 15:50:50 +08:00
# 基于信号生成size数据
def generate_pair_size(signals, close_smic, close_hhic, position_size=0.1):
2025-11-01 15:34:46 +08:00
"""
2025-11-01 15:50:50 +08:00
生成配对交易的size数据
返回一个与close相同形状的DataFrame包含每只股票的交易数量
2025-11-01 15:34:46 +08:00
"""
2025-11-01 15:50:50 +08:00
# 创建与close相同形状的size DataFrame初始为0
size_df = pd.DataFrame(0, index=close.index, columns=close.columns)
2025-11-01 15:34:46 +08:00
current_position = 0
for i in range(len(signals)):
if i < 20: # 跳过布林带计算期
continue
date = signals.index[i]
signal = signals.iloc[i]
if signal == 1 and current_position != 1: # 做多价差
# 买中芯,卖华虹
smic_price = close_smic.iloc[i]
hhic_price = close_hhic.iloc[i]
# 计算头寸规模
smic_shares = int((initial_cash * position_size) / smic_price)
hhic_shares = int((initial_cash * position_size) / hhic_price)
2025-11-01 15:50:50 +08:00
size_df.loc[date, '00981'] = smic_shares # 买入中芯
size_df.loc[date, '01347'] = -hhic_shares # 卖空华虹
2025-11-01 15:34:46 +08:00
current_position = 1
elif signal == -1 and current_position != -1: # 做空价差
# 卖中芯,买华虹
smic_price = close_smic.iloc[i]
hhic_price = close_hhic.iloc[i]
# 计算头寸规模
smic_shares = int((initial_cash * position_size) / smic_price)
hhic_shares = int((initial_cash * position_size) / hhic_price)
2025-11-01 15:50:50 +08:00
size_df.loc[date, '00981'] = -smic_shares # 卖空中芯
size_df.loc[date, '01347'] = hhic_shares # 买入华虹
2025-11-01 15:34:46 +08:00
current_position = -1
elif signal == 0 and current_position != 0: # 平仓
# 平掉所有头寸
2025-11-01 15:50:50 +08:00
size_df.loc[date, '00981'] = 0 # 平仓中芯
size_df.loc[date, '01347'] = 0 # 平仓华虹
2025-11-01 15:34:46 +08:00
current_position = 0
2025-11-01 15:50:50 +08:00
return size_df
2025-11-01 15:34:46 +08:00
2025-11-01 15:50:50 +08:00
# 生成size数据
size = generate_pair_size(signals, close_smic, close_hhic, position_size)
2025-11-01 15:34:46 +08:00
2025-11-01 15:50:50 +08:00
print(f"size数据形状: {size.shape}")
print(f"非零交易数量 - 中芯国际: {(size['00981'] != 0).sum()}, 华虹半导体: {(size['01347'] != 0).sum()}")
2025-11-01 15:34:46 +08:00
2025-11-01 15:50:50 +08:00
# 创建投资组合
2025-11-01 15:34:46 +08:00
print("创建投资组合...")
try:
portfolio = vbt.Portfolio.from_orders(
close=close,
2025-11-01 15:50:50 +08:00
size=size,
2025-11-01 15:34:46 +08:00
init_cash=initial_cash,
fees=commission,
freq='D'
)
print("投资组合创建成功!")
2025-11-01 16:03:50 +08:00
# 计算配对交易统计
2025-11-01 15:34:46 +08:00
print("\n=== 配对交易策略表现 ===")
2025-11-01 15:50:50 +08:00
# 获取组合总价值
portfolio_value = portfolio.value()
2025-11-01 15:34:46 +08:00
# 绘制结果
print("\n绘制图表...")
# 1. 价格比率和交易信号
fig, axes = plt.subplots(3, 1, figsize=(15, 12))
# 价格比率图
axes[0].plot(ratio.index, ratio, label='价格比率(中芯/华虹)', linewidth=1)
axes[0].plot(ratio.index, ratio_ma, label='移动平均', linewidth=1, alpha=0.7)
axes[0].plot(ratio.index, upper_band, label='上轨', linewidth=1, alpha=0.7, linestyle='--', color='red')
axes[0].plot(ratio.index, lower_band, label='下轨', linewidth=1, alpha=0.7, linestyle='--', color='green')
axes[0].set_title('中芯国际-华虹半导体价格比率')
axes[0].set_ylabel('价格比率')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# 标记交易信号
long_signals = signals[signals == 1]
short_signals = signals[signals == -1]
2025-11-01 15:50:50 +08:00
if len(long_signals) > 0:
axes[0].scatter(long_signals.index, ratio[long_signals.index],
color='green', marker='^', s=50, label='做多信号', zorder=5)
if len(short_signals) > 0:
axes[0].scatter(short_signals.index, ratio[short_signals.index],
color='red', marker='v', s=50, label='做空信号', zorder=5)
2025-11-01 15:34:46 +08:00
# 累积收益图
2025-11-01 15:50:50 +08:00
if len(portfolio_value) > 0:
axes[1].plot(portfolio_value.index, portfolio_value, label='组合净值', linewidth=1)
axes[1].set_title('组合净值曲线')
axes[1].set_ylabel('组合价值')
axes[1].grid(True, alpha=0.3)
2025-11-01 15:34:46 +08:00
# 回撤图
2025-11-01 15:50:50 +08:00
drawdown = portfolio.drawdown()
if len(drawdown) > 0:
axes[2].plot(drawdown.index, drawdown, label='回撤', linewidth=1, color='red')
axes[2].set_title('回撤')
axes[2].set_ylabel('回撤')
axes[2].set_xlabel('日期')
axes[2].grid(True, alpha=0.3)
2025-11-01 15:34:46 +08:00
plt.tight_layout()
plt.show()
2025-11-01 16:03:50 +08:00
# ========== 修复vectorbt自带的可视化 ==========
print("\n=== VectorBT 专业可视化 ===")
# 1. 组合价值、持仓和交易的可视化 - 分别绘制每只股票
print("\n绘制中芯国际分析...")
try:
# 中芯国际
fig1 = portfolio['00981'].plot(subplots=[
'orders', # 订单
'trade_pnl', # 交易盈亏
'cum_returns', # 累积收益
'drawdowns' # 回撤
])
fig1.update_layout(
title='中芯国际配对交易分析',
height=800
)
fig1.show()
except Exception as e:
print(f"中芯国际分析绘制失败: {e}")
print("\n绘制华虹半导体分析...")
try:
# 华虹半导体
fig2 = portfolio['01347'].plot(subplots=[
'orders', # 订单
'trade_pnl', # 交易盈亏
'cum_returns', # 累积收益
'drawdowns' # 回撤
])
fig2.update_layout(
title='华虹半导体配对交易分析',
height=800
)
fig2.show()
except Exception as e:
print(f"华虹半导体分析绘制失败: {e}")
# 2. 资产价值变化 - 使用组合总价值
print("\n绘制组合总价值变化...")
try:
fig = portfolio_value.vbt.plot(
2025-11-01 16:10:01 +08:00
title='配对交易组合总价值'
)
fig.update_layout(
xaxis_title='日期',
yaxis_title='组合价值'
2025-11-01 16:03:50 +08:00
)
fig.show()
except Exception as e:
print(f"组合价值变化绘制失败: {e}")
# 3. 累积收益 - 使用组合总收益
print("\n绘制组合累积收益...")
try:
cumulative_returns = portfolio.cumulative_returns()
# 如果是多列,取第一列或平均值
if hasattr(cumulative_returns, 'columns') and len(cumulative_returns.columns) > 1:
cumulative_returns = cumulative_returns.mean(axis=1)
fig = cumulative_returns.vbt.plot(
2025-11-01 16:10:01 +08:00
title='配对交易累积收益率'
)
fig.update_layout(
xaxis_title='日期',
yaxis_title='累积收益'
2025-11-01 16:03:50 +08:00
)
fig.show()
except Exception as e:
print(f"累积收益绘制失败: {e}")
# 4. 回撤分析 - 使用组合总回撤
print("\n绘制组合回撤分析...")
try:
# 如果是多列,取第一列或平均值
if hasattr(drawdown, 'columns') and len(drawdown.columns) > 1:
drawdown = drawdown.mean(axis=1)
fig = drawdown.vbt.plot(
2025-11-01 16:10:01 +08:00
title='配对交易回撤分析'
)
fig.update_layout(
xaxis_title='日期',
yaxis_title='回撤'
2025-11-01 16:03:50 +08:00
)
fig.show()
except Exception as e:
print(f"回撤分析绘制失败: {e}")
# 5. 月度收益热力图
print("\n绘制月度收益热力图...")
try:
monthly_returns = portfolio.returns().vbt.to_monthly()
# 如果是多列,取平均值
if hasattr(monthly_returns, 'columns') and len(monthly_returns.columns) > 1:
monthly_returns = monthly_returns.mean(axis=1)
fig = monthly_returns.vbt.heatmap(
title='配对交易月度收益热力图'
)
2025-11-01 16:10:01 +08:00
fig.update_layout(
xaxis_title='年份',
yaxis_title='月份'
)
2025-11-01 16:03:50 +08:00
fig.show()
except Exception as e:
print(f"月度收益热力图绘制失败: {e}")
# 6. 交易分析 - 分别分析每只股票
print("\n绘制交易分析...")
try:
# 中芯国际交易分析
trades_smic = portfolio['00981'].trades
if len(trades_smic) > 0:
fig = trades_smic.plot_pnl()
fig.update_layout(title='中芯国际交易盈亏分布')
fig.show()
fig = trades_smic.plot_duration()
fig.update_layout(title='中芯国际交易持续时间分布')
fig.show()
# 华虹半导体交易分析
trades_hhic = portfolio['01347'].trades
if len(trades_hhic) > 0:
fig = trades_hhic.plot_pnl()
fig.update_layout(title='华虹半导体交易盈亏分布')
fig.show()
fig = trades_hhic.plot_duration()
fig.update_layout(title='华虹半导体交易持续时间分布')
fig.show()
except Exception as e:
print(f"交易分析绘制失败: {e}")
# 7. 订单流分析 - 分别分析每只股票
print("\n绘制订单流分析...")
try:
# 中芯国际订单
fig1 = portfolio['00981'].orders.plot()
fig1.update_layout(title='中芯国际订单流分析')
fig1.show()
# 华虹半导体订单
fig2 = portfolio['01347'].orders.plot()
fig2.update_layout(title='华虹半导体订单流分析')
fig2.show()
except Exception as e:
print(f"订单流分析绘制失败: {e}")
# 8. 持仓分析 - 分别分析每只股票
print("\n绘制持仓分析...")
try:
# 中芯国际持仓
holdings_smic = portfolio['00981'].holdings
fig1 = holdings_smic.vbt.plot(
2025-11-01 16:10:01 +08:00
title='中芯国际持仓变化'
)
fig1.update_layout(
xaxis_title='日期',
yaxis_title='持仓价值'
2025-11-01 16:03:50 +08:00
)
fig1.show()
# 华虹半导体持仓
holdings_hhic = portfolio['01347'].holdings
fig2 = holdings_hhic.vbt.plot(
2025-11-01 16:10:01 +08:00
title='华虹半导体持仓变化'
)
fig2.update_layout(
xaxis_title='日期',
yaxis_title='持仓价值'
2025-11-01 16:03:50 +08:00
)
fig2.show()
except Exception as e:
print(f"持仓分析绘制失败: {e}")
# 9. 现金变化 - 使用组合总现金
print("\n绘制现金变化...")
try:
cash = portfolio.cash
# 如果是多列,取第一列
if hasattr(cash, 'columns') and len(cash.columns) > 1:
cash = cash.iloc[:, 0]
fig = cash.vbt.plot(
2025-11-01 16:10:01 +08:00
title='配对交易现金余额变化'
)
fig.update_layout(
xaxis_title='日期',
yaxis_title='现金余额'
2025-11-01 16:03:50 +08:00
)
fig.show()
except Exception as e:
print(f"现金变化绘制失败: {e}")
2025-11-01 16:10:01 +08:00
# 10. 资产配置比例
print("\n绘制资产配置比例...")
try:
# 计算各资产权重
asset_value = portfolio.asset_value()
total_value = portfolio.value()
weights = asset_value.div(total_value, axis=0).fillna(0)
fig = weights.vbt.areaplot(
title='资产配置比例变化'
)
fig.update_layout(
xaxis_title='日期',
yaxis_title='权重'
)
fig.show()
except Exception as e:
print(f"资产配置比例绘制失败: {e}")
2025-11-01 15:50:50 +08:00
# 显示交易记录
non_zero_size = size[(size != 0).any(axis=1)]
print(f"\n非零交易数量: {len(non_zero_size)}")
if len(non_zero_size) > 0:
print("交易记录示例:")
print(non_zero_size.head(20))
2025-11-01 15:34:46 +08:00
2025-11-01 16:03:50 +08:00
# 打印详细统计
2025-11-01 15:34:46 +08:00
print("\n=== 详细统计 ===")
2025-11-01 15:50:50 +08:00
try:
2025-11-01 16:03:50 +08:00
# 分别获取每只股票的统计
stats_smic = portfolio['00981'].stats()
stats_hhic = portfolio['01347'].stats()
2025-11-01 15:50:50 +08:00
def safe_get_stat(stat_dict, key, default="N/A"):
value = stat_dict.get(key, default)
if hasattr(value, 'iloc'):
return value.iloc[0] if len(value) == 1 else value
return value
2025-11-01 16:03:50 +08:00
print("\n中芯国际统计:")
print(f"总收益率: {safe_get_stat(stats_smic, 'Total Return [%]', 'N/A')}%")
print(f"年化收益率: {safe_get_stat(stats_smic, 'Annual Return [%]', 'N/A')}%")
print(f"夏普比率: {safe_get_stat(stats_smic, 'Sharpe Ratio', 'N/A')}")
print(f"最大回撤: {safe_get_stat(stats_smic, 'Max Drawdown [%]', 'N/A')}%")
print("\n华虹半导体统计:")
print(f"总收益率: {safe_get_stat(stats_hhic, 'Total Return [%]', 'N/A')}%")
print(f"年化收益率: {safe_get_stat(stats_hhic, 'Annual Return [%]', 'N/A')}%")
print(f"夏普比率: {safe_get_stat(stats_hhic, 'Sharpe Ratio', 'N/A')}")
print(f"最大回撤: {safe_get_stat(stats_hhic, 'Max Drawdown [%]', 'N/A')}%")
2025-11-01 15:50:50 +08:00
except Exception as e:
print(f"获取详细统计时出错: {e}")
2025-11-01 15:34:46 +08:00
# 分析每笔交易
2025-11-01 15:50:50 +08:00
try:
2025-11-01 16:03:50 +08:00
trades_smic = portfolio['00981'].trades.records_readable
trades_hhic = portfolio['01347'].trades.records_readable
if len(trades_smic) > 0:
print(f"\n中芯国际交易分析:")
print(f"总交易次数: {len(trades_smic)}")
if 'Duration' in trades_smic.columns:
print(f"平均持仓时间: {trades_smic['Duration'].mean():.1f}")
if len(trades_hhic) > 0:
print(f"\n华虹半导体交易分析:")
print(f"总交易次数: {len(trades_hhic)}")
if 'Duration' in trades_hhic.columns:
print(f"平均持仓时间: {trades_hhic['Duration'].mean():.1f}")
2025-11-01 15:50:50 +08:00
except Exception as e:
print(f"分析交易时出错: {e}")
2025-11-01 15:34:46 +08:00
except Exception as e:
print(f"创建投资组合时出错: {e}")
2025-11-01 15:50:50 +08:00
import traceback
traceback.print_exc()
2025-11-01 15:34:46 +08:00
# 显示原始价格走势对比
print("\n绘制原始价格走势...")
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 10))
# 价格走势
ax1.plot(close_smic.index, close_smic, label='中芯国际', linewidth=1)
ax1.plot(close_hhic.index, close_hhic, label='华虹半导体', linewidth=1)
ax1.set_title('股票价格走势')
ax1.set_ylabel('价格')
ax1.legend()
ax1.grid(True, alpha=0.3)
# 价格比率
ax2.plot(ratio.index, ratio, label='价格比率', linewidth=1, color='purple')
ax2.axhline(y=ratio.mean(), color='red', linestyle='--', alpha=0.7, label='平均比率')
ax2.set_title('价格比率走势')
ax2.set_ylabel('比率')
ax2.set_xlabel('日期')
ax2.legend()
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print("程序执行完成!")