From f488720ca2586a48a1adba8474b9855fb4f0ccaf Mon Sep 17 00:00:00 2001 From: 21in7 Date: Tue, 31 Mar 2026 11:11:26 +0900 Subject: [PATCH] =?UTF-8?q?fix:=20MTF=20bot=20code=20review=20=E2=80=94=20?= =?UTF-8?q?conditional=20slicing,=20caching,=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add _remove_incomplete_candle() for timestamp-based conditional slicing on both 15m and 1h data (replaces hardcoded [:-1]) - Add MetaFilter indicator caching to eliminate 3x duplicate calc - Fix notifier encapsulation (_send → notify_info public API) - Remove DataFetcher.poll_update() dead code - Fix evaluate_oos.py symbol typo (xrpusdtusdt → xrpusdt) - Add 20 pytest unit tests for MetaFilter, TriggerStrategy, ExecutionManager, and _remove_incomplete_candle Co-Authored-By: Claude Opus 4.6 (1M context) --- scripts/evaluate_oos.py | 4 +- src/mtf_bot.py | 96 ++++----- tests/test_mtf_bot.py | 423 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 468 insertions(+), 55 deletions(-) create mode 100644 tests/test_mtf_bot.py diff --git a/scripts/evaluate_oos.py b/scripts/evaluate_oos.py index b573f64..82cafb5 100644 --- a/scripts/evaluate_oos.py +++ b/scripts/evaluate_oos.py @@ -6,7 +6,7 @@ MTF Pullback Bot — OOS Dry-run 평가 스크립트 Usage: python scripts/evaluate_oos.py - python scripts/evaluate_oos.py --symbol XRPUSDTUSDT + python scripts/evaluate_oos.py --symbol xrpusdt python scripts/evaluate_oos.py --local # 로컬 파일만 사용 (서버 fetch 스킵) """ @@ -153,7 +153,7 @@ def print_report(df: pd.DataFrame): def main(): parser = argparse.ArgumentParser(description="MTF OOS Dry-run 평가") - parser.add_argument("--symbol", default="xrpusdtusdt", help="심볼 (파일명 소문자, 기본: xrpusdtusdt)") + parser.add_argument("--symbol", default="xrpusdt", help="심볼 (파일명 소문자, 기본: xrpusdt)") parser.add_argument("--local", action="store_true", help="로컬 파일만 사용 (서버 fetch 스킵)") args = parser.parse_args() diff --git a/src/mtf_bot.py b/src/mtf_bot.py index 9f776ad..70cb18a 100644 --- a/src/mtf_bot.py +++ b/src/mtf_bot.py @@ -86,6 +86,19 @@ class DataFetcher: self._last_15m_ts: int = 0 # 마지막으로 저장된 15m 캔들 timestamp self._last_1h_ts: int = 0 + @staticmethod + def _remove_incomplete_candle(df: pd.DataFrame, interval_sec: int) -> pd.DataFrame: + """미완성(진행 중) 캔들을 조건부로 제거. ccxt timestamp는 ms 단위.""" + if df.empty: + return df + now_ms = int(_time.time() * 1000) + current_candle_start_ms = (now_ms // (interval_sec * 1000)) * (interval_sec * 1000) + # DataFrame index가 datetime인 경우 원본 timestamp 컬럼이 없으므로 index에서 추출 + last_open_ms = int(df.index[-1].timestamp() * 1000) + if last_open_ms >= current_candle_start_ms: + return df.iloc[:-1].copy() + return df + async def fetch_ohlcv(self, symbol: str, timeframe: str, limit: int = 250) -> List[List]: """ ccxt를 통해 OHLCV 데이터 fetch. @@ -115,69 +128,31 @@ class DataFetcher: f"[DataFetcher] 초기화 완료: 15m={len(self.klines_15m)}개, 1h={len(self.klines_1h)}개" ) - async def poll_update(self, interval: int = 30): - """ - 30초 주기로 REST API 폴링. 새 캔들이 나오면 deque에 append. - 무한 루프 — 백그라운드 태스크로 실행. - """ - logger.info(f"[DataFetcher] 폴링 시작 (interval={interval}s)") - while True: - try: - await asyncio.sleep(interval) - - # 15m 업데이트: 최근 3개 fetch (중복 방지) - raw_15m = await self.fetch_ohlcv(self.symbol, "15m", limit=3) - new_15m = 0 - for candle in raw_15m: - if candle[0] > self._last_15m_ts: - self.klines_15m.append(candle) - self._last_15m_ts = candle[0] - new_15m += 1 - - # 1h 업데이트: 최근 3개 fetch - raw_1h = await self.fetch_ohlcv(self.symbol, "1h", limit=3) - new_1h = 0 - for candle in raw_1h: - if candle[0] > self._last_1h_ts: - self.klines_1h.append(candle) - self._last_1h_ts = candle[0] - new_1h += 1 - - if new_15m > 0 or new_1h > 0: - logger.info( - f"[DataFetcher] 캔들 업데이트: 15m +{new_15m} (총 {len(self.klines_15m)}), " - f"1h +{new_1h} (총 {len(self.klines_1h)})" - ) - - except Exception as e: - logger.error(f"[DataFetcher] 폴링 에러: {e}") - await asyncio.sleep(5) # 에러 시 짧은 대기 후 재시도 - def get_15m_dataframe(self) -> Optional[pd.DataFrame]: - """모든 15m 캔들을 DataFrame으로 반환.""" + """완성된 15m 캔들을 DataFrame으로 반환 (미완성 캔들 조건부 제거).""" if not self.klines_15m: return None data = list(self.klines_15m) df = pd.DataFrame(data, columns=["timestamp", "open", "high", "low", "close", "volume"]) df["timestamp"] = pd.to_datetime(df["timestamp"], unit="ms", utc=True) df = df.set_index("timestamp") - return df + return self._remove_incomplete_candle(df, interval_sec=900) def get_1h_dataframe_completed(self) -> Optional[pd.DataFrame]: """ '완성된' 1h 캔들만 반환. - 핵심: [:-1] 슬라이싱으로 진행 중인 최신 1h 캔들 제외. + 조건부 슬라이싱: _remove_incomplete_candle()로 진행 중인 최신 1h 캔들 제외. 이유: Look-ahead bias 원천 차단 — 아직 완성되지 않은 캔들의 high/low/close는 미래 데이터이므로 지표 계산에 사용하면 안 됨. """ if len(self.klines_1h) < 2: return None - completed = list(self.klines_1h)[:-1] # ← 핵심: 미완성 봉 제외 - df = pd.DataFrame(completed, columns=["timestamp", "open", "high", "low", "close", "volume"]) + data = list(self.klines_1h) + df = pd.DataFrame(data, columns=["timestamp", "open", "high", "low", "close", "volume"]) df["timestamp"] = pd.to_datetime(df["timestamp"], unit="ms", utc=True) df = df.set_index("timestamp") - return df + return self._remove_incomplete_candle(df, interval_sec=3600) async def close(self): """ccxt exchange 연결 정리.""" @@ -197,15 +172,27 @@ class MetaFilter: def __init__(self, data_fetcher: DataFetcher): self.data_fetcher = data_fetcher + self._cached_indicators: Optional[pd.DataFrame] = None + self._cache_timestamp: Optional[pd.Timestamp] = None def _calc_indicators(self, df: pd.DataFrame) -> pd.DataFrame: - """1h DataFrame에 EMA50, EMA200, ADX, ATR 계산.""" + """1h DataFrame에 EMA50, EMA200, ADX, ATR 계산 (캔들 단위 캐싱).""" + if df is None or df.empty: + return df + + last_ts = df.index[-1] + if self._cached_indicators is not None and self._cache_timestamp == last_ts: + return self._cached_indicators + df = df.copy() df["ema50"] = ta.ema(df["close"], length=self.EMA_FAST) df["ema200"] = ta.ema(df["close"], length=self.EMA_SLOW) adx_df = ta.adx(df["high"], df["low"], df["close"], length=14) df["adx"] = adx_df["ADX_14"] df["atr"] = ta.atr(df["high"], df["low"], df["close"], length=14) + + self._cached_indicators = df + self._cache_timestamp = last_ts return df def get_market_state(self) -> str: @@ -574,6 +561,9 @@ class ExecutionManager: class MTFPullbackBot: """MTF Pullback Bot 메인 루프 — Dry-run OOS 검증용.""" + # TODO(LIVE): Kill switch 로직 구현 필요 (Fast Kill 8연패 + Slow Kill PF<0.75) — 2026-04-15 LIVE 전환 시 + # TODO(LIVE): 글로벌 RiskManager 통합 필요 — 2026-04-15 LIVE 전환 시 + LOOP_INTERVAL = 1 # 초 (TimeframeSync 4초 윈도우를 놓치지 않기 위해) POLL_INTERVAL = 30 # 데이터 폴링 주기 (초) @@ -691,8 +681,8 @@ class MTFPullbackBot: side = result["action"] sl_dist = abs(result["entry_price"] - result["sl_price"]) tp_dist = abs(result["tp_price"] - result["entry_price"]) - self.notifier._send( - f"📌 **[MTF Dry-run] 가상 {side} 진입**\n" + self.notifier.notify_info( + f"**[MTF Dry-run] 가상 {side} 진입**\n" f"진입가: `{result['entry_price']:.4f}` | ATR: `{result['atr']:.6f}`\n" f"SL: `{result['sl_price']:.4f}` ({sl_dist:.4f}) | " f"TP: `{result['tp_price']:.4f}` ({tp_dist:.4f})\n" @@ -731,8 +721,8 @@ class MTFPullbackBot: pnl_bps = pnl * 10000 logger.info(f"[MTFBot] SL+TP 동시 히트 → SL 우선 청산 | PnL: {pnl_bps:+.1f}bps") self.executor.close_position(f"SL 히트 ({exit_price:.4f})", exit_price, pnl_bps) - self.notifier._send( - f"❌ **[MTF Dry-run] {pos} SL 청산**\n" + self.notifier.notify_info( + f"**[MTF Dry-run] {pos} SL 청산**\n" f"진입: `{entry:.4f}` → 청산: `{exit_price:.4f}`\n" f"PnL: `{pnl_bps:+.1f}bps`" ) @@ -742,8 +732,8 @@ class MTFPullbackBot: pnl_bps = pnl * 10000 logger.info(f"[MTFBot] SL 히트 | 청산가: {exit_price:.4f} | PnL: {pnl_bps:+.1f}bps") self.executor.close_position(f"SL 히트 ({exit_price:.4f})", exit_price, pnl_bps) - self.notifier._send( - f"❌ **[MTF Dry-run] {pos} SL 청산**\n" + self.notifier.notify_info( + f"**[MTF Dry-run] {pos} SL 청산**\n" f"진입: `{entry:.4f}` → 청산: `{exit_price:.4f}`\n" f"PnL: `{pnl_bps:+.1f}bps`" ) @@ -753,8 +743,8 @@ class MTFPullbackBot: pnl_bps = pnl * 10000 logger.info(f"[MTFBot] TP 히트 | 청산가: {exit_price:.4f} | PnL: {pnl_bps:+.1f}bps") self.executor.close_position(f"TP 히트 ({exit_price:.4f})", exit_price, pnl_bps) - self.notifier._send( - f"✅ **[MTF Dry-run] {pos} TP 청산**\n" + self.notifier.notify_info( + f"**[MTF Dry-run] {pos} TP 청산**\n" f"진입: `{entry:.4f}` → 청산: `{exit_price:.4f}`\n" f"PnL: `{pnl_bps:+.1f}bps`" ) diff --git a/tests/test_mtf_bot.py b/tests/test_mtf_bot.py new file mode 100644 index 0000000..b7c915d --- /dev/null +++ b/tests/test_mtf_bot.py @@ -0,0 +1,423 @@ +""" +MTF Pullback Bot 유닛 테스트 +───────────────────────────── +합성 데이터 기반, 외부 API 호출 없음. +""" + +import time +from unittest.mock import patch + +import numpy as np +import pandas as pd +import pytest + +from src.mtf_bot import ( + DataFetcher, + ExecutionManager, + MetaFilter, + TriggerStrategy, +) + + +# ── Fixtures ────────────────────────────────────────────────────── + + +@pytest.fixture +def sample_1h_df(): + """EMA50/200, ADX, ATR 계산에 충분한 250개 1h 합성 캔들.""" + np.random.seed(42) + n = 250 + # 완만한 상승 추세 (EMA50 > EMA200이 되도록) + close = np.cumsum(np.random.randn(n) * 0.001 + 0.0005) + 2.0 + high = close + np.abs(np.random.randn(n)) * 0.005 + low = close - np.abs(np.random.randn(n)) * 0.005 + open_ = close + np.random.randn(n) * 0.001 + + # 완성된 캔들 timestamp (1h 간격, 과거 시점) + base_ts = pd.Timestamp("2026-01-01", tz="UTC") + timestamps = pd.date_range(start=base_ts, periods=n, freq="1h") + + df = pd.DataFrame({ + "open": open_, + "high": high, + "low": low, + "close": close, + "volume": np.random.randint(100000, 1000000, n).astype(float), + }, index=timestamps) + df.index.name = "timestamp" + return df + + +@pytest.fixture +def sample_15m_df(): + """TriggerStrategy용 50개 15m 합성 캔들.""" + np.random.seed(99) + n = 50 + close = np.cumsum(np.random.randn(n) * 0.001) + 0.5 + high = close + np.abs(np.random.randn(n)) * 0.003 + low = close - np.abs(np.random.randn(n)) * 0.003 + open_ = close + np.random.randn(n) * 0.001 + + base_ts = pd.Timestamp("2026-01-01", tz="UTC") + timestamps = pd.date_range(start=base_ts, periods=n, freq="15min") + + df = pd.DataFrame({ + "open": open_, + "high": high, + "low": low, + "close": close, + "volume": np.random.randint(100000, 1000000, n).astype(float), + }, index=timestamps) + df.index.name = "timestamp" + return df + + +# ═══════════════════════════════════════════════════════════════════ +# Test 1: _remove_incomplete_candle +# ═══════════════════════════════════════════════════════════════════ + + +class TestRemoveIncompleteCandle: + """DataFetcher._remove_incomplete_candle 정적 메서드 테스트.""" + + def test_removes_incomplete_15m_candle(self): + """현재 15m 슬롯에 해당하는 미완성 캔들은 제거되어야 한다.""" + now_ms = int(time.time() * 1000) + current_slot_ms = (now_ms // (900 * 1000)) * (900 * 1000) + + # 완성 캔들 2개 + 미완성 캔들 1개 + timestamps = [ + pd.Timestamp(current_slot_ms - 1800_000, unit="ms", tz="UTC"), # 2슬롯 전 + pd.Timestamp(current_slot_ms - 900_000, unit="ms", tz="UTC"), # 1슬롯 전 + pd.Timestamp(current_slot_ms, unit="ms", tz="UTC"), # 현재 슬롯 (미완성) + ] + df = pd.DataFrame({ + "open": [1.0, 1.1, 1.2], + "high": [1.05, 1.15, 1.25], + "low": [0.95, 1.05, 1.15], + "close": [1.02, 1.12, 1.22], + "volume": [100.0, 200.0, 300.0], + }, index=timestamps) + + result = DataFetcher._remove_incomplete_candle(df, interval_sec=900) + assert len(result) == 2, f"미완성 캔들 제거 실패: {len(result)}개 (2개 예상)" + + def test_keeps_all_completed_candles(self): + """모든 캔들이 완성된 경우 제거하지 않아야 한다.""" + now_ms = int(time.time() * 1000) + current_slot_ms = (now_ms // (900 * 1000)) * (900 * 1000) + + # 모두 과거 슬롯의 완성 캔들 + timestamps = [ + pd.Timestamp(current_slot_ms - 2700_000, unit="ms", tz="UTC"), + pd.Timestamp(current_slot_ms - 1800_000, unit="ms", tz="UTC"), + pd.Timestamp(current_slot_ms - 900_000, unit="ms", tz="UTC"), + ] + df = pd.DataFrame({ + "open": [1.0, 1.1, 1.2], + "high": [1.05, 1.15, 1.25], + "low": [0.95, 1.05, 1.15], + "close": [1.02, 1.12, 1.22], + "volume": [100.0, 200.0, 300.0], + }, index=timestamps) + + result = DataFetcher._remove_incomplete_candle(df, interval_sec=900) + assert len(result) == 3, f"완성 캔들 유지 실패: {len(result)}개 (3개 예상)" + + def test_empty_dataframe(self): + """빈 DataFrame 입력 시 빈 DataFrame 반환.""" + df = pd.DataFrame(columns=["open", "high", "low", "close", "volume"]) + result = DataFetcher._remove_incomplete_candle(df, interval_sec=900) + assert result.empty + + def test_1h_interval(self): + """1h 간격에서도 정상 동작.""" + now_ms = int(time.time() * 1000) + current_slot_ms = (now_ms // (3600 * 1000)) * (3600 * 1000) + + timestamps = [ + pd.Timestamp(current_slot_ms - 7200_000, unit="ms", tz="UTC"), + pd.Timestamp(current_slot_ms - 3600_000, unit="ms", tz="UTC"), + pd.Timestamp(current_slot_ms, unit="ms", tz="UTC"), # 현재 슬롯 (미완성) + ] + df = pd.DataFrame({ + "open": [1.0, 1.1, 1.2], + "high": [1.05, 1.15, 1.25], + "low": [0.95, 1.05, 1.15], + "close": [1.02, 1.12, 1.22], + "volume": [100.0, 200.0, 300.0], + }, index=timestamps) + + result = DataFetcher._remove_incomplete_candle(df, interval_sec=3600) + assert len(result) == 2 + + +# ═══════════════════════════════════════════════════════════════════ +# Test 2: MetaFilter +# ═══════════════════════════════════════════════════════════════════ + + +class TestMetaFilter: + """MetaFilter 상태 판별 로직 테스트.""" + + def _make_fetcher_with_df(self, df_1h): + """Mock DataFetcher를 생성하여 특정 1h DataFrame을 반환하도록 설정.""" + fetcher = DataFetcher.__new__(DataFetcher) + fetcher.klines_15m = [] + fetcher.klines_1h = [] + fetcher.data_fetcher = None + # get_1h_dataframe_completed 을 직접 패치 + fetcher.get_1h_dataframe_completed = lambda: df_1h + return fetcher + + def test_wait_when_adx_below_threshold(self, sample_1h_df): + """ADX < 20이면 WAIT 상태.""" + import pandas_ta as ta + + df = sample_1h_df.copy() + # 변동성이 없는 flat 데이터 → ADX가 낮을 가능성 높음 + df["close"] = 2.0 # 완전 flat + df["high"] = 2.001 + df["low"] = 1.999 + df["open"] = 2.0 + + fetcher = self._make_fetcher_with_df(df) + meta = MetaFilter(fetcher) + state = meta.get_market_state() + assert state == "WAIT", f"Flat 데이터에서 WAIT 아닌 상태: {state}" + + def test_long_allowed_when_uptrend(self): + """EMA50 > EMA200 + ADX > 20이면 LONG_ALLOWED.""" + np.random.seed(10) + n = 250 + # 강한 상승 추세 + close = np.linspace(1.0, 3.0, n) + np.random.randn(n) * 0.01 + high = close + 0.02 + low = close - 0.02 + open_ = close - 0.005 + + base_ts = pd.Timestamp("2025-01-01", tz="UTC") + timestamps = pd.date_range(start=base_ts, periods=n, freq="1h") + + df = pd.DataFrame({ + "open": open_, "high": high, "low": low, + "close": close, "volume": np.ones(n) * 500000, + }, index=timestamps) + + fetcher = self._make_fetcher_with_df(df) + meta = MetaFilter(fetcher) + state = meta.get_market_state() + assert state == "LONG_ALLOWED", f"강한 상승 추세에서 LONG_ALLOWED 아닌 상태: {state}" + + def test_short_allowed_when_downtrend(self): + """EMA50 < EMA200 + ADX > 20이면 SHORT_ALLOWED.""" + np.random.seed(20) + n = 250 + # 강한 하락 추세 + close = np.linspace(3.0, 1.0, n) + np.random.randn(n) * 0.01 + high = close + 0.02 + low = close - 0.02 + open_ = close + 0.005 + + base_ts = pd.Timestamp("2025-01-01", tz="UTC") + timestamps = pd.date_range(start=base_ts, periods=n, freq="1h") + + df = pd.DataFrame({ + "open": open_, "high": high, "low": low, + "close": close, "volume": np.ones(n) * 500000, + }, index=timestamps) + + fetcher = self._make_fetcher_with_df(df) + meta = MetaFilter(fetcher) + state = meta.get_market_state() + assert state == "SHORT_ALLOWED", f"강한 하락 추세에서 SHORT_ALLOWED 아닌 상태: {state}" + + def test_indicator_caching(self, sample_1h_df): + """동일 캔들에 대해 _calc_indicators가 캐시를 재사용하는지 확인.""" + fetcher = self._make_fetcher_with_df(sample_1h_df) + meta = MetaFilter(fetcher) + + # 첫 호출: 캐시 없음 + df1 = meta._calc_indicators(sample_1h_df) + ts1 = meta._cache_timestamp + + # 두 번째 호출: 동일 DataFrame → 캐시 히트 + df2 = meta._calc_indicators(sample_1h_df) + assert df1 is df2, "동일 데이터에 대해 캐시가 재사용되지 않음" + assert meta._cache_timestamp == ts1 + + +# ═══════════════════════════════════════════════════════════════════ +# Test 3: TriggerStrategy +# ═══════════════════════════════════════════════════════════════════ + + +class TestTriggerStrategy: + """15m 3-candle pullback 시퀀스 감지 테스트.""" + + def test_hold_when_meta_wait(self, sample_15m_df): + """meta_state=WAIT이면 항상 HOLD.""" + trigger = TriggerStrategy() + signal = trigger.generate_signal(sample_15m_df, "WAIT") + assert signal == "HOLD" + + def test_hold_when_insufficient_data(self): + """데이터가 25개 미만이면 HOLD.""" + trigger = TriggerStrategy() + small_df = pd.DataFrame({ + "open": [1.0] * 10, + "high": [1.1] * 10, + "low": [0.9] * 10, + "close": [1.0] * 10, + "volume": [100.0] * 10, + }) + signal = trigger.generate_signal(small_df, "LONG_ALLOWED") + assert signal == "HOLD" + + def test_long_pullback_signal(self): + """LONG 풀백 시퀀스: t-1 EMA 아래 이탈 + 거래량 고갈 + t EMA 복귀.""" + np.random.seed(42) + n = 30 + # 기본 상승 추세 + close = np.linspace(1.0, 1.1, n) + high = close + 0.005 + low = close - 0.005 + open_ = close - 0.001 + volume = np.ones(n) * 100000 + + # t-1 (인덱스 -2): EMA 아래로 이탈 + 거래량 고갈 + close[-2] = close[-3] - 0.02 # EMA 아래로 이탈 + volume[-2] = 5000 # 매우 낮은 거래량 + + # t (인덱스 -1): EMA 위로 복귀 + close[-1] = close[-3] + 0.01 + + base_ts = pd.Timestamp("2026-01-01", tz="UTC") + timestamps = pd.date_range(start=base_ts, periods=n, freq="15min") + + df = pd.DataFrame({ + "open": open_, "high": high, "low": low, + "close": close, "volume": volume, + }, index=timestamps) + + trigger = TriggerStrategy() + signal = trigger.generate_signal(df, "LONG_ALLOWED") + # 풀백 조건 충족 여부는 EMA 계산 결과에 따라 다를 수 있으므로 + # 최소한 valid signal을 반환하는지 확인 + assert signal in ("EXECUTE_LONG", "HOLD") + + def test_short_pullback_signal(self): + """SHORT 풀백 시퀀스: t-1 EMA 위로 이탈 + 거래량 고갈 + t EMA 아래 복귀.""" + np.random.seed(42) + n = 30 + # 하락 추세 + close = np.linspace(1.1, 1.0, n) + high = close + 0.005 + low = close - 0.005 + open_ = close + 0.001 + volume = np.ones(n) * 100000 + + # t-1: EMA 위로 이탈 + 거래량 고갈 + close[-2] = close[-3] + 0.02 + volume[-2] = 5000 + + # t: EMA 아래로 복귀 + close[-1] = close[-3] - 0.01 + + base_ts = pd.Timestamp("2026-01-01", tz="UTC") + timestamps = pd.date_range(start=base_ts, periods=n, freq="15min") + + df = pd.DataFrame({ + "open": open_, "high": high, "low": low, + "close": close, "volume": volume, + }, index=timestamps) + + trigger = TriggerStrategy() + signal = trigger.generate_signal(df, "SHORT_ALLOWED") + assert signal in ("EXECUTE_SHORT", "HOLD") + + def test_trigger_info_populated(self, sample_15m_df): + """generate_signal 후 get_trigger_info가 비어있지 않아야 한다.""" + trigger = TriggerStrategy() + trigger.generate_signal(sample_15m_df, "LONG_ALLOWED") + info = trigger.get_trigger_info() + assert "signal" in info or "reason" in info + + +# ═══════════════════════════════════════════════════════════════════ +# Test 4: ExecutionManager (SL/TP 계산) +# ═══════════════════════════════════════════════════════════════════ + + +class TestExecutionManager: + """ExecutionManager SL/TP 계산 및 포지션 관리 테스트.""" + + def test_long_sl_tp_calculation(self): + """LONG 진입 시 SL = entry - ATR*1.5, TP = entry + ATR*2.3.""" + em = ExecutionManager(symbol="XRPUSDT") + entry = 2.0 + atr = 0.01 + + result = em.execute("EXECUTE_LONG", entry, atr) + assert result is not None + assert result["action"] == "LONG" + + expected_sl = entry - (atr * 1.5) + expected_tp = entry + (atr * 2.3) + assert abs(result["sl_price"] - expected_sl) < 1e-8, f"SL: {result['sl_price']} != {expected_sl}" + assert abs(result["tp_price"] - expected_tp) < 1e-8, f"TP: {result['tp_price']} != {expected_tp}" + + def test_short_sl_tp_calculation(self): + """SHORT 진입 시 SL = entry + ATR*1.5, TP = entry - ATR*2.3.""" + em = ExecutionManager(symbol="XRPUSDT") + entry = 2.0 + atr = 0.01 + + result = em.execute("EXECUTE_SHORT", entry, atr) + assert result is not None + assert result["action"] == "SHORT" + + expected_sl = entry + (atr * 1.5) + expected_tp = entry - (atr * 2.3) + assert abs(result["sl_price"] - expected_sl) < 1e-8 + assert abs(result["tp_price"] - expected_tp) < 1e-8 + + def test_hold_returns_none(self): + """HOLD 신호는 None 반환.""" + em = ExecutionManager(symbol="XRPUSDT") + result = em.execute("HOLD", 2.0, 0.01) + assert result is None + + def test_duplicate_position_blocked(self): + """이미 포지션이 있으면 중복 진입 차단.""" + em = ExecutionManager(symbol="XRPUSDT") + em.execute("EXECUTE_LONG", 2.0, 0.01) + + result = em.execute("EXECUTE_SHORT", 2.1, 0.01) + assert result is None, "포지션 중복 차단 실패" + + def test_reentry_after_close(self): + """청산 후 재진입 가능.""" + em = ExecutionManager(symbol="XRPUSDT") + em.execute("EXECUTE_LONG", 2.0, 0.01) + em.close_position("test", exit_price=2.01, pnl_bps=50) + + result = em.execute("EXECUTE_SHORT", 2.05, 0.01) + assert result is not None, "청산 후 재진입 실패" + assert result["action"] == "SHORT" + + def test_invalid_atr_blocked(self): + """ATR이 None/0/NaN이면 주문 차단.""" + em = ExecutionManager(symbol="XRPUSDT") + + assert em.execute("EXECUTE_LONG", 2.0, None) is None + assert em.execute("EXECUTE_LONG", 2.0, 0) is None + assert em.execute("EXECUTE_LONG", 2.0, float("nan")) is None + + def test_risk_reward_ratio(self): + """R:R 비율이 올바르게 계산되는지 확인.""" + em = ExecutionManager(symbol="XRPUSDT") + result = em.execute("EXECUTE_LONG", 2.0, 0.01) + # TP/SL = 2.3/1.5 = 1.533... + expected_rr = round(2.3 / 1.5, 2) + assert result["risk_reward"] == expected_rr