mirror of
https://github.com/syssec-utd/pylingual.git
synced 2026-05-11 02:40:13 -07:00
133 lines
5.0 KiB
Python
133 lines
5.0 KiB
Python
# /// script
|
|
# requires-python = ">= 3.12"
|
|
# dependencies = [
|
|
# "pylingual",
|
|
# ]
|
|
# [tool.uv.sources]
|
|
# pylingual = { path = "../", editable = true }
|
|
# ///
|
|
|
|
import logging
|
|
import os
|
|
import pathlib
|
|
import subprocess
|
|
import click
|
|
|
|
from pylingual.utils.get_logger import get_logger
|
|
|
|
|
|
def train_segmentation(segmentation_config_path: pathlib.Path, logger: logging.Logger, nnodes: int = 1, nproc_per_node: int = 1, rdzv_port: int = 29400):
|
|
segmentation_root = pathlib.Path(__file__).parent / "segmentation"
|
|
|
|
# train tokenizer
|
|
logger.info("training tokenizer...")
|
|
subprocess.run(["uv", "run", segmentation_root / "train_tokenizer.py", segmentation_config_path])
|
|
|
|
# train mlm (single gpu to avoid conflicts with local tokenized data)
|
|
logger.info("training masked language model...")
|
|
subprocess.run(
|
|
[
|
|
"uv",
|
|
"run",
|
|
"torchrun",
|
|
f"--nnodes={nnodes}",
|
|
f"--nproc-per-node={nproc_per_node}",
|
|
"--rdzv-backend=c10d",
|
|
f"--rdzv-endpoint=localhost:{rdzv_port}",
|
|
segmentation_root / "train_mlm.py",
|
|
segmentation_config_path,
|
|
],
|
|
env=dict(os.environ, NCCL_P2P_DISABLE="1"),
|
|
)
|
|
|
|
# tokenize dataset
|
|
logger.info("tokenizing segmentation dataset...")
|
|
subprocess.run(["uv", "run", segmentation_root / "tokenize_seg.py", segmentation_config_path])
|
|
|
|
# train segmentation model (4 gpus)
|
|
logger.info("training segmentation model...")
|
|
subprocess.run(
|
|
[
|
|
"uv",
|
|
"run",
|
|
"torchrun",
|
|
f"--nnodes={nnodes}",
|
|
f"--nproc-per-node={nproc_per_node}",
|
|
"--rdzv-backend=c10d",
|
|
f"--rdzv-endpoint=localhost:{rdzv_port}",
|
|
segmentation_root / "train_seg.py",
|
|
segmentation_config_path,
|
|
],
|
|
env=dict(os.environ, NCCL_P2P_DISABLE="1"),
|
|
)
|
|
|
|
|
|
def train_statement(statement_config_path: pathlib.Path, logger: logging.Logger, nnodes: int = 1, nproc_per_node: int = 1, rdzv_port: int = 29400):
|
|
statement_root = pathlib.Path(__file__).parent / "statement"
|
|
|
|
# manual tokenizer
|
|
subprocess.run(["uv", "run", statement_root / "train_tokenizer_auto.py", statement_config_path])
|
|
|
|
# tokenize statement dataset with salesforce tokenizer
|
|
logger.info("tokenizing statement dataset...")
|
|
subprocess.run(["uv", "run", statement_root / "tokenize_seq2seq.py", statement_config_path])
|
|
|
|
# train statement model (4 gpus)
|
|
logger.info("training statement model...")
|
|
subprocess.run(
|
|
[
|
|
"uv",
|
|
"run",
|
|
"torchrun",
|
|
f"--nnodes={nnodes}",
|
|
f"--nproc-per-node={nproc_per_node}",
|
|
"--rdzv-backend=c10d",
|
|
f"--rdzv-endpoint=localhost:{rdzv_port}",
|
|
statement_root / "train_seq2seq.py",
|
|
statement_config_path,
|
|
],
|
|
env=dict(os.environ, NCCL_P2P_DISABLE="1"),
|
|
)
|
|
|
|
|
|
@click.command(help="Full tokenization and training pipeline for the segmentation and statement translation models.")
|
|
@click.option("--segmentation", type=str, default=None, help="The path to the segmentation model description JSON file.")
|
|
@click.option("--statement", type=str, default=None, help="The path to the statement model description JSON file.")
|
|
@click.option("--nnodes", type=int, default=1, help="Torchrun nnodes arg")
|
|
@click.option("--nproc_per_node", type=int, default=1, help="Torchrun nproc_per_node arg")
|
|
@click.option("--rdzv_port", "-p", type=int, default=29400, help="Port to use for torchrun rendezvous endpoint")
|
|
def main(segmentation: str, statement: str, nnodes: int, nproc_per_node: int, rdzv_port: int):
|
|
logger = get_logger("train-models")
|
|
|
|
### LOAD JSON
|
|
logger.info("Training pipeline starting...")
|
|
logger.info("Loading dataset description JSON files...")
|
|
|
|
### CONFIG_PATHS
|
|
segmentation_config_path = pathlib.Path(segmentation).resolve() if segmentation is not None else None
|
|
statement_config_path = pathlib.Path(statement).resolve() if statement is not None else None
|
|
|
|
logger.info("Dataset description JSON files loaded!")
|
|
|
|
### TRAIN SEGMENTATION
|
|
if segmentation_config_path is not None:
|
|
logger.info("Segmentation model training starting...")
|
|
train_segmentation(segmentation_config_path, logger, nnodes, nproc_per_node, rdzv_port)
|
|
logger.info("Segmentation model training complete!")
|
|
else:
|
|
logger.warning("Segmentation model configuration json path not provided in --segmentation; skipping segmentation model training...")
|
|
|
|
### TRAIN STATEMENT
|
|
if statement_config_path is not None:
|
|
logger.info("Statement model training starting...")
|
|
train_statement(statement_config_path, logger, nnodes, nproc_per_node, rdzv_port)
|
|
logger.info("Statement model training complete!")
|
|
else:
|
|
logger.warning("Statement model configuration json path not provided in --statement; skipping statement model training...")
|
|
|
|
logger.info("Training pipeline complete!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|