127 lines
3.6 KiB
Python
127 lines
3.6 KiB
Python
import vectorbt as vbt
|
||
import akshare as ak
|
||
import pandas as pd
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
import datetime
|
||
|
||
stock_00981 = ak.stock_hk_daily(symbol="00981")
|
||
stock_01347 = ak.stock_hk_daily(symbol="01347")
|
||
|
||
# 数据预处理 - 设置日期索引
|
||
def prepare_data(df):
|
||
"""准备数据,设置日期索引"""
|
||
df = df.copy()
|
||
df['date'] = pd.to_datetime(df['date'])
|
||
df = df.set_index('date')
|
||
return df
|
||
|
||
stock_00981 = prepare_data(stock_00981)
|
||
stock_01347 = prepare_data(stock_01347)
|
||
|
||
stock_00981_close = stock_00981['close']
|
||
print("中芯国际最近5个交易日收盘价:")
|
||
print(stock_00981_close.tail(5))
|
||
|
||
def custom_indicator(close, rsi_window=14, ma_window=50):
|
||
# 计算日线RSI
|
||
rsi = vbt.RSI.run(close, window=rsi_window).rsi
|
||
|
||
# 计算日线移动平均
|
||
ma = vbt.MA.run(close, window=ma_window).ma
|
||
|
||
# 确保所有数组都是numpy数组,避免pandas索引问题
|
||
close_np = close.to_numpy() if hasattr(close, 'to_numpy') else close
|
||
rsi_np = rsi.to_numpy() if hasattr(rsi, 'to_numpy') else rsi
|
||
ma_np = ma.to_numpy() if hasattr(ma, 'to_numpy') else ma
|
||
|
||
# 使用numpy数组进行计算
|
||
trend = np.where(rsi_np > 70, -1, 0)
|
||
trend = np.where((rsi_np < 30) & (close_np < ma_np), 1, trend)
|
||
return trend
|
||
|
||
# 创建自定义指标 - 移除 keep_pd=True
|
||
ind = vbt.IndicatorFactory(
|
||
class_name="Combination",
|
||
short_name="comb",
|
||
input_names=["close"],
|
||
param_names=["rsi_window", "ma_window"],
|
||
output_names=["value"]
|
||
).from_apply_func(
|
||
custom_indicator,
|
||
rsi_window=14,
|
||
ma_window=50
|
||
# 移除 keep_pd=True
|
||
)
|
||
|
||
# 运行指标
|
||
res = ind.run(
|
||
stock_00981_close,
|
||
rsi_window=21,
|
||
ma_window=[21,50,100]
|
||
)
|
||
|
||
print("指标结果:")
|
||
print(res.value)
|
||
|
||
# 生成交易信号
|
||
entries = res.value == 1.0
|
||
exits = res.value == -1.0
|
||
|
||
print(f"买入信号数量: {entries.sum().sum()}")
|
||
print(f"卖出信号数量: {exits.sum().sum()}")
|
||
|
||
# 如果结果有多列,选择第一列
|
||
if entries.ndim > 1 and entries.shape[1] > 1:
|
||
entries = entries.iloc[:, 0]
|
||
exits = exits.iloc[:, 0]
|
||
print("使用第一组参数")
|
||
|
||
# 创建投资组合
|
||
pf = vbt.Portfolio.from_signals(stock_00981_close, entries, exits, freq='D')
|
||
|
||
# 绘图
|
||
try:
|
||
pf.plot().show()
|
||
except Exception as e:
|
||
print(f"绘图错误: {e}")
|
||
# 简化绘图
|
||
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))
|
||
|
||
# 价格和信号
|
||
ax1.plot(stock_00981_close.index, stock_00981_close.values, label='Close Price', color='blue')
|
||
|
||
# 买入信号
|
||
buy_signals = entries.fillna(False)
|
||
if buy_signals.any():
|
||
buy_dates = stock_00981_close.index[buy_signals]
|
||
buy_prices = stock_00981_close[buy_signals]
|
||
ax1.scatter(buy_dates, buy_prices, color='green', marker='^', s=100, label='Buy')
|
||
|
||
# 卖出信号
|
||
sell_signals = exits.fillna(False)
|
||
if sell_signals.any():
|
||
sell_dates = stock_00981_close.index[sell_signals]
|
||
sell_prices = stock_00981_close[sell_signals]
|
||
ax1.scatter(sell_dates, sell_prices, color='red', marker='v', s=100, label='Sell')
|
||
|
||
ax1.set_title('Price and Trading Signals')
|
||
ax1.legend()
|
||
ax1.grid(True)
|
||
|
||
# 投资组合价值
|
||
portfolio_value = pf.value()
|
||
ax2.plot(portfolio_value.index, portfolio_value.values, label='Portfolio Value', color='orange')
|
||
ax2.set_title('Portfolio Value')
|
||
ax2.legend()
|
||
ax2.grid(True)
|
||
|
||
plt.tight_layout()
|
||
plt.show()
|
||
|
||
print("-------------------")
|
||
print("投资组合统计:")
|
||
stats = pf.stats()
|
||
print(stats)
|
||
print(f"总收益率: {pf.total_return():.2%}")
|
||
print("-------------------") |