From 6ae0f9d81bf70fcaca6850de8aabe5cbb706dd93 Mon Sep 17 00:00:00 2001 From: 21in7 Date: Sun, 1 Mar 2026 23:53:49 +0900 Subject: [PATCH] =?UTF-8?q?fix:=20MLXFilter=20fit/predict=EC=97=90=20nan-s?= =?UTF-8?q?afe=20=EC=A0=95=EA=B7=9C=ED=99=94=20=EC=A0=81=EC=9A=A9=20(nanme?= =?UTF-8?q?an=20+=20nan=5Fto=5Fnum)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Made-with: Cursor --- src/mlx_filter.py | 8 ++++++-- tests/test_mlx_filter.py | 25 +++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/src/mlx_filter.py b/src/mlx_filter.py index 4545986..ccf7d04 100644 --- a/src/mlx_filter.py +++ b/src/mlx_filter.py @@ -140,9 +140,12 @@ class MLXFilter: X_np = X[FEATURE_COLS].values.astype(np.float32) y_np = y.values.astype(np.float32) - self._mean = X_np.mean(axis=0) - self._std = X_np.std(axis=0) + 1e-8 + # nan-safe 정규화: nanmean/nanstd로 통계 계산 후 nan → 0.0 대치 + # (z-score 후 0.0 = 평균값, 신경망에 줄 수 있는 가장 무난한 결측 대치값) + self._mean = np.nanmean(X_np, axis=0) + self._std = np.nanstd(X_np, axis=0) + 1e-8 X_np = (X_np - self._mean) / self._std + X_np = np.nan_to_num(X_np, nan=0.0) w_np = sample_weight.astype(np.float32) if sample_weight is not None else None @@ -186,6 +189,7 @@ class MLXFilter: X_np = X[FEATURE_COLS].values.astype(np.float32) if self._trained and self._mean is not None: X_np = (X_np - self._mean) / self._std + X_np = np.nan_to_num(X_np, nan=0.0) x = mx.array(X_np) self._model.eval() logits = self._model(x) diff --git a/tests/test_mlx_filter.py b/tests/test_mlx_filter.py index e19e937..af4a2c9 100644 --- a/tests/test_mlx_filter.py +++ b/tests/test_mlx_filter.py @@ -65,6 +65,31 @@ def test_mlx_filter_fit_and_predict(): assert np.all((proba >= 0.0) & (proba <= 1.0)) +def test_fit_with_nan_features(): + """oi_change 피처에 nan이 포함된 경우 학습이 정상 완료되어야 한다.""" + import numpy as np + import pandas as pd + from src.mlx_filter import MLXFilter + from src.ml_features import FEATURE_COLS + + n = 300 + np.random.seed(42) + X = pd.DataFrame( + np.random.randn(n, len(FEATURE_COLS)).astype(np.float32), + columns=FEATURE_COLS, + ) + # oi_change 앞 절반을 nan으로 + X["oi_change"] = np.where(np.arange(n) < n // 2, np.nan, X["oi_change"]) + y = pd.Series((np.random.rand(n) > 0.5).astype(np.float32)) + + model = MLXFilter(input_dim=len(FEATURE_COLS), hidden_dim=32, epochs=3) + model.fit(X, y) # nan 있어도 예외 없이 완료되어야 함 + + proba = model.predict_proba(X) + assert not np.any(np.isnan(proba)), "예측 확률에 nan이 없어야 함" + assert proba.min() >= 0.0 and proba.max() <= 1.0 + + def test_mlx_filter_save_load(tmp_path): """저장 후 로드한 모델이 동일한 예측값을 반환해야 한다.""" from src.mlx_filter import MLXFilter