feat: remove in-container retraining, training is now mac-only
Made-with: Cursor
This commit is contained in:
@@ -8,7 +8,6 @@ from src.notifier import DiscordNotifier
|
||||
from src.risk_manager import RiskManager
|
||||
from src.ml_filter import MLFilter
|
||||
from src.ml_features import build_features
|
||||
from src.retrainer import Retrainer
|
||||
|
||||
|
||||
class TradingBot:
|
||||
@@ -18,7 +17,6 @@ class TradingBot:
|
||||
self.notifier = DiscordNotifier(config.discord_webhook_url)
|
||||
self.risk = RiskManager(config)
|
||||
self.ml_filter = MLFilter()
|
||||
self.retrainer = Retrainer(ml_filter=self.ml_filter)
|
||||
self.current_trade_side: str | None = None # "LONG" | "SHORT"
|
||||
self.stream = KlineStream(
|
||||
symbol=config.symbol,
|
||||
@@ -165,7 +163,6 @@ class TradingBot:
|
||||
async def run(self):
|
||||
logger.info(f"봇 시작: {self.config.symbol}, 레버리지 {self.config.leverage}x")
|
||||
await self._recover_position()
|
||||
asyncio.create_task(self.retrainer.schedule_daily(hour=3))
|
||||
await self.stream.start(
|
||||
api_key=self.config.api_key,
|
||||
api_secret=self.config.api_secret,
|
||||
|
||||
130
src/mlx_filter.py
Normal file
130
src/mlx_filter.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""
|
||||
Apple MLX 기반 경량 신경망 필터.
|
||||
M4의 통합 GPU를 자동으로 활용한다.
|
||||
"""
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
from pathlib import Path
|
||||
|
||||
from src.ml_features import FEATURE_COLS
|
||||
|
||||
|
||||
class _Net(nn.Module):
|
||||
"""3층 MLP 이진 분류기."""
|
||||
|
||||
def __init__(self, input_dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(input_dim, hidden_dim)
|
||||
self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
|
||||
self.fc3 = nn.Linear(hidden_dim // 2, 1)
|
||||
self.dropout = nn.Dropout(p=0.2)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
x = nn.relu(self.fc1(x))
|
||||
x = self.dropout(x)
|
||||
x = nn.relu(self.fc2(x))
|
||||
return self.fc3(x).squeeze(-1)
|
||||
|
||||
|
||||
class MLXFilter:
|
||||
"""
|
||||
scikit-learn 호환 인터페이스를 제공하는 MLX 신경망 필터.
|
||||
M4 통합 GPU(Metal)를 자동으로 사용한다.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int = 13,
|
||||
hidden_dim: int = 64,
|
||||
lr: float = 1e-3,
|
||||
epochs: int = 50,
|
||||
batch_size: int = 256,
|
||||
):
|
||||
self.input_dim = input_dim
|
||||
self.hidden_dim = hidden_dim
|
||||
self.lr = lr
|
||||
self.epochs = epochs
|
||||
self.batch_size = batch_size
|
||||
self._model = _Net(input_dim, hidden_dim)
|
||||
self._mean: np.ndarray | None = None
|
||||
self._std: np.ndarray | None = None
|
||||
self._trained = False
|
||||
|
||||
def fit(self, X: pd.DataFrame, y: pd.Series) -> "MLXFilter":
|
||||
X_np = X[FEATURE_COLS].values.astype(np.float32)
|
||||
y_np = y.values.astype(np.float32)
|
||||
|
||||
self._mean = X_np.mean(axis=0)
|
||||
self._std = X_np.std(axis=0) + 1e-8
|
||||
X_np = (X_np - self._mean) / self._std
|
||||
|
||||
optimizer = optim.Adam(learning_rate=self.lr)
|
||||
|
||||
def loss_fn(model: _Net, x: mx.array, y: mx.array) -> mx.array:
|
||||
logits = model(x)
|
||||
return nn.losses.binary_cross_entropy(logits, y, with_logits=True)
|
||||
|
||||
loss_and_grad = nn.value_and_grad(self._model, loss_fn)
|
||||
|
||||
n = len(X_np)
|
||||
for epoch in range(self.epochs):
|
||||
idx = np.random.permutation(n)
|
||||
epoch_loss = 0.0
|
||||
steps = 0
|
||||
for start in range(0, n, self.batch_size):
|
||||
batch_idx = idx[start : start + self.batch_size]
|
||||
x_batch = mx.array(X_np[batch_idx])
|
||||
y_batch = mx.array(y_np[batch_idx])
|
||||
loss, grads = loss_and_grad(self._model, x_batch, y_batch)
|
||||
optimizer.update(self._model, grads)
|
||||
mx.eval(self._model.parameters(), optimizer.state)
|
||||
epoch_loss += loss.item()
|
||||
steps += 1
|
||||
if (epoch + 1) % 10 == 0:
|
||||
print(f" Epoch {epoch + 1}/{self.epochs} loss={epoch_loss / steps:.4f}")
|
||||
|
||||
self._trained = True
|
||||
return self
|
||||
|
||||
def predict_proba(self, X: pd.DataFrame) -> np.ndarray:
|
||||
X_np = X[FEATURE_COLS].values.astype(np.float32)
|
||||
if self._trained and self._mean is not None:
|
||||
X_np = (X_np - self._mean) / self._std
|
||||
x = mx.array(X_np)
|
||||
self._model.eval()
|
||||
logits = self._model(x)
|
||||
proba = mx.sigmoid(logits)
|
||||
mx.eval(proba)
|
||||
self._model.train()
|
||||
return np.array(proba)
|
||||
|
||||
def save(self, path: str | Path) -> None:
|
||||
path = Path(path)
|
||||
path.parent.mkdir(exist_ok=True)
|
||||
weights_path = path.with_suffix(".npz")
|
||||
self._model.save_weights(str(weights_path))
|
||||
meta_path = path.with_suffix(".meta.npz")
|
||||
np.savez(
|
||||
meta_path,
|
||||
mean=self._mean,
|
||||
std=self._std,
|
||||
input_dim=np.array(self.input_dim),
|
||||
hidden_dim=np.array(self.hidden_dim),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str | Path) -> "MLXFilter":
|
||||
path = Path(path)
|
||||
meta = np.load(path.with_suffix(".meta.npz"))
|
||||
obj = cls(
|
||||
input_dim=int(meta["input_dim"]),
|
||||
hidden_dim=int(meta["hidden_dim"]),
|
||||
)
|
||||
obj._mean = meta["mean"]
|
||||
obj._std = meta["std"]
|
||||
obj._model.load_weights(str(path.with_suffix(".npz")))
|
||||
obj._trained = True
|
||||
return obj
|
||||
@@ -1,92 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from src.ml_filter import MLFilter
|
||||
|
||||
MODEL_PATH = Path("models/lgbm_filter.pkl")
|
||||
PREV_MODEL_PATH = Path("models/lgbm_filter_prev.pkl")
|
||||
LOG_PATH = Path("models/training_log.json")
|
||||
|
||||
|
||||
def get_current_auc() -> float:
|
||||
"""training_log.json에서 가장 최근 AUC를 읽는다."""
|
||||
if not LOG_PATH.exists():
|
||||
return 0.0
|
||||
with open(LOG_PATH) as f:
|
||||
log = json.load(f)
|
||||
return log[-1]["auc"] if log else 0.0
|
||||
|
||||
|
||||
def rollback_model():
|
||||
"""이전 모델로 롤백한다."""
|
||||
if PREV_MODEL_PATH.exists():
|
||||
import shutil
|
||||
shutil.copy(PREV_MODEL_PATH, MODEL_PATH)
|
||||
logger.warning("ML 모델 롤백 완료")
|
||||
else:
|
||||
logger.warning("롤백할 이전 모델 없음")
|
||||
|
||||
|
||||
async def fetch_and_save(data_path: str):
|
||||
"""증분 데이터 수집 (fetch_history.py 로직 재사용)."""
|
||||
import subprocess
|
||||
result = subprocess.run(
|
||||
["python", "scripts/fetch_history.py", "--output", data_path, "--days", "90"],
|
||||
capture_output=True, text=True,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"데이터 수집 실패: {result.stderr}")
|
||||
logger.info(f"데이터 수집 완료: {data_path}")
|
||||
|
||||
|
||||
def run_training(data_path: str) -> float:
|
||||
"""train_model.py를 실행하고 새 AUC를 반환한다."""
|
||||
import subprocess
|
||||
result = subprocess.run(
|
||||
["python", "scripts/train_model.py", "--data", data_path],
|
||||
capture_output=True, text=True,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"학습 실패: {result.stderr}")
|
||||
new_auc = get_current_auc()
|
||||
return new_auc
|
||||
|
||||
|
||||
class Retrainer:
|
||||
def __init__(self, ml_filter: MLFilter, data_path: str = "data/xrpusdt_1m.parquet"):
|
||||
self._ml_filter = ml_filter
|
||||
self._data_path = data_path
|
||||
|
||||
async def retrain(self):
|
||||
logger.info("자동 재학습 시작")
|
||||
old_auc = get_current_auc()
|
||||
try:
|
||||
await fetch_and_save(self._data_path)
|
||||
new_auc = run_training(self._data_path)
|
||||
logger.info(f"재학습 완료: 이전 AUC={old_auc:.4f} → 새 AUC={new_auc:.4f}")
|
||||
|
||||
if new_auc < old_auc - 0.01:
|
||||
logger.warning(f"새 모델 성능 저하 ({new_auc:.4f} < {old_auc:.4f}), 롤백")
|
||||
rollback_model()
|
||||
else:
|
||||
self._ml_filter.reload_model()
|
||||
logger.success("새 ML 모델 적용 완료")
|
||||
except Exception as e:
|
||||
logger.error(f"재학습 실패: {e}")
|
||||
|
||||
async def schedule_daily(self, hour: int = 3):
|
||||
"""매일 지정 시각(컨테이너 로컬 시간 기준)에 재학습을 실행한다."""
|
||||
from datetime import timedelta
|
||||
while True:
|
||||
now = datetime.now()
|
||||
next_run = now.replace(hour=hour, minute=0, second=0, microsecond=0)
|
||||
if next_run <= now:
|
||||
next_run += timedelta(days=1)
|
||||
wait_secs = (next_run - now).total_seconds()
|
||||
logger.info(f"다음 재학습까지 {wait_secs/3600:.1f}시간 대기")
|
||||
await asyncio.sleep(wait_secs)
|
||||
await self.retrain()
|
||||
Reference in New Issue
Block a user