diff --git a/src/dataset_builder.py b/src/dataset_builder.py index 1a27acb..9b63761 100644 --- a/src/dataset_builder.py +++ b/src/dataset_builder.py @@ -47,6 +47,10 @@ def _calc_indicators(df: pd.DataFrame) -> pd.DataFrame: d["stoch_k"] = stoch["STOCHRSIk_14_14_3_3"] d["stoch_d"] = stoch["STOCHRSId_14_14_3_3"] + # ADX (14) — 횡보장 필터 + adx_df = ta.adx(high, low, close, length=14) + d["adx"] = adx_df["ADX_14"] + return d @@ -112,6 +116,12 @@ def _calc_signals(d: pd.DataFrame) -> np.ndarray: # 둘 다 해당하면 HOLD (충돌 방지) signal_arr[long_enter & short_enter] = "HOLD" + # ADX 횡보장 필터: ADX < 25이면 추세 부재로 판단하여 진입 차단 + if "adx" in d.columns: + adx = d["adx"].values + low_adx = (~np.isnan(adx)) & (adx < 25) + signal_arr[low_adx] = "HOLD" + return signal_arr diff --git a/src/indicators.py b/src/indicators.py index 5a2f166..4b34fc2 100644 --- a/src/indicators.py +++ b/src/indicators.py @@ -60,6 +60,12 @@ class Indicators: last = df.iloc[-1] prev = df.iloc[-2] + # ADX 횡보장 필터: ADX < 25이면 추세 부재로 판단하여 진입 차단 + adx = last.get("adx", None) + if adx is not None and not pd.isna(adx) and adx < 25: + logger.debug(f"ADX 필터: {adx:.1f} < 25 — HOLD") + return "HOLD" + long_signals = 0 short_signals = 0 diff --git a/tests/test_indicators.py b/tests/test_indicators.py index 9d20fbd..1135677 100644 --- a/tests/test_indicators.py +++ b/tests/test_indicators.py @@ -54,6 +54,33 @@ def test_adx_column_exists(sample_df): assert (valid >= 0).all() +def test_adx_filter_blocks_low_adx(sample_df): + """ADX < 25일 때 가중치와 무관하게 HOLD를 반환해야 한다.""" + ind = Indicators(sample_df) + df = ind.calculate_all() + # 강한 LONG 신호가 나오도록 지표 조작 + df.loc[df.index[-1], "rsi"] = 20 # RSI 과매도 → +1 + df.loc[df.index[-2], "macd"] = -1 # MACD 골든크로스 → +2 + df.loc[df.index[-2], "macd_signal"] = 0 + df.loc[df.index[-1], "macd"] = 1 + df.loc[df.index[-1], "macd_signal"] = 0 + df.loc[df.index[-1], "volume"] = df.loc[df.index[-1], "vol_ma20"] * 2 # 거래량 서지 + # ADX를 강제로 낮은 값으로 설정 + df["adx"] = 15.0 + signal = ind.get_signal(df) + assert signal == "HOLD" + + +def test_adx_nan_falls_through(sample_df): + """ADX가 NaN(초기 캔들)이면 기존 가중치 로직으로 폴백해야 한다.""" + ind = Indicators(sample_df) + df = ind.calculate_all() + df["adx"] = float("nan") + signal = ind.get_signal(df) + # NaN이면 차단하지 않고 기존 로직 실행 → LONG/SHORT/HOLD 중 하나 + assert signal in ("LONG", "SHORT", "HOLD") + + def test_signal_returns_direction(sample_df): ind = Indicators(sample_df) df = ind.calculate_all()