From e7620248c71e151f9495bbd5948af7fcb2697c70 Mon Sep 17 00:00:00 2001 From: 21in7 Date: Thu, 5 Mar 2026 23:13:04 +0900 Subject: [PATCH] feat: TradingBot accepts symbol and shared RiskManager, removes config.symbol dependency Co-Authored-By: Claude Opus 4.6 --- src/bot.py | 40 ++++++++++++++++++++++------------------ tests/test_bot.py | 37 +++++++++++++++++++++++++++++++++++-- 2 files changed, 57 insertions(+), 20 deletions(-) diff --git a/src/bot.py b/src/bot.py index 4567101..474e536 100644 --- a/src/bot.py +++ b/src/bot.py @@ -14,11 +14,12 @@ from src.user_data_stream import UserDataStream class TradingBot: - def __init__(self, config: Config): + def __init__(self, config: Config, symbol: str = None, risk: RiskManager = None): self.config = config - self.exchange = BinanceFuturesClient(config) + self.symbol = symbol or config.symbol + self.exchange = BinanceFuturesClient(config, symbol=self.symbol) self.notifier = DiscordNotifier(config.discord_webhook_url) - self.risk = RiskManager(config) + self.risk = risk or RiskManager(config) self.ml_filter = MLFilter(threshold=config.ml_threshold) self.current_trade_side: str | None = None # "LONG" | "SHORT" self._entry_price: float | None = None @@ -28,17 +29,17 @@ class TradingBot: self._oi_history: deque = deque(maxlen=5) self._latest_ret_1: float = 0.0 self.stream = MultiSymbolStream( - symbols=[config.symbol, "BTCUSDT", "ETHUSDT"], + symbols=[self.symbol] + config.correlation_symbols, interval="15m", on_candle=self._on_candle_closed, ) async def _on_candle_closed(self, candle: dict): - xrp_df = self.stream.get_dataframe(self.config.symbol) + primary_df = self.stream.get_dataframe(self.symbol) btc_df = self.stream.get_dataframe("BTCUSDT") eth_df = self.stream.get_dataframe("ETHUSDT") - if xrp_df is not None: - await self.process_candle(xrp_df, btc_df=btc_df, eth_df=eth_df) + if primary_df is not None: + await self.process_candle(primary_df, btc_df=btc_df, eth_df=eth_df) async def _recover_position(self) -> None: """재시작 시 바이낸스에서 현재 포지션을 조회하여 상태 복구.""" @@ -134,8 +135,8 @@ class TradingBot: if position is None and raw_signal != "HOLD": self.current_trade_side = None - if not self.risk.can_open_new_position(): - logger.info("최대 포지션 수 도달") + if not await self.risk.can_open_new_position(self.symbol, raw_signal): + logger.info(f"[{self.symbol}] 포지션 오픈 불가") return signal = raw_signal features = build_features( @@ -163,12 +164,14 @@ class TradingBot: async def _open_position(self, signal: str, df): balance = await self.exchange.get_balance() + num_symbols = len(self.config.symbols) + per_symbol_balance = balance / num_symbols price = df["close"].iloc[-1] margin_ratio = self.risk.get_dynamic_margin_ratio(balance) quantity = self.exchange.calculate_quantity( - balance=balance, price=price, leverage=self.config.leverage, margin_ratio=margin_ratio + balance=per_symbol_balance, price=price, leverage=self.config.leverage, margin_ratio=margin_ratio ) - logger.info(f"포지션 크기: 잔고={balance:.2f} USDT, 증거금비율={margin_ratio:.1%}, 수량={quantity}") + logger.info(f"[{self.symbol}] 포지션 크기: 잔고={per_symbol_balance:.2f}/{balance:.2f} USDT, 증거금비율={margin_ratio:.1%}, 수량={quantity}") stop_loss, take_profit = Indicators(df).get_atr_stop(df, signal, price) notional = quantity * price @@ -190,11 +193,12 @@ class TradingBot: "atr": float(last_row["atr"]) if "atr" in last_row.index and pd.notna(last_row["atr"]) else 0.0, } + await self.risk.register_position(self.symbol, signal) self.current_trade_side = signal self._entry_price = price self._entry_quantity = quantity self.notifier.notify_open( - symbol=self.config.symbol, + symbol=self.symbol, side=signal, entry_price=price, quantity=quantity, @@ -245,10 +249,10 @@ class TradingBot: estimated_pnl = self._calc_estimated_pnl(exit_price) diff = net_pnl - estimated_pnl - self.risk.record_pnl(net_pnl) + await self.risk.close_position(self.symbol, net_pnl) self.notifier.notify_close( - symbol=self.config.symbol, + symbol=self.symbol, side=self.current_trade_side or "UNKNOWN", close_reason=close_reason, exit_price=exit_price, @@ -317,8 +321,8 @@ class TradingBot: try: await self._close_position(position) - if not self.risk.can_open_new_position(): - logger.info("최대 포지션 수 도달 — 재진입 건너뜀") + if not await self.risk.can_open_new_position(self.symbol, signal): + logger.info(f"[{self.symbol}] 최대 포지션 수 도달 — 재진입 건너뜀") return if self.ml_filter.is_model_loaded(): @@ -337,7 +341,7 @@ class TradingBot: self._is_reentering = False async def run(self): - logger.info(f"봇 시작: {self.config.symbol}, 레버리지 {self.config.leverage}x") + logger.info(f"[{self.symbol}] 봇 시작, 레버리지 {self.config.leverage}x") await self._recover_position() await self._init_oi_history() balance = await self.exchange.get_balance() @@ -345,7 +349,7 @@ class TradingBot: logger.info(f"기준 잔고 설정: {balance:.2f} USDT (동적 증거금 비율 기준점)") user_stream = UserDataStream( - symbol=self.config.symbol, + symbol=self.symbol, on_order_filled=self._on_position_closed, ) diff --git a/tests/test_bot.py b/tests/test_bot.py index dab47be..e9846ba 100644 --- a/tests/test_bot.py +++ b/tests/test_bot.py @@ -37,6 +37,25 @@ def sample_df(): }) +def test_bot_accepts_symbol_and_risk(config): + """TradingBot이 symbol과 risk를 외부에서 주입받을 수 있다.""" + from src.risk_manager import RiskManager + risk = RiskManager(config) + with patch("src.bot.BinanceFuturesClient"): + bot = TradingBot(config, symbol="TRXUSDT", risk=risk) + assert bot.symbol == "TRXUSDT" + assert bot.risk is risk + + +def test_bot_stream_uses_injected_symbol(config): + """봇의 stream이 주입된 심볼을 primary로 사용한다.""" + from src.risk_manager import RiskManager + risk = RiskManager(config) + with patch("src.bot.BinanceFuturesClient"): + bot = TradingBot(config, symbol="DOGEUSDT", risk=risk) + assert "dogeusdt" in bot.stream.buffers + + def test_bot_uses_multi_symbol_stream(config): from src.data_stream import MultiSymbolStream with patch("src.bot.BinanceFuturesClient"): @@ -64,6 +83,12 @@ async def test_bot_processes_signal(config, sample_df): bot.exchange.calculate_quantity = MagicMock(return_value=100.0) bot.exchange.MIN_NOTIONAL = 5.0 + bot.risk = MagicMock() + bot.risk.is_trading_allowed.return_value = True + bot.risk.can_open_new_position = AsyncMock(return_value=True) + bot.risk.register_position = AsyncMock() + bot.risk.get_dynamic_margin_ratio.return_value = 0.50 + with patch("src.bot.Indicators") as MockInd: mock_ind = MagicMock() mock_ind.calculate_all.return_value = sample_df @@ -82,7 +107,7 @@ async def test_close_and_reenter_calls_open_when_ml_passes(config, sample_df): bot._close_position = AsyncMock() bot._open_position = AsyncMock() bot.risk = MagicMock() - bot.risk.can_open_new_position.return_value = True + bot.risk.can_open_new_position = AsyncMock(return_value=True) bot.ml_filter = MagicMock() bot.ml_filter.is_model_loaded.return_value = True bot.ml_filter.should_enter.return_value = True @@ -102,6 +127,8 @@ async def test_close_and_reenter_skips_open_when_ml_blocks(config, sample_df): bot._close_position = AsyncMock() bot._open_position = AsyncMock() + bot.risk = MagicMock() + bot.risk.can_open_new_position = AsyncMock(return_value=True) bot.ml_filter = MagicMock() bot.ml_filter.is_model_loaded.return_value = True bot.ml_filter.should_enter.return_value = False @@ -122,7 +149,7 @@ async def test_close_and_reenter_skips_open_when_max_positions_reached(config, s bot._close_position = AsyncMock() bot._open_position = AsyncMock() bot.risk = MagicMock() - bot.risk.can_open_new_position.return_value = False + bot.risk.can_open_new_position = AsyncMock(return_value=False) position = {"positionAmt": "100", "entryPrice": "0.5", "markPrice": "0.52"} await bot._close_and_reenter(position, "SHORT", sample_df) @@ -206,6 +233,12 @@ async def test_process_candle_fetches_oi_and_funding(config, sample_df): bot.exchange.get_open_interest = AsyncMock(return_value=5000000.0) bot.exchange.get_funding_rate = AsyncMock(return_value=0.0001) + bot.risk = MagicMock() + bot.risk.is_trading_allowed.return_value = True + bot.risk.can_open_new_position = AsyncMock(return_value=True) + bot.risk.register_position = AsyncMock() + bot.risk.get_dynamic_margin_ratio.return_value = 0.50 + # 신호를 LONG으로 강제해 build_features가 반드시 호출되도록 함 with patch("src.bot.Indicators") as mock_ind_cls: mock_ind = MagicMock()