feat: implement ML filter with LightGBM for trading signal validation

- Added MLFilter class to load and evaluate LightGBM model for trading signals.
- Introduced retraining mechanism to update the model daily based on new data.
- Created feature engineering and label building utilities for model training.
- Updated bot logic to incorporate ML filter for signal validation.
- Added scripts for data fetching and model training.

Made-with: Cursor
This commit is contained in:
21in7
2026-03-01 17:07:18 +09:00
parent ce57479b93
commit 7e4e9315c2
24 changed files with 2916 additions and 6 deletions

0
scripts/__init__.py Normal file
View File

69
scripts/fetch_history.py Normal file
View File

@@ -0,0 +1,69 @@
"""
바이낸스 선물 REST API로 과거 캔들 데이터를 수집해 parquet으로 저장한다.
사용법: python scripts/fetch_history.py --symbol XRPUSDT --interval 1m --days 90
"""
import asyncio
import argparse
from datetime import datetime, timedelta
import pandas as pd
from binance import AsyncClient
from dotenv import load_dotenv
import os
load_dotenv()
async def fetch_klines(symbol: str, interval: str, days: int) -> pd.DataFrame:
client = await AsyncClient.create(
api_key=os.getenv("BINANCE_API_KEY", ""),
api_secret=os.getenv("BINANCE_API_SECRET", ""),
)
try:
start_ts = int((datetime.utcnow() - timedelta(days=days)).timestamp() * 1000)
all_klines = []
while True:
klines = await client.futures_klines(
symbol=symbol,
interval=interval,
startTime=start_ts,
limit=1500,
)
if not klines:
break
all_klines.extend(klines)
last_ts = klines[-1][0]
if last_ts >= int(datetime.utcnow().timestamp() * 1000):
break
start_ts = last_ts + 1
print(f"수집 중... {len(all_klines)}")
finally:
await client.close_connection()
df = pd.DataFrame(all_klines, columns=[
"timestamp", "open", "high", "low", "close", "volume",
"close_time", "quote_volume", "trades",
"taker_buy_base", "taker_buy_quote", "ignore",
])
df = df[["timestamp", "open", "high", "low", "close", "volume"]].copy()
for col in ["open", "high", "low", "close", "volume"]:
df[col] = df[col].astype(float)
df["timestamp"] = pd.to_datetime(df["timestamp"], unit="ms")
df.set_index("timestamp", inplace=True)
return df
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--symbol", default="XRPUSDT")
parser.add_argument("--interval", default="1m")
parser.add_argument("--days", type=int, default=90)
parser.add_argument("--output", default="data/xrpusdt_1m.parquet")
args = parser.parse_args()
df = asyncio.run(fetch_klines(args.symbol, args.interval, args.days))
df.to_parquet(args.output)
print(f"저장 완료: {args.output} ({len(df)}행)")
if __name__ == "__main__":
main()

151
scripts/train_model.py Normal file
View File

@@ -0,0 +1,151 @@
"""
과거 캔들 데이터로 LightGBM 필터 모델을 학습하고 저장한다.
사용법: python scripts/train_model.py --data data/xrpusdt_1m.parquet
"""
import argparse
import json
from datetime import datetime
from pathlib import Path
import joblib
import lightgbm as lgb
import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score, classification_report
from sklearn.model_selection import TimeSeriesSplit
from src.indicators import Indicators
from src.ml_features import build_features, FEATURE_COLS
from src.label_builder import build_labels
LOOKAHEAD = 60
ATR_SL_MULT = 1.5
ATR_TP_MULT = 3.0
MODEL_PATH = Path("models/lgbm_filter.pkl")
PREV_MODEL_PATH = Path("models/lgbm_filter_prev.pkl")
LOG_PATH = Path("models/training_log.json")
def generate_dataset(df: pd.DataFrame) -> pd.DataFrame:
"""신호 발생 시점마다 피처와 레이블을 생성한다."""
rows = []
total = len(df)
for i in range(60, total - LOOKAHEAD):
window = df.iloc[i - 60: i + 1].copy()
ind = Indicators(window)
df_ind = ind.calculate_all()
if df_ind.isna().any().any():
continue
signal = ind.get_signal(df_ind)
if signal == "HOLD":
continue
entry_price = float(df_ind["close"].iloc[-1])
atr = float(df_ind["atr"].iloc[-1])
if atr <= 0:
continue
stop_loss = entry_price - atr * ATR_SL_MULT if signal == "LONG" else entry_price + atr * ATR_SL_MULT
take_profit = entry_price + atr * ATR_TP_MULT if signal == "LONG" else entry_price - atr * ATR_TP_MULT
future = df.iloc[i + 1: i + 1 + LOOKAHEAD]
label = build_labels(
future_closes=future["close"].tolist(),
future_highs=future["high"].tolist(),
future_lows=future["low"].tolist(),
take_profit=take_profit,
stop_loss=stop_loss,
side=signal,
)
if label is None:
continue
features = build_features(df_ind, signal)
row = features.to_dict()
row["label"] = label
rows.append(row)
if len(rows) % 500 == 0:
print(f" 샘플 생성 중: {len(rows)}개 (인덱스 {i}/{total})")
return pd.DataFrame(rows)
def train(data_path: str):
print(f"데이터 로드: {data_path}")
df = pd.read_parquet(data_path)
print(f"캔들 수: {len(df)}")
print("데이터셋 생성 중...")
dataset = generate_dataset(df)
print(f"학습 샘플: {len(dataset)}개 (양성={dataset['label'].sum():.0f}, 음성={(dataset['label']==0).sum():.0f})")
if len(dataset) < 200:
raise ValueError(f"학습 샘플 부족: {len(dataset)}개 (최소 200 필요)")
X = dataset[FEATURE_COLS]
y = dataset["label"]
split = int(len(X) * 0.8)
X_train, X_val = X.iloc[:split], X.iloc[split:]
y_train, y_val = y.iloc[:split], y.iloc[split:]
model = lgb.LGBMClassifier(
n_estimators=300,
learning_rate=0.05,
num_leaves=31,
min_child_samples=20,
subsample=0.8,
colsample_bytree=0.8,
class_weight="balanced",
random_state=42,
verbose=-1,
)
model.fit(
X_train, y_train,
eval_set=[(X_val, y_val)],
callbacks=[lgb.early_stopping(30, verbose=False), lgb.log_evaluation(50)],
)
val_proba = model.predict_proba(X_val)[:, 1]
auc = roc_auc_score(y_val, val_proba)
print(f"\n검증 AUC: {auc:.4f}")
print(classification_report(y_val, (val_proba >= 0.60).astype(int)))
if MODEL_PATH.exists():
import shutil
shutil.copy(MODEL_PATH, PREV_MODEL_PATH)
print(f"기존 모델 백업: {PREV_MODEL_PATH}")
MODEL_PATH.parent.mkdir(exist_ok=True)
joblib.dump(model, MODEL_PATH)
print(f"모델 저장: {MODEL_PATH}")
log = []
if LOG_PATH.exists():
with open(LOG_PATH) as f:
log = json.load(f)
log.append({
"date": datetime.now().isoformat(),
"auc": round(auc, 4),
"samples": len(dataset),
"model_path": str(MODEL_PATH),
})
with open(LOG_PATH, "w") as f:
json.dump(log, f, indent=2)
return auc
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--data", default="data/xrpusdt_1m.parquet")
args = parser.parse_args()
train(args.data)
if __name__ == "__main__":
main()