diff --git a/src/bot.py b/src/bot.py index 474e536..12a01d8 100644 --- a/src/bot.py +++ b/src/bot.py @@ -1,5 +1,6 @@ import asyncio from collections import deque +from pathlib import Path import pandas as pd from loguru import logger from src.config import Config @@ -20,7 +21,19 @@ class TradingBot: self.exchange = BinanceFuturesClient(config, symbol=self.symbol) self.notifier = DiscordNotifier(config.discord_webhook_url) self.risk = risk or RiskManager(config) - self.ml_filter = MLFilter(threshold=config.ml_threshold) + # 심볼별 모델 디렉토리. 없으면 기존 models/ 루트로 폴백 + symbol_model_dir = Path(f"models/{self.symbol.lower()}") + if symbol_model_dir.exists(): + onnx_path = str(symbol_model_dir / "mlx_filter.weights.onnx") + lgbm_path = str(symbol_model_dir / "lgbm_filter.pkl") + else: + onnx_path = "models/mlx_filter.weights.onnx" + lgbm_path = "models/lgbm_filter.pkl" + self.ml_filter = MLFilter( + onnx_path=onnx_path, + lgbm_path=lgbm_path, + threshold=config.ml_threshold, + ) self.current_trade_side: str | None = None # "LONG" | "SHORT" self._entry_price: float | None = None self._entry_quantity: float | None = None