feat: enhance MLX model training with combined data handling
- Introduced a new function `_split_combined` to separate XRP, BTC, and ETH data from a combined DataFrame. - Updated `train_mlx` to utilize the new function, improving data management and feature handling. - Adjusted dataset generation to accommodate BTC and ETH features, with warnings for missing features. - Changed default data path in `train_mlx` and `train_model` to point to the combined dataset for consistency. - Increased `LOOKAHEAD` from 60 to 90 and adjusted `ATR_TP_MULT` for better model performance.
This commit is contained in:
@@ -25,14 +25,40 @@ MLX_MODEL_PATH = Path("models/mlx_filter.weights")
|
|||||||
LOG_PATH = Path("models/training_log.json")
|
LOG_PATH = Path("models/training_log.json")
|
||||||
|
|
||||||
|
|
||||||
|
def _split_combined(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame | None, pd.DataFrame | None]:
|
||||||
|
"""combined parquet에서 XRP/BTC/ETH DataFrame을 분리한다."""
|
||||||
|
xrp_cols = ["open", "high", "low", "close", "volume"]
|
||||||
|
xrp_df = df[xrp_cols].copy()
|
||||||
|
|
||||||
|
btc_df = None
|
||||||
|
eth_df = None
|
||||||
|
btc_raw = [c for c in df.columns if c.endswith("_btc")]
|
||||||
|
eth_raw = [c for c in df.columns if c.endswith("_eth")]
|
||||||
|
|
||||||
|
if btc_raw:
|
||||||
|
btc_df = df[btc_raw].copy()
|
||||||
|
btc_df.columns = [c.replace("_btc", "") for c in btc_raw]
|
||||||
|
if eth_raw:
|
||||||
|
eth_df = df[eth_raw].copy()
|
||||||
|
eth_df.columns = [c.replace("_eth", "") for c in eth_raw]
|
||||||
|
|
||||||
|
return xrp_df, btc_df, eth_df
|
||||||
|
|
||||||
|
|
||||||
def train_mlx(data_path: str, time_weight_decay: float = 2.0) -> float:
|
def train_mlx(data_path: str, time_weight_decay: float = 2.0) -> float:
|
||||||
print(f"데이터 로드: {data_path}")
|
print(f"데이터 로드: {data_path}")
|
||||||
df = pd.read_parquet(data_path)
|
raw = pd.read_parquet(data_path)
|
||||||
print(f"캔들 수: {len(df)}")
|
print(f"캔들 수: {len(raw)}")
|
||||||
|
|
||||||
|
df, btc_df, eth_df = _split_combined(raw)
|
||||||
|
if btc_df is not None:
|
||||||
|
print(f" BTC/ETH 피처 활성화 (21개 피처)")
|
||||||
|
else:
|
||||||
|
print(f" XRP 단독 데이터 (13개 피처)")
|
||||||
|
|
||||||
print("\n데이터셋 생성 중...")
|
print("\n데이터셋 생성 중...")
|
||||||
t0 = time.perf_counter()
|
t0 = time.perf_counter()
|
||||||
dataset = generate_dataset_vectorized(df, time_weight_decay=time_weight_decay)
|
dataset = generate_dataset_vectorized(df, btc_df=btc_df, eth_df=eth_df, time_weight_decay=time_weight_decay)
|
||||||
t1 = time.perf_counter()
|
t1 = time.perf_counter()
|
||||||
print(f"데이터셋 생성 완료: {t1 - t0:.1f}초, {len(dataset)}개 샘플")
|
print(f"데이터셋 생성 완료: {t1 - t0:.1f}초, {len(dataset)}개 샘플")
|
||||||
|
|
||||||
@@ -44,6 +70,12 @@ def train_mlx(data_path: str, time_weight_decay: float = 2.0) -> float:
|
|||||||
if len(dataset) < 200:
|
if len(dataset) < 200:
|
||||||
raise ValueError(f"학습 샘플 부족: {len(dataset)}개 (최소 200 필요)")
|
raise ValueError(f"학습 샘플 부족: {len(dataset)}개 (최소 200 필요)")
|
||||||
|
|
||||||
|
actual_cols = [c for c in FEATURE_COLS if c in dataset.columns]
|
||||||
|
missing = [c for c in FEATURE_COLS if c not in dataset.columns]
|
||||||
|
if missing:
|
||||||
|
print(f" 경고: 데이터셋에 없는 피처 {missing} → 0으로 채움 (BTC/ETH 데이터 미제공)")
|
||||||
|
for col in missing:
|
||||||
|
dataset[col] = 0.0
|
||||||
X = dataset[FEATURE_COLS]
|
X = dataset[FEATURE_COLS]
|
||||||
y = dataset["label"]
|
y = dataset["label"]
|
||||||
w = dataset["sample_weight"].values
|
w = dataset["sample_weight"].values
|
||||||
@@ -114,7 +146,7 @@ def train_mlx(data_path: str, time_weight_decay: float = 2.0) -> float:
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--data", default="data/xrpusdt_1m.parquet")
|
parser.add_argument("--data", default="data/combined_1m.parquet")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--decay", type=float, default=2.0,
|
"--decay", type=float, default=2.0,
|
||||||
help="시간 가중치 감쇠 강도 (0=균등, 2.0=최신이 ~7.4배 높음)",
|
help="시간 가중치 감쇠 강도 (0=균등, 2.0=최신이 ~7.4배 높음)",
|
||||||
|
|||||||
@@ -261,7 +261,7 @@ def train(data_path: str, time_weight_decay: float = 2.0):
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--data", default="data/xrpusdt_1m.parquet")
|
parser.add_argument("--data", default="data/combined_1m.parquet")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--decay", type=float, default=2.0,
|
"--decay", type=float, default=2.0,
|
||||||
help="시간 가중치 감쇠 강도 (0=균등, 2.0=최신이 ~7.4배 높음)",
|
help="시간 가중치 감쇠 강도 (0=균등, 2.0=최신이 ~7.4배 높음)",
|
||||||
|
|||||||
@@ -11,9 +11,9 @@ import pandas_ta as ta
|
|||||||
|
|
||||||
from src.ml_features import FEATURE_COLS
|
from src.ml_features import FEATURE_COLS
|
||||||
|
|
||||||
LOOKAHEAD = 60
|
LOOKAHEAD = 90
|
||||||
ATR_SL_MULT = 1.5
|
ATR_SL_MULT = 1.5
|
||||||
ATR_TP_MULT = 3.0
|
ATR_TP_MULT = 2.0
|
||||||
WARMUP = 60 # 지표 안정화에 필요한 최소 행 수
|
WARMUP = 60 # 지표 안정화에 필요한 최소 행 수
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user