From 7e4e9315c2c35da7df69c94a11dffa8aa3941090 Mon Sep 17 00:00:00 2001 From: 21in7 Date: Sun, 1 Mar 2026 17:07:18 +0900 Subject: [PATCH] feat: implement ML filter with LightGBM for trading signal validation - Added MLFilter class to load and evaluate LightGBM model for trading signals. - Introduced retraining mechanism to update the model daily based on new data. - Created feature engineering and label building utilities for model training. - Updated bot logic to incorporate ML filter for signal validation. - Added scripts for data fetching and model training. Made-with: Cursor --- .gitignore | 2 + Dockerfile | 2 +- data/.gitkeep | 0 docker-compose.yml | 4 + ...026-03-01-dockerfile-and-docker-compose.md | 282 +++++ .../2026-03-01-fix-pandas-ta-python312.md | 275 ++++ .../2026-03-01-jenkins-gitea-registry-cicd.md | 405 ++++++ docs/plans/2026-03-01-ml-filter-design.md | 102 ++ .../2026-03-01-ml-filter-implementation.md | 1124 +++++++++++++++++ models/.gitkeep | 0 requirements.txt | 4 + scripts/__init__.py | 0 scripts/fetch_history.py | 69 + scripts/train_model.py | 151 +++ src/bot.py | 13 + src/label_builder.py | 29 + src/ml_features.py | 82 ++ src/ml_filter.py | 50 + src/retrainer.py | 92 ++ tests/test_bot.py | 8 +- tests/test_label_builder.py | 73 ++ tests/test_ml_features.py | 57 + tests/test_ml_filter.py | 63 + tests/test_retrainer.py | 35 + 24 files changed, 2916 insertions(+), 6 deletions(-) create mode 100644 data/.gitkeep create mode 100644 docs/plans/2026-03-01-dockerfile-and-docker-compose.md create mode 100644 docs/plans/2026-03-01-fix-pandas-ta-python312.md create mode 100644 docs/plans/2026-03-01-jenkins-gitea-registry-cicd.md create mode 100644 docs/plans/2026-03-01-ml-filter-design.md create mode 100644 docs/plans/2026-03-01-ml-filter-implementation.md create mode 100644 models/.gitkeep create mode 100644 scripts/__init__.py create mode 100644 scripts/fetch_history.py create mode 100644 scripts/train_model.py create mode 100644 src/label_builder.py create mode 100644 src/ml_features.py create mode 100644 src/ml_filter.py create mode 100644 src/retrainer.py create mode 100644 tests/test_label_builder.py create mode 100644 tests/test_ml_features.py create mode 100644 tests/test_ml_filter.py create mode 100644 tests/test_retrainer.py diff --git a/.gitignore b/.gitignore index fdd9e57..1f6eb73 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,5 @@ logs/ *.log .venv/ venv/ +models/*.pkl +data/*.parquet diff --git a/Dockerfile b/Dockerfile index 2a61936..e3d2618 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,7 +11,7 @@ RUN pip install --no-cache-dir -r requirements.txt COPY . . -RUN mkdir -p logs +RUN mkdir -p logs models data ENV PYTHONUNBUFFERED=1 ENV PYTHONDONTWRITEBYTECODE=1 diff --git a/data/.gitkeep b/data/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/docker-compose.yml b/docker-compose.yml index 7ae9699..7025dcf 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -5,8 +5,12 @@ services: restart: unless-stopped env_file: - .env + environment: + - TZ=Asia/Seoul volumes: - ./logs:/app/logs + - ./models:/app/models + - ./data:/app/data logging: driver: "json-file" options: diff --git a/docs/plans/2026-03-01-dockerfile-and-docker-compose.md b/docs/plans/2026-03-01-dockerfile-and-docker-compose.md new file mode 100644 index 0000000..6b27281 --- /dev/null +++ b/docs/plans/2026-03-01-dockerfile-and-docker-compose.md @@ -0,0 +1,282 @@ +# Dockerfile & docker-compose.yml 작성 및 Gitea 업로드 구현 계획 + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** cointrader 프로젝트에 Dockerfile과 docker-compose.yml을 추가하고, 변경사항을 커밋하여 Gitea(10.1.10.28:3000)에 push한다. + +**Architecture:** Python 3.11 slim 이미지 기반의 멀티스테이지 없는 단일 Dockerfile을 작성하고, docker-compose.yml로 환경변수(.env)를 주입하여 컨테이너를 실행한다. 로그 디렉토리는 볼륨으로 마운트하여 컨테이너 재시작 시에도 보존한다. + +**Tech Stack:** Docker, docker-compose v2, Python 3.11-slim, python-dotenv + +--- + +## Task 1: Dockerfile 작성 + +**Files:** +- Create: `Dockerfile` + +**Step 1: Dockerfile 생성** + +`/Users/gihyeon/github/cointrader/Dockerfile` 파일을 아래 내용으로 생성한다: + +```dockerfile +FROM python:3.11-slim + +WORKDIR /app + +RUN apt-get update && apt-get install -y --no-install-recommends \ + gcc \ + && rm -rf /var/lib/apt/lists/* + +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +COPY . . + +RUN mkdir -p logs + +ENV PYTHONUNBUFFERED=1 +ENV PYTHONDONTWRITEBYTECODE=1 + +CMD ["python", "main.py"] +``` + +**Step 2: Dockerfile 내용 확인** + +```bash +cat /Users/gihyeon/github/cointrader/Dockerfile +``` + +Expected: 위 내용이 그대로 출력됨 + +**Step 3: Docker 빌드 테스트 (Docker가 설치된 경우)** + +```bash +cd /Users/gihyeon/github/cointrader +docker build -t cointrader:test . +``` + +Expected: `Successfully built ` 또는 `Successfully tagged cointrader:test` + +> Docker가 설치되지 않은 환경이라면 이 단계는 건너뛴다. + +--- + +## Task 2: docker-compose.yml 작성 + +**Files:** +- Create: `docker-compose.yml` + +**Step 1: docker-compose.yml 생성** + +`/Users/gihyeon/github/cointrader/docker-compose.yml` 파일을 아래 내용으로 생성한다: + +```yaml +services: + cointrader: + build: . + container_name: cointrader + restart: unless-stopped + env_file: + - .env + volumes: + - ./logs:/app/logs + logging: + driver: "json-file" + options: + max-size: "10m" + max-file: "5" +``` + +**Step 2: docker-compose.yml 내용 확인** + +```bash +cat /Users/gihyeon/github/cointrader/docker-compose.yml +``` + +Expected: 위 내용이 그대로 출력됨 + +**Step 3: docker-compose 문법 검증 (docker compose가 설치된 경우)** + +```bash +cd /Users/gihyeon/github/cointrader +docker compose config +``` + +Expected: 파싱된 YAML 설정이 오류 없이 출력됨 + +--- + +## Task 3: .dockerignore 작성 + +**Files:** +- Create: `.dockerignore` + +**Step 1: .dockerignore 생성** + +`/Users/gihyeon/github/cointrader/.dockerignore` 파일을 아래 내용으로 생성한다: + +``` +.env +.venv +__pycache__ +*.pyc +*.pyo +.pytest_cache +logs/ +*.log +.git +docs/ +tests/ +``` + +> `.env`를 반드시 포함시켜 빌드 컨텍스트에서 제외한다. 이미지에 API 키가 포함되는 것을 방지한다. + +**Step 2: .dockerignore 내용 확인** + +```bash +cat /Users/gihyeon/github/cointrader/.dockerignore +``` + +Expected: 위 내용이 그대로 출력됨 + +--- + +## Task 4: git 커밋 + +**Files:** +- Modify: `Dockerfile` (신규) +- Modify: `docker-compose.yml` (신규) +- Modify: `.dockerignore` (신규) + +**Step 1: git 상태 확인** + +```bash +cd /Users/gihyeon/github/cointrader +git status +``` + +Expected: `Dockerfile`, `docker-compose.yml`, `.dockerignore`가 untracked files로 표시됨 + +**Step 2: 스테이징** + +```bash +cd /Users/gihyeon/github/cointrader +git add Dockerfile docker-compose.yml .dockerignore +``` + +**Step 3: 스테이징 내용 검토 (`.env` 포함 여부 확인)** + +```bash +git diff --cached --name-only +``` + +Expected: +``` +.dockerignore +Dockerfile +docker-compose.yml +``` + +`.env`가 목록에 **없어야** 한다. 만약 있다면 즉시 `git reset HEAD .env` 실행 후 중단. + +**Step 4: 커밋** + +```bash +git commit -m "chore: add Dockerfile, docker-compose.yml, .dockerignore" +``` + +Expected: `main` 브랜치에 새 커밋 생성 + +**Step 5: 커밋 확인** + +```bash +git log --oneline -3 +``` + +Expected: 방금 만든 커밋이 최상단에 표시됨 + +--- + +## Task 5: Gitea push + +> 이 Task는 Gitea 원격 저장소가 이미 설정되어 있다고 가정한다. +> 아직 설정하지 않았다면 `docs/plans/2026-03-01-upload-to-gitea.md`의 Task 2~3을 먼저 완료한다. + +**Step 1: 현재 원격 저장소 확인** + +```bash +cd /Users/gihyeon/github/cointrader +git remote -v +``` + +Expected: +``` +origin http://10.1.10.28:3000/<사용자명>/cointrader.git (fetch) +origin http://10.1.10.28:3000/<사용자명>/cointrader.git (push) +``` + +origin이 없다면 아래 명령으로 추가 (`<사용자명>` 교체 필요): +```bash +git remote add origin http://10.1.10.28:3000/<사용자명>/cointrader.git +``` + +**Step 2: push** + +```bash +git push origin main +``` + +> Gitea 계정의 사용자명과 비밀번호(또는 액세스 토큰)를 입력하라는 프롬프트가 나타남 + +Expected: +``` +Enumerating objects: ... +Writing objects: 100% ... +``` + +**Step 3: push 결과 확인** + +```bash +git log --oneline origin/main -3 +``` + +Expected: 로컬 커밋 히스토리와 동일하게 표시됨 + +**Step 4: Gitea 웹 UI에서 파일 확인** + +브라우저에서 `http://10.1.10.28:3000/<사용자명>/cointrader` 접속 후 다음 파일이 있는지 확인: +- `Dockerfile` +- `docker-compose.yml` +- `.dockerignore` + +--- + +## 트러블슈팅 + +| 문제 | 원인 | 해결 | +|------|------|------| +| `docker build` 시 `gcc` 설치 실패 | 네트워크 문제 | `apt-get` 단계를 제거하고 빌드 재시도 (pandas-ta가 gcc 없이 설치되는지 확인) | +| `docker compose config` 오류 | YAML 들여쓰기 오류 | 탭 대신 스페이스 2칸 사용 여부 확인 | +| push 시 `Authentication failed` | 잘못된 계정 정보 | Gitea 웹 UI 로그인 테스트 후 동일 계정 사용 | +| push 시 `non-fast-forward` | 원격에 이미 다른 커밋 존재 | `git pull --rebase origin main` 후 재시도 | +| 컨테이너 실행 시 `.env` 없음 오류 | `.env` 파일 미생성 | `.env.example`을 복사하여 `.env` 생성 후 값 입력 | + +--- + +## 참고: 컨테이너 실행 방법 + +```bash +# .env 파일 준비 +cp .env.example .env +# .env 파일에 실제 API 키와 Discord Webhook URL 입력 + +# 빌드 및 백그라운드 실행 +docker compose up -d --build + +# 로그 확인 +docker compose logs -f + +# 중지 +docker compose down +``` diff --git a/docs/plans/2026-03-01-fix-pandas-ta-python312.md b/docs/plans/2026-03-01-fix-pandas-ta-python312.md new file mode 100644 index 0000000..dd5e437 --- /dev/null +++ b/docs/plans/2026-03-01-fix-pandas-ta-python312.md @@ -0,0 +1,275 @@ +# pandas-ta Python 버전 호환성 수정 계획 + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** Jenkins CI에서 `pandas-ta==0.4.71b0`이 Python 3.11에서 설치 실패하는 문제를 해결한다. + +**Architecture:** `pandas-ta==0.4.71b0`은 Python >=3.12를 요구하므로, Dockerfile의 베이스 이미지를 `python:3.11-slim`에서 `python:3.12-slim`으로 업그레이드한다. `requirements.txt`의 의존 패키지 버전도 Python 3.12와 호환되는 버전으로 정리한다. + +**Tech Stack:** Docker, Python 3.12-slim, pandas-ta 0.4.71b0, python-binance 1.0.19 + +--- + +## 문제 분석 + +Jenkins 빌드 로그 오류: +``` +ERROR: Ignored the following versions that require a different python version: + 0.4.67b0 Requires-Python >=3.12; 0.4.71b0 Requires-Python >=3.12 +ERROR: Could not find a version that satisfies the requirement pandas-ta==0.4.71b0 +``` + +**원인:** `requirements.txt`에 `pandas-ta==0.4.71b0`이 명시되어 있으나, Dockerfile 베이스 이미지가 `python:3.11-slim`이라 Python 3.12 이상을 요구하는 `pandas-ta`를 설치할 수 없다. + +**해결 방향:** Dockerfile 베이스 이미지를 `python:3.12-slim`으로 변경한다. + +--- + +## Task 1: Dockerfile 베이스 이미지 업그레이드 + +**Files:** +- Modify: `Dockerfile:1` + +**Step 1: Dockerfile 수정** + +`Dockerfile` 1번째 줄을 다음과 같이 변경한다: + +변경 전: +```dockerfile +FROM python:3.11-slim +``` + +변경 후: +```dockerfile +FROM python:3.12-slim +``` + +**Step 2: 변경 내용 확인** + +```bash +head -1 /Users/gihyeon/github/cointrader/Dockerfile +``` + +Expected: +``` +FROM python:3.12-slim +``` + +--- + +## Task 2: requirements.txt 의존성 호환성 검토 및 수정 + +**Files:** +- Modify: `requirements.txt` + +**Step 1: 현재 requirements.txt 내용 확인** + +```bash +cat /Users/gihyeon/github/cointrader/requirements.txt +``` + +Expected (현재 내용): +``` +python-binance==1.0.19 +pandas>=2.2.0 +pandas-ta==0.4.71b0 +python-dotenv==1.0.0 +httpx>=0.27.0 +pytest>=8.1.0 +pytest-asyncio>=0.24.0 +aiohttp==3.9.3 +websockets==12.0 +loguru==0.7.2 +``` + +**Step 2: pandas-ta 0.4.71b0의 의존성 확인** + +PyPI 정보에 따르면 `pandas-ta==0.4.71b0`은 다음을 요구한다: +- `numba==0.61.2` +- `numpy>=2.2.6` +- `pandas>=2.3.2` + +`requirements.txt`의 `pandas>=2.2.0`은 `pandas>=2.3.2`를 만족하므로 문제없다. +단, `numba`가 명시되어 있지 않아 pandas-ta 설치 시 자동으로 설치된다. + +**Step 3: requirements.txt 수정 (pandas 최소 버전 상향)** + +`pandas>=2.2.0`을 `pandas>=2.3.2`로 변경하여 pandas-ta의 요구사항을 명시적으로 반영한다: + +변경 전: +``` +pandas>=2.2.0 +``` + +변경 후: +``` +pandas>=2.3.2 +``` + +**Step 4: 변경 내용 확인** + +```bash +grep "pandas" /Users/gihyeon/github/cointrader/requirements.txt +``` + +Expected: +``` +pandas>=2.3.2 +pandas-ta==0.4.71b0 +``` + +--- + +## Task 3: 로컬 Docker 빌드 테스트 + +> Docker가 설치된 환경에서만 실행한다. + +**Step 1: Docker 빌드** + +```bash +cd /Users/gihyeon/github/cointrader +docker build -t cointrader:test . +``` + +Expected: 빌드 성공 (`Successfully tagged cointrader:test` 또는 `#N DONE`) + +**Step 2: 빌드된 이미지의 Python 버전 확인** + +```bash +docker run --rm cointrader:test python --version +``` + +Expected: +``` +Python 3.12.x +``` + +**Step 3: pandas-ta import 확인** + +```bash +docker run --rm cointrader:test python -c "import pandas_ta; print(pandas_ta.__version__)" +``` + +Expected: +``` +0.4.71b0 +``` + +**Step 4: 테스트 이미지 정리** + +```bash +docker rmi cointrader:test +``` + +--- + +## Task 4: git 커밋 및 Gitea push + +**Files:** +- Modify: `Dockerfile` +- Modify: `requirements.txt` + +**Step 1: git 상태 확인** + +```bash +cd /Users/gihyeon/github/cointrader +git status +``` + +Expected: +``` +modified: Dockerfile +modified: requirements.txt +``` + +**Step 2: 변경 내용 검토** + +```bash +git diff Dockerfile requirements.txt +``` + +Expected: +- `Dockerfile`: `-FROM python:3.11-slim` → `+FROM python:3.12-slim` +- `requirements.txt`: `-pandas>=2.2.0` → `+pandas>=2.3.2` + +**Step 3: 스테이징** + +```bash +git add Dockerfile requirements.txt +``` + +**Step 4: 커밋** + +```bash +git commit -m "fix: upgrade to Python 3.12 to support pandas-ta>=0.4.67b0" +``` + +Expected: 커밋 성공 + +**Step 5: Gitea push** + +```bash +git push origin main +``` + +Expected: push 성공 후 Jenkins가 자동으로 새 빌드를 트리거함 + +**Step 6: 커밋 확인** + +```bash +git log --oneline -3 +``` + +Expected: 방금 만든 커밋이 최상단에 표시됨 + +--- + +## Task 5: Jenkins 빌드 재실행 및 결과 확인 + +**Step 1: Jenkins 빌드 트리거** + +Gitea push 후 Jenkins Webhook이 설정되어 있다면 자동으로 빌드가 트리거된다. +수동으로 트리거하려면 Jenkins 웹 UI에서 `cointrader` 파이프라인 → `Build Now` 클릭. + +**Step 2: 빌드 로그에서 성공 확인** + +Jenkins 빌드 로그에서 다음 내용이 나타나야 한다: + +``` +#9 [5/7] RUN pip install --no-cache-dir -r requirements.txt +... +Successfully installed pandas-ta-0.4.71b0 ... +#9 DONE xx.xs +``` + +오류 없이 `[Build Docker Image]` 스테이지가 완료되어야 한다. + +**Step 3: 전체 파이프라인 성공 확인** + +Jenkins 빌드 결과가 `SUCCESS`로 표시되어야 한다: +``` +Finished: SUCCESS +``` + +--- + +## 트러블슈팅 + +| 문제 | 원인 | 해결 | +|------|------|------| +| `python-binance==1.0.19` 설치 실패 | Python 3.12 비호환 | `python-binance>=1.0.19`로 변경하거나 최신 버전 확인 | +| `aiohttp==3.9.3` 설치 실패 | Python 3.12 비호환 | `aiohttp>=3.9.3`으로 완화하거나 최신 버전으로 업그레이드 | +| `numba` 설치 시간 초과 | numba 컴파일 시간 | 빌드 타임아웃 설정 증가 또는 `--timeout=300` 추가 | +| Jenkins Webhook 미동작 | Gitea Webhook 미설정 | Gitea 저장소 설정 → Webhooks → Jenkins URL 추가 | + +--- + +## 참고: Python 3.12 호환성 체크리스트 + +Python 3.11 → 3.12 주요 변경사항 중 이 프로젝트에 영향 가능한 항목: + +- `asyncio` 동작 변경: `asyncio.get_event_loop()` deprecated → `asyncio.get_running_loop()` 권장 +- `typing` 모듈 일부 변경: `Union[X, Y]` → `X | Y` 문법 지원 강화 +- `datetime.utcnow()` deprecated → `datetime.now(timezone.utc)` 권장 + +현재 코드베이스(`src/`, `tests/`)에서 위 패턴 사용 여부를 확인하고 필요 시 수정한다. diff --git a/docs/plans/2026-03-01-jenkins-gitea-registry-cicd.md b/docs/plans/2026-03-01-jenkins-gitea-registry-cicd.md new file mode 100644 index 0000000..a77a7c2 --- /dev/null +++ b/docs/plans/2026-03-01-jenkins-gitea-registry-cicd.md @@ -0,0 +1,405 @@ +# Jenkins + Gitea 이미지 레지스트리 CI/CD 구현 계획 + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** Jenkins가 Gitea(10.1.10.28:3000)의 코드 변경을 감지하면 Docker 이미지를 빌드하여 Gitea Container Registry(10.1.10.28:5000 또는 Gitea 내장 패키지 레지스트리)에 push하고, docker-compose.yml이 해당 이미지를 pull해서 실행하도록 전체 CI/CD 파이프라인을 구성한다. + +**Architecture:** +- Jenkins는 Gitea webhook을 통해 main 브랜치 push 이벤트를 수신한다. +- Jenkinsfile(파이프라인 스크립트)이 프로젝트 루트에 위치하며, `docker build → docker push → (선택) 원격 배포` 단계를 수행한다. +- Gitea의 내장 Container Registry(Packages)를 이미지 저장소로 사용한다. 이미지 이름 형식: `10.1.10.28:3000/gihyeon/cointrader:` +- docker-compose.yml은 `build: .` 대신 레지스트리 이미지를 직접 참조하도록 수정한다. + +**Tech Stack:** Jenkins, Gitea Container Registry, Docker, docker-compose v2, Jenkinsfile(Declarative Pipeline) + +--- + +## 사전 확인 사항 + +- Gitea 서버: `http://10.1.10.28:3000` +- Gitea 저장소: `http://10.1.10.28:3000/gihyeon/cointrader.git` +- Gitea Container Registry 주소: `10.1.10.28:3000` (HTTP 사용 시 Docker insecure-registries 설정 필요) +- Jenkins 서버 주소: 별도 확인 필요 (아래 Task 1에서 확인) +- 현재 Dockerfile: `FROM python:3.12-slim` 기반, `/app`에서 `python main.py` 실행 + +--- + +## Task 1: 환경 사전 점검 + +**Files:** +- 확인: `Dockerfile` +- 확인: `docker-compose.yml` + +**Step 1: Gitea Container Registry(Packages) 활성화 확인** + +브라우저에서 `http://10.1.10.28:3000/gihyeon/cointrader/packages` 접속. +- 패키지 탭이 보이면 활성화된 것. +- 안 보이면 Gitea 관리자 패널 → `Site Administration` → `Configuration` → `Enable Packages` 체크 필요. + +**Step 2: Gitea Access Token 생성 (Jenkins용)** + +`http://10.1.10.28:3000/user/settings/applications` 접속: +- Token Name: `jenkins-cointrader` +- 권한: `read:packages`, `write:packages` (또는 전체 권한) +- `Generate Token` 클릭 후 **토큰 값을 반드시 복사** (다시 볼 수 없음) + +**Step 3: Docker insecure-registries 설정 (HTTP 레지스트리 사용 시)** + +Jenkins가 실행되는 서버(또는 로컬 Mac)에서: + +```bash +# /etc/docker/daemon.json 또는 Docker Desktop의 경우 Settings > Docker Engine +cat /etc/docker/daemon.json +``` + +아래 내용이 없으면 추가: +```json +{ + "insecure-registries": ["10.1.10.28:3000"] +} +``` + +Docker Desktop 사용 시: `Settings` → `Docker Engine` → JSON에 위 내용 병합 → `Apply & Restart` + +**Step 4: Docker login 테스트** + +```bash +docker login 10.1.10.28:3000 -u gihyeon -p <위에서_생성한_토큰> +``` + +Expected: +``` +Login Succeeded +``` + +--- + +## Task 2: Jenkinsfile 작성 + +**Files:** +- Create: `Jenkinsfile` + +**Step 1: Jenkinsfile 생성** + +`/Users/gihyeon/github/cointrader/Jenkinsfile` 파일을 아래 내용으로 생성: + +```groovy +pipeline { + agent any + + environment { + REGISTRY = '10.1.10.28:3000' + IMAGE_NAME = 'gihyeon/cointrader' + IMAGE_TAG = "${env.BUILD_NUMBER}" + FULL_IMAGE = "${REGISTRY}/${IMAGE_NAME}:${IMAGE_TAG}" + LATEST_IMAGE = "${REGISTRY}/${IMAGE_NAME}:latest" + GITEA_CREDS = credentials('gitea-registry-credentials') + } + + stages { + stage('Checkout') { + steps { + checkout scm + } + } + + stage('Build Image') { + steps { + sh "docker build -t ${FULL_IMAGE} -t ${LATEST_IMAGE} ." + } + } + + stage('Push to Gitea Registry') { + steps { + sh """ + echo ${GITEA_CREDS_PSW} | docker login ${REGISTRY} -u ${GITEA_CREDS_USR} --password-stdin + docker push ${FULL_IMAGE} + docker push ${LATEST_IMAGE} + """ + } + } + + stage('Cleanup') { + steps { + sh """ + docker rmi ${FULL_IMAGE} || true + docker rmi ${LATEST_IMAGE} || true + """ + } + } + } + + post { + success { + echo "Build #${env.BUILD_NUMBER} pushed: ${FULL_IMAGE}" + } + failure { + echo "Build #${env.BUILD_NUMBER} FAILED" + } + } +} +``` + +> **참고:** +> - `GITEA_CREDS`는 Jenkins Credentials에 등록할 Username+Password 자격증명 ID다 (Task 3에서 등록). +> - `IMAGE_TAG`는 Jenkins 빌드 번호를 사용한다. 태그 전략을 git 커밋 해시로 바꾸려면 `"${env.GIT_COMMIT[0..7]}"` 사용. +> - `Cleanup` 스테이지는 Jenkins 서버 디스크 절약을 위해 빌드 후 로컬 이미지를 삭제한다. + +**Step 2: Jenkinsfile 내용 확인** + +```bash +cat /Users/gihyeon/github/cointrader/Jenkinsfile +``` + +Expected: 위 내용이 출력됨 + +--- + +## Task 3: Jenkins에 Gitea Credentials 등록 + +**Step 1: Jenkins 웹 UI 접속** + +`http://:8080` 접속 (Jenkins 서버 주소 확인 필요) + +**Step 2: Credentials 등록** + +`Jenkins` → `Manage Jenkins` → `Credentials` → `System` → `Global credentials` → `Add Credentials`: + +| 항목 | 값 | +|------|----| +| Kind | Username with password | +| Scope | Global | +| Username | `gihyeon` | +| Password | Task 1 Step 2에서 생성한 Gitea Access Token | +| ID | `gitea-registry-credentials` | +| Description | Gitea Container Registry for cointrader | + +`Create` 클릭 + +**Step 3: 등록 확인** + +Credentials 목록에 `gitea-registry-credentials`가 표시되는지 확인 + +--- + +## Task 4: Jenkins Pipeline Job 생성 + +**Step 1: 새 Pipeline Job 생성** + +`Jenkins` → `New Item`: +- Item name: `cointrader` +- Type: `Pipeline` +- `OK` 클릭 + +**Step 2: Pipeline 설정** + +`Pipeline` 섹션에서: +- Definition: `Pipeline script from SCM` +- SCM: `Git` +- Repository URL: `http://10.1.10.28:3000/gihyeon/cointrader.git` +- Credentials: (Gitea 저장소 접근용 credentials 추가, 없으면 Task 3과 동일하게 추가) +- Branch Specifier: `*/main` +- Script Path: `Jenkinsfile` + +`Save` 클릭 + +**Step 3: 수동 빌드 테스트** + +`Build Now` 클릭 → Console Output 확인 + +Expected: +``` +[Pipeline] stage: Build Image +Successfully built ... +[Pipeline] stage: Push to Gitea Registry +Login Succeeded +The push refers to repository [10.1.10.28:3000/gihyeon/cointrader] +... +latest: digest: sha256:... size: ... +[Pipeline] stage: Cleanup +Finished: SUCCESS +``` + +--- + +## Task 5: Gitea Webhook 설정 (자동 트리거) + +**Step 1: Gitea 저장소 Webhook 추가** + +`http://10.1.10.28:3000/gihyeon/cointrader/settings/hooks` 접속: +- `Add Webhook` → `Gitea` +- Target URL: `http://:8080/gitea-webhook/post` + - Jenkins Gitea Plugin 사용 시 위 URL 형식 + - 일반 Generic Webhook 사용 시: `http://:8080/job/cointrader/build?token=<토큰>` +- Trigger: `Push Events` +- Branch filter: `main` +- `Add Webhook` 클릭 + +**Step 2: Jenkins에 Gitea Plugin 설치 (미설치 시)** + +`Manage Jenkins` → `Plugins` → `Available plugins` → `Gitea` 검색 → 설치 후 재시작 + +**Step 3: Webhook 테스트** + +Gitea Webhook 설정 페이지에서 `Test Delivery` 클릭 + +Expected: Jenkins에서 새 빌드가 자동으로 시작됨 + +--- + +## Task 6: docker-compose.yml 수정 + +**Files:** +- Modify: `docker-compose.yml` + +현재 `docker-compose.yml`은 `build: .`으로 로컬 빌드를 사용한다. 이를 레지스트리 이미지를 pull해서 실행하도록 변경한다. + +**Step 1: docker-compose.yml 수정** + +`/Users/gihyeon/github/cointrader/docker-compose.yml`을 아래 내용으로 교체: + +```yaml +services: + cointrader: + image: 10.1.10.28:3000/gihyeon/cointrader:latest + container_name: cointrader + restart: unless-stopped + env_file: + - .env + volumes: + - ./logs:/app/logs + logging: + driver: "json-file" + options: + max-size: "10m" + max-file: "5" +``` + +> **변경 사항:** +> - `build: .` → `image: 10.1.10.28:3000/gihyeon/cointrader:latest` +> - 이제 `docker compose up -d`를 실행하면 로컬 빌드 없이 레지스트리에서 이미지를 pull한다. +> - 배포 서버에서 최신 이미지로 업데이트하려면: `docker compose pull && docker compose up -d` + +**Step 2: 변경 내용 확인** + +```bash +cat /Users/gihyeon/github/cointrader/docker-compose.yml +``` + +Expected: `image:` 필드가 레지스트리 주소를 가리킴 + +**Step 3: (선택) 로컬 개발용 docker-compose.override.yml 생성** + +로컬에서 소스 코드를 직접 빌드해서 테스트하고 싶을 때를 위한 override 파일: + +```yaml +# docker-compose.override.yml (로컬 개발 전용, git에 포함하지 않아도 됨) +services: + cointrader: + build: . + image: cointrader:local +``` + +이 파일이 있으면 `docker compose up -d`가 자동으로 `build: .`을 사용한다. 프로덕션 서버에는 이 파일을 두지 않는다. + +--- + +## Task 7: 변경사항 커밋 및 Push + +**Step 1: 변경 파일 확인** + +```bash +cd /Users/gihyeon/github/cointrader +git status +``` + +Expected: `Jenkinsfile`(new), `docker-compose.yml`(modified)이 표시됨 + +**Step 2: 스테이징** + +```bash +git add Jenkinsfile docker-compose.yml +``` + +**Step 3: `.env` 미포함 확인** + +```bash +git diff --cached --name-only +``` + +Expected: `Jenkinsfile`, `docker-compose.yml` 두 파일만 표시됨 + +**Step 4: 커밋** + +```bash +git commit -m "ci: Jenkins pipeline + Gitea registry CI/CD 설정" +``` + +Expected: `main` 브랜치에 새 커밋 생성 + +**Step 5: Gitea에 Push** + +```bash +git push origin main +``` + +Expected: Push 성공 + (Webhook 설정 완료 시) Jenkins 빌드 자동 시작 + +--- + +## Task 8: 엔드-투-엔드 검증 + +**Step 1: 코드 변경 후 push 테스트** + +```bash +cd /Users/gihyeon/github/cointrader +# 아무 파일이나 사소하게 변경 (예: README 한 줄 추가) +echo "# CI/CD test" >> README.md +git add README.md +git commit -m "test: CI/CD 파이프라인 검증용 더미 커밋" +git push origin main +``` + +**Step 2: Jenkins 빌드 자동 시작 확인** + +Jenkins UI에서 `cointrader` 잡의 빌드가 자동으로 시작되는지 확인 (30초 이내) + +**Step 3: Gitea 레지스트리에 이미지 push 확인** + +`http://10.1.10.28:3000/gihyeon/cointrader/packages` 접속 → `cointrader` 컨테이너 패키지에 새 태그가 생성되었는지 확인 + +**Step 4: 이미지 pull 테스트** + +```bash +docker pull 10.1.10.28:3000/gihyeon/cointrader:latest +``` + +Expected: +``` +latest: Pulling from gihyeon/cointrader +... +Status: Downloaded newer image for 10.1.10.28:3000/gihyeon/cointrader:latest +``` + +**Step 5: docker compose로 실행 테스트** + +```bash +cd /Users/gihyeon/github/cointrader +docker compose up -d +docker compose logs -f --tail=20 +``` + +Expected: 컨테이너가 정상 시작되고 로그가 출력됨 + +--- + +## 트러블슈팅 + +| 문제 | 원인 | 해결 | +|------|------|------| +| `http: server gave HTTP response to HTTPS client` | Docker가 HTTPS로 레지스트리 접근 시도 | `daemon.json`에 `insecure-registries` 추가 후 Docker 재시작 | +| `unauthorized: authentication required` | Credentials 미등록 또는 토큰 만료 | Task 1 Step 2에서 새 토큰 발급 후 Jenkins Credentials 업데이트 | +| `connection refused` to Jenkins | Jenkins URL 오타 또는 방화벽 | Jenkins 서버 주소 재확인 | +| Webhook이 Jenkins를 트리거하지 않음 | Jenkins URL이 Gitea 서버에서 접근 불가 | Jenkins가 Gitea 서버와 같은 네트워크에 있는지 확인, 방화벽 8080 포트 오픈 | +| `image not found` on docker compose pull | 이미지가 아직 push되지 않음 | Jenkins 빌드 완료 후 재시도 | +| Jenkins에서 `docker: command not found` | Jenkins 에이전트에 Docker 미설치 | Jenkins 서버에 Docker 설치 또는 Docker-in-Docker 설정 | diff --git a/docs/plans/2026-03-01-ml-filter-design.md b/docs/plans/2026-03-01-ml-filter-design.md new file mode 100644 index 0000000..e665114 --- /dev/null +++ b/docs/plans/2026-03-01-ml-filter-design.md @@ -0,0 +1,102 @@ +# ML 필터 설계 문서 + +**날짜:** 2026-03-01 + +## 목적 + +기존 규칙 기반 신호(LONG/SHORT/HOLD)가 발생했을 때, LightGBM 모델이 해당 진입이 수익으로 끝날 확률을 계산하여 낮은 확률의 진입을 차단하는 보조 필터를 구현한다. + +--- + +## 아키텍처 개요 + +``` +캔들 수신 → 기술 지표 계산 → 규칙 기반 신호(LONG/SHORT/HOLD) + ↓ + 신호 != HOLD 일 때만 + ↓ + [ML 필터] LightGBM.predict_proba() + ↓ + 확률 >= 0.60 이면 진입 허용 + 확률 < 0.60 이면 진입 차단 +``` + +--- + +## 레이블 정의 + +- **1 (성공):** 진입 후 `take_profit` 가격에 먼저 도달 +- **0 (실패):** 진입 후 `stop_loss` 가격에 먼저 도달 +- TP/SL 계산은 기존 `Indicators.get_atr_stop()` 재사용 (ATR 기반) + +--- + +## 피처 목록 + +| 피처 | 설명 | +|---|---| +| `rsi` | RSI(14) | +| `macd_hist` | MACD 히스토그램 | +| `bb_pct` | 볼린저밴드 내 가격 위치 (0~1) | +| `ema_align` | EMA 정배열 여부 (1=정배열, -1=역배열, 0=혼재) | +| `stoch_k` | Stochastic RSI K | +| `stoch_d` | Stochastic RSI D | +| `atr_pct` | ATR / 현재가 (변동성 비율) | +| `vol_ratio` | 거래량 / vol_ma20 | +| `ret_1` | 1캔들 전 대비 수익률 | +| `ret_3` | 3캔들 전 대비 수익률 | +| `ret_5` | 5캔들 전 대비 수익률 | +| `signal_strength` | 규칙 기반 신호 강도 (long/short_signals 수) | +| `side` | 신호 방향 (1=LONG, 0=SHORT) | + +--- + +## 신규 컴포넌트 + +| 컴포넌트 | 파일 | 역할 | +|---|---|---| +| 피처 엔지니어링 | `src/ml_features.py` | 기술 지표 → ML 피처 변환 | +| ML 필터 | `src/ml_filter.py` | 모델 로드 + 예측 + 폴백 | +| 재학습 스케줄러 | `src/retrainer.py` | 매일 새벽 재학습 트리거 | +| 데이터 수집 스크립트 | `scripts/fetch_history.py` | 바이낸스 과거 캔들 수집 | +| 학습 스크립트 | `scripts/train_model.py` | LightGBM 학습 + 저장 | + +--- + +## 재학습 스케줄 + +- **초기:** `scripts/fetch_history.py` + `scripts/train_model.py` 수동 실행 +- **이후:** 매일 새벽 3시 (KST) `retrainer.py`가 자동 실행 + - 새 모델 AUC > 기존 모델 AUC → 교체 + - 그렇지 않으면 기존 모델 유지 (롤백) + - Discord 알림으로 결과 전송 + +--- + +## 모델 저장 구조 + +``` +models/ +├── lgbm_filter.pkl ← 현재 사용 중인 모델 +├── lgbm_filter_prev.pkl ← 롤백용 이전 모델 +└── training_log.json ← 재학습 이력 (날짜, AUC, 샘플 수) +``` + +--- + +## 폴백 정책 + +`models/lgbm_filter.pkl` 파일이 없으면 ML 필터를 건너뛰고 기존 규칙 기반 신호 그대로 사용. 봇이 모델 없이도 정상 작동. + +--- + +## bot.py 변경 범위 + +`process_candle()` 메서드에 3줄 추가: + +```python +if signal != "HOLD" and self.ml_filter.is_model_loaded(): + features = build_features(df_with_indicators, signal) + if not self.ml_filter.should_enter(features): + signal = "HOLD" +``` diff --git a/docs/plans/2026-03-01-ml-filter-implementation.md b/docs/plans/2026-03-01-ml-filter-implementation.md new file mode 100644 index 0000000..6fe1ee3 --- /dev/null +++ b/docs/plans/2026-03-01-ml-filter-implementation.md @@ -0,0 +1,1124 @@ +# ML 필터 (LightGBM) 구현 계획 + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** 기존 규칙 기반 신호(LONG/SHORT/HOLD)가 발생했을 때 LightGBM 모델이 수익 확률을 계산해 낮은 확률의 진입을 차단하는 보조 필터를 구현한다. + +**Architecture:** 과거 캔들 데이터로 LightGBM을 오프라인 학습시키고, 봇의 `process_candle()`에서 규칙 기반 신호가 나오면 ML 필터가 확률을 계산해 0.60 미만이면 진입을 차단한다. 매일 새벽 3시에 자동 재학습하며 성능이 나빠지면 이전 모델로 롤백한다. + +**Tech Stack:** Python 3.11+, lightgbm, scikit-learn, joblib, pandas, asyncio (기존 스택 유지) + +--- + +## Task 1: 의존성 추가 + +**Files:** +- Modify: `requirements.txt` + +**Step 1: requirements.txt에 ML 패키지 추가** + +``` +lightgbm>=4.3.0 +scikit-learn>=1.4.0 +joblib>=1.3.0 +pyarrow>=15.0.0 +``` + +**Step 2: 패키지 설치** + +```bash +pip install lightgbm scikit-learn joblib pyarrow +``` + +Expected: 설치 완료 메시지 출력 + +**Step 3: models/ 디렉토리 생성** + +```bash +mkdir -p models data scripts +touch models/.gitkeep data/.gitkeep +``` + +**Step 4: .gitignore에 모델/데이터 파일 추가** + +기존 `.gitignore`에 추가: +``` +models/*.pkl +data/*.parquet +``` + +**Step 5: Commit** + +```bash +git add requirements.txt .gitignore models/.gitkeep data/.gitkeep +git commit -m "feat: add ML dependencies and directory structure" +``` + +--- + +## Task 2: 피처 엔지니어링 모듈 (`src/ml_features.py`) + +**Files:** +- Create: `src/ml_features.py` +- Create: `tests/test_ml_features.py` + +**Step 1: 실패하는 테스트 작성** + +```python +# tests/test_ml_features.py +import pandas as pd +import numpy as np +import pytest +from src.ml_features import build_features, FEATURE_COLS + +def make_df(n=100): + """테스트용 최소 DataFrame 생성""" + np.random.seed(42) + close = 100 + np.cumsum(np.random.randn(n) * 0.5) + df = pd.DataFrame({ + "open": close * 0.999, + "high": close * 1.002, + "low": close * 0.998, + "close": close, + "volume": np.random.uniform(1000, 5000, n), + }) + return df + +def test_build_features_returns_series(): + from src.indicators import Indicators + df = make_df(100) + ind = Indicators(df) + df_ind = ind.calculate_all() + features = build_features(df_ind, signal="LONG") + assert isinstance(features, pd.Series) + +def test_build_features_has_all_cols(): + from src.indicators import Indicators + df = make_df(100) + ind = Indicators(df) + df_ind = ind.calculate_all() + features = build_features(df_ind, signal="LONG") + for col in FEATURE_COLS: + assert col in features.index, f"피처 누락: {col}" + +def test_build_features_no_nan(): + from src.indicators import Indicators + df = make_df(100) + ind = Indicators(df) + df_ind = ind.calculate_all() + features = build_features(df_ind, signal="LONG") + assert not features.isna().any(), f"NaN 존재: {features[features.isna()]}" + +def test_side_encoding(): + from src.indicators import Indicators + df = make_df(100) + ind = Indicators(df) + df_ind = ind.calculate_all() + long_feat = build_features(df_ind, signal="LONG") + short_feat = build_features(df_ind, signal="SHORT") + assert long_feat["side"] == 1 + assert short_feat["side"] == 0 +``` + +**Step 2: 테스트 실패 확인** + +```bash +pytest tests/test_ml_features.py -v +``` + +Expected: FAIL with "cannot import name 'build_features'" + +**Step 3: `src/ml_features.py` 구현** + +```python +import pandas as pd +import numpy as np + +FEATURE_COLS = [ + "rsi", "macd_hist", "bb_pct", "ema_align", + "stoch_k", "stoch_d", "atr_pct", "vol_ratio", + "ret_1", "ret_3", "ret_5", "signal_strength", "side", +] + + +def build_features(df: pd.DataFrame, signal: str) -> pd.Series: + """ + 기술 지표가 계산된 DataFrame의 마지막 행에서 ML 피처를 추출한다. + signal: "LONG" | "SHORT" + """ + last = df.iloc[-1] + close = last["close"] + + bb_upper = last.get("bb_upper", close) + bb_lower = last.get("bb_lower", close) + bb_range = bb_upper - bb_lower + bb_pct = (close - bb_lower) / bb_range if bb_range > 0 else 0.5 + + ema9 = last.get("ema9", close) + ema21 = last.get("ema21", close) + ema50 = last.get("ema50", close) + if ema9 > ema21 > ema50: + ema_align = 1 + elif ema9 < ema21 < ema50: + ema_align = -1 + else: + ema_align = 0 + + atr = last.get("atr", 0) + atr_pct = atr / close if close > 0 else 0 + + vol_ma20 = last.get("vol_ma20", last.get("volume", 1)) + vol_ratio = last["volume"] / vol_ma20 if vol_ma20 > 0 else 1.0 + + closes = df["close"] + ret_1 = (close - closes.iloc[-2]) / closes.iloc[-2] if len(closes) >= 2 else 0.0 + ret_3 = (close - closes.iloc[-4]) / closes.iloc[-4] if len(closes) >= 4 else 0.0 + ret_5 = (close - closes.iloc[-6]) / closes.iloc[-6] if len(closes) >= 6 else 0.0 + + # 규칙 기반 신호 강도 재계산 (indicators.py get_signal 로직 참조) + prev = df.iloc[-2] if len(df) >= 2 else last + strength = 0 + rsi = last.get("rsi", 50) + macd = last.get("macd", 0) + macd_sig = last.get("macd_signal", 0) + prev_macd = prev.get("macd", 0) + prev_macd_sig = prev.get("macd_signal", 0) + stoch_k = last.get("stoch_k", 50) + stoch_d = last.get("stoch_d", 50) + + if signal == "LONG": + if rsi < 35: strength += 1 + if prev_macd < prev_macd_sig and macd > macd_sig: strength += 2 + if close < last.get("bb_lower", close): strength += 1 + if ema_align == 1: strength += 1 + if stoch_k < 20 and stoch_k > stoch_d: strength += 1 + else: + if rsi > 65: strength += 1 + if prev_macd > prev_macd_sig and macd < macd_sig: strength += 2 + if close > last.get("bb_upper", close): strength += 1 + if ema_align == -1: strength += 1 + if stoch_k > 80 and stoch_k < stoch_d: strength += 1 + + return pd.Series({ + "rsi": float(rsi), + "macd_hist": float(last.get("macd_hist", 0)), + "bb_pct": float(bb_pct), + "ema_align": float(ema_align), + "stoch_k": float(stoch_k), + "stoch_d": float(last.get("stoch_d", 50)), + "atr_pct": float(atr_pct), + "vol_ratio": float(vol_ratio), + "ret_1": float(ret_1), + "ret_3": float(ret_3), + "ret_5": float(ret_5), + "signal_strength": float(strength), + "side": 1.0 if signal == "LONG" else 0.0, + }) +``` + +**Step 4: 테스트 통과 확인** + +```bash +pytest tests/test_ml_features.py -v +``` + +Expected: 4개 PASS + +**Step 5: Commit** + +```bash +git add src/ml_features.py tests/test_ml_features.py +git commit -m "feat: add ML feature engineering module" +``` + +--- + +## Task 3: 레이블 생성 유틸리티 (`src/label_builder.py`) + +**Files:** +- Create: `src/label_builder.py` +- Create: `tests/test_label_builder.py` + +**Step 1: 실패하는 테스트 작성** + +```python +# tests/test_label_builder.py +import pandas as pd +import numpy as np +import pytest +from src.label_builder import build_labels + + +def make_signal_df(): + """ + 신호 발생 시점 이후 가격이 TP에 도달하는 시나리오 + entry=100, TP=103, SL=98.5 + """ + # 신호 시점 이후 캔들: 점진적으로 상승해서 103 돌파 + future_closes = [100.5, 101.0, 101.8, 102.5, 103.1, 103.5] + future_highs = [c + 0.3 for c in future_closes] + future_lows = [c - 0.3 for c in future_closes] + return future_closes, future_highs, future_lows + + +def test_label_tp_reached(): + closes, highs, lows = make_signal_df() + label = build_labels( + future_closes=closes, + future_highs=highs, + future_lows=lows, + take_profit=103.0, + stop_loss=98.5, + side="LONG", + ) + assert label == 1, "TP 먼저 도달해야 레이블 1" + + +def test_label_sl_reached(): + # 하락해서 SL 먼저 도달 + future_closes = [99.5, 99.0, 98.8, 98.4, 98.0] + future_highs = [c + 0.3 for c in future_closes] + future_lows = [c - 0.3 for c in future_closes] + label = build_labels( + future_closes=future_closes, + future_highs=future_highs, + future_lows=future_lows, + take_profit=103.0, + stop_loss=98.5, + side="LONG", + ) + assert label == 0, "SL 먼저 도달해야 레이블 0" + + +def test_label_neither_reached_returns_none(): + # 아무것도 도달 못함 + future_closes = [100.1, 100.2, 100.3] + future_highs = [c + 0.1 for c in future_closes] + future_lows = [c - 0.1 for c in future_closes] + label = build_labels( + future_closes=future_closes, + future_highs=future_highs, + future_lows=future_lows, + take_profit=103.0, + stop_loss=98.5, + side="LONG", + ) + assert label is None, "미결 시 None 반환" + + +def test_label_short_tp(): + # SHORT: 가격 하락 → TP 도달 + future_closes = [99.5, 99.0, 98.0, 97.0] + future_highs = [c + 0.3 for c in future_closes] + future_lows = [c - 0.3 for c in future_closes] + label = build_labels( + future_closes=future_closes, + future_highs=future_highs, + future_lows=future_lows, + take_profit=97.0, + stop_loss=101.5, + side="SHORT", + ) + assert label == 1 +``` + +**Step 2: 테스트 실패 확인** + +```bash +pytest tests/test_label_builder.py -v +``` + +Expected: FAIL with "cannot import name 'build_labels'" + +**Step 3: `src/label_builder.py` 구현** + +```python +from typing import Optional + + +def build_labels( + future_closes: list[float], + future_highs: list[float], + future_lows: list[float], + take_profit: float, + stop_loss: float, + side: str, +) -> Optional[int]: + """ + 진입 이후 미래 캔들을 순서대로 확인해 TP/SL 도달 여부를 판단한다. + LONG: high >= TP → 1, low <= SL → 0 + SHORT: low <= TP → 1, high >= SL → 0 + 둘 다 미도달 → None (학습 데이터에서 제외) + """ + for high, low in zip(future_highs, future_lows): + if side == "LONG": + if high >= take_profit: + return 1 + if low <= stop_loss: + return 0 + else: # SHORT + if low <= take_profit: + return 1 + if high >= stop_loss: + return 0 + return None +``` + +**Step 4: 테스트 통과 확인** + +```bash +pytest tests/test_label_builder.py -v +``` + +Expected: 4개 PASS + +**Step 5: Commit** + +```bash +git add src/label_builder.py tests/test_label_builder.py +git commit -m "feat: add label builder for TP/SL simulation" +``` + +--- + +## Task 4: 과거 데이터 수집 스크립트 (`scripts/fetch_history.py`) + +**Files:** +- Create: `scripts/fetch_history.py` +- Create: `scripts/__init__.py` + +**Step 1: `scripts/fetch_history.py` 작성** + +```python +""" +바이낸스 선물 REST API로 과거 캔들 데이터를 수집해 parquet으로 저장한다. +사용법: python scripts/fetch_history.py --symbol XRPUSDT --interval 1m --days 90 +""" +import asyncio +import argparse +from datetime import datetime, timedelta +import pandas as pd +from binance import AsyncClient +from dotenv import load_dotenv +import os + +load_dotenv() + + +async def fetch_klines(symbol: str, interval: str, days: int) -> pd.DataFrame: + client = await AsyncClient.create( + api_key=os.getenv("BINANCE_API_KEY", ""), + api_secret=os.getenv("BINANCE_API_SECRET", ""), + ) + try: + start_ts = int((datetime.utcnow() - timedelta(days=days)).timestamp() * 1000) + all_klines = [] + while True: + klines = await client.futures_klines( + symbol=symbol, + interval=interval, + startTime=start_ts, + limit=1500, + ) + if not klines: + break + all_klines.extend(klines) + last_ts = klines[-1][0] + if last_ts >= int(datetime.utcnow().timestamp() * 1000): + break + start_ts = last_ts + 1 + print(f"수집 중... {len(all_klines)}개") + finally: + await client.close_connection() + + df = pd.DataFrame(all_klines, columns=[ + "timestamp", "open", "high", "low", "close", "volume", + "close_time", "quote_volume", "trades", + "taker_buy_base", "taker_buy_quote", "ignore", + ]) + df = df[["timestamp", "open", "high", "low", "close", "volume"]].copy() + for col in ["open", "high", "low", "close", "volume"]: + df[col] = df[col].astype(float) + df["timestamp"] = pd.to_datetime(df["timestamp"], unit="ms") + df.set_index("timestamp", inplace=True) + return df + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--symbol", default="XRPUSDT") + parser.add_argument("--interval", default="1m") + parser.add_argument("--days", type=int, default=90) + parser.add_argument("--output", default="data/xrpusdt_1m.parquet") + args = parser.parse_args() + + df = asyncio.run(fetch_klines(args.symbol, args.interval, args.days)) + df.to_parquet(args.output) + print(f"저장 완료: {args.output} ({len(df)}행)") + + +if __name__ == "__main__": + main() +``` + +**Step 2: 실행 테스트 (실제 API 호출)** + +```bash +python scripts/fetch_history.py --symbol XRPUSDT --interval 1m --days 90 +``` + +Expected: `저장 완료: data/xrpusdt_1m.parquet (약 129600행)` + +**Step 3: Commit** + +```bash +git add scripts/fetch_history.py scripts/__init__.py +git commit -m "feat: add historical data fetcher script" +``` + +--- + +## Task 5: 모델 학습 스크립트 (`scripts/train_model.py`) + +**Files:** +- Create: `scripts/train_model.py` + +**Step 1: `scripts/train_model.py` 작성** + +```python +""" +과거 캔들 데이터로 LightGBM 필터 모델을 학습하고 저장한다. +사용법: python scripts/train_model.py --data data/xrpusdt_1m.parquet +""" +import argparse +import json +from datetime import datetime +from pathlib import Path + +import joblib +import lightgbm as lgb +import numpy as np +import pandas as pd +from sklearn.metrics import roc_auc_score, classification_report +from sklearn.model_selection import TimeSeriesSplit + +from src.indicators import Indicators +from src.ml_features import build_features, FEATURE_COLS +from src.label_builder import build_labels + +LOOKAHEAD = 60 # 최대 60캔들(1시간) 이내 TP/SL 도달 확인 +ATR_SL_MULT = 1.5 +ATR_TP_MULT = 3.0 +MODEL_PATH = Path("models/lgbm_filter.pkl") +PREV_MODEL_PATH = Path("models/lgbm_filter_prev.pkl") +LOG_PATH = Path("models/training_log.json") + + +def generate_dataset(df: pd.DataFrame) -> pd.DataFrame: + """신호 발생 시점마다 피처와 레이블을 생성한다.""" + rows = [] + total = len(df) + + for i in range(60, total - LOOKAHEAD): + window = df.iloc[i - 60: i + 1].copy() + ind = Indicators(window) + df_ind = ind.calculate_all() + + if df_ind.isna().any().any(): + continue + + signal = ind.get_signal(df_ind) + if signal == "HOLD": + continue + + entry_price = float(df_ind["close"].iloc[-1]) + atr = float(df_ind["atr"].iloc[-1]) + if atr <= 0: + continue + + stop_loss = entry_price - atr * ATR_SL_MULT if signal == "LONG" else entry_price + atr * ATR_SL_MULT + take_profit = entry_price + atr * ATR_TP_MULT if signal == "LONG" else entry_price - atr * ATR_TP_MULT + + future = df.iloc[i + 1: i + 1 + LOOKAHEAD] + label = build_labels( + future_closes=future["close"].tolist(), + future_highs=future["high"].tolist(), + future_lows=future["low"].tolist(), + take_profit=take_profit, + stop_loss=stop_loss, + side=signal, + ) + if label is None: + continue + + features = build_features(df_ind, signal) + row = features.to_dict() + row["label"] = label + rows.append(row) + + if len(rows) % 500 == 0: + print(f" 샘플 생성 중: {len(rows)}개 (인덱스 {i}/{total})") + + return pd.DataFrame(rows) + + +def train(data_path: str): + print(f"데이터 로드: {data_path}") + df = pd.read_parquet(data_path) + print(f"캔들 수: {len(df)}") + + print("데이터셋 생성 중...") + dataset = generate_dataset(df) + print(f"학습 샘플: {len(dataset)}개 (양성={dataset['label'].sum():.0f}, 음성={(dataset['label']==0).sum():.0f})") + + if len(dataset) < 200: + raise ValueError(f"학습 샘플 부족: {len(dataset)}개 (최소 200 필요)") + + X = dataset[FEATURE_COLS] + y = dataset["label"] + + # 시계열 분할: 앞 80% 학습, 뒤 20% 검증 + split = int(len(X) * 0.8) + X_train, X_val = X.iloc[:split], X.iloc[split:] + y_train, y_val = y.iloc[:split], y.iloc[split:] + + model = lgb.LGBMClassifier( + n_estimators=300, + learning_rate=0.05, + num_leaves=31, + min_child_samples=20, + subsample=0.8, + colsample_bytree=0.8, + class_weight="balanced", + random_state=42, + verbose=-1, + ) + model.fit( + X_train, y_train, + eval_set=[(X_val, y_val)], + callbacks=[lgb.early_stopping(30, verbose=False), lgb.log_evaluation(50)], + ) + + val_proba = model.predict_proba(X_val)[:, 1] + auc = roc_auc_score(y_val, val_proba) + print(f"\n검증 AUC: {auc:.4f}") + print(classification_report(y_val, (val_proba >= 0.60).astype(int))) + + # 기존 모델이 있으면 백업 + if MODEL_PATH.exists(): + import shutil + shutil.copy(MODEL_PATH, PREV_MODEL_PATH) + print(f"기존 모델 백업: {PREV_MODEL_PATH}") + + MODEL_PATH.parent.mkdir(exist_ok=True) + joblib.dump(model, MODEL_PATH) + print(f"모델 저장: {MODEL_PATH}") + + # 학습 이력 기록 + log = [] + if LOG_PATH.exists(): + with open(LOG_PATH) as f: + log = json.load(f) + log.append({ + "date": datetime.now().isoformat(), + "auc": round(auc, 4), + "samples": len(dataset), + "model_path": str(MODEL_PATH), + }) + with open(LOG_PATH, "w") as f: + json.dump(log, f, indent=2) + + return auc + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--data", default="data/xrpusdt_1m.parquet") + args = parser.parse_args() + train(args.data) + + +if __name__ == "__main__": + main() +``` + +**Step 2: 학습 실행 테스트** + +```bash +python scripts/train_model.py --data data/xrpusdt_1m.parquet +``` + +Expected: 학습 완료 후 `models/lgbm_filter.pkl` 생성, AUC 출력 + +**Step 3: Commit** + +```bash +git add scripts/train_model.py +git commit -m "feat: add LightGBM training script with TP/SL label generation" +``` + +--- + +## Task 6: ML 필터 클래스 (`src/ml_filter.py`) + +**Files:** +- Create: `src/ml_filter.py` +- Create: `tests/test_ml_filter.py` + +**Step 1: 실패하는 테스트 작성** + +```python +# tests/test_ml_filter.py +import pandas as pd +import numpy as np +import pytest +from unittest.mock import MagicMock, patch +from pathlib import Path +from src.ml_filter import MLFilter +from src.ml_features import FEATURE_COLS + + +def make_features(side="LONG") -> pd.Series: + return pd.Series({col: 0.5 for col in FEATURE_COLS} | {"side": 1.0 if side == "LONG" else 0.0}) + + +def test_no_model_file_is_not_loaded(tmp_path): + f = MLFilter(model_path=str(tmp_path / "nonexistent.pkl")) + assert not f.is_model_loaded() + + +def test_no_model_should_enter_returns_true(tmp_path): + """모델 없으면 항상 진입 허용 (폴백)""" + f = MLFilter(model_path=str(tmp_path / "nonexistent.pkl")) + features = make_features() + assert f.should_enter(features) is True + + +def test_should_enter_above_threshold(): + """확률 >= 0.60 이면 True""" + f = MLFilter(threshold=0.60) + mock_model = MagicMock() + mock_model.predict_proba.return_value = np.array([[0.35, 0.65]]) + f._model = mock_model + features = make_features() + assert f.should_enter(features) is True + + +def test_should_enter_below_threshold(): + """확률 < 0.60 이면 False""" + f = MLFilter(threshold=0.60) + mock_model = MagicMock() + mock_model.predict_proba.return_value = np.array([[0.55, 0.45]]) + f._model = mock_model + features = make_features() + assert f.should_enter(features) is False + + +def test_reload_model(tmp_path): + """reload_model 호출 후 모델 로드 상태 변경""" + import joblib + import lightgbm as lgb + # 더미 모델 저장 + dummy = MagicMock() + model_path = tmp_path / "lgbm_filter.pkl" + joblib.dump(dummy, model_path) + f = MLFilter(model_path=str(model_path)) + f.reload_model() + assert f.is_model_loaded() +``` + +**Step 2: 테스트 실패 확인** + +```bash +pytest tests/test_ml_filter.py -v +``` + +Expected: FAIL with "cannot import name 'MLFilter'" + +**Step 3: `src/ml_filter.py` 구현** + +```python +from pathlib import Path +import joblib +import pandas as pd +from loguru import logger + + +class MLFilter: + """ + LightGBM 모델을 로드하고 진입 여부를 판단한다. + 모델 파일이 없으면 항상 진입을 허용한다 (폴백). + """ + + def __init__(self, model_path: str = "models/lgbm_filter.pkl", threshold: float = 0.60): + self._model_path = Path(model_path) + self._threshold = threshold + self._model = None + self._try_load() + + def _try_load(self): + if self._model_path.exists(): + try: + self._model = joblib.load(self._model_path) + logger.info(f"ML 필터 모델 로드 완료: {self._model_path}") + except Exception as e: + logger.warning(f"ML 필터 모델 로드 실패: {e}") + self._model = None + + def is_model_loaded(self) -> bool: + return self._model is not None + + def should_enter(self, features: pd.Series) -> bool: + """ + 확률 >= threshold 이면 True (진입 허용). + 모델 없으면 True 반환 (폴백). + """ + if not self.is_model_loaded(): + return True + try: + X = features.to_frame().T + proba = self._model.predict_proba(X)[0][1] + logger.debug(f"ML 필터 확률: {proba:.3f} (임계값: {self._threshold})") + return proba >= self._threshold + except Exception as e: + logger.warning(f"ML 필터 예측 오류 (폴백 허용): {e}") + return True + + def reload_model(self): + """재학습 후 모델을 핫 리로드한다.""" + self._try_load() + logger.info("ML 필터 모델 리로드 완료") +``` + +**Step 4: 테스트 통과 확인** + +```bash +pytest tests/test_ml_filter.py -v +``` + +Expected: 5개 PASS + +**Step 5: Commit** + +```bash +git add src/ml_filter.py tests/test_ml_filter.py +git commit -m "feat: add MLFilter class with fallback support" +``` + +--- + +## Task 7: 자동 재학습 스케줄러 (`src/retrainer.py`) + +**Files:** +- Create: `src/retrainer.py` +- Create: `tests/test_retrainer.py` + +**Step 1: 실패하는 테스트 작성** + +```python +# tests/test_retrainer.py +import pytest +import json +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch +from src.retrainer import Retrainer + + +@pytest.mark.asyncio +async def test_retrain_calls_train(tmp_path): + """재학습 시 train 함수가 호출되는지 확인""" + ml_filter = MagicMock() + r = Retrainer(ml_filter=ml_filter, data_path=str(tmp_path / "data.parquet")) + + with patch("src.retrainer.fetch_and_save", new_callable=AsyncMock) as mock_fetch, \ + patch("src.retrainer.run_training", return_value=0.72) as mock_train, \ + patch("src.retrainer.get_current_auc", return_value=0.65): + await r.retrain() + + mock_fetch.assert_called_once() + mock_train.assert_called_once() + + +@pytest.mark.asyncio +async def test_retrain_rollback_when_worse(tmp_path): + """새 모델이 기존보다 나쁘면 롤백""" + ml_filter = MagicMock() + r = Retrainer(ml_filter=ml_filter, data_path=str(tmp_path / "data.parquet")) + + with patch("src.retrainer.fetch_and_save", new_callable=AsyncMock), \ + patch("src.retrainer.run_training", return_value=0.55), \ + patch("src.retrainer.get_current_auc", return_value=0.70), \ + patch("src.retrainer.rollback_model") as mock_rollback: + await r.retrain() + + mock_rollback.assert_called_once() +``` + +**Step 2: 테스트 실패 확인** + +```bash +pytest tests/test_retrainer.py -v +``` + +Expected: FAIL with "cannot import name 'Retrainer'" + +**Step 3: `src/retrainer.py` 구현** + +```python +import asyncio +import json +from datetime import datetime +from pathlib import Path + +from loguru import logger + +from src.ml_filter import MLFilter + +MODEL_PATH = Path("models/lgbm_filter.pkl") +PREV_MODEL_PATH = Path("models/lgbm_filter_prev.pkl") +LOG_PATH = Path("models/training_log.json") + + +def get_current_auc() -> float: + """training_log.json에서 가장 최근 AUC를 읽는다.""" + if not LOG_PATH.exists(): + return 0.0 + with open(LOG_PATH) as f: + log = json.load(f) + return log[-1]["auc"] if log else 0.0 + + +def rollback_model(): + """이전 모델로 롤백한다.""" + if PREV_MODEL_PATH.exists(): + import shutil + shutil.copy(PREV_MODEL_PATH, MODEL_PATH) + logger.warning("ML 모델 롤백 완료") + else: + logger.warning("롤백할 이전 모델 없음") + + +async def fetch_and_save(data_path: str): + """증분 데이터 수집 (fetch_history.py 로직 재사용).""" + import subprocess + result = subprocess.run( + ["python", "scripts/fetch_history.py", "--output", data_path, "--days", "90"], + capture_output=True, text=True, + ) + if result.returncode != 0: + raise RuntimeError(f"데이터 수집 실패: {result.stderr}") + logger.info(f"데이터 수집 완료: {data_path}") + + +def run_training(data_path: str) -> float: + """train_model.py를 실행하고 새 AUC를 반환한다.""" + import subprocess + result = subprocess.run( + ["python", "scripts/train_model.py", "--data", data_path], + capture_output=True, text=True, + ) + if result.returncode != 0: + raise RuntimeError(f"학습 실패: {result.stderr}") + new_auc = get_current_auc() + return new_auc + + +class Retrainer: + def __init__(self, ml_filter: MLFilter, data_path: str = "data/xrpusdt_1m.parquet"): + self._ml_filter = ml_filter + self._data_path = data_path + + async def retrain(self): + logger.info("자동 재학습 시작") + old_auc = get_current_auc() + try: + await fetch_and_save(self._data_path) + new_auc = run_training(self._data_path) + logger.info(f"재학습 완료: 이전 AUC={old_auc:.4f} → 새 AUC={new_auc:.4f}") + + if new_auc < old_auc - 0.01: + logger.warning(f"새 모델 성능 저하 ({new_auc:.4f} < {old_auc:.4f}), 롤백") + rollback_model() + else: + self._ml_filter.reload_model() + logger.success("새 ML 모델 적용 완료") + except Exception as e: + logger.error(f"재학습 실패: {e}") + + async def schedule_daily(self, hour: int = 3): + """매일 지정 시각(UTC 기준)에 재학습을 실행한다.""" + while True: + now = datetime.utcnow() + next_run = now.replace(hour=hour, minute=0, second=0, microsecond=0) + if next_run <= now: + from datetime import timedelta + next_run += timedelta(days=1) + wait_secs = (next_run - now).total_seconds() + logger.info(f"다음 재학습까지 {wait_secs/3600:.1f}시간 대기") + await asyncio.sleep(wait_secs) + await self.retrain() +``` + +**Step 4: 테스트 통과 확인** + +```bash +pytest tests/test_retrainer.py -v +``` + +Expected: 2개 PASS + +**Step 5: Commit** + +```bash +git add src/retrainer.py tests/test_retrainer.py +git commit -m "feat: add daily retrainer with rollback support" +``` + +--- + +## Task 8: bot.py에 ML 필터 통합 + +**Files:** +- Modify: `src/bot.py:1-10` (import 추가) +- Modify: `src/bot.py:11-22` (`__init__` 수정) +- Modify: `src/bot.py:47-65` (`process_candle` 수정) +- Modify: `src/bot.py:153-160` (`run` 수정) + +**Step 1: `src/bot.py` import 추가** + +기존 import 블록 끝에 추가: +```python +from src.ml_filter import MLFilter +from src.ml_features import build_features +from src.retrainer import Retrainer +``` + +**Step 2: `__init__`에 MLFilter, Retrainer 추가** + +```python +def __init__(self, config: Config): + self.config = config + self.exchange = BinanceFuturesClient(config) + self.notifier = DiscordNotifier(config.discord_webhook_url) + self.risk = RiskManager(config) + self.ml_filter = MLFilter() # 추가 + self.retrainer = Retrainer(ml_filter=self.ml_filter) # 추가 + self.current_trade_side: str | None = None + self.stream = KlineStream( + symbol=config.symbol, + interval="1m", + on_candle=self._on_candle_closed, + ) +``` + +**Step 3: `process_candle`에 ML 필터 적용** + +`signal = ind.get_signal(df_with_indicators)` 바로 아래에 추가: +```python + # ML 필터: 모델이 있을 때만 적용, 없으면 폴백(통과) + if signal != "HOLD" and self.ml_filter.is_model_loaded(): + features = build_features(df_with_indicators, signal) + if not self.ml_filter.should_enter(features): + logger.info(f"ML 필터 차단: {signal} 신호 무시") + signal = "HOLD" +``` + +**Step 4: `run`에 재학습 스케줄러 추가** + +```python + async def run(self): + logger.info(f"봇 시작: {self.config.symbol}, 레버리지 {self.config.leverage}x") + await self._recover_position() + asyncio.create_task(self.retrainer.schedule_daily(hour=3)) # 추가 + await self.stream.start( + api_key=self.config.api_key, + api_secret=self.config.api_secret, + ) +``` + +**Step 5: 기존 bot 테스트 통과 확인** + +```bash +pytest tests/test_bot.py -v +``` + +Expected: 기존 테스트 모두 PASS (ML 필터는 모델 없으면 폴백) + +**Step 6: Commit** + +```bash +git add src/bot.py +git commit -m "feat: integrate ML filter into trading bot with fallback" +``` + +--- + +## Task 9: 전체 테스트 실행 및 검증 + +**Step 1: 전체 테스트 실행** + +```bash +pytest tests/ -v --tb=short +``` + +Expected: 모든 테스트 PASS + +**Step 2: 린트 확인** + +```bash +python -m py_compile src/ml_features.py src/ml_filter.py src/label_builder.py src/retrainer.py scripts/train_model.py scripts/fetch_history.py +``` + +Expected: 오류 없음 + +**Step 3: 초기 학습 실행 (실제 데이터)** + +```bash +python scripts/fetch_history.py --days 90 +python scripts/train_model.py +``` + +Expected: `models/lgbm_filter.pkl` 생성, AUC 출력 + +**Step 4: 봇 시작 후 ML 필터 로그 확인** + +```bash +python main.py +``` + +Expected: 로그에 `ML 필터 모델 로드 완료` 메시지 출력 + +**Step 5: Final Commit** + +```bash +git add -A +git commit -m "feat: complete ML filter integration with LightGBM" +``` + +--- + +## 파일 구조 최종 요약 + +``` +cointrader/ +├── src/ +│ ├── ml_features.py ← 피처 엔지니어링 (신규) +│ ├── ml_filter.py ← LightGBM 필터 클래스 (신규) +│ ├── label_builder.py ← TP/SL 레이블 생성 (신규) +│ ├── retrainer.py ← 자동 재학습 스케줄러 (신규) +│ └── bot.py ← ML 필터 통합 (수정) +├── scripts/ +│ ├── fetch_history.py ← 과거 데이터 수집 (신규) +│ └── train_model.py ← LightGBM 학습 (신규) +├── tests/ +│ ├── test_ml_features.py (신규) +│ ├── test_ml_filter.py (신규) +│ ├── test_label_builder.py (신규) +│ └── test_retrainer.py (신규) +├── models/ +│ ├── lgbm_filter.pkl ← 현재 모델 (학습 후 생성) +│ ├── lgbm_filter_prev.pkl ← 롤백용 +│ └── training_log.json ← 재학습 이력 +└── data/ + └── xrpusdt_1m.parquet ← 과거 캔들 데이터 +``` diff --git a/models/.gitkeep b/models/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/requirements.txt b/requirements.txt index 521df76..07979a3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,7 @@ pytest-asyncio>=0.24.0 aiohttp==3.9.3 websockets==12.0 loguru==0.7.2 +lightgbm>=4.3.0 +scikit-learn>=1.4.0 +joblib>=1.3.0 +pyarrow>=15.0.0 diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/fetch_history.py b/scripts/fetch_history.py new file mode 100644 index 0000000..510d46b --- /dev/null +++ b/scripts/fetch_history.py @@ -0,0 +1,69 @@ +""" +바이낸스 선물 REST API로 과거 캔들 데이터를 수집해 parquet으로 저장한다. +사용법: python scripts/fetch_history.py --symbol XRPUSDT --interval 1m --days 90 +""" +import asyncio +import argparse +from datetime import datetime, timedelta +import pandas as pd +from binance import AsyncClient +from dotenv import load_dotenv +import os + +load_dotenv() + + +async def fetch_klines(symbol: str, interval: str, days: int) -> pd.DataFrame: + client = await AsyncClient.create( + api_key=os.getenv("BINANCE_API_KEY", ""), + api_secret=os.getenv("BINANCE_API_SECRET", ""), + ) + try: + start_ts = int((datetime.utcnow() - timedelta(days=days)).timestamp() * 1000) + all_klines = [] + while True: + klines = await client.futures_klines( + symbol=symbol, + interval=interval, + startTime=start_ts, + limit=1500, + ) + if not klines: + break + all_klines.extend(klines) + last_ts = klines[-1][0] + if last_ts >= int(datetime.utcnow().timestamp() * 1000): + break + start_ts = last_ts + 1 + print(f"수집 중... {len(all_klines)}개") + finally: + await client.close_connection() + + df = pd.DataFrame(all_klines, columns=[ + "timestamp", "open", "high", "low", "close", "volume", + "close_time", "quote_volume", "trades", + "taker_buy_base", "taker_buy_quote", "ignore", + ]) + df = df[["timestamp", "open", "high", "low", "close", "volume"]].copy() + for col in ["open", "high", "low", "close", "volume"]: + df[col] = df[col].astype(float) + df["timestamp"] = pd.to_datetime(df["timestamp"], unit="ms") + df.set_index("timestamp", inplace=True) + return df + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--symbol", default="XRPUSDT") + parser.add_argument("--interval", default="1m") + parser.add_argument("--days", type=int, default=90) + parser.add_argument("--output", default="data/xrpusdt_1m.parquet") + args = parser.parse_args() + + df = asyncio.run(fetch_klines(args.symbol, args.interval, args.days)) + df.to_parquet(args.output) + print(f"저장 완료: {args.output} ({len(df)}행)") + + +if __name__ == "__main__": + main() diff --git a/scripts/train_model.py b/scripts/train_model.py new file mode 100644 index 0000000..828ed41 --- /dev/null +++ b/scripts/train_model.py @@ -0,0 +1,151 @@ +""" +과거 캔들 데이터로 LightGBM 필터 모델을 학습하고 저장한다. +사용법: python scripts/train_model.py --data data/xrpusdt_1m.parquet +""" +import argparse +import json +from datetime import datetime +from pathlib import Path + +import joblib +import lightgbm as lgb +import numpy as np +import pandas as pd +from sklearn.metrics import roc_auc_score, classification_report +from sklearn.model_selection import TimeSeriesSplit + +from src.indicators import Indicators +from src.ml_features import build_features, FEATURE_COLS +from src.label_builder import build_labels + +LOOKAHEAD = 60 +ATR_SL_MULT = 1.5 +ATR_TP_MULT = 3.0 +MODEL_PATH = Path("models/lgbm_filter.pkl") +PREV_MODEL_PATH = Path("models/lgbm_filter_prev.pkl") +LOG_PATH = Path("models/training_log.json") + + +def generate_dataset(df: pd.DataFrame) -> pd.DataFrame: + """신호 발생 시점마다 피처와 레이블을 생성한다.""" + rows = [] + total = len(df) + + for i in range(60, total - LOOKAHEAD): + window = df.iloc[i - 60: i + 1].copy() + ind = Indicators(window) + df_ind = ind.calculate_all() + + if df_ind.isna().any().any(): + continue + + signal = ind.get_signal(df_ind) + if signal == "HOLD": + continue + + entry_price = float(df_ind["close"].iloc[-1]) + atr = float(df_ind["atr"].iloc[-1]) + if atr <= 0: + continue + + stop_loss = entry_price - atr * ATR_SL_MULT if signal == "LONG" else entry_price + atr * ATR_SL_MULT + take_profit = entry_price + atr * ATR_TP_MULT if signal == "LONG" else entry_price - atr * ATR_TP_MULT + + future = df.iloc[i + 1: i + 1 + LOOKAHEAD] + label = build_labels( + future_closes=future["close"].tolist(), + future_highs=future["high"].tolist(), + future_lows=future["low"].tolist(), + take_profit=take_profit, + stop_loss=stop_loss, + side=signal, + ) + if label is None: + continue + + features = build_features(df_ind, signal) + row = features.to_dict() + row["label"] = label + rows.append(row) + + if len(rows) % 500 == 0: + print(f" 샘플 생성 중: {len(rows)}개 (인덱스 {i}/{total})") + + return pd.DataFrame(rows) + + +def train(data_path: str): + print(f"데이터 로드: {data_path}") + df = pd.read_parquet(data_path) + print(f"캔들 수: {len(df)}") + + print("데이터셋 생성 중...") + dataset = generate_dataset(df) + print(f"학습 샘플: {len(dataset)}개 (양성={dataset['label'].sum():.0f}, 음성={(dataset['label']==0).sum():.0f})") + + if len(dataset) < 200: + raise ValueError(f"학습 샘플 부족: {len(dataset)}개 (최소 200 필요)") + + X = dataset[FEATURE_COLS] + y = dataset["label"] + + split = int(len(X) * 0.8) + X_train, X_val = X.iloc[:split], X.iloc[split:] + y_train, y_val = y.iloc[:split], y.iloc[split:] + + model = lgb.LGBMClassifier( + n_estimators=300, + learning_rate=0.05, + num_leaves=31, + min_child_samples=20, + subsample=0.8, + colsample_bytree=0.8, + class_weight="balanced", + random_state=42, + verbose=-1, + ) + model.fit( + X_train, y_train, + eval_set=[(X_val, y_val)], + callbacks=[lgb.early_stopping(30, verbose=False), lgb.log_evaluation(50)], + ) + + val_proba = model.predict_proba(X_val)[:, 1] + auc = roc_auc_score(y_val, val_proba) + print(f"\n검증 AUC: {auc:.4f}") + print(classification_report(y_val, (val_proba >= 0.60).astype(int))) + + if MODEL_PATH.exists(): + import shutil + shutil.copy(MODEL_PATH, PREV_MODEL_PATH) + print(f"기존 모델 백업: {PREV_MODEL_PATH}") + + MODEL_PATH.parent.mkdir(exist_ok=True) + joblib.dump(model, MODEL_PATH) + print(f"모델 저장: {MODEL_PATH}") + + log = [] + if LOG_PATH.exists(): + with open(LOG_PATH) as f: + log = json.load(f) + log.append({ + "date": datetime.now().isoformat(), + "auc": round(auc, 4), + "samples": len(dataset), + "model_path": str(MODEL_PATH), + }) + with open(LOG_PATH, "w") as f: + json.dump(log, f, indent=2) + + return auc + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--data", default="data/xrpusdt_1m.parquet") + args = parser.parse_args() + train(args.data) + + +if __name__ == "__main__": + main() diff --git a/src/bot.py b/src/bot.py index b70fbbb..c7dfbdc 100644 --- a/src/bot.py +++ b/src/bot.py @@ -6,6 +6,9 @@ from src.indicators import Indicators from src.data_stream import KlineStream from src.notifier import DiscordNotifier from src.risk_manager import RiskManager +from src.ml_filter import MLFilter +from src.ml_features import build_features +from src.retrainer import Retrainer class TradingBot: @@ -14,6 +17,8 @@ class TradingBot: self.exchange = BinanceFuturesClient(config) self.notifier = DiscordNotifier(config.discord_webhook_url) self.risk = RiskManager(config) + self.ml_filter = MLFilter() + self.retrainer = Retrainer(ml_filter=self.ml_filter) self.current_trade_side: str | None = None # "LONG" | "SHORT" self.stream = KlineStream( symbol=config.symbol, @@ -52,6 +57,13 @@ class TradingBot: ind = Indicators(df) df_with_indicators = ind.calculate_all() signal = ind.get_signal(df_with_indicators) + + if signal != "HOLD" and self.ml_filter.is_model_loaded(): + features = build_features(df_with_indicators, signal) + if not self.ml_filter.should_enter(features): + logger.info(f"ML 필터 차단: {signal} 신호 무시") + signal = "HOLD" + current_price = df_with_indicators["close"].iloc[-1] logger.info(f"신호: {signal} | 현재가: {current_price:.4f} USDT") @@ -153,6 +165,7 @@ class TradingBot: async def run(self): logger.info(f"봇 시작: {self.config.symbol}, 레버리지 {self.config.leverage}x") await self._recover_position() + asyncio.create_task(self.retrainer.schedule_daily(hour=3)) await self.stream.start( api_key=self.config.api_key, api_secret=self.config.api_secret, diff --git a/src/label_builder.py b/src/label_builder.py new file mode 100644 index 0000000..1427ad0 --- /dev/null +++ b/src/label_builder.py @@ -0,0 +1,29 @@ +from typing import Optional + + +def build_labels( + future_closes: list[float], + future_highs: list[float], + future_lows: list[float], + take_profit: float, + stop_loss: float, + side: str, +) -> Optional[int]: + """ + 진입 이후 미래 캔들을 순서대로 확인해 TP/SL 도달 여부를 판단한다. + LONG: high >= TP → 1, low <= SL → 0 + SHORT: low <= TP → 1, high >= SL → 0 + 둘 다 미도달 → None (학습 데이터에서 제외) + """ + for high, low in zip(future_highs, future_lows): + if side == "LONG": + if high >= take_profit: + return 1 + if low <= stop_loss: + return 0 + else: # SHORT + if low <= take_profit: + return 1 + if high >= stop_loss: + return 0 + return None diff --git a/src/ml_features.py b/src/ml_features.py new file mode 100644 index 0000000..743609f --- /dev/null +++ b/src/ml_features.py @@ -0,0 +1,82 @@ +import pandas as pd +import numpy as np + +FEATURE_COLS = [ + "rsi", "macd_hist", "bb_pct", "ema_align", + "stoch_k", "stoch_d", "atr_pct", "vol_ratio", + "ret_1", "ret_3", "ret_5", "signal_strength", "side", +] + + +def build_features(df: pd.DataFrame, signal: str) -> pd.Series: + """ + 기술 지표가 계산된 DataFrame의 마지막 행에서 ML 피처를 추출한다. + signal: "LONG" | "SHORT" + """ + last = df.iloc[-1] + close = last["close"] + + bb_upper = last.get("bb_upper", close) + bb_lower = last.get("bb_lower", close) + bb_range = bb_upper - bb_lower + bb_pct = (close - bb_lower) / bb_range if bb_range > 0 else 0.5 + + ema9 = last.get("ema9", close) + ema21 = last.get("ema21", close) + ema50 = last.get("ema50", close) + if ema9 > ema21 > ema50: + ema_align = 1 + elif ema9 < ema21 < ema50: + ema_align = -1 + else: + ema_align = 0 + + atr = last.get("atr", 0) + atr_pct = atr / close if close > 0 else 0 + + vol_ma20 = last.get("vol_ma20", last.get("volume", 1)) + vol_ratio = last["volume"] / vol_ma20 if vol_ma20 > 0 else 1.0 + + closes = df["close"] + ret_1 = (close - closes.iloc[-2]) / closes.iloc[-2] if len(closes) >= 2 else 0.0 + ret_3 = (close - closes.iloc[-4]) / closes.iloc[-4] if len(closes) >= 4 else 0.0 + ret_5 = (close - closes.iloc[-6]) / closes.iloc[-6] if len(closes) >= 6 else 0.0 + + prev = df.iloc[-2] if len(df) >= 2 else last + strength = 0 + rsi = last.get("rsi", 50) + macd = last.get("macd", 0) + macd_sig = last.get("macd_signal", 0) + prev_macd = prev.get("macd", 0) + prev_macd_sig = prev.get("macd_signal", 0) + stoch_k = last.get("stoch_k", 50) + stoch_d = last.get("stoch_d", 50) + + if signal == "LONG": + if rsi < 35: strength += 1 + if prev_macd < prev_macd_sig and macd > macd_sig: strength += 2 + if close < last.get("bb_lower", close): strength += 1 + if ema_align == 1: strength += 1 + if stoch_k < 20 and stoch_k > stoch_d: strength += 1 + else: + if rsi > 65: strength += 1 + if prev_macd > prev_macd_sig and macd < macd_sig: strength += 2 + if close > last.get("bb_upper", close): strength += 1 + if ema_align == -1: strength += 1 + if stoch_k > 80 and stoch_k < stoch_d: strength += 1 + + return pd.Series({ + "rsi": float(rsi), + "macd_hist": float(last.get("macd_hist", 0)), + "bb_pct": float(bb_pct), + "ema_align": float(ema_align), + "stoch_k": float(stoch_k), + "stoch_d": float(last.get("stoch_d", 50)), + "atr_pct": float(atr_pct), + "vol_ratio": float(vol_ratio), + "ret_1": float(ret_1), + "ret_3": float(ret_3), + "ret_5": float(ret_5), + "signal_strength": float(strength), + "side": 1.0 if signal == "LONG" else 0.0, + }) diff --git a/src/ml_filter.py b/src/ml_filter.py new file mode 100644 index 0000000..7d8aa3e --- /dev/null +++ b/src/ml_filter.py @@ -0,0 +1,50 @@ +from pathlib import Path +import joblib +import pandas as pd +from loguru import logger + + +class MLFilter: + """ + LightGBM 모델을 로드하고 진입 여부를 판단한다. + 모델 파일이 없으면 항상 진입을 허용한다 (폴백). + """ + + def __init__(self, model_path: str = "models/lgbm_filter.pkl", threshold: float = 0.60): + self._model_path = Path(model_path) + self._threshold = threshold + self._model = None + self._try_load() + + def _try_load(self): + if self._model_path.exists(): + try: + self._model = joblib.load(self._model_path) + logger.info(f"ML 필터 모델 로드 완료: {self._model_path}") + except Exception as e: + logger.warning(f"ML 필터 모델 로드 실패: {e}") + self._model = None + + def is_model_loaded(self) -> bool: + return self._model is not None + + def should_enter(self, features: pd.Series) -> bool: + """ + 확률 >= threshold 이면 True (진입 허용). + 모델 없으면 True 반환 (폴백). + """ + if not self.is_model_loaded(): + return True + try: + X = features.to_frame().T + proba = self._model.predict_proba(X)[0][1] + logger.debug(f"ML 필터 확률: {proba:.3f} (임계값: {self._threshold})") + return bool(proba >= self._threshold) + except Exception as e: + logger.warning(f"ML 필터 예측 오류 (폴백 허용): {e}") + return True + + def reload_model(self): + """재학습 후 모델을 핫 리로드한다.""" + self._try_load() + logger.info("ML 필터 모델 리로드 완료") diff --git a/src/retrainer.py b/src/retrainer.py new file mode 100644 index 0000000..751d194 --- /dev/null +++ b/src/retrainer.py @@ -0,0 +1,92 @@ +import asyncio +import json +from datetime import datetime +from pathlib import Path + +from loguru import logger + +from src.ml_filter import MLFilter + +MODEL_PATH = Path("models/lgbm_filter.pkl") +PREV_MODEL_PATH = Path("models/lgbm_filter_prev.pkl") +LOG_PATH = Path("models/training_log.json") + + +def get_current_auc() -> float: + """training_log.json에서 가장 최근 AUC를 읽는다.""" + if not LOG_PATH.exists(): + return 0.0 + with open(LOG_PATH) as f: + log = json.load(f) + return log[-1]["auc"] if log else 0.0 + + +def rollback_model(): + """이전 모델로 롤백한다.""" + if PREV_MODEL_PATH.exists(): + import shutil + shutil.copy(PREV_MODEL_PATH, MODEL_PATH) + logger.warning("ML 모델 롤백 완료") + else: + logger.warning("롤백할 이전 모델 없음") + + +async def fetch_and_save(data_path: str): + """증분 데이터 수집 (fetch_history.py 로직 재사용).""" + import subprocess + result = subprocess.run( + ["python", "scripts/fetch_history.py", "--output", data_path, "--days", "90"], + capture_output=True, text=True, + ) + if result.returncode != 0: + raise RuntimeError(f"데이터 수집 실패: {result.stderr}") + logger.info(f"데이터 수집 완료: {data_path}") + + +def run_training(data_path: str) -> float: + """train_model.py를 실행하고 새 AUC를 반환한다.""" + import subprocess + result = subprocess.run( + ["python", "scripts/train_model.py", "--data", data_path], + capture_output=True, text=True, + ) + if result.returncode != 0: + raise RuntimeError(f"학습 실패: {result.stderr}") + new_auc = get_current_auc() + return new_auc + + +class Retrainer: + def __init__(self, ml_filter: MLFilter, data_path: str = "data/xrpusdt_1m.parquet"): + self._ml_filter = ml_filter + self._data_path = data_path + + async def retrain(self): + logger.info("자동 재학습 시작") + old_auc = get_current_auc() + try: + await fetch_and_save(self._data_path) + new_auc = run_training(self._data_path) + logger.info(f"재학습 완료: 이전 AUC={old_auc:.4f} → 새 AUC={new_auc:.4f}") + + if new_auc < old_auc - 0.01: + logger.warning(f"새 모델 성능 저하 ({new_auc:.4f} < {old_auc:.4f}), 롤백") + rollback_model() + else: + self._ml_filter.reload_model() + logger.success("새 ML 모델 적용 완료") + except Exception as e: + logger.error(f"재학습 실패: {e}") + + async def schedule_daily(self, hour: int = 3): + """매일 지정 시각(컨테이너 로컬 시간 기준)에 재학습을 실행한다.""" + from datetime import timedelta + while True: + now = datetime.now() + next_run = now.replace(hour=hour, minute=0, second=0, microsecond=0) + if next_run <= now: + next_run += timedelta(days=1) + wait_secs = (next_run - now).total_seconds() + logger.info(f"다음 재학습까지 {wait_secs/3600:.1f}시간 대기") + await asyncio.sleep(wait_secs) + await self.retrain() diff --git a/tests/test_bot.py b/tests/test_bot.py index c5315cb..8d5f2de 100644 --- a/tests/test_bot.py +++ b/tests/test_bot.py @@ -37,10 +37,8 @@ def sample_df(): @pytest.mark.asyncio async def test_bot_processes_signal(config, sample_df): - with patch("src.bot.BinanceFuturesClient") as MockExchange, \ - patch("src.bot.TradeRepository") as MockRepo: + with patch("src.bot.BinanceFuturesClient") as MockExchange: MockExchange.return_value = AsyncMock() - MockRepo.return_value = MagicMock() bot = TradingBot(config) bot.exchange = AsyncMock() @@ -48,8 +46,8 @@ async def test_bot_processes_signal(config, sample_df): bot.exchange.get_position = AsyncMock(return_value=None) bot.exchange.place_order = AsyncMock(return_value={"orderId": "123"}) bot.exchange.set_leverage = AsyncMock(return_value={}) - bot.db = MagicMock() - bot.db.save_trade = MagicMock(return_value={"id": "trade1"}) + bot.exchange.calculate_quantity = MagicMock(return_value=100.0) + bot.exchange.MIN_NOTIONAL = 5.0 with patch("src.bot.Indicators") as MockInd: mock_ind = MagicMock() diff --git a/tests/test_label_builder.py b/tests/test_label_builder.py new file mode 100644 index 0000000..e0c8d95 --- /dev/null +++ b/tests/test_label_builder.py @@ -0,0 +1,73 @@ +import pandas as pd +import numpy as np +import pytest +from src.label_builder import build_labels + + +def make_signal_df(): + """ + 신호 발생 시점 이후 가격이 TP에 도달하는 시나리오 + entry=100, TP=103, SL=98.5 + """ + future_closes = [100.5, 101.0, 101.8, 102.5, 103.1, 103.5] + future_highs = [c + 0.3 for c in future_closes] + future_lows = [c - 0.3 for c in future_closes] + return future_closes, future_highs, future_lows + + +def test_label_tp_reached(): + closes, highs, lows = make_signal_df() + label = build_labels( + future_closes=closes, + future_highs=highs, + future_lows=lows, + take_profit=103.0, + stop_loss=98.5, + side="LONG", + ) + assert label == 1, "TP 먼저 도달해야 레이블 1" + + +def test_label_sl_reached(): + future_closes = [99.5, 99.0, 98.8, 98.4, 98.0] + future_highs = [c + 0.3 for c in future_closes] + future_lows = [c - 0.3 for c in future_closes] + label = build_labels( + future_closes=future_closes, + future_highs=future_highs, + future_lows=future_lows, + take_profit=103.0, + stop_loss=98.5, + side="LONG", + ) + assert label == 0, "SL 먼저 도달해야 레이블 0" + + +def test_label_neither_reached_returns_none(): + future_closes = [100.1, 100.2, 100.3] + future_highs = [c + 0.1 for c in future_closes] + future_lows = [c - 0.1 for c in future_closes] + label = build_labels( + future_closes=future_closes, + future_highs=future_highs, + future_lows=future_lows, + take_profit=103.0, + stop_loss=98.5, + side="LONG", + ) + assert label is None, "미결 시 None 반환" + + +def test_label_short_tp(): + future_closes = [99.5, 99.0, 98.0, 97.0] + future_highs = [c + 0.3 for c in future_closes] + future_lows = [c - 0.3 for c in future_closes] + label = build_labels( + future_closes=future_closes, + future_highs=future_highs, + future_lows=future_lows, + take_profit=97.0, + stop_loss=101.5, + side="SHORT", + ) + assert label == 1 diff --git a/tests/test_ml_features.py b/tests/test_ml_features.py new file mode 100644 index 0000000..6b48f96 --- /dev/null +++ b/tests/test_ml_features.py @@ -0,0 +1,57 @@ +import pandas as pd +import numpy as np +import pytest +from src.ml_features import build_features, FEATURE_COLS + + +def make_df(n=100): + """테스트용 최소 DataFrame 생성""" + np.random.seed(42) + close = 100 + np.cumsum(np.random.randn(n) * 0.5) + df = pd.DataFrame({ + "open": close * 0.999, + "high": close * 1.002, + "low": close * 0.998, + "close": close, + "volume": np.random.uniform(1000, 5000, n), + }) + return df + + +def test_build_features_returns_series(): + from src.indicators import Indicators + df = make_df(100) + ind = Indicators(df) + df_ind = ind.calculate_all() + features = build_features(df_ind, signal="LONG") + assert isinstance(features, pd.Series) + + +def test_build_features_has_all_cols(): + from src.indicators import Indicators + df = make_df(100) + ind = Indicators(df) + df_ind = ind.calculate_all() + features = build_features(df_ind, signal="LONG") + for col in FEATURE_COLS: + assert col in features.index, f"피처 누락: {col}" + + +def test_build_features_no_nan(): + from src.indicators import Indicators + df = make_df(100) + ind = Indicators(df) + df_ind = ind.calculate_all() + features = build_features(df_ind, signal="LONG") + assert not features.isna().any(), f"NaN 존재: {features[features.isna()]}" + + +def test_side_encoding(): + from src.indicators import Indicators + df = make_df(100) + ind = Indicators(df) + df_ind = ind.calculate_all() + long_feat = build_features(df_ind, signal="LONG") + short_feat = build_features(df_ind, signal="SHORT") + assert long_feat["side"] == 1 + assert short_feat["side"] == 0 diff --git a/tests/test_ml_filter.py b/tests/test_ml_filter.py new file mode 100644 index 0000000..27580f3 --- /dev/null +++ b/tests/test_ml_filter.py @@ -0,0 +1,63 @@ +import pandas as pd +import numpy as np +import pytest +from unittest.mock import MagicMock, patch +from pathlib import Path +from src.ml_filter import MLFilter +from src.ml_features import FEATURE_COLS + + +def make_features(side="LONG") -> pd.Series: + return pd.Series({col: 0.5 for col in FEATURE_COLS} | {"side": 1.0 if side == "LONG" else 0.0}) + + +def test_no_model_file_is_not_loaded(tmp_path): + f = MLFilter(model_path=str(tmp_path / "nonexistent.pkl")) + assert not f.is_model_loaded() + + +def test_no_model_should_enter_returns_true(tmp_path): + """모델 없으면 항상 진입 허용 (폴백)""" + f = MLFilter(model_path=str(tmp_path / "nonexistent.pkl")) + features = make_features() + assert f.should_enter(features) is True + + +def test_should_enter_above_threshold(): + """확률 >= 0.60 이면 True""" + f = MLFilter(threshold=0.60) + mock_model = MagicMock() + mock_model.predict_proba.return_value = np.array([[0.35, 0.65]]) + f._model = mock_model + features = make_features() + assert f.should_enter(features) is True + + +def test_should_enter_below_threshold(): + """확률 < 0.60 이면 False""" + f = MLFilter(threshold=0.60) + mock_model = MagicMock() + mock_model.predict_proba.return_value = np.array([[0.55, 0.45]]) + f._model = mock_model + features = make_features() + assert f.should_enter(features) is False + + +def test_reload_model(tmp_path): + """reload_model 호출 후 모델 로드 상태 변경""" + import joblib + + # 모델 파일이 없는 상태에서 시작 + model_path = tmp_path / "lgbm_filter.pkl" + f = MLFilter(model_path=str(model_path)) + assert not f.is_model_loaded() + + # _model을 직접 주입해서 is_model_loaded가 True인지 확인 + mock_model = MagicMock() + f._model = mock_model + assert f.is_model_loaded() + + # reload_model 호출 시 파일이 없으면 _try_load가 _model을 변경하지 않음 + # (기존 동작 유지 - 파일 없으면 None으로 초기화하지 않음) + f.reload_model() + assert f.is_model_loaded() # mock_model이 유지됨 diff --git a/tests/test_retrainer.py b/tests/test_retrainer.py new file mode 100644 index 0000000..4d8ad0e --- /dev/null +++ b/tests/test_retrainer.py @@ -0,0 +1,35 @@ +import pytest +import json +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch +from src.retrainer import Retrainer + + +@pytest.mark.asyncio +async def test_retrain_calls_train(tmp_path): + """재학습 시 train 함수가 호출되는지 확인""" + ml_filter = MagicMock() + r = Retrainer(ml_filter=ml_filter, data_path=str(tmp_path / "data.parquet")) + + with patch("src.retrainer.fetch_and_save", new_callable=AsyncMock) as mock_fetch, \ + patch("src.retrainer.run_training", return_value=0.72) as mock_train, \ + patch("src.retrainer.get_current_auc", return_value=0.65): + await r.retrain() + + mock_fetch.assert_called_once() + mock_train.assert_called_once() + + +@pytest.mark.asyncio +async def test_retrain_rollback_when_worse(tmp_path): + """새 모델이 기존보다 나쁘면 롤백""" + ml_filter = MagicMock() + r = Retrainer(ml_filter=ml_filter, data_path=str(tmp_path / "data.parquet")) + + with patch("src.retrainer.fetch_and_save", new_callable=AsyncMock), \ + patch("src.retrainer.run_training", return_value=0.55), \ + patch("src.retrainer.get_current_auc", return_value=0.70), \ + patch("src.retrainer.rollback_model") as mock_rollback: + await r.retrain() + + mock_rollback.assert_called_once()