feat: implement ML filter with LightGBM for trading signal validation
- Added MLFilter class to load and evaluate LightGBM model for trading signals. - Introduced retraining mechanism to update the model daily based on new data. - Created feature engineering and label building utilities for model training. - Updated bot logic to incorporate ML filter for signal validation. - Added scripts for data fetching and model training. Made-with: Cursor
This commit is contained in:
@@ -37,10 +37,8 @@ def sample_df():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bot_processes_signal(config, sample_df):
|
||||
with patch("src.bot.BinanceFuturesClient") as MockExchange, \
|
||||
patch("src.bot.TradeRepository") as MockRepo:
|
||||
with patch("src.bot.BinanceFuturesClient") as MockExchange:
|
||||
MockExchange.return_value = AsyncMock()
|
||||
MockRepo.return_value = MagicMock()
|
||||
bot = TradingBot(config)
|
||||
|
||||
bot.exchange = AsyncMock()
|
||||
@@ -48,8 +46,8 @@ async def test_bot_processes_signal(config, sample_df):
|
||||
bot.exchange.get_position = AsyncMock(return_value=None)
|
||||
bot.exchange.place_order = AsyncMock(return_value={"orderId": "123"})
|
||||
bot.exchange.set_leverage = AsyncMock(return_value={})
|
||||
bot.db = MagicMock()
|
||||
bot.db.save_trade = MagicMock(return_value={"id": "trade1"})
|
||||
bot.exchange.calculate_quantity = MagicMock(return_value=100.0)
|
||||
bot.exchange.MIN_NOTIONAL = 5.0
|
||||
|
||||
with patch("src.bot.Indicators") as MockInd:
|
||||
mock_ind = MagicMock()
|
||||
|
||||
73
tests/test_label_builder.py
Normal file
73
tests/test_label_builder.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pytest
|
||||
from src.label_builder import build_labels
|
||||
|
||||
|
||||
def make_signal_df():
|
||||
"""
|
||||
신호 발생 시점 이후 가격이 TP에 도달하는 시나리오
|
||||
entry=100, TP=103, SL=98.5
|
||||
"""
|
||||
future_closes = [100.5, 101.0, 101.8, 102.5, 103.1, 103.5]
|
||||
future_highs = [c + 0.3 for c in future_closes]
|
||||
future_lows = [c - 0.3 for c in future_closes]
|
||||
return future_closes, future_highs, future_lows
|
||||
|
||||
|
||||
def test_label_tp_reached():
|
||||
closes, highs, lows = make_signal_df()
|
||||
label = build_labels(
|
||||
future_closes=closes,
|
||||
future_highs=highs,
|
||||
future_lows=lows,
|
||||
take_profit=103.0,
|
||||
stop_loss=98.5,
|
||||
side="LONG",
|
||||
)
|
||||
assert label == 1, "TP 먼저 도달해야 레이블 1"
|
||||
|
||||
|
||||
def test_label_sl_reached():
|
||||
future_closes = [99.5, 99.0, 98.8, 98.4, 98.0]
|
||||
future_highs = [c + 0.3 for c in future_closes]
|
||||
future_lows = [c - 0.3 for c in future_closes]
|
||||
label = build_labels(
|
||||
future_closes=future_closes,
|
||||
future_highs=future_highs,
|
||||
future_lows=future_lows,
|
||||
take_profit=103.0,
|
||||
stop_loss=98.5,
|
||||
side="LONG",
|
||||
)
|
||||
assert label == 0, "SL 먼저 도달해야 레이블 0"
|
||||
|
||||
|
||||
def test_label_neither_reached_returns_none():
|
||||
future_closes = [100.1, 100.2, 100.3]
|
||||
future_highs = [c + 0.1 for c in future_closes]
|
||||
future_lows = [c - 0.1 for c in future_closes]
|
||||
label = build_labels(
|
||||
future_closes=future_closes,
|
||||
future_highs=future_highs,
|
||||
future_lows=future_lows,
|
||||
take_profit=103.0,
|
||||
stop_loss=98.5,
|
||||
side="LONG",
|
||||
)
|
||||
assert label is None, "미결 시 None 반환"
|
||||
|
||||
|
||||
def test_label_short_tp():
|
||||
future_closes = [99.5, 99.0, 98.0, 97.0]
|
||||
future_highs = [c + 0.3 for c in future_closes]
|
||||
future_lows = [c - 0.3 for c in future_closes]
|
||||
label = build_labels(
|
||||
future_closes=future_closes,
|
||||
future_highs=future_highs,
|
||||
future_lows=future_lows,
|
||||
take_profit=97.0,
|
||||
stop_loss=101.5,
|
||||
side="SHORT",
|
||||
)
|
||||
assert label == 1
|
||||
57
tests/test_ml_features.py
Normal file
57
tests/test_ml_features.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pytest
|
||||
from src.ml_features import build_features, FEATURE_COLS
|
||||
|
||||
|
||||
def make_df(n=100):
|
||||
"""테스트용 최소 DataFrame 생성"""
|
||||
np.random.seed(42)
|
||||
close = 100 + np.cumsum(np.random.randn(n) * 0.5)
|
||||
df = pd.DataFrame({
|
||||
"open": close * 0.999,
|
||||
"high": close * 1.002,
|
||||
"low": close * 0.998,
|
||||
"close": close,
|
||||
"volume": np.random.uniform(1000, 5000, n),
|
||||
})
|
||||
return df
|
||||
|
||||
|
||||
def test_build_features_returns_series():
|
||||
from src.indicators import Indicators
|
||||
df = make_df(100)
|
||||
ind = Indicators(df)
|
||||
df_ind = ind.calculate_all()
|
||||
features = build_features(df_ind, signal="LONG")
|
||||
assert isinstance(features, pd.Series)
|
||||
|
||||
|
||||
def test_build_features_has_all_cols():
|
||||
from src.indicators import Indicators
|
||||
df = make_df(100)
|
||||
ind = Indicators(df)
|
||||
df_ind = ind.calculate_all()
|
||||
features = build_features(df_ind, signal="LONG")
|
||||
for col in FEATURE_COLS:
|
||||
assert col in features.index, f"피처 누락: {col}"
|
||||
|
||||
|
||||
def test_build_features_no_nan():
|
||||
from src.indicators import Indicators
|
||||
df = make_df(100)
|
||||
ind = Indicators(df)
|
||||
df_ind = ind.calculate_all()
|
||||
features = build_features(df_ind, signal="LONG")
|
||||
assert not features.isna().any(), f"NaN 존재: {features[features.isna()]}"
|
||||
|
||||
|
||||
def test_side_encoding():
|
||||
from src.indicators import Indicators
|
||||
df = make_df(100)
|
||||
ind = Indicators(df)
|
||||
df_ind = ind.calculate_all()
|
||||
long_feat = build_features(df_ind, signal="LONG")
|
||||
short_feat = build_features(df_ind, signal="SHORT")
|
||||
assert long_feat["side"] == 1
|
||||
assert short_feat["side"] == 0
|
||||
63
tests/test_ml_filter.py
Normal file
63
tests/test_ml_filter.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from pathlib import Path
|
||||
from src.ml_filter import MLFilter
|
||||
from src.ml_features import FEATURE_COLS
|
||||
|
||||
|
||||
def make_features(side="LONG") -> pd.Series:
|
||||
return pd.Series({col: 0.5 for col in FEATURE_COLS} | {"side": 1.0 if side == "LONG" else 0.0})
|
||||
|
||||
|
||||
def test_no_model_file_is_not_loaded(tmp_path):
|
||||
f = MLFilter(model_path=str(tmp_path / "nonexistent.pkl"))
|
||||
assert not f.is_model_loaded()
|
||||
|
||||
|
||||
def test_no_model_should_enter_returns_true(tmp_path):
|
||||
"""모델 없으면 항상 진입 허용 (폴백)"""
|
||||
f = MLFilter(model_path=str(tmp_path / "nonexistent.pkl"))
|
||||
features = make_features()
|
||||
assert f.should_enter(features) is True
|
||||
|
||||
|
||||
def test_should_enter_above_threshold():
|
||||
"""확률 >= 0.60 이면 True"""
|
||||
f = MLFilter(threshold=0.60)
|
||||
mock_model = MagicMock()
|
||||
mock_model.predict_proba.return_value = np.array([[0.35, 0.65]])
|
||||
f._model = mock_model
|
||||
features = make_features()
|
||||
assert f.should_enter(features) is True
|
||||
|
||||
|
||||
def test_should_enter_below_threshold():
|
||||
"""확률 < 0.60 이면 False"""
|
||||
f = MLFilter(threshold=0.60)
|
||||
mock_model = MagicMock()
|
||||
mock_model.predict_proba.return_value = np.array([[0.55, 0.45]])
|
||||
f._model = mock_model
|
||||
features = make_features()
|
||||
assert f.should_enter(features) is False
|
||||
|
||||
|
||||
def test_reload_model(tmp_path):
|
||||
"""reload_model 호출 후 모델 로드 상태 변경"""
|
||||
import joblib
|
||||
|
||||
# 모델 파일이 없는 상태에서 시작
|
||||
model_path = tmp_path / "lgbm_filter.pkl"
|
||||
f = MLFilter(model_path=str(model_path))
|
||||
assert not f.is_model_loaded()
|
||||
|
||||
# _model을 직접 주입해서 is_model_loaded가 True인지 확인
|
||||
mock_model = MagicMock()
|
||||
f._model = mock_model
|
||||
assert f.is_model_loaded()
|
||||
|
||||
# reload_model 호출 시 파일이 없으면 _try_load가 _model을 변경하지 않음
|
||||
# (기존 동작 유지 - 파일 없으면 None으로 초기화하지 않음)
|
||||
f.reload_model()
|
||||
assert f.is_model_loaded() # mock_model이 유지됨
|
||||
35
tests/test_retrainer.py
Normal file
35
tests/test_retrainer.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import pytest
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from src.retrainer import Retrainer
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrain_calls_train(tmp_path):
|
||||
"""재학습 시 train 함수가 호출되는지 확인"""
|
||||
ml_filter = MagicMock()
|
||||
r = Retrainer(ml_filter=ml_filter, data_path=str(tmp_path / "data.parquet"))
|
||||
|
||||
with patch("src.retrainer.fetch_and_save", new_callable=AsyncMock) as mock_fetch, \
|
||||
patch("src.retrainer.run_training", return_value=0.72) as mock_train, \
|
||||
patch("src.retrainer.get_current_auc", return_value=0.65):
|
||||
await r.retrain()
|
||||
|
||||
mock_fetch.assert_called_once()
|
||||
mock_train.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrain_rollback_when_worse(tmp_path):
|
||||
"""새 모델이 기존보다 나쁘면 롤백"""
|
||||
ml_filter = MagicMock()
|
||||
r = Retrainer(ml_filter=ml_filter, data_path=str(tmp_path / "data.parquet"))
|
||||
|
||||
with patch("src.retrainer.fetch_and_save", new_callable=AsyncMock), \
|
||||
patch("src.retrainer.run_training", return_value=0.55), \
|
||||
patch("src.retrainer.get_current_auc", return_value=0.70), \
|
||||
patch("src.retrainer.rollback_model") as mock_rollback:
|
||||
await r.retrain()
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
Reference in New Issue
Block a user