import asyncio import logging.config from pybit.unified_trading import WebSocket import database.request as rq from app.bybit.logger_bybit.logger_bybit import LOGGING_CONFIG from app.bybit.telegram_message_handler import TelegramMessageHandler logging.config.dictConfig(LOGGING_CONFIG) logger = logging.getLogger("web_socket") class WebSocketBot: """ Class to handle WebSocket connections and messages. """ def __init__(self, telegram_bot): """Initialize the TradingBot class.""" self.telegram_bot = telegram_bot self.ws_private = None self.user_messages = {} self.user_sockets = {} self.user_keys = {} self.loop = None self.message_handler = TelegramMessageHandler(telegram_bot) async def run_user_check_loop(self): """Run a loop to check for users and connect them to the WebSocket.""" self.loop = asyncio.get_running_loop() while True: users = await WebSocketBot.get_users_from_db() for user in users: tg_id = user.tg_id api_key, api_secret = await rq.get_user_api(tg_id=tg_id) if not api_key or not api_secret: continue keys_stored = self.user_keys.get(tg_id) if tg_id in self.user_sockets and keys_stored == (api_key, api_secret): continue if tg_id in self.user_sockets: self.user_sockets.clear() self.user_messages.clear() self.user_keys.clear() logger.info( "Closed old websocket for user %s due to key change", tg_id ) success = await self.try_connect_user(api_key, api_secret, tg_id) if success: self.user_keys[tg_id] = (api_key, api_secret) self.user_messages.setdefault( tg_id, {"position": None, "order": None, "execution": None} ) logger.info("User %s connected to WebSocket", tg_id) else: await asyncio.sleep(30) await asyncio.sleep(10) async def clear_user_sockets(self): """Clear the user_sockets and user_messages dictionaries.""" self.user_sockets.clear() self.user_messages.clear() self.user_keys.clear() logger.info("Cleared user_sockets") async def try_connect_user(self, api_key, api_secret, tg_id): """Try to connect a user to the WebSocket.""" try: self.ws_private = WebSocket( testnet=False, channel_type="private", api_key=api_key, api_secret=api_secret, ) self.user_sockets[tg_id] = self.ws_private # Connect to the WebSocket private channel # Handle position updates self.ws_private.position_stream( lambda msg: self.loop.call_soon_threadsafe( asyncio.create_task, self.handle_position_update(msg) ) ) # Handle order updates self.ws_private.order_stream( lambda msg: self.loop.call_soon_threadsafe( asyncio.create_task, self.handle_order_update(msg, tg_id) ) ) # Handle execution updates self.ws_private.execution_stream( lambda msg: self.loop.call_soon_threadsafe( asyncio.create_task, self.handle_execution_update(msg, tg_id) ) ) return True except Exception as e: logger.error("Error connecting user %s: %s", tg_id, e) return False async def handle_position_update(self, message): """Handle position updates.""" await self.message_handler.format_position_update(message) async def handle_order_update(self, message, tg_id): """Handle order updates.""" await self.message_handler.format_order_update(message, tg_id) async def handle_execution_update(self, message, tg_id): """Handle execution updates.""" await self.message_handler.format_execution_update(message, tg_id) @staticmethod async def get_users_from_db(): """Get all users from the database.""" return await rq.get_users()