From ffa6e443c1c02044d2e25c46c031d2366561db4b Mon Sep 17 00:00:00 2001 From: 21in7 Date: Wed, 4 Mar 2026 20:13:07 +0900 Subject: [PATCH] feat: add --compare flag for OI derived features A/B comparison Co-Authored-By: Claude Opus 4.6 --- scripts/train_model.py | 113 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 112 insertions(+), 1 deletion(-) diff --git a/scripts/train_model.py b/scripts/train_model.py index bcf689a..dbed0a7 100644 --- a/scripts/train_model.py +++ b/scripts/train_model.py @@ -422,6 +422,113 @@ def walk_forward_auc( print(f" 폴드별: {[round(a, 4) for a in aucs]}") +def compare(data_path: str, time_weight_decay: float = 2.0, tuned_params_path: str | None = None): + """기존 피처 vs OI 파생 피처 추가 버전 A/B 비교.""" + import warnings + + print("=" * 70) + print(" OI 파생 피처 A/B 비교 (30일 데이터 기반, 방향성 참고용)") + print("=" * 70) + + df_raw = pd.read_parquet(data_path) + base_cols = ["open", "high", "low", "close", "volume"] + btc_df = eth_df = None + if "close_btc" in df_raw.columns: + btc_df = df_raw[[c + "_btc" for c in base_cols]].copy() + btc_df.columns = base_cols + if "close_eth" in df_raw.columns: + eth_df = df_raw[[c + "_eth" for c in base_cols]].copy() + eth_df.columns = base_cols + df = df_raw[base_cols].copy() + if "oi_change" in df_raw.columns: + df["oi_change"] = df_raw["oi_change"] + if "funding_rate" in df_raw.columns: + df["funding_rate"] = df_raw["funding_rate"] + + dataset = generate_dataset_vectorized( + df, btc_df=btc_df, eth_df=eth_df, + time_weight_decay=time_weight_decay, + negative_ratio=5, + ) + + if dataset.empty: + raise ValueError("데이터셋 생성 실패") + + lgbm_params, weight_scale = _load_lgbm_params(tuned_params_path) + + # Baseline: OI 파생 피처 제외 + BASELINE_EXCLUDE = {"oi_change_ma5", "oi_price_spread"} + baseline_cols = [c for c in FEATURE_COLS if c in dataset.columns and c not in BASELINE_EXCLUDE] + new_cols = [c for c in FEATURE_COLS if c in dataset.columns] + + results = {} + for label, cols in [("Baseline", baseline_cols), ("New", new_cols)]: + X = dataset[cols] + y = dataset["label"] + w = dataset["sample_weight"].values + source = dataset["source"].values if "source" in dataset.columns else np.full(len(X), "signal") + + split = int(len(X) * 0.8) + X_tr, X_val = X.iloc[:split], X.iloc[split:] + y_tr, y_val = y.iloc[:split], y.iloc[split:] + w_tr = (w[:split] * weight_scale).astype(np.float32) + source_tr = source[:split] + + balanced_idx = stratified_undersample(y_tr.values, source_tr, seed=42) + X_tr_b = X_tr.iloc[balanced_idx] + y_tr_b = y_tr.iloc[balanced_idx] + w_tr_b = w_tr[balanced_idx] + + model = lgb.LGBMClassifier(**lgbm_params, random_state=42, verbose=-1) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + model.fit(X_tr_b, y_tr_b, sample_weight=w_tr_b) + + proba = model.predict_proba(X_val)[:, 1] + auc = roc_auc_score(y_val, proba) if len(np.unique(y_val)) > 1 else 0.5 + + precs, recs, thrs = precision_recall_curve(y_val, proba) + precs, recs = precs[:-1], recs[:-1] + valid_idx = np.where(recs >= 0.15)[0] + if len(valid_idx) > 0: + best_i = valid_idx[np.argmax(precs[valid_idx])] + thr, prec, rec = float(thrs[best_i]), float(precs[best_i]), float(recs[best_i]) + else: + thr, prec, rec = 0.50, 0.0, 0.0 + + # Feature importance + imp = dict(zip(cols, model.feature_importances_)) + top10 = sorted(imp.items(), key=lambda x: x[1], reverse=True)[:10] + + results[label] = { + "auc": auc, "precision": prec, "recall": rec, + "threshold": thr, "n_val": len(y_val), + "n_val_pos": int(y_val.sum()), "top10": top10, + } + + # 비교 테이블 출력 + n_base = len(baseline_cols) + n_new = len(new_cols) + print(f"\n{'지표':<20} {f'Baseline({n_base})':>15} {f'New({n_new})':>15} {'Delta':>10}") + print("-" * 62) + for metric in ["auc", "precision", "recall", "threshold"]: + b = results["Baseline"][metric] + n = results["New"][metric] + d = n - b + sign = "+" if d > 0 else "" + print(f"{metric:<20} {b:>15.4f} {n:>15.4f} {sign}{d:>9.4f}") + + n_val = results["Baseline"]["n_val"] + n_pos = results["Baseline"]["n_val_pos"] + print(f"\n검증셋: n={n_val} (양성={n_pos}, 음성={n_val - n_pos})") + print("⚠ 30일 데이터 기반 — 방향성 참고용\n") + + print("Feature Importance Top 10 (New):") + for feat_name, imp_val in results["New"]["top10"]: + marker = " ← NEW" if feat_name in BASELINE_EXCLUDE else "" + print(f" {feat_name:<25} {imp_val:>6}{marker}") + + def main(): parser = argparse.ArgumentParser() parser.add_argument("--data", default="data/combined_15m.parquet") @@ -435,9 +542,13 @@ def main(): "--tuned-params", type=str, default=None, help="Optuna 튜닝 결과 JSON 경로 (지정 시 기본 파라미터를 덮어씀)", ) + parser.add_argument("--compare", action="store_true", + help="OI 파생 피처 추가 전후 A/B 성능 비교") args = parser.parse_args() - if args.wf: + if args.compare: + compare(args.data, time_weight_decay=args.decay, tuned_params_path=args.tuned_params) + elif args.wf: walk_forward_auc( args.data, time_weight_decay=args.decay,