feat: enhance model training and deployment scripts with time-weighted sampling
- Updated `train_model.py` and `train_mlx_model.py` to include a time weight decay parameter for improved sample weighting during training. - Modified dataset generation to incorporate sample weights based on time decay, enhancing model performance. - Adjusted deployment scripts to support new backend options and improved error handling for model file transfers. - Added new entries to the training log for better tracking of model performance metrics over time. - Included ONNX model export functionality in the MLX filter for compatibility with Linux servers.
This commit is contained in:
@@ -146,7 +146,7 @@ def generate_dataset(df: pd.DataFrame, n_jobs: int | None = None) -> pd.DataFram
|
||||
return pd.DataFrame(rows)
|
||||
|
||||
|
||||
def train(data_path: str):
|
||||
def train(data_path: str, time_weight_decay: float = 2.0):
|
||||
print(f"데이터 로드: {data_path}")
|
||||
df_raw = pd.read_parquet(data_path)
|
||||
print(f"캔들 수: {len(df_raw)}, 컬럼: {list(df_raw.columns)}")
|
||||
@@ -169,7 +169,7 @@ def train(data_path: str):
|
||||
df = df_raw[base_cols].copy()
|
||||
|
||||
print("데이터셋 생성 중...")
|
||||
dataset = generate_dataset_vectorized(df, btc_df=btc_df, eth_df=eth_df)
|
||||
dataset = generate_dataset_vectorized(df, btc_df=btc_df, eth_df=eth_df, time_weight_decay=time_weight_decay)
|
||||
|
||||
if dataset.empty or "label" not in dataset.columns:
|
||||
raise ValueError(f"데이터셋 생성 실패: 샘플 0개. 위 오류 메시지를 확인하세요.")
|
||||
@@ -183,10 +183,30 @@ def train(data_path: str):
|
||||
print(f"사용 피처: {len(actual_feature_cols)}개 {actual_feature_cols}")
|
||||
X = dataset[actual_feature_cols]
|
||||
y = dataset["label"]
|
||||
w = dataset["sample_weight"].values
|
||||
|
||||
split = int(len(X) * 0.8)
|
||||
X_train, X_val = X.iloc[:split], X.iloc[split:]
|
||||
y_train, y_val = y.iloc[:split], y.iloc[split:]
|
||||
w_train = w[:split]
|
||||
|
||||
# --- 클래스 불균형 처리: 언더샘플링 (가중치 인덱스 보존) ---
|
||||
pos_idx = np.where(y_train == 1)[0]
|
||||
neg_idx = np.where(y_train == 0)[0]
|
||||
|
||||
if len(neg_idx) > len(pos_idx):
|
||||
np.random.seed(42)
|
||||
neg_idx = np.random.choice(neg_idx, size=len(pos_idx), replace=False)
|
||||
|
||||
balanced_idx = np.concatenate([pos_idx, neg_idx])
|
||||
np.random.shuffle(balanced_idx)
|
||||
|
||||
X_train = X_train.iloc[balanced_idx]
|
||||
y_train = y_train.iloc[balanced_idx]
|
||||
w_train = w_train[balanced_idx]
|
||||
|
||||
print(f"\n언더샘플링 적용 후 학습 데이터: {len(X_train)}개 (양성={y_train.sum()}, 음성={(y_train==0).sum()})")
|
||||
# --------------------------------------
|
||||
|
||||
model = lgb.LGBMClassifier(
|
||||
n_estimators=300,
|
||||
@@ -201,6 +221,7 @@ def train(data_path: str):
|
||||
)
|
||||
model.fit(
|
||||
X_train, y_train,
|
||||
sample_weight=w_train,
|
||||
eval_set=[(X_val, y_val)],
|
||||
callbacks=[lgb.early_stopping(30, verbose=False), lgb.log_evaluation(50)],
|
||||
)
|
||||
@@ -225,9 +246,11 @@ def train(data_path: str):
|
||||
log = json.load(f)
|
||||
log.append({
|
||||
"date": datetime.now().isoformat(),
|
||||
"backend": "lgbm",
|
||||
"auc": round(auc, 4),
|
||||
"samples": len(dataset),
|
||||
"features": len(actual_feature_cols),
|
||||
"time_weight_decay": time_weight_decay,
|
||||
"model_path": str(MODEL_PATH),
|
||||
})
|
||||
with open(LOG_PATH, "w") as f:
|
||||
@@ -239,8 +262,12 @@ def train(data_path: str):
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data", default="data/xrpusdt_1m.parquet")
|
||||
parser.add_argument(
|
||||
"--decay", type=float, default=2.0,
|
||||
help="시간 가중치 감쇠 강도 (0=균등, 2.0=최신이 ~7.4배 높음)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
train(args.data)
|
||||
train(args.data, time_weight_decay=args.decay)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user