From 41b0aa3f28fe5570abea7f85b46a5a44256d8fc8 Mon Sep 17 00:00:00 2001 From: 21in7 Date: Sat, 21 Mar 2026 17:26:15 +0900 Subject: [PATCH] =?UTF-8?q?fix:=20address=20code=20review=20round=202=20?= =?UTF-8?q?=E2=80=94=209=20issues=20(2=20critical,=203=20important,=204=20?= =?UTF-8?q?minor)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Critical: - #2: Add _entry_lock in RiskManager to serialize concurrent entry (balance race) - #3: Add startTime to get_recent_income + record _entry_time_ms (SYNC PnL fix) Important: - #1: Add threading.Lock + _run_api() helper for thread-safe Client access - #4: Convert reset_daily to async with lock - #8: Add 24h TTL to exchange_info_cache Minor: - #7: Remove duplicate Indicators creation in _open_position (use ATR directly) - #11: Add input validation for LEVERAGE, MARGIN ratios, ML_THRESHOLD - #12: Replace hardcoded corr[0]/corr[1] with dict-based dynamic access - #14: Add fillna(0.0) to LightGBM path for NaN consistency with ONNX Co-Authored-By: Claude Opus 4.6 (1M context) --- CLAUDE.md | 2 + docs/plans/2026-03-21-code-review-fixes-r2.md | 108 +++++++++++++++++ main.py | 2 +- src/bot.py | 90 ++++++++------ src/config.py | 12 ++ src/exchange.py | 110 +++++++++--------- src/ml_filter.py | 1 + src/risk_manager.py | 8 +- tests/test_config.py | 25 ++++ tests/test_exchange.py | 3 + tests/test_risk_manager.py | 29 +++++ 11 files changed, 291 insertions(+), 99 deletions(-) create mode 100644 docs/plans/2026-03-21-code-review-fixes-r2.md diff --git a/CLAUDE.md b/CLAUDE.md index a34b682..0aab282 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -142,3 +142,5 @@ All design documents and implementation plans are stored in `docs/plans/` with t | 2026-03-07 | `weekly-report` (plan) | Completed | | 2026-03-07 | `code-review-improvements` | Partial (#1,#2,#4,#5,#6,#8 완료) | | 2026-03-19 | `critical-bugfixes` (C5,C1,C3,C8) | Completed | +| 2026-03-21 | `dashboard-code-review-r2` (#14,#19) | Completed | +| 2026-03-21 | `code-review-fixes-r2` (9 issues) | Completed | diff --git a/docs/plans/2026-03-21-code-review-fixes-r2.md b/docs/plans/2026-03-21-code-review-fixes-r2.md new file mode 100644 index 0000000..5cf5e70 --- /dev/null +++ b/docs/plans/2026-03-21-code-review-fixes-r2.md @@ -0,0 +1,108 @@ +# Code Review Fixes Round 2 Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Fix 9 issues from code review re-evaluation (2 Critical, 3 Important, 4 Minor) + +**Architecture:** Targeted fixes across risk_manager, exchange, bot, config, ml_filter. No new files — all modifications to existing modules. + +**Tech Stack:** Python 3.12, asyncio, python-binance, LightGBM, ONNX Runtime + +--- + +### Task 1: #2 Critical — Balance reservation lock for concurrent entry + +**Files:** +- Modify: `src/risk_manager.py` — add `_entry_lock` to serialize entry flow +- Modify: `src/bot.py:405-413` — acquire entry lock around balance read → order +- Test: `tests/test_risk_manager.py` + +The simplest fix: add an asyncio.Lock in RiskManager that serializes the entire _open_position flow across all bots. This prevents two bots from reading the same balance simultaneously. + +- [ ] Add `_entry_lock = asyncio.Lock()` to RiskManager +- [ ] Add `async def entry_lock(self)` context manager +- [ ] In bot.py `_open_position`, wrap balance read + order under `async with self.risk.entry_lock()` +- [ ] Add test for concurrent entry serialization +- [ ] Run tests + +### Task 2: #3 Critical — SYNC PnL startTime + single query + +**Files:** +- Modify: `src/exchange.py:166-185` — add `start_time` param to `get_recent_income` +- Modify: `src/bot.py:75-82` — record `_entry_time` on position open +- Modify: `src/bot.py:620-629` — pass `start_time` to income query +- Test: `tests/test_exchange.py` + +- [ ] Add `_entry_time: int | None = None` to TradingBot +- [ ] Set `_entry_time = int(time.time() * 1000)` on entry and recovery +- [ ] Add `start_time` parameter to `get_recent_income()` +- [ ] Use start_time in SYNC fallback +- [ ] Add test +- [ ] Run tests + +### Task 3: #1 Important — Thread-safe Client access + +**Files:** +- Modify: `src/exchange.py` — add `threading.Lock` per instance + +- [ ] Add `self._api_lock = threading.Lock()` in `__init__` +- [ ] Wrap all `run_in_executor` lambdas with lock acquisition +- [ ] Add test +- [ ] Run tests + +### Task 4: #4 Important — reset_daily async with lock + +**Files:** +- Modify: `src/risk_manager.py:61-64` — make async + lock +- Modify: `main.py:22` — await reset_daily +- Test: `tests/test_risk_manager.py` + +- [ ] Convert `reset_daily` to async, add lock +- [ ] Update `_daily_reset_loop` call +- [ ] Add test +- [ ] Run tests + +### Task 5: #8 Important — exchange_info cache TTL + +**Files:** +- Modify: `src/exchange.py:25-34` — add TTL (24h) + +- [ ] Add `_exchange_info_time: float = 0.0` +- [ ] Check TTL in `_get_exchange_info` +- [ ] Add test +- [ ] Run tests + +### Task 6: #7 Minor — Pass pre-computed indicators to _open_position + +**Files:** +- Modify: `src/bot.py:392,415,736` — pass df_with_indicators + +- [ ] Add `df_with_indicators` parameter to `_open_position` +- [ ] Use passed df instead of re-creating Indicators +- [ ] Run tests + +### Task 7: #11 Minor — Config input validation + +**Files:** +- Modify: `src/config.py:39` — add range checks +- Test: `tests/test_config.py` + +- [ ] Add validation for LEVERAGE, MARGIN ratios, ML_THRESHOLD +- [ ] Add test for invalid values +- [ ] Run tests + +### Task 8: #12 Minor — Dynamic correlation symbol access + +**Files:** +- Modify: `src/bot.py:196-198` — iterate dynamically + +- [ ] Replace hardcoded [0]/[1] with dict-based access +- [ ] Run tests + +### Task 9: #14 Minor — Normalize NaN handling for LightGBM + +**Files:** +- Modify: `src/ml_filter.py:144-147` — apply nan_to_num for LightGBM too + +- [ ] Add `np.nan_to_num` to LightGBM path +- [ ] Run tests diff --git a/main.py b/main.py index 387df6d..f84c57e 100644 --- a/main.py +++ b/main.py @@ -19,7 +19,7 @@ async def _daily_reset_loop(risk: RiskManager): hour=0, minute=0, second=0, microsecond=0, ) await asyncio.sleep((next_midnight - now).total_seconds()) - risk.reset_daily() + await risk.reset_daily() async def _graceful_shutdown(bots: list[TradingBot], tasks: list[asyncio.Task]): diff --git a/src/bot.py b/src/bot.py index a88716a..3415dc1 100644 --- a/src/bot.py +++ b/src/bot.py @@ -1,6 +1,7 @@ import asyncio import json import os +import time from collections import deque from datetime import datetime, timezone from pathlib import Path @@ -76,6 +77,7 @@ class TradingBot: self._entry_price: float | None = None self._entry_quantity: float | None = None self._is_reentering: bool = False # _close_and_reenter 중 콜백 상태 초기화 방지 + self._entry_time_ms: int | None = None # 포지션 진입 시각 (ms, SYNC PnL 범위 제한용) self._close_event = asyncio.Event() # 콜백 청산 완료 대기용 self._close_lock = asyncio.Lock() # 청산 처리 원자성 보장 (C3 fix) self._prev_oi: float | None = None # OI 변화율 계산용 이전 값 @@ -194,8 +196,9 @@ class TradingBot: async def _on_candle_closed(self, candle: dict): primary_df = self.stream.get_dataframe(self.symbol) corr = self.config.correlation_symbols - btc_df = self.stream.get_dataframe(corr[0]) if len(corr) > 0 else None - eth_df = self.stream.get_dataframe(corr[1]) if len(corr) > 1 else None + corr_dfs = {s: self.stream.get_dataframe(s) for s in corr} + btc_df = corr_dfs.get("BTCUSDT") + eth_df = corr_dfs.get("ETHUSDT") if primary_df is not None: await self.process_candle(primary_df, btc_df=btc_df, eth_df=eth_df) @@ -208,6 +211,7 @@ class TradingBot: self.current_trade_side = "LONG" if amt > 0 else "SHORT" self._entry_price = float(position["entryPrice"]) self._entry_quantity = abs(amt) + self._entry_time_ms = int(float(position.get("updateTime", time.time() * 1000))) entry = float(position["entryPrice"]) logger.info( f"[{self.symbol}] 기존 포지션 복구: {self.current_trade_side} | " @@ -403,44 +407,51 @@ class TradingBot: ) async def _open_position(self, signal: str, df): - balance = await self.exchange.get_balance() - num_symbols = len(self.config.symbols) - per_symbol_balance = balance / num_symbols - price = df["close"].iloc[-1] - margin_ratio = self.risk.get_dynamic_margin_ratio(per_symbol_balance) - quantity = self.exchange.calculate_quantity( - balance=per_symbol_balance, price=price, leverage=self.config.leverage, margin_ratio=margin_ratio - ) - logger.info(f"[{self.symbol}] 포지션 크기: 잔고={per_symbol_balance:.2f}/{balance:.2f} USDT, 증거금비율={margin_ratio:.1%}, 수량={quantity}") - stop_loss, take_profit = Indicators(df).get_atr_stop( - df, signal, price, - atr_sl_mult=self.strategy.atr_sl_mult, - atr_tp_mult=self.strategy.atr_tp_mult, - ) - - notional = quantity * price - if quantity <= 0 or notional < self.exchange.MIN_NOTIONAL: - logger.warning( - f"주문 건너뜀: 명목금액 {notional:.2f} USDT < 최소 {self.exchange.MIN_NOTIONAL} USDT " - f"(잔고={balance:.2f}, 수량={quantity})" + # 동시 진입 시 잔고 레이스 방지: entry_lock으로 잔고 조회→주문→등록을 직렬화 + async with self.risk._entry_lock: + balance = await self.exchange.get_balance() + num_symbols = len(self.config.symbols) + per_symbol_balance = balance / num_symbols + price = df["close"].iloc[-1] + margin_ratio = self.risk.get_dynamic_margin_ratio(per_symbol_balance) + quantity = self.exchange.calculate_quantity( + balance=per_symbol_balance, price=price, leverage=self.config.leverage, margin_ratio=margin_ratio ) - return + logger.info(f"[{self.symbol}] 포지션 크기: 잔고={per_symbol_balance:.2f}/{balance:.2f} USDT, 증거금비율={margin_ratio:.1%}, 수량={quantity}") + # df는 이미 calculate_all() 적용된 df_with_indicators이므로 + # Indicators를 재생성하지 않고 ATR을 직접 사용 + atr = df["atr"].iloc[-1] + if signal == "LONG": + stop_loss = price - atr * self.strategy.atr_sl_mult + take_profit = price + atr * self.strategy.atr_tp_mult + else: + stop_loss = price + atr * self.strategy.atr_sl_mult + take_profit = price - atr * self.strategy.atr_tp_mult - side = "BUY" if signal == "LONG" else "SELL" - await self.exchange.set_leverage(self.config.leverage) - await self.exchange.place_order(side=side, quantity=quantity) + notional = quantity * price + if quantity <= 0 or notional < self.exchange.MIN_NOTIONAL: + logger.warning( + f"주문 건너뜀: 명목금액 {notional:.2f} USDT < 최소 {self.exchange.MIN_NOTIONAL} USDT " + f"(잔고={balance:.2f}, 수량={quantity})" + ) + return - last_row = df.iloc[-1] - signal_snapshot = { - "rsi": float(last_row["rsi"]) if "rsi" in last_row.index and pd.notna(last_row["rsi"]) else 0.0, - "macd_hist": float(last_row["macd_hist"]) if "macd_hist" in last_row.index and pd.notna(last_row["macd_hist"]) else 0.0, - "atr": float(last_row["atr"]) if "atr" in last_row.index and pd.notna(last_row["atr"]) else 0.0, - } + side = "BUY" if signal == "LONG" else "SELL" + await self.exchange.set_leverage(self.config.leverage) + await self.exchange.place_order(side=side, quantity=quantity) - await self.risk.register_position(self.symbol, signal) - self.current_trade_side = signal - self._entry_price = price - self._entry_quantity = quantity + last_row = df.iloc[-1] + signal_snapshot = { + "rsi": float(last_row["rsi"]) if "rsi" in last_row.index and pd.notna(last_row["rsi"]) else 0.0, + "macd_hist": float(last_row["macd_hist"]) if "macd_hist" in last_row.index and pd.notna(last_row["macd_hist"]) else 0.0, + "atr": float(last_row["atr"]) if "atr" in last_row.index and pd.notna(last_row["atr"]) else 0.0, + } + + await self.risk.register_position(self.symbol, signal) + self.current_trade_side = signal + self._entry_price = price + self._entry_quantity = quantity + self._entry_time_ms = int(time.time() * 1000) self.notifier.notify_open( symbol=self.symbol, side=signal, @@ -592,6 +603,7 @@ class TradingBot: self.current_trade_side = None self._entry_price = None self._entry_quantity = None + self._entry_time_ms = None _MONITOR_INTERVAL = 300 # 5분 @@ -619,7 +631,9 @@ class TradingBot: commission = 0.0 exit_price = 0.0 try: - pnl_rows, comm_rows = await self.exchange.get_recent_income(limit=10) + pnl_rows, comm_rows = await self.exchange.get_recent_income( + limit=10, start_time=self._entry_time_ms, + ) if pnl_rows: realized_pnl = sum(float(r.get("income", "0")) for r in pnl_rows) if comm_rows: @@ -654,6 +668,7 @@ class TradingBot: self.current_trade_side = None self._entry_price = None self._entry_quantity = None + self._entry_time_ms = None self._close_event.set() continue except Exception as e: @@ -711,6 +726,7 @@ class TradingBot: self.current_trade_side = None self._entry_price = None self._entry_quantity = None + self._entry_time_ms = None if self._killed: logger.info(f"[{self.symbol}] 킬스위치 활성 — 재진입 건너뜀 (청산만 수행)") diff --git a/src/config.py b/src/config.py index c0d92ff..f3f795c 100644 --- a/src/config.py +++ b/src/config.py @@ -64,6 +64,18 @@ class Config: corr_env = os.getenv("CORRELATION_SYMBOLS", "BTCUSDT,ETHUSDT") self.correlation_symbols = [s.strip() for s in corr_env.split(",") if s.strip()] + # 입력 검증 + if self.leverage < 1: + raise ValueError(f"LEVERAGE는 1 이상이어야 합니다: {self.leverage}") + if not (0.0 < self.margin_max_ratio <= 1.0): + raise ValueError(f"MARGIN_MAX_RATIO는 (0, 1] 범위여야 합니다: {self.margin_max_ratio}") + if not (0.0 < self.margin_min_ratio <= 1.0): + raise ValueError(f"MARGIN_MIN_RATIO는 (0, 1] 범위여야 합니다: {self.margin_min_ratio}") + if self.margin_min_ratio > self.margin_max_ratio: + raise ValueError(f"MARGIN_MIN_RATIO({self.margin_min_ratio}) > MARGIN_MAX_RATIO({self.margin_max_ratio})") + if not (0.0 < self.ml_threshold <= 1.0): + raise ValueError(f"ML_THRESHOLD는 (0, 1] 범위여야 합니다: {self.ml_threshold}") + # Per-symbol strategy params: {symbol: SymbolStrategyParams} self._symbol_params: dict[str, SymbolStrategyParams] = {} for sym in self.symbols: diff --git a/src/exchange.py b/src/exchange.py index b039df7..f398066 100644 --- a/src/exchange.py +++ b/src/exchange.py @@ -1,5 +1,7 @@ import asyncio import math +import threading +import time as _time from binance.client import Client from binance.exceptions import BinanceAPIException from loguru import logger @@ -7,8 +9,10 @@ from src.config import Config class BinanceFuturesClient: - # 클래스 레벨 exchange info 캐시 (전체 심볼 1회만 조회) + # 클래스 레벨 exchange info 캐시 (TTL 24시간) _exchange_info_cache: dict | None = None + _exchange_info_time: float = 0.0 + _EXCHANGE_INFO_TTL: float = 86400.0 # 24시간 def __init__(self, config: Config, symbol: str = None): self.config = config @@ -19,18 +23,32 @@ class BinanceFuturesClient: ) self._qty_precision: int | None = None self._price_precision: int | None = None + self._api_lock = threading.Lock() # requests.Session 스레드 안전성 보장 MIN_NOTIONAL = 5.0 # 바이낸스 선물 최소 명목금액 (USDT) + async def _run_api(self, func): + """동기 API 호출을 스레드 풀에서 실행하되, _api_lock으로 직렬화한다.""" + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + None, lambda: self._call_with_lock(func), + ) + + def _call_with_lock(self, func): + with self._api_lock: + return func() + @classmethod def _get_exchange_info(cls, client: Client) -> dict | None: - """exchange info를 클래스 레벨로 캐시하여 1회만 조회한다.""" - if cls._exchange_info_cache is None: + """exchange info를 클래스 레벨로 캐시한다 (TTL 24시간).""" + now = _time.monotonic() + if cls._exchange_info_cache is None or (now - cls._exchange_info_time) > cls._EXCHANGE_INFO_TTL: try: cls._exchange_info_cache = client.futures_exchange_info() + cls._exchange_info_time = now except Exception as e: logger.warning(f"exchange info 조회 실패: {e}") - return None + return cls._exchange_info_cache # 만료돼도 기존 캐시 반환 return cls._exchange_info_cache def _load_symbol_precision(self) -> None: @@ -83,19 +101,14 @@ class BinanceFuturesClient: return qty_rounded async def set_leverage(self, leverage: int) -> dict: - loop = asyncio.get_running_loop() - return await loop.run_in_executor( - None, + return await self._run_api( lambda: self.client.futures_change_leverage( symbol=self.symbol, leverage=leverage ), ) async def get_balance(self) -> float: - loop = asyncio.get_running_loop() - balances = await loop.run_in_executor( - None, self.client.futures_account_balance - ) + balances = await self._run_api(self.client.futures_account_balance) for b in balances: if b["asset"] == "USDT": return float(b["balance"]) @@ -110,8 +123,6 @@ class BinanceFuturesClient: stop_price: float = None, reduce_only: bool = False, ) -> dict: - loop = asyncio.get_running_loop() - params = dict( symbol=self.symbol, side=side, @@ -125,17 +136,15 @@ class BinanceFuturesClient: if stop_price is not None: params["stopPrice"] = stop_price try: - return await loop.run_in_executor( - None, lambda: self.client.futures_create_order(**params) + return await self._run_api( + lambda: self.client.futures_create_order(**params) ) except BinanceAPIException as e: logger.error(f"주문 실패: {e}") raise async def get_position(self) -> dict | None: - loop = asyncio.get_running_loop() - positions = await loop.run_in_executor( - None, + positions = await self._run_api( lambda: self.client.futures_position_information( symbol=self.symbol ), @@ -147,37 +156,37 @@ class BinanceFuturesClient: async def get_open_orders(self) -> list[dict]: """현재 심볼의 오픈 주문 목록을 조회한다.""" - loop = asyncio.get_running_loop() - return await loop.run_in_executor( - None, + return await self._run_api( lambda: self.client.futures_get_open_orders(symbol=self.symbol), ) async def cancel_all_orders(self): """오픈 주문을 모두 취소한다.""" - loop = asyncio.get_running_loop() - await loop.run_in_executor( - None, + await self._run_api( lambda: self.client.futures_cancel_all_open_orders( symbol=self.symbol ), ) - async def get_recent_income(self, limit: int = 5) -> list[dict]: - """최근 REALIZED_PNL + COMMISSION 내역을 조회한다.""" - loop = asyncio.get_running_loop() + async def get_recent_income(self, limit: int = 5, start_time: int | None = None) -> tuple[list[dict], list[dict]]: + """최근 REALIZED_PNL + COMMISSION 내역을 조회한다. + + Args: + limit: 최대 조회 건수 + start_time: 밀리초 단위 시작 시각. 지정 시 해당 시각 이후 데이터만 반환. + """ try: - rows = await loop.run_in_executor( - None, - lambda: self.client.futures_income_history( - symbol=self.symbol, incomeType="REALIZED_PNL", limit=limit, - ), + pnl_params = dict(symbol=self.symbol, incomeType="REALIZED_PNL", limit=limit) + comm_params = dict(symbol=self.symbol, incomeType="COMMISSION", limit=limit) + if start_time is not None: + pnl_params["startTime"] = start_time + comm_params["startTime"] = start_time + + rows = await self._run_api( + lambda: self.client.futures_income_history(**pnl_params), ) - commissions = await loop.run_in_executor( - None, - lambda: self.client.futures_income_history( - symbol=self.symbol, incomeType="COMMISSION", limit=limit, - ), + commissions = await self._run_api( + lambda: self.client.futures_income_history(**comm_params), ) return rows, commissions except Exception as e: @@ -186,10 +195,8 @@ class BinanceFuturesClient: async def get_open_interest(self) -> float | None: """현재 미결제약정(OI)을 조회한다. 오류 시 None 반환.""" - loop = asyncio.get_running_loop() try: - result = await loop.run_in_executor( - None, + result = await self._run_api( lambda: self.client.futures_open_interest(symbol=self.symbol), ) return float(result["openInterest"]) @@ -199,10 +206,8 @@ class BinanceFuturesClient: async def get_funding_rate(self) -> float | None: """현재 펀딩비를 조회한다. 오류 시 None 반환.""" - loop = asyncio.get_running_loop() try: - result = await loop.run_in_executor( - None, + result = await self._run_api( lambda: self.client.futures_mark_price(symbol=self.symbol), ) return float(result["lastFundingRate"]) @@ -212,10 +217,8 @@ class BinanceFuturesClient: async def get_oi_history(self, limit: int = 5) -> list[float]: """최근 OI 변화율 히스토리를 조회한다 (봇 초기화용). 실패 시 빈 리스트.""" - loop = asyncio.get_running_loop() try: - result = await loop.run_in_executor( - None, + result = await self._run_api( lambda: self.client.futures_open_interest_hist( symbol=self.symbol, period="15m", limit=limit + 1, ), @@ -236,27 +239,18 @@ class BinanceFuturesClient: async def create_listen_key(self) -> str: """POST /fapi/v1/listenKey — listenKey 신규 발급""" - loop = asyncio.get_running_loop() - result = await loop.run_in_executor( - None, - lambda: self.client.futures_stream_get_listen_key(), - ) - return result + return await self._run_api(self.client.futures_stream_get_listen_key) async def keepalive_listen_key(self, listen_key: str) -> None: """PUT /fapi/v1/listenKey — listenKey 만료 연장 (60분 → 리셋)""" - loop = asyncio.get_running_loop() - await loop.run_in_executor( - None, + await self._run_api( lambda: self.client.futures_stream_keepalive(listenKey=listen_key), ) async def delete_listen_key(self, listen_key: str) -> None: """DELETE /fapi/v1/listenKey — listenKey 삭제 (정상 종료 시)""" - loop = asyncio.get_running_loop() try: - await loop.run_in_executor( - None, + await self._run_api( lambda: self.client.futures_stream_close(listenKey=listen_key), ) except Exception as e: diff --git a/src/ml_filter.py b/src/ml_filter.py index fe908c0..7491610 100644 --- a/src/ml_filter.py +++ b/src/ml_filter.py @@ -144,6 +144,7 @@ class MLFilter: else: available = [c for c in FEATURE_COLS if c in features.index] X = pd.DataFrame([features[available].values.astype(np.float64)], columns=available) + X = X.fillna(0.0) # ONNX(nan_to_num)와 동일한 NaN 처리 proba = float(self._lgbm_model.predict_proba(X)[0][1]) logger.debug( f"ML 필터 [{self.active_backend}] 확률: {proba:.3f} " diff --git a/src/risk_manager.py b/src/risk_manager.py index 1ad5bb6..5e66af9 100644 --- a/src/risk_manager.py +++ b/src/risk_manager.py @@ -11,6 +11,7 @@ class RiskManager: self.initial_balance: float = 0.0 self.open_positions: dict[str, str] = {} # {symbol: side} self._lock = asyncio.Lock() + self._entry_lock = asyncio.Lock() # 동시 진입 시 잔고 레이스 방지 async def is_trading_allowed(self) -> bool: """일일 최대 손실 초과 시 거래 중단""" @@ -58,10 +59,11 @@ class RiskManager: self.daily_pnl += pnl logger.info(f"오늘 누적 PnL: {self.daily_pnl:.4f} USDT") - def reset_daily(self): + async def reset_daily(self): """매일 자정 초기화""" - self.daily_pnl = 0.0 - logger.info("일일 PnL 초기화") + async with self._lock: + self.daily_pnl = 0.0 + logger.info("일일 PnL 초기화") def set_base_balance(self, balance: float) -> None: """봇 시작 시 기준 잔고 설정""" diff --git a/tests/test_config.py b/tests/test_config.py index 41d6f9d..3496ad2 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -48,3 +48,28 @@ def test_config_max_same_direction_default(): """동일 방향 최대 수 기본값 2.""" cfg = Config() assert cfg.max_same_direction == 2 + + +def test_config_rejects_zero_leverage(): + """LEVERAGE=0은 ValueError.""" + os.environ["LEVERAGE"] = "0" + with pytest.raises(ValueError, match="LEVERAGE"): + Config() + os.environ["LEVERAGE"] = "10" # 복원 + + +def test_config_rejects_invalid_margin_ratio(): + """MARGIN_MAX_RATIO가 0이면 ValueError.""" + os.environ["MARGIN_MAX_RATIO"] = "0" + with pytest.raises(ValueError, match="MARGIN_MAX_RATIO"): + Config() + os.environ["MARGIN_MAX_RATIO"] = "0.50" # 복원 + + +def test_config_rejects_min_gt_max_margin(): + """MARGIN_MIN > MAX이면 ValueError.""" + os.environ["MARGIN_MIN_RATIO"] = "0.80" + os.environ["MARGIN_MAX_RATIO"] = "0.50" + with pytest.raises(ValueError, match="MARGIN_MIN_RATIO"): + Config() + os.environ["MARGIN_MIN_RATIO"] = "0.20" # 복원 diff --git a/tests/test_exchange.py b/tests/test_exchange.py index 8192155..5960e86 100644 --- a/tests/test_exchange.py +++ b/tests/test_exchange.py @@ -1,3 +1,4 @@ +import threading import pytest from unittest.mock import AsyncMock, MagicMock, patch from src.exchange import BinanceFuturesClient @@ -25,6 +26,7 @@ def client(): c.symbol = config.symbol c._qty_precision = 1 c._price_precision = 4 + c._api_lock = threading.Lock() return c @@ -43,6 +45,7 @@ def exchange(): c.client = MagicMock() c._qty_precision = 1 c._price_precision = 4 + c._api_lock = threading.Lock() return c diff --git a/tests/test_risk_manager.py b/tests/test_risk_manager.py index 76ecc8b..63d7785 100644 --- a/tests/test_risk_manager.py +++ b/tests/test_risk_manager.py @@ -1,3 +1,4 @@ +import asyncio import pytest import os from src.risk_manager import RiskManager @@ -137,3 +138,31 @@ async def test_max_positions_global_limit(shared_risk): await shared_risk.register_position("XRPUSDT", "LONG") await shared_risk.register_position("TRXUSDT", "SHORT") assert await shared_risk.can_open_new_position("DOGEUSDT", "LONG") is False + + +@pytest.mark.asyncio +async def test_reset_daily_with_lock(shared_risk): + """reset_daily가 lock 하에서 PnL을 초기화한다.""" + await shared_risk.close_position("DUMMY", 5.0) # dummy 기록 + shared_risk.open_positions.clear() # clean up + assert shared_risk.daily_pnl == 5.0 + await shared_risk.reset_daily() + assert shared_risk.daily_pnl == 0.0 + + +@pytest.mark.asyncio +async def test_entry_lock_serializes_access(shared_risk): + """_entry_lock이 동시 접근을 직렬화하는지 확인.""" + order = [] + + async def simulated_entry(name: str): + async with shared_risk._entry_lock: + order.append(f"{name}_start") + await asyncio.sleep(0.05) + order.append(f"{name}_end") + + await asyncio.gather(simulated_entry("A"), simulated_entry("B")) + # 직렬화 확인: A_start, A_end, B_start, B_end 또는 B_start, B_end, A_start, A_end + assert order[0].endswith("_start") + assert order[1].endswith("_end") + assert order[0][0] == order[1][0] # 같은 이름으로 시작/끝