import logging from datetime import datetime from typing import List, Dict, Optional, Any try: import MySQLdb from MySQLdb import Error MYSQL_AVAILABLE = True except ImportError: MYSQL_AVAILABLE = False logging.warning("未安装MySQLdb,请安装: pip install mysqlclient") class FinancialDataManager: """金融数据管理类 - 使用MySQLdb连接MySQL数据库""" def __init__(self, host: str, database: str, user: str, password: str, port: int = 3306): """ 初始化数据库连接 Args: host: 数据库主机 database: 数据库名 user: 用户名 password: 密码 port: 端口号,默认3306 """ if not MYSQL_AVAILABLE: raise ImportError("未安装MySQLdb,请安装: pip install mysqlclient") self.host = host self.database = database self.user = user self.password = password self.port = port self.connection = None self.logger = self._setup_logger() def _setup_logger(self) -> logging.Logger: """设置日志记录器""" logger = logging.getLogger('FinancialDataManager') logger.setLevel(logging.INFO) if not logger.handlers: handler = logging.StreamHandler() formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) handler.setFormatter(formatter) logger.addHandler(handler) return logger def connect(self) -> bool: """建立数据库连接""" try: self.connection = MySQLdb.connect( host=self.host, user=self.user, passwd=self.password, db=self.database, port=self.port, charset='utf8mb4', autocommit=False ) if self.connection.open: self.logger.info("成功连接到MySQL数据库") return True except Error as e: self.logger.error(f"数据库连接失败: {e}") return False def disconnect(self): """关闭数据库连接""" if self.connection and self.connection.open: self.connection.close() self.logger.info("数据库连接已关闭") def __enter__(self): """上下文管理器入口""" self.connect() return self def __exit__(self, exc_type, exc_val, exc_tb): """上下文管理器出口""" self.disconnect() def _execute_query(self, query: str, params: tuple = None) -> bool: """执行SQL查询的辅助方法""" if not self.connection: self.logger.error("数据库未连接") return False try: cursor = self.connection.cursor() cursor.execute(query, params or ()) self.connection.commit() return True except Error as e: self.logger.error(f"执行查询失败: {e}") if self.connection: self.connection.rollback() return False finally: if 'cursor' in locals(): cursor.close() def _fetch_all(self, query: str, params: tuple = None) -> List[Dict[str, Any]]: """执行查询并返回所有结果的辅助方法""" if not self.connection: self.logger.error("数据库未连接") return [] try: cursor = self.connection.cursor(MySQLdb.cursors.DictCursor) cursor.execute(query, params or ()) results = cursor.fetchall() return results except Error as e: self.logger.error(f"查询失败: {e}") return [] finally: if 'cursor' in locals(): cursor.close() # ==================== 股票数据操作 ==================== def insert_stock(self, stock_data: Dict[str, Any]) -> bool: """ 插入或更新股票基本信息 - 更新版本 """ required_fields = ['stock_id'] if not all(field in stock_data for field in required_fields): self.logger.error(f"缺少必需字段: {required_fields}") return False query = """ INSERT INTO stocks (stock_id, wind_industry_code, wind_industry_name, country_code, exchange, listing_date) VALUES (%s, %s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE wind_industry_code = VALUES(wind_industry_code), wind_industry_name = VALUES(wind_industry_name), country_code = VALUES(country_code), exchange = VALUES(exchange), listing_date = VALUES(listing_date), updated_at = CURRENT_TIMESTAMP """ success = self._execute_query(query, ( stock_data['stock_id'], stock_data.get('wind_industry_code'), stock_data.get('wind_industry_name'), stock_data.get('country_code'), stock_data.get('exchange'), stock_data.get('listing_date') )) if success: self.logger.info(f"成功处理股票: {stock_data['stock_id']}") return success def batch_insert_stocks(self, stocks_list: List[Dict[str, Any]]) -> bool: """批量插入股票数据""" success_count = 0 for stock_data in stocks_list: if self.insert_stock(stock_data): success_count += 1 self.logger.info(f"批量插入完成: {success_count}/{len(stocks_list)} 成功") return success_count == len(stocks_list) # ==================== 财务指标定义操作 ==================== def insert_indicator(self, indicator_data: Dict[str, Any]) -> bool: """ 插入或更新财务指标定义 """ required_fields = ['indicator_id', 'indicator_name'] if not all(field in indicator_data for field in required_fields): self.logger.error(f"缺少必需字段: {required_fields}") return False query = """ INSERT INTO financial_indicators (indicator_id, indicator_name, indicator_desc, category, unit, data_source) VALUES (%s, %s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE indicator_name = VALUES(indicator_name), indicator_desc = VALUES(indicator_desc), category = VALUES(category), unit = VALUES(unit), data_source = VALUES(data_source) """ success = self._execute_query(query, ( indicator_data['indicator_id'], indicator_data['indicator_name'], indicator_data.get('indicator_desc'), indicator_data.get('category'), indicator_data.get('unit'), indicator_data.get('data_source', 'Wind') )) if success: self.logger.info(f"成功处理指标: {indicator_data['indicator_id']}") return success # ==================== 财务数据操作 ==================== def insert_financial_data(self, financial_data: Dict[str, Any]) -> bool: """ 插入财务数据 """ required_fields = ['stock_id', 'indicator_id', 'report_date', 'data_value'] if not all(field in financial_data for field in required_fields): self.logger.error(f"缺少必需字段: {required_fields}") return False query = """ INSERT INTO financial_data (stock_id, indicator_id, report_date, fiscal_year, period_type, currency, data_value, data_unit, data_status, source_system) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE data_value = VALUES(data_value), data_unit = VALUES(data_unit), data_status = VALUES(data_status), source_system = VALUES(source_system), last_updated = CURRENT_TIMESTAMP """ success = self._execute_query(query, ( financial_data['stock_id'], financial_data['indicator_id'], financial_data['report_date'], financial_data.get('fiscal_year'), financial_data.get('period_type', 'Y'), financial_data.get('currency', 'CNY'), financial_data['data_value'], financial_data.get('data_unit'), financial_data.get('data_status', 'valid'), financial_data.get('source_system', 'Wind') )) if success: self.logger.info( f"成功插入财务数据: {financial_data['stock_id']} - " f"{financial_data['indicator_id']} - {financial_data['report_date']}" ) return success def batch_insert_financial_data(self, data_list: List[Dict[str, Any]]) -> bool: """批量插入财务数据""" success_count = 0 for data in data_list: if self.insert_financial_data(data): success_count += 1 self.logger.info(f"批量插入财务数据完成: {success_count}/{len(data_list)} 成功") return success_count == len(data_list) # ==================== 查询方法 ==================== def get_stock_info(self, stock_id: str) -> Optional[Dict[str, Any]]: """根据股票代码查询股票信息""" query = "SELECT * FROM stocks WHERE stock_id = %s" results = self._fetch_all(query, (stock_id,)) return results[0] if results else None def get_financial_data(self, stock_id: str, indicator_id: str = None, start_date: str = None, end_date: str = None) -> List[Dict[str, Any]]: """ 查询财务数据 """ query = """ SELECT fd.*, fi.indicator_name, fi.category, s.stock_name FROM financial_data fd LEFT JOIN financial_indicators fi ON fd.indicator_id = fi.indicator_id LEFT JOIN stocks s ON fd.stock_id = s.stock_id WHERE fd.stock_id = %s """ params = [stock_id] if indicator_id: query += " AND fd.indicator_id = %s" params.append(indicator_id) if start_date: query += " AND fd.report_date >= %s" params.append(start_date) if end_date: query += " AND fd.report_date <= %s" params.append(end_date) query += " ORDER BY fd.report_date DESC, fd.indicator_id" return self._fetch_all(query, tuple(params)) def check_data_exists(self, stock_id: str, indicator_id: str, report_date: str, period_type: str = 'Y') -> bool: """检查指定数据是否已存在""" query = """ SELECT COUNT(*) as count FROM financial_data WHERE stock_id = %s AND indicator_id = %s AND report_date = %s AND period_type = %s """ try: cursor = self.connection.cursor() cursor.execute(query, (stock_id, indicator_id, report_date, period_type)) result = cursor.fetchone() return result[0] > 0 except Error as e: self.logger.error(f"检查数据存在性失败: {e}") return False finally: if 'cursor' in locals(): cursor.close()