Files
2025-03-13 16:56:36 -05:00

156 lines
5.4 KiB
Python

import logging
import os
import pathlib
import click
import evaluate
import numpy as np
from datasets import ReadInstruction, load_dataset
from huggingface_hub import hf_hub_download, repo_exists
from SegmentationConfiguration import SegmentationConfiguration, parse_segmentation_config_json
from transformers import AutoModelForTokenClassification, DataCollatorForTokenClassification, PreTrainedTokenizerFast, Trainer, TrainingArguments
# two dictionaries, id2label and label2id, which contain the mappings from ID to label and vice versa.
label_names = ["B", "I", "E"]
id2label = {str(i): label for i, label in enumerate(label_names)}
label2id = {v: k for k, v in id2label.items()}
# compute_metrics: evaluate metric for training and evaluation.
def compute_metrics(eval_preds):
metric = evaluate.load("seqeval")
logits, labels = eval_preds
predictions = np.argmax(logits, axis=-1)
# Remove ignored index (special tokens) and convert to labels
# noqa: E741
true_labels = [[label_names[l] for l in label if l != -100] for label in labels]
true_predictions = [[label_names[p] for (p, l) in zip(prediction, label) if l != -100] for prediction, label in zip(predictions, labels)]
all_metrics = metric.compute(predictions=true_predictions, references=true_labels)
return {
"precision": all_metrics["overall_precision"],
"recall": all_metrics["overall_recall"],
"f1": all_metrics["overall_f1"],
"accuracy": all_metrics["overall_accuracy"],
}
def load_tokenizer(tokenizer_repo_name: str, cache_dir: pathlib.Path) -> PreTrainedTokenizerFast:
tokenizer_dir = cache_dir / "tokenizers" / tokenizer_repo_name
tokenizer_file = hf_hub_download(
repo_id=tokenizer_repo_name,
filename="tokenizer.json",
token=True,
cache_dir=str(tokenizer_dir),
)
tokenizer = PreTrainedTokenizerFast(
tokenizer_file=tokenizer_file,
unk_token="[UNK]",
pad_token="[PAD]",
cls_token="[CLS]",
sep_token="[SEP]",
mask_token="[MASK]",
)
return tokenizer
def load_tokenized_train_and_valid_dataset(dataset_repo_name: str, cache_dir: pathlib.Path, dataset_percentage: int = 100):
dataset_dir = cache_dir / "datasets" / dataset_repo_name
# Load the tokenized dataset
tokenized_train_dataset = load_dataset(
dataset_repo_name,
token=True,
cache_dir=str(dataset_dir),
split=ReadInstruction("train", to=dataset_percentage, unit="%"),
)
tokenized_validation_dataset = load_dataset(
dataset_repo_name,
token=True,
cache_dir=str(dataset_dir),
split="valid",
)
return tokenized_train_dataset, tokenized_validation_dataset
def train_segmentation_model(config: SegmentationConfiguration):
if repo_exists(config.base_repo_name):
logging.error(f"{config.base_repo_name} has already exists")
exit(1)
# training arguments.
training_args = TrainingArguments(
output_dir=str(config.segmenter_dir),
overwrite_output_dir=True,
eval_strategy="epoch",
logging_strategy="epoch",
save_strategy="epoch",
learning_rate=config.segmentation_training_parameters.learning_rate,
num_train_epochs=config.segmentation_training_parameters.epochs,
per_device_train_batch_size=config.segmentation_training_parameters.batch_size,
save_steps=1000,
weight_decay=0.01,
fp16=True,
push_to_hub=True,
hub_model_id=config.segmenter_repo_name,
hub_private_repo=True,
ddp_backend="nccl",
ddp_find_unused_parameters=True,
save_total_limit=5,
)
# load a basic pretrained BERT model
model = AutoModelForTokenClassification.from_pretrained(
pretrained_model_name_or_path=config.mlm_repo_name,
id2label=id2label,
label2id=label2id,
token=True,
)
# Set DataCollator for DataCollatorForTokenClassification
tokenizer = load_tokenizer(config.tokenizer_repo_name, config.cache_dir)
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, max_length=config.max_token_length)
(
tokenized_train_dataset,
tokenized_validation_dataset,
) = load_tokenized_train_and_valid_dataset(config.tokenized_dataset_repo_name, config.cache_dir, config.dataset_percentage)
# Hugging face trainer: a Trainer class to fine-tune pretrained models
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=tokenized_train_dataset,
eval_dataset=tokenized_validation_dataset,
compute_metrics=compute_metrics,
tokenizer=tokenizer,
)
# Training
trainer.train()
if int(os.environ["LOCAL_RANK"]) == 0:
# Save the model
trainer.save_model(str(config.segmenter_dir))
trainer.push_to_hub(
finetuned_from=config.mlm_repo_name,
dataset=config.tokenized_dataset_repo_name,
commit_message=f"Trained on {config.tokenized_dataset_repo_name} using {config.mlm_repo_name}",
)
@click.command(help="Training script for the segmentation model given a segmentation json.")
@click.argument("json_path", type=str)
def main(json_path: str):
json_file_path = pathlib.Path(json_path)
segmentation_config = parse_segmentation_config_json(json_file_path)
train_segmentation_model(segmentation_config)
if __name__ == "__main__":
main()