fix: 기존 테스트를 현재 코드 구조에 맞게 수정 — MLFilter API, FEATURE_COLS 수, 버퍼 최솟값 반영
Made-with: Cursor
This commit is contained in:
@@ -23,6 +23,7 @@ def test_multi_symbol_stream_get_dataframe_returns_none_when_empty():
|
||||
|
||||
def test_multi_symbol_stream_get_dataframe_returns_df_when_full():
|
||||
import pandas as pd
|
||||
from src.data_stream import _MIN_CANDLES_FOR_SIGNAL
|
||||
stream = MultiSymbolStream(
|
||||
symbols=["XRPUSDT", "BTCUSDT", "ETHUSDT"],
|
||||
interval="1m",
|
||||
@@ -32,13 +33,13 @@ def test_multi_symbol_stream_get_dataframe_returns_df_when_full():
|
||||
"timestamp": 1000, "open": 1.0, "high": 1.1,
|
||||
"low": 0.9, "close": 1.05, "volume": 100.0, "is_closed": True,
|
||||
}
|
||||
for i in range(50):
|
||||
for i in range(_MIN_CANDLES_FOR_SIGNAL):
|
||||
c = candle.copy()
|
||||
c["timestamp"] = 1000 + i
|
||||
stream.buffers["xrpusdt"].append(c)
|
||||
df = stream.get_dataframe("XRPUSDT")
|
||||
assert df is not None
|
||||
assert len(df) == 50
|
||||
assert len(df) == _MIN_CANDLES_FOR_SIGNAL
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -49,9 +49,9 @@ def test_build_features_rs_zero_when_btc_ret_zero():
|
||||
features = build_features(xrp_df, "LONG", btc_df=btc_df, eth_df=eth_df)
|
||||
assert features["xrp_btc_rs"] == 0.0
|
||||
|
||||
def test_feature_cols_has_21_items():
|
||||
def test_feature_cols_has_23_items():
|
||||
from src.ml_features import FEATURE_COLS
|
||||
assert len(FEATURE_COLS) == 21
|
||||
assert len(FEATURE_COLS) == 23
|
||||
|
||||
|
||||
def make_df(n=100):
|
||||
|
||||
@@ -12,13 +12,19 @@ def make_features(side="LONG") -> pd.Series:
|
||||
|
||||
|
||||
def test_no_model_file_is_not_loaded(tmp_path):
|
||||
f = MLFilter(model_path=str(tmp_path / "nonexistent.pkl"))
|
||||
f = MLFilter(
|
||||
onnx_path=str(tmp_path / "nonexistent.onnx"),
|
||||
lgbm_path=str(tmp_path / "nonexistent.pkl"),
|
||||
)
|
||||
assert not f.is_model_loaded()
|
||||
|
||||
|
||||
def test_no_model_should_enter_returns_true(tmp_path):
|
||||
"""모델 없으면 항상 진입 허용 (폴백)"""
|
||||
f = MLFilter(model_path=str(tmp_path / "nonexistent.pkl"))
|
||||
f = MLFilter(
|
||||
onnx_path=str(tmp_path / "nonexistent.onnx"),
|
||||
lgbm_path=str(tmp_path / "nonexistent.pkl"),
|
||||
)
|
||||
features = make_features()
|
||||
assert f.should_enter(features) is True
|
||||
|
||||
@@ -28,7 +34,7 @@ def test_should_enter_above_threshold():
|
||||
f = MLFilter(threshold=0.60)
|
||||
mock_model = MagicMock()
|
||||
mock_model.predict_proba.return_value = np.array([[0.35, 0.65]])
|
||||
f._model = mock_model
|
||||
f._lgbm_model = mock_model
|
||||
features = make_features()
|
||||
assert f.should_enter(features) is True
|
||||
|
||||
@@ -38,7 +44,7 @@ def test_should_enter_below_threshold():
|
||||
f = MLFilter(threshold=0.60)
|
||||
mock_model = MagicMock()
|
||||
mock_model.predict_proba.return_value = np.array([[0.55, 0.45]])
|
||||
f._model = mock_model
|
||||
f._lgbm_model = mock_model
|
||||
features = make_features()
|
||||
assert f.should_enter(features) is False
|
||||
|
||||
@@ -48,8 +54,10 @@ def test_reload_model(tmp_path):
|
||||
import joblib
|
||||
|
||||
# 모델 파일이 없는 상태에서 시작
|
||||
model_path = tmp_path / "lgbm_filter.pkl"
|
||||
f = MLFilter(model_path=str(model_path))
|
||||
f = MLFilter(
|
||||
onnx_path=str(tmp_path / "nonexistent.onnx"),
|
||||
lgbm_path=str(tmp_path / "lgbm_filter.pkl"),
|
||||
)
|
||||
assert not f.is_model_loaded()
|
||||
|
||||
# _model을 직접 주입해서 is_model_loaded가 True인지 확인
|
||||
|
||||
Reference in New Issue
Block a user