diff --git a/scripts/train_model.py b/scripts/train_model.py index 89d80b8..f0081e8 100644 --- a/scripts/train_model.py +++ b/scripts/train_model.py @@ -22,6 +22,7 @@ from sklearn.metrics import roc_auc_score, classification_report from src.indicators import Indicators from src.ml_features import build_features, FEATURE_COLS from src.label_builder import build_labels +from src.dataset_builder import generate_dataset_vectorized def _cgroup_cpu_count() -> int: """cgroup v1/v2 쿼터를 읽어 실제 할당된 CPU 수를 반환한다. @@ -107,7 +108,8 @@ def generate_dataset(df: pd.DataFrame, n_jobs: int | None = None) -> pd.DataFram total = len(df) indices = range(60, total - LOOKAHEAD) - workers = n_jobs or max(1, _cgroup_cpu_count() - 1) + # M4 mini: 10코어(P4+E6). 너무 많은 worker는 IPC 오버헤드를 늘리므로 8로 제한 + workers = n_jobs or min(max(1, _cgroup_cpu_count() - 1), 8) print(f" 병렬 처리: {workers}코어 사용 (총 {len(indices):,}개 인덱스)") # DataFrame을 numpy로 변환해서 worker 간 전달 비용 최소화 @@ -117,7 +119,8 @@ def generate_dataset(df: pd.DataFrame, n_jobs: int | None = None) -> pd.DataFram rows = [] errors = [] - chunk = max(1, len(task_args) // (workers * 10)) + # chunksize를 크게 잡아 IPC 직렬화 횟수를 줄임 + chunk = max(100, len(task_args) // workers) with Pool(processes=workers) as pool: for idx, result in enumerate(pool.imap(_process_index, task_args, chunksize=chunk)): if isinstance(result, dict): @@ -143,13 +146,13 @@ def generate_dataset(df: pd.DataFrame, n_jobs: int | None = None) -> pd.DataFram return pd.DataFrame(rows) -def train(data_path: str, n_jobs: int | None = None): +def train(data_path: str): print(f"데이터 로드: {data_path}") df = pd.read_parquet(data_path) print(f"캔들 수: {len(df)}") print("데이터셋 생성 중...") - dataset = generate_dataset(df, n_jobs=n_jobs) + dataset = generate_dataset_vectorized(df) if dataset.empty or "label" not in dataset.columns: raise ValueError(f"데이터셋 생성 실패: 샘플 0개. 위 오류 메시지를 확인하세요.") @@ -216,10 +219,8 @@ def train(data_path: str, n_jobs: int | None = None): def main(): parser = argparse.ArgumentParser() parser.add_argument("--data", default="data/xrpusdt_1m.parquet") - parser.add_argument("--jobs", type=int, default=None, - help="병렬 worker 수 (기본: CPU 수 - 1)") args = parser.parse_args() - train(args.data, n_jobs=args.jobs) + train(args.data) if __name__ == "__main__":