feat: implement ML filter with LightGBM for trading signal validation

- Added MLFilter class to load and evaluate LightGBM model for trading signals.
- Introduced retraining mechanism to update the model daily based on new data.
- Created feature engineering and label building utilities for model training.
- Updated bot logic to incorporate ML filter for signal validation.
- Added scripts for data fetching and model training.

Made-with: Cursor
This commit is contained in:
21in7
2026-03-01 17:07:18 +09:00
parent ce57479b93
commit 7e4e9315c2
24 changed files with 2916 additions and 6 deletions

View File

@@ -6,6 +6,9 @@ from src.indicators import Indicators
from src.data_stream import KlineStream
from src.notifier import DiscordNotifier
from src.risk_manager import RiskManager
from src.ml_filter import MLFilter
from src.ml_features import build_features
from src.retrainer import Retrainer
class TradingBot:
@@ -14,6 +17,8 @@ class TradingBot:
self.exchange = BinanceFuturesClient(config)
self.notifier = DiscordNotifier(config.discord_webhook_url)
self.risk = RiskManager(config)
self.ml_filter = MLFilter()
self.retrainer = Retrainer(ml_filter=self.ml_filter)
self.current_trade_side: str | None = None # "LONG" | "SHORT"
self.stream = KlineStream(
symbol=config.symbol,
@@ -52,6 +57,13 @@ class TradingBot:
ind = Indicators(df)
df_with_indicators = ind.calculate_all()
signal = ind.get_signal(df_with_indicators)
if signal != "HOLD" and self.ml_filter.is_model_loaded():
features = build_features(df_with_indicators, signal)
if not self.ml_filter.should_enter(features):
logger.info(f"ML 필터 차단: {signal} 신호 무시")
signal = "HOLD"
current_price = df_with_indicators["close"].iloc[-1]
logger.info(f"신호: {signal} | 현재가: {current_price:.4f} USDT")
@@ -153,6 +165,7 @@ class TradingBot:
async def run(self):
logger.info(f"봇 시작: {self.config.symbol}, 레버리지 {self.config.leverage}x")
await self._recover_position()
asyncio.create_task(self.retrainer.schedule_daily(hour=3))
await self.stream.start(
api_key=self.config.api_key,
api_secret=self.config.api_secret,

29
src/label_builder.py Normal file
View File

@@ -0,0 +1,29 @@
from typing import Optional
def build_labels(
future_closes: list[float],
future_highs: list[float],
future_lows: list[float],
take_profit: float,
stop_loss: float,
side: str,
) -> Optional[int]:
"""
진입 이후 미래 캔들을 순서대로 확인해 TP/SL 도달 여부를 판단한다.
LONG: high >= TP → 1, low <= SL → 0
SHORT: low <= TP → 1, high >= SL → 0
둘 다 미도달 → None (학습 데이터에서 제외)
"""
for high, low in zip(future_highs, future_lows):
if side == "LONG":
if high >= take_profit:
return 1
if low <= stop_loss:
return 0
else: # SHORT
if low <= take_profit:
return 1
if high >= stop_loss:
return 0
return None

82
src/ml_features.py Normal file
View File

@@ -0,0 +1,82 @@
import pandas as pd
import numpy as np
FEATURE_COLS = [
"rsi", "macd_hist", "bb_pct", "ema_align",
"stoch_k", "stoch_d", "atr_pct", "vol_ratio",
"ret_1", "ret_3", "ret_5", "signal_strength", "side",
]
def build_features(df: pd.DataFrame, signal: str) -> pd.Series:
"""
기술 지표가 계산된 DataFrame의 마지막 행에서 ML 피처를 추출한다.
signal: "LONG" | "SHORT"
"""
last = df.iloc[-1]
close = last["close"]
bb_upper = last.get("bb_upper", close)
bb_lower = last.get("bb_lower", close)
bb_range = bb_upper - bb_lower
bb_pct = (close - bb_lower) / bb_range if bb_range > 0 else 0.5
ema9 = last.get("ema9", close)
ema21 = last.get("ema21", close)
ema50 = last.get("ema50", close)
if ema9 > ema21 > ema50:
ema_align = 1
elif ema9 < ema21 < ema50:
ema_align = -1
else:
ema_align = 0
atr = last.get("atr", 0)
atr_pct = atr / close if close > 0 else 0
vol_ma20 = last.get("vol_ma20", last.get("volume", 1))
vol_ratio = last["volume"] / vol_ma20 if vol_ma20 > 0 else 1.0
closes = df["close"]
ret_1 = (close - closes.iloc[-2]) / closes.iloc[-2] if len(closes) >= 2 else 0.0
ret_3 = (close - closes.iloc[-4]) / closes.iloc[-4] if len(closes) >= 4 else 0.0
ret_5 = (close - closes.iloc[-6]) / closes.iloc[-6] if len(closes) >= 6 else 0.0
prev = df.iloc[-2] if len(df) >= 2 else last
strength = 0
rsi = last.get("rsi", 50)
macd = last.get("macd", 0)
macd_sig = last.get("macd_signal", 0)
prev_macd = prev.get("macd", 0)
prev_macd_sig = prev.get("macd_signal", 0)
stoch_k = last.get("stoch_k", 50)
stoch_d = last.get("stoch_d", 50)
if signal == "LONG":
if rsi < 35: strength += 1
if prev_macd < prev_macd_sig and macd > macd_sig: strength += 2
if close < last.get("bb_lower", close): strength += 1
if ema_align == 1: strength += 1
if stoch_k < 20 and stoch_k > stoch_d: strength += 1
else:
if rsi > 65: strength += 1
if prev_macd > prev_macd_sig and macd < macd_sig: strength += 2
if close > last.get("bb_upper", close): strength += 1
if ema_align == -1: strength += 1
if stoch_k > 80 and stoch_k < stoch_d: strength += 1
return pd.Series({
"rsi": float(rsi),
"macd_hist": float(last.get("macd_hist", 0)),
"bb_pct": float(bb_pct),
"ema_align": float(ema_align),
"stoch_k": float(stoch_k),
"stoch_d": float(last.get("stoch_d", 50)),
"atr_pct": float(atr_pct),
"vol_ratio": float(vol_ratio),
"ret_1": float(ret_1),
"ret_3": float(ret_3),
"ret_5": float(ret_5),
"signal_strength": float(strength),
"side": 1.0 if signal == "LONG" else 0.0,
})

50
src/ml_filter.py Normal file
View File

@@ -0,0 +1,50 @@
from pathlib import Path
import joblib
import pandas as pd
from loguru import logger
class MLFilter:
"""
LightGBM 모델을 로드하고 진입 여부를 판단한다.
모델 파일이 없으면 항상 진입을 허용한다 (폴백).
"""
def __init__(self, model_path: str = "models/lgbm_filter.pkl", threshold: float = 0.60):
self._model_path = Path(model_path)
self._threshold = threshold
self._model = None
self._try_load()
def _try_load(self):
if self._model_path.exists():
try:
self._model = joblib.load(self._model_path)
logger.info(f"ML 필터 모델 로드 완료: {self._model_path}")
except Exception as e:
logger.warning(f"ML 필터 모델 로드 실패: {e}")
self._model = None
def is_model_loaded(self) -> bool:
return self._model is not None
def should_enter(self, features: pd.Series) -> bool:
"""
확률 >= threshold 이면 True (진입 허용).
모델 없으면 True 반환 (폴백).
"""
if not self.is_model_loaded():
return True
try:
X = features.to_frame().T
proba = self._model.predict_proba(X)[0][1]
logger.debug(f"ML 필터 확률: {proba:.3f} (임계값: {self._threshold})")
return bool(proba >= self._threshold)
except Exception as e:
logger.warning(f"ML 필터 예측 오류 (폴백 허용): {e}")
return True
def reload_model(self):
"""재학습 후 모델을 핫 리로드한다."""
self._try_load()
logger.info("ML 필터 모델 리로드 완료")

92
src/retrainer.py Normal file
View File

@@ -0,0 +1,92 @@
import asyncio
import json
from datetime import datetime
from pathlib import Path
from loguru import logger
from src.ml_filter import MLFilter
MODEL_PATH = Path("models/lgbm_filter.pkl")
PREV_MODEL_PATH = Path("models/lgbm_filter_prev.pkl")
LOG_PATH = Path("models/training_log.json")
def get_current_auc() -> float:
"""training_log.json에서 가장 최근 AUC를 읽는다."""
if not LOG_PATH.exists():
return 0.0
with open(LOG_PATH) as f:
log = json.load(f)
return log[-1]["auc"] if log else 0.0
def rollback_model():
"""이전 모델로 롤백한다."""
if PREV_MODEL_PATH.exists():
import shutil
shutil.copy(PREV_MODEL_PATH, MODEL_PATH)
logger.warning("ML 모델 롤백 완료")
else:
logger.warning("롤백할 이전 모델 없음")
async def fetch_and_save(data_path: str):
"""증분 데이터 수집 (fetch_history.py 로직 재사용)."""
import subprocess
result = subprocess.run(
["python", "scripts/fetch_history.py", "--output", data_path, "--days", "90"],
capture_output=True, text=True,
)
if result.returncode != 0:
raise RuntimeError(f"데이터 수집 실패: {result.stderr}")
logger.info(f"데이터 수집 완료: {data_path}")
def run_training(data_path: str) -> float:
"""train_model.py를 실행하고 새 AUC를 반환한다."""
import subprocess
result = subprocess.run(
["python", "scripts/train_model.py", "--data", data_path],
capture_output=True, text=True,
)
if result.returncode != 0:
raise RuntimeError(f"학습 실패: {result.stderr}")
new_auc = get_current_auc()
return new_auc
class Retrainer:
def __init__(self, ml_filter: MLFilter, data_path: str = "data/xrpusdt_1m.parquet"):
self._ml_filter = ml_filter
self._data_path = data_path
async def retrain(self):
logger.info("자동 재학습 시작")
old_auc = get_current_auc()
try:
await fetch_and_save(self._data_path)
new_auc = run_training(self._data_path)
logger.info(f"재학습 완료: 이전 AUC={old_auc:.4f} → 새 AUC={new_auc:.4f}")
if new_auc < old_auc - 0.01:
logger.warning(f"새 모델 성능 저하 ({new_auc:.4f} < {old_auc:.4f}), 롤백")
rollback_model()
else:
self._ml_filter.reload_model()
logger.success("새 ML 모델 적용 완료")
except Exception as e:
logger.error(f"재학습 실패: {e}")
async def schedule_daily(self, hour: int = 3):
"""매일 지정 시각(컨테이너 로컬 시간 기준)에 재학습을 실행한다."""
from datetime import timedelta
while True:
now = datetime.now()
next_run = now.replace(hour=hour, minute=0, second=0, microsecond=0)
if next_run <= now:
next_run += timedelta(days=1)
wait_secs = (next_run - now).total_seconds()
logger.info(f"다음 재학습까지 {wait_secs/3600:.1f}시간 대기")
await asyncio.sleep(wait_secs)
await self.retrain()