feat(backtest): add --compare-ml for ML on/off walk-forward comparison

Runs WalkForwardBacktester twice (use_ml=True/False), prints side-by-side
comparison of PF, win rate, MDD, Sharpe, and auto-judges ML filter value.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
21in7
2026-03-21 19:58:24 +09:00
parent c29d3e0569
commit b5a5510499
3 changed files with 465 additions and 0 deletions

View File

@@ -50,6 +50,8 @@ def parse_args():
# Walk-Forward
p.add_argument("--walk-forward", action="store_true", help="Walk-Forward 백테스트 (기간별 모델 학습/검증)")
p.add_argument("--compare-ml", action="store_true",
help="ML on vs off Walk-Forward 비교 (--walk-forward 자동 활성화)")
p.add_argument("--train-months", type=int, default=6, help="WF 학습 윈도우 개월 (기본: 6)")
p.add_argument("--test-months", type=int, default=1, help="WF 검증 윈도우 개월 (기본: 1)")
return p.parse_args()
@@ -148,6 +150,159 @@ def save_result(result: dict, cfg):
return path
def compare_ml(symbols: list[str], args):
"""ML on vs ML off Walk-Forward 백테스트 비교."""
base_kwargs = dict(
symbols=symbols,
start=args.start,
end=args.end,
initial_balance=args.balance,
leverage=args.leverage,
fee_pct=args.fee,
slippage_pct=args.slippage,
ml_threshold=args.ml_threshold,
atr_sl_mult=args.sl_atr,
atr_tp_mult=args.tp_atr,
signal_threshold=args.signal_threshold,
adx_threshold=args.adx_threshold,
volume_multiplier=args.vol_multiplier,
train_months=args.train_months,
test_months=args.test_months,
)
results = {}
for label, use_ml in [("ML OFF", False), ("ML ON", True)]:
print(f"\n{'='*60}")
print(f" Walk-Forward 백테스트: {label}")
print(f"{'='*60}")
cfg = WalkForwardConfig(**base_kwargs, use_ml=use_ml)
wf = WalkForwardBacktester(cfg)
result = wf.run()
results[label] = result
print_summary(result["summary"], cfg, mode="walk_forward")
if result.get("folds"):
print_fold_table(result["folds"])
_print_comparison(results, symbols)
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
if len(symbols) == 1:
out_dir = Path(f"results/{symbols[0].lower()}")
else:
out_dir = Path("results/combined")
out_dir.mkdir(parents=True, exist_ok=True)
path = out_dir / f"ml_comparison_{ts}.json"
comparison = {
"timestamp": datetime.now().isoformat(),
"symbols": symbols,
"ml_off": results["ML OFF"]["summary"],
"ml_on": results["ML ON"]["summary"],
}
def sanitize(obj):
if isinstance(obj, bool):
return obj
if isinstance(obj, (int, float)):
if isinstance(obj, float) and obj == float("inf"):
return "Infinity"
return obj
if isinstance(obj, dict):
return {k: sanitize(v) for k, v in obj.items()}
if isinstance(obj, list):
return [sanitize(v) for v in obj]
if isinstance(obj, (np.integer,)):
return int(obj)
if isinstance(obj, (np.floating,)):
return float(obj)
return obj
with open(path, "w") as f:
json.dump(sanitize(comparison), f, indent=2, ensure_ascii=False)
print(f"\n비교 결과 저장: {path}")
def _print_comparison(results: dict, symbols: list[str]):
"""ML on vs off 비교 리포트 출력."""
off = results["ML OFF"]["summary"]
on = results["ML ON"]["summary"]
print(f"\n{'='*64}")
print(f" ML ON vs OFF 비교 ({', '.join(symbols)})")
print(f"{'='*64}")
print(f" {'지표':<20} {'ML OFF':>12} {'ML ON':>12} {'Delta':>12}")
print(f"{''*64}")
metrics = [
("총 거래", "total_trades", "d"),
("총 PnL (USDT)", "total_pnl", ".2f"),
("수익률 (%)", "return_pct", ".2f"),
("승률 (%)", "win_rate", ".1f"),
("Profit Factor", "profit_factor", ".2f"),
("MDD (%)", "max_drawdown_pct", ".2f"),
("Sharpe", "sharpe_ratio", ".2f"),
]
for label, key, fmt in metrics:
v_off = off.get(key, 0)
v_on = on.get(key, 0)
if v_off == float("inf"):
v_off_str = "INF"
else:
v_off_str = f"{v_off:{fmt}}"
if v_on == float("inf"):
v_on_str = "INF"
else:
v_on_str = f"{v_on:{fmt}}"
if isinstance(v_off, (int, float)) and isinstance(v_on, (int, float)) \
and v_off != float("inf") and v_on != float("inf"):
delta = v_on - v_off
sign = "+" if delta > 0 else ""
delta_str = f"{sign}{delta:{fmt}}"
else:
delta_str = "N/A"
print(f" {label:<20} {v_off_str:>12} {v_on_str:>12} {delta_str:>12}")
pf_off = off.get("profit_factor", 0)
pf_on = on.get("profit_factor", 0)
wr_off = off.get("win_rate", 0)
wr_on = on.get("win_rate", 0)
mdd_off = off.get("max_drawdown_pct", 0)
mdd_on = on.get("max_drawdown_pct", 0)
print(f"{''*64}")
if pf_off == float("inf") or pf_on == float("inf"):
print(f" 판정: PF=INF — 한쪽 모드에서 손실 거래 없음 (거래 수 부족 가능), 판단 보류")
elif pf_off == 0:
print(f" 판정: ML OFF PF=0 — baseline 거래 없음, 판단 불가")
else:
pf_improvement = pf_on - pf_off
wr_improvement = wr_on - wr_off
mdd_improvement = mdd_off - mdd_on
improvements = []
if pf_improvement > 0.1:
improvements.append(f"PF +{pf_improvement:.2f}")
if wr_improvement > 2.0:
improvements.append(f"승률 +{wr_improvement:.1f}%p")
if mdd_improvement > 1.0:
improvements.append(f"MDD -{mdd_improvement:.1f}%p")
if len(improvements) >= 2:
verdict = f"ML 필터 투입 가치 있음 ({', '.join(improvements)})"
elif len(improvements) == 1:
verdict = f"ML 필터 조건부 투입 ({improvements[0]}, 다른 지표 변화 미미)"
else:
verdict = f"ML 필터 기여 미미 (PF {pf_improvement:+.2f}, 승률 {wr_improvement:+.1f}%p)"
print(f" 판정: {verdict}")
print(f"{'='*64}\n")
def main():
args = parse_args()
@@ -156,6 +311,12 @@ def main():
else:
symbols = [s.strip().upper() for s in args.symbols.split(",") if s.strip()]
if args.compare_ml:
if args.no_ml:
logger.warning("--no-ml is ignored when using --compare-ml")
compare_ml(symbols, args)
return
if args.walk_forward:
cfg = WalkForwardConfig(
symbols=symbols,