diff --git a/src/bot.py b/src/bot.py index 5a98ee0..bfc72f9 100644 --- a/src/bot.py +++ b/src/bot.py @@ -1,4 +1,3 @@ -import asyncio import pandas as pd from loguru import logger from src.config import Config @@ -25,12 +24,12 @@ class TradingBot: on_candle=self._on_candle_closed, ) - def _on_candle_closed(self, candle: dict): + async def _on_candle_closed(self, candle: dict): xrp_df = self.stream.get_dataframe(self.config.symbol) btc_df = self.stream.get_dataframe("BTCUSDT") eth_df = self.stream.get_dataframe("ETHUSDT") if xrp_df is not None: - asyncio.create_task(self.process_candle(xrp_df, btc_df=btc_df, eth_df=eth_df)) + await self.process_candle(xrp_df, btc_df=btc_df, eth_df=eth_df) async def _recover_position(self) -> None: """재시작 시 바이낸스에서 현재 포지션을 조회하여 상태 복구.""" diff --git a/src/data_stream.py b/src/data_stream.py index 7fe9065..086e005 100644 --- a/src/data_stream.py +++ b/src/data_stream.py @@ -40,12 +40,12 @@ class KlineStream: "is_closed": k["x"], } - def handle_message(self, msg: dict): + async def handle_message(self, msg: dict): candle = self.parse_kline(msg) if candle["is_closed"]: self.buffer.append(candle) if self.on_candle: - self.on_candle(candle) + await self.on_candle(candle) def get_dataframe(self) -> pd.DataFrame | None: if len(self.buffer) < _MIN_CANDLES_FOR_SIGNAL: @@ -90,7 +90,7 @@ class KlineStream: ) as stream: while True: msg = await stream.recv() - self.handle_message(msg) + await self.handle_message(msg) finally: await client.close_connection() @@ -129,7 +129,7 @@ class MultiSymbolStream: "is_closed": k["x"], } - def handle_message(self, msg: dict): + async def handle_message(self, msg: dict): # Combined stream 메시지는 {"stream": "...", "data": {...}} 형태 if "stream" in msg: data = msg["data"] @@ -145,7 +145,7 @@ class MultiSymbolStream: if candle["is_closed"] and symbol in self.buffers: self.buffers[symbol].append(candle) if symbol == self.primary_symbol and self.on_candle: - self.on_candle(candle) + await self.on_candle(candle) def get_dataframe(self, symbol: str) -> pd.DataFrame | None: key = symbol.lower() @@ -192,6 +192,6 @@ class MultiSymbolStream: async with bm.futures_multiplex_socket(streams) as stream: while True: msg = await stream.recv() - self.handle_message(msg) + await self.handle_message(msg) finally: await client.close_connection() diff --git a/tests/test_data_stream.py b/tests/test_data_stream.py index 6f4e4ac..e5c84c6 100644 --- a/tests/test_data_stream.py +++ b/tests/test_data_stream.py @@ -63,11 +63,11 @@ async def test_kline_stream_parses_message(): @pytest.mark.asyncio async def test_callback_called_on_closed_candle(): - received = [] + callback = AsyncMock() stream = KlineStream( symbol="XRPUSDT", interval="1m", - on_candle=lambda c: received.append(c), + on_candle=callback, ) raw_msg = { "k": { @@ -80,8 +80,8 @@ async def test_callback_called_on_closed_candle(): "x": True, } } - stream.handle_message(raw_msg) - assert len(received) == 1 + await stream.handle_message(raw_msg) + assert callback.call_count == 1 @pytest.mark.asyncio