chore: .worktrees/ gitignore에 추가
Made-with: Cursor
This commit is contained in:
@@ -144,6 +144,92 @@ def train_mlx(data_path: str, time_weight_decay: float = 2.0) -> float:
|
||||
return auc
|
||||
|
||||
|
||||
def walk_forward_auc(
|
||||
data_path: str,
|
||||
time_weight_decay: float = 2.0,
|
||||
n_splits: int = 5,
|
||||
train_ratio: float = 0.6,
|
||||
) -> None:
|
||||
"""Walk-Forward 검증: 슬라이딩 윈도우로 n_splits번 학습/검증 반복."""
|
||||
print(f"\n=== Walk-Forward 검증 ({n_splits}폴드, decay={time_weight_decay}) ===")
|
||||
raw = pd.read_parquet(data_path)
|
||||
df, btc_df, eth_df = _split_combined(raw)
|
||||
|
||||
dataset = generate_dataset_vectorized(
|
||||
df, btc_df=btc_df, eth_df=eth_df, time_weight_decay=time_weight_decay
|
||||
)
|
||||
missing = [c for c in FEATURE_COLS if c not in dataset.columns]
|
||||
for col in missing:
|
||||
dataset[col] = 0.0
|
||||
|
||||
X_all = dataset[FEATURE_COLS].values.astype(np.float32)
|
||||
y_all = dataset["label"].values.astype(np.float32)
|
||||
w_all = dataset["sample_weight"].values.astype(np.float32)
|
||||
n = len(dataset)
|
||||
|
||||
step = max(1, int(n * (1 - train_ratio) / n_splits))
|
||||
train_end_start = int(n * train_ratio)
|
||||
|
||||
aucs = []
|
||||
for i in range(n_splits):
|
||||
tr_end = train_end_start + i * step
|
||||
val_end = tr_end + step
|
||||
if val_end > n:
|
||||
break
|
||||
|
||||
X_tr_raw = X_all[:tr_end]
|
||||
y_tr = y_all[:tr_end]
|
||||
w_tr = w_all[:tr_end]
|
||||
X_val_raw = X_all[tr_end:val_end]
|
||||
y_val = y_all[tr_end:val_end]
|
||||
|
||||
pos_idx = np.where(y_tr == 1)[0]
|
||||
neg_idx = np.where(y_tr == 0)[0]
|
||||
if len(neg_idx) > len(pos_idx):
|
||||
np.random.seed(42)
|
||||
neg_idx = np.random.choice(neg_idx, size=len(pos_idx), replace=False)
|
||||
bal_idx = np.sort(np.concatenate([pos_idx, neg_idx]))
|
||||
|
||||
X_tr_bal = X_tr_raw[bal_idx]
|
||||
y_tr_bal = y_tr[bal_idx]
|
||||
w_tr_bal = w_tr[bal_idx]
|
||||
|
||||
# 폴드별 정규화 (학습 데이터 기준으로 계산, 검증에도 동일 적용)
|
||||
mean = X_tr_bal.mean(axis=0)
|
||||
std = X_tr_bal.std(axis=0) + 1e-8
|
||||
X_tr_norm = (X_tr_bal - mean) / std
|
||||
X_val_norm = (X_val_raw - mean) / std
|
||||
|
||||
# DataFrame으로 래핑해서 MLXFilter.fit()에 전달
|
||||
# fit() 내부 정규화가 덮어쓰지 않도록 이미 정규화된 데이터를 넘기고
|
||||
# _mean=0, _std=1로 고정해 이중 정규화를 방지
|
||||
X_tr_df = pd.DataFrame(X_tr_norm, columns=FEATURE_COLS)
|
||||
X_val_df = pd.DataFrame(X_val_norm, columns=FEATURE_COLS)
|
||||
|
||||
model = MLXFilter(
|
||||
input_dim=len(FEATURE_COLS),
|
||||
hidden_dim=128,
|
||||
lr=1e-3,
|
||||
epochs=100,
|
||||
batch_size=256,
|
||||
)
|
||||
model.fit(X_tr_df, pd.Series(y_tr_bal), sample_weight=w_tr_bal)
|
||||
# fit()이 내부에서 다시 정규화하므로 저장된 mean/std를 항등 변환으로 교체
|
||||
model._mean = np.zeros(len(FEATURE_COLS), dtype=np.float32)
|
||||
model._std = np.ones(len(FEATURE_COLS), dtype=np.float32)
|
||||
|
||||
proba = model.predict_proba(X_val_df)
|
||||
auc = roc_auc_score(y_val, proba) if len(np.unique(y_val)) > 1 else 0.5
|
||||
aucs.append(auc)
|
||||
print(
|
||||
f" 폴드 {i+1}/{n_splits}: 학습={tr_end}개, "
|
||||
f"검증={tr_end}~{val_end} ({step}개), AUC={auc:.4f}"
|
||||
)
|
||||
|
||||
print(f"\n Walk-Forward 평균 AUC: {np.mean(aucs):.4f} ± {np.std(aucs):.4f}")
|
||||
print(f" 폴드별: {[round(a, 4) for a in aucs]}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data", default="data/combined_15m.parquet")
|
||||
@@ -151,8 +237,14 @@ def main():
|
||||
"--decay", type=float, default=2.0,
|
||||
help="시간 가중치 감쇠 강도 (0=균등, 2.0=최신이 ~7.4배 높음)",
|
||||
)
|
||||
parser.add_argument("--wf", action="store_true", help="Walk-Forward 검증 실행")
|
||||
parser.add_argument("--wf-splits", type=int, default=5, help="Walk-Forward 폴드 수")
|
||||
args = parser.parse_args()
|
||||
train_mlx(args.data, time_weight_decay=args.decay)
|
||||
|
||||
if args.wf:
|
||||
walk_forward_auc(args.data, time_weight_decay=args.decay, n_splits=args.wf_splits)
|
||||
else:
|
||||
train_mlx(args.data, time_weight_decay=args.decay)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user