Files
quant/comtom_indicator_example.py

127 lines
3.6 KiB
Python
Raw Permalink Normal View History

import vectorbt as vbt
import akshare as ak
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import datetime
stock_00981 = ak.stock_hk_daily(symbol="00981")
stock_01347 = ak.stock_hk_daily(symbol="01347")
# 数据预处理 - 设置日期索引
def prepare_data(df):
"""准备数据,设置日期索引"""
df = df.copy()
df['date'] = pd.to_datetime(df['date'])
df = df.set_index('date')
return df
stock_00981 = prepare_data(stock_00981)
stock_01347 = prepare_data(stock_01347)
stock_00981_close = stock_00981['close']
print("中芯国际最近5个交易日收盘价:")
print(stock_00981_close.tail(5))
def custom_indicator(close, rsi_window=14, ma_window=50):
# 计算日线RSI
rsi = vbt.RSI.run(close, window=rsi_window).rsi
# 计算日线移动平均
ma = vbt.MA.run(close, window=ma_window).ma
# 确保所有数组都是numpy数组避免pandas索引问题
close_np = close.to_numpy() if hasattr(close, 'to_numpy') else close
rsi_np = rsi.to_numpy() if hasattr(rsi, 'to_numpy') else rsi
ma_np = ma.to_numpy() if hasattr(ma, 'to_numpy') else ma
# 使用numpy数组进行计算
trend = np.where(rsi_np > 70, -1, 0)
trend = np.where((rsi_np < 30) & (close_np < ma_np), 1, trend)
return trend
# 创建自定义指标 - 移除 keep_pd=True
ind = vbt.IndicatorFactory(
class_name="Combination",
short_name="comb",
input_names=["close"],
param_names=["rsi_window", "ma_window"],
output_names=["value"]
).from_apply_func(
custom_indicator,
rsi_window=14,
ma_window=50
# 移除 keep_pd=True
)
# 运行指标
res = ind.run(
stock_00981_close,
rsi_window=21,
ma_window=[21,50,100]
)
print("指标结果:")
print(res.value)
# 生成交易信号
entries = res.value == 1.0
exits = res.value == -1.0
print(f"买入信号数量: {entries.sum().sum()}")
print(f"卖出信号数量: {exits.sum().sum()}")
# 如果结果有多列,选择第一列
if entries.ndim > 1 and entries.shape[1] > 1:
entries = entries.iloc[:, 0]
exits = exits.iloc[:, 0]
print("使用第一组参数")
# 创建投资组合
pf = vbt.Portfolio.from_signals(stock_00981_close, entries, exits, freq='D')
# 绘图
try:
pf.plot().show()
except Exception as e:
print(f"绘图错误: {e}")
# 简化绘图
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))
# 价格和信号
ax1.plot(stock_00981_close.index, stock_00981_close.values, label='Close Price', color='blue')
# 买入信号
buy_signals = entries.fillna(False)
if buy_signals.any():
buy_dates = stock_00981_close.index[buy_signals]
buy_prices = stock_00981_close[buy_signals]
ax1.scatter(buy_dates, buy_prices, color='green', marker='^', s=100, label='Buy')
# 卖出信号
sell_signals = exits.fillna(False)
if sell_signals.any():
sell_dates = stock_00981_close.index[sell_signals]
sell_prices = stock_00981_close[sell_signals]
ax1.scatter(sell_dates, sell_prices, color='red', marker='v', s=100, label='Sell')
ax1.set_title('Price and Trading Signals')
ax1.legend()
ax1.grid(True)
# 投资组合价值
portfolio_value = pf.value()
ax2.plot(portfolio_value.index, portfolio_value.values, label='Portfolio Value', color='orange')
ax2.set_title('Portfolio Value')
ax2.legend()
ax2.grid(True)
plt.tight_layout()
plt.show()
print("-------------------")
print("投资组合统计:")
stats = pf.stats()
print(stats)
print(f"总收益率: {pf.total_return():.2%}")
print("-------------------")