fix(mlx): use stratified_undersample consistent with LightGBM
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -17,7 +17,7 @@ import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.metrics import roc_auc_score, classification_report
|
||||
|
||||
from src.dataset_builder import generate_dataset_vectorized
|
||||
from src.dataset_builder import generate_dataset_vectorized, stratified_undersample
|
||||
from src.ml_features import FEATURE_COLS
|
||||
from src.mlx_filter import MLXFilter
|
||||
|
||||
@@ -59,7 +59,7 @@ def train_mlx(data_path: str, time_weight_decay: float = 2.0, atr_sl_mult: float
|
||||
print("\n데이터셋 생성 중...")
|
||||
t0 = time.perf_counter()
|
||||
dataset = generate_dataset_vectorized(df, btc_df=btc_df, eth_df=eth_df, time_weight_decay=time_weight_decay,
|
||||
atr_sl_mult=atr_sl_mult, atr_tp_mult=atr_tp_mult)
|
||||
atr_sl_mult=atr_sl_mult, atr_tp_mult=atr_tp_mult, negative_ratio=5)
|
||||
t1 = time.perf_counter()
|
||||
print(f"데이터셋 생성 완료: {t1 - t0:.1f}초, {len(dataset)}개 샘플")
|
||||
|
||||
@@ -86,16 +86,10 @@ def train_mlx(data_path: str, time_weight_decay: float = 2.0, atr_sl_mult: float
|
||||
y_train, y_val = y.iloc[:split], y.iloc[split:]
|
||||
w_train = w[:split]
|
||||
|
||||
# --- 클래스 불균형 처리: 언더샘플링 (가중치 인덱스 보존) ---
|
||||
pos_idx = np.where(y_train == 1)[0]
|
||||
neg_idx = np.where(y_train == 0)[0]
|
||||
|
||||
if len(neg_idx) > len(pos_idx):
|
||||
np.random.seed(42)
|
||||
neg_idx = np.random.choice(neg_idx, size=len(pos_idx), replace=False)
|
||||
|
||||
balanced_idx = np.concatenate([pos_idx, neg_idx])
|
||||
np.random.shuffle(balanced_idx)
|
||||
# --- 클래스 불균형 처리: stratified 언더샘플링 (Signal 전수 유지, HOLD만 샘플링) ---
|
||||
source = dataset["source"].values if "source" in dataset.columns else np.full(len(dataset), "signal")
|
||||
source_train = source[:split]
|
||||
balanced_idx = stratified_undersample(y_train.values, source_train, seed=42)
|
||||
|
||||
X_train = X_train.iloc[balanced_idx]
|
||||
y_train = y_train.iloc[balanced_idx]
|
||||
@@ -181,7 +175,7 @@ def walk_forward_auc(
|
||||
|
||||
dataset = generate_dataset_vectorized(
|
||||
df, btc_df=btc_df, eth_df=eth_df, time_weight_decay=time_weight_decay,
|
||||
atr_sl_mult=atr_sl_mult, atr_tp_mult=atr_tp_mult,
|
||||
atr_sl_mult=atr_sl_mult, atr_tp_mult=atr_tp_mult, negative_ratio=5,
|
||||
)
|
||||
missing = [c for c in FEATURE_COLS if c not in dataset.columns]
|
||||
for col in missing:
|
||||
@@ -190,6 +184,7 @@ def walk_forward_auc(
|
||||
X_all = dataset[FEATURE_COLS].values.astype(np.float32)
|
||||
y_all = dataset["label"].values.astype(np.float32)
|
||||
w_all = dataset["sample_weight"].values.astype(np.float32)
|
||||
source_all = dataset["source"].values if "source" in dataset.columns else np.full(len(dataset), "signal")
|
||||
n = len(dataset)
|
||||
|
||||
step = max(1, int(n * (1 - train_ratio) / n_splits))
|
||||
@@ -208,12 +203,8 @@ def walk_forward_auc(
|
||||
X_val_raw = X_all[tr_end:val_end]
|
||||
y_val = y_all[tr_end:val_end]
|
||||
|
||||
pos_idx = np.where(y_tr == 1)[0]
|
||||
neg_idx = np.where(y_tr == 0)[0]
|
||||
if len(neg_idx) > len(pos_idx):
|
||||
np.random.seed(42)
|
||||
neg_idx = np.random.choice(neg_idx, size=len(pos_idx), replace=False)
|
||||
bal_idx = np.sort(np.concatenate([pos_idx, neg_idx]))
|
||||
source_tr = source_all[:tr_end]
|
||||
bal_idx = stratified_undersample(y_tr, source_tr, seed=42)
|
||||
|
||||
X_tr_bal = X_tr_raw[bal_idx]
|
||||
y_tr_bal = y_tr[bal_idx]
|
||||
|
||||
Reference in New Issue
Block a user