配对交易最初代码(不能运行版本)
This commit is contained in:
1678
PairsTrading.ipynb
Normal file
1678
PairsTrading.ipynb
Normal file
File diff suppressed because one or more lines are too long
343
PairsTrading.py
Normal file
343
PairsTrading.py
Normal file
@ -0,0 +1,343 @@
|
||||
import vectorbt as vbt
|
||||
import akshare as ak
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import datetime
|
||||
from numba import njit
|
||||
from collections import namedtuple
|
||||
|
||||
# 设置中文显示
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei']
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
|
||||
# 获取股票数据
|
||||
print("正在获取股票数据...")
|
||||
stock_00981 = ak.stock_hk_daily(symbol="00981")
|
||||
stock_01347 = ak.stock_hk_daily(symbol="01347")
|
||||
|
||||
# 数据预处理
|
||||
def preprocess_data(df, symbol):
|
||||
"""预处理股票数据"""
|
||||
df = df.copy()
|
||||
df['date'] = pd.to_datetime(df['日期'])
|
||||
df.set_index('date', inplace=True)
|
||||
df = df.sort_index()
|
||||
# 重命名列以符合vectorbt要求
|
||||
df = df.rename(columns={
|
||||
'开盘': 'open',
|
||||
'最高': 'high',
|
||||
'最低': 'low',
|
||||
'收盘': 'close',
|
||||
'成交量': 'volume'
|
||||
})
|
||||
return df[['open', 'high', 'low', 'close', 'volume']]
|
||||
|
||||
# 预处理数据
|
||||
smic_data = preprocess_data(stock_00981, "00981")
|
||||
hhic_data = preprocess_data(stock_01347, "01347")
|
||||
|
||||
# 对齐数据时间索引
|
||||
common_index = smic_data.index.intersection(hhic_data.index)
|
||||
smic_data = smic_data.loc[common_index]
|
||||
hhic_data = hhic_data.loc[common_index]
|
||||
|
||||
print(f"数据时间范围: {common_index.min()} 到 {common_index.max()}")
|
||||
print(f"总交易日数: {len(common_index)}")
|
||||
|
||||
# 配对交易参数
|
||||
PERIOD = 30 # 回看周期
|
||||
UPPER = 2.0 # 上界
|
||||
LOWER = -2.0 # 下界
|
||||
ORDER_PCT1 = 0.1 # 中芯国际交易比例
|
||||
ORDER_PCT2 = 0.1 # 华虹半导体交易比例
|
||||
COMMPERC = 0.002 # 手续费率 0.2%
|
||||
MODE = 'OLS' # 使用OLS方法
|
||||
INITIAL_CASH = 100000 # 初始资金
|
||||
|
||||
# 定义数据结构
|
||||
Memory = namedtuple("Memory", ('spread', 'zscore', 'status'))
|
||||
Params = namedtuple("Params", ('period', 'upper', 'lower', 'order_pct1', 'order_pct2'))
|
||||
|
||||
@njit
|
||||
def ols_spread_nb(a, b):
|
||||
"""计算OLS价差"""
|
||||
a = np.log(a)
|
||||
b = np.log(b)
|
||||
_b = np.vstack((b, np.ones(len(b)))).T
|
||||
slope, intercept = np.dot(np.linalg.inv(np.dot(_b.T, _b)), np.dot(_b.T, a))
|
||||
spread = a - (slope * b + intercept)
|
||||
return spread[-1]
|
||||
|
||||
@njit
|
||||
def pre_group_func_nb(c, _period, _upper, _lower, _order_pct1, _order_pct2):
|
||||
"""准备当前组(资产对)"""
|
||||
assert c.group_len == 2
|
||||
|
||||
# 初始化内存数组
|
||||
spread = np.full(c.target_shape[0], np.nan, dtype=np.float64)
|
||||
zscore = np.full(c.target_shape[0], np.nan, dtype=np.float64)
|
||||
status = np.full(1, 0, dtype=np.int64)
|
||||
memory = Memory(spread, zscore, status)
|
||||
|
||||
# 选择参数
|
||||
period = _period[0] if len(_period) > 1 else _period
|
||||
upper = _upper[0] if len(_upper) > 1 else _upper
|
||||
lower = _lower[0] if len(_lower) > 1 else _lower
|
||||
order_pct1 = _order_pct1[0] if len(_order_pct1) > 1 else _order_pct1
|
||||
order_pct2 = _order_pct2[0] if len(_order_pct2) > 1 else _order_pct2
|
||||
|
||||
params = Params(period, upper, lower, order_pct1, order_pct2)
|
||||
|
||||
# 创建仓位大小数组
|
||||
size = np.empty(c.group_len, dtype=np.float64)
|
||||
|
||||
return (memory, params, size)
|
||||
|
||||
@njit
|
||||
def pre_segment_func_nb(c, memory, params, size, mode):
|
||||
"""准备当前段(组内行)"""
|
||||
|
||||
# 等待足够的数据
|
||||
if c.i < params.period - 1:
|
||||
size[0] = np.nan
|
||||
size[1] = np.nan
|
||||
return (size,)
|
||||
|
||||
# 计算窗口切片
|
||||
window_slice = slice(max(0, c.i + 1 - params.period), c.i + 1)
|
||||
|
||||
# 根据模式计算价差
|
||||
if mode == 'OLS':
|
||||
a = c.close[window_slice, c.from_col] # 中芯国际
|
||||
b = c.close[window_slice, c.from_col + 1] # 华虹半导体
|
||||
memory.spread[c.i] = ols_spread_nb(a, b)
|
||||
elif mode == 'log_return':
|
||||
# 对数收益率方法
|
||||
logret_a = np.log(c.close[c.i, c.from_col] / c.close[c.i - 1, c.from_col])
|
||||
logret_b = np.log(c.close[c.i, c.from_col + 1] / c.close[c.i - 1, c.from_col + 1])
|
||||
memory.spread[c.i] = logret_a - logret_b
|
||||
else:
|
||||
raise ValueError("Unknown mode")
|
||||
|
||||
# 计算z-score
|
||||
spread_mean = np.mean(memory.spread[window_slice])
|
||||
spread_std = np.std(memory.spread[window_slice])
|
||||
memory.zscore[c.i] = (memory.spread[c.i] - spread_mean) / spread_std
|
||||
|
||||
# 使用前一个z-score生成交易信号(避免未来数据)
|
||||
if c.i > 0 and not np.isnan(memory.zscore[c.i - 1]):
|
||||
# 做空信号:z-score > 上界 且 当前不是做空状态
|
||||
if memory.zscore[c.i - 1] > params.upper and memory.status[0] != 1:
|
||||
size[0] = -params.order_pct1 # 卖空中芯国际
|
||||
size[1] = params.order_pct2 # 买入华虹半导体
|
||||
# 执行顺序:先卖后买
|
||||
c.call_seq_now[0] = 0
|
||||
c.call_seq_now[1] = 1
|
||||
memory.status[0] = 1 # 设置为做空状态
|
||||
|
||||
# 做多信号:z-score < 下界 且 当前不是做多状态
|
||||
elif memory.zscore[c.i - 1] < params.lower and memory.status[0] != 2:
|
||||
size[0] = params.order_pct1 # 买入中芯国际
|
||||
size[1] = -params.order_pct2 # 卖空华虹半导体
|
||||
# 执行顺序:先卖后买
|
||||
c.call_seq_now[0] = 1
|
||||
c.call_seq_now[1] = 0
|
||||
memory.status[0] = 2 # 设置为做多状态
|
||||
|
||||
else:
|
||||
size[0] = np.nan
|
||||
size[1] = np.nan
|
||||
else:
|
||||
size[0] = np.nan
|
||||
size[1] = np.nan
|
||||
|
||||
# 设置估值价格
|
||||
c.last_val_price[c.from_col] = c.close[c.i - 1, c.from_col]
|
||||
c.last_val_price[c.from_col + 1] = c.close[c.i - 1, c.from_col + 1]
|
||||
|
||||
return (size,)
|
||||
|
||||
@njit
|
||||
def order_func_nb(c, size, price, commperc):
|
||||
"""执行订单"""
|
||||
group_col = c.col - c.from_col
|
||||
return vbt.portfolio.nb.order_nb(
|
||||
size=size[group_col],
|
||||
price=price[c.i, c.col],
|
||||
size_type=vbt.portfolio.enums.SizeType.TargetPercent,
|
||||
fees=commperc
|
||||
)
|
||||
|
||||
# 准备价格数据
|
||||
print("准备回测数据...")
|
||||
price_data = pd.DataFrame({
|
||||
'SMIC': smic_data['close'],
|
||||
'HHIC': hhic_data['close']
|
||||
})
|
||||
|
||||
# 运行配对交易回测
|
||||
print("运行配对交易回测...")
|
||||
portfolio = vbt.Portfolio.from_order_func(
|
||||
price_data,
|
||||
order_func_nb,
|
||||
pre_group_func_nb=pre_group_func_nb,
|
||||
pre_segment_func_nb=pre_segment_func_nb,
|
||||
pre_group_args=(
|
||||
np.array([PERIOD]), # period
|
||||
np.array([UPPER]), # upper
|
||||
np.array([LOWER]), # lower
|
||||
np.array([ORDER_PCT1]), # order_pct1
|
||||
np.array([ORDER_PCT2]) # order_pct2
|
||||
),
|
||||
pre_segment_args=(MODE,), # mode
|
||||
order_args=(COMMPERC,), # commperc
|
||||
group_by=np.array([0, 0]), # 将两列分为同一组
|
||||
cash_sharing=True, # 共享现金
|
||||
cash=INITIAL_CASH,
|
||||
freq='1D'
|
||||
)
|
||||
|
||||
# 计算额外指标
|
||||
def calculate_additional_metrics(portfolio, price_data):
|
||||
"""计算额外指标"""
|
||||
# 价差和z-score(需要重新计算用于分析)
|
||||
smic_close = price_data['SMIC'].values
|
||||
hhic_close = price_data['HHIC'].values
|
||||
|
||||
if MODE == 'OLS':
|
||||
# 使用OLS方法计算价差
|
||||
spread = np.full(len(smic_close), np.nan)
|
||||
for i in range(PERIOD, len(smic_close)):
|
||||
window_slice = slice(i - PERIOD + 1, i + 1)
|
||||
a = smic_close[window_slice]
|
||||
b = hhic_close[window_slice]
|
||||
spread[i] = ols_spread_nb(a, b)
|
||||
else:
|
||||
# 对数收益率方法
|
||||
spread = np.full(len(smic_close), np.nan)
|
||||
spread[1:] = np.log(smic_close[1:] / smic_close[:-1]) - np.log(hhic_close[1:] / hhic_close[:-1])
|
||||
|
||||
# 计算z-score
|
||||
zscore = np.full(len(spread), np.nan)
|
||||
for i in range(PERIOD, len(spread)):
|
||||
window_slice = slice(i - PERIOD + 1, i + 1)
|
||||
spread_mean = np.nanmean(spread[window_slice])
|
||||
spread_std = np.nanstd(spread[window_slice])
|
||||
if spread_std > 0:
|
||||
zscore[i] = (spread[i] - spread_mean) / spread_std
|
||||
|
||||
return pd.Series(spread, index=price_data.index), pd.Series(zscore, index=price_data.index)
|
||||
|
||||
# 计算指标
|
||||
spread_series, zscore_series = calculate_additional_metrics(portfolio, price_data)
|
||||
|
||||
# 生成交易信号
|
||||
short_signals = (zscore_series > UPPER).rename('short_signals')
|
||||
long_signals = (zscore_series < LOWER).rename('long_signals')
|
||||
|
||||
# 可视化结果
|
||||
print("生成分析图表...")
|
||||
|
||||
# 创建子图
|
||||
fig = plt.figure(figsize=(15, 12))
|
||||
|
||||
# 1. 价格走势
|
||||
ax1 = plt.subplot(3, 1, 1)
|
||||
plt.plot(price_data.index, price_data['SMIC'], label='中芯国际(00981)', linewidth=1)
|
||||
plt.plot(price_data.index, price_data['HHIC'], label='华虹半导体(01347)', linewidth=1)
|
||||
plt.title('股票价格走势')
|
||||
plt.ylabel('价格(港元)')
|
||||
plt.legend()
|
||||
plt.grid(True, alpha=0.3)
|
||||
|
||||
# 2. 价差和z-score
|
||||
ax2 = plt.subplot(3, 1, 2)
|
||||
plt.plot(spread_series.index, spread_series.values, label='价差', color='blue', linewidth=1)
|
||||
plt.title('价差走势')
|
||||
plt.ylabel('价差')
|
||||
plt.legend()
|
||||
plt.grid(True, alpha=0.3)
|
||||
|
||||
ax3 = ax2.twinx()
|
||||
plt.plot(zscore_series.index, zscore_series.values, label='Z-Score', color='red', linewidth=1, alpha=0.7)
|
||||
plt.axhline(y=UPPER, color='red', linestyle='--', alpha=0.5, label=f'上界({UPPER})')
|
||||
plt.axhline(y=0, color='black', linestyle='-', alpha=0.3)
|
||||
plt.axhline(y=LOWER, color='green', linestyle='--', alpha=0.5, label=f'下界({LOWER})')
|
||||
plt.ylabel('Z-Score')
|
||||
plt.legend()
|
||||
|
||||
# 标记交易信号
|
||||
short_points = zscore_series[short_signals]
|
||||
long_points = zscore_series[long_signals]
|
||||
plt.scatter(short_points.index, short_points.values, color='red', marker='v', s=30, label='做空信号')
|
||||
plt.scatter(long_points.index, long_points.values, color='green', marker='^', s=30, label='做多信号')
|
||||
|
||||
# 3. 资产价值
|
||||
ax4 = plt.subplot(3, 1, 3)
|
||||
portfolio.value().vbt.plot(ax=ax4, title='投资组合价值')
|
||||
plt.ylabel('资产价值(港元)')
|
||||
plt.xlabel('日期')
|
||||
plt.grid(True, alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
# 输出回测结果
|
||||
print("\n" + "="*50)
|
||||
print("配对交易回测结果")
|
||||
print("="*50)
|
||||
|
||||
# 基本统计
|
||||
stats = portfolio.stats()
|
||||
print(f"初始资金: {INITIAL_CASH:,.2f} 港元")
|
||||
print(f"最终资产: {portfolio.value().iloc[-1]:,.2f} 港元")
|
||||
print(f"总收益率: {stats['Total Return']:.2%}")
|
||||
print(f"年化收益率: {stats['Annual Return']:.2%}")
|
||||
print(f"夏普比率: {stats['Sharpe Ratio']:.2f}")
|
||||
print(f"最大回撤: {stats['Max Drawdown']:.2%}")
|
||||
print(f"总交易次数: {stats['Total Trades']}")
|
||||
|
||||
# 交易统计
|
||||
orders = portfolio.orders()
|
||||
if len(orders) > 0:
|
||||
print(f"\n交易统计:")
|
||||
print(f"中芯国际交易次数: {len(orders[orders['Column'] == 'SMIC'])}")
|
||||
print(f"华虹半导体交易次数: {len(orders[orders['Column'] == 'HHIC'])}")
|
||||
|
||||
# 计算平均持仓时间
|
||||
positions = portfolio.positions()
|
||||
if len(positions) > 0:
|
||||
avg_holding = positions.duration.mean()
|
||||
print(f"平均持仓时间: {avg_holding:.1f} 天")
|
||||
|
||||
# 信号统计
|
||||
print(f"\n信号统计:")
|
||||
print(f"做空信号数量: {short_signals.sum()}")
|
||||
print(f"做多信号数量: {long_signals.sum()}")
|
||||
print(f"总信号数量: {short_signals.sum() + long_signals.sum()}")
|
||||
|
||||
# 相关性分析
|
||||
correlation = price_data['SMIC'].corr(price_data['HHIC'])
|
||||
print(f"\n股票相关性: {correlation:.4f}")
|
||||
|
||||
# 保存结果到文件
|
||||
results_df = pd.DataFrame({
|
||||
'Date': price_data.index,
|
||||
'SMIC_Price': price_data['SMIC'],
|
||||
'HHIC_Price': price_data['HHIC'],
|
||||
'Spread': spread_series,
|
||||
'ZScore': zscore_series,
|
||||
'Portfolio_Value': portfolio.value(),
|
||||
'Short_Signals': short_signals,
|
||||
'Long_Signals': long_signals
|
||||
})
|
||||
|
||||
results_df.to_csv('pair_trading_results.csv', index=False)
|
||||
print(f"\n详细结果已保存到: pair_trading_results.csv")
|
||||
|
||||
# 显示最近交易
|
||||
recent_trades = orders.tail(10)
|
||||
if len(recent_trades) > 0:
|
||||
print(f"\n最近10笔交易:")
|
||||
print(recent_trades[['Timestamp', 'Column', 'Size', 'Price', 'Fees']])
|
||||
127
comtom_indicator_example.py
Normal file
127
comtom_indicator_example.py
Normal file
@ -0,0 +1,127 @@
|
||||
import vectorbt as vbt
|
||||
import akshare as ak
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import datetime
|
||||
|
||||
stock_00981 = ak.stock_hk_daily(symbol="00981")
|
||||
stock_01347 = ak.stock_hk_daily(symbol="01347")
|
||||
|
||||
# 数据预处理 - 设置日期索引
|
||||
def prepare_data(df):
|
||||
"""准备数据,设置日期索引"""
|
||||
df = df.copy()
|
||||
df['date'] = pd.to_datetime(df['date'])
|
||||
df = df.set_index('date')
|
||||
return df
|
||||
|
||||
stock_00981 = prepare_data(stock_00981)
|
||||
stock_01347 = prepare_data(stock_01347)
|
||||
|
||||
stock_00981_close = stock_00981['close']
|
||||
print("中芯国际最近5个交易日收盘价:")
|
||||
print(stock_00981_close.tail(5))
|
||||
|
||||
def custom_indicator(close, rsi_window=14, ma_window=50):
|
||||
# 计算日线RSI
|
||||
rsi = vbt.RSI.run(close, window=rsi_window).rsi
|
||||
|
||||
# 计算日线移动平均
|
||||
ma = vbt.MA.run(close, window=ma_window).ma
|
||||
|
||||
# 确保所有数组都是numpy数组,避免pandas索引问题
|
||||
close_np = close.to_numpy() if hasattr(close, 'to_numpy') else close
|
||||
rsi_np = rsi.to_numpy() if hasattr(rsi, 'to_numpy') else rsi
|
||||
ma_np = ma.to_numpy() if hasattr(ma, 'to_numpy') else ma
|
||||
|
||||
# 使用numpy数组进行计算
|
||||
trend = np.where(rsi_np > 70, -1, 0)
|
||||
trend = np.where((rsi_np < 30) & (close_np < ma_np), 1, trend)
|
||||
return trend
|
||||
|
||||
# 创建自定义指标 - 移除 keep_pd=True
|
||||
ind = vbt.IndicatorFactory(
|
||||
class_name="Combination",
|
||||
short_name="comb",
|
||||
input_names=["close"],
|
||||
param_names=["rsi_window", "ma_window"],
|
||||
output_names=["value"]
|
||||
).from_apply_func(
|
||||
custom_indicator,
|
||||
rsi_window=14,
|
||||
ma_window=50
|
||||
# 移除 keep_pd=True
|
||||
)
|
||||
|
||||
# 运行指标
|
||||
res = ind.run(
|
||||
stock_00981_close,
|
||||
rsi_window=21,
|
||||
ma_window=[21,50,100]
|
||||
)
|
||||
|
||||
print("指标结果:")
|
||||
print(res.value)
|
||||
|
||||
# 生成交易信号
|
||||
entries = res.value == 1.0
|
||||
exits = res.value == -1.0
|
||||
|
||||
print(f"买入信号数量: {entries.sum().sum()}")
|
||||
print(f"卖出信号数量: {exits.sum().sum()}")
|
||||
|
||||
# 如果结果有多列,选择第一列
|
||||
if entries.ndim > 1 and entries.shape[1] > 1:
|
||||
entries = entries.iloc[:, 0]
|
||||
exits = exits.iloc[:, 0]
|
||||
print("使用第一组参数")
|
||||
|
||||
# 创建投资组合
|
||||
pf = vbt.Portfolio.from_signals(stock_00981_close, entries, exits, freq='D')
|
||||
|
||||
# 绘图
|
||||
try:
|
||||
pf.plot().show()
|
||||
except Exception as e:
|
||||
print(f"绘图错误: {e}")
|
||||
# 简化绘图
|
||||
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))
|
||||
|
||||
# 价格和信号
|
||||
ax1.plot(stock_00981_close.index, stock_00981_close.values, label='Close Price', color='blue')
|
||||
|
||||
# 买入信号
|
||||
buy_signals = entries.fillna(False)
|
||||
if buy_signals.any():
|
||||
buy_dates = stock_00981_close.index[buy_signals]
|
||||
buy_prices = stock_00981_close[buy_signals]
|
||||
ax1.scatter(buy_dates, buy_prices, color='green', marker='^', s=100, label='Buy')
|
||||
|
||||
# 卖出信号
|
||||
sell_signals = exits.fillna(False)
|
||||
if sell_signals.any():
|
||||
sell_dates = stock_00981_close.index[sell_signals]
|
||||
sell_prices = stock_00981_close[sell_signals]
|
||||
ax1.scatter(sell_dates, sell_prices, color='red', marker='v', s=100, label='Sell')
|
||||
|
||||
ax1.set_title('Price and Trading Signals')
|
||||
ax1.legend()
|
||||
ax1.grid(True)
|
||||
|
||||
# 投资组合价值
|
||||
portfolio_value = pf.value()
|
||||
ax2.plot(portfolio_value.index, portfolio_value.values, label='Portfolio Value', color='orange')
|
||||
ax2.set_title('Portfolio Value')
|
||||
ax2.legend()
|
||||
ax2.grid(True)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
print("-------------------")
|
||||
print("投资组合统计:")
|
||||
stats = pf.stats()
|
||||
print(stats)
|
||||
print(f"总收益率: {pf.total_return():.2%}")
|
||||
print("-------------------")
|
||||
BIN
python-for-algorithmic-trading-cookbook-main.zip
Normal file
BIN
python-for-algorithmic-trading-cookbook-main.zip
Normal file
Binary file not shown.
132
t.py
132
t.py
@ -1,127 +1,17 @@
|
||||
import vectorbt as vbt
|
||||
import akshare as ak
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import datetime
|
||||
|
||||
stock_00981 = ak.stock_hk_daily(symbol="00981")
|
||||
stock_01347 = ak.stock_hk_daily(symbol="01347")
|
||||
# 创建两个DataFrame
|
||||
df1 = pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]}, index=['x', 'y', 'z'])
|
||||
df2 = pd.DataFrame({'A': [7, 8, 9], 'C': [10, 11, 12]}, index=['y', 'z', 'w'])
|
||||
|
||||
# 数据预处理 - 设置日期索引
|
||||
def prepare_data(df):
|
||||
"""准备数据,设置日期索引"""
|
||||
df = df.copy()
|
||||
df['date'] = pd.to_datetime(df['date'])
|
||||
df = df.set_index('date')
|
||||
return df
|
||||
print(df1)
|
||||
print(df2)
|
||||
|
||||
stock_00981 = prepare_data(stock_00981)
|
||||
stock_01347 = prepare_data(stock_01347)
|
||||
# 对齐两个DataFrame
|
||||
aligned_df1, aligned_df2 = df1.align(df2, broadcast_axis=0)
|
||||
|
||||
stock_00981_close = stock_00981['close']
|
||||
print("中芯国际最近5个交易日收盘价:")
|
||||
print(stock_00981_close.tail(5))
|
||||
|
||||
def custom_indicator(close, rsi_window=14, ma_window=50):
|
||||
# 计算日线RSI
|
||||
rsi = vbt.RSI.run(close, window=rsi_window).rsi
|
||||
|
||||
# 计算日线移动平均
|
||||
ma = vbt.MA.run(close, window=ma_window).ma
|
||||
|
||||
# 确保所有数组都是numpy数组,避免pandas索引问题
|
||||
close_np = close.to_numpy() if hasattr(close, 'to_numpy') else close
|
||||
rsi_np = rsi.to_numpy() if hasattr(rsi, 'to_numpy') else rsi
|
||||
ma_np = ma.to_numpy() if hasattr(ma, 'to_numpy') else ma
|
||||
|
||||
# 使用numpy数组进行计算
|
||||
trend = np.where(rsi_np > 70, -1, 0)
|
||||
trend = np.where((rsi_np < 30) & (close_np < ma_np), 1, trend)
|
||||
return trend
|
||||
|
||||
# 创建自定义指标 - 移除 keep_pd=True
|
||||
ind = vbt.IndicatorFactory(
|
||||
class_name="Combination",
|
||||
short_name="comb",
|
||||
input_names=["close"],
|
||||
param_names=["rsi_window", "ma_window"],
|
||||
output_names=["value"]
|
||||
).from_apply_func(
|
||||
custom_indicator,
|
||||
rsi_window=14,
|
||||
ma_window=50
|
||||
# 移除 keep_pd=True
|
||||
)
|
||||
|
||||
# 运行指标
|
||||
res = ind.run(
|
||||
stock_00981_close,
|
||||
rsi_window=21,
|
||||
ma_window=50
|
||||
)
|
||||
|
||||
print("指标结果:")
|
||||
print(res.value)
|
||||
|
||||
# 生成交易信号
|
||||
entries = res.value == 1.0
|
||||
exits = res.value == -1.0
|
||||
|
||||
print(f"买入信号数量: {entries.sum().sum()}")
|
||||
print(f"卖出信号数量: {exits.sum().sum()}")
|
||||
|
||||
# 如果结果有多列,选择第一列
|
||||
if entries.ndim > 1 and entries.shape[1] > 1:
|
||||
entries = entries.iloc[:, 0]
|
||||
exits = exits.iloc[:, 0]
|
||||
print("使用第一组参数")
|
||||
|
||||
# 创建投资组合
|
||||
pf = vbt.Portfolio.from_signals(stock_00981_close, entries, exits, freq='D')
|
||||
|
||||
# 绘图
|
||||
try:
|
||||
pf.plot().show()
|
||||
except Exception as e:
|
||||
print(f"绘图错误: {e}")
|
||||
# 简化绘图
|
||||
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))
|
||||
|
||||
# 价格和信号
|
||||
ax1.plot(stock_00981_close.index, stock_00981_close.values, label='Close Price', color='blue')
|
||||
|
||||
# 买入信号
|
||||
buy_signals = entries.fillna(False)
|
||||
if buy_signals.any():
|
||||
buy_dates = stock_00981_close.index[buy_signals]
|
||||
buy_prices = stock_00981_close[buy_signals]
|
||||
ax1.scatter(buy_dates, buy_prices, color='green', marker='^', s=100, label='Buy')
|
||||
|
||||
# 卖出信号
|
||||
sell_signals = exits.fillna(False)
|
||||
if sell_signals.any():
|
||||
sell_dates = stock_00981_close.index[sell_signals]
|
||||
sell_prices = stock_00981_close[sell_signals]
|
||||
ax1.scatter(sell_dates, sell_prices, color='red', marker='v', s=100, label='Sell')
|
||||
|
||||
ax1.set_title('Price and Trading Signals')
|
||||
ax1.legend()
|
||||
ax1.grid(True)
|
||||
|
||||
# 投资组合价值
|
||||
portfolio_value = pf.value()
|
||||
ax2.plot(portfolio_value.index, portfolio_value.values, label='Portfolio Value', color='orange')
|
||||
ax2.set_title('Portfolio Value')
|
||||
ax2.legend()
|
||||
ax2.grid(True)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
print("-------------------")
|
||||
print("投资组合统计:")
|
||||
stats = pf.stats()
|
||||
print(stats)
|
||||
print(f"总收益率: {pf.total_return():.2%}")
|
||||
print("-------------------")
|
||||
print("对齐后的df1:")
|
||||
print(aligned_df1)
|
||||
print("\n对齐后的df2:")
|
||||
print(aligned_df2)
|
||||
Reference in New Issue
Block a user