feat: implement BTC/ETH correlation features for improved model accuracy
- Added a new design document outlining the integration of BTC/ETH candle data as additional features in the XRP ML filter, enhancing prediction accuracy. - Introduced `MultiSymbolStream` for combined WebSocket data retrieval of XRP, BTC, and ETH. - Expanded feature set from 13 to 21 by including 8 new BTC/ETH-related features. - Updated various scripts and modules to support the new feature set and data handling. - Enhanced training and deployment scripts to accommodate the new dataset structure. This commit lays the groundwork for improved model performance by leveraging the correlation between BTC and ETH with XRP.
This commit is contained in:
@@ -54,6 +54,13 @@ fi
|
||||
|
||||
echo "=== 전송 완료 ==="
|
||||
echo ""
|
||||
echo "봇이 실행 중이라면 아래 명령으로 모델을 즉시 리로드할 수 있습니다:"
|
||||
echo " docker exec cointrader python -c \\"
|
||||
echo " \"from src.ml_filter import MLFilter; f=MLFilter(); f.reload_model(); print('리로드 완료')\""
|
||||
|
||||
# 봇 컨테이너가 실행 중이면 모델 핫리로드, 아니면 건너뜀
|
||||
echo "=== 핫리로드 시도 ==="
|
||||
if ssh "${LXC_HOST}" "docker inspect -f '{{.State.Running}}' cointrader 2>/dev/null | grep -q true"; then
|
||||
ssh "${LXC_HOST}" "docker exec cointrader python -c \
|
||||
\"from src.ml_filter import MLFilter; f=MLFilter(); f.reload_model(); print('리로드 완료')\""
|
||||
echo "=== 핫리로드 완료 ==="
|
||||
else
|
||||
echo " cointrader 컨테이너가 실행 중이 아닙니다. 건너뜁니다."
|
||||
fi
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
바이낸스 선물 REST API로 과거 캔들 데이터를 수집해 parquet으로 저장한다.
|
||||
사용법: python scripts/fetch_history.py --symbol XRPUSDT --interval 1m --days 90
|
||||
python scripts/fetch_history.py --symbols XRPUSDT BTCUSDT ETHUSDT --days 90
|
||||
"""
|
||||
import sys
|
||||
from pathlib import Path
|
||||
@@ -8,7 +9,7 @@ sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
import asyncio
|
||||
import argparse
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timezone, timedelta
|
||||
import pandas as pd
|
||||
from binance import AsyncClient
|
||||
from dotenv import load_dotenv
|
||||
@@ -16,32 +17,41 @@ import os
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# 요청 사이 딜레이 (초). 바이낸스 선물 기본 한도: 2400 req/min = 40 req/s
|
||||
# 1500개씩 가져오므로 90일 1m 데이터 = ~65회 요청/심볼
|
||||
# 심볼 간 딜레이 없이 연속 요청하면 레이트 리밋(-1003) 발생
|
||||
_REQUEST_DELAY = 0.3 # 초당 ~3.3 req → 안전 마진 충분
|
||||
|
||||
async def fetch_klines(symbol: str, interval: str, days: int) -> pd.DataFrame:
|
||||
client = await AsyncClient.create(
|
||||
api_key=os.getenv("BINANCE_API_KEY", ""),
|
||||
api_secret=os.getenv("BINANCE_API_SECRET", ""),
|
||||
)
|
||||
try:
|
||||
start_ts = int((datetime.utcnow() - timedelta(days=days)).timestamp() * 1000)
|
||||
all_klines = []
|
||||
while True:
|
||||
klines = await client.futures_klines(
|
||||
symbol=symbol,
|
||||
interval=interval,
|
||||
startTime=start_ts,
|
||||
limit=1500,
|
||||
)
|
||||
if not klines:
|
||||
break
|
||||
all_klines.extend(klines)
|
||||
last_ts = klines[-1][0]
|
||||
if last_ts >= int(datetime.utcnow().timestamp() * 1000):
|
||||
break
|
||||
start_ts = last_ts + 1
|
||||
print(f"수집 중... {len(all_klines)}개")
|
||||
finally:
|
||||
await client.close_connection()
|
||||
|
||||
def _now_ms() -> int:
|
||||
return int(datetime.now(timezone.utc).timestamp() * 1000)
|
||||
|
||||
|
||||
async def _fetch_klines_with_client(
|
||||
client: AsyncClient,
|
||||
symbol: str,
|
||||
interval: str,
|
||||
days: int,
|
||||
) -> pd.DataFrame:
|
||||
"""기존 클라이언트를 재사용해 단일 심볼 캔들을 수집한다."""
|
||||
start_ts = int((datetime.now(timezone.utc) - timedelta(days=days)).timestamp() * 1000)
|
||||
all_klines = []
|
||||
while True:
|
||||
klines = await client.futures_klines(
|
||||
symbol=symbol,
|
||||
interval=interval,
|
||||
startTime=start_ts,
|
||||
limit=1500,
|
||||
)
|
||||
if not klines:
|
||||
break
|
||||
all_klines.extend(klines)
|
||||
last_ts = klines[-1][0]
|
||||
if last_ts >= _now_ms():
|
||||
break
|
||||
start_ts = last_ts + 1
|
||||
print(f" [{symbol}] 수집 중... {len(all_klines):,}개")
|
||||
await asyncio.sleep(_REQUEST_DELAY)
|
||||
|
||||
df = pd.DataFrame(all_klines, columns=[
|
||||
"timestamp", "open", "high", "low", "close", "volume",
|
||||
@@ -51,22 +61,87 @@ async def fetch_klines(symbol: str, interval: str, days: int) -> pd.DataFrame:
|
||||
df = df[["timestamp", "open", "high", "low", "close", "volume"]].copy()
|
||||
for col in ["open", "high", "low", "close", "volume"]:
|
||||
df[col] = df[col].astype(float)
|
||||
df["timestamp"] = pd.to_datetime(df["timestamp"], unit="ms")
|
||||
df["timestamp"] = pd.to_datetime(df["timestamp"], unit="ms", utc=True)
|
||||
df.set_index("timestamp", inplace=True)
|
||||
return df
|
||||
|
||||
|
||||
async def fetch_klines(symbol: str, interval: str, days: int) -> pd.DataFrame:
|
||||
"""단일 심볼 수집 (하위 호환용)."""
|
||||
client = await AsyncClient.create(
|
||||
api_key=os.getenv("BINANCE_API_KEY", ""),
|
||||
api_secret=os.getenv("BINANCE_API_SECRET", ""),
|
||||
)
|
||||
try:
|
||||
return await _fetch_klines_with_client(client, symbol, interval, days)
|
||||
finally:
|
||||
await client.close_connection()
|
||||
|
||||
|
||||
async def fetch_klines_all(
|
||||
symbols: list[str],
|
||||
interval: str,
|
||||
days: int,
|
||||
) -> dict[str, pd.DataFrame]:
|
||||
"""
|
||||
단일 클라이언트로 여러 심볼을 순차 수집한다.
|
||||
asyncio.run()을 심볼마다 반복하면 연결 오버헤드와 레이트 리밋 위험이 있으므로
|
||||
하나의 연결 안에서 심볼 간 딜레이를 두고 순차 처리한다.
|
||||
"""
|
||||
client = await AsyncClient.create(
|
||||
api_key=os.getenv("BINANCE_API_KEY", ""),
|
||||
api_secret=os.getenv("BINANCE_API_SECRET", ""),
|
||||
)
|
||||
dfs = {}
|
||||
try:
|
||||
for i, symbol in enumerate(symbols):
|
||||
print(f"\n[{i+1}/{len(symbols)}] {symbol} 수집 시작...")
|
||||
dfs[symbol] = await _fetch_klines_with_client(client, symbol, interval, days)
|
||||
print(f" [{symbol}] 완료: {len(dfs[symbol]):,}행")
|
||||
# 심볼 간 추가 대기: 레이트 리밋 카운터가 리셋될 시간 확보
|
||||
if i < len(symbols) - 1:
|
||||
print(f" 다음 심볼 수집 전 5초 대기...")
|
||||
await asyncio.sleep(5)
|
||||
finally:
|
||||
await client.close_connection()
|
||||
return dfs
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--symbol", default="XRPUSDT")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="바이낸스 선물 과거 캔들 수집. 단일 심볼 또는 멀티 심볼 병합 저장."
|
||||
)
|
||||
parser.add_argument("--symbols", nargs="+", default=["XRPUSDT"])
|
||||
parser.add_argument("--symbol", default=None, help="단일 심볼 (--symbols 미사용 시)")
|
||||
parser.add_argument("--interval", default="1m")
|
||||
parser.add_argument("--days", type=int, default=90)
|
||||
parser.add_argument("--output", default="data/xrpusdt_1m.parquet")
|
||||
args = parser.parse_args()
|
||||
|
||||
df = asyncio.run(fetch_klines(args.symbol, args.interval, args.days))
|
||||
df.to_parquet(args.output)
|
||||
print(f"저장 완료: {args.output} ({len(df)}행)")
|
||||
# 하위 호환: --symbol 단독 사용 시 symbols로 통합
|
||||
if args.symbol and args.symbols == ["XRPUSDT"]:
|
||||
args.symbols = [args.symbol]
|
||||
|
||||
if len(args.symbols) == 1:
|
||||
df = asyncio.run(fetch_klines(args.symbols[0], args.interval, args.days))
|
||||
df.to_parquet(args.output)
|
||||
print(f"저장 완료: {args.output} ({len(df):,}행)")
|
||||
else:
|
||||
# 멀티 심볼: 단일 클라이언트로 순차 수집 후 타임스탬프 기준 inner join 병합
|
||||
dfs = asyncio.run(fetch_klines_all(args.symbols, args.interval, args.days))
|
||||
|
||||
primary = args.symbols[0]
|
||||
merged = dfs[primary].copy()
|
||||
for symbol in args.symbols[1:]:
|
||||
suffix = "_" + symbol.lower().replace("usdt", "")
|
||||
merged = merged.join(
|
||||
dfs[symbol].add_suffix(suffix),
|
||||
how="inner",
|
||||
)
|
||||
|
||||
output = args.output.replace("xrpusdt", "combined")
|
||||
merged.to_parquet(output)
|
||||
print(f"\n병합 저장 완료: {output} ({len(merged):,}행, {len(merged.columns)}컬럼)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -16,19 +16,24 @@ PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
|
||||
|
||||
cd "$PROJECT_ROOT"
|
||||
|
||||
echo "=== [1/3] 데이터 수집 ==="
|
||||
python scripts/fetch_history.py --symbol XRPUSDT --interval 1m --days 90 --output data/xrpusdt_1m.parquet
|
||||
echo "=== [1/3] 데이터 수집 (XRP + BTC + ETH 3심볼) ==="
|
||||
python scripts/fetch_history.py \
|
||||
--symbols XRPUSDT BTCUSDT ETHUSDT \
|
||||
--interval 1m \
|
||||
--days 90 \
|
||||
--output data/xrpusdt_1m.parquet
|
||||
# 결과: data/combined_1m.parquet (타임스탬프 기준 병합)
|
||||
|
||||
echo ""
|
||||
echo "=== [2/3] 모델 학습 ==="
|
||||
echo "=== [2/3] 모델 학습 (21개 피처: XRP 13 + BTC/ETH 상관관계 8) ==="
|
||||
# TRAIN_BACKEND=mlx 로 설정하면 Apple Silicon GPU(Metal)를 사용한다 (기본: lgbm)
|
||||
BACKEND="${TRAIN_BACKEND:-lgbm}"
|
||||
if [ "$BACKEND" = "mlx" ]; then
|
||||
echo " 백엔드: MLX (Apple Silicon GPU)"
|
||||
python scripts/train_mlx_model.py --data data/xrpusdt_1m.parquet
|
||||
python scripts/train_mlx_model.py --data data/combined_1m.parquet
|
||||
else
|
||||
echo " 백엔드: LightGBM (CPU)"
|
||||
python scripts/train_model.py --data data/xrpusdt_1m.parquet
|
||||
python scripts/train_model.py --data data/combined_1m.parquet
|
||||
fi
|
||||
|
||||
echo ""
|
||||
|
||||
@@ -148,11 +148,28 @@ def generate_dataset(df: pd.DataFrame, n_jobs: int | None = None) -> pd.DataFram
|
||||
|
||||
def train(data_path: str):
|
||||
print(f"데이터 로드: {data_path}")
|
||||
df = pd.read_parquet(data_path)
|
||||
print(f"캔들 수: {len(df)}")
|
||||
df_raw = pd.read_parquet(data_path)
|
||||
print(f"캔들 수: {len(df_raw)}, 컬럼: {list(df_raw.columns)}")
|
||||
|
||||
# 병합 데이터셋 여부 판별
|
||||
btc_df = None
|
||||
eth_df = None
|
||||
base_cols = ["open", "high", "low", "close", "volume"]
|
||||
|
||||
if "close_btc" in df_raw.columns:
|
||||
btc_df = df_raw[[c + "_btc" for c in base_cols]].copy()
|
||||
btc_df.columns = base_cols
|
||||
print("BTC 피처 활성화")
|
||||
|
||||
if "close_eth" in df_raw.columns:
|
||||
eth_df = df_raw[[c + "_eth" for c in base_cols]].copy()
|
||||
eth_df.columns = base_cols
|
||||
print("ETH 피처 활성화")
|
||||
|
||||
df = df_raw[base_cols].copy()
|
||||
|
||||
print("데이터셋 생성 중...")
|
||||
dataset = generate_dataset_vectorized(df)
|
||||
dataset = generate_dataset_vectorized(df, btc_df=btc_df, eth_df=eth_df)
|
||||
|
||||
if dataset.empty or "label" not in dataset.columns:
|
||||
raise ValueError(f"데이터셋 생성 실패: 샘플 0개. 위 오류 메시지를 확인하세요.")
|
||||
@@ -162,7 +179,9 @@ def train(data_path: str):
|
||||
if len(dataset) < 200:
|
||||
raise ValueError(f"학습 샘플 부족: {len(dataset)}개 (최소 200 필요)")
|
||||
|
||||
X = dataset[FEATURE_COLS]
|
||||
actual_feature_cols = [c for c in FEATURE_COLS if c in dataset.columns]
|
||||
print(f"사용 피처: {len(actual_feature_cols)}개 {actual_feature_cols}")
|
||||
X = dataset[actual_feature_cols]
|
||||
y = dataset["label"]
|
||||
|
||||
split = int(len(X) * 0.8)
|
||||
@@ -208,6 +227,7 @@ def train(data_path: str):
|
||||
"date": datetime.now().isoformat(),
|
||||
"auc": round(auc, 4),
|
||||
"samples": len(dataset),
|
||||
"features": len(actual_feature_cols),
|
||||
"model_path": str(MODEL_PATH),
|
||||
})
|
||||
with open(LOG_PATH, "w") as f:
|
||||
|
||||
Reference in New Issue
Block a user