feat: add stratified_undersample helper function
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -459,3 +459,33 @@ def generate_dataset_vectorized(
|
||||
print(f" 최종 데이터셋: {n:,}개 (시그널={total_sig:,}, HOLD={total_hold:,})")
|
||||
|
||||
return feat_final
|
||||
|
||||
|
||||
def stratified_undersample(
|
||||
y: np.ndarray,
|
||||
source: np.ndarray,
|
||||
seed: int = 42,
|
||||
) -> np.ndarray:
|
||||
"""Signal 샘플 전수 유지 + HOLD negative만 양성 수 만큼 샘플링.
|
||||
|
||||
Args:
|
||||
y: 라벨 배열 (0 or 1)
|
||||
source: 소스 배열 ("signal" or "hold_negative")
|
||||
seed: 랜덤 시드
|
||||
|
||||
Returns:
|
||||
정렬된 인덱스 배열 (학습에 사용할 행 인덱스)
|
||||
"""
|
||||
pos_idx = np.where(y == 1)[0] # Signal Win
|
||||
sig_neg_idx = np.where((y == 0) & (source == "signal"))[0] # Signal Loss
|
||||
hold_neg_idx = np.where(source == "hold_negative")[0] # HOLD negative
|
||||
|
||||
# HOLD negative에서 양성 수 만큼만 샘플링
|
||||
n_hold = min(len(hold_neg_idx), len(pos_idx))
|
||||
rng = np.random.default_rng(seed)
|
||||
if n_hold > 0:
|
||||
hold_sampled = rng.choice(hold_neg_idx, size=n_hold, replace=False)
|
||||
else:
|
||||
hold_sampled = np.array([], dtype=np.intp)
|
||||
|
||||
return np.sort(np.concatenate([pos_idx, sig_neg_idx, hold_sampled]))
|
||||
|
||||
@@ -249,3 +249,20 @@ def test_signal_samples_preserved_after_sampling(signal_producing_df):
|
||||
signal_count = (result_with_hold["source"] == "signal").sum()
|
||||
assert signal_count == len(result_signal_only), \
|
||||
f"Signal 샘플 손실: 원본={len(result_signal_only)}, 유지={signal_count}"
|
||||
|
||||
|
||||
def test_stratified_undersample_preserves_signal():
|
||||
"""stratified_undersample은 signal 샘플을 전수 유지해야 한다."""
|
||||
from src.dataset_builder import stratified_undersample
|
||||
|
||||
y = np.array([1, 0, 0, 0, 0, 0, 0, 0, 1, 0])
|
||||
source = np.array(["signal", "signal", "signal", "hold_negative",
|
||||
"hold_negative", "hold_negative", "hold_negative",
|
||||
"hold_negative", "signal", "signal"])
|
||||
|
||||
idx = stratified_undersample(y, source, seed=42)
|
||||
|
||||
# signal 인덱스: 0, 1, 2, 8, 9 → 전부 포함
|
||||
signal_indices = np.where(source == "signal")[0]
|
||||
for si in signal_indices:
|
||||
assert si in idx, f"signal 인덱스 {si}가 누락됨"
|
||||
|
||||
Reference in New Issue
Block a user