fix: 기존 테스트를 현재 코드 구조에 맞게 수정 — MLFilter API, FEATURE_COLS 수, 버퍼 최솟값 반영

Made-with: Cursor
This commit is contained in:
21in7
2026-03-02 00:36:13 +09:00
parent 3bfd1ca5a3
commit 518f1846b8
3 changed files with 19 additions and 10 deletions

View File

@@ -23,6 +23,7 @@ def test_multi_symbol_stream_get_dataframe_returns_none_when_empty():
def test_multi_symbol_stream_get_dataframe_returns_df_when_full(): def test_multi_symbol_stream_get_dataframe_returns_df_when_full():
import pandas as pd import pandas as pd
from src.data_stream import _MIN_CANDLES_FOR_SIGNAL
stream = MultiSymbolStream( stream = MultiSymbolStream(
symbols=["XRPUSDT", "BTCUSDT", "ETHUSDT"], symbols=["XRPUSDT", "BTCUSDT", "ETHUSDT"],
interval="1m", interval="1m",
@@ -32,13 +33,13 @@ def test_multi_symbol_stream_get_dataframe_returns_df_when_full():
"timestamp": 1000, "open": 1.0, "high": 1.1, "timestamp": 1000, "open": 1.0, "high": 1.1,
"low": 0.9, "close": 1.05, "volume": 100.0, "is_closed": True, "low": 0.9, "close": 1.05, "volume": 100.0, "is_closed": True,
} }
for i in range(50): for i in range(_MIN_CANDLES_FOR_SIGNAL):
c = candle.copy() c = candle.copy()
c["timestamp"] = 1000 + i c["timestamp"] = 1000 + i
stream.buffers["xrpusdt"].append(c) stream.buffers["xrpusdt"].append(c)
df = stream.get_dataframe("XRPUSDT") df = stream.get_dataframe("XRPUSDT")
assert df is not None assert df is not None
assert len(df) == 50 assert len(df) == _MIN_CANDLES_FOR_SIGNAL
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@@ -49,9 +49,9 @@ def test_build_features_rs_zero_when_btc_ret_zero():
features = build_features(xrp_df, "LONG", btc_df=btc_df, eth_df=eth_df) features = build_features(xrp_df, "LONG", btc_df=btc_df, eth_df=eth_df)
assert features["xrp_btc_rs"] == 0.0 assert features["xrp_btc_rs"] == 0.0
def test_feature_cols_has_21_items(): def test_feature_cols_has_23_items():
from src.ml_features import FEATURE_COLS from src.ml_features import FEATURE_COLS
assert len(FEATURE_COLS) == 21 assert len(FEATURE_COLS) == 23
def make_df(n=100): def make_df(n=100):

View File

@@ -12,13 +12,19 @@ def make_features(side="LONG") -> pd.Series:
def test_no_model_file_is_not_loaded(tmp_path): def test_no_model_file_is_not_loaded(tmp_path):
f = MLFilter(model_path=str(tmp_path / "nonexistent.pkl")) f = MLFilter(
onnx_path=str(tmp_path / "nonexistent.onnx"),
lgbm_path=str(tmp_path / "nonexistent.pkl"),
)
assert not f.is_model_loaded() assert not f.is_model_loaded()
def test_no_model_should_enter_returns_true(tmp_path): def test_no_model_should_enter_returns_true(tmp_path):
"""모델 없으면 항상 진입 허용 (폴백)""" """모델 없으면 항상 진입 허용 (폴백)"""
f = MLFilter(model_path=str(tmp_path / "nonexistent.pkl")) f = MLFilter(
onnx_path=str(tmp_path / "nonexistent.onnx"),
lgbm_path=str(tmp_path / "nonexistent.pkl"),
)
features = make_features() features = make_features()
assert f.should_enter(features) is True assert f.should_enter(features) is True
@@ -28,7 +34,7 @@ def test_should_enter_above_threshold():
f = MLFilter(threshold=0.60) f = MLFilter(threshold=0.60)
mock_model = MagicMock() mock_model = MagicMock()
mock_model.predict_proba.return_value = np.array([[0.35, 0.65]]) mock_model.predict_proba.return_value = np.array([[0.35, 0.65]])
f._model = mock_model f._lgbm_model = mock_model
features = make_features() features = make_features()
assert f.should_enter(features) is True assert f.should_enter(features) is True
@@ -38,7 +44,7 @@ def test_should_enter_below_threshold():
f = MLFilter(threshold=0.60) f = MLFilter(threshold=0.60)
mock_model = MagicMock() mock_model = MagicMock()
mock_model.predict_proba.return_value = np.array([[0.55, 0.45]]) mock_model.predict_proba.return_value = np.array([[0.55, 0.45]])
f._model = mock_model f._lgbm_model = mock_model
features = make_features() features = make_features()
assert f.should_enter(features) is False assert f.should_enter(features) is False
@@ -48,8 +54,10 @@ def test_reload_model(tmp_path):
import joblib import joblib
# 모델 파일이 없는 상태에서 시작 # 모델 파일이 없는 상태에서 시작
model_path = tmp_path / "lgbm_filter.pkl" f = MLFilter(
f = MLFilter(model_path=str(model_path)) onnx_path=str(tmp_path / "nonexistent.onnx"),
lgbm_path=str(tmp_path / "lgbm_filter.pkl"),
)
assert not f.is_model_loaded() assert not f.is_model_loaded()
# _model을 직접 주입해서 is_model_loaded가 True인지 확인 # _model을 직접 주입해서 is_model_loaded가 True인지 확인