mirror of
https://github.com/syssec-utd/pylingual.git
synced 2026-05-11 02:40:13 -07:00
156 lines
5.4 KiB
Python
156 lines
5.4 KiB
Python
import logging
|
|
import os
|
|
import pathlib
|
|
import click
|
|
|
|
import evaluate
|
|
import numpy as np
|
|
from datasets import ReadInstruction, load_dataset
|
|
from huggingface_hub import hf_hub_download, repo_exists
|
|
from SegmentationConfiguration import SegmentationConfiguration, parse_segmentation_config_json
|
|
from transformers import AutoModelForTokenClassification, DataCollatorForTokenClassification, PreTrainedTokenizerFast, Trainer, TrainingArguments
|
|
|
|
# two dictionaries, id2label and label2id, which contain the mappings from ID to label and vice versa.
|
|
label_names = ["B", "I", "E"]
|
|
id2label = {str(i): label for i, label in enumerate(label_names)}
|
|
label2id = {v: k for k, v in id2label.items()}
|
|
|
|
|
|
# compute_metrics: evaluate metric for training and evaluation.
|
|
def compute_metrics(eval_preds):
|
|
metric = evaluate.load("seqeval")
|
|
logits, labels = eval_preds
|
|
predictions = np.argmax(logits, axis=-1)
|
|
|
|
# Remove ignored index (special tokens) and convert to labels
|
|
# noqa: E741
|
|
true_labels = [[label_names[l] for l in label if l != -100] for label in labels]
|
|
true_predictions = [[label_names[p] for (p, l) in zip(prediction, label) if l != -100] for prediction, label in zip(predictions, labels)]
|
|
all_metrics = metric.compute(predictions=true_predictions, references=true_labels)
|
|
return {
|
|
"precision": all_metrics["overall_precision"],
|
|
"recall": all_metrics["overall_recall"],
|
|
"f1": all_metrics["overall_f1"],
|
|
"accuracy": all_metrics["overall_accuracy"],
|
|
}
|
|
|
|
|
|
def load_tokenizer(tokenizer_repo_name: str, cache_dir: pathlib.Path) -> PreTrainedTokenizerFast:
|
|
tokenizer_dir = cache_dir / "tokenizers" / tokenizer_repo_name
|
|
|
|
tokenizer_file = hf_hub_download(
|
|
repo_id=tokenizer_repo_name,
|
|
filename="tokenizer.json",
|
|
token=True,
|
|
cache_dir=str(tokenizer_dir),
|
|
)
|
|
tokenizer = PreTrainedTokenizerFast(
|
|
tokenizer_file=tokenizer_file,
|
|
unk_token="[UNK]",
|
|
pad_token="[PAD]",
|
|
cls_token="[CLS]",
|
|
sep_token="[SEP]",
|
|
mask_token="[MASK]",
|
|
)
|
|
|
|
return tokenizer
|
|
|
|
|
|
def load_tokenized_train_and_valid_dataset(dataset_repo_name: str, cache_dir: pathlib.Path, dataset_percentage: int = 100):
|
|
dataset_dir = cache_dir / "datasets" / dataset_repo_name
|
|
# Load the tokenized dataset
|
|
tokenized_train_dataset = load_dataset(
|
|
dataset_repo_name,
|
|
token=True,
|
|
cache_dir=str(dataset_dir),
|
|
split=ReadInstruction("train", to=dataset_percentage, unit="%"),
|
|
)
|
|
|
|
tokenized_validation_dataset = load_dataset(
|
|
dataset_repo_name,
|
|
token=True,
|
|
cache_dir=str(dataset_dir),
|
|
split="valid",
|
|
)
|
|
|
|
return tokenized_train_dataset, tokenized_validation_dataset
|
|
|
|
|
|
def train_segmentation_model(config: SegmentationConfiguration):
|
|
if repo_exists(config.base_repo_name):
|
|
logging.error(f"{config.base_repo_name} has already exists")
|
|
exit(1)
|
|
# training arguments.
|
|
training_args = TrainingArguments(
|
|
output_dir=str(config.segmenter_dir),
|
|
overwrite_output_dir=True,
|
|
eval_strategy="epoch",
|
|
logging_strategy="epoch",
|
|
save_strategy="epoch",
|
|
learning_rate=config.segmentation_training_parameters.learning_rate,
|
|
num_train_epochs=config.segmentation_training_parameters.epochs,
|
|
per_device_train_batch_size=config.segmentation_training_parameters.batch_size,
|
|
save_steps=1000,
|
|
weight_decay=0.01,
|
|
fp16=True,
|
|
push_to_hub=True,
|
|
hub_model_id=config.segmenter_repo_name,
|
|
hub_private_repo=True,
|
|
ddp_backend="nccl",
|
|
ddp_find_unused_parameters=True,
|
|
save_total_limit=5,
|
|
)
|
|
|
|
# load a basic pretrained BERT model
|
|
model = AutoModelForTokenClassification.from_pretrained(
|
|
pretrained_model_name_or_path=config.mlm_repo_name,
|
|
id2label=id2label,
|
|
label2id=label2id,
|
|
token=True,
|
|
)
|
|
|
|
# Set DataCollator for DataCollatorForTokenClassification
|
|
tokenizer = load_tokenizer(config.tokenizer_repo_name, config.cache_dir)
|
|
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, max_length=config.max_token_length)
|
|
|
|
(
|
|
tokenized_train_dataset,
|
|
tokenized_validation_dataset,
|
|
) = load_tokenized_train_and_valid_dataset(config.tokenized_dataset_repo_name, config.cache_dir, config.dataset_percentage)
|
|
|
|
# Hugging face trainer: a Trainer class to fine-tune pretrained models
|
|
trainer = Trainer(
|
|
model=model,
|
|
args=training_args,
|
|
data_collator=data_collator,
|
|
train_dataset=tokenized_train_dataset,
|
|
eval_dataset=tokenized_validation_dataset,
|
|
compute_metrics=compute_metrics,
|
|
tokenizer=tokenizer,
|
|
)
|
|
|
|
# Training
|
|
trainer.train()
|
|
|
|
if int(os.environ["LOCAL_RANK"]) == 0:
|
|
# Save the model
|
|
trainer.save_model(str(config.segmenter_dir))
|
|
|
|
trainer.push_to_hub(
|
|
finetuned_from=config.mlm_repo_name,
|
|
dataset=config.tokenized_dataset_repo_name,
|
|
commit_message=f"Trained on {config.tokenized_dataset_repo_name} using {config.mlm_repo_name}",
|
|
)
|
|
|
|
|
|
@click.command(help="Training script for the segmentation model given a segmentation json.")
|
|
@click.argument("json_path", type=str)
|
|
def main(json_path: str):
|
|
json_file_path = pathlib.Path(json_path)
|
|
segmentation_config = parse_segmentation_config_json(json_file_path)
|
|
train_segmentation_model(segmentation_config)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|