diff --git a/tests/test_data_stream.py b/tests/test_data_stream.py index 935c3f2..6f4e4ac 100644 --- a/tests/test_data_stream.py +++ b/tests/test_data_stream.py @@ -23,6 +23,7 @@ def test_multi_symbol_stream_get_dataframe_returns_none_when_empty(): def test_multi_symbol_stream_get_dataframe_returns_df_when_full(): import pandas as pd + from src.data_stream import _MIN_CANDLES_FOR_SIGNAL stream = MultiSymbolStream( symbols=["XRPUSDT", "BTCUSDT", "ETHUSDT"], interval="1m", @@ -32,13 +33,13 @@ def test_multi_symbol_stream_get_dataframe_returns_df_when_full(): "timestamp": 1000, "open": 1.0, "high": 1.1, "low": 0.9, "close": 1.05, "volume": 100.0, "is_closed": True, } - for i in range(50): + for i in range(_MIN_CANDLES_FOR_SIGNAL): c = candle.copy() c["timestamp"] = 1000 + i stream.buffers["xrpusdt"].append(c) df = stream.get_dataframe("XRPUSDT") assert df is not None - assert len(df) == 50 + assert len(df) == _MIN_CANDLES_FOR_SIGNAL @pytest.mark.asyncio diff --git a/tests/test_ml_features.py b/tests/test_ml_features.py index b7a645f..1f3ffff 100644 --- a/tests/test_ml_features.py +++ b/tests/test_ml_features.py @@ -49,9 +49,9 @@ def test_build_features_rs_zero_when_btc_ret_zero(): features = build_features(xrp_df, "LONG", btc_df=btc_df, eth_df=eth_df) assert features["xrp_btc_rs"] == 0.0 -def test_feature_cols_has_21_items(): +def test_feature_cols_has_23_items(): from src.ml_features import FEATURE_COLS - assert len(FEATURE_COLS) == 21 + assert len(FEATURE_COLS) == 23 def make_df(n=100): diff --git a/tests/test_ml_filter.py b/tests/test_ml_filter.py index 27580f3..0c837ae 100644 --- a/tests/test_ml_filter.py +++ b/tests/test_ml_filter.py @@ -12,13 +12,19 @@ def make_features(side="LONG") -> pd.Series: def test_no_model_file_is_not_loaded(tmp_path): - f = MLFilter(model_path=str(tmp_path / "nonexistent.pkl")) + f = MLFilter( + onnx_path=str(tmp_path / "nonexistent.onnx"), + lgbm_path=str(tmp_path / "nonexistent.pkl"), + ) assert not f.is_model_loaded() def test_no_model_should_enter_returns_true(tmp_path): """모델 없으면 항상 진입 허용 (폴백)""" - f = MLFilter(model_path=str(tmp_path / "nonexistent.pkl")) + f = MLFilter( + onnx_path=str(tmp_path / "nonexistent.onnx"), + lgbm_path=str(tmp_path / "nonexistent.pkl"), + ) features = make_features() assert f.should_enter(features) is True @@ -28,7 +34,7 @@ def test_should_enter_above_threshold(): f = MLFilter(threshold=0.60) mock_model = MagicMock() mock_model.predict_proba.return_value = np.array([[0.35, 0.65]]) - f._model = mock_model + f._lgbm_model = mock_model features = make_features() assert f.should_enter(features) is True @@ -38,7 +44,7 @@ def test_should_enter_below_threshold(): f = MLFilter(threshold=0.60) mock_model = MagicMock() mock_model.predict_proba.return_value = np.array([[0.55, 0.45]]) - f._model = mock_model + f._lgbm_model = mock_model features = make_features() assert f.should_enter(features) is False @@ -48,8 +54,10 @@ def test_reload_model(tmp_path): import joblib # 모델 파일이 없는 상태에서 시작 - model_path = tmp_path / "lgbm_filter.pkl" - f = MLFilter(model_path=str(model_path)) + f = MLFilter( + onnx_path=str(tmp_path / "nonexistent.onnx"), + lgbm_path=str(tmp_path / "lgbm_filter.pkl"), + ) assert not f.is_model_loaded() # _model을 직접 주입해서 is_model_loaded가 True인지 확인