From fd96055e73b2ca4827ea565e158c9e6ec1185f54 Mon Sep 17 00:00:00 2001 From: 21in7 Date: Sun, 1 Mar 2026 18:53:21 +0900 Subject: [PATCH] perf: replace generate_dataset with vectorized version in train_mlx_model Made-with: Cursor --- scripts/train_mlx_model.py | 102 +++++++++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 scripts/train_mlx_model.py diff --git a/scripts/train_mlx_model.py b/scripts/train_mlx_model.py new file mode 100644 index 0000000..1da1369 --- /dev/null +++ b/scripts/train_mlx_model.py @@ -0,0 +1,102 @@ +""" +MLX 기반 신경망 필터를 학습하고 저장한다. +M4 통합 GPU(Metal)를 자동으로 사용한다. + +사용법: python scripts/train_mlx_model.py --data data/xrpusdt_1m.parquet +""" +import sys +from pathlib import Path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +import argparse +import json +import time +from datetime import datetime + +import numpy as np +import pandas as pd +from sklearn.metrics import roc_auc_score, classification_report + +from src.dataset_builder import generate_dataset_vectorized +from src.ml_features import FEATURE_COLS +from src.mlx_filter import MLXFilter + +MLX_MODEL_PATH = Path("models/mlx_filter.weights") +LOG_PATH = Path("models/training_log.json") + + +def train_mlx(data_path: str) -> float: + print(f"데이터 로드: {data_path}") + df = pd.read_parquet(data_path) + print(f"캔들 수: {len(df)}") + + print("\n데이터셋 생성 중...") + t0 = time.perf_counter() + dataset = generate_dataset_vectorized(df) + t1 = time.perf_counter() + print(f"데이터셋 생성 완료: {t1 - t0:.1f}초, {len(dataset)}개 샘플") + + if dataset.empty or "label" not in dataset.columns: + raise ValueError("데이터셋 생성 실패: 샘플 0개") + + print(f"학습 샘플: {len(dataset)}개 (양성={dataset['label'].sum():.0f}, 음성={(dataset['label']==0).sum():.0f})") + + if len(dataset) < 200: + raise ValueError(f"학습 샘플 부족: {len(dataset)}개 (최소 200 필요)") + + X = dataset[FEATURE_COLS] + y = dataset["label"] + + split = int(len(X) * 0.8) + X_train, X_val = X.iloc[:split], X.iloc[split:] + y_train, y_val = y.iloc[:split], y.iloc[split:] + + print("\nMLX 신경망 학습 시작 (GPU)...") + t2 = time.perf_counter() + model = MLXFilter( + input_dim=len(FEATURE_COLS), + hidden_dim=128, + lr=1e-3, + epochs=100, + batch_size=256, + ) + model.fit(X_train, y_train) + t3 = time.perf_counter() + print(f"학습 완료: {t3 - t2:.1f}초") + + val_proba = model.predict_proba(X_val) + auc = roc_auc_score(y_val, val_proba) + print(f"\n검증 AUC: {auc:.4f}") + print(classification_report(y_val, (val_proba >= 0.60).astype(int))) + + MLX_MODEL_PATH.parent.mkdir(exist_ok=True) + model.save(MLX_MODEL_PATH) + print(f"모델 저장: {MLX_MODEL_PATH}") + + log = [] + if LOG_PATH.exists(): + with open(LOG_PATH) as f: + log = json.load(f) + log.append({ + "date": datetime.now().isoformat(), + "backend": "mlx", + "auc": round(auc, 4), + "samples": len(dataset), + "train_sec": round(t3 - t2, 1), + "model_path": str(MLX_MODEL_PATH), + }) + with open(LOG_PATH, "w") as f: + json.dump(log, f, indent=2) + + return auc + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--data", default="data/xrpusdt_1m.parquet") + args = parser.parse_args() + train_mlx(args.data) + + +if __name__ == "__main__": + main()