Files
cointrader/tests/test_retrainer.py
21in7 7e4e9315c2 feat: implement ML filter with LightGBM for trading signal validation
- Added MLFilter class to load and evaluate LightGBM model for trading signals.
- Introduced retraining mechanism to update the model daily based on new data.
- Created feature engineering and label building utilities for model training.
- Updated bot logic to incorporate ML filter for signal validation.
- Added scripts for data fetching and model training.

Made-with: Cursor
2026-03-01 17:07:18 +09:00

36 lines
1.3 KiB
Python

import pytest
import json
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
from src.retrainer import Retrainer
@pytest.mark.asyncio
async def test_retrain_calls_train(tmp_path):
"""재학습 시 train 함수가 호출되는지 확인"""
ml_filter = MagicMock()
r = Retrainer(ml_filter=ml_filter, data_path=str(tmp_path / "data.parquet"))
with patch("src.retrainer.fetch_and_save", new_callable=AsyncMock) as mock_fetch, \
patch("src.retrainer.run_training", return_value=0.72) as mock_train, \
patch("src.retrainer.get_current_auc", return_value=0.65):
await r.retrain()
mock_fetch.assert_called_once()
mock_train.assert_called_once()
@pytest.mark.asyncio
async def test_retrain_rollback_when_worse(tmp_path):
"""새 모델이 기존보다 나쁘면 롤백"""
ml_filter = MagicMock()
r = Retrainer(ml_filter=ml_filter, data_path=str(tmp_path / "data.parquet"))
with patch("src.retrainer.fetch_and_save", new_callable=AsyncMock), \
patch("src.retrainer.run_training", return_value=0.55), \
patch("src.retrainer.get_current_auc", return_value=0.70), \
patch("src.retrainer.rollback_model") as mock_rollback:
await r.retrain()
mock_rollback.assert_called_once()