Files
quant/finance/financial_data_manager.py
2025-11-16 09:23:47 +08:00

319 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
from datetime import datetime
from typing import List, Dict, Optional, Any
try:
import MySQLdb
from MySQLdb import Error
MYSQL_AVAILABLE = True
except ImportError:
MYSQL_AVAILABLE = False
logging.warning("未安装MySQLdb请安装: pip install mysqlclient")
class FinancialDataManager:
"""金融数据管理类 - 使用MySQLdb连接MySQL数据库"""
def __init__(self, host: str, database: str, user: str, password: str, port: int = 3306):
"""
初始化数据库连接
Args:
host: 数据库主机
database: 数据库名
user: 用户名
password: 密码
port: 端口号默认3306
"""
if not MYSQL_AVAILABLE:
raise ImportError("未安装MySQLdb请安装: pip install mysqlclient")
self.host = host
self.database = database
self.user = user
self.password = password
self.port = port
self.connection = None
self.logger = self._setup_logger()
def _setup_logger(self) -> logging.Logger:
"""设置日志记录器"""
logger = logging.getLogger('FinancialDataManager')
logger.setLevel(logging.INFO)
if not logger.handlers:
handler = logging.StreamHandler()
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
def connect(self) -> bool:
"""建立数据库连接"""
try:
self.connection = MySQLdb.connect(
host=self.host,
user=self.user,
passwd=self.password,
db=self.database,
port=self.port,
charset='utf8mb4',
autocommit=False
)
if self.connection.open:
self.logger.info("成功连接到MySQL数据库")
return True
except Error as e:
self.logger.error(f"数据库连接失败: {e}")
return False
def disconnect(self):
"""关闭数据库连接"""
if self.connection and self.connection.open:
self.connection.close()
self.logger.info("数据库连接已关闭")
def __enter__(self):
"""上下文管理器入口"""
self.connect()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""上下文管理器出口"""
self.disconnect()
def _execute_query(self, query: str, params: tuple = None) -> bool:
"""执行SQL查询的辅助方法"""
if not self.connection:
self.logger.error("数据库未连接")
return False
try:
cursor = self.connection.cursor()
cursor.execute(query, params or ())
self.connection.commit()
return True
except Error as e:
self.logger.error(f"执行查询失败: {e}")
if self.connection:
self.connection.rollback()
return False
finally:
if 'cursor' in locals():
cursor.close()
def _fetch_all(self, query: str, params: tuple = None) -> List[Dict[str, Any]]:
"""执行查询并返回所有结果的辅助方法"""
if not self.connection:
self.logger.error("数据库未连接")
return []
try:
cursor = self.connection.cursor(MySQLdb.cursors.DictCursor)
cursor.execute(query, params or ())
results = cursor.fetchall()
return results
except Error as e:
self.logger.error(f"查询失败: {e}")
return []
finally:
if 'cursor' in locals():
cursor.close()
# ==================== 股票数据操作 ====================
def insert_stock(self, stock_data: Dict[str, Any]) -> bool:
"""
插入或更新股票基本信息 - 更新版本
"""
required_fields = ['stock_id']
if not all(field in stock_data for field in required_fields):
self.logger.error(f"缺少必需字段: {required_fields}")
return False
query = """
INSERT INTO stocks (stock_id, wind_industry_code, wind_industry_name, country_code, exchange, listing_date)
VALUES (%s, %s, %s, %s, %s, %s)
ON DUPLICATE KEY UPDATE
wind_industry_code = VALUES(wind_industry_code),
wind_industry_name = VALUES(wind_industry_name),
country_code = VALUES(country_code),
exchange = VALUES(exchange),
listing_date = VALUES(listing_date),
updated_at = CURRENT_TIMESTAMP
"""
success = self._execute_query(query, (
stock_data['stock_id'],
stock_data.get('wind_industry_code'),
stock_data.get('wind_industry_name'),
stock_data.get('country_code'),
stock_data.get('exchange'),
stock_data.get('listing_date')
))
if success:
self.logger.info(f"成功处理股票: {stock_data['stock_id']}")
return success
def batch_insert_stocks(self, stocks_list: List[Dict[str, Any]]) -> bool:
"""批量插入股票数据"""
success_count = 0
for stock_data in stocks_list:
if self.insert_stock(stock_data):
success_count += 1
self.logger.info(f"批量插入完成: {success_count}/{len(stocks_list)} 成功")
return success_count == len(stocks_list)
# ==================== 财务指标定义操作 ====================
def insert_indicator(self, indicator_data: Dict[str, Any]) -> bool:
"""
插入或更新财务指标定义
"""
required_fields = ['indicator_id', 'indicator_name']
if not all(field in indicator_data for field in required_fields):
self.logger.error(f"缺少必需字段: {required_fields}")
return False
query = """
INSERT INTO financial_indicators
(indicator_id, indicator_name, indicator_desc, category, unit, data_source)
VALUES (%s, %s, %s, %s, %s, %s)
ON DUPLICATE KEY UPDATE
indicator_name = VALUES(indicator_name),
indicator_desc = VALUES(indicator_desc),
category = VALUES(category),
unit = VALUES(unit),
data_source = VALUES(data_source)
"""
success = self._execute_query(query, (
indicator_data['indicator_id'],
indicator_data['indicator_name'],
indicator_data.get('indicator_desc'),
indicator_data.get('category'),
indicator_data.get('unit'),
indicator_data.get('data_source', 'Wind')
))
if success:
self.logger.info(f"成功处理指标: {indicator_data['indicator_id']}")
return success
# ==================== 财务数据操作 ====================
def insert_financial_data(self, financial_data: Dict[str, Any]) -> bool:
"""
插入财务数据
"""
required_fields = ['stock_id', 'indicator_id', 'report_date', 'data_value']
if not all(field in financial_data for field in required_fields):
self.logger.error(f"缺少必需字段: {required_fields}")
return False
query = """
INSERT INTO financial_data
(stock_id, indicator_id, report_date, fiscal_year, period_type,
currency, data_value, data_unit, data_status, source_system)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
ON DUPLICATE KEY UPDATE
data_value = VALUES(data_value),
data_unit = VALUES(data_unit),
data_status = VALUES(data_status),
source_system = VALUES(source_system),
last_updated = CURRENT_TIMESTAMP
"""
success = self._execute_query(query, (
financial_data['stock_id'],
financial_data['indicator_id'],
financial_data['report_date'],
financial_data.get('fiscal_year'),
financial_data.get('period_type', 'Y'),
financial_data.get('currency', 'CNY'),
financial_data['data_value'],
financial_data.get('data_unit'),
financial_data.get('data_status', 'valid'),
financial_data.get('source_system', 'Wind')
))
if success:
self.logger.info(
f"成功插入财务数据: {financial_data['stock_id']} - "
f"{financial_data['indicator_id']} - {financial_data['report_date']}"
)
return success
def batch_insert_financial_data(self, data_list: List[Dict[str, Any]]) -> bool:
"""批量插入财务数据"""
success_count = 0
for data in data_list:
if self.insert_financial_data(data):
success_count += 1
self.logger.info(f"批量插入财务数据完成: {success_count}/{len(data_list)} 成功")
return success_count == len(data_list)
# ==================== 查询方法 ====================
def get_stock_info(self, stock_id: str) -> Optional[Dict[str, Any]]:
"""根据股票代码查询股票信息"""
query = "SELECT * FROM stocks WHERE stock_id = %s"
results = self._fetch_all(query, (stock_id,))
return results[0] if results else None
def get_financial_data(self, stock_id: str, indicator_id: str = None,
start_date: str = None, end_date: str = None) -> List[Dict[str, Any]]:
"""
查询财务数据
"""
query = """
SELECT fd.*, fi.indicator_name, fi.category, s.stock_name
FROM financial_data fd
LEFT JOIN financial_indicators fi ON fd.indicator_id = fi.indicator_id
LEFT JOIN stocks s ON fd.stock_id = s.stock_id
WHERE fd.stock_id = %s
"""
params = [stock_id]
if indicator_id:
query += " AND fd.indicator_id = %s"
params.append(indicator_id)
if start_date:
query += " AND fd.report_date >= %s"
params.append(start_date)
if end_date:
query += " AND fd.report_date <= %s"
params.append(end_date)
query += " ORDER BY fd.report_date DESC, fd.indicator_id"
return self._fetch_all(query, tuple(params))
def check_data_exists(self, stock_id: str, indicator_id: str,
report_date: str, period_type: str = 'Y') -> bool:
"""检查指定数据是否已存在"""
query = """
SELECT COUNT(*) as count FROM financial_data
WHERE stock_id = %s AND indicator_id = %s
AND report_date = %s AND period_type = %s
"""
try:
cursor = self.connection.cursor()
cursor.execute(query, (stock_id, indicator_id, report_date, period_type))
result = cursor.fetchone()
return result[0] > 0
except Error as e:
self.logger.error(f"检查数据存在性失败: {e}")
return False
finally:
if 'cursor' in locals():
cursor.close()