feat: enhance model training and deployment scripts with time-weighted sampling
- Updated `train_model.py` and `train_mlx_model.py` to include a time weight decay parameter for improved sample weighting during training. - Modified dataset generation to incorporate sample weights based on time decay, enhancing model performance. - Adjusted deployment scripts to support new backend options and improved error handling for model file transfers. - Added new entries to the training log for better tracking of model performance metrics over time. - Included ONNX model export functionality in the MLX filter for compatibility with Linux servers.
This commit is contained in:
@@ -1,32 +1,63 @@
|
||||
from pathlib import Path
|
||||
import joblib
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from loguru import logger
|
||||
|
||||
from src.ml_features import FEATURE_COLS
|
||||
|
||||
ONNX_MODEL_PATH = Path("models/mlx_filter.weights.onnx")
|
||||
LGBM_MODEL_PATH = Path("models/lgbm_filter.pkl")
|
||||
|
||||
|
||||
class MLFilter:
|
||||
"""
|
||||
LightGBM 모델을 로드하고 진입 여부를 판단한다.
|
||||
모델 파일이 없으면 항상 진입을 허용한다 (폴백).
|
||||
ML 필터. ONNX(MLX 신경망) 우선 로드, 없으면 LightGBM으로 폴백한다.
|
||||
둘 다 없으면 항상 진입을 허용한다.
|
||||
|
||||
우선순위: ONNX > LightGBM > 폴백(항상 허용)
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: str = "models/lgbm_filter.pkl", threshold: float = 0.60):
|
||||
self._model_path = Path(model_path)
|
||||
def __init__(
|
||||
self,
|
||||
onnx_path: str = str(ONNX_MODEL_PATH),
|
||||
lgbm_path: str = str(LGBM_MODEL_PATH),
|
||||
threshold: float = 0.60,
|
||||
):
|
||||
self._onnx_path = Path(onnx_path)
|
||||
self._lgbm_path = Path(lgbm_path)
|
||||
self._threshold = threshold
|
||||
self._model = None
|
||||
self._onnx_session = None
|
||||
self._lgbm_model = None
|
||||
self._try_load()
|
||||
|
||||
def _try_load(self):
|
||||
if self._model_path.exists():
|
||||
# ONNX 우선 시도
|
||||
if self._onnx_path.exists():
|
||||
try:
|
||||
self._model = joblib.load(self._model_path)
|
||||
logger.info(f"ML 필터 모델 로드 완료: {self._model_path}")
|
||||
import onnxruntime as ort
|
||||
self._onnx_session = ort.InferenceSession(
|
||||
str(self._onnx_path),
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
self._lgbm_model = None
|
||||
logger.info(f"ML 필터 ONNX 모델 로드 완료: {self._onnx_path}")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"ML 필터 모델 로드 실패: {e}")
|
||||
self._model = None
|
||||
logger.warning(f"ONNX 모델 로드 실패: {e}")
|
||||
self._onnx_session = None
|
||||
|
||||
# LightGBM 폴백
|
||||
if self._lgbm_path.exists():
|
||||
try:
|
||||
self._lgbm_model = joblib.load(self._lgbm_path)
|
||||
logger.info(f"ML 필터 LightGBM 모델 로드 완료: {self._lgbm_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"LightGBM 모델 로드 실패: {e}")
|
||||
self._lgbm_model = None
|
||||
|
||||
def is_model_loaded(self) -> bool:
|
||||
return self._model is not None
|
||||
return self._onnx_session is not None or self._lgbm_model is not None
|
||||
|
||||
def should_enter(self, features: pd.Series) -> bool:
|
||||
"""
|
||||
@@ -36,8 +67,13 @@ class MLFilter:
|
||||
if not self.is_model_loaded():
|
||||
return True
|
||||
try:
|
||||
X = features.to_frame().T
|
||||
proba = self._model.predict_proba(X)[0][1]
|
||||
if self._onnx_session is not None:
|
||||
input_name = self._onnx_session.get_inputs()[0].name
|
||||
X = features[FEATURE_COLS].values.astype(np.float32).reshape(1, -1)
|
||||
proba = float(self._onnx_session.run(None, {input_name: X})[0][0])
|
||||
else:
|
||||
X = features.to_frame().T
|
||||
proba = float(self._lgbm_model.predict_proba(X)[0][1])
|
||||
logger.debug(f"ML 필터 확률: {proba:.3f} (임계값: {self._threshold})")
|
||||
return bool(proba >= self._threshold)
|
||||
except Exception as e:
|
||||
@@ -46,5 +82,7 @@ class MLFilter:
|
||||
|
||||
def reload_model(self):
|
||||
"""재학습 후 모델을 핫 리로드한다."""
|
||||
self._onnx_session = None
|
||||
self._lgbm_model = None
|
||||
self._try_load()
|
||||
logger.info("ML 필터 모델 리로드 완료")
|
||||
|
||||
Reference in New Issue
Block a user