fix: 기존 테스트를 현재 코드 구조에 맞게 수정 — MLFilter API, FEATURE_COLS 수, 버퍼 최솟값 반영

Made-with: Cursor
This commit is contained in:
21in7
2026-03-02 00:36:13 +09:00
parent 3bfd1ca5a3
commit 518f1846b8
3 changed files with 19 additions and 10 deletions

View File

@@ -23,6 +23,7 @@ def test_multi_symbol_stream_get_dataframe_returns_none_when_empty():
def test_multi_symbol_stream_get_dataframe_returns_df_when_full():
import pandas as pd
from src.data_stream import _MIN_CANDLES_FOR_SIGNAL
stream = MultiSymbolStream(
symbols=["XRPUSDT", "BTCUSDT", "ETHUSDT"],
interval="1m",
@@ -32,13 +33,13 @@ def test_multi_symbol_stream_get_dataframe_returns_df_when_full():
"timestamp": 1000, "open": 1.0, "high": 1.1,
"low": 0.9, "close": 1.05, "volume": 100.0, "is_closed": True,
}
for i in range(50):
for i in range(_MIN_CANDLES_FOR_SIGNAL):
c = candle.copy()
c["timestamp"] = 1000 + i
stream.buffers["xrpusdt"].append(c)
df = stream.get_dataframe("XRPUSDT")
assert df is not None
assert len(df) == 50
assert len(df) == _MIN_CANDLES_FOR_SIGNAL
@pytest.mark.asyncio

View File

@@ -49,9 +49,9 @@ def test_build_features_rs_zero_when_btc_ret_zero():
features = build_features(xrp_df, "LONG", btc_df=btc_df, eth_df=eth_df)
assert features["xrp_btc_rs"] == 0.0
def test_feature_cols_has_21_items():
def test_feature_cols_has_23_items():
from src.ml_features import FEATURE_COLS
assert len(FEATURE_COLS) == 21
assert len(FEATURE_COLS) == 23
def make_df(n=100):

View File

@@ -12,13 +12,19 @@ def make_features(side="LONG") -> pd.Series:
def test_no_model_file_is_not_loaded(tmp_path):
f = MLFilter(model_path=str(tmp_path / "nonexistent.pkl"))
f = MLFilter(
onnx_path=str(tmp_path / "nonexistent.onnx"),
lgbm_path=str(tmp_path / "nonexistent.pkl"),
)
assert not f.is_model_loaded()
def test_no_model_should_enter_returns_true(tmp_path):
"""모델 없으면 항상 진입 허용 (폴백)"""
f = MLFilter(model_path=str(tmp_path / "nonexistent.pkl"))
f = MLFilter(
onnx_path=str(tmp_path / "nonexistent.onnx"),
lgbm_path=str(tmp_path / "nonexistent.pkl"),
)
features = make_features()
assert f.should_enter(features) is True
@@ -28,7 +34,7 @@ def test_should_enter_above_threshold():
f = MLFilter(threshold=0.60)
mock_model = MagicMock()
mock_model.predict_proba.return_value = np.array([[0.35, 0.65]])
f._model = mock_model
f._lgbm_model = mock_model
features = make_features()
assert f.should_enter(features) is True
@@ -38,7 +44,7 @@ def test_should_enter_below_threshold():
f = MLFilter(threshold=0.60)
mock_model = MagicMock()
mock_model.predict_proba.return_value = np.array([[0.55, 0.45]])
f._model = mock_model
f._lgbm_model = mock_model
features = make_features()
assert f.should_enter(features) is False
@@ -48,8 +54,10 @@ def test_reload_model(tmp_path):
import joblib
# 모델 파일이 없는 상태에서 시작
model_path = tmp_path / "lgbm_filter.pkl"
f = MLFilter(model_path=str(model_path))
f = MLFilter(
onnx_path=str(tmp_path / "nonexistent.onnx"),
lgbm_path=str(tmp_path / "lgbm_filter.pkl"),
)
assert not f.is_model_loaded()
# _model을 직접 주입해서 is_model_loaded가 True인지 확인