Files
quant/finance/excel_stock_importer.py
2025-11-16 09:23:47 +08:00

188 lines
7.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
import logging
from typing import List, Dict, Any, Optional
from financial_data_manager import FinancialDataManager
class ExcelStockImporter:
"""Excel股票数据导入器"""
def __init__(self, db_config: Dict[str, Any]):
"""
初始化导入器
Args:
db_config: 数据库配置字典
"""
self.db_config = db_config
self.logger = self._setup_logger()
def _setup_logger(self) -> logging.Logger:
"""设置日志记录器"""
logger = logging.getLogger('ExcelStockImporter')
logger.setLevel(logging.INFO)
if not logger.handlers:
handler = logging.StreamHandler()
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
def read_excel_file(self, file_path: str) -> Optional[pd.DataFrame]:
"""
读取Excel文件
Args:
file_path: Excel文件路径
Returns:
pandas DataFrame 或 None如果读取失败
"""
try:
self.logger.info(f"开始读取Excel文件: {file_path}")
# 读取Excel文件
df = pd.read_excel(file_path)
# 检查必要的列是否存在
required_columns = ['Wind代码', 'Wind一级行业代码', 'Wind一级行业名称', '上市地国家(地区)代码', '交易所']
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
self.logger.error(f"Excel文件中缺少必要的列: {missing_columns}")
self.logger.info(f"Excel文件中的列: {df.columns.tolist()}")
return None
self.logger.info(f"成功读取Excel文件{len(df)} 行数据")
self.logger.debug(f"数据前5行:\n{df.head()}")
return df
except FileNotFoundError:
self.logger.error(f"文件不存在: {file_path}")
return None
except Exception as e:
self.logger.error(f"读取Excel文件失败: {e}")
return None
def process_stock_data(self, df: pd.DataFrame) -> List[Dict[str, Any]]:
"""
处理股票数据,转换为数据库格式
Args:
df: 原始DataFrame
Returns:
处理后的股票数据列表
"""
processed_data = []
for idx, row in df.iterrows():
try:
stock_data = {
'stock_id': str(row['Wind代码']).strip(),
'wind_industry_code': str(row['Wind一级行业代码']).strip() if pd.notna(row['Wind一级行业代码']) else None,
'wind_industry_name': str(row['Wind一级行业名称']).strip() if pd.notna(row['Wind一级行业名称']) else None,
'country_code': str(row['上市地国家(地区)代码']).strip() if pd.notna(row['上市地国家(地区)代码']) else None,
'exchange': str(row['交易所']).strip() if pd.notna(row['交易所']) else None,
# 如果有上市日期列,可以在这里添加
# 'listing_date': row['上市日期'] if '上市日期' in row and pd.notna(row['上市日期']) else None
}
# 验证必要字段
if not stock_data['stock_id']:
self.logger.warning(f"{idx+1} 行股票代码为空,跳过")
continue
processed_data.append(stock_data)
except Exception as e:
self.logger.warning(f"处理第 {idx+1} 行数据时出错: {e}")
continue
self.logger.info(f"成功处理 {len(processed_data)} 条股票数据")
return processed_data
def import_to_database(self, file_path: str) -> bool:
"""
导入Excel数据到数据库
Args:
file_path: Excel文件路径
Returns:
成功返回True失败返回False
"""
try:
# 读取Excel文件
df = self.read_excel_file(file_path)
if df is None or df.empty:
self.logger.error("读取Excel文件失败或文件为空")
return False
# 处理数据
stock_data_list = self.process_stock_data(df)
if not stock_data_list:
self.logger.error("没有有效的股票数据需要导入")
return False
# 插入数据库
with FinancialDataManager(**self.db_config) as db_manager:
success_count = 0
total_count = len(stock_data_list)
for stock_data in stock_data_list:
try:
# 使用insert_stock方法插入数据
if db_manager.insert_stock(stock_data):
success_count += 1
else:
self.logger.warning(f"插入股票数据失败: {stock_data['stock_id']}")
except Exception as e:
self.logger.error(f"插入股票 {stock_data['stock_id']} 时出错: {e}")
self.logger.info(f"导入完成: 成功 {success_count}/{total_count}")
return success_count > 0
except Exception as e:
self.logger.error(f"导入过程失败: {e}")
return False
def validate_data(self, file_path: str) -> Dict[str, Any]:
"""
验证Excel数据
Args:
file_path: Excel文件路径
Returns:
验证结果字典
"""
try:
df = self.read_excel_file(file_path)
if df is None:
return {'valid': False, 'message': '文件读取失败'}
# 基本统计
result = {
'valid': True,
'total_rows': len(df),
'stock_codes': df['Wind代码'].tolist(),
'industries': df['Wind一级行业名称'].unique().tolist(),
'exchanges': df['交易所'].unique().tolist(),
'duplicate_stocks': df[df.duplicated('Wind代码')]['Wind代码'].tolist()
}
# 检查重复股票代码
if result['duplicate_stocks']:
self.logger.warning(f"发现重复的股票代码: {result['duplicate_stocks']}")
result['valid'] = False
return result
except Exception as e:
self.logger.error(f"数据验证失败: {e}")
return {'valid': False, 'message': str(e)}