配对交易最初代码(不能运行版本)

This commit is contained in:
2025-10-31 19:16:36 +08:00
parent 33c243a22a
commit 7175ca8a1c
5 changed files with 2159 additions and 121 deletions

127
comtom_indicator_example.py Normal file
View File

@ -0,0 +1,127 @@
import vectorbt as vbt
import akshare as ak
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import datetime
stock_00981 = ak.stock_hk_daily(symbol="00981")
stock_01347 = ak.stock_hk_daily(symbol="01347")
# 数据预处理 - 设置日期索引
def prepare_data(df):
"""准备数据,设置日期索引"""
df = df.copy()
df['date'] = pd.to_datetime(df['date'])
df = df.set_index('date')
return df
stock_00981 = prepare_data(stock_00981)
stock_01347 = prepare_data(stock_01347)
stock_00981_close = stock_00981['close']
print("中芯国际最近5个交易日收盘价:")
print(stock_00981_close.tail(5))
def custom_indicator(close, rsi_window=14, ma_window=50):
# 计算日线RSI
rsi = vbt.RSI.run(close, window=rsi_window).rsi
# 计算日线移动平均
ma = vbt.MA.run(close, window=ma_window).ma
# 确保所有数组都是numpy数组避免pandas索引问题
close_np = close.to_numpy() if hasattr(close, 'to_numpy') else close
rsi_np = rsi.to_numpy() if hasattr(rsi, 'to_numpy') else rsi
ma_np = ma.to_numpy() if hasattr(ma, 'to_numpy') else ma
# 使用numpy数组进行计算
trend = np.where(rsi_np > 70, -1, 0)
trend = np.where((rsi_np < 30) & (close_np < ma_np), 1, trend)
return trend
# 创建自定义指标 - 移除 keep_pd=True
ind = vbt.IndicatorFactory(
class_name="Combination",
short_name="comb",
input_names=["close"],
param_names=["rsi_window", "ma_window"],
output_names=["value"]
).from_apply_func(
custom_indicator,
rsi_window=14,
ma_window=50
# 移除 keep_pd=True
)
# 运行指标
res = ind.run(
stock_00981_close,
rsi_window=21,
ma_window=[21,50,100]
)
print("指标结果:")
print(res.value)
# 生成交易信号
entries = res.value == 1.0
exits = res.value == -1.0
print(f"买入信号数量: {entries.sum().sum()}")
print(f"卖出信号数量: {exits.sum().sum()}")
# 如果结果有多列,选择第一列
if entries.ndim > 1 and entries.shape[1] > 1:
entries = entries.iloc[:, 0]
exits = exits.iloc[:, 0]
print("使用第一组参数")
# 创建投资组合
pf = vbt.Portfolio.from_signals(stock_00981_close, entries, exits, freq='D')
# 绘图
try:
pf.plot().show()
except Exception as e:
print(f"绘图错误: {e}")
# 简化绘图
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))
# 价格和信号
ax1.plot(stock_00981_close.index, stock_00981_close.values, label='Close Price', color='blue')
# 买入信号
buy_signals = entries.fillna(False)
if buy_signals.any():
buy_dates = stock_00981_close.index[buy_signals]
buy_prices = stock_00981_close[buy_signals]
ax1.scatter(buy_dates, buy_prices, color='green', marker='^', s=100, label='Buy')
# 卖出信号
sell_signals = exits.fillna(False)
if sell_signals.any():
sell_dates = stock_00981_close.index[sell_signals]
sell_prices = stock_00981_close[sell_signals]
ax1.scatter(sell_dates, sell_prices, color='red', marker='v', s=100, label='Sell')
ax1.set_title('Price and Trading Signals')
ax1.legend()
ax1.grid(True)
# 投资组合价值
portfolio_value = pf.value()
ax2.plot(portfolio_value.index, portfolio_value.values, label='Portfolio Value', color='orange')
ax2.set_title('Portfolio Value')
ax2.legend()
ax2.grid(True)
plt.tight_layout()
plt.show()
print("-------------------")
print("投资组合统计:")
stats = pf.stats()
print(stats)
print(f"总收益率: {pf.total_return():.2%}")
print("-------------------")