319 lines
11 KiB
Python
319 lines
11 KiB
Python
import logging
|
||
from datetime import datetime
|
||
from typing import List, Dict, Optional, Any
|
||
|
||
try:
|
||
import MySQLdb
|
||
from MySQLdb import Error
|
||
MYSQL_AVAILABLE = True
|
||
except ImportError:
|
||
MYSQL_AVAILABLE = False
|
||
logging.warning("未安装MySQLdb,请安装: pip install mysqlclient")
|
||
|
||
|
||
class FinancialDataManager:
|
||
"""金融数据管理类 - 使用MySQLdb连接MySQL数据库"""
|
||
|
||
def __init__(self, host: str, database: str, user: str, password: str, port: int = 3306):
|
||
"""
|
||
初始化数据库连接
|
||
|
||
Args:
|
||
host: 数据库主机
|
||
database: 数据库名
|
||
user: 用户名
|
||
password: 密码
|
||
port: 端口号,默认3306
|
||
"""
|
||
if not MYSQL_AVAILABLE:
|
||
raise ImportError("未安装MySQLdb,请安装: pip install mysqlclient")
|
||
|
||
self.host = host
|
||
self.database = database
|
||
self.user = user
|
||
self.password = password
|
||
self.port = port
|
||
self.connection = None
|
||
self.logger = self._setup_logger()
|
||
|
||
def _setup_logger(self) -> logging.Logger:
|
||
"""设置日志记录器"""
|
||
logger = logging.getLogger('FinancialDataManager')
|
||
logger.setLevel(logging.INFO)
|
||
|
||
if not logger.handlers:
|
||
handler = logging.StreamHandler()
|
||
formatter = logging.Formatter(
|
||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||
)
|
||
handler.setFormatter(formatter)
|
||
logger.addHandler(handler)
|
||
|
||
return logger
|
||
|
||
def connect(self) -> bool:
|
||
"""建立数据库连接"""
|
||
try:
|
||
self.connection = MySQLdb.connect(
|
||
host=self.host,
|
||
user=self.user,
|
||
passwd=self.password,
|
||
db=self.database,
|
||
port=self.port,
|
||
charset='utf8mb4',
|
||
autocommit=False
|
||
)
|
||
if self.connection.open:
|
||
self.logger.info("成功连接到MySQL数据库")
|
||
return True
|
||
except Error as e:
|
||
self.logger.error(f"数据库连接失败: {e}")
|
||
return False
|
||
|
||
def disconnect(self):
|
||
"""关闭数据库连接"""
|
||
if self.connection and self.connection.open:
|
||
self.connection.close()
|
||
self.logger.info("数据库连接已关闭")
|
||
|
||
def __enter__(self):
|
||
"""上下文管理器入口"""
|
||
self.connect()
|
||
return self
|
||
|
||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||
"""上下文管理器出口"""
|
||
self.disconnect()
|
||
|
||
def _execute_query(self, query: str, params: tuple = None) -> bool:
|
||
"""执行SQL查询的辅助方法"""
|
||
if not self.connection:
|
||
self.logger.error("数据库未连接")
|
||
return False
|
||
|
||
try:
|
||
cursor = self.connection.cursor()
|
||
cursor.execute(query, params or ())
|
||
self.connection.commit()
|
||
return True
|
||
except Error as e:
|
||
self.logger.error(f"执行查询失败: {e}")
|
||
if self.connection:
|
||
self.connection.rollback()
|
||
return False
|
||
finally:
|
||
if 'cursor' in locals():
|
||
cursor.close()
|
||
|
||
def _fetch_all(self, query: str, params: tuple = None) -> List[Dict[str, Any]]:
|
||
"""执行查询并返回所有结果的辅助方法"""
|
||
if not self.connection:
|
||
self.logger.error("数据库未连接")
|
||
return []
|
||
|
||
try:
|
||
cursor = self.connection.cursor(MySQLdb.cursors.DictCursor)
|
||
cursor.execute(query, params or ())
|
||
results = cursor.fetchall()
|
||
return results
|
||
except Error as e:
|
||
self.logger.error(f"查询失败: {e}")
|
||
return []
|
||
finally:
|
||
if 'cursor' in locals():
|
||
cursor.close()
|
||
|
||
# ==================== 股票数据操作 ====================
|
||
|
||
def insert_stock(self, stock_data: Dict[str, Any]) -> bool:
|
||
"""
|
||
插入或更新股票基本信息 - 更新版本
|
||
"""
|
||
required_fields = ['stock_id']
|
||
if not all(field in stock_data for field in required_fields):
|
||
self.logger.error(f"缺少必需字段: {required_fields}")
|
||
return False
|
||
|
||
query = """
|
||
INSERT INTO stocks (stock_id, wind_industry_code, wind_industry_name, country_code, exchange, listing_date)
|
||
VALUES (%s, %s, %s, %s, %s, %s)
|
||
ON DUPLICATE KEY UPDATE
|
||
wind_industry_code = VALUES(wind_industry_code),
|
||
wind_industry_name = VALUES(wind_industry_name),
|
||
country_code = VALUES(country_code),
|
||
exchange = VALUES(exchange),
|
||
listing_date = VALUES(listing_date),
|
||
updated_at = CURRENT_TIMESTAMP
|
||
"""
|
||
|
||
success = self._execute_query(query, (
|
||
stock_data['stock_id'],
|
||
stock_data.get('wind_industry_code'),
|
||
stock_data.get('wind_industry_name'),
|
||
stock_data.get('country_code'),
|
||
stock_data.get('exchange'),
|
||
stock_data.get('listing_date')
|
||
))
|
||
|
||
if success:
|
||
self.logger.info(f"成功处理股票: {stock_data['stock_id']}")
|
||
return success
|
||
|
||
|
||
def batch_insert_stocks(self, stocks_list: List[Dict[str, Any]]) -> bool:
|
||
"""批量插入股票数据"""
|
||
success_count = 0
|
||
for stock_data in stocks_list:
|
||
if self.insert_stock(stock_data):
|
||
success_count += 1
|
||
|
||
self.logger.info(f"批量插入完成: {success_count}/{len(stocks_list)} 成功")
|
||
return success_count == len(stocks_list)
|
||
|
||
# ==================== 财务指标定义操作 ====================
|
||
|
||
def insert_indicator(self, indicator_data: Dict[str, Any]) -> bool:
|
||
"""
|
||
插入或更新财务指标定义
|
||
"""
|
||
required_fields = ['indicator_id', 'indicator_name']
|
||
if not all(field in indicator_data for field in required_fields):
|
||
self.logger.error(f"缺少必需字段: {required_fields}")
|
||
return False
|
||
|
||
query = """
|
||
INSERT INTO financial_indicators
|
||
(indicator_id, indicator_name, indicator_desc, category, unit, data_source)
|
||
VALUES (%s, %s, %s, %s, %s, %s)
|
||
ON DUPLICATE KEY UPDATE
|
||
indicator_name = VALUES(indicator_name),
|
||
indicator_desc = VALUES(indicator_desc),
|
||
category = VALUES(category),
|
||
unit = VALUES(unit),
|
||
data_source = VALUES(data_source)
|
||
"""
|
||
|
||
success = self._execute_query(query, (
|
||
indicator_data['indicator_id'],
|
||
indicator_data['indicator_name'],
|
||
indicator_data.get('indicator_desc'),
|
||
indicator_data.get('category'),
|
||
indicator_data.get('unit'),
|
||
indicator_data.get('data_source', 'Wind')
|
||
))
|
||
|
||
if success:
|
||
self.logger.info(f"成功处理指标: {indicator_data['indicator_id']}")
|
||
return success
|
||
|
||
# ==================== 财务数据操作 ====================
|
||
|
||
def insert_financial_data(self, financial_data: Dict[str, Any]) -> bool:
|
||
"""
|
||
插入财务数据
|
||
"""
|
||
required_fields = ['stock_id', 'indicator_id', 'report_date', 'data_value']
|
||
if not all(field in financial_data for field in required_fields):
|
||
self.logger.error(f"缺少必需字段: {required_fields}")
|
||
return False
|
||
|
||
query = """
|
||
INSERT INTO financial_data
|
||
(stock_id, indicator_id, report_date, fiscal_year, period_type,
|
||
currency, data_value, data_unit, data_status, source_system)
|
||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||
ON DUPLICATE KEY UPDATE
|
||
data_value = VALUES(data_value),
|
||
data_unit = VALUES(data_unit),
|
||
data_status = VALUES(data_status),
|
||
source_system = VALUES(source_system),
|
||
last_updated = CURRENT_TIMESTAMP
|
||
"""
|
||
|
||
success = self._execute_query(query, (
|
||
financial_data['stock_id'],
|
||
financial_data['indicator_id'],
|
||
financial_data['report_date'],
|
||
financial_data.get('fiscal_year'),
|
||
financial_data.get('period_type', 'Y'),
|
||
financial_data.get('currency', 'CNY'),
|
||
financial_data['data_value'],
|
||
financial_data.get('data_unit'),
|
||
financial_data.get('data_status', 'valid'),
|
||
financial_data.get('source_system', 'Wind')
|
||
))
|
||
|
||
if success:
|
||
self.logger.info(
|
||
f"成功插入财务数据: {financial_data['stock_id']} - "
|
||
f"{financial_data['indicator_id']} - {financial_data['report_date']}"
|
||
)
|
||
return success
|
||
|
||
def batch_insert_financial_data(self, data_list: List[Dict[str, Any]]) -> bool:
|
||
"""批量插入财务数据"""
|
||
success_count = 0
|
||
for data in data_list:
|
||
if self.insert_financial_data(data):
|
||
success_count += 1
|
||
|
||
self.logger.info(f"批量插入财务数据完成: {success_count}/{len(data_list)} 成功")
|
||
return success_count == len(data_list)
|
||
|
||
# ==================== 查询方法 ====================
|
||
|
||
def get_stock_info(self, stock_id: str) -> Optional[Dict[str, Any]]:
|
||
"""根据股票代码查询股票信息"""
|
||
query = "SELECT * FROM stocks WHERE stock_id = %s"
|
||
results = self._fetch_all(query, (stock_id,))
|
||
return results[0] if results else None
|
||
|
||
def get_financial_data(self, stock_id: str, indicator_id: str = None,
|
||
start_date: str = None, end_date: str = None) -> List[Dict[str, Any]]:
|
||
"""
|
||
查询财务数据
|
||
"""
|
||
query = """
|
||
SELECT fd.*, fi.indicator_name, fi.category, s.stock_name
|
||
FROM financial_data fd
|
||
LEFT JOIN financial_indicators fi ON fd.indicator_id = fi.indicator_id
|
||
LEFT JOIN stocks s ON fd.stock_id = s.stock_id
|
||
WHERE fd.stock_id = %s
|
||
"""
|
||
params = [stock_id]
|
||
|
||
if indicator_id:
|
||
query += " AND fd.indicator_id = %s"
|
||
params.append(indicator_id)
|
||
|
||
if start_date:
|
||
query += " AND fd.report_date >= %s"
|
||
params.append(start_date)
|
||
|
||
if end_date:
|
||
query += " AND fd.report_date <= %s"
|
||
params.append(end_date)
|
||
|
||
query += " ORDER BY fd.report_date DESC, fd.indicator_id"
|
||
|
||
return self._fetch_all(query, tuple(params))
|
||
|
||
def check_data_exists(self, stock_id: str, indicator_id: str,
|
||
report_date: str, period_type: str = 'Y') -> bool:
|
||
"""检查指定数据是否已存在"""
|
||
query = """
|
||
SELECT COUNT(*) as count FROM financial_data
|
||
WHERE stock_id = %s AND indicator_id = %s
|
||
AND report_date = %s AND period_type = %s
|
||
"""
|
||
try:
|
||
cursor = self.connection.cursor()
|
||
cursor.execute(query, (stock_id, indicator_id, report_date, period_type))
|
||
result = cursor.fetchone()
|
||
return result[0] > 0
|
||
except Error as e:
|
||
self.logger.error(f"检查数据存在性失败: {e}")
|
||
return False
|
||
finally:
|
||
if 'cursor' in locals():
|
||
cursor.close() |