fix(mlx): remove double normalization in walk-forward validation

Add normalize=False parameter to MLXFilter.fit() so external callers
can skip internal normalization. Remove the external normalization +
manual _mean/_std reset hack from walk_forward_auc() in train_mlx_model.py.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
21in7
2026-03-21 18:31:11 +09:00
parent 0fe87bb366
commit 24f0faa540
3 changed files with 40 additions and 22 deletions

View File

@@ -219,17 +219,8 @@ def walk_forward_auc(
y_tr_bal = y_tr[bal_idx]
w_tr_bal = w_tr[bal_idx]
# 폴드별 정규화 (학습 데이터 기준으로 계산, 검증에도 동일 적용)
mean = X_tr_bal.mean(axis=0)
std = X_tr_bal.std(axis=0) + 1e-8
X_tr_norm = (X_tr_bal - mean) / std
X_val_norm = (X_val_raw - mean) / std
# DataFrame으로 래핑해서 MLXFilter.fit()에 전달
# fit() 내부 정규화가 덮어쓰지 않도록 이미 정규화된 데이터를 넘기고
# _mean=0, _std=1로 고정해 이중 정규화를 방지
X_tr_df = pd.DataFrame(X_tr_norm, columns=FEATURE_COLS)
X_val_df = pd.DataFrame(X_val_norm, columns=FEATURE_COLS)
X_tr_df = pd.DataFrame(X_tr_bal, columns=FEATURE_COLS)
X_val_df = pd.DataFrame(X_val_raw, columns=FEATURE_COLS)
model = MLXFilter(
input_dim=len(FEATURE_COLS),
@@ -239,9 +230,7 @@ def walk_forward_auc(
batch_size=256,
)
model.fit(X_tr_df, pd.Series(y_tr_bal), sample_weight=w_tr_bal)
# fit()이 내부에서 다시 정규화하므로 저장된 mean/std를 항등 변환으로 교체
model._mean = np.zeros(len(FEATURE_COLS), dtype=np.float32)
model._std = np.ones(len(FEATURE_COLS), dtype=np.float32)
# fit() handles normalization internally, predict_proba() applies same mean/std
proba = model.predict_proba(X_val_df)
auc = roc_auc_score(y_val, proba) if len(np.unique(y_val)) > 1 else 0.5