fix(mlx): use stratified_undersample consistent with LightGBM

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
21in7
2026-03-21 18:33:36 +09:00
parent 24f0faa540
commit a34fc6f996

View File

@@ -17,7 +17,7 @@ import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score, classification_report
from src.dataset_builder import generate_dataset_vectorized
from src.dataset_builder import generate_dataset_vectorized, stratified_undersample
from src.ml_features import FEATURE_COLS
from src.mlx_filter import MLXFilter
@@ -59,7 +59,7 @@ def train_mlx(data_path: str, time_weight_decay: float = 2.0, atr_sl_mult: float
print("\n데이터셋 생성 중...")
t0 = time.perf_counter()
dataset = generate_dataset_vectorized(df, btc_df=btc_df, eth_df=eth_df, time_weight_decay=time_weight_decay,
atr_sl_mult=atr_sl_mult, atr_tp_mult=atr_tp_mult)
atr_sl_mult=atr_sl_mult, atr_tp_mult=atr_tp_mult, negative_ratio=5)
t1 = time.perf_counter()
print(f"데이터셋 생성 완료: {t1 - t0:.1f}초, {len(dataset)}개 샘플")
@@ -86,16 +86,10 @@ def train_mlx(data_path: str, time_weight_decay: float = 2.0, atr_sl_mult: float
y_train, y_val = y.iloc[:split], y.iloc[split:]
w_train = w[:split]
# --- 클래스 불균형 처리: 언더샘플링 (가중치 인덱스 보존) ---
pos_idx = np.where(y_train == 1)[0]
neg_idx = np.where(y_train == 0)[0]
if len(neg_idx) > len(pos_idx):
np.random.seed(42)
neg_idx = np.random.choice(neg_idx, size=len(pos_idx), replace=False)
balanced_idx = np.concatenate([pos_idx, neg_idx])
np.random.shuffle(balanced_idx)
# --- 클래스 불균형 처리: stratified 언더샘플링 (Signal 전수 유지, HOLD만 샘플링) ---
source = dataset["source"].values if "source" in dataset.columns else np.full(len(dataset), "signal")
source_train = source[:split]
balanced_idx = stratified_undersample(y_train.values, source_train, seed=42)
X_train = X_train.iloc[balanced_idx]
y_train = y_train.iloc[balanced_idx]
@@ -181,7 +175,7 @@ def walk_forward_auc(
dataset = generate_dataset_vectorized(
df, btc_df=btc_df, eth_df=eth_df, time_weight_decay=time_weight_decay,
atr_sl_mult=atr_sl_mult, atr_tp_mult=atr_tp_mult,
atr_sl_mult=atr_sl_mult, atr_tp_mult=atr_tp_mult, negative_ratio=5,
)
missing = [c for c in FEATURE_COLS if c not in dataset.columns]
for col in missing:
@@ -190,6 +184,7 @@ def walk_forward_auc(
X_all = dataset[FEATURE_COLS].values.astype(np.float32)
y_all = dataset["label"].values.astype(np.float32)
w_all = dataset["sample_weight"].values.astype(np.float32)
source_all = dataset["source"].values if "source" in dataset.columns else np.full(len(dataset), "signal")
n = len(dataset)
step = max(1, int(n * (1 - train_ratio) / n_splits))
@@ -208,12 +203,8 @@ def walk_forward_auc(
X_val_raw = X_all[tr_end:val_end]
y_val = y_all[tr_end:val_end]
pos_idx = np.where(y_tr == 1)[0]
neg_idx = np.where(y_tr == 0)[0]
if len(neg_idx) > len(pos_idx):
np.random.seed(42)
neg_idx = np.random.choice(neg_idx, size=len(pos_idx), replace=False)
bal_idx = np.sort(np.concatenate([pos_idx, neg_idx]))
source_tr = source_all[:tr_end]
bal_idx = stratified_undersample(y_tr, source_tr, seed=42)
X_tr_bal = X_tr_raw[bal_idx]
y_tr_bal = y_tr[bal_idx]