feat: add README and enhance scripts for data fetching and model training

- Created README.md to document project features, structure, and setup instructions.
- Updated fetch_history.py to include path adjustments for module imports.
- Enhanced train_model.py for parallel processing of dataset generation and added command-line argument for specifying worker count.

Made-with: Cursor
This commit is contained in:
21in7
2026-03-01 17:42:12 +09:00
parent 7e4e9315c2
commit b86c88a8d6
3 changed files with 259 additions and 42 deletions

View File

@@ -2,9 +2,15 @@
과거 캔들 데이터로 LightGBM 필터 모델을 학습하고 저장한다.
사용법: python scripts/train_model.py --data data/xrpusdt_1m.parquet
"""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
import argparse
import json
import os
from datetime import datetime
from multiprocessing import Pool, cpu_count
from pathlib import Path
import joblib
@@ -12,7 +18,6 @@ import lightgbm as lgb
import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score, classification_report
from sklearn.model_selection import TimeSeriesSplit
from src.indicators import Indicators
from src.ml_features import build_features, FEATURE_COLS
@@ -26,61 +31,100 @@ PREV_MODEL_PATH = Path("models/lgbm_filter_prev.pkl")
LOG_PATH = Path("models/training_log.json")
def generate_dataset(df: pd.DataFrame) -> pd.DataFrame:
"""신호 발생 시점마다 피처레이블을 생성한다."""
rows = []
def _process_index(args: tuple) -> dict | None:
"""단일 인덱스에 대해 피처+레이블을 계산한다. Pool worker 함수."""
i, df_values, df_columns = args
df = pd.DataFrame(df_values, columns=df_columns)
window = df.iloc[i - 60: i + 1].copy()
ind = Indicators(window)
df_ind = ind.calculate_all()
if df_ind.isna().any().any():
return None
signal = ind.get_signal(df_ind)
if signal == "HOLD":
return None
entry_price = float(df_ind["close"].iloc[-1])
atr = float(df_ind["atr"].iloc[-1])
if atr <= 0:
return None
stop_loss = entry_price - atr * ATR_SL_MULT if signal == "LONG" else entry_price + atr * ATR_SL_MULT
take_profit = entry_price + atr * ATR_TP_MULT if signal == "LONG" else entry_price - atr * ATR_TP_MULT
future = df.iloc[i + 1: i + 1 + LOOKAHEAD]
label = build_labels(
future_closes=future["close"].tolist(),
future_highs=future["high"].tolist(),
future_lows=future["low"].tolist(),
take_profit=take_profit,
stop_loss=stop_loss,
side=signal,
)
if label is None:
return None
features = build_features(df_ind, signal)
row = features.to_dict()
row["label"] = label
return row
def generate_dataset(df: pd.DataFrame, n_jobs: int | None = None) -> pd.DataFrame:
"""신호 발생 시점마다 피처와 레이블을 병렬로 생성한다."""
total = len(df)
indices = range(60, total - LOOKAHEAD)
for i in range(60, total - LOOKAHEAD):
window = df.iloc[i - 60: i + 1].copy()
ind = Indicators(window)
df_ind = ind.calculate_all()
workers = n_jobs or max(1, cpu_count() - 1)
print(f" 병렬 처리: {workers}코어 사용 (총 {len(indices):,}개 인덱스)")
if df_ind.isna().any().any():
continue
# DataFrame을 numpy로 변환해서 worker 간 전달 비용 최소화
df_values = df.values
df_columns = list(df.columns)
task_args = [(i, df_values, df_columns) for i in indices]
signal = ind.get_signal(df_ind)
if signal == "HOLD":
continue
rows = []
errors = []
chunk = max(1, len(task_args) // (workers * 10))
with Pool(processes=workers) as pool:
for idx, result in enumerate(pool.imap(_process_index, task_args, chunksize=chunk)):
if isinstance(result, dict):
rows.append(result)
elif result is not None:
errors.append(result)
if (idx + 1) % 10000 == 0:
print(f" 진행: {idx + 1:,}/{len(task_args):,} | 샘플: {len(rows):,}")
entry_price = float(df_ind["close"].iloc[-1])
atr = float(df_ind["atr"].iloc[-1])
if atr <= 0:
continue
if errors:
print(f" [경고] worker 오류 {len(errors)}건: {errors[0]}")
stop_loss = entry_price - atr * ATR_SL_MULT if signal == "LONG" else entry_price + atr * ATR_SL_MULT
take_profit = entry_price + atr * ATR_TP_MULT if signal == "LONG" else entry_price - atr * ATR_TP_MULT
future = df.iloc[i + 1: i + 1 + LOOKAHEAD]
label = build_labels(
future_closes=future["close"].tolist(),
future_highs=future["high"].tolist(),
future_lows=future["low"].tolist(),
take_profit=take_profit,
stop_loss=stop_loss,
side=signal,
)
if label is None:
continue
features = build_features(df_ind, signal)
row = features.to_dict()
row["label"] = label
rows.append(row)
if len(rows) % 500 == 0:
print(f" 샘플 생성 중: {len(rows)}개 (인덱스 {i}/{total})")
if not rows:
print(" [오류] 생성된 샘플이 없습니다. worker 예외 여부를 확인합니다...")
# 단일 프로세스로 첫 번째 인덱스를 직접 실행해서 예외 확인
try:
test_result = _process_index(task_args[0])
print(f" 단일 실행 결과: {test_result}")
except Exception as e:
import traceback
print(f" 단일 실행 예외:\n{traceback.format_exc()}")
return pd.DataFrame(rows)
def train(data_path: str):
def train(data_path: str, n_jobs: int | None = None):
print(f"데이터 로드: {data_path}")
df = pd.read_parquet(data_path)
print(f"캔들 수: {len(df)}")
print("데이터셋 생성 중...")
dataset = generate_dataset(df)
dataset = generate_dataset(df, n_jobs=n_jobs)
if dataset.empty or "label" not in dataset.columns:
raise ValueError(f"데이터셋 생성 실패: 샘플 0개. 위 오류 메시지를 확인하세요.")
print(f"학습 샘플: {len(dataset)}개 (양성={dataset['label'].sum():.0f}, 음성={(dataset['label']==0).sum():.0f})")
if len(dataset) < 200:
@@ -143,8 +187,10 @@ def train(data_path: str):
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--data", default="data/xrpusdt_1m.parquet")
parser.add_argument("--jobs", type=int, default=None,
help="병렬 worker 수 (기본: CPU 수 - 1)")
args = parser.parse_args()
train(args.data)
train(args.data, n_jobs=args.jobs)
if __name__ == "__main__":