From 820d8e02138dc2d7759482b6f318bf00a88bc261 Mon Sep 17 00:00:00 2001 From: 21in7 Date: Sun, 1 Mar 2026 23:52:59 +0900 Subject: [PATCH] =?UTF-8?q?refactor:=20=EB=B6=84=EB=AA=A8=20=EC=97=B0?= =?UTF-8?q?=EC=82=B0=EC=9D=84=201e-8=20epsilon=20=ED=8C=A8=ED=84=B4?= =?UTF-8?q?=EC=9C=BC=EB=A1=9C=20=ED=86=B5=EC=9D=BC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Made-with: Cursor --- src/dataset_builder.py | 10 +++++----- tests/test_dataset_builder.py | 24 ++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/src/dataset_builder.py b/src/dataset_builder.py index 7ca8085..8f36e4b 100644 --- a/src/dataset_builder.py +++ b/src/dataset_builder.py @@ -154,7 +154,7 @@ def _calc_features_vectorized( macd_sig = d["macd_signal"] bb_range = bb_upper - bb_lower - bb_pct = np.where(bb_range > 0, (close - bb_lower) / bb_range, 0.5) + bb_pct = (close - bb_lower) / (bb_range + 1e-8) ema_align = np.where( (ema9 > ema21) & (ema21 > ema50), 1, @@ -163,8 +163,8 @@ def _calc_features_vectorized( ) ).astype(np.float32) - atr_pct = np.where(close > 0, atr / close, 0.0) - vol_ratio = np.where(vol_ma20 > 0, volume / vol_ma20, 1.0) + atr_pct = atr / (close + 1e-8) + vol_ratio = volume / (vol_ma20 + 1e-8) ret_1 = close.pct_change(1).fillna(0).values ret_3 = close.pct_change(3).fillna(0).values @@ -242,8 +242,8 @@ def _calc_features_vectorized( eth_r5 = _align(eth_ret_5, n).astype(np.float32) xrp_r1 = ret_1.astype(np.float32) - xrp_btc_rs_raw = np.where(btc_r1 != 0, xrp_r1 / btc_r1, 0.0).astype(np.float32) - xrp_eth_rs_raw = np.where(eth_r1 != 0, xrp_r1 / eth_r1, 0.0).astype(np.float32) + xrp_btc_rs_raw = (xrp_r1 / (btc_r1 + 1e-8)).astype(np.float32) + xrp_eth_rs_raw = (xrp_r1 / (eth_r1 + 1e-8)).astype(np.float32) extra = pd.DataFrame({ "btc_ret_1": _rolling_zscore(btc_r1), diff --git a/tests/test_dataset_builder.py b/tests/test_dataset_builder.py index 03272fb..d72b893 100644 --- a/tests/test_dataset_builder.py +++ b/tests/test_dataset_builder.py @@ -93,6 +93,30 @@ def test_matches_original_generate_dataset(sample_df): ) +def test_epsilon_no_division_by_zero(): + """bb_range=0, close=0, vol_ma20=0 극단값에서 nan/inf가 발생하지 않아야 한다.""" + import numpy as np + import pandas as pd + from src.dataset_builder import _calc_features_vectorized, _calc_signals, _calc_indicators + + n = 100 + # close를 모두 같은 값으로 → bb_range=0 유발 + df = pd.DataFrame({ + "open": np.ones(n), + "high": np.ones(n), + "low": np.ones(n), + "close": np.ones(n), + "volume": np.ones(n), + }) + d = _calc_indicators(df) + sig = _calc_signals(d) + feat = _calc_features_vectorized(d, sig) + + numeric_cols = feat.select_dtypes(include=[np.number]).columns + assert not feat[numeric_cols].isin([np.inf, -np.inf]).any().any(), \ + "inf 값이 있으면 안 됨" + + def test_oi_nan_masking_no_column(): """oi_change 컬럼이 없으면 전체가 nan이어야 한다.""" import numpy as np