Files
pylingual/dev_scripts/statement/tokenize_seq2seq.py
2025-03-13 16:56:36 -05:00

52 lines
1.8 KiB
Python

import os
from typing import Any
import click
import pathlib
from datasets import load_dataset
from transformers import RobertaTokenizer
from StatementConfiguration import StatementConfiguration, parse_statement_config_json
import functools
def preprocess_function(tokenizer: RobertaTokenizer, max_token_length: int, input_key: str, examples: dict[str, Any]) -> dict[str, Any]:
"""Set up Huggingface tokenizers for both inputs and targets"""
inputs = [ex if ex else "" for ex in examples[input_key]]
targets = [ex if ex else "" for ex in examples["source"]]
return tokenizer(text=inputs, text_target=targets, max_length=max_token_length, truncation=True)
def tokenize_seq2seq_dataset(config: StatementConfiguration):
# ref: https://huggingface.co/Salesforce/codet5-base
tokenizer = RobertaTokenizer.from_pretrained(config.tokenizer_repo_name)
raw_datasets = load_dataset(config.dataset_repo_name, token=True)
column_names = raw_datasets["train"].column_names
input_key = "bytecode"
prepped_preprocess_function = functools.partial(preprocess_function, tokenizer, config.max_token_length, input_key)
tokenized_datasets = raw_datasets.map(
prepped_preprocess_function,
batched=True,
remove_columns=column_names,
num_proc=os.cpu_count(),
desc="Tokenizing datasets",
)
tokenized_datasets.push_to_hub(config.tokenized_dataset_repo_name, private=True)
@click.command(help="Tokenization script for Statement Translation model given a statement json.")
@click.argument("json_path", type=str)
def main(json_path: str):
json_file_path = pathlib.Path(json_path)
statement_config = parse_statement_config_json(json_file_path)
tokenize_seq2seq_dataset(statement_config)
if __name__ == "__main__":
main()