103 lines
2.9 KiB
Python
103 lines
2.9 KiB
Python
"""
|
|
MLX 기반 신경망 필터를 학습하고 저장한다.
|
|
M4 통합 GPU(Metal)를 자동으로 사용한다.
|
|
|
|
사용법: python scripts/train_mlx_model.py --data data/xrpusdt_1m.parquet
|
|
"""
|
|
import sys
|
|
from pathlib import Path
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
|
|
import argparse
|
|
import json
|
|
import time
|
|
from datetime import datetime
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
from sklearn.metrics import roc_auc_score, classification_report
|
|
|
|
from src.dataset_builder import generate_dataset_vectorized
|
|
from src.ml_features import FEATURE_COLS
|
|
from src.mlx_filter import MLXFilter
|
|
|
|
MLX_MODEL_PATH = Path("models/mlx_filter.weights")
|
|
LOG_PATH = Path("models/training_log.json")
|
|
|
|
|
|
def train_mlx(data_path: str) -> float:
|
|
print(f"데이터 로드: {data_path}")
|
|
df = pd.read_parquet(data_path)
|
|
print(f"캔들 수: {len(df)}")
|
|
|
|
print("\n데이터셋 생성 중...")
|
|
t0 = time.perf_counter()
|
|
dataset = generate_dataset_vectorized(df)
|
|
t1 = time.perf_counter()
|
|
print(f"데이터셋 생성 완료: {t1 - t0:.1f}초, {len(dataset)}개 샘플")
|
|
|
|
if dataset.empty or "label" not in dataset.columns:
|
|
raise ValueError("데이터셋 생성 실패: 샘플 0개")
|
|
|
|
print(f"학습 샘플: {len(dataset)}개 (양성={dataset['label'].sum():.0f}, 음성={(dataset['label']==0).sum():.0f})")
|
|
|
|
if len(dataset) < 200:
|
|
raise ValueError(f"학습 샘플 부족: {len(dataset)}개 (최소 200 필요)")
|
|
|
|
X = dataset[FEATURE_COLS]
|
|
y = dataset["label"]
|
|
|
|
split = int(len(X) * 0.8)
|
|
X_train, X_val = X.iloc[:split], X.iloc[split:]
|
|
y_train, y_val = y.iloc[:split], y.iloc[split:]
|
|
|
|
print("\nMLX 신경망 학습 시작 (GPU)...")
|
|
t2 = time.perf_counter()
|
|
model = MLXFilter(
|
|
input_dim=len(FEATURE_COLS),
|
|
hidden_dim=128,
|
|
lr=1e-3,
|
|
epochs=100,
|
|
batch_size=256,
|
|
)
|
|
model.fit(X_train, y_train)
|
|
t3 = time.perf_counter()
|
|
print(f"학습 완료: {t3 - t2:.1f}초")
|
|
|
|
val_proba = model.predict_proba(X_val)
|
|
auc = roc_auc_score(y_val, val_proba)
|
|
print(f"\n검증 AUC: {auc:.4f}")
|
|
print(classification_report(y_val, (val_proba >= 0.60).astype(int)))
|
|
|
|
MLX_MODEL_PATH.parent.mkdir(exist_ok=True)
|
|
model.save(MLX_MODEL_PATH)
|
|
print(f"모델 저장: {MLX_MODEL_PATH}")
|
|
|
|
log = []
|
|
if LOG_PATH.exists():
|
|
with open(LOG_PATH) as f:
|
|
log = json.load(f)
|
|
log.append({
|
|
"date": datetime.now().isoformat(),
|
|
"backend": "mlx",
|
|
"auc": round(auc, 4),
|
|
"samples": len(dataset),
|
|
"train_sec": round(t3 - t2, 1),
|
|
"model_path": str(MLX_MODEL_PATH),
|
|
})
|
|
with open(LOG_PATH, "w") as f:
|
|
json.dump(log, f, indent=2)
|
|
|
|
return auc
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--data", default="data/xrpusdt_1m.parquet")
|
|
args = parser.parse_args()
|
|
train_mlx(args.data)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|