fix(mlx): use stratified_undersample consistent with LightGBM
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -17,7 +17,7 @@ import numpy as np
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
from sklearn.metrics import roc_auc_score, classification_report
|
from sklearn.metrics import roc_auc_score, classification_report
|
||||||
|
|
||||||
from src.dataset_builder import generate_dataset_vectorized
|
from src.dataset_builder import generate_dataset_vectorized, stratified_undersample
|
||||||
from src.ml_features import FEATURE_COLS
|
from src.ml_features import FEATURE_COLS
|
||||||
from src.mlx_filter import MLXFilter
|
from src.mlx_filter import MLXFilter
|
||||||
|
|
||||||
@@ -59,7 +59,7 @@ def train_mlx(data_path: str, time_weight_decay: float = 2.0, atr_sl_mult: float
|
|||||||
print("\n데이터셋 생성 중...")
|
print("\n데이터셋 생성 중...")
|
||||||
t0 = time.perf_counter()
|
t0 = time.perf_counter()
|
||||||
dataset = generate_dataset_vectorized(df, btc_df=btc_df, eth_df=eth_df, time_weight_decay=time_weight_decay,
|
dataset = generate_dataset_vectorized(df, btc_df=btc_df, eth_df=eth_df, time_weight_decay=time_weight_decay,
|
||||||
atr_sl_mult=atr_sl_mult, atr_tp_mult=atr_tp_mult)
|
atr_sl_mult=atr_sl_mult, atr_tp_mult=atr_tp_mult, negative_ratio=5)
|
||||||
t1 = time.perf_counter()
|
t1 = time.perf_counter()
|
||||||
print(f"데이터셋 생성 완료: {t1 - t0:.1f}초, {len(dataset)}개 샘플")
|
print(f"데이터셋 생성 완료: {t1 - t0:.1f}초, {len(dataset)}개 샘플")
|
||||||
|
|
||||||
@@ -86,16 +86,10 @@ def train_mlx(data_path: str, time_weight_decay: float = 2.0, atr_sl_mult: float
|
|||||||
y_train, y_val = y.iloc[:split], y.iloc[split:]
|
y_train, y_val = y.iloc[:split], y.iloc[split:]
|
||||||
w_train = w[:split]
|
w_train = w[:split]
|
||||||
|
|
||||||
# --- 클래스 불균형 처리: 언더샘플링 (가중치 인덱스 보존) ---
|
# --- 클래스 불균형 처리: stratified 언더샘플링 (Signal 전수 유지, HOLD만 샘플링) ---
|
||||||
pos_idx = np.where(y_train == 1)[0]
|
source = dataset["source"].values if "source" in dataset.columns else np.full(len(dataset), "signal")
|
||||||
neg_idx = np.where(y_train == 0)[0]
|
source_train = source[:split]
|
||||||
|
balanced_idx = stratified_undersample(y_train.values, source_train, seed=42)
|
||||||
if len(neg_idx) > len(pos_idx):
|
|
||||||
np.random.seed(42)
|
|
||||||
neg_idx = np.random.choice(neg_idx, size=len(pos_idx), replace=False)
|
|
||||||
|
|
||||||
balanced_idx = np.concatenate([pos_idx, neg_idx])
|
|
||||||
np.random.shuffle(balanced_idx)
|
|
||||||
|
|
||||||
X_train = X_train.iloc[balanced_idx]
|
X_train = X_train.iloc[balanced_idx]
|
||||||
y_train = y_train.iloc[balanced_idx]
|
y_train = y_train.iloc[balanced_idx]
|
||||||
@@ -181,7 +175,7 @@ def walk_forward_auc(
|
|||||||
|
|
||||||
dataset = generate_dataset_vectorized(
|
dataset = generate_dataset_vectorized(
|
||||||
df, btc_df=btc_df, eth_df=eth_df, time_weight_decay=time_weight_decay,
|
df, btc_df=btc_df, eth_df=eth_df, time_weight_decay=time_weight_decay,
|
||||||
atr_sl_mult=atr_sl_mult, atr_tp_mult=atr_tp_mult,
|
atr_sl_mult=atr_sl_mult, atr_tp_mult=atr_tp_mult, negative_ratio=5,
|
||||||
)
|
)
|
||||||
missing = [c for c in FEATURE_COLS if c not in dataset.columns]
|
missing = [c for c in FEATURE_COLS if c not in dataset.columns]
|
||||||
for col in missing:
|
for col in missing:
|
||||||
@@ -190,6 +184,7 @@ def walk_forward_auc(
|
|||||||
X_all = dataset[FEATURE_COLS].values.astype(np.float32)
|
X_all = dataset[FEATURE_COLS].values.astype(np.float32)
|
||||||
y_all = dataset["label"].values.astype(np.float32)
|
y_all = dataset["label"].values.astype(np.float32)
|
||||||
w_all = dataset["sample_weight"].values.astype(np.float32)
|
w_all = dataset["sample_weight"].values.astype(np.float32)
|
||||||
|
source_all = dataset["source"].values if "source" in dataset.columns else np.full(len(dataset), "signal")
|
||||||
n = len(dataset)
|
n = len(dataset)
|
||||||
|
|
||||||
step = max(1, int(n * (1 - train_ratio) / n_splits))
|
step = max(1, int(n * (1 - train_ratio) / n_splits))
|
||||||
@@ -208,12 +203,8 @@ def walk_forward_auc(
|
|||||||
X_val_raw = X_all[tr_end:val_end]
|
X_val_raw = X_all[tr_end:val_end]
|
||||||
y_val = y_all[tr_end:val_end]
|
y_val = y_all[tr_end:val_end]
|
||||||
|
|
||||||
pos_idx = np.where(y_tr == 1)[0]
|
source_tr = source_all[:tr_end]
|
||||||
neg_idx = np.where(y_tr == 0)[0]
|
bal_idx = stratified_undersample(y_tr, source_tr, seed=42)
|
||||||
if len(neg_idx) > len(pos_idx):
|
|
||||||
np.random.seed(42)
|
|
||||||
neg_idx = np.random.choice(neg_idx, size=len(pos_idx), replace=False)
|
|
||||||
bal_idx = np.sort(np.concatenate([pos_idx, neg_idx]))
|
|
||||||
|
|
||||||
X_tr_bal = X_tr_raw[bal_idx]
|
X_tr_bal = X_tr_raw[bal_idx]
|
||||||
y_tr_bal = y_tr[bal_idx]
|
y_tr_bal = y_tr[bal_idx]
|
||||||
|
|||||||
Reference in New Issue
Block a user