Compare commits
9 Commits
852d3a8265
...
d92fae13f8
| Author | SHA1 | Date | |
|---|---|---|---|
| d92fae13f8 | |||
|
|
dfcd803db5 | ||
|
|
9f4c22b5e6 | ||
|
|
ae5692cde4 | ||
|
|
7acbdca3f4 | ||
|
|
e7620248c7 | ||
|
|
2e09f5340a | ||
|
|
9318fb887e | ||
|
|
7aef391b69 |
13
main.py
13
main.py
@@ -1,7 +1,9 @@
|
||||
import asyncio
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from src.config import Config
|
||||
from src.bot import TradingBot
|
||||
from src.risk_manager import RiskManager
|
||||
from src.logger_setup import setup_logger
|
||||
|
||||
load_dotenv()
|
||||
@@ -10,8 +12,15 @@ load_dotenv()
|
||||
async def main():
|
||||
setup_logger(log_level="INFO")
|
||||
config = Config()
|
||||
bot = TradingBot(config)
|
||||
await bot.run()
|
||||
risk = RiskManager(config)
|
||||
|
||||
bots = []
|
||||
for symbol in config.symbols:
|
||||
bot = TradingBot(config, symbol=symbol, risk=risk)
|
||||
bots.append(bot)
|
||||
|
||||
logger.info(f"멀티심볼 봇 시작: {config.symbols} ({len(bots)}개 인스턴스)")
|
||||
await asyncio.gather(*[bot.run() for bot in bots])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
55
src/bot.py
55
src/bot.py
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
from loguru import logger
|
||||
from src.config import Config
|
||||
@@ -14,12 +15,25 @@ from src.user_data_stream import UserDataStream
|
||||
|
||||
|
||||
class TradingBot:
|
||||
def __init__(self, config: Config):
|
||||
def __init__(self, config: Config, symbol: str = None, risk: RiskManager = None):
|
||||
self.config = config
|
||||
self.exchange = BinanceFuturesClient(config)
|
||||
self.symbol = symbol or config.symbol
|
||||
self.exchange = BinanceFuturesClient(config, symbol=self.symbol)
|
||||
self.notifier = DiscordNotifier(config.discord_webhook_url)
|
||||
self.risk = RiskManager(config)
|
||||
self.ml_filter = MLFilter(threshold=config.ml_threshold)
|
||||
self.risk = risk or RiskManager(config)
|
||||
# 심볼별 모델 디렉토리. 없으면 기존 models/ 루트로 폴백
|
||||
symbol_model_dir = Path(f"models/{self.symbol.lower()}")
|
||||
if symbol_model_dir.exists():
|
||||
onnx_path = str(symbol_model_dir / "mlx_filter.weights.onnx")
|
||||
lgbm_path = str(symbol_model_dir / "lgbm_filter.pkl")
|
||||
else:
|
||||
onnx_path = "models/mlx_filter.weights.onnx"
|
||||
lgbm_path = "models/lgbm_filter.pkl"
|
||||
self.ml_filter = MLFilter(
|
||||
onnx_path=onnx_path,
|
||||
lgbm_path=lgbm_path,
|
||||
threshold=config.ml_threshold,
|
||||
)
|
||||
self.current_trade_side: str | None = None # "LONG" | "SHORT"
|
||||
self._entry_price: float | None = None
|
||||
self._entry_quantity: float | None = None
|
||||
@@ -28,17 +42,17 @@ class TradingBot:
|
||||
self._oi_history: deque = deque(maxlen=5)
|
||||
self._latest_ret_1: float = 0.0
|
||||
self.stream = MultiSymbolStream(
|
||||
symbols=[config.symbol, "BTCUSDT", "ETHUSDT"],
|
||||
symbols=[self.symbol] + config.correlation_symbols,
|
||||
interval="15m",
|
||||
on_candle=self._on_candle_closed,
|
||||
)
|
||||
|
||||
async def _on_candle_closed(self, candle: dict):
|
||||
xrp_df = self.stream.get_dataframe(self.config.symbol)
|
||||
primary_df = self.stream.get_dataframe(self.symbol)
|
||||
btc_df = self.stream.get_dataframe("BTCUSDT")
|
||||
eth_df = self.stream.get_dataframe("ETHUSDT")
|
||||
if xrp_df is not None:
|
||||
await self.process_candle(xrp_df, btc_df=btc_df, eth_df=eth_df)
|
||||
if primary_df is not None:
|
||||
await self.process_candle(primary_df, btc_df=btc_df, eth_df=eth_df)
|
||||
|
||||
async def _recover_position(self) -> None:
|
||||
"""재시작 시 바이낸스에서 현재 포지션을 조회하여 상태 복구."""
|
||||
@@ -134,8 +148,8 @@ class TradingBot:
|
||||
|
||||
if position is None and raw_signal != "HOLD":
|
||||
self.current_trade_side = None
|
||||
if not self.risk.can_open_new_position():
|
||||
logger.info("최대 포지션 수 도달")
|
||||
if not await self.risk.can_open_new_position(self.symbol, raw_signal):
|
||||
logger.info(f"[{self.symbol}] 포지션 오픈 불가")
|
||||
return
|
||||
signal = raw_signal
|
||||
features = build_features(
|
||||
@@ -163,12 +177,14 @@ class TradingBot:
|
||||
|
||||
async def _open_position(self, signal: str, df):
|
||||
balance = await self.exchange.get_balance()
|
||||
num_symbols = len(self.config.symbols)
|
||||
per_symbol_balance = balance / num_symbols
|
||||
price = df["close"].iloc[-1]
|
||||
margin_ratio = self.risk.get_dynamic_margin_ratio(balance)
|
||||
quantity = self.exchange.calculate_quantity(
|
||||
balance=balance, price=price, leverage=self.config.leverage, margin_ratio=margin_ratio
|
||||
balance=per_symbol_balance, price=price, leverage=self.config.leverage, margin_ratio=margin_ratio
|
||||
)
|
||||
logger.info(f"포지션 크기: 잔고={balance:.2f} USDT, 증거금비율={margin_ratio:.1%}, 수량={quantity}")
|
||||
logger.info(f"[{self.symbol}] 포지션 크기: 잔고={per_symbol_balance:.2f}/{balance:.2f} USDT, 증거금비율={margin_ratio:.1%}, 수량={quantity}")
|
||||
stop_loss, take_profit = Indicators(df).get_atr_stop(df, signal, price)
|
||||
|
||||
notional = quantity * price
|
||||
@@ -190,11 +206,12 @@ class TradingBot:
|
||||
"atr": float(last_row["atr"]) if "atr" in last_row.index and pd.notna(last_row["atr"]) else 0.0,
|
||||
}
|
||||
|
||||
await self.risk.register_position(self.symbol, signal)
|
||||
self.current_trade_side = signal
|
||||
self._entry_price = price
|
||||
self._entry_quantity = quantity
|
||||
self.notifier.notify_open(
|
||||
symbol=self.config.symbol,
|
||||
symbol=self.symbol,
|
||||
side=signal,
|
||||
entry_price=price,
|
||||
quantity=quantity,
|
||||
@@ -245,10 +262,10 @@ class TradingBot:
|
||||
estimated_pnl = self._calc_estimated_pnl(exit_price)
|
||||
diff = net_pnl - estimated_pnl
|
||||
|
||||
self.risk.record_pnl(net_pnl)
|
||||
await self.risk.close_position(self.symbol, net_pnl)
|
||||
|
||||
self.notifier.notify_close(
|
||||
symbol=self.config.symbol,
|
||||
symbol=self.symbol,
|
||||
side=self.current_trade_side or "UNKNOWN",
|
||||
close_reason=close_reason,
|
||||
exit_price=exit_price,
|
||||
@@ -317,8 +334,8 @@ class TradingBot:
|
||||
try:
|
||||
await self._close_position(position)
|
||||
|
||||
if not self.risk.can_open_new_position():
|
||||
logger.info("최대 포지션 수 도달 — 재진입 건너뜀")
|
||||
if not await self.risk.can_open_new_position(self.symbol, signal):
|
||||
logger.info(f"[{self.symbol}] 최대 포지션 수 도달 — 재진입 건너뜀")
|
||||
return
|
||||
|
||||
if self.ml_filter.is_model_loaded():
|
||||
@@ -337,7 +354,7 @@ class TradingBot:
|
||||
self._is_reentering = False
|
||||
|
||||
async def run(self):
|
||||
logger.info(f"봇 시작: {self.config.symbol}, 레버리지 {self.config.leverage}x")
|
||||
logger.info(f"[{self.symbol}] 봇 시작, 레버리지 {self.config.leverage}x")
|
||||
await self._recover_position()
|
||||
await self._init_oi_history()
|
||||
balance = await self.exchange.get_balance()
|
||||
@@ -345,7 +362,7 @@ class TradingBot:
|
||||
logger.info(f"기준 잔고 설정: {balance:.2f} USDT (동적 증거금 비율 기준점)")
|
||||
|
||||
user_stream = UserDataStream(
|
||||
symbol=self.config.symbol,
|
||||
symbol=self.symbol,
|
||||
on_order_filled=self._on_position_closed,
|
||||
)
|
||||
|
||||
|
||||
@@ -10,8 +10,11 @@ class Config:
|
||||
api_key: str = ""
|
||||
api_secret: str = ""
|
||||
symbol: str = "XRPUSDT"
|
||||
symbols: list = None
|
||||
correlation_symbols: list = None
|
||||
leverage: int = 10
|
||||
max_positions: int = 3
|
||||
max_same_direction: int = 2
|
||||
stop_loss_pct: float = 0.015 # 1.5%
|
||||
take_profit_pct: float = 0.045 # 4.5% (3:1 RR)
|
||||
trailing_stop_pct: float = 0.01 # 1%
|
||||
@@ -31,3 +34,15 @@ class Config:
|
||||
self.margin_min_ratio = float(os.getenv("MARGIN_MIN_RATIO", "0.20"))
|
||||
self.margin_decay_rate = float(os.getenv("MARGIN_DECAY_RATE", "0.0006"))
|
||||
self.ml_threshold = float(os.getenv("ML_THRESHOLD", "0.55"))
|
||||
self.max_same_direction = int(os.getenv("MAX_SAME_DIRECTION", "2"))
|
||||
|
||||
# symbols: SYMBOLS 환경변수 우선, 없으면 SYMBOL에서 변환
|
||||
symbols_env = os.getenv("SYMBOLS", "")
|
||||
if symbols_env:
|
||||
self.symbols = [s.strip() for s in symbols_env.split(",") if s.strip()]
|
||||
else:
|
||||
self.symbols = [self.symbol]
|
||||
|
||||
# correlation_symbols
|
||||
corr_env = os.getenv("CORRELATION_SYMBOLS", "BTCUSDT,ETHUSDT")
|
||||
self.correlation_symbols = [s.strip() for s in corr_env.split(",") if s.strip()]
|
||||
|
||||
@@ -6,8 +6,9 @@ from src.config import Config
|
||||
|
||||
|
||||
class BinanceFuturesClient:
|
||||
def __init__(self, config: Config):
|
||||
def __init__(self, config: Config, symbol: str = None):
|
||||
self.config = config
|
||||
self.symbol = symbol or config.symbol
|
||||
self.client = Client(
|
||||
api_key=config.api_key,
|
||||
api_secret=config.api_secret,
|
||||
@@ -31,7 +32,7 @@ class BinanceFuturesClient:
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.client.futures_change_leverage(
|
||||
symbol=self.config.symbol, leverage=leverage
|
||||
symbol=self.symbol, leverage=leverage
|
||||
),
|
||||
)
|
||||
|
||||
@@ -68,7 +69,7 @@ class BinanceFuturesClient:
|
||||
)
|
||||
|
||||
params = dict(
|
||||
symbol=self.config.symbol,
|
||||
symbol=self.symbol,
|
||||
side=side,
|
||||
type=order_type,
|
||||
quantity=quantity,
|
||||
@@ -98,7 +99,7 @@ class BinanceFuturesClient:
|
||||
"""STOP_MARKET / TAKE_PROFIT_MARKET 등 Algo Order API(/fapi/v1/algoOrder)로 전송."""
|
||||
loop = asyncio.get_event_loop()
|
||||
params = dict(
|
||||
symbol=self.config.symbol,
|
||||
symbol=self.symbol,
|
||||
side=side,
|
||||
algoType="CONDITIONAL",
|
||||
type=order_type,
|
||||
@@ -120,7 +121,7 @@ class BinanceFuturesClient:
|
||||
positions = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.client.futures_position_information(
|
||||
symbol=self.config.symbol
|
||||
symbol=self.symbol
|
||||
),
|
||||
)
|
||||
for p in positions:
|
||||
@@ -134,14 +135,14 @@ class BinanceFuturesClient:
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.client.futures_cancel_all_open_orders(
|
||||
symbol=self.config.symbol
|
||||
symbol=self.symbol
|
||||
),
|
||||
)
|
||||
try:
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.client.futures_cancel_all_algo_open_orders(
|
||||
symbol=self.config.symbol
|
||||
symbol=self.symbol
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -153,7 +154,7 @@ class BinanceFuturesClient:
|
||||
try:
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.client.futures_open_interest(symbol=self.config.symbol),
|
||||
lambda: self.client.futures_open_interest(symbol=self.symbol),
|
||||
)
|
||||
return float(result["openInterest"])
|
||||
except Exception as e:
|
||||
@@ -166,7 +167,7 @@ class BinanceFuturesClient:
|
||||
try:
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.client.futures_mark_price(symbol=self.config.symbol),
|
||||
lambda: self.client.futures_mark_price(symbol=self.symbol),
|
||||
)
|
||||
return float(result["lastFundingRate"])
|
||||
except Exception as e:
|
||||
@@ -180,7 +181,7 @@ class BinanceFuturesClient:
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.client.futures_open_interest_hist(
|
||||
symbol=self.config.symbol, period="15m", limit=limit + 1,
|
||||
symbol=self.symbol, period="15m", limit=limit + 1,
|
||||
),
|
||||
)
|
||||
if len(result) < 2:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
from loguru import logger
|
||||
from src.config import Config
|
||||
|
||||
@@ -5,10 +6,11 @@ from src.config import Config
|
||||
class RiskManager:
|
||||
def __init__(self, config: Config, max_daily_loss_pct: float = 0.05):
|
||||
self.config = config
|
||||
self.max_daily_loss_pct = max_daily_loss_pct # 일일 최대 손실 5%
|
||||
self.max_daily_loss_pct = max_daily_loss_pct
|
||||
self.daily_pnl: float = 0.0
|
||||
self.initial_balance: float = 0.0
|
||||
self.open_positions: list = []
|
||||
self.open_positions: dict[str, str] = {} # {symbol: side}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
def is_trading_allowed(self) -> bool:
|
||||
"""일일 최대 손실 초과 시 거래 중단"""
|
||||
@@ -22,9 +24,33 @@ class RiskManager:
|
||||
return False
|
||||
return True
|
||||
|
||||
def can_open_new_position(self) -> bool:
|
||||
"""최대 동시 포지션 수 체크"""
|
||||
return len(self.open_positions) < self.config.max_positions
|
||||
async def can_open_new_position(self, symbol: str, side: str) -> bool:
|
||||
"""포지션 오픈 가능 여부 (전체 한도 + 중복 진입 + 동일 방향 제한)"""
|
||||
async with self._lock:
|
||||
if len(self.open_positions) >= self.config.max_positions:
|
||||
logger.info(f"최대 포지션 수 도달: {len(self.open_positions)}/{self.config.max_positions}")
|
||||
return False
|
||||
if symbol in self.open_positions:
|
||||
logger.info(f"{symbol} 이미 포지션 보유 중")
|
||||
return False
|
||||
same_dir = sum(1 for s in self.open_positions.values() if s == side)
|
||||
if same_dir >= self.config.max_same_direction:
|
||||
logger.info(f"동일 방향({side}) 한도 도달: {same_dir}/{self.config.max_same_direction}")
|
||||
return False
|
||||
return True
|
||||
|
||||
async def register_position(self, symbol: str, side: str):
|
||||
"""포지션 등록"""
|
||||
async with self._lock:
|
||||
self.open_positions[symbol] = side
|
||||
logger.info(f"포지션 등록: {symbol} {side} (현재 {len(self.open_positions)}개)")
|
||||
|
||||
async def close_position(self, symbol: str, pnl: float):
|
||||
"""포지션 닫기 + PnL 기록"""
|
||||
async with self._lock:
|
||||
self.open_positions.pop(symbol, None)
|
||||
self.daily_pnl += pnl
|
||||
logger.info(f"포지션 종료: {symbol}, PnL={pnl:+.4f}, 누적={self.daily_pnl:+.4f}")
|
||||
|
||||
def record_pnl(self, pnl: float):
|
||||
self.daily_pnl += pnl
|
||||
@@ -36,7 +62,7 @@ class RiskManager:
|
||||
logger.info("일일 PnL 초기화")
|
||||
|
||||
def set_base_balance(self, balance: float) -> None:
|
||||
"""봇 시작 시 기준 잔고 설정 (동적 비율 계산 기준점)"""
|
||||
"""봇 시작 시 기준 잔고 설정"""
|
||||
self.initial_balance = balance
|
||||
|
||||
def get_dynamic_margin_ratio(self, balance: float) -> float:
|
||||
|
||||
@@ -37,6 +37,25 @@ def sample_df():
|
||||
})
|
||||
|
||||
|
||||
def test_bot_accepts_symbol_and_risk(config):
|
||||
"""TradingBot이 symbol과 risk를 외부에서 주입받을 수 있다."""
|
||||
from src.risk_manager import RiskManager
|
||||
risk = RiskManager(config)
|
||||
with patch("src.bot.BinanceFuturesClient"):
|
||||
bot = TradingBot(config, symbol="TRXUSDT", risk=risk)
|
||||
assert bot.symbol == "TRXUSDT"
|
||||
assert bot.risk is risk
|
||||
|
||||
|
||||
def test_bot_stream_uses_injected_symbol(config):
|
||||
"""봇의 stream이 주입된 심볼을 primary로 사용한다."""
|
||||
from src.risk_manager import RiskManager
|
||||
risk = RiskManager(config)
|
||||
with patch("src.bot.BinanceFuturesClient"):
|
||||
bot = TradingBot(config, symbol="DOGEUSDT", risk=risk)
|
||||
assert "dogeusdt" in bot.stream.buffers
|
||||
|
||||
|
||||
def test_bot_uses_multi_symbol_stream(config):
|
||||
from src.data_stream import MultiSymbolStream
|
||||
with patch("src.bot.BinanceFuturesClient"):
|
||||
@@ -64,6 +83,12 @@ async def test_bot_processes_signal(config, sample_df):
|
||||
bot.exchange.calculate_quantity = MagicMock(return_value=100.0)
|
||||
bot.exchange.MIN_NOTIONAL = 5.0
|
||||
|
||||
bot.risk = MagicMock()
|
||||
bot.risk.is_trading_allowed.return_value = True
|
||||
bot.risk.can_open_new_position = AsyncMock(return_value=True)
|
||||
bot.risk.register_position = AsyncMock()
|
||||
bot.risk.get_dynamic_margin_ratio.return_value = 0.50
|
||||
|
||||
with patch("src.bot.Indicators") as MockInd:
|
||||
mock_ind = MagicMock()
|
||||
mock_ind.calculate_all.return_value = sample_df
|
||||
@@ -82,7 +107,7 @@ async def test_close_and_reenter_calls_open_when_ml_passes(config, sample_df):
|
||||
bot._close_position = AsyncMock()
|
||||
bot._open_position = AsyncMock()
|
||||
bot.risk = MagicMock()
|
||||
bot.risk.can_open_new_position.return_value = True
|
||||
bot.risk.can_open_new_position = AsyncMock(return_value=True)
|
||||
bot.ml_filter = MagicMock()
|
||||
bot.ml_filter.is_model_loaded.return_value = True
|
||||
bot.ml_filter.should_enter.return_value = True
|
||||
@@ -102,6 +127,8 @@ async def test_close_and_reenter_skips_open_when_ml_blocks(config, sample_df):
|
||||
|
||||
bot._close_position = AsyncMock()
|
||||
bot._open_position = AsyncMock()
|
||||
bot.risk = MagicMock()
|
||||
bot.risk.can_open_new_position = AsyncMock(return_value=True)
|
||||
bot.ml_filter = MagicMock()
|
||||
bot.ml_filter.is_model_loaded.return_value = True
|
||||
bot.ml_filter.should_enter.return_value = False
|
||||
@@ -122,7 +149,7 @@ async def test_close_and_reenter_skips_open_when_max_positions_reached(config, s
|
||||
bot._close_position = AsyncMock()
|
||||
bot._open_position = AsyncMock()
|
||||
bot.risk = MagicMock()
|
||||
bot.risk.can_open_new_position.return_value = False
|
||||
bot.risk.can_open_new_position = AsyncMock(return_value=False)
|
||||
|
||||
position = {"positionAmt": "100", "entryPrice": "0.5", "markPrice": "0.52"}
|
||||
await bot._close_and_reenter(position, "SHORT", sample_df)
|
||||
@@ -206,6 +233,12 @@ async def test_process_candle_fetches_oi_and_funding(config, sample_df):
|
||||
bot.exchange.get_open_interest = AsyncMock(return_value=5000000.0)
|
||||
bot.exchange.get_funding_rate = AsyncMock(return_value=0.0001)
|
||||
|
||||
bot.risk = MagicMock()
|
||||
bot.risk.is_trading_allowed.return_value = True
|
||||
bot.risk.can_open_new_position = AsyncMock(return_value=True)
|
||||
bot.risk.register_position = AsyncMock()
|
||||
bot.risk.get_dynamic_margin_ratio.return_value = 0.50
|
||||
|
||||
# 신호를 LONG으로 강제해 build_features가 반드시 호출되도록 함
|
||||
with patch("src.bot.Indicators") as mock_ind_cls:
|
||||
mock_ind = MagicMock()
|
||||
|
||||
@@ -19,3 +19,32 @@ def test_config_dynamic_margin_params():
|
||||
assert cfg.margin_max_ratio == 0.50
|
||||
assert cfg.margin_min_ratio == 0.20
|
||||
assert cfg.margin_decay_rate == 0.0006
|
||||
|
||||
|
||||
def test_config_loads_symbols_list():
|
||||
"""SYMBOLS 환경변수로 쉼표 구분 리스트를 로드한다."""
|
||||
os.environ["SYMBOLS"] = "XRPUSDT,TRXUSDT,DOGEUSDT"
|
||||
os.environ.pop("SYMBOL", None)
|
||||
cfg = Config()
|
||||
assert cfg.symbols == ["XRPUSDT", "TRXUSDT", "DOGEUSDT"]
|
||||
|
||||
|
||||
def test_config_fallback_to_symbol():
|
||||
"""SYMBOLS 미설정 시 SYMBOL에서 1개짜리 리스트로 변환한다."""
|
||||
os.environ.pop("SYMBOLS", None)
|
||||
os.environ["SYMBOL"] = "XRPUSDT"
|
||||
cfg = Config()
|
||||
assert cfg.symbols == ["XRPUSDT"]
|
||||
|
||||
|
||||
def test_config_correlation_symbols():
|
||||
"""상관관계 심볼 로드."""
|
||||
os.environ["CORRELATION_SYMBOLS"] = "BTCUSDT,ETHUSDT"
|
||||
cfg = Config()
|
||||
assert cfg.correlation_symbols == ["BTCUSDT", "ETHUSDT"]
|
||||
|
||||
|
||||
def test_config_max_same_direction_default():
|
||||
"""동일 방향 최대 수 기본값 2."""
|
||||
cfg = Config()
|
||||
assert cfg.max_same_direction == 2
|
||||
|
||||
@@ -22,6 +22,7 @@ def client():
|
||||
config.leverage = 10
|
||||
c = BinanceFuturesClient.__new__(BinanceFuturesClient)
|
||||
c.config = config
|
||||
c.symbol = config.symbol
|
||||
return c
|
||||
|
||||
|
||||
@@ -36,10 +37,24 @@ def exchange():
|
||||
config = Config()
|
||||
c = BinanceFuturesClient.__new__(BinanceFuturesClient)
|
||||
c.config = config
|
||||
c.symbol = config.symbol
|
||||
c.client = MagicMock()
|
||||
return c
|
||||
|
||||
|
||||
def test_exchange_uses_own_symbol():
|
||||
"""Exchange 클라이언트가 config.symbol 대신 생성자의 symbol을 사용한다."""
|
||||
os.environ.update({
|
||||
"BINANCE_API_KEY": "test_key",
|
||||
"BINANCE_API_SECRET": "test_secret",
|
||||
"SYMBOL": "XRPUSDT",
|
||||
})
|
||||
config = Config()
|
||||
with patch("src.exchange.Client"):
|
||||
client = BinanceFuturesClient(config, symbol="TRXUSDT")
|
||||
assert client.symbol == "TRXUSDT"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_leverage(config):
|
||||
with patch("src.exchange.Client") as MockClient:
|
||||
|
||||
@@ -29,10 +29,13 @@ def test_trading_allowed_normal(config):
|
||||
assert rm.is_trading_allowed() is True
|
||||
|
||||
|
||||
def test_position_size_capped(config):
|
||||
@pytest.mark.asyncio
|
||||
async def test_position_size_capped(config):
|
||||
rm = RiskManager(config, max_daily_loss_pct=0.05)
|
||||
rm.open_positions = ["pos1", "pos2", "pos3"]
|
||||
assert rm.can_open_new_position() is False
|
||||
await rm.register_position("XRPUSDT", "LONG")
|
||||
await rm.register_position("TRXUSDT", "SHORT")
|
||||
await rm.register_position("DOGEUSDT", "LONG")
|
||||
assert await rm.can_open_new_position("SOLUSDT", "SHORT") is False
|
||||
|
||||
|
||||
# --- 동적 증거금 비율 테스트 ---
|
||||
@@ -81,3 +84,54 @@ def test_ratio_clamped_at_max(risk):
|
||||
"""잔고가 기준보다 작아도 최대 비율(50%) 초과하지 않음"""
|
||||
ratio = risk.get_dynamic_margin_ratio(5.0)
|
||||
assert ratio == pytest.approx(0.50, abs=1e-6)
|
||||
|
||||
|
||||
# --- 멀티심볼 공유 RiskManager 테스트 ---
|
||||
|
||||
@pytest.fixture
|
||||
def shared_risk(config):
|
||||
config.max_same_direction = 2
|
||||
return RiskManager(config)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_open_new_position_async(shared_risk):
|
||||
"""비동기 포지션 오픈 허용 체크."""
|
||||
assert await shared_risk.can_open_new_position("XRPUSDT", "LONG") is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_and_close_position(shared_risk):
|
||||
"""포지션 등록 후 닫기."""
|
||||
await shared_risk.register_position("XRPUSDT", "LONG")
|
||||
assert "XRPUSDT" in shared_risk.open_positions
|
||||
await shared_risk.close_position("XRPUSDT", pnl=1.5)
|
||||
assert "XRPUSDT" not in shared_risk.open_positions
|
||||
assert shared_risk.daily_pnl == 1.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_same_symbol_blocked(shared_risk):
|
||||
"""같은 심볼 중복 진입 차단."""
|
||||
await shared_risk.register_position("XRPUSDT", "LONG")
|
||||
assert await shared_risk.can_open_new_position("XRPUSDT", "SHORT") is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_same_direction_limit(shared_risk):
|
||||
"""같은 방향 2개 초과 차단."""
|
||||
await shared_risk.register_position("XRPUSDT", "LONG")
|
||||
await shared_risk.register_position("TRXUSDT", "LONG")
|
||||
# 3번째 LONG 차단
|
||||
assert await shared_risk.can_open_new_position("DOGEUSDT", "LONG") is False
|
||||
# SHORT은 허용
|
||||
assert await shared_risk.can_open_new_position("DOGEUSDT", "SHORT") is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_positions_global_limit(shared_risk):
|
||||
"""전체 포지션 수 한도 초과 차단."""
|
||||
shared_risk.config.max_positions = 2
|
||||
await shared_risk.register_position("XRPUSDT", "LONG")
|
||||
await shared_risk.register_position("TRXUSDT", "SHORT")
|
||||
assert await shared_risk.can_open_new_position("DOGEUSDT", "LONG") is False
|
||||
|
||||
Reference in New Issue
Block a user