feat: connect Optuna tuning results to train_model.py via --tuned-params
- _load_lgbm_params() 헬퍼 추가: 기본 파라미터 반환, JSON 주어지면 덮어씀 - train(): tuned_params_path 인자 추가, weight_scale 적용 - walk_forward_auc(): tuned_params_path 인자 추가, weight_scale 적용 - main(): --tuned-params argparse 인자 추가, 두 함수에 전달 - training_log.json에 tuned_params_path, lgbm_params, weight_scale 기록 Made-with: Cursor
This commit is contained in:
@@ -146,7 +146,36 @@ def generate_dataset(df: pd.DataFrame, n_jobs: int | None = None) -> pd.DataFram
|
||||
return pd.DataFrame(rows)
|
||||
|
||||
|
||||
def train(data_path: str, time_weight_decay: float = 2.0):
|
||||
def _load_lgbm_params(tuned_params_path: str | None) -> tuple[dict, float]:
|
||||
"""기본 LightGBM 파라미터를 반환하고, 튜닝 JSON이 주어지면 덮어쓴다.
|
||||
반환: (lgbm_params, weight_scale)
|
||||
"""
|
||||
lgbm_params: dict = {
|
||||
"n_estimators": 500,
|
||||
"learning_rate": 0.05,
|
||||
"num_leaves": 31,
|
||||
"min_child_samples": 15,
|
||||
"subsample": 0.8,
|
||||
"colsample_bytree": 0.8,
|
||||
"reg_alpha": 0.05,
|
||||
"reg_lambda": 0.1,
|
||||
}
|
||||
weight_scale = 1.0
|
||||
|
||||
if tuned_params_path:
|
||||
with open(tuned_params_path, "r", encoding="utf-8") as f:
|
||||
tune_data = json.load(f)
|
||||
best_params = dict(tune_data["best_trial"]["params"])
|
||||
weight_scale = float(best_params.pop("weight_scale", 1.0))
|
||||
lgbm_params.update(best_params)
|
||||
print(f"\n[Optuna] 튜닝 파라미터 로드: {tuned_params_path}")
|
||||
print(f"[Optuna] 적용 파라미터: {lgbm_params}")
|
||||
print(f"[Optuna] weight_scale: {weight_scale}\n")
|
||||
|
||||
return lgbm_params, weight_scale
|
||||
|
||||
|
||||
def train(data_path: str, time_weight_decay: float = 2.0, tuned_params_path: str | None = None):
|
||||
print(f"데이터 로드: {data_path}")
|
||||
df_raw = pd.read_parquet(data_path)
|
||||
print(f"캔들 수: {len(df_raw)}, 컬럼: {list(df_raw.columns)}")
|
||||
@@ -188,7 +217,10 @@ def train(data_path: str, time_weight_decay: float = 2.0):
|
||||
split = int(len(X) * 0.8)
|
||||
X_train, X_val = X.iloc[:split], X.iloc[split:]
|
||||
y_train, y_val = y.iloc[:split], y.iloc[split:]
|
||||
w_train = w[:split]
|
||||
|
||||
# 튜닝 파라미터 로드 (없으면 기본값 사용)
|
||||
lgbm_params, weight_scale = _load_lgbm_params(tuned_params_path)
|
||||
w_train = (w[:split] * weight_scale).astype(np.float32)
|
||||
|
||||
# --- 클래스 불균형 처리: 언더샘플링 (시간 가중치 인덱스 보존) ---
|
||||
pos_idx = np.where(y_train == 1)[0]
|
||||
@@ -208,18 +240,7 @@ def train(data_path: str, time_weight_decay: float = 2.0):
|
||||
print(f"검증 데이터: {len(X_val)}개 (양성={int(y_val.sum())}, 음성={int((y_val==0).sum())})")
|
||||
# ---------------------------------------------------------------
|
||||
|
||||
model = lgb.LGBMClassifier(
|
||||
n_estimators=500,
|
||||
learning_rate=0.05,
|
||||
num_leaves=31,
|
||||
min_child_samples=15,
|
||||
subsample=0.8,
|
||||
colsample_bytree=0.8,
|
||||
reg_alpha=0.05,
|
||||
reg_lambda=0.1,
|
||||
random_state=42,
|
||||
verbose=-1,
|
||||
)
|
||||
model = lgb.LGBMClassifier(**lgbm_params, random_state=42, verbose=-1)
|
||||
model.fit(
|
||||
X_train, y_train,
|
||||
sample_weight=w_train,
|
||||
@@ -268,7 +289,7 @@ def train(data_path: str, time_weight_decay: float = 2.0):
|
||||
if LOG_PATH.exists():
|
||||
with open(LOG_PATH) as f:
|
||||
log = json.load(f)
|
||||
log.append({
|
||||
log_entry: dict = {
|
||||
"date": datetime.now().isoformat(),
|
||||
"backend": "lgbm",
|
||||
"auc": round(auc, 4),
|
||||
@@ -279,7 +300,11 @@ def train(data_path: str, time_weight_decay: float = 2.0):
|
||||
"features": len(actual_feature_cols),
|
||||
"time_weight_decay": time_weight_decay,
|
||||
"model_path": str(MODEL_PATH),
|
||||
})
|
||||
"tuned_params_path": tuned_params_path,
|
||||
"lgbm_params": lgbm_params,
|
||||
"weight_scale": weight_scale,
|
||||
}
|
||||
log.append(log_entry)
|
||||
with open(LOG_PATH, "w") as f:
|
||||
json.dump(log, f, indent=2)
|
||||
|
||||
@@ -291,6 +316,7 @@ def walk_forward_auc(
|
||||
time_weight_decay: float = 2.0,
|
||||
n_splits: int = 5,
|
||||
train_ratio: float = 0.6,
|
||||
tuned_params_path: str | None = None,
|
||||
) -> None:
|
||||
"""Walk-Forward 검증: 슬라이딩 윈도우로 n_splits번 학습/검증 반복.
|
||||
|
||||
@@ -320,6 +346,9 @@ def walk_forward_auc(
|
||||
w = dataset["sample_weight"].values
|
||||
n = len(dataset)
|
||||
|
||||
lgbm_params, weight_scale = _load_lgbm_params(tuned_params_path)
|
||||
w = (w * weight_scale).astype(np.float32)
|
||||
|
||||
step = max(1, int(n * (1 - train_ratio) / n_splits))
|
||||
train_end_start = int(n * train_ratio)
|
||||
|
||||
@@ -340,18 +369,7 @@ def walk_forward_auc(
|
||||
neg_idx = np.random.choice(neg_idx, size=len(pos_idx), replace=False)
|
||||
idx = np.sort(np.concatenate([pos_idx, neg_idx]))
|
||||
|
||||
model = lgb.LGBMClassifier(
|
||||
n_estimators=500,
|
||||
learning_rate=0.05,
|
||||
num_leaves=31,
|
||||
min_child_samples=15,
|
||||
subsample=0.8,
|
||||
colsample_bytree=0.8,
|
||||
reg_alpha=0.05,
|
||||
reg_lambda=0.1,
|
||||
random_state=42,
|
||||
verbose=-1,
|
||||
)
|
||||
model = lgb.LGBMClassifier(**lgbm_params, random_state=42, verbose=-1)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
model.fit(X_tr[idx], y_tr[idx], sample_weight=w_tr[idx])
|
||||
@@ -377,12 +395,21 @@ def main():
|
||||
)
|
||||
parser.add_argument("--wf", action="store_true", help="Walk-Forward 검증 실행")
|
||||
parser.add_argument("--wf-splits", type=int, default=5, help="Walk-Forward 폴드 수")
|
||||
parser.add_argument(
|
||||
"--tuned-params", type=str, default=None,
|
||||
help="Optuna 튜닝 결과 JSON 경로 (지정 시 기본 파라미터를 덮어씀)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.wf:
|
||||
walk_forward_auc(args.data, time_weight_decay=args.decay, n_splits=args.wf_splits)
|
||||
walk_forward_auc(
|
||||
args.data,
|
||||
time_weight_decay=args.decay,
|
||||
n_splits=args.wf_splits,
|
||||
tuned_params_path=args.tuned_params,
|
||||
)
|
||||
else:
|
||||
train(args.data, time_weight_decay=args.decay)
|
||||
train(args.data, time_weight_decay=args.decay, tuned_params_path=args.tuned_params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user