perf: replace generate_dataset with vectorized version in train_model
Made-with: Cursor
This commit is contained in:
@@ -22,6 +22,7 @@ from sklearn.metrics import roc_auc_score, classification_report
|
||||
from src.indicators import Indicators
|
||||
from src.ml_features import build_features, FEATURE_COLS
|
||||
from src.label_builder import build_labels
|
||||
from src.dataset_builder import generate_dataset_vectorized
|
||||
|
||||
def _cgroup_cpu_count() -> int:
|
||||
"""cgroup v1/v2 쿼터를 읽어 실제 할당된 CPU 수를 반환한다.
|
||||
@@ -107,7 +108,8 @@ def generate_dataset(df: pd.DataFrame, n_jobs: int | None = None) -> pd.DataFram
|
||||
total = len(df)
|
||||
indices = range(60, total - LOOKAHEAD)
|
||||
|
||||
workers = n_jobs or max(1, _cgroup_cpu_count() - 1)
|
||||
# M4 mini: 10코어(P4+E6). 너무 많은 worker는 IPC 오버헤드를 늘리므로 8로 제한
|
||||
workers = n_jobs or min(max(1, _cgroup_cpu_count() - 1), 8)
|
||||
print(f" 병렬 처리: {workers}코어 사용 (총 {len(indices):,}개 인덱스)")
|
||||
|
||||
# DataFrame을 numpy로 변환해서 worker 간 전달 비용 최소화
|
||||
@@ -117,7 +119,8 @@ def generate_dataset(df: pd.DataFrame, n_jobs: int | None = None) -> pd.DataFram
|
||||
|
||||
rows = []
|
||||
errors = []
|
||||
chunk = max(1, len(task_args) // (workers * 10))
|
||||
# chunksize를 크게 잡아 IPC 직렬화 횟수를 줄임
|
||||
chunk = max(100, len(task_args) // workers)
|
||||
with Pool(processes=workers) as pool:
|
||||
for idx, result in enumerate(pool.imap(_process_index, task_args, chunksize=chunk)):
|
||||
if isinstance(result, dict):
|
||||
@@ -143,13 +146,13 @@ def generate_dataset(df: pd.DataFrame, n_jobs: int | None = None) -> pd.DataFram
|
||||
return pd.DataFrame(rows)
|
||||
|
||||
|
||||
def train(data_path: str, n_jobs: int | None = None):
|
||||
def train(data_path: str):
|
||||
print(f"데이터 로드: {data_path}")
|
||||
df = pd.read_parquet(data_path)
|
||||
print(f"캔들 수: {len(df)}")
|
||||
|
||||
print("데이터셋 생성 중...")
|
||||
dataset = generate_dataset(df, n_jobs=n_jobs)
|
||||
dataset = generate_dataset_vectorized(df)
|
||||
|
||||
if dataset.empty or "label" not in dataset.columns:
|
||||
raise ValueError(f"데이터셋 생성 실패: 샘플 0개. 위 오류 메시지를 확인하세요.")
|
||||
@@ -216,10 +219,8 @@ def train(data_path: str, n_jobs: int | None = None):
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data", default="data/xrpusdt_1m.parquet")
|
||||
parser.add_argument("--jobs", type=int, default=None,
|
||||
help="병렬 worker 수 (기본: CPU 수 - 1)")
|
||||
args = parser.parse_args()
|
||||
train(args.data, n_jobs=args.jobs)
|
||||
train(args.data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user