fix(ml): pass SL/TP multipliers to dataset generation — align train/serve
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -39,7 +39,7 @@ from src.dataset_builder import generate_dataset_vectorized, stratified_undersam
|
||||
# 데이터 로드 및 데이터셋 생성 (1회 캐싱)
|
||||
# ──────────────────────────────────────────────
|
||||
|
||||
def load_dataset(data_path: str) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
def load_dataset(data_path: str, atr_sl_mult: float = 2.0, atr_tp_mult: float = 2.0) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
parquet 로드 → 벡터화 데이터셋 생성 → (X, y, w) numpy 배열 반환.
|
||||
study 시작 전 1회만 호출하여 모든 trial이 공유한다.
|
||||
@@ -64,7 +64,8 @@ def load_dataset(data_path: str) -> tuple[np.ndarray, np.ndarray, np.ndarray, np
|
||||
df = df_raw[base_cols].copy()
|
||||
|
||||
print("\n데이터셋 생성 중 (1회만 실행)...")
|
||||
dataset = generate_dataset_vectorized(df, btc_df=btc_df, eth_df=eth_df, time_weight_decay=0.0, negative_ratio=5)
|
||||
dataset = generate_dataset_vectorized(df, btc_df=btc_df, eth_df=eth_df, time_weight_decay=0.0, negative_ratio=5,
|
||||
atr_sl_mult=atr_sl_mult, atr_tp_mult=atr_tp_mult)
|
||||
|
||||
if dataset.empty or "label" not in dataset.columns:
|
||||
raise ValueError("데이터셋 생성 실패: 샘플 0개")
|
||||
@@ -527,6 +528,8 @@ def main():
|
||||
parser.add_argument("--train-ratio", type=float, default=0.6, help="학습 구간 비율 (기본: 0.6)")
|
||||
parser.add_argument("--min-recall", type=float, default=0.35, help="최소 재현율 제약 (기본: 0.35)")
|
||||
parser.add_argument("--no-baseline", action="store_true", help="베이스라인 측정 건너뜀")
|
||||
parser.add_argument("--sl-mult", type=float, default=2.0, help="SL ATR 배수 (기본 2.0)")
|
||||
parser.add_argument("--tp-mult", type=float, default=2.0, help="TP ATR 배수 (기본 2.0)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# --symbol 모드: 심볼별 디렉토리 경로 자동 결정
|
||||
@@ -538,7 +541,7 @@ def main():
|
||||
args.data = "data/combined_15m.parquet"
|
||||
|
||||
# 1. 데이터셋 로드 (1회)
|
||||
X, y, w, source = load_dataset(args.data)
|
||||
X, y, w, source = load_dataset(args.data, atr_sl_mult=args.sl_mult, atr_tp_mult=args.tp_mult)
|
||||
|
||||
# 2. 베이스라인 측정
|
||||
if args.symbol:
|
||||
|
||||
Reference in New Issue
Block a user