feat: enhance MLX model training with combined data handling

- Introduced a new function `_split_combined` to separate XRP, BTC, and ETH data from a combined DataFrame.
- Updated `train_mlx` to utilize the new function, improving data management and feature handling.
- Adjusted dataset generation to accommodate BTC and ETH features, with warnings for missing features.
- Changed default data path in `train_mlx` and `train_model` to point to the combined dataset for consistency.
- Increased `LOOKAHEAD` from 60 to 90 and adjusted `ATR_TP_MULT` for better model performance.
This commit is contained in:
21in7
2026-03-01 21:43:27 +09:00
parent db144750a3
commit d9238afaf9
3 changed files with 39 additions and 7 deletions

View File

@@ -25,14 +25,40 @@ MLX_MODEL_PATH = Path("models/mlx_filter.weights")
LOG_PATH = Path("models/training_log.json")
def _split_combined(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame | None, pd.DataFrame | None]:
"""combined parquet에서 XRP/BTC/ETH DataFrame을 분리한다."""
xrp_cols = ["open", "high", "low", "close", "volume"]
xrp_df = df[xrp_cols].copy()
btc_df = None
eth_df = None
btc_raw = [c for c in df.columns if c.endswith("_btc")]
eth_raw = [c for c in df.columns if c.endswith("_eth")]
if btc_raw:
btc_df = df[btc_raw].copy()
btc_df.columns = [c.replace("_btc", "") for c in btc_raw]
if eth_raw:
eth_df = df[eth_raw].copy()
eth_df.columns = [c.replace("_eth", "") for c in eth_raw]
return xrp_df, btc_df, eth_df
def train_mlx(data_path: str, time_weight_decay: float = 2.0) -> float:
print(f"데이터 로드: {data_path}")
df = pd.read_parquet(data_path)
print(f"캔들 수: {len(df)}")
raw = pd.read_parquet(data_path)
print(f"캔들 수: {len(raw)}")
df, btc_df, eth_df = _split_combined(raw)
if btc_df is not None:
print(f" BTC/ETH 피처 활성화 (21개 피처)")
else:
print(f" XRP 단독 데이터 (13개 피처)")
print("\n데이터셋 생성 중...")
t0 = time.perf_counter()
dataset = generate_dataset_vectorized(df, time_weight_decay=time_weight_decay)
dataset = generate_dataset_vectorized(df, btc_df=btc_df, eth_df=eth_df, time_weight_decay=time_weight_decay)
t1 = time.perf_counter()
print(f"데이터셋 생성 완료: {t1 - t0:.1f}초, {len(dataset)}개 샘플")
@@ -44,6 +70,12 @@ def train_mlx(data_path: str, time_weight_decay: float = 2.0) -> float:
if len(dataset) < 200:
raise ValueError(f"학습 샘플 부족: {len(dataset)}개 (최소 200 필요)")
actual_cols = [c for c in FEATURE_COLS if c in dataset.columns]
missing = [c for c in FEATURE_COLS if c not in dataset.columns]
if missing:
print(f" 경고: 데이터셋에 없는 피처 {missing} → 0으로 채움 (BTC/ETH 데이터 미제공)")
for col in missing:
dataset[col] = 0.0
X = dataset[FEATURE_COLS]
y = dataset["label"]
w = dataset["sample_weight"].values
@@ -114,7 +146,7 @@ def train_mlx(data_path: str, time_weight_decay: float = 2.0) -> float:
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--data", default="data/xrpusdt_1m.parquet")
parser.add_argument("--data", default="data/combined_1m.parquet")
parser.add_argument(
"--decay", type=float, default=2.0,
help="시간 가중치 감쇠 강도 (0=균등, 2.0=최신이 ~7.4배 높음)",

View File

@@ -261,7 +261,7 @@ def train(data_path: str, time_weight_decay: float = 2.0):
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--data", default="data/xrpusdt_1m.parquet")
parser.add_argument("--data", default="data/combined_1m.parquet")
parser.add_argument(
"--decay", type=float, default=2.0,
help="시간 가중치 감쇠 강도 (0=균등, 2.0=최신이 ~7.4배 높음)",

View File

@@ -11,9 +11,9 @@ import pandas_ta as ta
from src.ml_features import FEATURE_COLS
LOOKAHEAD = 60
LOOKAHEAD = 90
ATR_SL_MULT = 1.5
ATR_TP_MULT = 3.0
ATR_TP_MULT = 2.0
WARMUP = 60 # 지표 안정화에 필요한 최소 행 수