fix: make HOLD negative sampling tests non-vacuous
The two HOLD negative tests (test_hold_negative_labels_are_all_zero, test_signal_samples_preserved_after_sampling) were passing vacuously because sample_df produces 0 signal candles (ADX ~18, below threshold 25). Added signal_producing_df fixture with higher volatility and volume surges to reliably generate signals. Removed if-guards so assertions are mandatory. Also restored the full docstring for generate_dataset_vectorized() documenting btc_df/eth_df, time_weight_decay, and negative_ratio parameters. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -366,6 +366,13 @@ def generate_dataset_vectorized(
|
|||||||
) -> pd.DataFrame:
|
) -> pd.DataFrame:
|
||||||
"""
|
"""
|
||||||
전체 시계열을 1회 계산해 학습 데이터셋을 생성한다.
|
전체 시계열을 1회 계산해 학습 데이터셋을 생성한다.
|
||||||
|
기존 generate_dataset()의 drop-in 대체제.
|
||||||
|
btc_df, eth_df가 제공되면 21개 피처로 확장한다.
|
||||||
|
|
||||||
|
time_weight_decay: 지수 감쇠 강도. 0이면 균등 가중치.
|
||||||
|
양수일수록 최신 샘플에 더 높은 가중치를 부여한다.
|
||||||
|
예) 2.0 → 최신 샘플이 가장 오래된 샘플보다 e^2 ≈ 7.4배 높은 가중치.
|
||||||
|
결과 DataFrame에 'sample_weight' 컬럼으로 포함된다.
|
||||||
|
|
||||||
negative_ratio: 시그널 샘플 대비 HOLD negative 샘플 비율.
|
negative_ratio: 시그널 샘플 대비 HOLD negative 샘플 비율.
|
||||||
0이면 기존 동작 (시그널만). 5면 시그널의 5배만큼 HOLD 샘플 추가.
|
0이면 기존 동작 (시그널만). 5면 시그널의 5배만큼 HOLD 샘플 추가.
|
||||||
|
|||||||
@@ -210,22 +210,42 @@ def test_rs_zero_denominator():
|
|||||||
"xrp_btc_rs가 전부 nan이면 안 됨"
|
"xrp_btc_rs가 전부 nan이면 안 됨"
|
||||||
|
|
||||||
|
|
||||||
def test_hold_negative_labels_are_all_zero(sample_df):
|
@pytest.fixture
|
||||||
|
def signal_producing_df():
|
||||||
|
"""시그널이 반드시 발생하는 더미 데이터. 높은 변동성 + 거래량 급증."""
|
||||||
|
rng = np.random.default_rng(7)
|
||||||
|
n = 800
|
||||||
|
trend = np.linspace(1.5, 3.0, n)
|
||||||
|
noise = np.cumsum(rng.normal(0, 0.04, n))
|
||||||
|
close = np.clip(trend + noise, 0.01, None)
|
||||||
|
high = close * (1 + rng.uniform(0, 0.015, n))
|
||||||
|
low = close * (1 - rng.uniform(0, 0.015, n))
|
||||||
|
volume = rng.uniform(1e6, 3e6, n)
|
||||||
|
volume[::30] *= 3.0 # 30봉마다 거래량 급증
|
||||||
|
return pd.DataFrame({
|
||||||
|
"open": close, "high": high, "low": low,
|
||||||
|
"close": close, "volume": volume,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def test_hold_negative_labels_are_all_zero(signal_producing_df):
|
||||||
"""HOLD negative 샘플의 label은 전부 0이어야 한다."""
|
"""HOLD negative 샘플의 label은 전부 0이어야 한다."""
|
||||||
result = generate_dataset_vectorized(sample_df, negative_ratio=3)
|
result = generate_dataset_vectorized(signal_producing_df, negative_ratio=3)
|
||||||
if len(result) > 0 and "source" in result.columns:
|
assert len(result) > 0, "시그널이 발생하지 않아 테스트 불가"
|
||||||
hold_neg = result[result["source"] == "hold_negative"]
|
assert "source" in result.columns
|
||||||
if len(hold_neg) > 0:
|
hold_neg = result[result["source"] == "hold_negative"]
|
||||||
assert (hold_neg["label"] == 0).all(), \
|
assert len(hold_neg) > 0, "HOLD negative 샘플이 0개"
|
||||||
f"HOLD negative 중 label != 0인 샘플 존재: {hold_neg['label'].value_counts().to_dict()}"
|
assert (hold_neg["label"] == 0).all(), \
|
||||||
|
f"HOLD negative 중 label != 0인 샘플 존재: {hold_neg['label'].value_counts().to_dict()}"
|
||||||
|
|
||||||
|
|
||||||
def test_signal_samples_preserved_after_sampling(sample_df):
|
def test_signal_samples_preserved_after_sampling(signal_producing_df):
|
||||||
"""계층적 샘플링 후 source='signal' 샘플이 하나도 버려지지 않아야 한다."""
|
"""계층적 샘플링 후 source='signal' 샘플이 하나도 버려지지 않아야 한다."""
|
||||||
result_signal_only = generate_dataset_vectorized(sample_df, negative_ratio=0)
|
result_signal_only = generate_dataset_vectorized(signal_producing_df, negative_ratio=0)
|
||||||
result_with_hold = generate_dataset_vectorized(sample_df, negative_ratio=3)
|
result_with_hold = generate_dataset_vectorized(signal_producing_df, negative_ratio=3)
|
||||||
|
|
||||||
if len(result_with_hold) > 0 and "source" in result_with_hold.columns:
|
assert len(result_signal_only) > 0, "시그널이 발생하지 않아 테스트 불가"
|
||||||
signal_count = (result_with_hold["source"] == "signal").sum()
|
assert "source" in result_with_hold.columns
|
||||||
assert signal_count == len(result_signal_only), \
|
signal_count = (result_with_hold["source"] == "signal").sum()
|
||||||
f"Signal 샘플 손실: 원본={len(result_signal_only)}, 유지={signal_count}"
|
assert signal_count == len(result_signal_only), \
|
||||||
|
f"Signal 샘플 손실: 원본={len(result_signal_only)}, 유지={signal_count}"
|
||||||
|
|||||||
Reference in New Issue
Block a user