增加vectorbt绘图

This commit is contained in:
2025-11-01 16:03:50 +08:00
parent 4d909e65c2
commit a66fcd0da9

View File

@ -199,13 +199,12 @@ try:
print("投资组合创建成功!")
# 计算配对交易统计 - 修复统计计算
# 计算配对交易统计
print("\n=== 配对交易策略表现 ===")
# 获取组合总价值
portfolio_value = portfolio.value()
# 绘制结果
print("\n绘制图表...")
@ -252,6 +251,184 @@ try:
plt.tight_layout()
plt.show()
# ========== 修复vectorbt自带的可视化 ==========
print("\n=== VectorBT 专业可视化 ===")
# 1. 组合价值、持仓和交易的可视化 - 分别绘制每只股票
print("\n绘制中芯国际分析...")
try:
# 中芯国际
fig1 = portfolio['00981'].plot(subplots=[
'orders', # 订单
'trade_pnl', # 交易盈亏
'cum_returns', # 累积收益
'drawdowns' # 回撤
])
fig1.update_layout(
title='中芯国际配对交易分析',
height=800
)
fig1.show()
except Exception as e:
print(f"中芯国际分析绘制失败: {e}")
print("\n绘制华虹半导体分析...")
try:
# 华虹半导体
fig2 = portfolio['01347'].plot(subplots=[
'orders', # 订单
'trade_pnl', # 交易盈亏
'cum_returns', # 累积收益
'drawdowns' # 回撤
])
fig2.update_layout(
title='华虹半导体配对交易分析',
height=800
)
fig2.show()
except Exception as e:
print(f"华虹半导体分析绘制失败: {e}")
# 2. 资产价值变化 - 使用组合总价值
print("\n绘制组合总价值变化...")
try:
fig = portfolio_value.vbt.plot(
title='配对交易组合总价值',
xlabel='日期',
ylabel='组合价值'
)
fig.show()
except Exception as e:
print(f"组合价值变化绘制失败: {e}")
# 3. 累积收益 - 使用组合总收益
print("\n绘制组合累积收益...")
try:
cumulative_returns = portfolio.cumulative_returns()
# 如果是多列,取第一列或平均值
if hasattr(cumulative_returns, 'columns') and len(cumulative_returns.columns) > 1:
cumulative_returns = cumulative_returns.mean(axis=1)
fig = cumulative_returns.vbt.plot(
title='配对交易累积收益率',
xlabel='日期',
ylabel='累积收益'
)
fig.show()
except Exception as e:
print(f"累积收益绘制失败: {e}")
# 4. 回撤分析 - 使用组合总回撤
print("\n绘制组合回撤分析...")
try:
# 如果是多列,取第一列或平均值
if hasattr(drawdown, 'columns') and len(drawdown.columns) > 1:
drawdown = drawdown.mean(axis=1)
fig = drawdown.vbt.plot(
title='配对交易回撤分析',
xlabel='日期',
ylabel='回撤'
)
fig.show()
except Exception as e:
print(f"回撤分析绘制失败: {e}")
# 5. 月度收益热力图
print("\n绘制月度收益热力图...")
try:
monthly_returns = portfolio.returns().vbt.to_monthly()
# 如果是多列,取平均值
if hasattr(monthly_returns, 'columns') and len(monthly_returns.columns) > 1:
monthly_returns = monthly_returns.mean(axis=1)
fig = monthly_returns.vbt.heatmap(
xaxis_title='年份',
yaxis_title='月份',
title='配对交易月度收益热力图'
)
fig.show()
except Exception as e:
print(f"月度收益热力图绘制失败: {e}")
# 6. 交易分析 - 分别分析每只股票
print("\n绘制交易分析...")
try:
# 中芯国际交易分析
trades_smic = portfolio['00981'].trades
if len(trades_smic) > 0:
fig = trades_smic.plot_pnl()
fig.update_layout(title='中芯国际交易盈亏分布')
fig.show()
fig = trades_smic.plot_duration()
fig.update_layout(title='中芯国际交易持续时间分布')
fig.show()
# 华虹半导体交易分析
trades_hhic = portfolio['01347'].trades
if len(trades_hhic) > 0:
fig = trades_hhic.plot_pnl()
fig.update_layout(title='华虹半导体交易盈亏分布')
fig.show()
fig = trades_hhic.plot_duration()
fig.update_layout(title='华虹半导体交易持续时间分布')
fig.show()
except Exception as e:
print(f"交易分析绘制失败: {e}")
# 7. 订单流分析 - 分别分析每只股票
print("\n绘制订单流分析...")
try:
# 中芯国际订单
fig1 = portfolio['00981'].orders.plot()
fig1.update_layout(title='中芯国际订单流分析')
fig1.show()
# 华虹半导体订单
fig2 = portfolio['01347'].orders.plot()
fig2.update_layout(title='华虹半导体订单流分析')
fig2.show()
except Exception as e:
print(f"订单流分析绘制失败: {e}")
# 8. 持仓分析 - 分别分析每只股票
print("\n绘制持仓分析...")
try:
# 中芯国际持仓
holdings_smic = portfolio['00981'].holdings
fig1 = holdings_smic.vbt.plot(
title='中芯国际持仓变化',
xlabel='日期',
ylabel='持仓价值'
)
fig1.show()
# 华虹半导体持仓
holdings_hhic = portfolio['01347'].holdings
fig2 = holdings_hhic.vbt.plot(
title='华虹半导体持仓变化',
xlabel='日期',
ylabel='持仓价值'
)
fig2.show()
except Exception as e:
print(f"持仓分析绘制失败: {e}")
# 9. 现金变化 - 使用组合总现金
print("\n绘制现金变化...")
try:
cash = portfolio.cash
# 如果是多列,取第一列
if hasattr(cash, 'columns') and len(cash.columns) > 1:
cash = cash.iloc[:, 0]
fig = cash.vbt.plot(
title='配对交易现金余额变化',
xlabel='日期',
ylabel='现金余额'
)
fig.show()
except Exception as e:
print(f"现金变化绘制失败: {e}")
# 显示交易记录
non_zero_size = size[(size != 0).any(axis=1)]
print(f"\n非零交易数量: {len(non_zero_size)}")
@ -259,48 +436,51 @@ try:
print("交易记录示例:")
print(non_zero_size.head(20))
# 打印详细统计 - 使用更安全的方式
# 打印详细统计
print("\n=== 详细统计 ===")
try:
stats = portfolio.stats()
# 安全地获取统计值
# 分别获取每只股票的统计
stats_smic = portfolio['00981'].stats()
stats_hhic = portfolio['01347'].stats()
def safe_get_stat(stat_dict, key, default="N/A"):
value = stat_dict.get(key, default)
if hasattr(value, 'iloc'):
return value.iloc[0] if len(value) == 1 else value
return value
print(f"开始日期: {safe_get_stat(stats, 'Start')}")
print(f"结束日期: {safe_get_stat(stats, 'End')}")
print(f"期间: {safe_get_stat(stats, 'Period')}")
print(f"总收益率: {safe_get_stat(stats, 'Total Return [%]', 'N/A')}%")
print(f"年化收益率: {safe_get_stat(stats, 'Annual Return [%]', 'N/A')}%")
print(f"年化波动率: {safe_get_stat(stats, 'Annual Volatility [%]', 'N/A')}%")
print(f"夏普比率: {safe_get_stat(stats, 'Sharpe Ratio', 'N/A')}")
print(f"最大回撤: {safe_get_stat(stats, 'Max Drawdown [%]', 'N/A')}%")
print(f"总交易次数: {safe_get_stat(stats, 'Total Trades', 'N/A')}")
print(f"率: {safe_get_stat(stats, 'Win Rate [%]', 'N/A')}%")
print("\n中芯国际统计:")
print(f"总收益率: {safe_get_stat(stats_smic, 'Total Return [%]', 'N/A')}%")
print(f"年化收益率: {safe_get_stat(stats_smic, 'Annual Return [%]', 'N/A')}%")
print(f"夏普比率: {safe_get_stat(stats_smic, 'Sharpe Ratio', 'N/A')}")
print(f"最大回撤: {safe_get_stat(stats_smic, 'Max Drawdown [%]', 'N/A')}%")
print("\n华虹半导体统计:")
print(f"总收益率: {safe_get_stat(stats_hhic, 'Total Return [%]', 'N/A')}%")
print(f"年化收益率: {safe_get_stat(stats_hhic, 'Annual Return [%]', 'N/A')}%")
print(f"夏普比率: {safe_get_stat(stats_hhic, 'Sharpe Ratio', 'N/A')}")
print(f"最大回撤: {safe_get_stat(stats_hhic, 'Max Drawdown [%]', 'N/A')}%")
except Exception as e:
print(f"获取详细统计时出错: {e}")
# 分析每笔交易
try:
trades = portfolio.trades.records_readable
if len(trades) > 0:
print(f"\n交易分析:")
print(f"总交易次数: {len(trades)}")
if 'Duration' in trades.columns:
print(f"平均持仓时间: {trades['Duration'].mean():.1f}")
if 'PnL' in trades.columns:
print(f"最大单笔盈利: {trades['PnL'].max():.2f}")
print(f"最大单笔亏损: {trades['PnL'].min():.2f}")
winning_trades = trades[trades['PnL'] > 0]
losing_trades = trades[trades['PnL'] < 0]
if len(winning_trades) > 0:
print(f"平均盈利: {winning_trades['PnL'].mean():.2f}")
if len(losing_trades) > 0:
print(f"平均亏损: {losing_trades['PnL'].mean():.2f}")
trades_smic = portfolio['00981'].trades.records_readable
trades_hhic = portfolio['01347'].trades.records_readable
if len(trades_smic) > 0:
print(f"\n中芯国际交易分析:")
print(f"总交易次数: {len(trades_smic)}")
if 'Duration' in trades_smic.columns:
print(f"平均持仓时间: {trades_smic['Duration'].mean():.1f}")
if len(trades_hhic) > 0:
print(f"\n华虹半导体交易分析:")
print(f"总交易次数: {len(trades_hhic)}")
if 'Duration' in trades_hhic.columns:
print(f"平均持仓时间: {trades_hhic['Duration'].mean():.1f}")
except Exception as e:
print(f"分析交易时出错: {e}")