feat: remove in-container retraining, training is now mac-only

Made-with: Cursor
This commit is contained in:
21in7
2026-03-01 18:54:00 +09:00
parent fd96055e73
commit de933b97cc
13 changed files with 955 additions and 132 deletions

86
tests/test_mlx_filter.py Normal file
View File

@@ -0,0 +1,86 @@
"""
MLXFilter 단위 테스트.
Apple Silicon GPU(Metal)가 없는 환경에서는 스킵한다.
"""
import numpy as np
import pandas as pd
import pytest
mlx = pytest.importorskip("mlx.core", reason="MLX 미설치")
def _make_X(n: int = 4) -> pd.DataFrame:
rng = np.random.default_rng(0)
return pd.DataFrame(
{
"rsi": rng.uniform(20, 80, n),
"macd_hist": rng.uniform(-0.1, 0.1, n),
"bb_pct": rng.uniform(0, 1, n),
"ema_align": rng.choice([-1.0, 0.0, 1.0], n),
"stoch_k": rng.uniform(0, 100, n),
"stoch_d": rng.uniform(0, 100, n),
"atr_pct": rng.uniform(0.001, 0.05, n),
"vol_ratio": rng.uniform(0.5, 3.0, n),
"ret_1": rng.uniform(-0.01, 0.01, n),
"ret_3": rng.uniform(-0.02, 0.02, n),
"ret_5": rng.uniform(-0.03, 0.03, n),
"signal_strength": rng.integers(0, 6, n).astype(float),
"side": rng.choice([0.0, 1.0], n),
}
)
def test_mlx_gpu_device():
"""MLX가 GPU 디바이스를 기본으로 사용해야 한다."""
import mlx.core as mx
device = mx.default_device()
assert "gpu" in str(device)
def test_mlx_filter_predict_shape_untrained():
"""학습 전에도 predict_proba가 (N,) 형태를 반환해야 한다."""
from src.mlx_filter import MLXFilter
X = _make_X(4)
model = MLXFilter(input_dim=13, hidden_dim=32)
proba = model.predict_proba(X)
assert proba.shape == (4,)
assert np.all((proba >= 0.0) & (proba <= 1.0))
def test_mlx_filter_fit_and_predict():
"""학습 후 predict_proba가 유효한 확률값을 반환해야 한다."""
from src.mlx_filter import MLXFilter
n = 100
X = _make_X(n)
y = pd.Series(np.random.randint(0, 2, n))
model = MLXFilter(input_dim=13, hidden_dim=32, epochs=5, batch_size=32)
model.fit(X, y)
proba = model.predict_proba(X)
assert proba.shape == (n,)
assert np.all((proba >= 0.0) & (proba <= 1.0))
def test_mlx_filter_save_load(tmp_path):
"""저장 후 로드한 모델이 동일한 예측값을 반환해야 한다."""
from src.mlx_filter import MLXFilter
n = 50
X = _make_X(n)
y = pd.Series(np.random.randint(0, 2, n))
model = MLXFilter(input_dim=13, hidden_dim=32, epochs=3, batch_size=32)
model.fit(X, y)
proba_before = model.predict_proba(X)
save_path = tmp_path / "mlx_filter.weights"
model.save(save_path)
loaded = MLXFilter.load(save_path)
proba_after = loaded.predict_proba(X)
np.testing.assert_allclose(proba_before, proba_after, atol=1e-5)

View File

@@ -1,35 +0,0 @@
import pytest
import json
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
from src.retrainer import Retrainer
@pytest.mark.asyncio
async def test_retrain_calls_train(tmp_path):
"""재학습 시 train 함수가 호출되는지 확인"""
ml_filter = MagicMock()
r = Retrainer(ml_filter=ml_filter, data_path=str(tmp_path / "data.parquet"))
with patch("src.retrainer.fetch_and_save", new_callable=AsyncMock) as mock_fetch, \
patch("src.retrainer.run_training", return_value=0.72) as mock_train, \
patch("src.retrainer.get_current_auc", return_value=0.65):
await r.retrain()
mock_fetch.assert_called_once()
mock_train.assert_called_once()
@pytest.mark.asyncio
async def test_retrain_rollback_when_worse(tmp_path):
"""새 모델이 기존보다 나쁘면 롤백"""
ml_filter = MagicMock()
r = Retrainer(ml_filter=ml_filter, data_path=str(tmp_path / "data.parquet"))
with patch("src.retrainer.fetch_and_save", new_callable=AsyncMock), \
patch("src.retrainer.run_training", return_value=0.55), \
patch("src.retrainer.get_current_auc", return_value=0.70), \
patch("src.retrainer.rollback_model") as mock_rollback:
await r.retrain()
mock_rollback.assert_called_once()