refactor(ml): add MLFilter.from_model(), fix validator initial_balance

- MLFilter.from_model() classmethod eliminates brittle __new__() private-attribute
  manipulation in backtester walk-forward model injection
- backtest_validator._check_invariants() now accepts cfg and uses cfg.initial_balance
  instead of a hardcoded 1000.0 for the negative-balance invariant check
- backtester.py walk-forward injection block simplified to use the new factory method

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
21in7
2026-03-21 18:36:30 +09:00
parent a34fc6f996
commit 5bad7dd691
4 changed files with 34 additions and 13 deletions

View File

@@ -30,7 +30,7 @@ def validate(trades: list[dict], summary: dict, cfg) -> dict:
results: list[CheckResult] = []
# 검증 1: 논리적 불변 조건
results.extend(_check_invariants(trades))
results.extend(_check_invariants(trades, cfg))
# 검증 2: 통계적 이상 감지
results.extend(_check_statistics(trades, summary))
@@ -47,7 +47,7 @@ def validate(trades: list[dict], summary: dict, cfg) -> dict:
}
def _check_invariants(trades: list[dict]) -> list[CheckResult]:
def _check_invariants(trades: list[dict], cfg=None) -> list[CheckResult]:
"""논리적 불변 조건. 하나라도 위반 시 FAIL."""
results = []
@@ -120,7 +120,7 @@ def _check_invariants(trades: list[dict]) -> list[CheckResult]:
))
# 5. 잔고가 음수가 된 적 없음
balance = 1000.0 # cfg.initial_balance를 몰라도 trades에서 추적 가능
balance = cfg.initial_balance if cfg is not None else 1000.0
min_balance = balance
for t in trades:
balance += t["net_pnl"]

View File

@@ -317,16 +317,9 @@ class Backtester:
self.ml_filters = {}
for sym in self.cfg.symbols:
if sym in ml_models and ml_models[sym] is not None:
mf = MLFilter.__new__(MLFilter)
mf._disabled = False
mf._onnx_session = None
mf._lgbm_model = ml_models[sym]
mf._threshold = self.cfg.ml_threshold
mf._onnx_path = Path("/dev/null")
mf._lgbm_path = Path("/dev/null")
mf._loaded_onnx_mtime = 0.0
mf._loaded_lgbm_mtime = 0.0
self.ml_filters[sym] = mf
self.ml_filters[sym] = MLFilter.from_model(
ml_models[sym], threshold=self.cfg.ml_threshold
)
else:
self.ml_filters[sym] = None

View File

@@ -155,6 +155,21 @@ class MLFilter:
logger.warning(f"ML 필터 예측 오류 (진입 차단): {e}")
return False
@classmethod
def from_model(cls, model, threshold: float = 0.55) -> "MLFilter":
"""외부에서 학습된 LightGBM 모델을 주입하여 MLFilter를 생성한다.
backtester walk-forward에서 사용."""
instance = cls.__new__(cls)
instance._disabled = False
instance._onnx_session = None
instance._lgbm_model = model
instance._threshold = threshold
instance._onnx_path = Path("/dev/null")
instance._lgbm_path = Path("/dev/null")
instance._loaded_onnx_mtime = 0.0
instance._loaded_lgbm_mtime = 0.0
return instance
def reload_model(self):
"""외부에서 강제 리로드할 때 사용 (하위 호환)."""
prev_backend = self.active_backend

View File

@@ -100,3 +100,16 @@ def test_mlx_no_double_normalization():
assert np.allclose(model._mean, 0.0), "normalize=False시 mean은 0이어야 한다"
assert np.allclose(model._std, 1.0), "normalize=False시 std는 1이어야 한다"
def test_ml_filter_from_model():
"""MLFilter.from_model()로 LightGBM 모델을 주입할 수 있어야 한다."""
from src.ml_filter import MLFilter
from unittest.mock import MagicMock
mock_model = MagicMock()
mock_model.predict_proba.return_value = [[0.3, 0.7]]
mf = MLFilter.from_model(mock_model, threshold=0.55)
assert mf.is_model_loaded()
assert mf.active_backend == "LightGBM"