fix: 기존 테스트를 현재 코드 구조에 맞게 수정 — MLFilter API, FEATURE_COLS 수, 버퍼 최솟값 반영
Made-with: Cursor
This commit is contained in:
@@ -23,6 +23,7 @@ def test_multi_symbol_stream_get_dataframe_returns_none_when_empty():
|
|||||||
|
|
||||||
def test_multi_symbol_stream_get_dataframe_returns_df_when_full():
|
def test_multi_symbol_stream_get_dataframe_returns_df_when_full():
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from src.data_stream import _MIN_CANDLES_FOR_SIGNAL
|
||||||
stream = MultiSymbolStream(
|
stream = MultiSymbolStream(
|
||||||
symbols=["XRPUSDT", "BTCUSDT", "ETHUSDT"],
|
symbols=["XRPUSDT", "BTCUSDT", "ETHUSDT"],
|
||||||
interval="1m",
|
interval="1m",
|
||||||
@@ -32,13 +33,13 @@ def test_multi_symbol_stream_get_dataframe_returns_df_when_full():
|
|||||||
"timestamp": 1000, "open": 1.0, "high": 1.1,
|
"timestamp": 1000, "open": 1.0, "high": 1.1,
|
||||||
"low": 0.9, "close": 1.05, "volume": 100.0, "is_closed": True,
|
"low": 0.9, "close": 1.05, "volume": 100.0, "is_closed": True,
|
||||||
}
|
}
|
||||||
for i in range(50):
|
for i in range(_MIN_CANDLES_FOR_SIGNAL):
|
||||||
c = candle.copy()
|
c = candle.copy()
|
||||||
c["timestamp"] = 1000 + i
|
c["timestamp"] = 1000 + i
|
||||||
stream.buffers["xrpusdt"].append(c)
|
stream.buffers["xrpusdt"].append(c)
|
||||||
df = stream.get_dataframe("XRPUSDT")
|
df = stream.get_dataframe("XRPUSDT")
|
||||||
assert df is not None
|
assert df is not None
|
||||||
assert len(df) == 50
|
assert len(df) == _MIN_CANDLES_FOR_SIGNAL
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@@ -49,9 +49,9 @@ def test_build_features_rs_zero_when_btc_ret_zero():
|
|||||||
features = build_features(xrp_df, "LONG", btc_df=btc_df, eth_df=eth_df)
|
features = build_features(xrp_df, "LONG", btc_df=btc_df, eth_df=eth_df)
|
||||||
assert features["xrp_btc_rs"] == 0.0
|
assert features["xrp_btc_rs"] == 0.0
|
||||||
|
|
||||||
def test_feature_cols_has_21_items():
|
def test_feature_cols_has_23_items():
|
||||||
from src.ml_features import FEATURE_COLS
|
from src.ml_features import FEATURE_COLS
|
||||||
assert len(FEATURE_COLS) == 21
|
assert len(FEATURE_COLS) == 23
|
||||||
|
|
||||||
|
|
||||||
def make_df(n=100):
|
def make_df(n=100):
|
||||||
|
|||||||
@@ -12,13 +12,19 @@ def make_features(side="LONG") -> pd.Series:
|
|||||||
|
|
||||||
|
|
||||||
def test_no_model_file_is_not_loaded(tmp_path):
|
def test_no_model_file_is_not_loaded(tmp_path):
|
||||||
f = MLFilter(model_path=str(tmp_path / "nonexistent.pkl"))
|
f = MLFilter(
|
||||||
|
onnx_path=str(tmp_path / "nonexistent.onnx"),
|
||||||
|
lgbm_path=str(tmp_path / "nonexistent.pkl"),
|
||||||
|
)
|
||||||
assert not f.is_model_loaded()
|
assert not f.is_model_loaded()
|
||||||
|
|
||||||
|
|
||||||
def test_no_model_should_enter_returns_true(tmp_path):
|
def test_no_model_should_enter_returns_true(tmp_path):
|
||||||
"""모델 없으면 항상 진입 허용 (폴백)"""
|
"""모델 없으면 항상 진입 허용 (폴백)"""
|
||||||
f = MLFilter(model_path=str(tmp_path / "nonexistent.pkl"))
|
f = MLFilter(
|
||||||
|
onnx_path=str(tmp_path / "nonexistent.onnx"),
|
||||||
|
lgbm_path=str(tmp_path / "nonexistent.pkl"),
|
||||||
|
)
|
||||||
features = make_features()
|
features = make_features()
|
||||||
assert f.should_enter(features) is True
|
assert f.should_enter(features) is True
|
||||||
|
|
||||||
@@ -28,7 +34,7 @@ def test_should_enter_above_threshold():
|
|||||||
f = MLFilter(threshold=0.60)
|
f = MLFilter(threshold=0.60)
|
||||||
mock_model = MagicMock()
|
mock_model = MagicMock()
|
||||||
mock_model.predict_proba.return_value = np.array([[0.35, 0.65]])
|
mock_model.predict_proba.return_value = np.array([[0.35, 0.65]])
|
||||||
f._model = mock_model
|
f._lgbm_model = mock_model
|
||||||
features = make_features()
|
features = make_features()
|
||||||
assert f.should_enter(features) is True
|
assert f.should_enter(features) is True
|
||||||
|
|
||||||
@@ -38,7 +44,7 @@ def test_should_enter_below_threshold():
|
|||||||
f = MLFilter(threshold=0.60)
|
f = MLFilter(threshold=0.60)
|
||||||
mock_model = MagicMock()
|
mock_model = MagicMock()
|
||||||
mock_model.predict_proba.return_value = np.array([[0.55, 0.45]])
|
mock_model.predict_proba.return_value = np.array([[0.55, 0.45]])
|
||||||
f._model = mock_model
|
f._lgbm_model = mock_model
|
||||||
features = make_features()
|
features = make_features()
|
||||||
assert f.should_enter(features) is False
|
assert f.should_enter(features) is False
|
||||||
|
|
||||||
@@ -48,8 +54,10 @@ def test_reload_model(tmp_path):
|
|||||||
import joblib
|
import joblib
|
||||||
|
|
||||||
# 모델 파일이 없는 상태에서 시작
|
# 모델 파일이 없는 상태에서 시작
|
||||||
model_path = tmp_path / "lgbm_filter.pkl"
|
f = MLFilter(
|
||||||
f = MLFilter(model_path=str(model_path))
|
onnx_path=str(tmp_path / "nonexistent.onnx"),
|
||||||
|
lgbm_path=str(tmp_path / "lgbm_filter.pkl"),
|
||||||
|
)
|
||||||
assert not f.is_model_loaded()
|
assert not f.is_model_loaded()
|
||||||
|
|
||||||
# _model을 직접 주입해서 is_model_loaded가 True인지 확인
|
# _model을 직접 주입해서 is_model_loaded가 True인지 확인
|
||||||
|
|||||||
Reference in New Issue
Block a user