mirror of
https://github.com/syssec-utd/pylingual.git
synced 2026-05-10 18:39:03 -07:00
rename
This commit is contained in:
@@ -0,0 +1,93 @@
|
||||
import os
|
||||
import pathlib
|
||||
import time
|
||||
from datetime import timedelta
|
||||
import click
|
||||
|
||||
from datasets import ReadInstruction, load_dataset
|
||||
from StatementConfiguration import StatementConfiguration, parse_statement_config_json
|
||||
from transformers import (
|
||||
DataCollatorForSeq2Seq,
|
||||
RobertaTokenizer,
|
||||
Seq2SeqTrainer,
|
||||
Seq2SeqTrainingArguments,
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
|
||||
|
||||
def load_tokenized_train_dataset(dataset_repo_name: str, dataset_percentage: int):
|
||||
# Load the tokenized dataset
|
||||
tokenized_train_dataset = load_dataset(
|
||||
dataset_repo_name,
|
||||
token=True,
|
||||
split=ReadInstruction("train", to=dataset_percentage, unit="%"),
|
||||
)
|
||||
return tokenized_train_dataset
|
||||
|
||||
|
||||
def train_statement_model(config: StatementConfiguration):
|
||||
# load model, Salesforce/codet5-base is a pretrained model solving the code generation task.
|
||||
tokenizer = RobertaTokenizer.from_pretrained(config.tokenizer_repo_name)
|
||||
model = T5ForConditionalGeneration.from_pretrained(config.pretrained_seq2seq_repo_name)
|
||||
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
|
||||
|
||||
model_dir = str(config.statement_model_dir)
|
||||
model_repo_name = config.statement_model_repo_name
|
||||
|
||||
train_args = Seq2SeqTrainingArguments(
|
||||
output_dir=model_dir,
|
||||
learning_rate=config.statement_training_parameters.learning_rate,
|
||||
per_device_train_batch_size=config.statement_training_parameters.batch_size,
|
||||
per_device_eval_batch_size=config.statement_training_parameters.batch_size,
|
||||
weight_decay=0.01,
|
||||
fp16=config.fp16,
|
||||
logging_dir=str(config.log_dir),
|
||||
report_to="tensorboard",
|
||||
logging_strategy="steps",
|
||||
logging_steps=1000,
|
||||
save_strategy="steps",
|
||||
save_steps=10000,
|
||||
save_total_limit=2,
|
||||
num_train_epochs=config.statement_training_parameters.epochs,
|
||||
predict_with_generate=True,
|
||||
push_to_hub=True,
|
||||
hub_model_id=model_repo_name,
|
||||
hub_private_repo=True,
|
||||
ddp_backend="nccl",
|
||||
ddp_find_unused_parameters=False,
|
||||
)
|
||||
|
||||
tokenized_train_dataset = load_tokenized_train_dataset(config.tokenized_dataset_repo_name, config.dataset_percentage)
|
||||
trainer = Seq2SeqTrainer(
|
||||
model=model,
|
||||
args=train_args,
|
||||
data_collator=data_collator,
|
||||
train_dataset=tokenized_train_dataset,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
start = time.time()
|
||||
trainer.train()
|
||||
duration = str(timedelta(seconds=time.time() - start))
|
||||
|
||||
if int(os.environ["LOCAL_RANK"]) == 0:
|
||||
# upload the latest version of the model to the Model Hub on Huggingface
|
||||
trainer.save_model(str(config.statement_model_dir))
|
||||
# this command returns the URL of the commit it just did
|
||||
trainer.push_to_hub(
|
||||
commit_message=duration,
|
||||
finetuned_from=config.pretrained_seq2seq_repo_name,
|
||||
dataset=config.tokenized_dataset_repo_name,
|
||||
)
|
||||
|
||||
|
||||
@click.command(help="Training script for the statement translation model given a statement json.")
|
||||
@click.argument("json_path", type=str)
|
||||
def main(json_path: str):
|
||||
json_file_path = pathlib.Path(json_path)
|
||||
statement_config = parse_statement_config_json(json_file_path)
|
||||
train_statement_model(statement_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user