Files
quant/PairsTrading.py
2025-10-31 19:31:03 +08:00

401 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import vectorbt as vbt
import akshare as ak
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import datetime
from numba import njit
from collections import namedtuple
# 设置中文显示
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# 获取股票数据
print("正在获取股票数据...")
stock_00981 = ak.stock_hk_daily(symbol="00981")
stock_01347 = ak.stock_hk_daily(symbol="01347")
print("中芯国际数据列名:", stock_00981.columns.tolist())
print("华虹半导体数据列名:", stock_01347.columns.tolist())
# 数据预处理
def preprocess_data(df, symbol):
"""预处理股票数据"""
df = df.copy()
# 检查列名并重命名(如果需要)
if 'date' in df.columns:
df['date'] = pd.to_datetime(df['date'])
df.set_index('date', inplace=True)
elif '日期' in df.columns:
df['date'] = pd.to_datetime(df['日期'])
df.set_index('date', inplace=True)
# 重命名中文列为英文
rename_dict = {
'开盘': 'open',
'最高': 'high',
'最低': 'low',
'收盘': 'close',
'成交量': 'volume'
}
df = df.rename(columns=rename_dict)
else:
# 如果已经有英文列名,直接使用
df.index = pd.to_datetime(df.index)
df = df.sort_index()
return df[['open', 'high', 'low', 'close', 'volume']]
# 预处理数据
smic_data = preprocess_data(stock_00981, "00981")
hhic_data = preprocess_data(stock_01347, "01347")
print(f"中芯国际原始数据时间范围: {smic_data.index.min()}{smic_data.index.max()}")
print(f"华虹半导体原始数据时间范围: {hhic_data.index.min()}{hhic_data.index.max()}")
# 限制为最近三年数据
end_date = smic_data.index.max()
start_date = end_date - pd.Timedelta(days=3*365) # 最近三年
print(f"\n限制回测时间范围: {start_date}{end_date}")
smic_data = smic_data.loc[start_date:end_date]
hhic_data = hhic_data.loc[start_date:end_date]
print(f"限制后中芯国际数据形状: {smic_data.shape}")
print(f"限制后华虹半导体数据形状: {hhic_data.shape}")
# 对齐数据时间索引
common_index = smic_data.index.intersection(hhic_data.index)
smic_data = smic_data.loc[common_index]
hhic_data = hhic_data.loc[common_index]
print(f"对齐后数据时间范围: {common_index.min()}{common_index.max()}")
print(f"总交易日数: {len(common_index)}")
if len(common_index) == 0:
print("错误: 没有共同交易日数据!")
exit()
# 配对交易参数
PERIOD = 30 # 回看周期
UPPER = 2.0 # 上界
LOWER = -2.0 # 下界
ORDER_PCT1 = 0.1 # 中芯国际交易比例
ORDER_PCT2 = 0.1 # 华虹半导体交易比例
COMMPERC = 0.002 # 手续费率 0.2%
MODE = 'OLS' # 使用OLS方法
INITIAL_CASH = 100000 # 初始资金
# 定义数据结构
Memory = namedtuple("Memory", ('spread', 'zscore', 'status'))
Params = namedtuple("Params", ('period', 'upper', 'lower', 'order_pct1', 'order_pct2'))
@njit
def ols_spread_nb(a, b):
"""计算OLS价差"""
a = np.log(a)
b = np.log(b)
_b = np.vstack((b, np.ones(len(b)))).T
slope, intercept = np.dot(np.linalg.inv(np.dot(_b.T, _b)), np.dot(_b.T, a))
spread = a - (slope * b + intercept)
return spread[-1]
@njit
def pre_group_func_nb(c, _period, _upper, _lower, _order_pct1, _order_pct2):
"""准备当前组(资产对)"""
assert c.group_len == 2
# 初始化内存数组
spread = np.full(c.target_shape[0], np.nan, dtype=np.float64)
zscore = np.full(c.target_shape[0], np.nan, dtype=np.float64)
status = np.full(1, 0, dtype=np.int64)
memory = Memory(spread, zscore, status)
# 选择参数
period = _period[0] if len(_period) > 1 else _period
upper = _upper[0] if len(_upper) > 1 else _upper
lower = _lower[0] if len(_lower) > 1 else _lower
order_pct1 = _order_pct1[0] if len(_order_pct1) > 1 else _order_pct1
order_pct2 = _order_pct2[0] if len(_order_pct2) > 1 else _order_pct2
params = Params(period, upper, lower, order_pct1, order_pct2)
# 创建仓位大小数组
size = np.empty(c.group_len, dtype=np.float64)
return (memory, params, size)
@njit
def pre_segment_func_nb(c, memory, params, size, mode):
"""准备当前段(组内行)"""
# 等待足够的数据
if c.i < params.period - 1:
size[0] = np.nan
size[1] = np.nan
return (size,)
# 计算窗口切片
window_slice = slice(max(0, c.i + 1 - params.period), c.i + 1)
# 根据模式计算价差
if mode == 'OLS':
a = c.close[window_slice, c.from_col] # 中芯国际
b = c.close[window_slice, c.from_col + 1] # 华虹半导体
memory.spread[c.i] = ols_spread_nb(a, b)
elif mode == 'log_return':
# 对数收益率方法
logret_a = np.log(c.close[c.i, c.from_col] / c.close[c.i - 1, c.from_col])
logret_b = np.log(c.close[c.i, c.from_col + 1] / c.close[c.i - 1, c.from_col + 1])
memory.spread[c.i] = logret_a - logret_b
else:
raise ValueError("Unknown mode")
# 计算z-score
spread_mean = np.mean(memory.spread[window_slice])
spread_std = np.std(memory.spread[window_slice])
memory.zscore[c.i] = (memory.spread[c.i] - spread_mean) / spread_std
# 使用前一个z-score生成交易信号避免未来数据
if c.i > 0 and not np.isnan(memory.zscore[c.i - 1]):
# 做空信号z-score > 上界 且 当前不是做空状态
if memory.zscore[c.i - 1] > params.upper and memory.status[0] != 1:
size[0] = -params.order_pct1 # 卖空中芯国际
size[1] = params.order_pct2 # 买入华虹半导体
# 执行顺序:先卖后买
c.call_seq_now[0] = 0
c.call_seq_now[1] = 1
memory.status[0] = 1 # 设置为做空状态
# 做多信号z-score < 下界 且 当前不是做多状态
elif memory.zscore[c.i - 1] < params.lower and memory.status[0] != 2:
size[0] = params.order_pct1 # 买入中芯国际
size[1] = -params.order_pct2 # 卖空华虹半导体
# 执行顺序:先卖后买
c.call_seq_now[0] = 1
c.call_seq_now[1] = 0
memory.status[0] = 2 # 设置为做多状态
else:
size[0] = np.nan
size[1] = np.nan
else:
size[0] = np.nan
size[1] = np.nan
# 设置估值价格
c.last_val_price[c.from_col] = c.close[c.i - 1, c.from_col]
c.last_val_price[c.from_col + 1] = c.close[c.i - 1, c.from_col + 1]
return (size,)
@njit
def order_func_nb(c, size, price, commperc):
"""执行订单"""
group_col = c.col - c.from_col
return vbt.portfolio.nb.order_nb(
size=size[group_col],
price=price[c.i, c.col],
size_type=vbt.portfolio.enums.SizeType.TargetPercent,
fees=commperc
)
# 准备价格数据
print("准备回测数据...")
price_data = pd.DataFrame({
'SMIC': smic_data['close'],
'HHIC': hhic_data['close']
})
print(f"价格数据形状: {price_data.shape}")
print(f"价格数据时间范围: {price_data.index.min()}{price_data.index.max()}")
print(f"价格数据前5行:\n{price_data.head()}")
# 运行配对交易回测
print("运行配对交易回测...")
try:
portfolio = vbt.Portfolio.from_order_func(
price_data,
order_func_nb,
pre_group_func_nb=pre_group_func_nb,
pre_segment_func_nb=pre_segment_func_nb,
pre_group_args=(
np.array([PERIOD]), # period
np.array([UPPER]), # upper
np.array([LOWER]), # lower
np.array([ORDER_PCT1]), # order_pct1
np.array([ORDER_PCT2]) # order_pct2
),
pre_segment_args=(MODE,), # mode
order_args=(COMMPERC,), # commperc
group_by=np.array([0, 0]), # 将两列分为同一组
cash_sharing=True, # 共享现金
cash=INITIAL_CASH,
freq='1D'
)
print("回测完成!")
except Exception as e:
print(f"回测出错: {e}")
import traceback
traceback.print_exc()
exit()
# 计算额外指标
def calculate_additional_metrics(price_data, period, mode):
"""计算额外指标"""
smic_close = price_data['SMIC'].values
hhic_close = price_data['HHIC'].values
if mode == 'OLS':
# 使用OLS方法计算价差
spread = np.full(len(smic_close), np.nan)
for i in range(period, len(smic_close)):
window_slice = slice(i - period + 1, i + 1)
a = smic_close[window_slice]
b = hhic_close[window_slice]
spread[i] = ols_spread_nb(a, b)
else:
# 对数收益率方法
spread = np.full(len(smic_close), np.nan)
spread[1:] = np.log(smic_close[1:] / smic_close[:-1]) - np.log(hhic_close[1:] / hhic_close[:-1])
# 计算z-score
zscore = np.full(len(spread), np.nan)
for i in range(period, len(spread)):
window_slice = slice(i - period + 1, i + 1)
spread_mean = np.nanmean(spread[window_slice])
spread_std = np.nanstd(spread[window_slice])
if spread_std > 0:
zscore[i] = (spread[i] - spread_mean) / spread_std
return pd.Series(spread, index=price_data.index), pd.Series(zscore, index=price_data.index)
# 计算指标
spread_series, zscore_series = calculate_additional_metrics(price_data, PERIOD, MODE)
# 生成交易信号
short_signals = (zscore_series > UPPER).rename('short_signals')
long_signals = (zscore_series < LOWER).rename('long_signals')
# 可视化结果
print("生成分析图表...")
# 创建子图
fig = plt.figure(figsize=(15, 12))
# 1. 价格走势
ax1 = plt.subplot(3, 1, 1)
plt.plot(price_data.index, price_data['SMIC'], label='中芯国际(00981)', linewidth=1)
plt.plot(price_data.index, price_data['HHIC'], label='华虹半导体(01347)', linewidth=1)
plt.title(f'股票价格走势 ({start_date.strftime("%Y-%m-%d")}{end_date.strftime("%Y-%m-%d")})')
plt.ylabel('价格(港元)')
plt.legend()
plt.grid(True, alpha=0.3)
# 2. 价差和z-score
ax2 = plt.subplot(3, 1, 2)
plt.plot(spread_series.index, spread_series.values, label='价差', color='blue', linewidth=1)
plt.title('价差走势')
plt.ylabel('价差')
plt.legend()
plt.grid(True, alpha=0.3)
ax3 = ax2.twinx()
plt.plot(zscore_series.index, zscore_series.values, label='Z-Score', color='red', linewidth=1, alpha=0.7)
plt.axhline(y=UPPER, color='red', linestyle='--', alpha=0.5, label=f'上界({UPPER})')
plt.axhline(y=0, color='black', linestyle='-', alpha=0.3)
plt.axhline(y=LOWER, color='green', linestyle='--', alpha=0.5, label=f'下界({LOWER})')
plt.ylabel('Z-Score')
plt.legend()
# 标记交易信号
short_points = zscore_series[short_signals]
long_points = zscore_series[long_signals]
plt.scatter(short_points.index, short_points.values, color='red', marker='v', s=30, label='做空信号')
plt.scatter(long_points.index, long_points.values, color='green', marker='^', s=30, label='做多信号')
# 3. 资产价值
ax4 = plt.subplot(3, 1, 3)
portfolio.value().vbt.plot(ax=ax4, title='投资组合价值')
plt.ylabel('资产价值(港元)')
plt.xlabel('日期')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# 输出回测结果
print("\n" + "="*50)
print("配对交易回测结果(最近三年)")
print("="*50)
# 基本统计
stats = portfolio.stats()
print(f"回测期间: {start_date.strftime('%Y-%m-%d')}{end_date.strftime('%Y-%m-%d')}")
print(f"初始资金: {INITIAL_CASH:,.2f} 港元")
print(f"最终资产: {portfolio.value().iloc[-1]:,.2f} 港元")
print(f"总收益率: {stats['Total Return']:.2%}")
print(f"年化收益率: {stats['Annual Return']:.2%}")
print(f"夏普比率: {stats['Sharpe Ratio']:.2f}")
print(f"最大回撤: {stats['Max Drawdown']:.2%}")
print(f"总交易次数: {stats['Total Trades']}")
# 交易统计
orders = portfolio.orders()
if len(orders) > 0:
print(f"\n交易统计:")
print(f"中芯国际交易次数: {len(orders[orders['Column'] == 'SMIC'])}")
print(f"华虹半导体交易次数: {len(orders[orders['Column'] == 'HHIC'])}")
# 计算平均持仓时间
positions = portfolio.positions()
if len(positions) > 0:
avg_holding = positions.duration.mean()
print(f"平均持仓时间: {avg_holding:.1f}")
# 信号统计
print(f"\n信号统计:")
print(f"做空信号数量: {short_signals.sum()}")
print(f"做多信号数量: {long_signals.sum()}")
print(f"总信号数量: {short_signals.sum() + long_signals.sum()}")
# 相关性分析
correlation = price_data['SMIC'].corr(price_data['HHIC'])
print(f"\n股票相关性: {correlation:.4f}")
# 基本价格统计
print(f"\n价格统计:")
print(f"中芯国际 - 期初价格: {price_data['SMIC'].iloc[0]:.2f} 港元")
print(f"中芯国际 - 期末价格: {price_data['SMIC'].iloc[-1]:.2f} 港元")
print(f"华虹半导体 - 期初价格: {price_data['HHIC'].iloc[0]:.2f} 港元")
print(f"华虹半导体 - 期末价格: {price_data['HHIC'].iloc[-1]:.2f} 港元")
print(f"中芯国际 - 价格波动率: {price_data['SMIC'].pct_change().std():.4f}")
print(f"华虹半导体 - 价格波动率: {price_data['HHIC'].pct_change().std():.4f}")
# 保存结果到文件
results_df = pd.DataFrame({
'Date': price_data.index,
'SMIC_Price': price_data['SMIC'],
'HHIC_Price': price_data['HHIC'],
'Spread': spread_series,
'ZScore': zscore_series,
'Portfolio_Value': portfolio.value(),
'Short_Signals': short_signals,
'Long_Signals': long_signals
})
results_df.to_csv('pair_trading_results_3years.csv', index=False)
print(f"\n详细结果已保存到: pair_trading_results_3years.csv")
# 显示最近交易
recent_trades = orders.tail(10)
if len(recent_trades) > 0:
print(f"\n最近10笔交易:")
print(recent_trades[['Timestamp', 'Column', 'Size', 'Price', 'Fees']])
# 年度绩效分析
print(f"\n年度绩效分析:")
yearly_returns = portfolio.annual_returns()
for year, ret in yearly_returns.items():
print(f"{year}年收益率: {ret:.2%}")