Files
pylingual/dev_scripts/train_models.py
T
2025-09-12 11:40:46 -05:00

130 lines
4.9 KiB
Python

# /// script
# requires-python = ">= 3.12"
# dependencies = [
# "pylingual",
# ]
# [tool.uv.sources]
# pylingual = { path = "../" }
# ///
import logging
import os
import pathlib
import subprocess
import click
from pylingual.utils.get_logger import get_logger
def train_segmentation(segmentation_config_path: pathlib.Path, logger: logging.Logger, nnodes: int = 1, nproc_per_node: int = 1, rdzv_port: int = 29400):
segmentation_root = pathlib.Path(__file__).parent / "segmentation"
# train tokenizer
logger.info("training tokenizer...")
subprocess.run(["uv", "run", segmentation_root / "train_tokenizer.py", segmentation_config_path])
# train mlm (single gpu to avoid conflicts with local tokenized data)
logger.info("training masked language model...")
subprocess.run(
[
"uv", "run",
"torchrun",
f"--nnodes={nnodes}",
f"--nproc-per-node={nproc_per_node}",
"--rdzv-backend=c10d",
f"--rdzv-endpoint=localhost:{rdzv_port}",
segmentation_root / "train_mlm.py",
segmentation_config_path,
],
env=dict(os.environ, NCCL_P2P_DISABLE="1"),
)
# tokenize dataset
logger.info("tokenizing segmentation dataset...")
subprocess.run(["uv", "run", segmentation_root / "tokenize_seg.py", segmentation_config_path])
# train segmentation model (4 gpus)
logger.info("training segmentation model...")
subprocess.run(
[
"uv", "run",
"torchrun",
f"--nnodes={nnodes}",
f"--nproc-per-node={nproc_per_node}",
"--rdzv-backend=c10d",
f"--rdzv-endpoint=localhost:{rdzv_port}",
segmentation_root / "train_seg.py",
segmentation_config_path,
],
env=dict(os.environ, NCCL_P2P_DISABLE="1"),
)
def train_statement(statement_config_path: pathlib.Path, logger: logging.Logger, nnodes: int = 1, nproc_per_node: int = 1, rdzv_port: int = 29400):
statement_root = pathlib.Path(__file__).parent / "statement"
# manual tokenizer
subprocess.run(["uv", "run", statement_root / "train_tokenizer_auto.py", statement_config_path])
# tokenize statement dataset with salesforce tokenizer
logger.info("tokenizing statement dataset...")
subprocess.run(["uv", "run", statement_root / "tokenize_seq2seq.py", statement_config_path])
# train statement model (4 gpus)
logger.info("training statement model...")
subprocess.run(
[
"uv", "run",
"torchrun",
f"--nnodes={nnodes}",
f"--nproc-per-node={nproc_per_node}",
"--rdzv-backend=c10d",
f"--rdzv-endpoint=localhost:{rdzv_port}",
statement_root / "train_seq2seq.py",
statement_config_path,
],
env=dict(os.environ, NCCL_P2P_DISABLE="1"),
)
@click.command(help="Full tokenization and training pipeline for the segmentation and statement translation models.")
@click.option("--segmentation", type=str, default=None, help="The path to the segmentation model description JSON file.")
@click.option("--statement", type=str, default=None, help="The path to the statement model description JSON file.")
@click.option("--nnodes", type=int, default=1, help="Torchrun nnodes arg")
@click.option("--nproc_per_node", type=int, default=1, help="Torchrun nproc_per_node arg")
@click.option("--rdzv_port", "-p", type=int, default=29400, help="Port to use for torchrun rendezvous endpoint")
def main(segmentation: str, statement: str, nnodes: int, nproc_per_node: int, rdzv_port: int):
logger = get_logger("train-models")
### LOAD JSON
logger.info("Training pipeline starting...")
logger.info("Loading dataset description JSON files...")
### CONFIG_PATHS
segmentation_config_path = pathlib.Path(segmentation).resolve() if segmentation is not None else None
statement_config_path = pathlib.Path(statement).resolve() if statement is not None else None
logger.info("Dataset description JSON files loaded!")
### TRAIN SEGMENTATION
if segmentation_config_path is not None:
logger.info("Segmentation model training starting...")
train_segmentation(segmentation_config_path, logger, nnodes, nproc_per_node, rdzv_port)
logger.info("Segmentation model training complete!")
else:
logger.warning("Segmentation model configuration json path not provided in --segmentation; skipping segmentation model training...")
### TRAIN STATEMENT
if statement_config_path is not None:
logger.info("Statement model training starting...")
train_statement(statement_config_path, logger, nnodes, nproc_per_node, rdzv_port)
logger.info("Statement model training complete!")
else:
logger.warning("Statement model configuration json path not provided in --statement; skipping statement model training...")
logger.info("Training pipeline complete!")
if __name__ == "__main__":
main()