Files
quant/PairsTrading.py
2025-11-01 09:32:26 +08:00

458 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import vectorbt as vbt
import akshare as ak
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import datetime
from numba import njit
from collections import namedtuple
# 设置中文显示
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# 获取股票数据
print("正在获取股票数据...")
stock_00981 = ak.stock_hk_daily(symbol="00981")
stock_01347 = ak.stock_hk_daily(symbol="01347")
print("中芯国际数据列名:", stock_00981.columns.tolist())
print("华虹半导体数据列名:", stock_01347.columns.tolist())
# 数据预处理
def preprocess_data(df, symbol):
"""预处理股票数据"""
df = df.copy()
# 检查列名并重命名(如果需要)
if 'date' in df.columns:
df['date'] = pd.to_datetime(df['date'])
df.set_index('date', inplace=True)
elif '日期' in df.columns:
df['date'] = pd.to_datetime(df['日期'])
df.set_index('date', inplace=True)
# 重命名中文列为英文
rename_dict = {
'开盘': 'open',
'最高': 'high',
'最低': 'low',
'收盘': 'close',
'成交量': 'volume'
}
df = df.rename(columns=rename_dict)
else:
# 如果已经有英文列名,直接使用
df.index = pd.to_datetime(df.index)
df = df.sort_index()
return df[['open', 'high', 'low', 'close', 'volume']]
# 预处理数据
smic_data = preprocess_data(stock_00981, "00981")
hhic_data = preprocess_data(stock_01347, "01347")
print(f"中芯国际原始数据时间范围: {smic_data.index.min()}{smic_data.index.max()}")
print(f"华虹半导体原始数据时间范围: {hhic_data.index.min()}{hhic_data.index.max()}")
# 限制为最近三年数据
end_date = smic_data.index.max()
start_date = end_date - pd.Timedelta(days=3*365) # 最近三年
print(f"\n限制回测时间范围: {start_date}{end_date}")
smic_data = smic_data.loc[start_date:end_date]
hhic_data = hhic_data.loc[start_date:end_date]
print(f"限制后中芯国际数据形状: {smic_data.shape}")
print(f"限制后华虹半导体数据形状: {hhic_data.shape}")
# 对齐数据时间索引
common_index = smic_data.index.intersection(hhic_data.index)
smic_data = smic_data.loc[common_index]
hhic_data = hhic_data.loc[common_index]
print(f"对齐后数据时间范围: {common_index.min()}{common_index.max()}")
print(f"总交易日数: {len(common_index)}")
if len(common_index) == 0:
print("错误: 没有共同交易日数据!")
exit()
# 配对交易参数
PERIOD = 30 # 回看周期
UPPER = 2.0 # 上界
LOWER = -2.0 # 下界
ORDER_PCT1 = 0.1 # 中芯国际交易比例
ORDER_PCT2 = 0.1 # 华虹半导体交易比例
COMMPERC = 0.002 # 手续费率 0.2%
MODE = 'OLS' # 使用OLS方法
INITIAL_CASH = 100000 # 初始资金
# 定义数据结构
Memory = namedtuple("Memory", ('spread', 'zscore', 'status'))
Params = namedtuple("Params", ('period', 'upper', 'lower', 'order_pct1', 'order_pct2'))
@njit
def ols_spread_nb(a, b):
"""计算OLS价差"""
a = np.log(a)
b = np.log(b)
_b = np.vstack((b, np.ones(len(b)))).T
slope, intercept = np.dot(np.linalg.inv(np.dot(_b.T, _b)), np.dot(_b.T, a))
spread = a - (slope * b + intercept)
return spread[-1]
@njit
def pre_group_func_nb(c, _period, _upper, _lower, _order_pct1, _order_pct2):
"""准备当前组(资产对)"""
assert c.group_len == 2
# 初始化内存数组
spread = np.full(c.target_shape[0], np.nan, dtype=np.float64)
zscore = np.full(c.target_shape[0], np.nan, dtype=np.float64)
status = np.full(1, 0, dtype=np.int64)
memory = Memory(spread, zscore, status)
# 简化参数处理 - 直接使用标量值
period = _period
upper = _upper
lower = _lower
order_pct1 = _order_pct1
order_pct2 = _order_pct2
params = Params(period, upper, lower, order_pct1, order_pct2)
# 创建仓位大小数组
size = np.empty(c.group_len, dtype=np.float64)
return (memory, params, size)
@njit
def pre_segment_func_nb(c, memory, params, size, mode):
"""准备当前段(组内行)"""
# 等待足够的数据
if c.i < params.period - 1:
size[0] = np.nan
size[1] = np.nan
return (size,)
# 计算窗口切片
window_slice = slice(max(0, c.i + 1 - params.period), c.i + 1)
# 根据模式计算价差
if mode == 'OLS':
a = c.close[window_slice, c.from_col] # 中芯国际
b = c.close[window_slice, c.from_col + 1] # 华虹半导体
memory.spread[c.i] = ols_spread_nb(a, b)
elif mode == 'log_return':
# 对数收益率方法
logret_a = np.log(c.close[c.i, c.from_col] / c.close[c.i - 1, c.from_col])
logret_b = np.log(c.close[c.i, c.from_col + 1] / c.close[c.i - 1, c.from_col + 1])
memory.spread[c.i] = logret_a - logret_b
else:
raise ValueError("Unknown mode")
# 计算z-score
spread_mean = np.mean(memory.spread[window_slice])
spread_std = np.std(memory.spread[window_slice])
memory.zscore[c.i] = (memory.spread[c.i] - spread_mean) / spread_std
# 使用前一个z-score生成交易信号避免未来数据
if c.i > 0 and not np.isnan(memory.zscore[c.i - 1]):
# 做空信号z-score > 上界 且 当前不是做空状态
if memory.zscore[c.i - 1] > params.upper and memory.status[0] != 1:
size[0] = -params.order_pct1 # 卖空中芯国际
size[1] = params.order_pct2 # 买入华虹半导体
# 执行顺序:先卖后买
c.call_seq_now[0] = 0
c.call_seq_now[1] = 1
memory.status[0] = 1 # 设置为做空状态
# 做多信号z-score < 下界 且 当前不是做多状态
elif memory.zscore[c.i - 1] < params.lower and memory.status[0] != 2:
size[0] = params.order_pct1 # 买入中芯国际
size[1] = -params.order_pct2 # 卖空华虹半导体
# 执行顺序:先卖后买
c.call_seq_now[0] = 1
c.call_seq_now[1] = 0
memory.status[0] = 2 # 设置为做多状态
else:
size[0] = np.nan
size[1] = np.nan
else:
size[0] = np.nan
size[1] = np.nan
# 设置估值价格
c.last_val_price[c.from_col] = c.close[c.i - 1, c.from_col]
c.last_val_price[c.from_col + 1] = c.close[c.i - 1, c.from_col + 1]
return (size,)
# 修复订单函数 - 使用正确的参数顺序
@njit
def order_func_nb(c, size):
"""执行订单 - 简化版本,只接受必要的参数"""
group_col = c.col - c.from_col
# 直接使用固定手续费率,避免参数传递问题
commperc = 0.002 # 0.2%
return vbt.portfolio.nb.order_nb(
size=size[group_col],
price=c.close[c.i, c.col], # 使用当前价格
size_type=vbt.portfolio.enums.SizeType.TargetPercent,
fees=commperc
)
# 准备价格数据
print("准备回测数据...")
price_data = pd.DataFrame({
'SMIC': smic_data['close'],
'HHIC': hhic_data['close']
})
print(f"价格数据形状: {price_data.shape}")
print(f"价格数据时间范围: {price_data.index.min()}{price_data.index.max()}")
print(f"价格数据前5行:\n{price_data.head()}")
# 调试信息检查vectorbt版本和可用参数
print("\n=== 调试信息 ===")
print(f"vectorbt版本: {vbt.__version__}")
print(f"pandas版本: {pd.__version__}")
print(f"numpy版本: {np.__version__}")
# 运行配对交易回测
print("\n运行配对交易回测...")
# 首先尝试不使用cash参数
print("尝试方法1: 不使用cash参数...")
portfolio = vbt.Portfolio.from_order_func(
price_data,
order_func_nb,
pre_group_func_nb=pre_group_func_nb,
pre_segment_func_nb=pre_segment_func_nb,
pre_group_args=(
PERIOD, # period (标量)
UPPER, # upper (标量)
LOWER, # lower (标量)
ORDER_PCT1, # order_pct1 (标量)
ORDER_PCT2 # order_pct2 (标量)
),
pre_segment_args=(MODE,), # mode
group_by=True, # 将两列分为同一组
cash_sharing=True, # 共享现金
init_cash=INITIAL_CASH,
freq='d'
)
print("方法1成功!")
print("回测完成!")
# 计算额外指标用于分析
def calculate_additional_metrics(price_data, period, mode):
"""计算额外指标"""
smic_close = price_data['SMIC'].values
hhic_close = price_data['HHIC'].values
if mode == 'OLS':
# 使用OLS方法计算价差
spread = np.full(len(smic_close), np.nan)
for i in range(period, len(smic_close)):
window_slice = slice(i - period + 1, i + 1)
a = smic_close[window_slice]
b = hhic_close[window_slice]
spread[i] = ols_spread_nb(a, b)
else:
# 对数收益率方法
spread = np.full(len(smic_close), np.nan)
spread[1:] = np.log(smic_close[1:] / smic_close[:-1]) - np.log(hhic_close[1:] / hhic_close[:-1])
# 计算z-score
zscore = np.full(len(spread), np.nan)
for i in range(period, len(spread)):
window_slice = slice(i - period + 1, i + 1)
spread_mean = np.nanmean(spread[window_slice])
spread_std = np.nanstd(spread[window_slice])
if spread_std > 0:
zscore[i] = (spread[i] - spread_mean) / spread_std
return pd.Series(spread, index=price_data.index), pd.Series(zscore, index=price_data.index)
# 计算指标
spread_series, zscore_series = calculate_additional_metrics(price_data, PERIOD, MODE)
# 生成交易信号
short_signals = (zscore_series > UPPER).rename('short_signals')
long_signals = (zscore_series < LOWER).rename('long_signals')
# 可视化结果 - 使用matplotlib而不是vectorbt的绘图
print("生成分析图表...")
# 创建子图
fig, axes = plt.subplots(3, 1, figsize=(15, 12))
# 1. 价格走势
ax1 = axes[0]
ax1.plot(price_data.index, price_data['SMIC'], label='中芯国际(00981)', linewidth=1)
ax1.plot(price_data.index, price_data['HHIC'], label='华虹半导体(01347)', linewidth=1)
ax1.set_title(f'股票价格走势 ({start_date.strftime("%Y-%m-%d")}{end_date.strftime("%Y-%m-%d")})')
ax1.set_ylabel('价格(港元)')
ax1.legend()
ax1.grid(True, alpha=0.3)
# 2. 价差和z-score
ax2 = axes[1]
ax2.plot(spread_series.index, spread_series.values, label='价差', color='blue', linewidth=1)
ax2.set_title('价差走势')
ax2.set_ylabel('价差')
ax2.legend(loc='upper left')
ax2.grid(True, alpha=0.3)
ax3 = ax2.twinx()
ax3.plot(zscore_series.index, zscore_series.values, label='Z-Score', color='red', linewidth=1, alpha=0.7)
ax3.axhline(y=UPPER, color='red', linestyle='--', alpha=0.5, label=f'上界({UPPER})')
ax3.axhline(y=0, color='black', linestyle='-', alpha=0.3)
ax3.axhline(y=LOWER, color='green', linestyle='--', alpha=0.5, label=f'下界({LOWER})')
ax3.set_ylabel('Z-Score')
ax3.legend(loc='upper right')
# 标记交易信号
short_points = zscore_series[short_signals]
long_points = zscore_series[long_signals]
ax3.scatter(short_points.index, short_points.values, color='red', marker='v', s=30, label='做空信号')
ax3.scatter(long_points.index, long_points.values, color='green', marker='^', s=30, label='做多信号')
# 3. 资产价值
ax4 = axes[2]
portfolio_value = portfolio.value()
ax4.plot(portfolio_value.index, portfolio_value.values, label='投资组合价值', color='purple', linewidth=1)
ax4.set_title('投资组合价值')
ax4.set_ylabel('资产价值(港元)')
ax4.set_xlabel('日期')
ax4.legend()
ax4.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# 输出回测结果
print("\n" + "="*50)
print("配对交易回测结果(最近三年)")
print("="*50)
# 调试信息:检查投资组合属性
print("\n=== 投资组合调试信息 ===")
print(f"投资组合类型: {type(portfolio)}")
print(f"投资组合值形状: {portfolio.value().shape}")
print(f"投资组合值前5行:\n{portfolio.value().head()}")
print(f"投资组合值统计:")
print(f" 最小值: {portfolio.value().min():.2f}")
print(f" 最大值: {portfolio.value().max():.2f}")
print(f" 最终值: {portfolio.value().iloc[-1]:.2f}")
# 基本统计
try:
stats = portfolio.stats()
print(f"\n回测期间: {start_date.strftime('%Y-%m-%d')}{end_date.strftime('%Y-%m-%d')}")
print(f"初始资金: {INITIAL_CASH:,.2f} 港元")
print(f"最终资产: {portfolio.value().iloc[-1]:,.2f} 港元")
print(f"总收益率: {stats['Total Return']:.2%}")
print(f"年化收益率: {stats['Annual Return']:.2%}")
print(f"夏普比率: {stats['Sharpe Ratio']:.2f}")
print(f"最大回撤: {stats['Max Drawdown']:.2%}")
print(f"总交易次数: {stats['Total Trades']}")
except Exception as e:
print(f"统计计算错误: {e}")
# 交易统计
try:
orders = portfolio.orders()
if len(orders) > 0:
print(f"\n交易统计:")
print(f"总订单数: {len(orders)}")
print(f"中芯国际交易次数: {len(orders[orders['Column'] == 'SMIC'])}")
print(f"华虹半导体交易次数: {len(orders[orders['Column'] == 'HHIC'])}")
# 计算平均持仓时间
positions = portfolio.positions()
if len(positions) > 0:
avg_holding = positions.duration.mean()
print(f"平均持仓时间: {avg_holding:.1f}")
# 显示交易盈亏
if hasattr(portfolio, 'trades'):
trades = portfolio.trades()
if len(trades) > 0:
print(f"总交易盈亏: {trades.pnl.sum():.2f} 港元")
print(f"平均每笔交易盈亏: {trades.pnl.mean():.2f} 港元")
else:
print("没有交易记录")
except Exception as e:
print(f"交易统计错误: {e}")
# 信号统计
print(f"\n信号统计:")
print(f"做空信号数量: {short_signals.sum()}")
print(f"做多信号数量: {long_signals.sum()}")
print(f"总信号数量: {short_signals.sum() + long_signals.sum()}")
# 相关性分析
correlation = price_data['SMIC'].corr(price_data['HHIC'])
print(f"\n股票相关性: {correlation:.4f}")
# 基本价格统计
print(f"\n价格统计:")
print(f"中芯国际 - 期初价格: {price_data['SMIC'].iloc[0]:.2f} 港元")
print(f"中芯国际 - 期末价格: {price_data['SMIC'].iloc[-1]:.2f} 港元")
print(f"中芯国际 - 期间涨跌幅: {(price_data['SMIC'].iloc[-1] / price_data['SMIC'].iloc[0] - 1):.2%}")
print(f"华虹半导体 - 期初价格: {price_data['HHIC'].iloc[0]:.2f} 港元")
print(f"华虹半导体 - 期末价格: {price_data['HHIC'].iloc[-1]:.2f} 港元")
print(f"华虹半导体 - 期间涨跌幅: {(price_data['HHIC'].iloc[-1] / price_data['HHIC'].iloc[0] - 1):.2%}")
print(f"中芯国际 - 价格波动率: {price_data['SMIC'].pct_change().std():.4f}")
print(f"华虹半导体 - 价格波动率: {price_data['HHIC'].pct_change().std():.4f}")
# 保存结果到文件
try:
results_df = pd.DataFrame({
'Date': price_data.index,
'SMIC_Price': price_data['SMIC'],
'HHIC_Price': price_data['HHIC'],
'Spread': spread_series,
'ZScore': zscore_series,
'Portfolio_Value': portfolio.value(),
'Short_Signals': short_signals,
'Long_Signals': long_signals
})
results_df.to_csv('pair_trading_results_3years.csv', index=False)
print(f"\n详细结果已保存到: pair_trading_results_3years.csv")
except Exception as e:
print(f"保存结果错误: {e}")
# 显示最近交易
try:
orders = portfolio.orders()
if len(orders) > 0:
recent_trades = orders.tail(10)
print(f"\n最近10笔交易:")
for idx, trade in recent_trades.iterrows():
print(f" 时间: {trade['Timestamp']}, 股票: {trade['Column']}, "
f"数量: {trade['Size']:.4f}, 价格: {trade['Price']:.2f}, "
f"手续费: {trade['Fees']:.2f}")
else:
print("没有交易记录")
except Exception as e:
print(f"显示最近交易错误: {e}")
# 年度绩效分析
try:
print(f"\n年度绩效分析:")
yearly_returns = portfolio.annual_returns()
for year, ret in yearly_returns.items():
print(f"{year}年收益率: {ret:.2%}")
except Exception as e:
print(f"年度绩效分析错误: {e}")
print("\n回测完成!")