fix(ml): pass SL/TP multipliers to dataset generation — align train/serve
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -45,7 +45,7 @@ def _split_combined(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame | None
|
||||
return xrp_df, btc_df, eth_df
|
||||
|
||||
|
||||
def train_mlx(data_path: str, time_weight_decay: float = 2.0) -> float:
|
||||
def train_mlx(data_path: str, time_weight_decay: float = 2.0, atr_sl_mult: float = 2.0, atr_tp_mult: float = 2.0) -> float:
|
||||
print(f"데이터 로드: {data_path}")
|
||||
raw = pd.read_parquet(data_path)
|
||||
print(f"캔들 수: {len(raw)}")
|
||||
@@ -58,7 +58,8 @@ def train_mlx(data_path: str, time_weight_decay: float = 2.0) -> float:
|
||||
|
||||
print("\n데이터셋 생성 중...")
|
||||
t0 = time.perf_counter()
|
||||
dataset = generate_dataset_vectorized(df, btc_df=btc_df, eth_df=eth_df, time_weight_decay=time_weight_decay)
|
||||
dataset = generate_dataset_vectorized(df, btc_df=btc_df, eth_df=eth_df, time_weight_decay=time_weight_decay,
|
||||
atr_sl_mult=atr_sl_mult, atr_tp_mult=atr_tp_mult)
|
||||
t1 = time.perf_counter()
|
||||
print(f"데이터셋 생성 완료: {t1 - t0:.1f}초, {len(dataset)}개 샘플")
|
||||
|
||||
@@ -170,6 +171,8 @@ def walk_forward_auc(
|
||||
time_weight_decay: float = 2.0,
|
||||
n_splits: int = 5,
|
||||
train_ratio: float = 0.6,
|
||||
atr_sl_mult: float = 2.0,
|
||||
atr_tp_mult: float = 2.0,
|
||||
) -> None:
|
||||
"""Walk-Forward 검증: 슬라이딩 윈도우로 n_splits번 학습/검증 반복."""
|
||||
print(f"\n=== Walk-Forward 검증 ({n_splits}폴드, decay={time_weight_decay}) ===")
|
||||
@@ -177,7 +180,8 @@ def walk_forward_auc(
|
||||
df, btc_df, eth_df = _split_combined(raw)
|
||||
|
||||
dataset = generate_dataset_vectorized(
|
||||
df, btc_df=btc_df, eth_df=eth_df, time_weight_decay=time_weight_decay
|
||||
df, btc_df=btc_df, eth_df=eth_df, time_weight_decay=time_weight_decay,
|
||||
atr_sl_mult=atr_sl_mult, atr_tp_mult=atr_tp_mult,
|
||||
)
|
||||
missing = [c for c in FEATURE_COLS if c not in dataset.columns]
|
||||
for col in missing:
|
||||
@@ -260,12 +264,16 @@ def main():
|
||||
)
|
||||
parser.add_argument("--wf", action="store_true", help="Walk-Forward 검증 실행")
|
||||
parser.add_argument("--wf-splits", type=int, default=5, help="Walk-Forward 폴드 수")
|
||||
parser.add_argument("--sl-mult", type=float, default=2.0, help="SL ATR 배수 (기본 2.0)")
|
||||
parser.add_argument("--tp-mult", type=float, default=2.0, help="TP ATR 배수 (기본 2.0)")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.wf:
|
||||
walk_forward_auc(args.data, time_weight_decay=args.decay, n_splits=args.wf_splits)
|
||||
walk_forward_auc(args.data, time_weight_decay=args.decay, n_splits=args.wf_splits,
|
||||
atr_sl_mult=args.sl_mult, atr_tp_mult=args.tp_mult)
|
||||
else:
|
||||
train_mlx(args.data, time_weight_decay=args.decay)
|
||||
train_mlx(args.data, time_weight_decay=args.decay,
|
||||
atr_sl_mult=args.sl_mult, atr_tp_mult=args.tp_mult)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user