From 518f1846b8b5acf570bd60a3f0a32776d0bf20f5 Mon Sep 17 00:00:00 2001 From: 21in7 Date: Mon, 2 Mar 2026 00:36:13 +0900 Subject: [PATCH] =?UTF-8?q?fix:=20=EA=B8=B0=EC=A1=B4=20=ED=85=8C=EC=8A=A4?= =?UTF-8?q?=ED=8A=B8=EB=A5=BC=20=ED=98=84=EC=9E=AC=20=EC=BD=94=EB=93=9C=20?= =?UTF-8?q?=EA=B5=AC=EC=A1=B0=EC=97=90=20=EB=A7=9E=EA=B2=8C=20=EC=88=98?= =?UTF-8?q?=EC=A0=95=20=E2=80=94=20MLFilter=20API,=20FEATURE=5FCOLS=20?= =?UTF-8?q?=EC=88=98,=20=EB=B2=84=ED=8D=BC=20=EC=B5=9C=EC=86=9F=EA=B0=92?= =?UTF-8?q?=20=EB=B0=98=EC=98=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Made-with: Cursor --- tests/test_data_stream.py | 5 +++-- tests/test_ml_features.py | 4 ++-- tests/test_ml_filter.py | 20 ++++++++++++++------ 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/tests/test_data_stream.py b/tests/test_data_stream.py index 935c3f2..6f4e4ac 100644 --- a/tests/test_data_stream.py +++ b/tests/test_data_stream.py @@ -23,6 +23,7 @@ def test_multi_symbol_stream_get_dataframe_returns_none_when_empty(): def test_multi_symbol_stream_get_dataframe_returns_df_when_full(): import pandas as pd + from src.data_stream import _MIN_CANDLES_FOR_SIGNAL stream = MultiSymbolStream( symbols=["XRPUSDT", "BTCUSDT", "ETHUSDT"], interval="1m", @@ -32,13 +33,13 @@ def test_multi_symbol_stream_get_dataframe_returns_df_when_full(): "timestamp": 1000, "open": 1.0, "high": 1.1, "low": 0.9, "close": 1.05, "volume": 100.0, "is_closed": True, } - for i in range(50): + for i in range(_MIN_CANDLES_FOR_SIGNAL): c = candle.copy() c["timestamp"] = 1000 + i stream.buffers["xrpusdt"].append(c) df = stream.get_dataframe("XRPUSDT") assert df is not None - assert len(df) == 50 + assert len(df) == _MIN_CANDLES_FOR_SIGNAL @pytest.mark.asyncio diff --git a/tests/test_ml_features.py b/tests/test_ml_features.py index b7a645f..1f3ffff 100644 --- a/tests/test_ml_features.py +++ b/tests/test_ml_features.py @@ -49,9 +49,9 @@ def test_build_features_rs_zero_when_btc_ret_zero(): features = build_features(xrp_df, "LONG", btc_df=btc_df, eth_df=eth_df) assert features["xrp_btc_rs"] == 0.0 -def test_feature_cols_has_21_items(): +def test_feature_cols_has_23_items(): from src.ml_features import FEATURE_COLS - assert len(FEATURE_COLS) == 21 + assert len(FEATURE_COLS) == 23 def make_df(n=100): diff --git a/tests/test_ml_filter.py b/tests/test_ml_filter.py index 27580f3..0c837ae 100644 --- a/tests/test_ml_filter.py +++ b/tests/test_ml_filter.py @@ -12,13 +12,19 @@ def make_features(side="LONG") -> pd.Series: def test_no_model_file_is_not_loaded(tmp_path): - f = MLFilter(model_path=str(tmp_path / "nonexistent.pkl")) + f = MLFilter( + onnx_path=str(tmp_path / "nonexistent.onnx"), + lgbm_path=str(tmp_path / "nonexistent.pkl"), + ) assert not f.is_model_loaded() def test_no_model_should_enter_returns_true(tmp_path): """모델 없으면 항상 진입 허용 (폴백)""" - f = MLFilter(model_path=str(tmp_path / "nonexistent.pkl")) + f = MLFilter( + onnx_path=str(tmp_path / "nonexistent.onnx"), + lgbm_path=str(tmp_path / "nonexistent.pkl"), + ) features = make_features() assert f.should_enter(features) is True @@ -28,7 +34,7 @@ def test_should_enter_above_threshold(): f = MLFilter(threshold=0.60) mock_model = MagicMock() mock_model.predict_proba.return_value = np.array([[0.35, 0.65]]) - f._model = mock_model + f._lgbm_model = mock_model features = make_features() assert f.should_enter(features) is True @@ -38,7 +44,7 @@ def test_should_enter_below_threshold(): f = MLFilter(threshold=0.60) mock_model = MagicMock() mock_model.predict_proba.return_value = np.array([[0.55, 0.45]]) - f._model = mock_model + f._lgbm_model = mock_model features = make_features() assert f.should_enter(features) is False @@ -48,8 +54,10 @@ def test_reload_model(tmp_path): import joblib # 모델 파일이 없는 상태에서 시작 - model_path = tmp_path / "lgbm_filter.pkl" - f = MLFilter(model_path=str(model_path)) + f = MLFilter( + onnx_path=str(tmp_path / "nonexistent.onnx"), + lgbm_path=str(tmp_path / "lgbm_filter.pkl"), + ) assert not f.is_model_loaded() # _model을 직접 주입해서 is_model_loaded가 True인지 확인