From a34fc6f9962a3250b670e7d8b9f78da32106cab3 Mon Sep 17 00:00:00 2001 From: 21in7 Date: Sat, 21 Mar 2026 18:33:36 +0900 Subject: [PATCH] fix(mlx): use stratified_undersample consistent with LightGBM Co-Authored-By: Claude Sonnet 4.6 --- scripts/train_mlx_model.py | 29 ++++++++++------------------- 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/scripts/train_mlx_model.py b/scripts/train_mlx_model.py index d04241c..e2ed162 100644 --- a/scripts/train_mlx_model.py +++ b/scripts/train_mlx_model.py @@ -17,7 +17,7 @@ import numpy as np import pandas as pd from sklearn.metrics import roc_auc_score, classification_report -from src.dataset_builder import generate_dataset_vectorized +from src.dataset_builder import generate_dataset_vectorized, stratified_undersample from src.ml_features import FEATURE_COLS from src.mlx_filter import MLXFilter @@ -59,7 +59,7 @@ def train_mlx(data_path: str, time_weight_decay: float = 2.0, atr_sl_mult: float print("\n데이터셋 생성 중...") t0 = time.perf_counter() dataset = generate_dataset_vectorized(df, btc_df=btc_df, eth_df=eth_df, time_weight_decay=time_weight_decay, - atr_sl_mult=atr_sl_mult, atr_tp_mult=atr_tp_mult) + atr_sl_mult=atr_sl_mult, atr_tp_mult=atr_tp_mult, negative_ratio=5) t1 = time.perf_counter() print(f"데이터셋 생성 완료: {t1 - t0:.1f}초, {len(dataset)}개 샘플") @@ -86,16 +86,10 @@ def train_mlx(data_path: str, time_weight_decay: float = 2.0, atr_sl_mult: float y_train, y_val = y.iloc[:split], y.iloc[split:] w_train = w[:split] - # --- 클래스 불균형 처리: 언더샘플링 (가중치 인덱스 보존) --- - pos_idx = np.where(y_train == 1)[0] - neg_idx = np.where(y_train == 0)[0] - - if len(neg_idx) > len(pos_idx): - np.random.seed(42) - neg_idx = np.random.choice(neg_idx, size=len(pos_idx), replace=False) - - balanced_idx = np.concatenate([pos_idx, neg_idx]) - np.random.shuffle(balanced_idx) + # --- 클래스 불균형 처리: stratified 언더샘플링 (Signal 전수 유지, HOLD만 샘플링) --- + source = dataset["source"].values if "source" in dataset.columns else np.full(len(dataset), "signal") + source_train = source[:split] + balanced_idx = stratified_undersample(y_train.values, source_train, seed=42) X_train = X_train.iloc[balanced_idx] y_train = y_train.iloc[balanced_idx] @@ -181,7 +175,7 @@ def walk_forward_auc( dataset = generate_dataset_vectorized( df, btc_df=btc_df, eth_df=eth_df, time_weight_decay=time_weight_decay, - atr_sl_mult=atr_sl_mult, atr_tp_mult=atr_tp_mult, + atr_sl_mult=atr_sl_mult, atr_tp_mult=atr_tp_mult, negative_ratio=5, ) missing = [c for c in FEATURE_COLS if c not in dataset.columns] for col in missing: @@ -190,6 +184,7 @@ def walk_forward_auc( X_all = dataset[FEATURE_COLS].values.astype(np.float32) y_all = dataset["label"].values.astype(np.float32) w_all = dataset["sample_weight"].values.astype(np.float32) + source_all = dataset["source"].values if "source" in dataset.columns else np.full(len(dataset), "signal") n = len(dataset) step = max(1, int(n * (1 - train_ratio) / n_splits)) @@ -208,12 +203,8 @@ def walk_forward_auc( X_val_raw = X_all[tr_end:val_end] y_val = y_all[tr_end:val_end] - pos_idx = np.where(y_tr == 1)[0] - neg_idx = np.where(y_tr == 0)[0] - if len(neg_idx) > len(pos_idx): - np.random.seed(42) - neg_idx = np.random.choice(neg_idx, size=len(pos_idx), replace=False) - bal_idx = np.sort(np.concatenate([pos_idx, neg_idx])) + source_tr = source_all[:tr_end] + bal_idx = stratified_undersample(y_tr, source_tr, seed=42) X_tr_bal = X_tr_raw[bal_idx] y_tr_bal = y_tr[bal_idx]