diff --git a/.venv b/.venv new file mode 120000 index 0000000..872124e --- /dev/null +++ b/.venv @@ -0,0 +1 @@ +/Users/gihyeon/github/cointrader/.venv \ No newline at end of file diff --git a/main.py b/main.py index 8055649..b5d24f2 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,9 @@ import asyncio from dotenv import load_dotenv +from loguru import logger from src.config import Config from src.bot import TradingBot +from src.risk_manager import RiskManager from src.logger_setup import setup_logger load_dotenv() @@ -10,8 +12,15 @@ load_dotenv() async def main(): setup_logger(log_level="INFO") config = Config() - bot = TradingBot(config) - await bot.run() + risk = RiskManager(config) + + bots = [] + for symbol in config.symbols: + bot = TradingBot(config, symbol=symbol, risk=risk) + bots.append(bot) + + logger.info(f"멀티심볼 봇 시작: {config.symbols} ({len(bots)}개 인스턴스)") + await asyncio.gather(*[bot.run() for bot in bots]) if __name__ == "__main__": diff --git a/src/bot.py b/src/bot.py index 4567101..12a01d8 100644 --- a/src/bot.py +++ b/src/bot.py @@ -1,5 +1,6 @@ import asyncio from collections import deque +from pathlib import Path import pandas as pd from loguru import logger from src.config import Config @@ -14,12 +15,25 @@ from src.user_data_stream import UserDataStream class TradingBot: - def __init__(self, config: Config): + def __init__(self, config: Config, symbol: str = None, risk: RiskManager = None): self.config = config - self.exchange = BinanceFuturesClient(config) + self.symbol = symbol or config.symbol + self.exchange = BinanceFuturesClient(config, symbol=self.symbol) self.notifier = DiscordNotifier(config.discord_webhook_url) - self.risk = RiskManager(config) - self.ml_filter = MLFilter(threshold=config.ml_threshold) + self.risk = risk or RiskManager(config) + # 심볼별 모델 디렉토리. 없으면 기존 models/ 루트로 폴백 + symbol_model_dir = Path(f"models/{self.symbol.lower()}") + if symbol_model_dir.exists(): + onnx_path = str(symbol_model_dir / "mlx_filter.weights.onnx") + lgbm_path = str(symbol_model_dir / "lgbm_filter.pkl") + else: + onnx_path = "models/mlx_filter.weights.onnx" + lgbm_path = "models/lgbm_filter.pkl" + self.ml_filter = MLFilter( + onnx_path=onnx_path, + lgbm_path=lgbm_path, + threshold=config.ml_threshold, + ) self.current_trade_side: str | None = None # "LONG" | "SHORT" self._entry_price: float | None = None self._entry_quantity: float | None = None @@ -28,17 +42,17 @@ class TradingBot: self._oi_history: deque = deque(maxlen=5) self._latest_ret_1: float = 0.0 self.stream = MultiSymbolStream( - symbols=[config.symbol, "BTCUSDT", "ETHUSDT"], + symbols=[self.symbol] + config.correlation_symbols, interval="15m", on_candle=self._on_candle_closed, ) async def _on_candle_closed(self, candle: dict): - xrp_df = self.stream.get_dataframe(self.config.symbol) + primary_df = self.stream.get_dataframe(self.symbol) btc_df = self.stream.get_dataframe("BTCUSDT") eth_df = self.stream.get_dataframe("ETHUSDT") - if xrp_df is not None: - await self.process_candle(xrp_df, btc_df=btc_df, eth_df=eth_df) + if primary_df is not None: + await self.process_candle(primary_df, btc_df=btc_df, eth_df=eth_df) async def _recover_position(self) -> None: """재시작 시 바이낸스에서 현재 포지션을 조회하여 상태 복구.""" @@ -134,8 +148,8 @@ class TradingBot: if position is None and raw_signal != "HOLD": self.current_trade_side = None - if not self.risk.can_open_new_position(): - logger.info("최대 포지션 수 도달") + if not await self.risk.can_open_new_position(self.symbol, raw_signal): + logger.info(f"[{self.symbol}] 포지션 오픈 불가") return signal = raw_signal features = build_features( @@ -163,12 +177,14 @@ class TradingBot: async def _open_position(self, signal: str, df): balance = await self.exchange.get_balance() + num_symbols = len(self.config.symbols) + per_symbol_balance = balance / num_symbols price = df["close"].iloc[-1] margin_ratio = self.risk.get_dynamic_margin_ratio(balance) quantity = self.exchange.calculate_quantity( - balance=balance, price=price, leverage=self.config.leverage, margin_ratio=margin_ratio + balance=per_symbol_balance, price=price, leverage=self.config.leverage, margin_ratio=margin_ratio ) - logger.info(f"포지션 크기: 잔고={balance:.2f} USDT, 증거금비율={margin_ratio:.1%}, 수량={quantity}") + logger.info(f"[{self.symbol}] 포지션 크기: 잔고={per_symbol_balance:.2f}/{balance:.2f} USDT, 증거금비율={margin_ratio:.1%}, 수량={quantity}") stop_loss, take_profit = Indicators(df).get_atr_stop(df, signal, price) notional = quantity * price @@ -190,11 +206,12 @@ class TradingBot: "atr": float(last_row["atr"]) if "atr" in last_row.index and pd.notna(last_row["atr"]) else 0.0, } + await self.risk.register_position(self.symbol, signal) self.current_trade_side = signal self._entry_price = price self._entry_quantity = quantity self.notifier.notify_open( - symbol=self.config.symbol, + symbol=self.symbol, side=signal, entry_price=price, quantity=quantity, @@ -245,10 +262,10 @@ class TradingBot: estimated_pnl = self._calc_estimated_pnl(exit_price) diff = net_pnl - estimated_pnl - self.risk.record_pnl(net_pnl) + await self.risk.close_position(self.symbol, net_pnl) self.notifier.notify_close( - symbol=self.config.symbol, + symbol=self.symbol, side=self.current_trade_side or "UNKNOWN", close_reason=close_reason, exit_price=exit_price, @@ -317,8 +334,8 @@ class TradingBot: try: await self._close_position(position) - if not self.risk.can_open_new_position(): - logger.info("최대 포지션 수 도달 — 재진입 건너뜀") + if not await self.risk.can_open_new_position(self.symbol, signal): + logger.info(f"[{self.symbol}] 최대 포지션 수 도달 — 재진입 건너뜀") return if self.ml_filter.is_model_loaded(): @@ -337,7 +354,7 @@ class TradingBot: self._is_reentering = False async def run(self): - logger.info(f"봇 시작: {self.config.symbol}, 레버리지 {self.config.leverage}x") + logger.info(f"[{self.symbol}] 봇 시작, 레버리지 {self.config.leverage}x") await self._recover_position() await self._init_oi_history() balance = await self.exchange.get_balance() @@ -345,7 +362,7 @@ class TradingBot: logger.info(f"기준 잔고 설정: {balance:.2f} USDT (동적 증거금 비율 기준점)") user_stream = UserDataStream( - symbol=self.config.symbol, + symbol=self.symbol, on_order_filled=self._on_position_closed, ) diff --git a/src/config.py b/src/config.py index 86bee31..3f9b477 100644 --- a/src/config.py +++ b/src/config.py @@ -10,8 +10,11 @@ class Config: api_key: str = "" api_secret: str = "" symbol: str = "XRPUSDT" + symbols: list = None + correlation_symbols: list = None leverage: int = 10 max_positions: int = 3 + max_same_direction: int = 2 stop_loss_pct: float = 0.015 # 1.5% take_profit_pct: float = 0.045 # 4.5% (3:1 RR) trailing_stop_pct: float = 0.01 # 1% @@ -31,3 +34,15 @@ class Config: self.margin_min_ratio = float(os.getenv("MARGIN_MIN_RATIO", "0.20")) self.margin_decay_rate = float(os.getenv("MARGIN_DECAY_RATE", "0.0006")) self.ml_threshold = float(os.getenv("ML_THRESHOLD", "0.55")) + self.max_same_direction = int(os.getenv("MAX_SAME_DIRECTION", "2")) + + # symbols: SYMBOLS 환경변수 우선, 없으면 SYMBOL에서 변환 + symbols_env = os.getenv("SYMBOLS", "") + if symbols_env: + self.symbols = [s.strip() for s in symbols_env.split(",") if s.strip()] + else: + self.symbols = [self.symbol] + + # correlation_symbols + corr_env = os.getenv("CORRELATION_SYMBOLS", "BTCUSDT,ETHUSDT") + self.correlation_symbols = [s.strip() for s in corr_env.split(",") if s.strip()] diff --git a/src/exchange.py b/src/exchange.py index 1dba3bb..0fe71f5 100644 --- a/src/exchange.py +++ b/src/exchange.py @@ -6,8 +6,9 @@ from src.config import Config class BinanceFuturesClient: - def __init__(self, config: Config): + def __init__(self, config: Config, symbol: str = None): self.config = config + self.symbol = symbol or config.symbol self.client = Client( api_key=config.api_key, api_secret=config.api_secret, @@ -31,7 +32,7 @@ class BinanceFuturesClient: return await loop.run_in_executor( None, lambda: self.client.futures_change_leverage( - symbol=self.config.symbol, leverage=leverage + symbol=self.symbol, leverage=leverage ), ) @@ -68,7 +69,7 @@ class BinanceFuturesClient: ) params = dict( - symbol=self.config.symbol, + symbol=self.symbol, side=side, type=order_type, quantity=quantity, @@ -98,7 +99,7 @@ class BinanceFuturesClient: """STOP_MARKET / TAKE_PROFIT_MARKET 등 Algo Order API(/fapi/v1/algoOrder)로 전송.""" loop = asyncio.get_event_loop() params = dict( - symbol=self.config.symbol, + symbol=self.symbol, side=side, algoType="CONDITIONAL", type=order_type, @@ -120,7 +121,7 @@ class BinanceFuturesClient: positions = await loop.run_in_executor( None, lambda: self.client.futures_position_information( - symbol=self.config.symbol + symbol=self.symbol ), ) for p in positions: @@ -134,14 +135,14 @@ class BinanceFuturesClient: await loop.run_in_executor( None, lambda: self.client.futures_cancel_all_open_orders( - symbol=self.config.symbol + symbol=self.symbol ), ) try: await loop.run_in_executor( None, lambda: self.client.futures_cancel_all_algo_open_orders( - symbol=self.config.symbol + symbol=self.symbol ), ) except Exception as e: @@ -153,7 +154,7 @@ class BinanceFuturesClient: try: result = await loop.run_in_executor( None, - lambda: self.client.futures_open_interest(symbol=self.config.symbol), + lambda: self.client.futures_open_interest(symbol=self.symbol), ) return float(result["openInterest"]) except Exception as e: @@ -166,7 +167,7 @@ class BinanceFuturesClient: try: result = await loop.run_in_executor( None, - lambda: self.client.futures_mark_price(symbol=self.config.symbol), + lambda: self.client.futures_mark_price(symbol=self.symbol), ) return float(result["lastFundingRate"]) except Exception as e: @@ -180,7 +181,7 @@ class BinanceFuturesClient: result = await loop.run_in_executor( None, lambda: self.client.futures_open_interest_hist( - symbol=self.config.symbol, period="15m", limit=limit + 1, + symbol=self.symbol, period="15m", limit=limit + 1, ), ) if len(result) < 2: diff --git a/src/risk_manager.py b/src/risk_manager.py index 480bb34..3b0f395 100644 --- a/src/risk_manager.py +++ b/src/risk_manager.py @@ -1,3 +1,4 @@ +import asyncio from loguru import logger from src.config import Config @@ -5,10 +6,11 @@ from src.config import Config class RiskManager: def __init__(self, config: Config, max_daily_loss_pct: float = 0.05): self.config = config - self.max_daily_loss_pct = max_daily_loss_pct # 일일 최대 손실 5% + self.max_daily_loss_pct = max_daily_loss_pct self.daily_pnl: float = 0.0 self.initial_balance: float = 0.0 - self.open_positions: list = [] + self.open_positions: dict[str, str] = {} # {symbol: side} + self._lock = asyncio.Lock() def is_trading_allowed(self) -> bool: """일일 최대 손실 초과 시 거래 중단""" @@ -22,9 +24,33 @@ class RiskManager: return False return True - def can_open_new_position(self) -> bool: - """최대 동시 포지션 수 체크""" - return len(self.open_positions) < self.config.max_positions + async def can_open_new_position(self, symbol: str, side: str) -> bool: + """포지션 오픈 가능 여부 (전체 한도 + 중복 진입 + 동일 방향 제한)""" + async with self._lock: + if len(self.open_positions) >= self.config.max_positions: + logger.info(f"최대 포지션 수 도달: {len(self.open_positions)}/{self.config.max_positions}") + return False + if symbol in self.open_positions: + logger.info(f"{symbol} 이미 포지션 보유 중") + return False + same_dir = sum(1 for s in self.open_positions.values() if s == side) + if same_dir >= self.config.max_same_direction: + logger.info(f"동일 방향({side}) 한도 도달: {same_dir}/{self.config.max_same_direction}") + return False + return True + + async def register_position(self, symbol: str, side: str): + """포지션 등록""" + async with self._lock: + self.open_positions[symbol] = side + logger.info(f"포지션 등록: {symbol} {side} (현재 {len(self.open_positions)}개)") + + async def close_position(self, symbol: str, pnl: float): + """포지션 닫기 + PnL 기록""" + async with self._lock: + self.open_positions.pop(symbol, None) + self.daily_pnl += pnl + logger.info(f"포지션 종료: {symbol}, PnL={pnl:+.4f}, 누적={self.daily_pnl:+.4f}") def record_pnl(self, pnl: float): self.daily_pnl += pnl @@ -36,7 +62,7 @@ class RiskManager: logger.info("일일 PnL 초기화") def set_base_balance(self, balance: float) -> None: - """봇 시작 시 기준 잔고 설정 (동적 비율 계산 기준점)""" + """봇 시작 시 기준 잔고 설정""" self.initial_balance = balance def get_dynamic_margin_ratio(self, balance: float) -> float: diff --git a/tests/test_bot.py b/tests/test_bot.py index dab47be..e9846ba 100644 --- a/tests/test_bot.py +++ b/tests/test_bot.py @@ -37,6 +37,25 @@ def sample_df(): }) +def test_bot_accepts_symbol_and_risk(config): + """TradingBot이 symbol과 risk를 외부에서 주입받을 수 있다.""" + from src.risk_manager import RiskManager + risk = RiskManager(config) + with patch("src.bot.BinanceFuturesClient"): + bot = TradingBot(config, symbol="TRXUSDT", risk=risk) + assert bot.symbol == "TRXUSDT" + assert bot.risk is risk + + +def test_bot_stream_uses_injected_symbol(config): + """봇의 stream이 주입된 심볼을 primary로 사용한다.""" + from src.risk_manager import RiskManager + risk = RiskManager(config) + with patch("src.bot.BinanceFuturesClient"): + bot = TradingBot(config, symbol="DOGEUSDT", risk=risk) + assert "dogeusdt" in bot.stream.buffers + + def test_bot_uses_multi_symbol_stream(config): from src.data_stream import MultiSymbolStream with patch("src.bot.BinanceFuturesClient"): @@ -64,6 +83,12 @@ async def test_bot_processes_signal(config, sample_df): bot.exchange.calculate_quantity = MagicMock(return_value=100.0) bot.exchange.MIN_NOTIONAL = 5.0 + bot.risk = MagicMock() + bot.risk.is_trading_allowed.return_value = True + bot.risk.can_open_new_position = AsyncMock(return_value=True) + bot.risk.register_position = AsyncMock() + bot.risk.get_dynamic_margin_ratio.return_value = 0.50 + with patch("src.bot.Indicators") as MockInd: mock_ind = MagicMock() mock_ind.calculate_all.return_value = sample_df @@ -82,7 +107,7 @@ async def test_close_and_reenter_calls_open_when_ml_passes(config, sample_df): bot._close_position = AsyncMock() bot._open_position = AsyncMock() bot.risk = MagicMock() - bot.risk.can_open_new_position.return_value = True + bot.risk.can_open_new_position = AsyncMock(return_value=True) bot.ml_filter = MagicMock() bot.ml_filter.is_model_loaded.return_value = True bot.ml_filter.should_enter.return_value = True @@ -102,6 +127,8 @@ async def test_close_and_reenter_skips_open_when_ml_blocks(config, sample_df): bot._close_position = AsyncMock() bot._open_position = AsyncMock() + bot.risk = MagicMock() + bot.risk.can_open_new_position = AsyncMock(return_value=True) bot.ml_filter = MagicMock() bot.ml_filter.is_model_loaded.return_value = True bot.ml_filter.should_enter.return_value = False @@ -122,7 +149,7 @@ async def test_close_and_reenter_skips_open_when_max_positions_reached(config, s bot._close_position = AsyncMock() bot._open_position = AsyncMock() bot.risk = MagicMock() - bot.risk.can_open_new_position.return_value = False + bot.risk.can_open_new_position = AsyncMock(return_value=False) position = {"positionAmt": "100", "entryPrice": "0.5", "markPrice": "0.52"} await bot._close_and_reenter(position, "SHORT", sample_df) @@ -206,6 +233,12 @@ async def test_process_candle_fetches_oi_and_funding(config, sample_df): bot.exchange.get_open_interest = AsyncMock(return_value=5000000.0) bot.exchange.get_funding_rate = AsyncMock(return_value=0.0001) + bot.risk = MagicMock() + bot.risk.is_trading_allowed.return_value = True + bot.risk.can_open_new_position = AsyncMock(return_value=True) + bot.risk.register_position = AsyncMock() + bot.risk.get_dynamic_margin_ratio.return_value = 0.50 + # 신호를 LONG으로 강제해 build_features가 반드시 호출되도록 함 with patch("src.bot.Indicators") as mock_ind_cls: mock_ind = MagicMock() diff --git a/tests/test_config.py b/tests/test_config.py index 4593858..41d6f9d 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -19,3 +19,32 @@ def test_config_dynamic_margin_params(): assert cfg.margin_max_ratio == 0.50 assert cfg.margin_min_ratio == 0.20 assert cfg.margin_decay_rate == 0.0006 + + +def test_config_loads_symbols_list(): + """SYMBOLS 환경변수로 쉼표 구분 리스트를 로드한다.""" + os.environ["SYMBOLS"] = "XRPUSDT,TRXUSDT,DOGEUSDT" + os.environ.pop("SYMBOL", None) + cfg = Config() + assert cfg.symbols == ["XRPUSDT", "TRXUSDT", "DOGEUSDT"] + + +def test_config_fallback_to_symbol(): + """SYMBOLS 미설정 시 SYMBOL에서 1개짜리 리스트로 변환한다.""" + os.environ.pop("SYMBOLS", None) + os.environ["SYMBOL"] = "XRPUSDT" + cfg = Config() + assert cfg.symbols == ["XRPUSDT"] + + +def test_config_correlation_symbols(): + """상관관계 심볼 로드.""" + os.environ["CORRELATION_SYMBOLS"] = "BTCUSDT,ETHUSDT" + cfg = Config() + assert cfg.correlation_symbols == ["BTCUSDT", "ETHUSDT"] + + +def test_config_max_same_direction_default(): + """동일 방향 최대 수 기본값 2.""" + cfg = Config() + assert cfg.max_same_direction == 2 diff --git a/tests/test_exchange.py b/tests/test_exchange.py index 592d46c..473839a 100644 --- a/tests/test_exchange.py +++ b/tests/test_exchange.py @@ -22,6 +22,7 @@ def client(): config.leverage = 10 c = BinanceFuturesClient.__new__(BinanceFuturesClient) c.config = config + c.symbol = config.symbol return c @@ -36,10 +37,24 @@ def exchange(): config = Config() c = BinanceFuturesClient.__new__(BinanceFuturesClient) c.config = config + c.symbol = config.symbol c.client = MagicMock() return c +def test_exchange_uses_own_symbol(): + """Exchange 클라이언트가 config.symbol 대신 생성자의 symbol을 사용한다.""" + os.environ.update({ + "BINANCE_API_KEY": "test_key", + "BINANCE_API_SECRET": "test_secret", + "SYMBOL": "XRPUSDT", + }) + config = Config() + with patch("src.exchange.Client"): + client = BinanceFuturesClient(config, symbol="TRXUSDT") + assert client.symbol == "TRXUSDT" + + @pytest.mark.asyncio async def test_set_leverage(config): with patch("src.exchange.Client") as MockClient: diff --git a/tests/test_risk_manager.py b/tests/test_risk_manager.py index 0aa06be..fe5bedb 100644 --- a/tests/test_risk_manager.py +++ b/tests/test_risk_manager.py @@ -29,10 +29,13 @@ def test_trading_allowed_normal(config): assert rm.is_trading_allowed() is True -def test_position_size_capped(config): +@pytest.mark.asyncio +async def test_position_size_capped(config): rm = RiskManager(config, max_daily_loss_pct=0.05) - rm.open_positions = ["pos1", "pos2", "pos3"] - assert rm.can_open_new_position() is False + await rm.register_position("XRPUSDT", "LONG") + await rm.register_position("TRXUSDT", "SHORT") + await rm.register_position("DOGEUSDT", "LONG") + assert await rm.can_open_new_position("SOLUSDT", "SHORT") is False # --- 동적 증거금 비율 테스트 --- @@ -81,3 +84,54 @@ def test_ratio_clamped_at_max(risk): """잔고가 기준보다 작아도 최대 비율(50%) 초과하지 않음""" ratio = risk.get_dynamic_margin_ratio(5.0) assert ratio == pytest.approx(0.50, abs=1e-6) + + +# --- 멀티심볼 공유 RiskManager 테스트 --- + +@pytest.fixture +def shared_risk(config): + config.max_same_direction = 2 + return RiskManager(config) + + +@pytest.mark.asyncio +async def test_can_open_new_position_async(shared_risk): + """비동기 포지션 오픈 허용 체크.""" + assert await shared_risk.can_open_new_position("XRPUSDT", "LONG") is True + + +@pytest.mark.asyncio +async def test_register_and_close_position(shared_risk): + """포지션 등록 후 닫기.""" + await shared_risk.register_position("XRPUSDT", "LONG") + assert "XRPUSDT" in shared_risk.open_positions + await shared_risk.close_position("XRPUSDT", pnl=1.5) + assert "XRPUSDT" not in shared_risk.open_positions + assert shared_risk.daily_pnl == 1.5 + + +@pytest.mark.asyncio +async def test_same_symbol_blocked(shared_risk): + """같은 심볼 중복 진입 차단.""" + await shared_risk.register_position("XRPUSDT", "LONG") + assert await shared_risk.can_open_new_position("XRPUSDT", "SHORT") is False + + +@pytest.mark.asyncio +async def test_max_same_direction_limit(shared_risk): + """같은 방향 2개 초과 차단.""" + await shared_risk.register_position("XRPUSDT", "LONG") + await shared_risk.register_position("TRXUSDT", "LONG") + # 3번째 LONG 차단 + assert await shared_risk.can_open_new_position("DOGEUSDT", "LONG") is False + # SHORT은 허용 + assert await shared_risk.can_open_new_position("DOGEUSDT", "SHORT") is True + + +@pytest.mark.asyncio +async def test_max_positions_global_limit(shared_risk): + """전체 포지션 수 한도 초과 차단.""" + shared_risk.config.max_positions = 2 + await shared_risk.register_position("XRPUSDT", "LONG") + await shared_risk.register_position("TRXUSDT", "SHORT") + assert await shared_risk.can_open_new_position("DOGEUSDT", "LONG") is False