feat: enhance train_model.py to dynamically determine CPU count for parallel processing

- Added a new function to accurately retrieve the number of allocated CPUs in containerized environments, improving parallel processing efficiency.
- Updated the dataset generation function to utilize the new CPU count function, ensuring optimal resource usage during model training.

Made-with: Cursor
This commit is contained in:
21in7
2026-03-01 17:46:40 +09:00
parent b86c88a8d6
commit 298d4ad95e

View File

@@ -8,7 +8,7 @@ sys.path.insert(0, str(Path(__file__).parent.parent))
import argparse import argparse
import json import json
import os import math
from datetime import datetime from datetime import datetime
from multiprocessing import Pool, cpu_count from multiprocessing import Pool, cpu_count
from pathlib import Path from pathlib import Path
@@ -23,6 +23,35 @@ from src.indicators import Indicators
from src.ml_features import build_features, FEATURE_COLS from src.ml_features import build_features, FEATURE_COLS
from src.label_builder import build_labels from src.label_builder import build_labels
def _cgroup_cpu_count() -> int:
"""cgroup v1/v2 쿼터를 읽어 실제 할당된 CPU 수를 반환한다.
LXC/컨테이너 환경에서 cpu_count()가 호스트 전체 코어를 반환하는 문제를 방지한다.
쿼터를 읽을 수 없으면 cpu_count()를 그대로 사용한다.
"""
# cgroup v2
try:
quota_path = Path("/sys/fs/cgroup/cpu.max")
if quota_path.exists():
parts = quota_path.read_text().split()
if parts[0] != "max":
quota = int(parts[0])
period = int(parts[1])
return max(1, math.floor(quota / period))
except Exception:
pass
# cgroup v1
try:
quota = int(Path("/sys/fs/cgroup/cpu/cpu.cfs_quota_us").read_text())
period = int(Path("/sys/fs/cgroup/cpu/cpu.cfs_period_us").read_text())
if quota > 0:
return max(1, math.floor(quota / period))
except Exception:
pass
return cpu_count()
LOOKAHEAD = 60 LOOKAHEAD = 60
ATR_SL_MULT = 1.5 ATR_SL_MULT = 1.5
ATR_TP_MULT = 3.0 ATR_TP_MULT = 3.0
@@ -78,7 +107,7 @@ def generate_dataset(df: pd.DataFrame, n_jobs: int | None = None) -> pd.DataFram
total = len(df) total = len(df)
indices = range(60, total - LOOKAHEAD) indices = range(60, total - LOOKAHEAD)
workers = n_jobs or max(1, cpu_count() - 1) workers = n_jobs or max(1, _cgroup_cpu_count() - 1)
print(f" 병렬 처리: {workers}코어 사용 (총 {len(indices):,}개 인덱스)") print(f" 병렬 처리: {workers}코어 사용 (총 {len(indices):,}개 인덱스)")
# DataFrame을 numpy로 변환해서 worker 간 전달 비용 최소화 # DataFrame을 numpy로 변환해서 worker 간 전달 비용 최소화