feat: remove in-container retraining, training is now mac-only
Made-with: Cursor
This commit is contained in:
86
tests/test_mlx_filter.py
Normal file
86
tests/test_mlx_filter.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""
|
||||
MLXFilter 단위 테스트.
|
||||
Apple Silicon GPU(Metal)가 없는 환경에서는 스킵한다.
|
||||
"""
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
mlx = pytest.importorskip("mlx.core", reason="MLX 미설치")
|
||||
|
||||
|
||||
def _make_X(n: int = 4) -> pd.DataFrame:
|
||||
rng = np.random.default_rng(0)
|
||||
return pd.DataFrame(
|
||||
{
|
||||
"rsi": rng.uniform(20, 80, n),
|
||||
"macd_hist": rng.uniform(-0.1, 0.1, n),
|
||||
"bb_pct": rng.uniform(0, 1, n),
|
||||
"ema_align": rng.choice([-1.0, 0.0, 1.0], n),
|
||||
"stoch_k": rng.uniform(0, 100, n),
|
||||
"stoch_d": rng.uniform(0, 100, n),
|
||||
"atr_pct": rng.uniform(0.001, 0.05, n),
|
||||
"vol_ratio": rng.uniform(0.5, 3.0, n),
|
||||
"ret_1": rng.uniform(-0.01, 0.01, n),
|
||||
"ret_3": rng.uniform(-0.02, 0.02, n),
|
||||
"ret_5": rng.uniform(-0.03, 0.03, n),
|
||||
"signal_strength": rng.integers(0, 6, n).astype(float),
|
||||
"side": rng.choice([0.0, 1.0], n),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_mlx_gpu_device():
|
||||
"""MLX가 GPU 디바이스를 기본으로 사용해야 한다."""
|
||||
import mlx.core as mx
|
||||
|
||||
device = mx.default_device()
|
||||
assert "gpu" in str(device)
|
||||
|
||||
|
||||
def test_mlx_filter_predict_shape_untrained():
|
||||
"""학습 전에도 predict_proba가 (N,) 형태를 반환해야 한다."""
|
||||
from src.mlx_filter import MLXFilter
|
||||
|
||||
X = _make_X(4)
|
||||
model = MLXFilter(input_dim=13, hidden_dim=32)
|
||||
proba = model.predict_proba(X)
|
||||
assert proba.shape == (4,)
|
||||
assert np.all((proba >= 0.0) & (proba <= 1.0))
|
||||
|
||||
|
||||
def test_mlx_filter_fit_and_predict():
|
||||
"""학습 후 predict_proba가 유효한 확률값을 반환해야 한다."""
|
||||
from src.mlx_filter import MLXFilter
|
||||
|
||||
n = 100
|
||||
X = _make_X(n)
|
||||
y = pd.Series(np.random.randint(0, 2, n))
|
||||
|
||||
model = MLXFilter(input_dim=13, hidden_dim=32, epochs=5, batch_size=32)
|
||||
model.fit(X, y)
|
||||
proba = model.predict_proba(X)
|
||||
|
||||
assert proba.shape == (n,)
|
||||
assert np.all((proba >= 0.0) & (proba <= 1.0))
|
||||
|
||||
|
||||
def test_mlx_filter_save_load(tmp_path):
|
||||
"""저장 후 로드한 모델이 동일한 예측값을 반환해야 한다."""
|
||||
from src.mlx_filter import MLXFilter
|
||||
|
||||
n = 50
|
||||
X = _make_X(n)
|
||||
y = pd.Series(np.random.randint(0, 2, n))
|
||||
|
||||
model = MLXFilter(input_dim=13, hidden_dim=32, epochs=3, batch_size=32)
|
||||
model.fit(X, y)
|
||||
proba_before = model.predict_proba(X)
|
||||
|
||||
save_path = tmp_path / "mlx_filter.weights"
|
||||
model.save(save_path)
|
||||
|
||||
loaded = MLXFilter.load(save_path)
|
||||
proba_after = loaded.predict_proba(X)
|
||||
|
||||
np.testing.assert_allclose(proba_before, proba_after, atol=1e-5)
|
||||
@@ -1,35 +0,0 @@
|
||||
import pytest
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from src.retrainer import Retrainer
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrain_calls_train(tmp_path):
|
||||
"""재학습 시 train 함수가 호출되는지 확인"""
|
||||
ml_filter = MagicMock()
|
||||
r = Retrainer(ml_filter=ml_filter, data_path=str(tmp_path / "data.parquet"))
|
||||
|
||||
with patch("src.retrainer.fetch_and_save", new_callable=AsyncMock) as mock_fetch, \
|
||||
patch("src.retrainer.run_training", return_value=0.72) as mock_train, \
|
||||
patch("src.retrainer.get_current_auc", return_value=0.65):
|
||||
await r.retrain()
|
||||
|
||||
mock_fetch.assert_called_once()
|
||||
mock_train.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrain_rollback_when_worse(tmp_path):
|
||||
"""새 모델이 기존보다 나쁘면 롤백"""
|
||||
ml_filter = MagicMock()
|
||||
r = Retrainer(ml_filter=ml_filter, data_path=str(tmp_path / "data.parquet"))
|
||||
|
||||
with patch("src.retrainer.fetch_and_save", new_callable=AsyncMock), \
|
||||
patch("src.retrainer.run_training", return_value=0.55), \
|
||||
patch("src.retrainer.get_current_auc", return_value=0.70), \
|
||||
patch("src.retrainer.rollback_model") as mock_rollback:
|
||||
await r.retrain()
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
Reference in New Issue
Block a user