From 0af138d8ee65781d4e99b42a141bfeb03d5e1370 Mon Sep 17 00:00:00 2001 From: 21in7 Date: Mon, 2 Mar 2026 23:58:15 +0900 Subject: [PATCH] feat: add stratified_undersample helper function Co-Authored-By: Claude Opus 4.6 --- src/dataset_builder.py | 30 ++++++++++++++++++++++++++++++ tests/test_dataset_builder.py | 17 +++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/src/dataset_builder.py b/src/dataset_builder.py index b960956..37dc4cb 100644 --- a/src/dataset_builder.py +++ b/src/dataset_builder.py @@ -459,3 +459,33 @@ def generate_dataset_vectorized( print(f" 최종 데이터셋: {n:,}개 (시그널={total_sig:,}, HOLD={total_hold:,})") return feat_final + + +def stratified_undersample( + y: np.ndarray, + source: np.ndarray, + seed: int = 42, +) -> np.ndarray: + """Signal 샘플 전수 유지 + HOLD negative만 양성 수 만큼 샘플링. + + Args: + y: 라벨 배열 (0 or 1) + source: 소스 배열 ("signal" or "hold_negative") + seed: 랜덤 시드 + + Returns: + 정렬된 인덱스 배열 (학습에 사용할 행 인덱스) + """ + pos_idx = np.where(y == 1)[0] # Signal Win + sig_neg_idx = np.where((y == 0) & (source == "signal"))[0] # Signal Loss + hold_neg_idx = np.where(source == "hold_negative")[0] # HOLD negative + + # HOLD negative에서 양성 수 만큼만 샘플링 + n_hold = min(len(hold_neg_idx), len(pos_idx)) + rng = np.random.default_rng(seed) + if n_hold > 0: + hold_sampled = rng.choice(hold_neg_idx, size=n_hold, replace=False) + else: + hold_sampled = np.array([], dtype=np.intp) + + return np.sort(np.concatenate([pos_idx, sig_neg_idx, hold_sampled])) diff --git a/tests/test_dataset_builder.py b/tests/test_dataset_builder.py index 55bc3b4..0a29e6d 100644 --- a/tests/test_dataset_builder.py +++ b/tests/test_dataset_builder.py @@ -249,3 +249,20 @@ def test_signal_samples_preserved_after_sampling(signal_producing_df): signal_count = (result_with_hold["source"] == "signal").sum() assert signal_count == len(result_signal_only), \ f"Signal 샘플 손실: 원본={len(result_signal_only)}, 유지={signal_count}" + + +def test_stratified_undersample_preserves_signal(): + """stratified_undersample은 signal 샘플을 전수 유지해야 한다.""" + from src.dataset_builder import stratified_undersample + + y = np.array([1, 0, 0, 0, 0, 0, 0, 0, 1, 0]) + source = np.array(["signal", "signal", "signal", "hold_negative", + "hold_negative", "hold_negative", "hold_negative", + "hold_negative", "signal", "signal"]) + + idx = stratified_undersample(y, source, seed=42) + + # signal 인덱스: 0, 1, 2, 8, 9 → 전부 포함 + signal_indices = np.where(source == "signal")[0] + for si in signal_indices: + assert si in idx, f"signal 인덱스 {si}가 누락됨"