Files
quant/PairsTradingSimple.py
2025-11-01 16:10:01 +08:00

549 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import vectorbt as vbt
import akshare as ak
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import datetime
from numba import njit
from collections import namedtuple
# 设置中文显示
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# 获取股票数据
print("正在获取股票数据...")
stock_00981 = ak.stock_hk_daily(symbol="00981")
stock_01347 = ak.stock_hk_daily(symbol="01347")
print("中芯国际数据列名:", stock_00981.columns.tolist())
print("华虹半导体数据列名:", stock_01347.columns.tolist())
# 数据预处理
def preprocess_data(df, symbol):
"""预处理股票数据"""
df = df.copy()
# 检查列名并重命名(如果需要)
if 'date' in df.columns:
df['date'] = pd.to_datetime(df['date'])
df.set_index('date', inplace=True)
elif '日期' in df.columns:
df['date'] = pd.to_datetime(df['日期'])
df.set_index('date', inplace=True)
# 重命名中文列为英文
rename_dict = {
'开盘': 'open',
'最高': 'high',
'最低': 'low',
'收盘': 'close',
'成交量': 'volume'
}
df = df.rename(columns=rename_dict)
else:
# 如果已经有英文列名,直接使用
df.index = pd.to_datetime(df.index)
df = df.sort_index()
return df[['open', 'high', 'low', 'close', 'volume']]
# 预处理数据
smic_data = preprocess_data(stock_00981, "00981")
hhic_data = preprocess_data(stock_01347, "01347")
print(f"中芯国际原始数据时间范围: {smic_data.index.min()}{smic_data.index.max()}")
print(f"华虹半导体原始数据时间范围: {hhic_data.index.min()}{hhic_data.index.max()}")
# 限制为最近三年数据
end_date = smic_data.index.max()
start_date = end_date - pd.Timedelta(days=3*365) # 最近三年
print(f"\n限制回测时间范围: {start_date}{end_date}")
smic_data = smic_data.loc[start_date:end_date]
hhic_data = hhic_data.loc[start_date:end_date]
print(f"限制后中芯国际数据形状: {smic_data.shape}")
print(f"限制后华虹半导体数据形状: {hhic_data.shape}")
# 对齐数据时间范围
common_index = smic_data.index.intersection(hhic_data.index)
smic_data = smic_data.loc[common_index]
hhic_data = hhic_data.loc[common_index]
print(f"\n对齐后数据时间范围: {smic_data.index.min()}{smic_data.index.max()}")
print(f"对齐后数据点数: {len(smic_data)}")
# 设置交易参数
initial_cash = 100000
commission = 0.001 # 0.1% 交易佣金
position_size = 0.1 # 每次交易仓位比例
# 计算配对交易信号
def calculate_pair_signals(price1, price2, window=20, num_std=2):
"""
计算配对交易信号
"""
# 计算价格比率
ratio = price1 / price2
# 计算布林带
ratio_ma = ratio.rolling(window=window).mean()
ratio_std = ratio.rolling(window=window).std()
upper_band = ratio_ma + num_std * ratio_std
lower_band = ratio_ma - num_std * ratio_std
# 生成交易信号
# 1: 做多价差 (买中芯/卖华虹)
# -1: 做空价差 (卖中芯/买华虹)
# 0: 平仓
signals = pd.Series(0, index=ratio.index, name='signal')
# 当比率突破下轨时做多价差
long_condition = (ratio < lower_band) & (ratio_ma.notna())
signals[long_condition] = 1
# 当比率突破上轨时做空价差
short_condition = (ratio > upper_band) & (ratio_ma.notna())
signals[short_condition] = -1
# 当比率回归均值时平仓
close_condition = (ratio.between(lower_band, upper_band)) & (signals.shift(1) != 0)
signals[close_condition] = 0
return signals, ratio, ratio_ma, upper_band, lower_band
# 计算信号
close_smic = smic_data['close']
close_hhic = hhic_data['close']
signals, ratio, ratio_ma, upper_band, lower_band = calculate_pair_signals(
close_smic, close_hhic, window=20, num_std=2
)
print(f"信号计算完成,有效信号数量: {(signals != 0).sum()}")
# 创建组合数据
close = pd.DataFrame({
'00981': close_smic,
'01347': close_hhic
})
# 基于信号生成size数据
def generate_pair_size(signals, close_smic, close_hhic, position_size=0.1):
"""
生成配对交易的size数据
返回一个与close相同形状的DataFrame包含每只股票的交易数量
"""
# 创建与close相同形状的size DataFrame初始为0
size_df = pd.DataFrame(0, index=close.index, columns=close.columns)
current_position = 0
for i in range(len(signals)):
if i < 20: # 跳过布林带计算期
continue
date = signals.index[i]
signal = signals.iloc[i]
if signal == 1 and current_position != 1: # 做多价差
# 买中芯,卖华虹
smic_price = close_smic.iloc[i]
hhic_price = close_hhic.iloc[i]
# 计算头寸规模
smic_shares = int((initial_cash * position_size) / smic_price)
hhic_shares = int((initial_cash * position_size) / hhic_price)
size_df.loc[date, '00981'] = smic_shares # 买入中芯
size_df.loc[date, '01347'] = -hhic_shares # 卖空华虹
current_position = 1
elif signal == -1 and current_position != -1: # 做空价差
# 卖中芯,买华虹
smic_price = close_smic.iloc[i]
hhic_price = close_hhic.iloc[i]
# 计算头寸规模
smic_shares = int((initial_cash * position_size) / smic_price)
hhic_shares = int((initial_cash * position_size) / hhic_price)
size_df.loc[date, '00981'] = -smic_shares # 卖空中芯
size_df.loc[date, '01347'] = hhic_shares # 买入华虹
current_position = -1
elif signal == 0 and current_position != 0: # 平仓
# 平掉所有头寸
size_df.loc[date, '00981'] = 0 # 平仓中芯
size_df.loc[date, '01347'] = 0 # 平仓华虹
current_position = 0
return size_df
# 生成size数据
size = generate_pair_size(signals, close_smic, close_hhic, position_size)
print(f"size数据形状: {size.shape}")
print(f"非零交易数量 - 中芯国际: {(size['00981'] != 0).sum()}, 华虹半导体: {(size['01347'] != 0).sum()}")
# 创建投资组合
print("创建投资组合...")
try:
portfolio = vbt.Portfolio.from_orders(
close=close,
size=size,
init_cash=initial_cash,
fees=commission,
freq='D'
)
print("投资组合创建成功!")
# 计算配对交易统计
print("\n=== 配对交易策略表现 ===")
# 获取组合总价值
portfolio_value = portfolio.value()
# 绘制结果
print("\n绘制图表...")
# 1. 价格比率和交易信号
fig, axes = plt.subplots(3, 1, figsize=(15, 12))
# 价格比率图
axes[0].plot(ratio.index, ratio, label='价格比率(中芯/华虹)', linewidth=1)
axes[0].plot(ratio.index, ratio_ma, label='移动平均', linewidth=1, alpha=0.7)
axes[0].plot(ratio.index, upper_band, label='上轨', linewidth=1, alpha=0.7, linestyle='--', color='red')
axes[0].plot(ratio.index, lower_band, label='下轨', linewidth=1, alpha=0.7, linestyle='--', color='green')
axes[0].set_title('中芯国际-华虹半导体价格比率')
axes[0].set_ylabel('价格比率')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# 标记交易信号
long_signals = signals[signals == 1]
short_signals = signals[signals == -1]
if len(long_signals) > 0:
axes[0].scatter(long_signals.index, ratio[long_signals.index],
color='green', marker='^', s=50, label='做多信号', zorder=5)
if len(short_signals) > 0:
axes[0].scatter(short_signals.index, ratio[short_signals.index],
color='red', marker='v', s=50, label='做空信号', zorder=5)
# 累积收益图
if len(portfolio_value) > 0:
axes[1].plot(portfolio_value.index, portfolio_value, label='组合净值', linewidth=1)
axes[1].set_title('组合净值曲线')
axes[1].set_ylabel('组合价值')
axes[1].grid(True, alpha=0.3)
# 回撤图
drawdown = portfolio.drawdown()
if len(drawdown) > 0:
axes[2].plot(drawdown.index, drawdown, label='回撤', linewidth=1, color='red')
axes[2].set_title('回撤')
axes[2].set_ylabel('回撤')
axes[2].set_xlabel('日期')
axes[2].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# ========== 修复vectorbt自带的可视化 ==========
print("\n=== VectorBT 专业可视化 ===")
# 1. 组合价值、持仓和交易的可视化 - 分别绘制每只股票
print("\n绘制中芯国际分析...")
try:
# 中芯国际
fig1 = portfolio['00981'].plot(subplots=[
'orders', # 订单
'trade_pnl', # 交易盈亏
'cum_returns', # 累积收益
'drawdowns' # 回撤
])
fig1.update_layout(
title='中芯国际配对交易分析',
height=800
)
fig1.show()
except Exception as e:
print(f"中芯国际分析绘制失败: {e}")
print("\n绘制华虹半导体分析...")
try:
# 华虹半导体
fig2 = portfolio['01347'].plot(subplots=[
'orders', # 订单
'trade_pnl', # 交易盈亏
'cum_returns', # 累积收益
'drawdowns' # 回撤
])
fig2.update_layout(
title='华虹半导体配对交易分析',
height=800
)
fig2.show()
except Exception as e:
print(f"华虹半导体分析绘制失败: {e}")
# 2. 资产价值变化 - 使用组合总价值
print("\n绘制组合总价值变化...")
try:
fig = portfolio_value.vbt.plot(
title='配对交易组合总价值'
)
fig.update_layout(
xaxis_title='日期',
yaxis_title='组合价值'
)
fig.show()
except Exception as e:
print(f"组合价值变化绘制失败: {e}")
# 3. 累积收益 - 使用组合总收益
print("\n绘制组合累积收益...")
try:
cumulative_returns = portfolio.cumulative_returns()
# 如果是多列,取第一列或平均值
if hasattr(cumulative_returns, 'columns') and len(cumulative_returns.columns) > 1:
cumulative_returns = cumulative_returns.mean(axis=1)
fig = cumulative_returns.vbt.plot(
title='配对交易累积收益率'
)
fig.update_layout(
xaxis_title='日期',
yaxis_title='累积收益'
)
fig.show()
except Exception as e:
print(f"累积收益绘制失败: {e}")
# 4. 回撤分析 - 使用组合总回撤
print("\n绘制组合回撤分析...")
try:
# 如果是多列,取第一列或平均值
if hasattr(drawdown, 'columns') and len(drawdown.columns) > 1:
drawdown = drawdown.mean(axis=1)
fig = drawdown.vbt.plot(
title='配对交易回撤分析'
)
fig.update_layout(
xaxis_title='日期',
yaxis_title='回撤'
)
fig.show()
except Exception as e:
print(f"回撤分析绘制失败: {e}")
# 5. 月度收益热力图
print("\n绘制月度收益热力图...")
try:
monthly_returns = portfolio.returns().vbt.to_monthly()
# 如果是多列,取平均值
if hasattr(monthly_returns, 'columns') and len(monthly_returns.columns) > 1:
monthly_returns = monthly_returns.mean(axis=1)
fig = monthly_returns.vbt.heatmap(
title='配对交易月度收益热力图'
)
fig.update_layout(
xaxis_title='年份',
yaxis_title='月份'
)
fig.show()
except Exception as e:
print(f"月度收益热力图绘制失败: {e}")
# 6. 交易分析 - 分别分析每只股票
print("\n绘制交易分析...")
try:
# 中芯国际交易分析
trades_smic = portfolio['00981'].trades
if len(trades_smic) > 0:
fig = trades_smic.plot_pnl()
fig.update_layout(title='中芯国际交易盈亏分布')
fig.show()
fig = trades_smic.plot_duration()
fig.update_layout(title='中芯国际交易持续时间分布')
fig.show()
# 华虹半导体交易分析
trades_hhic = portfolio['01347'].trades
if len(trades_hhic) > 0:
fig = trades_hhic.plot_pnl()
fig.update_layout(title='华虹半导体交易盈亏分布')
fig.show()
fig = trades_hhic.plot_duration()
fig.update_layout(title='华虹半导体交易持续时间分布')
fig.show()
except Exception as e:
print(f"交易分析绘制失败: {e}")
# 7. 订单流分析 - 分别分析每只股票
print("\n绘制订单流分析...")
try:
# 中芯国际订单
fig1 = portfolio['00981'].orders.plot()
fig1.update_layout(title='中芯国际订单流分析')
fig1.show()
# 华虹半导体订单
fig2 = portfolio['01347'].orders.plot()
fig2.update_layout(title='华虹半导体订单流分析')
fig2.show()
except Exception as e:
print(f"订单流分析绘制失败: {e}")
# 8. 持仓分析 - 分别分析每只股票
print("\n绘制持仓分析...")
try:
# 中芯国际持仓
holdings_smic = portfolio['00981'].holdings
fig1 = holdings_smic.vbt.plot(
title='中芯国际持仓变化'
)
fig1.update_layout(
xaxis_title='日期',
yaxis_title='持仓价值'
)
fig1.show()
# 华虹半导体持仓
holdings_hhic = portfolio['01347'].holdings
fig2 = holdings_hhic.vbt.plot(
title='华虹半导体持仓变化'
)
fig2.update_layout(
xaxis_title='日期',
yaxis_title='持仓价值'
)
fig2.show()
except Exception as e:
print(f"持仓分析绘制失败: {e}")
# 9. 现金变化 - 使用组合总现金
print("\n绘制现金变化...")
try:
cash = portfolio.cash
# 如果是多列,取第一列
if hasattr(cash, 'columns') and len(cash.columns) > 1:
cash = cash.iloc[:, 0]
fig = cash.vbt.plot(
title='配对交易现金余额变化'
)
fig.update_layout(
xaxis_title='日期',
yaxis_title='现金余额'
)
fig.show()
except Exception as e:
print(f"现金变化绘制失败: {e}")
# 10. 资产配置比例
print("\n绘制资产配置比例...")
try:
# 计算各资产权重
asset_value = portfolio.asset_value()
total_value = portfolio.value()
weights = asset_value.div(total_value, axis=0).fillna(0)
fig = weights.vbt.areaplot(
title='资产配置比例变化'
)
fig.update_layout(
xaxis_title='日期',
yaxis_title='权重'
)
fig.show()
except Exception as e:
print(f"资产配置比例绘制失败: {e}")
# 显示交易记录
non_zero_size = size[(size != 0).any(axis=1)]
print(f"\n非零交易数量: {len(non_zero_size)}")
if len(non_zero_size) > 0:
print("交易记录示例:")
print(non_zero_size.head(20))
# 打印详细统计
print("\n=== 详细统计 ===")
try:
# 分别获取每只股票的统计
stats_smic = portfolio['00981'].stats()
stats_hhic = portfolio['01347'].stats()
def safe_get_stat(stat_dict, key, default="N/A"):
value = stat_dict.get(key, default)
if hasattr(value, 'iloc'):
return value.iloc[0] if len(value) == 1 else value
return value
print("\n中芯国际统计:")
print(f"总收益率: {safe_get_stat(stats_smic, 'Total Return [%]', 'N/A')}%")
print(f"年化收益率: {safe_get_stat(stats_smic, 'Annual Return [%]', 'N/A')}%")
print(f"夏普比率: {safe_get_stat(stats_smic, 'Sharpe Ratio', 'N/A')}")
print(f"最大回撤: {safe_get_stat(stats_smic, 'Max Drawdown [%]', 'N/A')}%")
print("\n华虹半导体统计:")
print(f"总收益率: {safe_get_stat(stats_hhic, 'Total Return [%]', 'N/A')}%")
print(f"年化收益率: {safe_get_stat(stats_hhic, 'Annual Return [%]', 'N/A')}%")
print(f"夏普比率: {safe_get_stat(stats_hhic, 'Sharpe Ratio', 'N/A')}")
print(f"最大回撤: {safe_get_stat(stats_hhic, 'Max Drawdown [%]', 'N/A')}%")
except Exception as e:
print(f"获取详细统计时出错: {e}")
# 分析每笔交易
try:
trades_smic = portfolio['00981'].trades.records_readable
trades_hhic = portfolio['01347'].trades.records_readable
if len(trades_smic) > 0:
print(f"\n中芯国际交易分析:")
print(f"总交易次数: {len(trades_smic)}")
if 'Duration' in trades_smic.columns:
print(f"平均持仓时间: {trades_smic['Duration'].mean():.1f}")
if len(trades_hhic) > 0:
print(f"\n华虹半导体交易分析:")
print(f"总交易次数: {len(trades_hhic)}")
if 'Duration' in trades_hhic.columns:
print(f"平均持仓时间: {trades_hhic['Duration'].mean():.1f}")
except Exception as e:
print(f"分析交易时出错: {e}")
except Exception as e:
print(f"创建投资组合时出错: {e}")
import traceback
traceback.print_exc()
# 显示原始价格走势对比
print("\n绘制原始价格走势...")
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 10))
# 价格走势
ax1.plot(close_smic.index, close_smic, label='中芯国际', linewidth=1)
ax1.plot(close_hhic.index, close_hhic, label='华虹半导体', linewidth=1)
ax1.set_title('股票价格走势')
ax1.set_ylabel('价格')
ax1.legend()
ax1.grid(True, alpha=0.3)
# 价格比率
ax2.plot(ratio.index, ratio, label='价格比率', linewidth=1, color='purple')
ax2.axhline(y=ratio.mean(), color='red', linestyle='--', alpha=0.7, label='平均比率')
ax2.set_title('价格比率走势')
ax2.set_ylabel('比率')
ax2.set_xlabel('日期')
ax2.legend()
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print("程序执行完成!")