Compare commits

..

7 Commits

Author SHA1 Message Date
21in7
9f4c22b5e6 feat: add virtual environment symlink for project dependencies 2026-03-05 23:34:32 +09:00
21in7
ae5692cde4 feat: MLFilter falls back to models/ root if symbol-specific dir not found
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 23:14:38 +09:00
21in7
7acbdca3f4 feat: main.py spawns per-symbol TradingBot instances with shared RiskManager
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 23:13:52 +09:00
21in7
e7620248c7 feat: TradingBot accepts symbol and shared RiskManager, removes config.symbol dependency
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 23:13:04 +09:00
21in7
2e09f5340a feat: exchange client accepts explicit symbol parameter, removes config.symbol dependency
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 23:07:44 +09:00
21in7
9318fb887e feat: shared RiskManager with async lock, same-direction limit, per-symbol tracking
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 23:06:41 +09:00
21in7
7aef391b69 feat: add multi-symbol config (symbols list, correlation_symbols, max_same_direction)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 23:05:22 +09:00
11 changed files with 245 additions and 43 deletions

View File

@@ -1,9 +1,11 @@
BINANCE_API_KEY= BINANCE_API_KEY=
BINANCE_API_SECRET= BINANCE_API_SECRET=
SYMBOL=XRPUSDT SYMBOLS=XRPUSDT
CORRELATION_SYMBOLS=BTCUSDT,ETHUSDT
LEVERAGE=10 LEVERAGE=10
RISK_PER_TRADE=0.02 RISK_PER_TRADE=0.02
DISCORD_WEBHOOK_URL= DISCORD_WEBHOOK_URL=
ML_THRESHOLD=0.55 ML_THRESHOLD=0.55
MAX_SAME_DIRECTION=2
BINANCE_TESTNET_API_KEY= BINANCE_TESTNET_API_KEY=
BINANCE_TESTNET_API_SECRET= BINANCE_TESTNET_API_SECRET=

1
.venv Symbolic link
View File

@@ -0,0 +1 @@
/Users/gihyeon/github/cointrader/.venv

13
main.py
View File

@@ -1,7 +1,9 @@
import asyncio import asyncio
from dotenv import load_dotenv from dotenv import load_dotenv
from loguru import logger
from src.config import Config from src.config import Config
from src.bot import TradingBot from src.bot import TradingBot
from src.risk_manager import RiskManager
from src.logger_setup import setup_logger from src.logger_setup import setup_logger
load_dotenv() load_dotenv()
@@ -10,8 +12,15 @@ load_dotenv()
async def main(): async def main():
setup_logger(log_level="INFO") setup_logger(log_level="INFO")
config = Config() config = Config()
bot = TradingBot(config) risk = RiskManager(config)
await bot.run()
bots = []
for symbol in config.symbols:
bot = TradingBot(config, symbol=symbol, risk=risk)
bots.append(bot)
logger.info(f"멀티심볼 봇 시작: {config.symbols} ({len(bots)}개 인스턴스)")
await asyncio.gather(*[bot.run() for bot in bots])
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1,5 +1,6 @@
import asyncio import asyncio
from collections import deque from collections import deque
from pathlib import Path
import pandas as pd import pandas as pd
from loguru import logger from loguru import logger
from src.config import Config from src.config import Config
@@ -14,12 +15,25 @@ from src.user_data_stream import UserDataStream
class TradingBot: class TradingBot:
def __init__(self, config: Config): def __init__(self, config: Config, symbol: str = None, risk: RiskManager = None):
self.config = config self.config = config
self.exchange = BinanceFuturesClient(config) self.symbol = symbol or config.symbol
self.exchange = BinanceFuturesClient(config, symbol=self.symbol)
self.notifier = DiscordNotifier(config.discord_webhook_url) self.notifier = DiscordNotifier(config.discord_webhook_url)
self.risk = RiskManager(config) self.risk = risk or RiskManager(config)
self.ml_filter = MLFilter(threshold=config.ml_threshold) # 심볼별 모델 디렉토리. 없으면 기존 models/ 루트로 폴백
symbol_model_dir = Path(f"models/{self.symbol.lower()}")
if symbol_model_dir.exists():
onnx_path = str(symbol_model_dir / "mlx_filter.weights.onnx")
lgbm_path = str(symbol_model_dir / "lgbm_filter.pkl")
else:
onnx_path = "models/mlx_filter.weights.onnx"
lgbm_path = "models/lgbm_filter.pkl"
self.ml_filter = MLFilter(
onnx_path=onnx_path,
lgbm_path=lgbm_path,
threshold=config.ml_threshold,
)
self.current_trade_side: str | None = None # "LONG" | "SHORT" self.current_trade_side: str | None = None # "LONG" | "SHORT"
self._entry_price: float | None = None self._entry_price: float | None = None
self._entry_quantity: float | None = None self._entry_quantity: float | None = None
@@ -28,17 +42,17 @@ class TradingBot:
self._oi_history: deque = deque(maxlen=5) self._oi_history: deque = deque(maxlen=5)
self._latest_ret_1: float = 0.0 self._latest_ret_1: float = 0.0
self.stream = MultiSymbolStream( self.stream = MultiSymbolStream(
symbols=[config.symbol, "BTCUSDT", "ETHUSDT"], symbols=[self.symbol] + config.correlation_symbols,
interval="15m", interval="15m",
on_candle=self._on_candle_closed, on_candle=self._on_candle_closed,
) )
async def _on_candle_closed(self, candle: dict): async def _on_candle_closed(self, candle: dict):
xrp_df = self.stream.get_dataframe(self.config.symbol) primary_df = self.stream.get_dataframe(self.symbol)
btc_df = self.stream.get_dataframe("BTCUSDT") btc_df = self.stream.get_dataframe("BTCUSDT")
eth_df = self.stream.get_dataframe("ETHUSDT") eth_df = self.stream.get_dataframe("ETHUSDT")
if xrp_df is not None: if primary_df is not None:
await self.process_candle(xrp_df, btc_df=btc_df, eth_df=eth_df) await self.process_candle(primary_df, btc_df=btc_df, eth_df=eth_df)
async def _recover_position(self) -> None: async def _recover_position(self) -> None:
"""재시작 시 바이낸스에서 현재 포지션을 조회하여 상태 복구.""" """재시작 시 바이낸스에서 현재 포지션을 조회하여 상태 복구."""
@@ -134,8 +148,8 @@ class TradingBot:
if position is None and raw_signal != "HOLD": if position is None and raw_signal != "HOLD":
self.current_trade_side = None self.current_trade_side = None
if not self.risk.can_open_new_position(): if not await self.risk.can_open_new_position(self.symbol, raw_signal):
logger.info("최대 포지션 수 도달") logger.info(f"[{self.symbol}] 포지션 오픈 불가")
return return
signal = raw_signal signal = raw_signal
features = build_features( features = build_features(
@@ -163,12 +177,14 @@ class TradingBot:
async def _open_position(self, signal: str, df): async def _open_position(self, signal: str, df):
balance = await self.exchange.get_balance() balance = await self.exchange.get_balance()
num_symbols = len(self.config.symbols)
per_symbol_balance = balance / num_symbols
price = df["close"].iloc[-1] price = df["close"].iloc[-1]
margin_ratio = self.risk.get_dynamic_margin_ratio(balance) margin_ratio = self.risk.get_dynamic_margin_ratio(balance)
quantity = self.exchange.calculate_quantity( quantity = self.exchange.calculate_quantity(
balance=balance, price=price, leverage=self.config.leverage, margin_ratio=margin_ratio balance=per_symbol_balance, price=price, leverage=self.config.leverage, margin_ratio=margin_ratio
) )
logger.info(f"포지션 크기: 잔고={balance:.2f} USDT, 증거금비율={margin_ratio:.1%}, 수량={quantity}") logger.info(f"[{self.symbol}] 포지션 크기: 잔고={per_symbol_balance:.2f}/{balance:.2f} USDT, 증거금비율={margin_ratio:.1%}, 수량={quantity}")
stop_loss, take_profit = Indicators(df).get_atr_stop(df, signal, price) stop_loss, take_profit = Indicators(df).get_atr_stop(df, signal, price)
notional = quantity * price notional = quantity * price
@@ -190,11 +206,12 @@ class TradingBot:
"atr": float(last_row["atr"]) if "atr" in last_row.index and pd.notna(last_row["atr"]) else 0.0, "atr": float(last_row["atr"]) if "atr" in last_row.index and pd.notna(last_row["atr"]) else 0.0,
} }
await self.risk.register_position(self.symbol, signal)
self.current_trade_side = signal self.current_trade_side = signal
self._entry_price = price self._entry_price = price
self._entry_quantity = quantity self._entry_quantity = quantity
self.notifier.notify_open( self.notifier.notify_open(
symbol=self.config.symbol, symbol=self.symbol,
side=signal, side=signal,
entry_price=price, entry_price=price,
quantity=quantity, quantity=quantity,
@@ -245,10 +262,10 @@ class TradingBot:
estimated_pnl = self._calc_estimated_pnl(exit_price) estimated_pnl = self._calc_estimated_pnl(exit_price)
diff = net_pnl - estimated_pnl diff = net_pnl - estimated_pnl
self.risk.record_pnl(net_pnl) await self.risk.close_position(self.symbol, net_pnl)
self.notifier.notify_close( self.notifier.notify_close(
symbol=self.config.symbol, symbol=self.symbol,
side=self.current_trade_side or "UNKNOWN", side=self.current_trade_side or "UNKNOWN",
close_reason=close_reason, close_reason=close_reason,
exit_price=exit_price, exit_price=exit_price,
@@ -317,8 +334,8 @@ class TradingBot:
try: try:
await self._close_position(position) await self._close_position(position)
if not self.risk.can_open_new_position(): if not await self.risk.can_open_new_position(self.symbol, signal):
logger.info("최대 포지션 수 도달 — 재진입 건너뜀") logger.info(f"[{self.symbol}] 최대 포지션 수 도달 — 재진입 건너뜀")
return return
if self.ml_filter.is_model_loaded(): if self.ml_filter.is_model_loaded():
@@ -337,7 +354,7 @@ class TradingBot:
self._is_reentering = False self._is_reentering = False
async def run(self): async def run(self):
logger.info(f"봇 시작: {self.config.symbol}, 레버리지 {self.config.leverage}x") logger.info(f"[{self.symbol}] 봇 시작, 레버리지 {self.config.leverage}x")
await self._recover_position() await self._recover_position()
await self._init_oi_history() await self._init_oi_history()
balance = await self.exchange.get_balance() balance = await self.exchange.get_balance()
@@ -345,7 +362,7 @@ class TradingBot:
logger.info(f"기준 잔고 설정: {balance:.2f} USDT (동적 증거금 비율 기준점)") logger.info(f"기준 잔고 설정: {balance:.2f} USDT (동적 증거금 비율 기준점)")
user_stream = UserDataStream( user_stream = UserDataStream(
symbol=self.config.symbol, symbol=self.symbol,
on_order_filled=self._on_position_closed, on_order_filled=self._on_position_closed,
) )

View File

@@ -10,8 +10,11 @@ class Config:
api_key: str = "" api_key: str = ""
api_secret: str = "" api_secret: str = ""
symbol: str = "XRPUSDT" symbol: str = "XRPUSDT"
symbols: list = None
correlation_symbols: list = None
leverage: int = 10 leverage: int = 10
max_positions: int = 3 max_positions: int = 3
max_same_direction: int = 2
stop_loss_pct: float = 0.015 # 1.5% stop_loss_pct: float = 0.015 # 1.5%
take_profit_pct: float = 0.045 # 4.5% (3:1 RR) take_profit_pct: float = 0.045 # 4.5% (3:1 RR)
trailing_stop_pct: float = 0.01 # 1% trailing_stop_pct: float = 0.01 # 1%
@@ -31,3 +34,15 @@ class Config:
self.margin_min_ratio = float(os.getenv("MARGIN_MIN_RATIO", "0.20")) self.margin_min_ratio = float(os.getenv("MARGIN_MIN_RATIO", "0.20"))
self.margin_decay_rate = float(os.getenv("MARGIN_DECAY_RATE", "0.0006")) self.margin_decay_rate = float(os.getenv("MARGIN_DECAY_RATE", "0.0006"))
self.ml_threshold = float(os.getenv("ML_THRESHOLD", "0.55")) self.ml_threshold = float(os.getenv("ML_THRESHOLD", "0.55"))
self.max_same_direction = int(os.getenv("MAX_SAME_DIRECTION", "2"))
# symbols: SYMBOLS 환경변수 우선, 없으면 SYMBOL에서 변환
symbols_env = os.getenv("SYMBOLS", "")
if symbols_env:
self.symbols = [s.strip() for s in symbols_env.split(",") if s.strip()]
else:
self.symbols = [self.symbol]
# correlation_symbols
corr_env = os.getenv("CORRELATION_SYMBOLS", "BTCUSDT,ETHUSDT")
self.correlation_symbols = [s.strip() for s in corr_env.split(",") if s.strip()]

View File

@@ -6,8 +6,9 @@ from src.config import Config
class BinanceFuturesClient: class BinanceFuturesClient:
def __init__(self, config: Config): def __init__(self, config: Config, symbol: str = None):
self.config = config self.config = config
self.symbol = symbol or config.symbol
self.client = Client( self.client = Client(
api_key=config.api_key, api_key=config.api_key,
api_secret=config.api_secret, api_secret=config.api_secret,
@@ -31,7 +32,7 @@ class BinanceFuturesClient:
return await loop.run_in_executor( return await loop.run_in_executor(
None, None,
lambda: self.client.futures_change_leverage( lambda: self.client.futures_change_leverage(
symbol=self.config.symbol, leverage=leverage symbol=self.symbol, leverage=leverage
), ),
) )
@@ -68,7 +69,7 @@ class BinanceFuturesClient:
) )
params = dict( params = dict(
symbol=self.config.symbol, symbol=self.symbol,
side=side, side=side,
type=order_type, type=order_type,
quantity=quantity, quantity=quantity,
@@ -98,7 +99,7 @@ class BinanceFuturesClient:
"""STOP_MARKET / TAKE_PROFIT_MARKET 등 Algo Order API(/fapi/v1/algoOrder)로 전송.""" """STOP_MARKET / TAKE_PROFIT_MARKET 등 Algo Order API(/fapi/v1/algoOrder)로 전송."""
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
params = dict( params = dict(
symbol=self.config.symbol, symbol=self.symbol,
side=side, side=side,
algoType="CONDITIONAL", algoType="CONDITIONAL",
type=order_type, type=order_type,
@@ -120,7 +121,7 @@ class BinanceFuturesClient:
positions = await loop.run_in_executor( positions = await loop.run_in_executor(
None, None,
lambda: self.client.futures_position_information( lambda: self.client.futures_position_information(
symbol=self.config.symbol symbol=self.symbol
), ),
) )
for p in positions: for p in positions:
@@ -134,14 +135,14 @@ class BinanceFuturesClient:
await loop.run_in_executor( await loop.run_in_executor(
None, None,
lambda: self.client.futures_cancel_all_open_orders( lambda: self.client.futures_cancel_all_open_orders(
symbol=self.config.symbol symbol=self.symbol
), ),
) )
try: try:
await loop.run_in_executor( await loop.run_in_executor(
None, None,
lambda: self.client.futures_cancel_all_algo_open_orders( lambda: self.client.futures_cancel_all_algo_open_orders(
symbol=self.config.symbol symbol=self.symbol
), ),
) )
except Exception as e: except Exception as e:
@@ -153,7 +154,7 @@ class BinanceFuturesClient:
try: try:
result = await loop.run_in_executor( result = await loop.run_in_executor(
None, None,
lambda: self.client.futures_open_interest(symbol=self.config.symbol), lambda: self.client.futures_open_interest(symbol=self.symbol),
) )
return float(result["openInterest"]) return float(result["openInterest"])
except Exception as e: except Exception as e:
@@ -166,7 +167,7 @@ class BinanceFuturesClient:
try: try:
result = await loop.run_in_executor( result = await loop.run_in_executor(
None, None,
lambda: self.client.futures_mark_price(symbol=self.config.symbol), lambda: self.client.futures_mark_price(symbol=self.symbol),
) )
return float(result["lastFundingRate"]) return float(result["lastFundingRate"])
except Exception as e: except Exception as e:
@@ -180,7 +181,7 @@ class BinanceFuturesClient:
result = await loop.run_in_executor( result = await loop.run_in_executor(
None, None,
lambda: self.client.futures_open_interest_hist( lambda: self.client.futures_open_interest_hist(
symbol=self.config.symbol, period="15m", limit=limit + 1, symbol=self.symbol, period="15m", limit=limit + 1,
), ),
) )
if len(result) < 2: if len(result) < 2:

View File

@@ -1,3 +1,4 @@
import asyncio
from loguru import logger from loguru import logger
from src.config import Config from src.config import Config
@@ -5,10 +6,11 @@ from src.config import Config
class RiskManager: class RiskManager:
def __init__(self, config: Config, max_daily_loss_pct: float = 0.05): def __init__(self, config: Config, max_daily_loss_pct: float = 0.05):
self.config = config self.config = config
self.max_daily_loss_pct = max_daily_loss_pct # 일일 최대 손실 5% self.max_daily_loss_pct = max_daily_loss_pct
self.daily_pnl: float = 0.0 self.daily_pnl: float = 0.0
self.initial_balance: float = 0.0 self.initial_balance: float = 0.0
self.open_positions: list = [] self.open_positions: dict[str, str] = {} # {symbol: side}
self._lock = asyncio.Lock()
def is_trading_allowed(self) -> bool: def is_trading_allowed(self) -> bool:
"""일일 최대 손실 초과 시 거래 중단""" """일일 최대 손실 초과 시 거래 중단"""
@@ -22,9 +24,33 @@ class RiskManager:
return False return False
return True return True
def can_open_new_position(self) -> bool: async def can_open_new_position(self, symbol: str, side: str) -> bool:
"""최대 동시 포지션 수 체크""" """포지션 오픈 가능 여부 (전체 한도 + 중복 진입 + 동일 방향 제한)"""
return len(self.open_positions) < self.config.max_positions async with self._lock:
if len(self.open_positions) >= self.config.max_positions:
logger.info(f"최대 포지션 수 도달: {len(self.open_positions)}/{self.config.max_positions}")
return False
if symbol in self.open_positions:
logger.info(f"{symbol} 이미 포지션 보유 중")
return False
same_dir = sum(1 for s in self.open_positions.values() if s == side)
if same_dir >= self.config.max_same_direction:
logger.info(f"동일 방향({side}) 한도 도달: {same_dir}/{self.config.max_same_direction}")
return False
return True
async def register_position(self, symbol: str, side: str):
"""포지션 등록"""
async with self._lock:
self.open_positions[symbol] = side
logger.info(f"포지션 등록: {symbol} {side} (현재 {len(self.open_positions)}개)")
async def close_position(self, symbol: str, pnl: float):
"""포지션 닫기 + PnL 기록"""
async with self._lock:
self.open_positions.pop(symbol, None)
self.daily_pnl += pnl
logger.info(f"포지션 종료: {symbol}, PnL={pnl:+.4f}, 누적={self.daily_pnl:+.4f}")
def record_pnl(self, pnl: float): def record_pnl(self, pnl: float):
self.daily_pnl += pnl self.daily_pnl += pnl
@@ -36,7 +62,7 @@ class RiskManager:
logger.info("일일 PnL 초기화") logger.info("일일 PnL 초기화")
def set_base_balance(self, balance: float) -> None: def set_base_balance(self, balance: float) -> None:
"""봇 시작 시 기준 잔고 설정 (동적 비율 계산 기준점)""" """봇 시작 시 기준 잔고 설정"""
self.initial_balance = balance self.initial_balance = balance
def get_dynamic_margin_ratio(self, balance: float) -> float: def get_dynamic_margin_ratio(self, balance: float) -> float:

View File

@@ -37,6 +37,25 @@ def sample_df():
}) })
def test_bot_accepts_symbol_and_risk(config):
"""TradingBot이 symbol과 risk를 외부에서 주입받을 수 있다."""
from src.risk_manager import RiskManager
risk = RiskManager(config)
with patch("src.bot.BinanceFuturesClient"):
bot = TradingBot(config, symbol="TRXUSDT", risk=risk)
assert bot.symbol == "TRXUSDT"
assert bot.risk is risk
def test_bot_stream_uses_injected_symbol(config):
"""봇의 stream이 주입된 심볼을 primary로 사용한다."""
from src.risk_manager import RiskManager
risk = RiskManager(config)
with patch("src.bot.BinanceFuturesClient"):
bot = TradingBot(config, symbol="DOGEUSDT", risk=risk)
assert "dogeusdt" in bot.stream.buffers
def test_bot_uses_multi_symbol_stream(config): def test_bot_uses_multi_symbol_stream(config):
from src.data_stream import MultiSymbolStream from src.data_stream import MultiSymbolStream
with patch("src.bot.BinanceFuturesClient"): with patch("src.bot.BinanceFuturesClient"):
@@ -64,6 +83,12 @@ async def test_bot_processes_signal(config, sample_df):
bot.exchange.calculate_quantity = MagicMock(return_value=100.0) bot.exchange.calculate_quantity = MagicMock(return_value=100.0)
bot.exchange.MIN_NOTIONAL = 5.0 bot.exchange.MIN_NOTIONAL = 5.0
bot.risk = MagicMock()
bot.risk.is_trading_allowed.return_value = True
bot.risk.can_open_new_position = AsyncMock(return_value=True)
bot.risk.register_position = AsyncMock()
bot.risk.get_dynamic_margin_ratio.return_value = 0.50
with patch("src.bot.Indicators") as MockInd: with patch("src.bot.Indicators") as MockInd:
mock_ind = MagicMock() mock_ind = MagicMock()
mock_ind.calculate_all.return_value = sample_df mock_ind.calculate_all.return_value = sample_df
@@ -82,7 +107,7 @@ async def test_close_and_reenter_calls_open_when_ml_passes(config, sample_df):
bot._close_position = AsyncMock() bot._close_position = AsyncMock()
bot._open_position = AsyncMock() bot._open_position = AsyncMock()
bot.risk = MagicMock() bot.risk = MagicMock()
bot.risk.can_open_new_position.return_value = True bot.risk.can_open_new_position = AsyncMock(return_value=True)
bot.ml_filter = MagicMock() bot.ml_filter = MagicMock()
bot.ml_filter.is_model_loaded.return_value = True bot.ml_filter.is_model_loaded.return_value = True
bot.ml_filter.should_enter.return_value = True bot.ml_filter.should_enter.return_value = True
@@ -102,6 +127,8 @@ async def test_close_and_reenter_skips_open_when_ml_blocks(config, sample_df):
bot._close_position = AsyncMock() bot._close_position = AsyncMock()
bot._open_position = AsyncMock() bot._open_position = AsyncMock()
bot.risk = MagicMock()
bot.risk.can_open_new_position = AsyncMock(return_value=True)
bot.ml_filter = MagicMock() bot.ml_filter = MagicMock()
bot.ml_filter.is_model_loaded.return_value = True bot.ml_filter.is_model_loaded.return_value = True
bot.ml_filter.should_enter.return_value = False bot.ml_filter.should_enter.return_value = False
@@ -122,7 +149,7 @@ async def test_close_and_reenter_skips_open_when_max_positions_reached(config, s
bot._close_position = AsyncMock() bot._close_position = AsyncMock()
bot._open_position = AsyncMock() bot._open_position = AsyncMock()
bot.risk = MagicMock() bot.risk = MagicMock()
bot.risk.can_open_new_position.return_value = False bot.risk.can_open_new_position = AsyncMock(return_value=False)
position = {"positionAmt": "100", "entryPrice": "0.5", "markPrice": "0.52"} position = {"positionAmt": "100", "entryPrice": "0.5", "markPrice": "0.52"}
await bot._close_and_reenter(position, "SHORT", sample_df) await bot._close_and_reenter(position, "SHORT", sample_df)
@@ -206,6 +233,12 @@ async def test_process_candle_fetches_oi_and_funding(config, sample_df):
bot.exchange.get_open_interest = AsyncMock(return_value=5000000.0) bot.exchange.get_open_interest = AsyncMock(return_value=5000000.0)
bot.exchange.get_funding_rate = AsyncMock(return_value=0.0001) bot.exchange.get_funding_rate = AsyncMock(return_value=0.0001)
bot.risk = MagicMock()
bot.risk.is_trading_allowed.return_value = True
bot.risk.can_open_new_position = AsyncMock(return_value=True)
bot.risk.register_position = AsyncMock()
bot.risk.get_dynamic_margin_ratio.return_value = 0.50
# 신호를 LONG으로 강제해 build_features가 반드시 호출되도록 함 # 신호를 LONG으로 강제해 build_features가 반드시 호출되도록 함
with patch("src.bot.Indicators") as mock_ind_cls: with patch("src.bot.Indicators") as mock_ind_cls:
mock_ind = MagicMock() mock_ind = MagicMock()

View File

@@ -19,3 +19,32 @@ def test_config_dynamic_margin_params():
assert cfg.margin_max_ratio == 0.50 assert cfg.margin_max_ratio == 0.50
assert cfg.margin_min_ratio == 0.20 assert cfg.margin_min_ratio == 0.20
assert cfg.margin_decay_rate == 0.0006 assert cfg.margin_decay_rate == 0.0006
def test_config_loads_symbols_list():
"""SYMBOLS 환경변수로 쉼표 구분 리스트를 로드한다."""
os.environ["SYMBOLS"] = "XRPUSDT,TRXUSDT,DOGEUSDT"
os.environ.pop("SYMBOL", None)
cfg = Config()
assert cfg.symbols == ["XRPUSDT", "TRXUSDT", "DOGEUSDT"]
def test_config_fallback_to_symbol():
"""SYMBOLS 미설정 시 SYMBOL에서 1개짜리 리스트로 변환한다."""
os.environ.pop("SYMBOLS", None)
os.environ["SYMBOL"] = "XRPUSDT"
cfg = Config()
assert cfg.symbols == ["XRPUSDT"]
def test_config_correlation_symbols():
"""상관관계 심볼 로드."""
os.environ["CORRELATION_SYMBOLS"] = "BTCUSDT,ETHUSDT"
cfg = Config()
assert cfg.correlation_symbols == ["BTCUSDT", "ETHUSDT"]
def test_config_max_same_direction_default():
"""동일 방향 최대 수 기본값 2."""
cfg = Config()
assert cfg.max_same_direction == 2

View File

@@ -22,6 +22,7 @@ def client():
config.leverage = 10 config.leverage = 10
c = BinanceFuturesClient.__new__(BinanceFuturesClient) c = BinanceFuturesClient.__new__(BinanceFuturesClient)
c.config = config c.config = config
c.symbol = config.symbol
return c return c
@@ -36,10 +37,24 @@ def exchange():
config = Config() config = Config()
c = BinanceFuturesClient.__new__(BinanceFuturesClient) c = BinanceFuturesClient.__new__(BinanceFuturesClient)
c.config = config c.config = config
c.symbol = config.symbol
c.client = MagicMock() c.client = MagicMock()
return c return c
def test_exchange_uses_own_symbol():
"""Exchange 클라이언트가 config.symbol 대신 생성자의 symbol을 사용한다."""
os.environ.update({
"BINANCE_API_KEY": "test_key",
"BINANCE_API_SECRET": "test_secret",
"SYMBOL": "XRPUSDT",
})
config = Config()
with patch("src.exchange.Client"):
client = BinanceFuturesClient(config, symbol="TRXUSDT")
assert client.symbol == "TRXUSDT"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_set_leverage(config): async def test_set_leverage(config):
with patch("src.exchange.Client") as MockClient: with patch("src.exchange.Client") as MockClient:

View File

@@ -29,10 +29,13 @@ def test_trading_allowed_normal(config):
assert rm.is_trading_allowed() is True assert rm.is_trading_allowed() is True
def test_position_size_capped(config): @pytest.mark.asyncio
async def test_position_size_capped(config):
rm = RiskManager(config, max_daily_loss_pct=0.05) rm = RiskManager(config, max_daily_loss_pct=0.05)
rm.open_positions = ["pos1", "pos2", "pos3"] await rm.register_position("XRPUSDT", "LONG")
assert rm.can_open_new_position() is False await rm.register_position("TRXUSDT", "SHORT")
await rm.register_position("DOGEUSDT", "LONG")
assert await rm.can_open_new_position("SOLUSDT", "SHORT") is False
# --- 동적 증거금 비율 테스트 --- # --- 동적 증거금 비율 테스트 ---
@@ -81,3 +84,54 @@ def test_ratio_clamped_at_max(risk):
"""잔고가 기준보다 작아도 최대 비율(50%) 초과하지 않음""" """잔고가 기준보다 작아도 최대 비율(50%) 초과하지 않음"""
ratio = risk.get_dynamic_margin_ratio(5.0) ratio = risk.get_dynamic_margin_ratio(5.0)
assert ratio == pytest.approx(0.50, abs=1e-6) assert ratio == pytest.approx(0.50, abs=1e-6)
# --- 멀티심볼 공유 RiskManager 테스트 ---
@pytest.fixture
def shared_risk(config):
config.max_same_direction = 2
return RiskManager(config)
@pytest.mark.asyncio
async def test_can_open_new_position_async(shared_risk):
"""비동기 포지션 오픈 허용 체크."""
assert await shared_risk.can_open_new_position("XRPUSDT", "LONG") is True
@pytest.mark.asyncio
async def test_register_and_close_position(shared_risk):
"""포지션 등록 후 닫기."""
await shared_risk.register_position("XRPUSDT", "LONG")
assert "XRPUSDT" in shared_risk.open_positions
await shared_risk.close_position("XRPUSDT", pnl=1.5)
assert "XRPUSDT" not in shared_risk.open_positions
assert shared_risk.daily_pnl == 1.5
@pytest.mark.asyncio
async def test_same_symbol_blocked(shared_risk):
"""같은 심볼 중복 진입 차단."""
await shared_risk.register_position("XRPUSDT", "LONG")
assert await shared_risk.can_open_new_position("XRPUSDT", "SHORT") is False
@pytest.mark.asyncio
async def test_max_same_direction_limit(shared_risk):
"""같은 방향 2개 초과 차단."""
await shared_risk.register_position("XRPUSDT", "LONG")
await shared_risk.register_position("TRXUSDT", "LONG")
# 3번째 LONG 차단
assert await shared_risk.can_open_new_position("DOGEUSDT", "LONG") is False
# SHORT은 허용
assert await shared_risk.can_open_new_position("DOGEUSDT", "SHORT") is True
@pytest.mark.asyncio
async def test_max_positions_global_limit(shared_risk):
"""전체 포지션 수 한도 초과 차단."""
shared_risk.config.max_positions = 2
await shared_risk.register_position("XRPUSDT", "LONG")
await shared_risk.register_position("TRXUSDT", "SHORT")
assert await shared_risk.can_open_new_position("DOGEUSDT", "LONG") is False