diff --git a/src/backtest_validator.py b/src/backtest_validator.py index 8856e8c..1931a28 100644 --- a/src/backtest_validator.py +++ b/src/backtest_validator.py @@ -30,7 +30,7 @@ def validate(trades: list[dict], summary: dict, cfg) -> dict: results: list[CheckResult] = [] # 검증 1: 논리적 불변 조건 - results.extend(_check_invariants(trades)) + results.extend(_check_invariants(trades, cfg)) # 검증 2: 통계적 이상 감지 results.extend(_check_statistics(trades, summary)) @@ -47,7 +47,7 @@ def validate(trades: list[dict], summary: dict, cfg) -> dict: } -def _check_invariants(trades: list[dict]) -> list[CheckResult]: +def _check_invariants(trades: list[dict], cfg=None) -> list[CheckResult]: """논리적 불변 조건. 하나라도 위반 시 FAIL.""" results = [] @@ -120,7 +120,7 @@ def _check_invariants(trades: list[dict]) -> list[CheckResult]: )) # 5. 잔고가 음수가 된 적 없음 - balance = 1000.0 # cfg.initial_balance를 몰라도 trades에서 추적 가능 + balance = cfg.initial_balance if cfg is not None else 1000.0 min_balance = balance for t in trades: balance += t["net_pnl"] diff --git a/src/backtester.py b/src/backtester.py index 0a2d563..73d2553 100644 --- a/src/backtester.py +++ b/src/backtester.py @@ -317,16 +317,9 @@ class Backtester: self.ml_filters = {} for sym in self.cfg.symbols: if sym in ml_models and ml_models[sym] is not None: - mf = MLFilter.__new__(MLFilter) - mf._disabled = False - mf._onnx_session = None - mf._lgbm_model = ml_models[sym] - mf._threshold = self.cfg.ml_threshold - mf._onnx_path = Path("/dev/null") - mf._lgbm_path = Path("/dev/null") - mf._loaded_onnx_mtime = 0.0 - mf._loaded_lgbm_mtime = 0.0 - self.ml_filters[sym] = mf + self.ml_filters[sym] = MLFilter.from_model( + ml_models[sym], threshold=self.cfg.ml_threshold + ) else: self.ml_filters[sym] = None diff --git a/src/ml_filter.py b/src/ml_filter.py index 7491610..1336450 100644 --- a/src/ml_filter.py +++ b/src/ml_filter.py @@ -155,6 +155,21 @@ class MLFilter: logger.warning(f"ML 필터 예측 오류 (진입 차단): {e}") return False + @classmethod + def from_model(cls, model, threshold: float = 0.55) -> "MLFilter": + """외부에서 학습된 LightGBM 모델을 주입하여 MLFilter를 생성한다. + backtester walk-forward에서 사용.""" + instance = cls.__new__(cls) + instance._disabled = False + instance._onnx_session = None + instance._lgbm_model = model + instance._threshold = threshold + instance._onnx_path = Path("/dev/null") + instance._lgbm_path = Path("/dev/null") + instance._loaded_onnx_mtime = 0.0 + instance._loaded_lgbm_mtime = 0.0 + return instance + def reload_model(self): """외부에서 강제 리로드할 때 사용 (하위 호환).""" prev_backend = self.active_backend diff --git a/tests/test_ml_pipeline_fixes.py b/tests/test_ml_pipeline_fixes.py index 24a1f6a..a9ef2d1 100644 --- a/tests/test_ml_pipeline_fixes.py +++ b/tests/test_ml_pipeline_fixes.py @@ -100,3 +100,16 @@ def test_mlx_no_double_normalization(): assert np.allclose(model._mean, 0.0), "normalize=False시 mean은 0이어야 한다" assert np.allclose(model._std, 1.0), "normalize=False시 std는 1이어야 한다" + + +def test_ml_filter_from_model(): + """MLFilter.from_model()로 LightGBM 모델을 주입할 수 있어야 한다.""" + from src.ml_filter import MLFilter + from unittest.mock import MagicMock + + mock_model = MagicMock() + mock_model.predict_proba.return_value = [[0.3, 0.7]] + + mf = MLFilter.from_model(mock_model, threshold=0.55) + assert mf.is_model_loaded() + assert mf.active_backend == "LightGBM"