import asyncio from collections import deque import logging.config from pybit.unified_trading import WebSocket import database.request as rq from app.bybit.logger_bybit.logger_bybit import LOGGING_CONFIG from app.bybit.telegram_message_handler import TelegramMessageHandler logging.config.dictConfig(LOGGING_CONFIG) logger = logging.getLogger("web_socket") class CustomWebSocket(WebSocket): """Custom WebSocket wrapper with enhanced error handling.""" def _on_error(self, error): logger.error(f"WebSocket error: {error}") return super()._on_error(error) def _on_close(self): logger.warning("WebSocket connection closed") super()._on_close() class WebSocketBot: """ Manages multiple Bybit private WebSocket connections for Telegram users. Uses queue-based message processing to handle thread-safe async calls. """ def __init__(self, telegram_bot): """ Initialize WebSocketBot. Args: telegram_bot: Telegram bot instance for message handling """ self.telegram_bot = telegram_bot self.user_sockets = {} self.user_messages = {} self.user_keys = {} self.loop = None self.message_handler = TelegramMessageHandler(telegram_bot) self.order_queues = {} # {tg_id: deque} self.execution_queues = {} # {tg_id: deque} self.processing_tasks = {} # {tg_id: task} async def run_user_check_loop(self): """Main loop that continuously checks users and maintains connections.""" self.loop = asyncio.get_running_loop() logger.info("Starting WebSocket user check loop") while True: try: users = await WebSocketBot.get_users_from_db() for user in users: tg_id = user.tg_id api_key, api_secret = await rq.get_user_api(tg_id=tg_id) if not api_key or not api_secret: continue keys_stored = self.user_keys.get(tg_id) socket_exists = tg_id in self.user_sockets if socket_exists and keys_stored == (api_key, api_secret): continue if socket_exists: await self.close_user_socket(tg_id) success = await self.try_connect_user(api_key, api_secret, tg_id) if success: self.user_keys[tg_id] = (api_key, api_secret) self.user_messages.setdefault( tg_id, {"position": None, "order": None, "execution": None} ) logger.info("User %s successfully connected", tg_id) except Exception as e: logger.error("Error in user check loop: %s", e) await asyncio.sleep(10) async def try_connect_user(self, api_key, api_secret, tg_id): """ Create and setup WebSocket streams with thread-safe queues. """ try: ws = CustomWebSocket( demo=True, testnet=False, channel_type="private", api_key=api_key, api_secret=api_secret ) self.user_sockets[tg_id] = ws self.order_queues[tg_id] = deque() self.execution_queues[tg_id] = deque() self.processing_tasks[tg_id] = asyncio.create_task( self._process_order_queue(tg_id) ) self.processing_tasks[tg_id + 1] = asyncio.create_task( self._process_execution_queue(tg_id) ) def order_callback(msg): self.order_queues[tg_id].append(msg) def execution_callback(msg): self.execution_queues[tg_id].append(msg) ws.order_stream(order_callback) ws.execution_stream(execution_callback) logger.info("WebSocket streams configured for user %s", tg_id) return True except Exception as e: logger.error("Error connecting user %s: %s", tg_id, e) self.user_sockets.pop(tg_id, None) return False async def _process_order_queue(self, tg_id): """Continuously process order queue for user.""" while tg_id in self.user_sockets: try: if self.order_queues[tg_id]: msg = self.order_queues[tg_id].popleft() await self.handle_order_update(msg, tg_id) except Exception as e: logger.error("Error processing order queue %s: %s", tg_id, e) await asyncio.sleep(0.01) async def _process_execution_queue(self, tg_id): """Continuously process execution queue for user.""" while tg_id in self.user_sockets: try: if self.execution_queues[tg_id]: msg = self.execution_queues[tg_id].popleft() await self.handle_execution_update(msg, tg_id) except Exception as e: logger.error("Error processing execution queue %s: %s", tg_id, e) await asyncio.sleep(0.01) async def close_user_socket(self, tg_id): """Gracefully close user connection.""" if tg_id in self.user_sockets: self.user_sockets.pop(tg_id, None) for key in (tg_id, tg_id + 1): task = self.processing_tasks.pop(key, None) if task and not task.done(): task.cancel() self.order_queues.pop(tg_id, None) self.execution_queues.pop(tg_id, None) self.user_messages.pop(tg_id, None) self.user_keys.pop(tg_id, None) logger.info("Cleaned up user %s", tg_id) async def handle_order_update(self, message, tg_id): """Process order updates.""" try: await self.message_handler.format_order_update(message, tg_id) except Exception as e: logger.error("Error handling order update for %s: %s", tg_id, e) async def handle_execution_update(self, message, tg_id): """Process execution updates.""" try: await self.message_handler.format_execution_update(message, tg_id) except Exception as e: logger.error("Error handling execution update for %s: %s", tg_id, e) @staticmethod async def get_users_from_db(): """Fetch all users from database.""" try: return await rq.get_users() except Exception as e: logger.error("Error getting users from DB: %s", e) return []