mirror of
https://github.com/syssec-utd/pylingual.git
synced 2026-05-10 18:39:03 -07:00
rename
This commit is contained in:
@@ -0,0 +1,51 @@
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
import pathlib
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import RobertaTokenizer
|
||||
|
||||
from StatementConfiguration import StatementConfiguration, parse_statement_config_json
|
||||
|
||||
import functools
|
||||
|
||||
|
||||
def preprocess_function(tokenizer: RobertaTokenizer, max_token_length: int, input_key: str, examples: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Set up Huggingface tokenizers for both inputs and targets"""
|
||||
inputs = [ex if ex else "" for ex in examples[input_key]]
|
||||
targets = [ex if ex else "" for ex in examples["source"]]
|
||||
|
||||
return tokenizer(text=inputs, text_target=targets, max_length=max_token_length, truncation=True)
|
||||
|
||||
|
||||
def tokenize_seq2seq_dataset(config: StatementConfiguration):
|
||||
# ref: https://huggingface.co/Salesforce/codet5-base
|
||||
tokenizer = RobertaTokenizer.from_pretrained(config.tokenizer_repo_name)
|
||||
raw_datasets = load_dataset(config.dataset_repo_name, token=True)
|
||||
|
||||
column_names = raw_datasets["train"].column_names
|
||||
input_key = "bytecode"
|
||||
prepped_preprocess_function = functools.partial(preprocess_function, tokenizer, config.max_token_length, input_key)
|
||||
tokenized_datasets = raw_datasets.map(
|
||||
prepped_preprocess_function,
|
||||
batched=True,
|
||||
remove_columns=column_names,
|
||||
num_proc=os.cpu_count(),
|
||||
desc="Tokenizing datasets",
|
||||
)
|
||||
|
||||
tokenized_datasets.push_to_hub(config.tokenized_dataset_repo_name, private=True)
|
||||
|
||||
|
||||
@click.command(help="Tokenization script for Statement Translation model given a statement json.")
|
||||
@click.argument("json_path", type=str)
|
||||
def main(json_path: str):
|
||||
json_file_path = pathlib.Path(json_path)
|
||||
statement_config = parse_statement_config_json(json_file_path)
|
||||
tokenize_seq2seq_dataset(statement_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user