refactor(ml): add MLFilter.from_model(), fix validator initial_balance
- MLFilter.from_model() classmethod eliminates brittle __new__() private-attribute manipulation in backtester walk-forward model injection - backtest_validator._check_invariants() now accepts cfg and uses cfg.initial_balance instead of a hardcoded 1000.0 for the negative-balance invariant check - backtester.py walk-forward injection block simplified to use the new factory method Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -30,7 +30,7 @@ def validate(trades: list[dict], summary: dict, cfg) -> dict:
|
||||
results: list[CheckResult] = []
|
||||
|
||||
# 검증 1: 논리적 불변 조건
|
||||
results.extend(_check_invariants(trades))
|
||||
results.extend(_check_invariants(trades, cfg))
|
||||
|
||||
# 검증 2: 통계적 이상 감지
|
||||
results.extend(_check_statistics(trades, summary))
|
||||
@@ -47,7 +47,7 @@ def validate(trades: list[dict], summary: dict, cfg) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def _check_invariants(trades: list[dict]) -> list[CheckResult]:
|
||||
def _check_invariants(trades: list[dict], cfg=None) -> list[CheckResult]:
|
||||
"""논리적 불변 조건. 하나라도 위반 시 FAIL."""
|
||||
results = []
|
||||
|
||||
@@ -120,7 +120,7 @@ def _check_invariants(trades: list[dict]) -> list[CheckResult]:
|
||||
))
|
||||
|
||||
# 5. 잔고가 음수가 된 적 없음
|
||||
balance = 1000.0 # cfg.initial_balance를 몰라도 trades에서 추적 가능
|
||||
balance = cfg.initial_balance if cfg is not None else 1000.0
|
||||
min_balance = balance
|
||||
for t in trades:
|
||||
balance += t["net_pnl"]
|
||||
|
||||
@@ -317,16 +317,9 @@ class Backtester:
|
||||
self.ml_filters = {}
|
||||
for sym in self.cfg.symbols:
|
||||
if sym in ml_models and ml_models[sym] is not None:
|
||||
mf = MLFilter.__new__(MLFilter)
|
||||
mf._disabled = False
|
||||
mf._onnx_session = None
|
||||
mf._lgbm_model = ml_models[sym]
|
||||
mf._threshold = self.cfg.ml_threshold
|
||||
mf._onnx_path = Path("/dev/null")
|
||||
mf._lgbm_path = Path("/dev/null")
|
||||
mf._loaded_onnx_mtime = 0.0
|
||||
mf._loaded_lgbm_mtime = 0.0
|
||||
self.ml_filters[sym] = mf
|
||||
self.ml_filters[sym] = MLFilter.from_model(
|
||||
ml_models[sym], threshold=self.cfg.ml_threshold
|
||||
)
|
||||
else:
|
||||
self.ml_filters[sym] = None
|
||||
|
||||
|
||||
@@ -155,6 +155,21 @@ class MLFilter:
|
||||
logger.warning(f"ML 필터 예측 오류 (진입 차단): {e}")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, model, threshold: float = 0.55) -> "MLFilter":
|
||||
"""외부에서 학습된 LightGBM 모델을 주입하여 MLFilter를 생성한다.
|
||||
backtester walk-forward에서 사용."""
|
||||
instance = cls.__new__(cls)
|
||||
instance._disabled = False
|
||||
instance._onnx_session = None
|
||||
instance._lgbm_model = model
|
||||
instance._threshold = threshold
|
||||
instance._onnx_path = Path("/dev/null")
|
||||
instance._lgbm_path = Path("/dev/null")
|
||||
instance._loaded_onnx_mtime = 0.0
|
||||
instance._loaded_lgbm_mtime = 0.0
|
||||
return instance
|
||||
|
||||
def reload_model(self):
|
||||
"""외부에서 강제 리로드할 때 사용 (하위 호환)."""
|
||||
prev_backend = self.active_backend
|
||||
|
||||
@@ -100,3 +100,16 @@ def test_mlx_no_double_normalization():
|
||||
|
||||
assert np.allclose(model._mean, 0.0), "normalize=False시 mean은 0이어야 한다"
|
||||
assert np.allclose(model._std, 1.0), "normalize=False시 std는 1이어야 한다"
|
||||
|
||||
|
||||
def test_ml_filter_from_model():
|
||||
"""MLFilter.from_model()로 LightGBM 모델을 주입할 수 있어야 한다."""
|
||||
from src.ml_filter import MLFilter
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_model.predict_proba.return_value = [[0.3, 0.7]]
|
||||
|
||||
mf = MLFilter.from_model(mock_model, threshold=0.55)
|
||||
assert mf.is_model_loaded()
|
||||
assert mf.active_backend == "LightGBM"
|
||||
|
||||
Reference in New Issue
Block a user