From 298d4ad95ebc2df192deef5c4cca44487effe45b Mon Sep 17 00:00:00 2001 From: 21in7 Date: Sun, 1 Mar 2026 17:46:40 +0900 Subject: [PATCH] feat: enhance train_model.py to dynamically determine CPU count for parallel processing - Added a new function to accurately retrieve the number of allocated CPUs in containerized environments, improving parallel processing efficiency. - Updated the dataset generation function to utilize the new CPU count function, ensuring optimal resource usage during model training. Made-with: Cursor --- scripts/train_model.py | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/scripts/train_model.py b/scripts/train_model.py index 516041d..6407fc7 100644 --- a/scripts/train_model.py +++ b/scripts/train_model.py @@ -8,7 +8,7 @@ sys.path.insert(0, str(Path(__file__).parent.parent)) import argparse import json -import os +import math from datetime import datetime from multiprocessing import Pool, cpu_count from pathlib import Path @@ -23,6 +23,35 @@ from src.indicators import Indicators from src.ml_features import build_features, FEATURE_COLS from src.label_builder import build_labels +def _cgroup_cpu_count() -> int: + """cgroup v1/v2 쿼터를 읽어 실제 할당된 CPU 수를 반환한다. + LXC/컨테이너 환경에서 cpu_count()가 호스트 전체 코어를 반환하는 문제를 방지한다. + 쿼터를 읽을 수 없으면 cpu_count()를 그대로 사용한다. + """ + # cgroup v2 + try: + quota_path = Path("/sys/fs/cgroup/cpu.max") + if quota_path.exists(): + parts = quota_path.read_text().split() + if parts[0] != "max": + quota = int(parts[0]) + period = int(parts[1]) + return max(1, math.floor(quota / period)) + except Exception: + pass + + # cgroup v1 + try: + quota = int(Path("/sys/fs/cgroup/cpu/cpu.cfs_quota_us").read_text()) + period = int(Path("/sys/fs/cgroup/cpu/cpu.cfs_period_us").read_text()) + if quota > 0: + return max(1, math.floor(quota / period)) + except Exception: + pass + + return cpu_count() + + LOOKAHEAD = 60 ATR_SL_MULT = 1.5 ATR_TP_MULT = 3.0 @@ -78,7 +107,7 @@ def generate_dataset(df: pd.DataFrame, n_jobs: int | None = None) -> pd.DataFram total = len(df) indices = range(60, total - LOOKAHEAD) - workers = n_jobs or max(1, cpu_count() - 1) + workers = n_jobs or max(1, _cgroup_cpu_count() - 1) print(f" 병렬 처리: {workers}코어 사용 (총 {len(indices):,}개 인덱스)") # DataFrame을 numpy로 변환해서 worker 간 전달 비용 최소화