From 181f82d3c07eb2f22d707aba0809d1ad3880bd36 Mon Sep 17 00:00:00 2001 From: 21in7 Date: Thu, 19 Mar 2026 23:03:52 +0900 Subject: [PATCH] fix: address critical code review issues (PnL double recording, sync HTTP, race conditions) - fix(bot): prevent PnL double recording in _close_and_reenter using asyncio.Event - fix(bot): prevent SYNC detection PnL duplication with _close_handled_by_sync flag - fix(notifier): move sync HTTP call to background thread via run_in_executor - fix(risk_manager): make is_trading_allowed async with lock for thread safety - fix(exchange): cache exchange info at class level (1 API call for all symbols) - fix(exchange): use `is not None` instead of truthy check for price/stop_price - refactor(backtester): extract _calc_trade_stats to eliminate code duplication - fix(ml_features): apply rolling z-score to OI/funding rate in serving (train-serve skew) - fix(bot): use config.correlation_symbols instead of hardcoded BTCUSDT/ETHUSDT - fix(bot): expand OI/funding history deque to 96 for z-score window - cleanup(config): remove unused stop_loss_pct, take_profit_pct, trailing_stop_pct fields Co-Authored-By: Claude Opus 4.6 (1M context) --- src/backtester.py | 207 +++++++++++++------------------------ src/bot.py | 48 +++++++-- src/config.py | 3 - src/exchange.py | 30 ++++-- src/ml_features.py | 29 ++++-- src/notifier.py | 9 ++ src/risk_manager.py | 19 ++-- tests/test_bot.py | 16 ++- tests/test_risk_manager.py | 10 +- 9 files changed, 189 insertions(+), 182 deletions(-) diff --git a/src/backtester.py b/src/backtester.py index ff46e9d..6860d75 100644 --- a/src/backtester.py +++ b/src/backtester.py @@ -14,6 +14,73 @@ import numpy as np import pandas as pd from loguru import logger + +def _calc_trade_stats(trades: list[dict], initial_balance: float) -> dict: + """거래 리스트에서 통계 요약을 계산한다. Backtester와 WalkForward 공통 사용.""" + if not trades: + return { + "total_trades": 0, "total_pnl": 0.0, "return_pct": 0.0, + "win_rate": 0.0, "avg_win": 0.0, "avg_loss": 0.0, + "payoff_ratio": 0.0, "max_consecutive_losses": 0, + "profit_factor": 0.0, "max_drawdown_pct": 0.0, + "sharpe_ratio": 0.0, "total_fees": 0.0, "close_reasons": {}, + } + + pnls = [t["net_pnl"] for t in trades] + wins = [p for p in pnls if p > 0] + losses = [p for p in pnls if p <= 0] + + total_pnl = sum(pnls) + total_fees = sum(t["entry_fee"] + t["exit_fee"] for t in trades) + gross_profit = sum(wins) if wins else 0.0 + gross_loss = abs(sum(losses)) if losses else 0.0 + + cumulative = np.cumsum(pnls) + equity = initial_balance + cumulative + peak = np.maximum.accumulate(equity) + drawdown = (peak - equity) / peak + mdd = float(np.max(drawdown)) * 100 if len(drawdown) > 0 else 0.0 + + if len(pnls) > 1: + pnl_arr = np.array(pnls) + sharpe = float(np.mean(pnl_arr) / np.std(pnl_arr) * np.sqrt(24192)) if np.std(pnl_arr) > 0 else 0.0 + else: + sharpe = 0.0 + + avg_w = float(np.mean(wins)) if wins else 0.0 + avg_l = float(np.mean(losses)) if losses else 0.0 + payoff_ratio = round(avg_w / abs(avg_l), 2) if avg_l != 0 else float("inf") + + max_consec_loss = 0 + cur_streak = 0 + for p in pnls: + if p <= 0: + cur_streak += 1 + max_consec_loss = max(max_consec_loss, cur_streak) + else: + cur_streak = 0 + + reasons = {} + for t in trades: + r = t["close_reason"] + reasons[r] = reasons.get(r, 0) + 1 + + return { + "total_trades": len(trades), + "total_pnl": round(total_pnl, 4), + "return_pct": round(total_pnl / initial_balance * 100, 2), + "win_rate": round(len(wins) / len(trades) * 100, 2), + "avg_win": round(avg_w, 4), + "avg_loss": round(avg_l, 4), + "payoff_ratio": payoff_ratio, + "max_consecutive_losses": max_consec_loss, + "profit_factor": round(gross_profit / gross_loss, 2) if gross_loss > 0 else float("inf"), + "max_drawdown_pct": round(mdd, 2), + "sharpe_ratio": round(sharpe, 2), + "total_fees": round(total_fees, 4), + "close_reasons": reasons, + } + import warnings import joblib @@ -524,80 +591,7 @@ class Backtester: } def _calc_summary(self) -> dict: - if not self.trades: - return { - "total_trades": 0, - "total_pnl": 0.0, - "return_pct": 0.0, - "win_rate": 0.0, - "avg_win": 0.0, - "avg_loss": 0.0, - "profit_factor": 0.0, - "max_drawdown_pct": 0.0, - "sharpe_ratio": 0.0, - "total_fees": 0.0, - "close_reasons": {}, - } - - pnls = [t["net_pnl"] for t in self.trades] - wins = [p for p in pnls if p > 0] - losses = [p for p in pnls if p <= 0] - - total_pnl = sum(pnls) - total_fees = sum(t["entry_fee"] + t["exit_fee"] for t in self.trades) - gross_profit = sum(wins) if wins else 0.0 - gross_loss = abs(sum(losses)) if losses else 0.0 - - # MDD 계산 - cumulative = np.cumsum(pnls) - equity = self.cfg.initial_balance + cumulative - peak = np.maximum.accumulate(equity) - drawdown = (peak - equity) / peak - mdd = float(np.max(drawdown)) * 100 if len(drawdown) > 0 else 0.0 - - # 샤프비율 (연율화, 15분봉 기준: 252일 * 96봉 = 24192) - if len(pnls) > 1: - pnl_arr = np.array(pnls) - sharpe = float(np.mean(pnl_arr) / np.std(pnl_arr) * np.sqrt(24192)) if np.std(pnl_arr) > 0 else 0.0 - else: - sharpe = 0.0 - - # 손익비 (avg_win / |avg_loss|) - avg_w = float(np.mean(wins)) if wins else 0.0 - avg_l = float(np.mean(losses)) if losses else 0.0 - payoff_ratio = round(avg_w / abs(avg_l), 2) if avg_l != 0 else float("inf") - - # 최대 연속 손실 횟수 - max_consec_loss = 0 - cur_streak = 0 - for p in pnls: - if p <= 0: - cur_streak += 1 - max_consec_loss = max(max_consec_loss, cur_streak) - else: - cur_streak = 0 - - # 청산 사유별 비율 - reasons = {} - for t in self.trades: - r = t["close_reason"] - reasons[r] = reasons.get(r, 0) + 1 - - return { - "total_trades": len(self.trades), - "total_pnl": round(total_pnl, 4), - "return_pct": round(total_pnl / self.cfg.initial_balance * 100, 2), - "win_rate": round(len(wins) / len(self.trades) * 100, 2) if self.trades else 0.0, - "avg_win": round(avg_w, 4), - "avg_loss": round(avg_l, 4), - "payoff_ratio": payoff_ratio, - "max_consecutive_losses": max_consec_loss, - "profit_factor": round(gross_profit / gross_loss, 2) if gross_loss > 0 else float("inf"), - "max_drawdown_pct": round(mdd, 2), - "sharpe_ratio": round(sharpe, 2), - "total_fees": round(total_fees, 4), - "close_reasons": reasons, - } + return _calc_trade_stats(self.trades, self.cfg.initial_balance) # ── Walk-Forward 백테스트 ───────────────────────────────────────────── @@ -810,70 +804,7 @@ class WalkForwardBacktester: """폴드별 결과를 합산하여 전체 Walk-Forward 결과 생성.""" from src.backtest_validator import validate - # 전체 통계 계산 - if not all_trades: - summary = {"total_trades": 0, "total_pnl": 0.0, "return_pct": 0.0, - "win_rate": 0.0, "avg_win": 0.0, "avg_loss": 0.0, - "payoff_ratio": 0.0, "max_consecutive_losses": 0, - "profit_factor": 0.0, "max_drawdown_pct": 0.0, - "sharpe_ratio": 0.0, "total_fees": 0.0, "close_reasons": {}} - else: - pnls = [t["net_pnl"] for t in all_trades] - wins = [p for p in pnls if p > 0] - losses = [p for p in pnls if p <= 0] - total_pnl = sum(pnls) - total_fees = sum(t["entry_fee"] + t["exit_fee"] for t in all_trades) - gross_profit = sum(wins) if wins else 0.0 - gross_loss = abs(sum(losses)) if losses else 0.0 - - cumulative = np.cumsum(pnls) - equity = self.cfg.initial_balance + cumulative - peak = np.maximum.accumulate(equity) - drawdown = (peak - equity) / peak - mdd = float(np.max(drawdown)) * 100 if len(drawdown) > 0 else 0.0 - - if len(pnls) > 1: - pnl_arr = np.array(pnls) - sharpe = float(np.mean(pnl_arr) / np.std(pnl_arr) * np.sqrt(24192)) if np.std(pnl_arr) > 0 else 0.0 - else: - sharpe = 0.0 - - # 손익비 (avg_win / |avg_loss|) - avg_w = float(np.mean(wins)) if wins else 0.0 - avg_l = float(np.mean(losses)) if losses else 0.0 - payoff_ratio = round(avg_w / abs(avg_l), 2) if avg_l != 0 else float("inf") - - # 최대 연속 손실 횟수 - max_consec_loss = 0 - cur_streak = 0 - for p in pnls: - if p <= 0: - cur_streak += 1 - max_consec_loss = max(max_consec_loss, cur_streak) - else: - cur_streak = 0 - - reasons = {} - for t in all_trades: - r = t["close_reason"] - reasons[r] = reasons.get(r, 0) + 1 - - summary = { - "total_trades": len(all_trades), - "total_pnl": round(total_pnl, 4), - "return_pct": round(total_pnl / self.cfg.initial_balance * 100, 2), - "win_rate": round(len(wins) / len(all_trades) * 100, 2), - "avg_win": round(avg_w, 4), - "avg_loss": round(avg_l, 4), - "payoff_ratio": payoff_ratio, - "max_consecutive_losses": max_consec_loss, - "profit_factor": round(gross_profit / gross_loss, 2) if gross_loss > 0 else float("inf"), - "max_drawdown_pct": round(mdd, 2), - "sharpe_ratio": round(sharpe, 2), - "total_fees": round(total_fees, 4), - "close_reasons": reasons, - } - + summary = _calc_trade_stats(all_trades, self.cfg.initial_balance) validation = validate(all_trades, summary, self.cfg) return { diff --git a/src/bot.py b/src/bot.py index 064a552..e0668b0 100644 --- a/src/bot.py +++ b/src/bot.py @@ -76,8 +76,11 @@ class TradingBot: self._entry_price: float | None = None self._entry_quantity: float | None = None self._is_reentering: bool = False # _close_and_reenter 중 콜백 상태 초기화 방지 + self._close_event = asyncio.Event() # 콜백 청산 완료 대기용 + self._close_handled_by_sync: bool = False # SYNC 감지 시 콜백 중복 방지 self._prev_oi: float | None = None # OI 변화율 계산용 이전 값 - self._oi_history: deque = deque(maxlen=5) + self._oi_history: deque = deque(maxlen=96) # z-score 윈도우(96=1일분 15분봉) + self._funding_history: deque = deque(maxlen=96) self._latest_ret_1: float = 0.0 self._killed: bool = False # 킬스위치 발동 상태 self._trade_history: list[dict] = [] # 최근 거래 이력 (net_pnl 기록) @@ -190,8 +193,9 @@ class TradingBot: async def _on_candle_closed(self, candle: dict): primary_df = self.stream.get_dataframe(self.symbol) - btc_df = self.stream.get_dataframe("BTCUSDT") - eth_df = self.stream.get_dataframe("ETHUSDT") + corr = self.config.correlation_symbols + btc_df = self.stream.get_dataframe(corr[0]) if len(corr) > 0 else None + eth_df = self.stream.get_dataframe(corr[1]) if len(corr) > 1 else None if primary_df is not None: await self.process_candle(primary_df, btc_df=btc_df, eth_df=eth_df) @@ -240,9 +244,13 @@ class TradingBot: oi_change = 0.0 fr_float = float(fr_val) if isinstance(fr_val, (int, float)) else 0.0 - # OI 히스토리 업데이트 및 MA5 계산 + # 히스토리 업데이트 (z-score 계산용) self._oi_history.append(oi_change) - oi_ma5 = sum(self._oi_history) / len(self._oi_history) if self._oi_history else 0.0 + self._funding_history.append(fr_float) + + # OI MA5 계산 + recent_5 = list(self._oi_history)[-5:] + oi_ma5 = sum(recent_5) / len(recent_5) if recent_5 else 0.0 # OI-가격 스프레드 oi_price_spread = oi_change - self._latest_ret_1 @@ -274,7 +282,7 @@ class TradingBot: # 캔들 마감 시 OI/펀딩비 실시간 조회 (실패해도 0으로 폴백) oi_change, funding_rate, oi_ma5, oi_price_spread = await self._fetch_market_microstructure() - if not self.risk.is_trading_allowed(): + if not await self.risk.is_trading_allowed(): logger.warning(f"[{self.symbol}] 리스크 한도 초과 - 거래 중단") return @@ -313,6 +321,8 @@ class TradingBot: btc_df=btc_df, eth_df=eth_df, oi_change=oi_change, funding_rate=funding_rate, oi_change_ma5=oi_ma5, oi_price_spread=oi_price_spread, + oi_history=list(self._oi_history), + funding_history=list(self._funding_history), ) if self.ml_filter.is_model_loaded(): if not self.ml_filter.should_enter(features): @@ -419,6 +429,12 @@ class TradingBot: exit_price: float, ) -> None: """User Data Stream에서 청산 감지 시 호출되는 콜백.""" + # SYNC 핸들러가 이미 처리한 경우 중복 기록 방지 + if self._close_handled_by_sync: + logger.debug(f"[{self.symbol}] SYNC에서 이미 처리된 청산 — 콜백 건너뜀") + self._close_event.set() + return + estimated_pnl = self._calc_estimated_pnl(exit_price) diff = net_pnl - estimated_pnl @@ -443,6 +459,9 @@ class TradingBot: self._append_trade(net_pnl, close_reason) self._check_kill_switch() + # _close_and_reenter 대기 해제 + self._close_event.set() + # _close_and_reenter 중이면 신규 포지션 상태를 덮어쓰지 않는다 if self._is_reentering: return @@ -469,6 +488,8 @@ class TradingBot: f"[{self.symbol}] 포지션 불일치 감지: " f"봇={self.current_trade_side}, 바이낸스=포지션 없음 — 상태 동기화" ) + # 콜백 중복 방지 플래그 설정 + self._close_handled_by_sync = True # Binance income API에서 실제 PnL 조회 realized_pnl = 0.0 commission = 0.0 @@ -509,6 +530,7 @@ class TradingBot: self.current_trade_side = None self._entry_price = None self._entry_quantity = None + self._close_handled_by_sync = False continue except Exception as e: logger.debug(f"[{self.symbol}] 포지션 동기화 확인 실패 (무시): {e}") @@ -550,15 +572,21 @@ class TradingBot: """기존 포지션을 청산하고, ML 필터 통과 시 반대 방향으로 즉시 재진입한다.""" # 재진입 플래그: User Data Stream 콜백이 신규 포지션 상태를 초기화하지 않도록 보호 self._is_reentering = True - prev_side = self.current_trade_side + self._close_event.clear() try: await self._close_position(position) - # 청산 완료 확인: 콜백이 처리했든 아니든 로컬 상태를 명시적으로 Flat으로 전환 + # 콜백이 PnL을 기록할 때까지 대기 (최대 10초) + try: + await asyncio.wait_for(self._close_event.wait(), timeout=10) + except asyncio.TimeoutError: + logger.warning(f"[{self.symbol}] 청산 콜백 타임아웃 — 수동 동기화") + await self.risk.close_position(self.symbol, 0.0) + + # 로컬 상태를 Flat으로 전환 self.current_trade_side = None self._entry_price = None self._entry_quantity = None - await self.risk.close_position(self.symbol, 0.0) if prev_side and self.symbol not in self.risk.open_positions else None if self._killed: logger.info(f"[{self.symbol}] 킬스위치 활성 — 재진입 건너뜀 (청산만 수행)") @@ -574,6 +602,8 @@ class TradingBot: btc_df=btc_df, eth_df=eth_df, oi_change=oi_change, funding_rate=funding_rate, oi_change_ma5=oi_change_ma5, oi_price_spread=oi_price_spread, + oi_history=list(self._oi_history), + funding_history=list(self._funding_history), ) if not self.ml_filter.should_enter(features): logger.info(f"[{self.symbol}] ML 필터 차단: {signal} 재진입 무시") diff --git a/src/config.py b/src/config.py index 69a48bf..c0d92ff 100644 --- a/src/config.py +++ b/src/config.py @@ -25,9 +25,6 @@ class Config: leverage: int = 10 max_positions: int = 3 max_same_direction: int = 2 - stop_loss_pct: float = 0.015 # 1.5% - take_profit_pct: float = 0.045 # 4.5% (3:1 RR) - trailing_stop_pct: float = 0.01 # 1% discord_webhook_url: str = "" margin_max_ratio: float = 0.50 margin_min_ratio: float = 0.20 diff --git a/src/exchange.py b/src/exchange.py index 5b3f979..7a1dbb6 100644 --- a/src/exchange.py +++ b/src/exchange.py @@ -7,6 +7,9 @@ from src.config import Config class BinanceFuturesClient: + # 클래스 레벨 exchange info 캐시 (전체 심볼 1회만 조회) + _exchange_info_cache: dict | None = None + def __init__(self, config: Config, symbol: str = None): self.config = config self.symbol = symbol or config.symbol @@ -19,10 +22,21 @@ class BinanceFuturesClient: MIN_NOTIONAL = 5.0 # 바이낸스 선물 최소 명목금액 (USDT) + @classmethod + def _get_exchange_info(cls, client: Client) -> dict | None: + """exchange info를 클래스 레벨로 캐시하여 1회만 조회한다.""" + if cls._exchange_info_cache is None: + try: + cls._exchange_info_cache = client.futures_exchange_info() + except Exception as e: + logger.warning(f"exchange info 조회 실패: {e}") + return None + return cls._exchange_info_cache + def _load_symbol_precision(self) -> None: """바이낸스 exchange info에서 심볼별 수량/가격 정밀도를 로드한다.""" - try: - info = self.client.futures_exchange_info() + info = self._get_exchange_info(self.client) + if info is not None: for s in info["symbols"]: if s["symbol"] == self.symbol: self._qty_precision = s.get("quantityPrecision", 1) @@ -32,12 +46,8 @@ class BinanceFuturesClient: ) return logger.warning(f"[{self.symbol}] exchange info에서 심볼 미발견, 기본 정밀도 사용") - self._qty_precision = 1 - self._price_precision = 2 - except Exception as e: - logger.warning(f"[{self.symbol}] exchange info 조회 실패 ({e}), 기본 정밀도 사용") - self._qty_precision = 1 - self._price_precision = 2 + self._qty_precision = 1 + self._price_precision = 2 @property def qty_precision(self) -> int: @@ -109,10 +119,10 @@ class BinanceFuturesClient: quantity=quantity, reduceOnly=reduce_only, ) - if price: + if price is not None: params["price"] = price params["timeInForce"] = "GTC" - if stop_price: + if stop_price is not None: params["stopPrice"] = stop_price try: return await loop.run_in_executor( diff --git a/src/ml_features.py b/src/ml_features.py index ebc1200..4fbeae0 100644 --- a/src/ml_features.py +++ b/src/ml_features.py @@ -167,6 +167,8 @@ def build_features_aligned( funding_rate: float | None = None, oi_change_ma5: float | None = None, oi_price_spread: float | None = None, + oi_history: list[float] | None = None, + funding_history: list[float] | None = None, ) -> pd.Series: """ 학습(dataset_builder._calc_features_vectorized)과 동일한 rolling z-score를 @@ -297,12 +299,27 @@ def build_features_aligned( "primary_eth_rs": _rolling_zscore_last(rs_eth), }) - # OI/펀딩비 z-score (실시간 값이 제공되면 히스토리 끝에 추가하여 z-score) - # 서빙 시 OI/펀딩비 히스토리가 없으므로 단일 값 → z-score 불가, NaN 처리 - # LightGBM은 NaN을 자체 처리함 - base["oi_change"] = float(oi_change) if oi_change is not None else np.nan - base["funding_rate"] = float(funding_rate) if funding_rate is not None else np.nan - base["oi_change_ma5"] = float(oi_change_ma5) if oi_change_ma5 is not None else np.nan + # OI/펀딩비 z-score (학습과 동일한 rolling z-score 적용) + if oi_history and len(oi_history) >= 2 and oi_change is not None: + oi_arr = np.array(oi_history, dtype=np.float64) + base["oi_change"] = _rolling_zscore_last(oi_arr, window=_ZSCORE_WINDOW_OI) + else: + base["oi_change"] = np.nan + + if funding_history and len(funding_history) >= 2 and funding_rate is not None: + fr_arr = np.array(funding_history, dtype=np.float64) + base["funding_rate"] = _rolling_zscore_last(fr_arr, window=_ZSCORE_WINDOW_OI) + else: + base["funding_rate"] = np.nan + + if oi_history and len(oi_history) >= 5 and oi_change_ma5 is not None: + # OI MA5 히스토리로 z-score + oi_arr = np.array(oi_history, dtype=np.float64) + ma5 = pd.Series(oi_arr).rolling(5, min_periods=1).mean().values + base["oi_change_ma5"] = _rolling_zscore_last(ma5, window=_ZSCORE_WINDOW_OI) + else: + base["oi_change_ma5"] = np.nan + base["oi_price_spread"] = float(oi_price_spread) if oi_price_spread is not None else np.nan base["adx"] = adx_z diff --git a/src/notifier.py b/src/notifier.py index 7096332..7a147c3 100644 --- a/src/notifier.py +++ b/src/notifier.py @@ -1,3 +1,4 @@ +import asyncio import httpx from loguru import logger @@ -10,9 +11,17 @@ class DiscordNotifier: self._enabled = bool(webhook_url) def _send(self, content: str) -> None: + """알림 전송. 이벤트 루프 내에서는 백그라운드 스레드로 실행하여 블로킹 방지.""" if not self._enabled: logger.debug("Discord 웹훅 URL 미설정 - 알림 건너뜀") return + try: + loop = asyncio.get_running_loop() + loop.run_in_executor(None, self._send_sync, content) + except RuntimeError: + self._send_sync(content) + + def _send_sync(self, content: str) -> None: try: resp = httpx.post( self.webhook_url, diff --git a/src/risk_manager.py b/src/risk_manager.py index a59260c..1ad5bb6 100644 --- a/src/risk_manager.py +++ b/src/risk_manager.py @@ -12,17 +12,18 @@ class RiskManager: self.open_positions: dict[str, str] = {} # {symbol: side} self._lock = asyncio.Lock() - def is_trading_allowed(self) -> bool: + async def is_trading_allowed(self) -> bool: """일일 최대 손실 초과 시 거래 중단""" - if self.initial_balance <= 0: + async with self._lock: + if self.initial_balance <= 0: + return True + loss_pct = abs(self.daily_pnl) / self.initial_balance + if self.daily_pnl < 0 and loss_pct >= self.max_daily_loss_pct: + logger.warning( + f"일일 손실 한도 초과: {loss_pct:.2%} >= {self.max_daily_loss_pct:.2%}" + ) + return False return True - loss_pct = abs(self.daily_pnl) / self.initial_balance - if self.daily_pnl < 0 and loss_pct >= self.max_daily_loss_pct: - logger.warning( - f"일일 손실 한도 초과: {loss_pct:.2%} >= {self.max_daily_loss_pct:.2%}" - ) - return False - return True async def can_open_new_position(self, symbol: str, side: str) -> bool: """포지션 오픈 가능 여부 (전체 한도 + 중복 진입 + 동일 방향 제한)""" diff --git a/tests/test_bot.py b/tests/test_bot.py index 0d4e8a3..93f791b 100644 --- a/tests/test_bot.py +++ b/tests/test_bot.py @@ -84,7 +84,7 @@ async def test_bot_processes_signal(config, sample_df): bot.exchange.MIN_NOTIONAL = 5.0 bot.risk = MagicMock() - bot.risk.is_trading_allowed.return_value = True + bot.risk.is_trading_allowed = AsyncMock(return_value=True) bot.risk.can_open_new_position = AsyncMock(return_value=True) bot.risk.register_position = AsyncMock() bot.risk.get_dynamic_margin_ratio.return_value = 0.50 @@ -108,10 +108,14 @@ async def test_close_and_reenter_calls_open_when_ml_passes(config, sample_df): bot._open_position = AsyncMock() bot.risk = MagicMock() bot.risk.can_open_new_position = AsyncMock(return_value=True) + bot.risk.close_position = AsyncMock() bot.ml_filter = MagicMock() bot.ml_filter.is_model_loaded.return_value = True bot.ml_filter.should_enter.return_value = True + # 콜백 대기를 건너뛰도록 Event 미리 설정 + bot._close_event.set() + position = {"positionAmt": "100", "entryPrice": "0.5", "markPrice": "0.52"} await bot._close_and_reenter(position, "SHORT", sample_df) @@ -129,10 +133,13 @@ async def test_close_and_reenter_skips_open_when_ml_blocks(config, sample_df): bot._open_position = AsyncMock() bot.risk = MagicMock() bot.risk.can_open_new_position = AsyncMock(return_value=True) + bot.risk.close_position = AsyncMock() bot.ml_filter = MagicMock() bot.ml_filter.is_model_loaded.return_value = True bot.ml_filter.should_enter.return_value = False + bot._close_event.set() + position = {"positionAmt": "100", "entryPrice": "0.5", "markPrice": "0.52"} await bot._close_and_reenter(position, "SHORT", sample_df) @@ -150,6 +157,9 @@ async def test_close_and_reenter_skips_open_when_max_positions_reached(config, s bot._open_position = AsyncMock() bot.risk = MagicMock() bot.risk.can_open_new_position = AsyncMock(return_value=False) + bot.risk.close_position = AsyncMock() + + bot._close_event.set() position = {"positionAmt": "100", "entryPrice": "0.5", "markPrice": "0.52"} await bot._close_and_reenter(position, "SHORT", sample_df) @@ -234,7 +244,7 @@ async def test_process_candle_fetches_oi_and_funding(config, sample_df): bot.exchange.get_funding_rate = AsyncMock(return_value=0.0001) bot.risk = MagicMock() - bot.risk.is_trading_allowed.return_value = True + bot.risk.is_trading_allowed = AsyncMock(return_value=True) bot.risk.can_open_new_position = AsyncMock(return_value=True) bot.risk.register_position = AsyncMock() bot.risk.get_dynamic_margin_ratio.return_value = 0.50 @@ -266,7 +276,7 @@ def test_bot_has_oi_history_deque(config): with patch("src.bot.BinanceFuturesClient"): bot = TradingBot(config) assert isinstance(bot._oi_history, deque) - assert bot._oi_history.maxlen == 5 + assert bot._oi_history.maxlen == 96 @pytest.mark.asyncio diff --git a/tests/test_risk_manager.py b/tests/test_risk_manager.py index fe5bedb..76ecc8b 100644 --- a/tests/test_risk_manager.py +++ b/tests/test_risk_manager.py @@ -15,18 +15,20 @@ def config(): return Config() -def test_max_drawdown_check(config): +@pytest.mark.asyncio +async def test_max_drawdown_check(config): rm = RiskManager(config, max_daily_loss_pct=0.05) rm.daily_pnl = -60.0 rm.initial_balance = 1000.0 - assert rm.is_trading_allowed() is False + assert await rm.is_trading_allowed() is False -def test_trading_allowed_normal(config): +@pytest.mark.asyncio +async def test_trading_allowed_normal(config): rm = RiskManager(config, max_daily_loss_pct=0.05) rm.daily_pnl = -10.0 rm.initial_balance = 1000.0 - assert rm.is_trading_allowed() is True + assert await rm.is_trading_allowed() is True @pytest.mark.asyncio