feat: LightGBM 임계값 탐색을 정밀도 우선(recall>=0.15 조건부)으로 변경
Made-with: Cursor
This commit is contained in:
@@ -233,16 +233,26 @@ def train(data_path: str, time_weight_decay: float = 2.0):
|
|||||||
|
|
||||||
val_proba = model.predict_proba(X_val)[:, 1]
|
val_proba = model.predict_proba(X_val)[:, 1]
|
||||||
auc = roc_auc_score(y_val, val_proba)
|
auc = roc_auc_score(y_val, val_proba)
|
||||||
# 최적 임계값 탐색 (F1 기준)
|
|
||||||
thresholds = np.arange(0.40, 0.70, 0.05)
|
# 최적 임계값 탐색: 최소 재현율(0.15) 조건부 정밀도 최대화
|
||||||
best_thr, best_f1 = 0.50, 0.0
|
from sklearn.metrics import precision_recall_curve
|
||||||
for thr in thresholds:
|
precisions, recalls, thresholds = precision_recall_curve(y_val, val_proba)
|
||||||
pred = (val_proba >= thr).astype(int)
|
# precision_recall_curve의 마지막 원소는 (1.0, 0.0)이므로 제외
|
||||||
from sklearn.metrics import f1_score
|
precisions, recalls = precisions[:-1], recalls[:-1]
|
||||||
f1 = f1_score(y_val, pred, zero_division=0)
|
|
||||||
if f1 > best_f1:
|
MIN_RECALL = 0.15
|
||||||
best_f1, best_thr = f1, thr
|
valid_idx = np.where(recalls >= MIN_RECALL)[0]
|
||||||
print(f"\n검증 AUC: {auc:.4f} | 최적 임계값: {best_thr:.2f} (F1={best_f1:.3f})")
|
if len(valid_idx) > 0:
|
||||||
|
best_idx = valid_idx[np.argmax(precisions[valid_idx])]
|
||||||
|
best_thr = float(thresholds[best_idx])
|
||||||
|
best_prec = float(precisions[best_idx])
|
||||||
|
best_rec = float(recalls[best_idx])
|
||||||
|
else:
|
||||||
|
best_thr, best_prec, best_rec = 0.50, 0.0, 0.0
|
||||||
|
print(f" [경고] recall >= {MIN_RECALL} 조건 만족 임계값 없음 → 기본값 0.50 사용")
|
||||||
|
|
||||||
|
print(f"\n검증 AUC: {auc:.4f} | 최적 임계값: {best_thr:.4f} "
|
||||||
|
f"(Precision={best_prec:.3f}, Recall={best_rec:.3f})")
|
||||||
print(classification_report(y_val, (val_proba >= best_thr).astype(int), zero_division=0))
|
print(classification_report(y_val, (val_proba >= best_thr).astype(int), zero_division=0))
|
||||||
|
|
||||||
if MODEL_PATH.exists():
|
if MODEL_PATH.exists():
|
||||||
@@ -262,6 +272,9 @@ def train(data_path: str, time_weight_decay: float = 2.0):
|
|||||||
"date": datetime.now().isoformat(),
|
"date": datetime.now().isoformat(),
|
||||||
"backend": "lgbm",
|
"backend": "lgbm",
|
||||||
"auc": round(auc, 4),
|
"auc": round(auc, 4),
|
||||||
|
"best_threshold": round(best_thr, 4),
|
||||||
|
"best_precision": round(best_prec, 3),
|
||||||
|
"best_recall": round(best_rec, 3),
|
||||||
"samples": len(dataset),
|
"samples": len(dataset),
|
||||||
"features": len(actual_feature_cols),
|
"features": len(actual_feature_cols),
|
||||||
"time_weight_decay": time_weight_decay,
|
"time_weight_decay": time_weight_decay,
|
||||||
|
|||||||
Reference in New Issue
Block a user