287 lines
12 KiB
Python
287 lines
12 KiB
Python
import pandas as pd
|
||
from datetime import datetime, timedelta
|
||
import logging
|
||
from typing import List, Dict, Any, Optional
|
||
import sys
|
||
import os
|
||
|
||
# 添加当前目录到路径,以便导入自定义模块
|
||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||
|
||
from financial_data_manager import FinancialDataManager
|
||
|
||
class WindDataFetcher:
|
||
"""Wind数据获取器 - 用于从Wind API获取财务数据并存入数据库"""
|
||
|
||
def __init__(self, db_config: Dict[str, Any]):
|
||
"""
|
||
初始化Wind数据获取器
|
||
|
||
Args:
|
||
db_config: 数据库配置字典,包含host, database, user, password, port
|
||
"""
|
||
self.db_config = db_config
|
||
self.logger = self._setup_logger()
|
||
|
||
# 指标定义映射
|
||
self.indicator_definitions = {
|
||
'wgsd_capex_ff': {
|
||
'indicator_id': 'wgsd_capex_ff',
|
||
'indicator_name': '资本支出_现金流量表',
|
||
'category': '现金流量表',
|
||
'unit': '元',
|
||
'indicator_desc': '资本性支出,反映公司用于购置固定资产、无形资产和其他长期资产支付的现金'
|
||
},
|
||
'wgsd_assets_bus_cf': {
|
||
'indicator_id': 'wgsd_assets_bus_cf',
|
||
'indicator_name': '资产处置收益_现金流量表',
|
||
'category': '现金流量表',
|
||
'unit': '元',
|
||
'indicator_desc': '资产处置产生的现金流量,反映公司处置固定资产、无形资产和其他长期资产收回的现金净额'
|
||
},
|
||
'wgsd_net_profit_is': {
|
||
'indicator_id': 'wgsd_net_profit_is',
|
||
'indicator_name': '净利润_利润表',
|
||
'category': '利润表',
|
||
'unit': '元',
|
||
'indicator_desc': '净利润,反映公司在一定会计期间内实现的税后利润'
|
||
}
|
||
}
|
||
|
||
def _setup_logger(self) -> logging.Logger:
|
||
"""设置日志记录器"""
|
||
logger = logging.getLogger('WindDataFetcher')
|
||
logger.setLevel(logging.DEBUG) # 改为DEBUG级别以获取更多信息
|
||
|
||
if not logger.handlers:
|
||
handler = logging.StreamHandler()
|
||
formatter = logging.Formatter(
|
||
'%(asctime)s - %(name)s - %(levelname)s - [%(funcName)s] - %(message)s'
|
||
)
|
||
handler.setFormatter(formatter)
|
||
logger.addHandler(handler)
|
||
|
||
return logger
|
||
|
||
def ensure_indicators_defined(self, db_manager: FinancialDataManager) -> bool:
|
||
"""
|
||
确保财务指标定义已存在于数据库中
|
||
|
||
Args:
|
||
db_manager: 数据库管理器实例
|
||
"""
|
||
try:
|
||
self.logger.debug("开始确保指标定义存在")
|
||
for indicator_id, definition in self.indicator_definitions.items():
|
||
self.logger.debug(f"处理指标定义: {indicator_id}")
|
||
success = db_manager.insert_indicator(definition)
|
||
if not success:
|
||
self.logger.warning(f"插入指标定义失败: {indicator_id}")
|
||
self.logger.info("财务指标定义已确保存在")
|
||
return True
|
||
except Exception as e:
|
||
self.logger.error(f"确保指标定义失败: {e}")
|
||
return False
|
||
|
||
def fetch_wind_data(self, wind_code: str, end_date: str = None) -> Optional[pd.DataFrame]:
|
||
"""
|
||
从Wind API获取财务数据
|
||
|
||
Args:
|
||
wind_code: 股票代码 (如: "000001.SZ")
|
||
end_date: 结束日期,格式 "YYYY-MM-DD",默认为当前日期
|
||
|
||
Returns:
|
||
pandas DataFrame 包含获取的数据,失败返回None
|
||
"""
|
||
try:
|
||
# 导入WindPy,如果未安装会抛出异常
|
||
from WindPy import w
|
||
|
||
# 启动Wind
|
||
self.logger.debug("启动Wind API")
|
||
w.start()
|
||
|
||
if end_date is None:
|
||
end_date = datetime.now().strftime('%Y-%m-%d')
|
||
|
||
# 定义要获取的指标
|
||
indicators = "wgsd_capex_ff,wgsd_assets_bus_cf,wgsd_net_profit_is"
|
||
|
||
self.logger.info(f"开始从Wind获取数据: {wind_code}, 指标: {indicators}, 结束日期: {end_date}")
|
||
|
||
# 调用Wind API
|
||
# ED-10Y 表示结束日期前10年
|
||
self.logger.debug("调用Wind API...")
|
||
result = w.wsd(wind_code, indicators, "ED-10Y", end_date,
|
||
"unit=1;rptType=1;currencyType=;Period=Y;Days=Alldays;Currency=CNY")
|
||
|
||
self.logger.debug(f"Wind API返回 - ErrorCode: {result.ErrorCode}, 数据形状: {len(result.Data)}x{len(result.Times) if result.Data else 0}")
|
||
|
||
if result.ErrorCode != 0:
|
||
self.logger.error(f"Wind API错误: {result.ErrorCode} - {result.Data[0] if result.Data else 'Unknown error'}")
|
||
return None
|
||
|
||
# 转换为DataFrame
|
||
self.logger.debug("转换数据为DataFrame")
|
||
df = pd.DataFrame(result.Data, index=result.Fields, columns=result.Times).T
|
||
df.index.name = 'report_date'
|
||
df.reset_index(inplace=True)
|
||
|
||
# 重命名列
|
||
column_mapping = {
|
||
'wgsd_capex_ff': 'capital_expenditure',
|
||
'wgsd_assets_bus_cf': 'asset_disposal_cash_flow',
|
||
'wgsd_net_profit_is': 'net_profit'
|
||
}
|
||
df.rename(columns=column_mapping, inplace=True)
|
||
|
||
# 添加调试信息
|
||
self.logger.debug(f"DataFrame列名: {df.columns.tolist()}")
|
||
self.logger.debug(f"DataFrame形状: {df.shape}")
|
||
self.logger.debug(f"DataFrame前3行:\n{df.head(3)}")
|
||
self.logger.debug(f"数据统计:\n{df.describe()}")
|
||
|
||
self.logger.info(f"成功获取数据: {wind_code}, 共 {len(df)} 条记录")
|
||
return df
|
||
|
||
except ImportError:
|
||
self.logger.error("未安装WindPy,请先安装Wind官方Python API")
|
||
return None
|
||
except Exception as e:
|
||
self.logger.error(f"获取Wind数据失败: {e}")
|
||
return None
|
||
|
||
def process_and_save_data(self, wind_code: str, end_date: str = None) -> bool:
|
||
"""
|
||
获取数据并保存到数据库
|
||
|
||
Args:
|
||
wind_code: 股票代码
|
||
end_date: 结束日期
|
||
|
||
Returns:
|
||
成功返回True,失败返回False
|
||
"""
|
||
try:
|
||
self.logger.info(f"开始处理股票数据: {wind_code}")
|
||
|
||
# 获取数据
|
||
df = self.fetch_wind_data(wind_code, end_date)
|
||
if df is None:
|
||
self.logger.error(f"获取数据失败: {wind_code}")
|
||
return False
|
||
|
||
if df.empty:
|
||
self.logger.warning(f"获取到空数据: {wind_code}")
|
||
return False
|
||
|
||
self.logger.debug(f"获取到的数据行数: {len(df)}")
|
||
self.logger.debug(f"DataFrame实际列名: {df.columns.tolist()}")
|
||
|
||
# 使用数据库管理器
|
||
with FinancialDataManager(**self.db_config) as db_manager:
|
||
# 确保指标定义存在
|
||
self.logger.debug("确保指标定义存在")
|
||
self.ensure_indicators_defined(db_manager)
|
||
|
||
# 准备要插入的数据
|
||
data_to_insert = []
|
||
valid_row_count = 0
|
||
invalid_row_count = 0
|
||
|
||
self.logger.debug("开始处理每一行数据")
|
||
for idx, row in df.iterrows():
|
||
self.logger.debug(f"处理第 {idx} 行数据")
|
||
report_date = row['report_date']
|
||
self.logger.debug(f"报告日期: {report_date}")
|
||
|
||
# 处理每个指标的数据 - 直接使用Wind返回的原始列名
|
||
row_has_valid_data = False
|
||
for wind_field in ['WGSD_CAPEX_FF', 'WGSD_ASSETS_BUS_CF', 'WGSD_NET_PROFIT_IS']:
|
||
|
||
value = row.get(wind_field)
|
||
self.logger.debug(f"指标 {wind_field} 的值: {value}, 类型: {type(value)}")
|
||
|
||
if pd.notna(value) and value is not None:
|
||
try:
|
||
# 将Wind字段名转换为小写作为indicator_id
|
||
indicator_id = wind_field.lower()
|
||
|
||
data_record = {
|
||
'stock_id': wind_code,
|
||
'indicator_id': indicator_id,
|
||
'report_date': report_date.strftime('%Y-%m-%d') if hasattr(report_date, 'strftime') else str(report_date),
|
||
'data_value': float(value),
|
||
'fiscal_year': report_date.year if hasattr(report_date, 'year') else pd.to_datetime(report_date).year,
|
||
'period_type': 'Y',
|
||
'currency': 'CNY',
|
||
'data_unit': '元',
|
||
'source_system': 'Wind'
|
||
}
|
||
data_to_insert.append(data_record)
|
||
row_has_valid_data = True
|
||
self.logger.debug(f"添加数据记录: {wind_field} = {value}")
|
||
except Exception as e:
|
||
self.logger.warning(f"处理数据值时出错: {value}, 错误: {e}")
|
||
else:
|
||
self.logger.debug(f"跳过空值: {wind_field}")
|
||
|
||
if row_has_valid_data:
|
||
valid_row_count += 1
|
||
else:
|
||
invalid_row_count += 1
|
||
|
||
self.logger.info(f"数据处理完成 - 有效行: {valid_row_count}, 无效行: {invalid_row_count}, 总记录数: {len(data_to_insert)}")
|
||
|
||
# 批量插入数据
|
||
if data_to_insert:
|
||
self.logger.info(f"准备插入 {len(data_to_insert)} 条数据到数据库")
|
||
success = db_manager.batch_insert_financial_data(data_to_insert)
|
||
if success:
|
||
self.logger.info(f"成功保存 {len(data_to_insert)} 条数据到数据库: {wind_code}")
|
||
else:
|
||
self.logger.error(f"保存数据到数据库失败: {wind_code}")
|
||
return success
|
||
else:
|
||
self.logger.warning(f"没有有效数据需要保存: {wind_code}")
|
||
# 检查为什么所有数据都被过滤掉了
|
||
self.logger.debug("检查数据过滤原因:")
|
||
for idx, row in df.iterrows():
|
||
self.logger.debug(f"第{idx}行数据:")
|
||
for col in df.columns:
|
||
self.logger.debug(f" {col}: {row[col]} (类型: {type(row[col])}, 是否空值: {pd.isna(row[col])})")
|
||
return False
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"处理保存数据失败: {e}", exc_info=True)
|
||
return False
|
||
|
||
|
||
def batch_fetch_stocks(self, stock_list: List[str], end_date: str = None) -> Dict[str, bool]:
|
||
"""
|
||
批量获取多个股票的数据
|
||
|
||
Args:
|
||
stock_list: 股票代码列表
|
||
end_date: 结束日期
|
||
|
||
Returns:
|
||
字典,键为股票代码,值为是否成功
|
||
"""
|
||
results = {}
|
||
|
||
for stock_code in stock_list:
|
||
self.logger.info(f"开始处理股票: {stock_code}")
|
||
success = self.process_and_save_data(stock_code, end_date)
|
||
results[stock_code] = success
|
||
|
||
# 添加短暂延迟,避免请求过于频繁
|
||
import time
|
||
time.sleep(1)
|
||
|
||
# 统计结果
|
||
success_count = sum(results.values())
|
||
self.logger.info(f"批量处理完成: 成功 {success_count}/{len(stock_list)}")
|
||
|
||
return results |