feat: implement ML filter with LightGBM for trading signal validation
- Added MLFilter class to load and evaluate LightGBM model for trading signals. - Introduced retraining mechanism to update the model daily based on new data. - Created feature engineering and label building utilities for model training. - Updated bot logic to incorporate ML filter for signal validation. - Added scripts for data fetching and model training. Made-with: Cursor
This commit is contained in:
13
src/bot.py
13
src/bot.py
@@ -6,6 +6,9 @@ from src.indicators import Indicators
|
||||
from src.data_stream import KlineStream
|
||||
from src.notifier import DiscordNotifier
|
||||
from src.risk_manager import RiskManager
|
||||
from src.ml_filter import MLFilter
|
||||
from src.ml_features import build_features
|
||||
from src.retrainer import Retrainer
|
||||
|
||||
|
||||
class TradingBot:
|
||||
@@ -14,6 +17,8 @@ class TradingBot:
|
||||
self.exchange = BinanceFuturesClient(config)
|
||||
self.notifier = DiscordNotifier(config.discord_webhook_url)
|
||||
self.risk = RiskManager(config)
|
||||
self.ml_filter = MLFilter()
|
||||
self.retrainer = Retrainer(ml_filter=self.ml_filter)
|
||||
self.current_trade_side: str | None = None # "LONG" | "SHORT"
|
||||
self.stream = KlineStream(
|
||||
symbol=config.symbol,
|
||||
@@ -52,6 +57,13 @@ class TradingBot:
|
||||
ind = Indicators(df)
|
||||
df_with_indicators = ind.calculate_all()
|
||||
signal = ind.get_signal(df_with_indicators)
|
||||
|
||||
if signal != "HOLD" and self.ml_filter.is_model_loaded():
|
||||
features = build_features(df_with_indicators, signal)
|
||||
if not self.ml_filter.should_enter(features):
|
||||
logger.info(f"ML 필터 차단: {signal} 신호 무시")
|
||||
signal = "HOLD"
|
||||
|
||||
current_price = df_with_indicators["close"].iloc[-1]
|
||||
logger.info(f"신호: {signal} | 현재가: {current_price:.4f} USDT")
|
||||
|
||||
@@ -153,6 +165,7 @@ class TradingBot:
|
||||
async def run(self):
|
||||
logger.info(f"봇 시작: {self.config.symbol}, 레버리지 {self.config.leverage}x")
|
||||
await self._recover_position()
|
||||
asyncio.create_task(self.retrainer.schedule_daily(hour=3))
|
||||
await self.stream.start(
|
||||
api_key=self.config.api_key,
|
||||
api_secret=self.config.api_secret,
|
||||
|
||||
29
src/label_builder.py
Normal file
29
src/label_builder.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def build_labels(
|
||||
future_closes: list[float],
|
||||
future_highs: list[float],
|
||||
future_lows: list[float],
|
||||
take_profit: float,
|
||||
stop_loss: float,
|
||||
side: str,
|
||||
) -> Optional[int]:
|
||||
"""
|
||||
진입 이후 미래 캔들을 순서대로 확인해 TP/SL 도달 여부를 판단한다.
|
||||
LONG: high >= TP → 1, low <= SL → 0
|
||||
SHORT: low <= TP → 1, high >= SL → 0
|
||||
둘 다 미도달 → None (학습 데이터에서 제외)
|
||||
"""
|
||||
for high, low in zip(future_highs, future_lows):
|
||||
if side == "LONG":
|
||||
if high >= take_profit:
|
||||
return 1
|
||||
if low <= stop_loss:
|
||||
return 0
|
||||
else: # SHORT
|
||||
if low <= take_profit:
|
||||
return 1
|
||||
if high >= stop_loss:
|
||||
return 0
|
||||
return None
|
||||
82
src/ml_features.py
Normal file
82
src/ml_features.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
FEATURE_COLS = [
|
||||
"rsi", "macd_hist", "bb_pct", "ema_align",
|
||||
"stoch_k", "stoch_d", "atr_pct", "vol_ratio",
|
||||
"ret_1", "ret_3", "ret_5", "signal_strength", "side",
|
||||
]
|
||||
|
||||
|
||||
def build_features(df: pd.DataFrame, signal: str) -> pd.Series:
|
||||
"""
|
||||
기술 지표가 계산된 DataFrame의 마지막 행에서 ML 피처를 추출한다.
|
||||
signal: "LONG" | "SHORT"
|
||||
"""
|
||||
last = df.iloc[-1]
|
||||
close = last["close"]
|
||||
|
||||
bb_upper = last.get("bb_upper", close)
|
||||
bb_lower = last.get("bb_lower", close)
|
||||
bb_range = bb_upper - bb_lower
|
||||
bb_pct = (close - bb_lower) / bb_range if bb_range > 0 else 0.5
|
||||
|
||||
ema9 = last.get("ema9", close)
|
||||
ema21 = last.get("ema21", close)
|
||||
ema50 = last.get("ema50", close)
|
||||
if ema9 > ema21 > ema50:
|
||||
ema_align = 1
|
||||
elif ema9 < ema21 < ema50:
|
||||
ema_align = -1
|
||||
else:
|
||||
ema_align = 0
|
||||
|
||||
atr = last.get("atr", 0)
|
||||
atr_pct = atr / close if close > 0 else 0
|
||||
|
||||
vol_ma20 = last.get("vol_ma20", last.get("volume", 1))
|
||||
vol_ratio = last["volume"] / vol_ma20 if vol_ma20 > 0 else 1.0
|
||||
|
||||
closes = df["close"]
|
||||
ret_1 = (close - closes.iloc[-2]) / closes.iloc[-2] if len(closes) >= 2 else 0.0
|
||||
ret_3 = (close - closes.iloc[-4]) / closes.iloc[-4] if len(closes) >= 4 else 0.0
|
||||
ret_5 = (close - closes.iloc[-6]) / closes.iloc[-6] if len(closes) >= 6 else 0.0
|
||||
|
||||
prev = df.iloc[-2] if len(df) >= 2 else last
|
||||
strength = 0
|
||||
rsi = last.get("rsi", 50)
|
||||
macd = last.get("macd", 0)
|
||||
macd_sig = last.get("macd_signal", 0)
|
||||
prev_macd = prev.get("macd", 0)
|
||||
prev_macd_sig = prev.get("macd_signal", 0)
|
||||
stoch_k = last.get("stoch_k", 50)
|
||||
stoch_d = last.get("stoch_d", 50)
|
||||
|
||||
if signal == "LONG":
|
||||
if rsi < 35: strength += 1
|
||||
if prev_macd < prev_macd_sig and macd > macd_sig: strength += 2
|
||||
if close < last.get("bb_lower", close): strength += 1
|
||||
if ema_align == 1: strength += 1
|
||||
if stoch_k < 20 and stoch_k > stoch_d: strength += 1
|
||||
else:
|
||||
if rsi > 65: strength += 1
|
||||
if prev_macd > prev_macd_sig and macd < macd_sig: strength += 2
|
||||
if close > last.get("bb_upper", close): strength += 1
|
||||
if ema_align == -1: strength += 1
|
||||
if stoch_k > 80 and stoch_k < stoch_d: strength += 1
|
||||
|
||||
return pd.Series({
|
||||
"rsi": float(rsi),
|
||||
"macd_hist": float(last.get("macd_hist", 0)),
|
||||
"bb_pct": float(bb_pct),
|
||||
"ema_align": float(ema_align),
|
||||
"stoch_k": float(stoch_k),
|
||||
"stoch_d": float(last.get("stoch_d", 50)),
|
||||
"atr_pct": float(atr_pct),
|
||||
"vol_ratio": float(vol_ratio),
|
||||
"ret_1": float(ret_1),
|
||||
"ret_3": float(ret_3),
|
||||
"ret_5": float(ret_5),
|
||||
"signal_strength": float(strength),
|
||||
"side": 1.0 if signal == "LONG" else 0.0,
|
||||
})
|
||||
50
src/ml_filter.py
Normal file
50
src/ml_filter.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from pathlib import Path
|
||||
import joblib
|
||||
import pandas as pd
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class MLFilter:
|
||||
"""
|
||||
LightGBM 모델을 로드하고 진입 여부를 판단한다.
|
||||
모델 파일이 없으면 항상 진입을 허용한다 (폴백).
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: str = "models/lgbm_filter.pkl", threshold: float = 0.60):
|
||||
self._model_path = Path(model_path)
|
||||
self._threshold = threshold
|
||||
self._model = None
|
||||
self._try_load()
|
||||
|
||||
def _try_load(self):
|
||||
if self._model_path.exists():
|
||||
try:
|
||||
self._model = joblib.load(self._model_path)
|
||||
logger.info(f"ML 필터 모델 로드 완료: {self._model_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"ML 필터 모델 로드 실패: {e}")
|
||||
self._model = None
|
||||
|
||||
def is_model_loaded(self) -> bool:
|
||||
return self._model is not None
|
||||
|
||||
def should_enter(self, features: pd.Series) -> bool:
|
||||
"""
|
||||
확률 >= threshold 이면 True (진입 허용).
|
||||
모델 없으면 True 반환 (폴백).
|
||||
"""
|
||||
if not self.is_model_loaded():
|
||||
return True
|
||||
try:
|
||||
X = features.to_frame().T
|
||||
proba = self._model.predict_proba(X)[0][1]
|
||||
logger.debug(f"ML 필터 확률: {proba:.3f} (임계값: {self._threshold})")
|
||||
return bool(proba >= self._threshold)
|
||||
except Exception as e:
|
||||
logger.warning(f"ML 필터 예측 오류 (폴백 허용): {e}")
|
||||
return True
|
||||
|
||||
def reload_model(self):
|
||||
"""재학습 후 모델을 핫 리로드한다."""
|
||||
self._try_load()
|
||||
logger.info("ML 필터 모델 리로드 완료")
|
||||
92
src/retrainer.py
Normal file
92
src/retrainer.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from src.ml_filter import MLFilter
|
||||
|
||||
MODEL_PATH = Path("models/lgbm_filter.pkl")
|
||||
PREV_MODEL_PATH = Path("models/lgbm_filter_prev.pkl")
|
||||
LOG_PATH = Path("models/training_log.json")
|
||||
|
||||
|
||||
def get_current_auc() -> float:
|
||||
"""training_log.json에서 가장 최근 AUC를 읽는다."""
|
||||
if not LOG_PATH.exists():
|
||||
return 0.0
|
||||
with open(LOG_PATH) as f:
|
||||
log = json.load(f)
|
||||
return log[-1]["auc"] if log else 0.0
|
||||
|
||||
|
||||
def rollback_model():
|
||||
"""이전 모델로 롤백한다."""
|
||||
if PREV_MODEL_PATH.exists():
|
||||
import shutil
|
||||
shutil.copy(PREV_MODEL_PATH, MODEL_PATH)
|
||||
logger.warning("ML 모델 롤백 완료")
|
||||
else:
|
||||
logger.warning("롤백할 이전 모델 없음")
|
||||
|
||||
|
||||
async def fetch_and_save(data_path: str):
|
||||
"""증분 데이터 수집 (fetch_history.py 로직 재사용)."""
|
||||
import subprocess
|
||||
result = subprocess.run(
|
||||
["python", "scripts/fetch_history.py", "--output", data_path, "--days", "90"],
|
||||
capture_output=True, text=True,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"데이터 수집 실패: {result.stderr}")
|
||||
logger.info(f"데이터 수집 완료: {data_path}")
|
||||
|
||||
|
||||
def run_training(data_path: str) -> float:
|
||||
"""train_model.py를 실행하고 새 AUC를 반환한다."""
|
||||
import subprocess
|
||||
result = subprocess.run(
|
||||
["python", "scripts/train_model.py", "--data", data_path],
|
||||
capture_output=True, text=True,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"학습 실패: {result.stderr}")
|
||||
new_auc = get_current_auc()
|
||||
return new_auc
|
||||
|
||||
|
||||
class Retrainer:
|
||||
def __init__(self, ml_filter: MLFilter, data_path: str = "data/xrpusdt_1m.parquet"):
|
||||
self._ml_filter = ml_filter
|
||||
self._data_path = data_path
|
||||
|
||||
async def retrain(self):
|
||||
logger.info("자동 재학습 시작")
|
||||
old_auc = get_current_auc()
|
||||
try:
|
||||
await fetch_and_save(self._data_path)
|
||||
new_auc = run_training(self._data_path)
|
||||
logger.info(f"재학습 완료: 이전 AUC={old_auc:.4f} → 새 AUC={new_auc:.4f}")
|
||||
|
||||
if new_auc < old_auc - 0.01:
|
||||
logger.warning(f"새 모델 성능 저하 ({new_auc:.4f} < {old_auc:.4f}), 롤백")
|
||||
rollback_model()
|
||||
else:
|
||||
self._ml_filter.reload_model()
|
||||
logger.success("새 ML 모델 적용 완료")
|
||||
except Exception as e:
|
||||
logger.error(f"재학습 실패: {e}")
|
||||
|
||||
async def schedule_daily(self, hour: int = 3):
|
||||
"""매일 지정 시각(컨테이너 로컬 시간 기준)에 재학습을 실행한다."""
|
||||
from datetime import timedelta
|
||||
while True:
|
||||
now = datetime.now()
|
||||
next_run = now.replace(hour=hour, minute=0, second=0, microsecond=0)
|
||||
if next_run <= now:
|
||||
next_run += timedelta(days=1)
|
||||
wait_secs = (next_run - now).total_seconds()
|
||||
logger.info(f"다음 재학습까지 {wait_secs/3600:.1f}시간 대기")
|
||||
await asyncio.sleep(wait_secs)
|
||||
await self.retrain()
|
||||
Reference in New Issue
Block a user