mirror of
https://github.com/syssec-utd/pylingual.git
synced 2026-05-10 18:39:03 -07:00
94 lines
3.2 KiB
Python
94 lines
3.2 KiB
Python
import os
|
|
import pathlib
|
|
import time
|
|
from datetime import timedelta
|
|
import click
|
|
|
|
from datasets import ReadInstruction, load_dataset
|
|
from StatementConfiguration import StatementConfiguration, parse_statement_config_json
|
|
from transformers import (
|
|
DataCollatorForSeq2Seq,
|
|
RobertaTokenizer,
|
|
Seq2SeqTrainer,
|
|
Seq2SeqTrainingArguments,
|
|
T5ForConditionalGeneration,
|
|
)
|
|
|
|
|
|
def load_tokenized_train_dataset(dataset_repo_name: str, dataset_percentage: int):
|
|
# Load the tokenized dataset
|
|
tokenized_train_dataset = load_dataset(
|
|
dataset_repo_name,
|
|
token=True,
|
|
split=ReadInstruction("train", to=dataset_percentage, unit="%"),
|
|
)
|
|
return tokenized_train_dataset
|
|
|
|
|
|
def train_statement_model(config: StatementConfiguration):
|
|
# load model, Salesforce/codet5-base is a pretrained model solving the code generation task.
|
|
tokenizer = RobertaTokenizer.from_pretrained(config.tokenizer_repo_name)
|
|
model = T5ForConditionalGeneration.from_pretrained(config.pretrained_seq2seq_repo_name)
|
|
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
|
|
|
|
model_dir = str(config.statement_model_dir)
|
|
model_repo_name = config.statement_model_repo_name
|
|
|
|
train_args = Seq2SeqTrainingArguments(
|
|
output_dir=model_dir,
|
|
learning_rate=config.statement_training_parameters.learning_rate,
|
|
per_device_train_batch_size=config.statement_training_parameters.batch_size,
|
|
per_device_eval_batch_size=config.statement_training_parameters.batch_size,
|
|
weight_decay=0.01,
|
|
fp16=config.fp16,
|
|
logging_dir=str(config.log_dir),
|
|
report_to="tensorboard",
|
|
logging_strategy="steps",
|
|
logging_steps=1000,
|
|
save_strategy="steps",
|
|
save_steps=10000,
|
|
save_total_limit=2,
|
|
num_train_epochs=config.statement_training_parameters.epochs,
|
|
predict_with_generate=True,
|
|
push_to_hub=True,
|
|
hub_model_id=model_repo_name,
|
|
hub_private_repo=True,
|
|
ddp_backend="nccl",
|
|
ddp_find_unused_parameters=False,
|
|
)
|
|
|
|
tokenized_train_dataset = load_tokenized_train_dataset(config.tokenized_dataset_repo_name, config.dataset_percentage)
|
|
trainer = Seq2SeqTrainer(
|
|
model=model,
|
|
args=train_args,
|
|
data_collator=data_collator,
|
|
train_dataset=tokenized_train_dataset,
|
|
tokenizer=tokenizer,
|
|
)
|
|
|
|
start = time.time()
|
|
trainer.train()
|
|
duration = str(timedelta(seconds=time.time() - start))
|
|
|
|
if int(os.environ["LOCAL_RANK"]) == 0:
|
|
# upload the latest version of the model to the Model Hub on Huggingface
|
|
trainer.save_model(str(config.statement_model_dir))
|
|
# this command returns the URL of the commit it just did
|
|
trainer.push_to_hub(
|
|
commit_message=duration,
|
|
finetuned_from=config.pretrained_seq2seq_repo_name,
|
|
dataset=config.tokenized_dataset_repo_name,
|
|
)
|
|
|
|
|
|
@click.command(help="Training script for the statement translation model given a statement json.")
|
|
@click.argument("json_path", type=str)
|
|
def main(json_path: str):
|
|
json_file_path = pathlib.Path(json_path)
|
|
statement_config = parse_statement_config_json(json_file_path)
|
|
train_statement_model(statement_config)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|