fix: MLXFilter fit/predict에 nan-safe 정규화 적용 (nanmean + nan_to_num)

Made-with: Cursor
This commit is contained in:
21in7
2026-03-01 23:53:49 +09:00
parent 820d8e0213
commit 6ae0f9d81b
2 changed files with 31 additions and 2 deletions

View File

@@ -65,6 +65,31 @@ def test_mlx_filter_fit_and_predict():
assert np.all((proba >= 0.0) & (proba <= 1.0))
def test_fit_with_nan_features():
"""oi_change 피처에 nan이 포함된 경우 학습이 정상 완료되어야 한다."""
import numpy as np
import pandas as pd
from src.mlx_filter import MLXFilter
from src.ml_features import FEATURE_COLS
n = 300
np.random.seed(42)
X = pd.DataFrame(
np.random.randn(n, len(FEATURE_COLS)).astype(np.float32),
columns=FEATURE_COLS,
)
# oi_change 앞 절반을 nan으로
X["oi_change"] = np.where(np.arange(n) < n // 2, np.nan, X["oi_change"])
y = pd.Series((np.random.rand(n) > 0.5).astype(np.float32))
model = MLXFilter(input_dim=len(FEATURE_COLS), hidden_dim=32, epochs=3)
model.fit(X, y) # nan 있어도 예외 없이 완료되어야 함
proba = model.predict_proba(X)
assert not np.any(np.isnan(proba)), "예측 확률에 nan이 없어야 함"
assert proba.min() >= 0.0 and proba.max() <= 1.0
def test_mlx_filter_save_load(tmp_path):
"""저장 후 로드한 모델이 동일한 예측값을 반환해야 한다."""
from src.mlx_filter import MLXFilter