feat: enhance MLX model training with combined data handling
- Introduced a new function `_split_combined` to separate XRP, BTC, and ETH data from a combined DataFrame. - Updated `train_mlx` to utilize the new function, improving data management and feature handling. - Adjusted dataset generation to accommodate BTC and ETH features, with warnings for missing features. - Changed default data path in `train_mlx` and `train_model` to point to the combined dataset for consistency. - Increased `LOOKAHEAD` from 60 to 90 and adjusted `ATR_TP_MULT` for better model performance.
This commit is contained in:
@@ -25,14 +25,40 @@ MLX_MODEL_PATH = Path("models/mlx_filter.weights")
|
||||
LOG_PATH = Path("models/training_log.json")
|
||||
|
||||
|
||||
def _split_combined(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame | None, pd.DataFrame | None]:
|
||||
"""combined parquet에서 XRP/BTC/ETH DataFrame을 분리한다."""
|
||||
xrp_cols = ["open", "high", "low", "close", "volume"]
|
||||
xrp_df = df[xrp_cols].copy()
|
||||
|
||||
btc_df = None
|
||||
eth_df = None
|
||||
btc_raw = [c for c in df.columns if c.endswith("_btc")]
|
||||
eth_raw = [c for c in df.columns if c.endswith("_eth")]
|
||||
|
||||
if btc_raw:
|
||||
btc_df = df[btc_raw].copy()
|
||||
btc_df.columns = [c.replace("_btc", "") for c in btc_raw]
|
||||
if eth_raw:
|
||||
eth_df = df[eth_raw].copy()
|
||||
eth_df.columns = [c.replace("_eth", "") for c in eth_raw]
|
||||
|
||||
return xrp_df, btc_df, eth_df
|
||||
|
||||
|
||||
def train_mlx(data_path: str, time_weight_decay: float = 2.0) -> float:
|
||||
print(f"데이터 로드: {data_path}")
|
||||
df = pd.read_parquet(data_path)
|
||||
print(f"캔들 수: {len(df)}")
|
||||
raw = pd.read_parquet(data_path)
|
||||
print(f"캔들 수: {len(raw)}")
|
||||
|
||||
df, btc_df, eth_df = _split_combined(raw)
|
||||
if btc_df is not None:
|
||||
print(f" BTC/ETH 피처 활성화 (21개 피처)")
|
||||
else:
|
||||
print(f" XRP 단독 데이터 (13개 피처)")
|
||||
|
||||
print("\n데이터셋 생성 중...")
|
||||
t0 = time.perf_counter()
|
||||
dataset = generate_dataset_vectorized(df, time_weight_decay=time_weight_decay)
|
||||
dataset = generate_dataset_vectorized(df, btc_df=btc_df, eth_df=eth_df, time_weight_decay=time_weight_decay)
|
||||
t1 = time.perf_counter()
|
||||
print(f"데이터셋 생성 완료: {t1 - t0:.1f}초, {len(dataset)}개 샘플")
|
||||
|
||||
@@ -44,6 +70,12 @@ def train_mlx(data_path: str, time_weight_decay: float = 2.0) -> float:
|
||||
if len(dataset) < 200:
|
||||
raise ValueError(f"학습 샘플 부족: {len(dataset)}개 (최소 200 필요)")
|
||||
|
||||
actual_cols = [c for c in FEATURE_COLS if c in dataset.columns]
|
||||
missing = [c for c in FEATURE_COLS if c not in dataset.columns]
|
||||
if missing:
|
||||
print(f" 경고: 데이터셋에 없는 피처 {missing} → 0으로 채움 (BTC/ETH 데이터 미제공)")
|
||||
for col in missing:
|
||||
dataset[col] = 0.0
|
||||
X = dataset[FEATURE_COLS]
|
||||
y = dataset["label"]
|
||||
w = dataset["sample_weight"].values
|
||||
@@ -114,7 +146,7 @@ def train_mlx(data_path: str, time_weight_decay: float = 2.0) -> float:
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data", default="data/xrpusdt_1m.parquet")
|
||||
parser.add_argument("--data", default="data/combined_1m.parquet")
|
||||
parser.add_argument(
|
||||
"--decay", type=float, default=2.0,
|
||||
help="시간 가중치 감쇠 강도 (0=균등, 2.0=최신이 ~7.4배 높음)",
|
||||
|
||||
Reference in New Issue
Block a user