- Updated `train_model.py` and `train_mlx_model.py` to include a time weight decay parameter for improved sample weighting during training. - Modified dataset generation to incorporate sample weights based on time decay, enhancing model performance. - Adjusted deployment scripts to support new backend options and improved error handling for model file transfers. - Added new entries to the training log for better tracking of model performance metrics over time. - Included ONNX model export functionality in the MLX filter for compatibility with Linux servers.
89 lines
3.1 KiB
Python
89 lines
3.1 KiB
Python
from pathlib import Path
|
|
import joblib
|
|
import numpy as np
|
|
import pandas as pd
|
|
from loguru import logger
|
|
|
|
from src.ml_features import FEATURE_COLS
|
|
|
|
ONNX_MODEL_PATH = Path("models/mlx_filter.weights.onnx")
|
|
LGBM_MODEL_PATH = Path("models/lgbm_filter.pkl")
|
|
|
|
|
|
class MLFilter:
|
|
"""
|
|
ML 필터. ONNX(MLX 신경망) 우선 로드, 없으면 LightGBM으로 폴백한다.
|
|
둘 다 없으면 항상 진입을 허용한다.
|
|
|
|
우선순위: ONNX > LightGBM > 폴백(항상 허용)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
onnx_path: str = str(ONNX_MODEL_PATH),
|
|
lgbm_path: str = str(LGBM_MODEL_PATH),
|
|
threshold: float = 0.60,
|
|
):
|
|
self._onnx_path = Path(onnx_path)
|
|
self._lgbm_path = Path(lgbm_path)
|
|
self._threshold = threshold
|
|
self._onnx_session = None
|
|
self._lgbm_model = None
|
|
self._try_load()
|
|
|
|
def _try_load(self):
|
|
# ONNX 우선 시도
|
|
if self._onnx_path.exists():
|
|
try:
|
|
import onnxruntime as ort
|
|
self._onnx_session = ort.InferenceSession(
|
|
str(self._onnx_path),
|
|
providers=["CPUExecutionProvider"],
|
|
)
|
|
self._lgbm_model = None
|
|
logger.info(f"ML 필터 ONNX 모델 로드 완료: {self._onnx_path}")
|
|
return
|
|
except Exception as e:
|
|
logger.warning(f"ONNX 모델 로드 실패: {e}")
|
|
self._onnx_session = None
|
|
|
|
# LightGBM 폴백
|
|
if self._lgbm_path.exists():
|
|
try:
|
|
self._lgbm_model = joblib.load(self._lgbm_path)
|
|
logger.info(f"ML 필터 LightGBM 모델 로드 완료: {self._lgbm_path}")
|
|
except Exception as e:
|
|
logger.warning(f"LightGBM 모델 로드 실패: {e}")
|
|
self._lgbm_model = None
|
|
|
|
def is_model_loaded(self) -> bool:
|
|
return self._onnx_session is not None or self._lgbm_model is not None
|
|
|
|
def should_enter(self, features: pd.Series) -> bool:
|
|
"""
|
|
확률 >= threshold 이면 True (진입 허용).
|
|
모델 없으면 True 반환 (폴백).
|
|
"""
|
|
if not self.is_model_loaded():
|
|
return True
|
|
try:
|
|
if self._onnx_session is not None:
|
|
input_name = self._onnx_session.get_inputs()[0].name
|
|
X = features[FEATURE_COLS].values.astype(np.float32).reshape(1, -1)
|
|
proba = float(self._onnx_session.run(None, {input_name: X})[0][0])
|
|
else:
|
|
X = features.to_frame().T
|
|
proba = float(self._lgbm_model.predict_proba(X)[0][1])
|
|
logger.debug(f"ML 필터 확률: {proba:.3f} (임계값: {self._threshold})")
|
|
return bool(proba >= self._threshold)
|
|
except Exception as e:
|
|
logger.warning(f"ML 필터 예측 오류 (폴백 허용): {e}")
|
|
return True
|
|
|
|
def reload_model(self):
|
|
"""재학습 후 모델을 핫 리로드한다."""
|
|
self._onnx_session = None
|
|
self._lgbm_model = None
|
|
self._try_load()
|
|
logger.info("ML 필터 모델 리로드 완료")
|