458 lines
16 KiB
Python
458 lines
16 KiB
Python
import vectorbt as vbt
|
||
import akshare as ak
|
||
import pandas as pd
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
import datetime
|
||
from numba import njit
|
||
from collections import namedtuple
|
||
|
||
# 设置中文显示
|
||
plt.rcParams['font.sans-serif'] = ['SimHei']
|
||
plt.rcParams['axes.unicode_minus'] = False
|
||
|
||
# 获取股票数据
|
||
print("正在获取股票数据...")
|
||
stock_00981 = ak.stock_hk_daily(symbol="00981")
|
||
stock_01347 = ak.stock_hk_daily(symbol="01347")
|
||
|
||
print("中芯国际数据列名:", stock_00981.columns.tolist())
|
||
print("华虹半导体数据列名:", stock_01347.columns.tolist())
|
||
|
||
# 数据预处理
|
||
def preprocess_data(df, symbol):
|
||
"""预处理股票数据"""
|
||
df = df.copy()
|
||
# 检查列名并重命名(如果需要)
|
||
if 'date' in df.columns:
|
||
df['date'] = pd.to_datetime(df['date'])
|
||
df.set_index('date', inplace=True)
|
||
elif '日期' in df.columns:
|
||
df['date'] = pd.to_datetime(df['日期'])
|
||
df.set_index('date', inplace=True)
|
||
# 重命名中文列为英文
|
||
rename_dict = {
|
||
'开盘': 'open',
|
||
'最高': 'high',
|
||
'最低': 'low',
|
||
'收盘': 'close',
|
||
'成交量': 'volume'
|
||
}
|
||
df = df.rename(columns=rename_dict)
|
||
else:
|
||
# 如果已经有英文列名,直接使用
|
||
df.index = pd.to_datetime(df.index)
|
||
|
||
df = df.sort_index()
|
||
return df[['open', 'high', 'low', 'close', 'volume']]
|
||
|
||
# 预处理数据
|
||
smic_data = preprocess_data(stock_00981, "00981")
|
||
hhic_data = preprocess_data(stock_01347, "01347")
|
||
|
||
print(f"中芯国际原始数据时间范围: {smic_data.index.min()} 到 {smic_data.index.max()}")
|
||
print(f"华虹半导体原始数据时间范围: {hhic_data.index.min()} 到 {hhic_data.index.max()}")
|
||
|
||
# 限制为最近三年数据
|
||
end_date = smic_data.index.max()
|
||
start_date = end_date - pd.Timedelta(days=3*365) # 最近三年
|
||
|
||
print(f"\n限制回测时间范围: {start_date} 到 {end_date}")
|
||
|
||
smic_data = smic_data.loc[start_date:end_date]
|
||
hhic_data = hhic_data.loc[start_date:end_date]
|
||
|
||
print(f"限制后中芯国际数据形状: {smic_data.shape}")
|
||
print(f"限制后华虹半导体数据形状: {hhic_data.shape}")
|
||
|
||
# 对齐数据时间索引
|
||
common_index = smic_data.index.intersection(hhic_data.index)
|
||
smic_data = smic_data.loc[common_index]
|
||
hhic_data = hhic_data.loc[common_index]
|
||
|
||
print(f"对齐后数据时间范围: {common_index.min()} 到 {common_index.max()}")
|
||
print(f"总交易日数: {len(common_index)}")
|
||
|
||
if len(common_index) == 0:
|
||
print("错误: 没有共同交易日数据!")
|
||
exit()
|
||
|
||
# 配对交易参数
|
||
PERIOD = 30 # 回看周期
|
||
UPPER = 2.0 # 上界
|
||
LOWER = -2.0 # 下界
|
||
ORDER_PCT1 = 0.1 # 中芯国际交易比例
|
||
ORDER_PCT2 = 0.1 # 华虹半导体交易比例
|
||
COMMPERC = 0.002 # 手续费率 0.2%
|
||
MODE = 'OLS' # 使用OLS方法
|
||
INITIAL_CASH = 100000 # 初始资金
|
||
|
||
# 定义数据结构
|
||
Memory = namedtuple("Memory", ('spread', 'zscore', 'status'))
|
||
Params = namedtuple("Params", ('period', 'upper', 'lower', 'order_pct1', 'order_pct2'))
|
||
|
||
@njit
|
||
def ols_spread_nb(a, b):
|
||
"""计算OLS价差"""
|
||
a = np.log(a)
|
||
b = np.log(b)
|
||
_b = np.vstack((b, np.ones(len(b)))).T
|
||
slope, intercept = np.dot(np.linalg.inv(np.dot(_b.T, _b)), np.dot(_b.T, a))
|
||
spread = a - (slope * b + intercept)
|
||
return spread[-1]
|
||
|
||
@njit
|
||
def pre_group_func_nb(c, _period, _upper, _lower, _order_pct1, _order_pct2):
|
||
"""准备当前组(资产对)"""
|
||
assert c.group_len == 2
|
||
|
||
# 初始化内存数组
|
||
spread = np.full(c.target_shape[0], np.nan, dtype=np.float64)
|
||
zscore = np.full(c.target_shape[0], np.nan, dtype=np.float64)
|
||
status = np.full(1, 0, dtype=np.int64)
|
||
memory = Memory(spread, zscore, status)
|
||
|
||
# 简化参数处理 - 直接使用标量值
|
||
period = _period
|
||
upper = _upper
|
||
lower = _lower
|
||
order_pct1 = _order_pct1
|
||
order_pct2 = _order_pct2
|
||
|
||
params = Params(period, upper, lower, order_pct1, order_pct2)
|
||
|
||
# 创建仓位大小数组
|
||
size = np.empty(c.group_len, dtype=np.float64)
|
||
|
||
return (memory, params, size)
|
||
|
||
@njit
|
||
def pre_segment_func_nb(c, memory, params, size, mode):
|
||
"""准备当前段(组内行)"""
|
||
|
||
# 等待足够的数据
|
||
if c.i < params.period - 1:
|
||
size[0] = np.nan
|
||
size[1] = np.nan
|
||
return (size,)
|
||
|
||
# 计算窗口切片
|
||
window_slice = slice(max(0, c.i + 1 - params.period), c.i + 1)
|
||
|
||
# 根据模式计算价差
|
||
if mode == 'OLS':
|
||
a = c.close[window_slice, c.from_col] # 中芯国际
|
||
b = c.close[window_slice, c.from_col + 1] # 华虹半导体
|
||
memory.spread[c.i] = ols_spread_nb(a, b)
|
||
elif mode == 'log_return':
|
||
# 对数收益率方法
|
||
logret_a = np.log(c.close[c.i, c.from_col] / c.close[c.i - 1, c.from_col])
|
||
logret_b = np.log(c.close[c.i, c.from_col + 1] / c.close[c.i - 1, c.from_col + 1])
|
||
memory.spread[c.i] = logret_a - logret_b
|
||
else:
|
||
raise ValueError("Unknown mode")
|
||
|
||
# 计算z-score
|
||
spread_mean = np.mean(memory.spread[window_slice])
|
||
spread_std = np.std(memory.spread[window_slice])
|
||
memory.zscore[c.i] = (memory.spread[c.i] - spread_mean) / spread_std
|
||
|
||
# 使用前一个z-score生成交易信号(避免未来数据)
|
||
if c.i > 0 and not np.isnan(memory.zscore[c.i - 1]):
|
||
# 做空信号:z-score > 上界 且 当前不是做空状态
|
||
if memory.zscore[c.i - 1] > params.upper and memory.status[0] != 1:
|
||
size[0] = -params.order_pct1 # 卖空中芯国际
|
||
size[1] = params.order_pct2 # 买入华虹半导体
|
||
# 执行顺序:先卖后买
|
||
c.call_seq_now[0] = 0
|
||
c.call_seq_now[1] = 1
|
||
memory.status[0] = 1 # 设置为做空状态
|
||
|
||
# 做多信号:z-score < 下界 且 当前不是做多状态
|
||
elif memory.zscore[c.i - 1] < params.lower and memory.status[0] != 2:
|
||
size[0] = params.order_pct1 # 买入中芯国际
|
||
size[1] = -params.order_pct2 # 卖空华虹半导体
|
||
# 执行顺序:先卖后买
|
||
c.call_seq_now[0] = 1
|
||
c.call_seq_now[1] = 0
|
||
memory.status[0] = 2 # 设置为做多状态
|
||
|
||
else:
|
||
size[0] = np.nan
|
||
size[1] = np.nan
|
||
else:
|
||
size[0] = np.nan
|
||
size[1] = np.nan
|
||
|
||
# 设置估值价格
|
||
c.last_val_price[c.from_col] = c.close[c.i - 1, c.from_col]
|
||
c.last_val_price[c.from_col + 1] = c.close[c.i - 1, c.from_col + 1]
|
||
|
||
return (size,)
|
||
|
||
# 修复订单函数 - 使用正确的参数顺序
|
||
@njit
|
||
def order_func_nb(c, size):
|
||
"""执行订单 - 简化版本,只接受必要的参数"""
|
||
group_col = c.col - c.from_col
|
||
|
||
# 直接使用固定手续费率,避免参数传递问题
|
||
commperc = 0.002 # 0.2%
|
||
|
||
return vbt.portfolio.nb.order_nb(
|
||
size=size[group_col],
|
||
price=c.close[c.i, c.col], # 使用当前价格
|
||
size_type=vbt.portfolio.enums.SizeType.TargetPercent,
|
||
fees=commperc
|
||
)
|
||
|
||
# 准备价格数据
|
||
print("准备回测数据...")
|
||
price_data = pd.DataFrame({
|
||
'SMIC': smic_data['close'],
|
||
'HHIC': hhic_data['close']
|
||
})
|
||
|
||
print(f"价格数据形状: {price_data.shape}")
|
||
print(f"价格数据时间范围: {price_data.index.min()} 到 {price_data.index.max()}")
|
||
print(f"价格数据前5行:\n{price_data.head()}")
|
||
|
||
# 调试信息:检查vectorbt版本和可用参数
|
||
print("\n=== 调试信息 ===")
|
||
print(f"vectorbt版本: {vbt.__version__}")
|
||
print(f"pandas版本: {pd.__version__}")
|
||
print(f"numpy版本: {np.__version__}")
|
||
|
||
# 运行配对交易回测
|
||
print("\n运行配对交易回测...")
|
||
|
||
# 首先尝试不使用cash参数
|
||
print("尝试方法1: 不使用cash参数...")
|
||
portfolio = vbt.Portfolio.from_order_func(
|
||
price_data,
|
||
order_func_nb,
|
||
pre_group_func_nb=pre_group_func_nb,
|
||
pre_segment_func_nb=pre_segment_func_nb,
|
||
pre_group_args=(
|
||
PERIOD, # period (标量)
|
||
UPPER, # upper (标量)
|
||
LOWER, # lower (标量)
|
||
ORDER_PCT1, # order_pct1 (标量)
|
||
ORDER_PCT2 # order_pct2 (标量)
|
||
),
|
||
pre_segment_args=(MODE,), # mode
|
||
group_by=True, # 将两列分为同一组
|
||
cash_sharing=True, # 共享现金
|
||
init_cash=INITIAL_CASH,
|
||
freq='d'
|
||
)
|
||
print("方法1成功!")
|
||
|
||
print("回测完成!")
|
||
|
||
# 计算额外指标用于分析
|
||
def calculate_additional_metrics(price_data, period, mode):
|
||
"""计算额外指标"""
|
||
smic_close = price_data['SMIC'].values
|
||
hhic_close = price_data['HHIC'].values
|
||
|
||
if mode == 'OLS':
|
||
# 使用OLS方法计算价差
|
||
spread = np.full(len(smic_close), np.nan)
|
||
for i in range(period, len(smic_close)):
|
||
window_slice = slice(i - period + 1, i + 1)
|
||
a = smic_close[window_slice]
|
||
b = hhic_close[window_slice]
|
||
spread[i] = ols_spread_nb(a, b)
|
||
else:
|
||
# 对数收益率方法
|
||
spread = np.full(len(smic_close), np.nan)
|
||
spread[1:] = np.log(smic_close[1:] / smic_close[:-1]) - np.log(hhic_close[1:] / hhic_close[:-1])
|
||
|
||
# 计算z-score
|
||
zscore = np.full(len(spread), np.nan)
|
||
for i in range(period, len(spread)):
|
||
window_slice = slice(i - period + 1, i + 1)
|
||
spread_mean = np.nanmean(spread[window_slice])
|
||
spread_std = np.nanstd(spread[window_slice])
|
||
if spread_std > 0:
|
||
zscore[i] = (spread[i] - spread_mean) / spread_std
|
||
|
||
return pd.Series(spread, index=price_data.index), pd.Series(zscore, index=price_data.index)
|
||
|
||
# 计算指标
|
||
spread_series, zscore_series = calculate_additional_metrics(price_data, PERIOD, MODE)
|
||
|
||
# 生成交易信号
|
||
short_signals = (zscore_series > UPPER).rename('short_signals')
|
||
long_signals = (zscore_series < LOWER).rename('long_signals')
|
||
|
||
# 可视化结果 - 使用matplotlib而不是vectorbt的绘图
|
||
print("生成分析图表...")
|
||
|
||
# 创建子图
|
||
fig, axes = plt.subplots(3, 1, figsize=(15, 12))
|
||
|
||
# 1. 价格走势
|
||
ax1 = axes[0]
|
||
ax1.plot(price_data.index, price_data['SMIC'], label='中芯国际(00981)', linewidth=1)
|
||
ax1.plot(price_data.index, price_data['HHIC'], label='华虹半导体(01347)', linewidth=1)
|
||
ax1.set_title(f'股票价格走势 ({start_date.strftime("%Y-%m-%d")} 至 {end_date.strftime("%Y-%m-%d")})')
|
||
ax1.set_ylabel('价格(港元)')
|
||
ax1.legend()
|
||
ax1.grid(True, alpha=0.3)
|
||
|
||
# 2. 价差和z-score
|
||
ax2 = axes[1]
|
||
ax2.plot(spread_series.index, spread_series.values, label='价差', color='blue', linewidth=1)
|
||
ax2.set_title('价差走势')
|
||
ax2.set_ylabel('价差')
|
||
ax2.legend(loc='upper left')
|
||
ax2.grid(True, alpha=0.3)
|
||
|
||
ax3 = ax2.twinx()
|
||
ax3.plot(zscore_series.index, zscore_series.values, label='Z-Score', color='red', linewidth=1, alpha=0.7)
|
||
ax3.axhline(y=UPPER, color='red', linestyle='--', alpha=0.5, label=f'上界({UPPER})')
|
||
ax3.axhline(y=0, color='black', linestyle='-', alpha=0.3)
|
||
ax3.axhline(y=LOWER, color='green', linestyle='--', alpha=0.5, label=f'下界({LOWER})')
|
||
ax3.set_ylabel('Z-Score')
|
||
ax3.legend(loc='upper right')
|
||
|
||
# 标记交易信号
|
||
short_points = zscore_series[short_signals]
|
||
long_points = zscore_series[long_signals]
|
||
ax3.scatter(short_points.index, short_points.values, color='red', marker='v', s=30, label='做空信号')
|
||
ax3.scatter(long_points.index, long_points.values, color='green', marker='^', s=30, label='做多信号')
|
||
|
||
# 3. 资产价值
|
||
ax4 = axes[2]
|
||
portfolio_value = portfolio.value()
|
||
ax4.plot(portfolio_value.index, portfolio_value.values, label='投资组合价值', color='purple', linewidth=1)
|
||
ax4.set_title('投资组合价值')
|
||
ax4.set_ylabel('资产价值(港元)')
|
||
ax4.set_xlabel('日期')
|
||
ax4.legend()
|
||
ax4.grid(True, alpha=0.3)
|
||
|
||
plt.tight_layout()
|
||
plt.show()
|
||
|
||
# 输出回测结果
|
||
print("\n" + "="*50)
|
||
print("配对交易回测结果(最近三年)")
|
||
print("="*50)
|
||
|
||
# 调试信息:检查投资组合属性
|
||
print("\n=== 投资组合调试信息 ===")
|
||
print(f"投资组合类型: {type(portfolio)}")
|
||
print(f"投资组合值形状: {portfolio.value().shape}")
|
||
print(f"投资组合值前5行:\n{portfolio.value().head()}")
|
||
print(f"投资组合值统计:")
|
||
print(f" 最小值: {portfolio.value().min():.2f}")
|
||
print(f" 最大值: {portfolio.value().max():.2f}")
|
||
print(f" 最终值: {portfolio.value().iloc[-1]:.2f}")
|
||
|
||
# 基本统计
|
||
try:
|
||
stats = portfolio.stats()
|
||
print(f"\n回测期间: {start_date.strftime('%Y-%m-%d')} 至 {end_date.strftime('%Y-%m-%d')}")
|
||
print(f"初始资金: {INITIAL_CASH:,.2f} 港元")
|
||
print(f"最终资产: {portfolio.value().iloc[-1]:,.2f} 港元")
|
||
print(f"总收益率: {stats['Total Return']:.2%}")
|
||
print(f"年化收益率: {stats['Annual Return']:.2%}")
|
||
print(f"夏普比率: {stats['Sharpe Ratio']:.2f}")
|
||
print(f"最大回撤: {stats['Max Drawdown']:.2%}")
|
||
print(f"总交易次数: {stats['Total Trades']}")
|
||
except Exception as e:
|
||
print(f"统计计算错误: {e}")
|
||
|
||
# 交易统计
|
||
try:
|
||
orders = portfolio.orders()
|
||
if len(orders) > 0:
|
||
print(f"\n交易统计:")
|
||
print(f"总订单数: {len(orders)}")
|
||
print(f"中芯国际交易次数: {len(orders[orders['Column'] == 'SMIC'])}")
|
||
print(f"华虹半导体交易次数: {len(orders[orders['Column'] == 'HHIC'])}")
|
||
|
||
# 计算平均持仓时间
|
||
positions = portfolio.positions()
|
||
if len(positions) > 0:
|
||
avg_holding = positions.duration.mean()
|
||
print(f"平均持仓时间: {avg_holding:.1f} 天")
|
||
|
||
# 显示交易盈亏
|
||
if hasattr(portfolio, 'trades'):
|
||
trades = portfolio.trades()
|
||
if len(trades) > 0:
|
||
print(f"总交易盈亏: {trades.pnl.sum():.2f} 港元")
|
||
print(f"平均每笔交易盈亏: {trades.pnl.mean():.2f} 港元")
|
||
else:
|
||
print("没有交易记录")
|
||
except Exception as e:
|
||
print(f"交易统计错误: {e}")
|
||
|
||
# 信号统计
|
||
print(f"\n信号统计:")
|
||
print(f"做空信号数量: {short_signals.sum()}")
|
||
print(f"做多信号数量: {long_signals.sum()}")
|
||
print(f"总信号数量: {short_signals.sum() + long_signals.sum()}")
|
||
|
||
# 相关性分析
|
||
correlation = price_data['SMIC'].corr(price_data['HHIC'])
|
||
print(f"\n股票相关性: {correlation:.4f}")
|
||
|
||
# 基本价格统计
|
||
print(f"\n价格统计:")
|
||
print(f"中芯国际 - 期初价格: {price_data['SMIC'].iloc[0]:.2f} 港元")
|
||
print(f"中芯国际 - 期末价格: {price_data['SMIC'].iloc[-1]:.2f} 港元")
|
||
print(f"中芯国际 - 期间涨跌幅: {(price_data['SMIC'].iloc[-1] / price_data['SMIC'].iloc[0] - 1):.2%}")
|
||
print(f"华虹半导体 - 期初价格: {price_data['HHIC'].iloc[0]:.2f} 港元")
|
||
print(f"华虹半导体 - 期末价格: {price_data['HHIC'].iloc[-1]:.2f} 港元")
|
||
print(f"华虹半导体 - 期间涨跌幅: {(price_data['HHIC'].iloc[-1] / price_data['HHIC'].iloc[0] - 1):.2%}")
|
||
print(f"中芯国际 - 价格波动率: {price_data['SMIC'].pct_change().std():.4f}")
|
||
print(f"华虹半导体 - 价格波动率: {price_data['HHIC'].pct_change().std():.4f}")
|
||
|
||
# 保存结果到文件
|
||
try:
|
||
results_df = pd.DataFrame({
|
||
'Date': price_data.index,
|
||
'SMIC_Price': price_data['SMIC'],
|
||
'HHIC_Price': price_data['HHIC'],
|
||
'Spread': spread_series,
|
||
'ZScore': zscore_series,
|
||
'Portfolio_Value': portfolio.value(),
|
||
'Short_Signals': short_signals,
|
||
'Long_Signals': long_signals
|
||
})
|
||
|
||
results_df.to_csv('pair_trading_results_3years.csv', index=False)
|
||
print(f"\n详细结果已保存到: pair_trading_results_3years.csv")
|
||
except Exception as e:
|
||
print(f"保存结果错误: {e}")
|
||
|
||
# 显示最近交易
|
||
try:
|
||
orders = portfolio.orders()
|
||
if len(orders) > 0:
|
||
recent_trades = orders.tail(10)
|
||
print(f"\n最近10笔交易:")
|
||
for idx, trade in recent_trades.iterrows():
|
||
print(f" 时间: {trade['Timestamp']}, 股票: {trade['Column']}, "
|
||
f"数量: {trade['Size']:.4f}, 价格: {trade['Price']:.2f}, "
|
||
f"手续费: {trade['Fees']:.2f}")
|
||
else:
|
||
print("没有交易记录")
|
||
except Exception as e:
|
||
print(f"显示最近交易错误: {e}")
|
||
|
||
# 年度绩效分析
|
||
try:
|
||
print(f"\n年度绩效分析:")
|
||
yearly_returns = portfolio.annual_returns()
|
||
for year, ret in yearly_returns.items():
|
||
print(f"{year}年收益率: {ret:.2%}")
|
||
except Exception as e:
|
||
print(f"年度绩效分析错误: {e}")
|
||
|
||
print("\n回测完成!") |