fix(ml): pass SL/TP multipliers to dataset generation — align train/serve

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
21in7
2026-03-21 18:16:50 +09:00
parent 75d1af7fcc
commit 0cc5835b3a
4 changed files with 41 additions and 14 deletions

View File

@@ -45,7 +45,7 @@ def _split_combined(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame | None
return xrp_df, btc_df, eth_df
def train_mlx(data_path: str, time_weight_decay: float = 2.0) -> float:
def train_mlx(data_path: str, time_weight_decay: float = 2.0, atr_sl_mult: float = 2.0, atr_tp_mult: float = 2.0) -> float:
print(f"데이터 로드: {data_path}")
raw = pd.read_parquet(data_path)
print(f"캔들 수: {len(raw)}")
@@ -58,7 +58,8 @@ def train_mlx(data_path: str, time_weight_decay: float = 2.0) -> float:
print("\n데이터셋 생성 중...")
t0 = time.perf_counter()
dataset = generate_dataset_vectorized(df, btc_df=btc_df, eth_df=eth_df, time_weight_decay=time_weight_decay)
dataset = generate_dataset_vectorized(df, btc_df=btc_df, eth_df=eth_df, time_weight_decay=time_weight_decay,
atr_sl_mult=atr_sl_mult, atr_tp_mult=atr_tp_mult)
t1 = time.perf_counter()
print(f"데이터셋 생성 완료: {t1 - t0:.1f}초, {len(dataset)}개 샘플")
@@ -170,6 +171,8 @@ def walk_forward_auc(
time_weight_decay: float = 2.0,
n_splits: int = 5,
train_ratio: float = 0.6,
atr_sl_mult: float = 2.0,
atr_tp_mult: float = 2.0,
) -> None:
"""Walk-Forward 검증: 슬라이딩 윈도우로 n_splits번 학습/검증 반복."""
print(f"\n=== Walk-Forward 검증 ({n_splits}폴드, decay={time_weight_decay}) ===")
@@ -177,7 +180,8 @@ def walk_forward_auc(
df, btc_df, eth_df = _split_combined(raw)
dataset = generate_dataset_vectorized(
df, btc_df=btc_df, eth_df=eth_df, time_weight_decay=time_weight_decay
df, btc_df=btc_df, eth_df=eth_df, time_weight_decay=time_weight_decay,
atr_sl_mult=atr_sl_mult, atr_tp_mult=atr_tp_mult,
)
missing = [c for c in FEATURE_COLS if c not in dataset.columns]
for col in missing:
@@ -260,12 +264,16 @@ def main():
)
parser.add_argument("--wf", action="store_true", help="Walk-Forward 검증 실행")
parser.add_argument("--wf-splits", type=int, default=5, help="Walk-Forward 폴드 수")
parser.add_argument("--sl-mult", type=float, default=2.0, help="SL ATR 배수 (기본 2.0)")
parser.add_argument("--tp-mult", type=float, default=2.0, help="TP ATR 배수 (기본 2.0)")
args = parser.parse_args()
if args.wf:
walk_forward_auc(args.data, time_weight_decay=args.decay, n_splits=args.wf_splits)
walk_forward_auc(args.data, time_weight_decay=args.decay, n_splits=args.wf_splits,
atr_sl_mult=args.sl_mult, atr_tp_mult=args.tp_mult)
else:
train_mlx(args.data, time_weight_decay=args.decay)
train_mlx(args.data, time_weight_decay=args.decay,
atr_sl_mult=args.sl_mult, atr_tp_mult=args.tp_mult)
if __name__ == "__main__":