mirror of
https://github.com/syssec-utd/pylingual.git
synced 2026-05-10 18:39:03 -07:00
rename
This commit is contained in:
@@ -0,0 +1,18 @@
|
||||
# seq2seq
|
||||
|
||||
- train_tokenizer_auto.py:
|
||||
- trains the manual tokenizer
|
||||
|
||||
- tokenize_seq2seq.py:
|
||||
- tokenize the dataset for the seq2seq model
|
||||
|
||||
- train_seq2seq.py:
|
||||
- finetuning the pretrained model
|
||||
- will create a sequence-to-sequence translation model
|
||||
|
||||
- StatementConfiguration.py
|
||||
- defines the JSON format for statement translation training
|
||||
|
||||
# manual1
|
||||
|
||||
Contains JSONs mapping bytecode instructions and their configurations to use in training.
|
||||
@@ -0,0 +1,59 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pathlib
|
||||
import json
|
||||
import logging
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingParameters:
|
||||
batch_size: int
|
||||
epochs: int
|
||||
learning_rate: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class StatementConfiguration:
|
||||
base_repo_name: str
|
||||
dataset_repo_name: str
|
||||
tokenizer_repo_name: str
|
||||
pretrained_seq2seq_repo_name: str
|
||||
cache_dir: pathlib.Path
|
||||
max_token_length: int
|
||||
dataset_percentage: int
|
||||
do_eval: bool
|
||||
fp16: bool
|
||||
statement_training_parameters: TrainingParameters
|
||||
|
||||
@property
|
||||
def tokenized_dataset_repo_name(self):
|
||||
return self.dataset_repo_name + "-tokenized"
|
||||
|
||||
@property
|
||||
def statement_model_repo_name(self):
|
||||
return self.base_repo_name + "-statement"
|
||||
|
||||
@property
|
||||
def statement_model_dir(self):
|
||||
return self.cache_dir / "models" / self.statement_model_repo_name
|
||||
|
||||
@property
|
||||
def log_dir(self):
|
||||
return self.statement_model_dir / "logs"
|
||||
|
||||
def __post_init__(self):
|
||||
self.cache_dir = pathlib.Path(self.cache_dir)
|
||||
|
||||
|
||||
def parse_statement_config_json(json_file_path: pathlib.Path, logger: logging.Logger = None) -> StatementConfiguration:
|
||||
if not json_file_path.exists():
|
||||
raise FileNotFoundError(f"{json_file_path} does not exist")
|
||||
|
||||
if logger:
|
||||
logger.info(f"Loading model description from {json_file_path}...")
|
||||
|
||||
with json_file_path.open() as json_file:
|
||||
statement_config_dict = json.load(json_file)
|
||||
|
||||
statement_config_dict["statement_training_parameters"] = TrainingParameters(**statement_config_dict["statement_training_parameters"])
|
||||
return StatementConfiguration(**statement_config_dict)
|
||||
@@ -0,0 +1,51 @@
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
import pathlib
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import RobertaTokenizer
|
||||
|
||||
from StatementConfiguration import StatementConfiguration, parse_statement_config_json
|
||||
|
||||
import functools
|
||||
|
||||
|
||||
def preprocess_function(tokenizer: RobertaTokenizer, max_token_length: int, input_key: str, examples: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Set up Huggingface tokenizers for both inputs and targets"""
|
||||
inputs = [ex if ex else "" for ex in examples[input_key]]
|
||||
targets = [ex if ex else "" for ex in examples["source"]]
|
||||
|
||||
return tokenizer(text=inputs, text_target=targets, max_length=max_token_length, truncation=True)
|
||||
|
||||
|
||||
def tokenize_seq2seq_dataset(config: StatementConfiguration):
|
||||
# ref: https://huggingface.co/Salesforce/codet5-base
|
||||
tokenizer = RobertaTokenizer.from_pretrained(config.tokenizer_repo_name)
|
||||
raw_datasets = load_dataset(config.dataset_repo_name, token=True)
|
||||
|
||||
column_names = raw_datasets["train"].column_names
|
||||
input_key = "bytecode"
|
||||
prepped_preprocess_function = functools.partial(preprocess_function, tokenizer, config.max_token_length, input_key)
|
||||
tokenized_datasets = raw_datasets.map(
|
||||
prepped_preprocess_function,
|
||||
batched=True,
|
||||
remove_columns=column_names,
|
||||
num_proc=os.cpu_count(),
|
||||
desc="Tokenizing datasets",
|
||||
)
|
||||
|
||||
tokenized_datasets.push_to_hub(config.tokenized_dataset_repo_name, private=True)
|
||||
|
||||
|
||||
@click.command(help="Tokenization script for Statement Translation model given a statement json.")
|
||||
@click.argument("json_path", type=str)
|
||||
def main(json_path: str):
|
||||
json_file_path = pathlib.Path(json_path)
|
||||
statement_config = parse_statement_config_json(json_file_path)
|
||||
tokenize_seq2seq_dataset(statement_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,929 @@
|
||||
{
|
||||
"add_prefix_space": false,
|
||||
"additional_special_tokens": [
|
||||
"<pad>",
|
||||
"<s>",
|
||||
"</s>",
|
||||
"<unk>",
|
||||
"<mask>",
|
||||
"!",
|
||||
"\"",
|
||||
"#",
|
||||
"$",
|
||||
"%",
|
||||
"&",
|
||||
"'",
|
||||
"(",
|
||||
")",
|
||||
"*",
|
||||
"+",
|
||||
",",
|
||||
"-",
|
||||
".",
|
||||
"/",
|
||||
"0",
|
||||
"1",
|
||||
"2",
|
||||
"3",
|
||||
"4",
|
||||
"5",
|
||||
"6",
|
||||
"7",
|
||||
"8",
|
||||
"9",
|
||||
":",
|
||||
";",
|
||||
"<",
|
||||
"=",
|
||||
">",
|
||||
"?",
|
||||
"@",
|
||||
"A",
|
||||
"B",
|
||||
"C",
|
||||
"D",
|
||||
"E",
|
||||
"F",
|
||||
"G",
|
||||
"H",
|
||||
"I",
|
||||
"J",
|
||||
"K",
|
||||
"L",
|
||||
"M",
|
||||
"N",
|
||||
"O",
|
||||
"P",
|
||||
"Q",
|
||||
"R",
|
||||
"S",
|
||||
"T",
|
||||
"U",
|
||||
"V",
|
||||
"W",
|
||||
"X",
|
||||
"Y",
|
||||
"Z",
|
||||
"[",
|
||||
"\\",
|
||||
"]",
|
||||
"^",
|
||||
"_",
|
||||
"`",
|
||||
"a",
|
||||
"b",
|
||||
"c",
|
||||
"d",
|
||||
"e",
|
||||
"f",
|
||||
"g",
|
||||
"h",
|
||||
"i",
|
||||
"j",
|
||||
"k",
|
||||
"l",
|
||||
"m",
|
||||
"n",
|
||||
"o",
|
||||
"p",
|
||||
"q",
|
||||
"r",
|
||||
"s",
|
||||
"t",
|
||||
"u",
|
||||
"v",
|
||||
"w",
|
||||
"x",
|
||||
"y",
|
||||
"z",
|
||||
"{",
|
||||
"|",
|
||||
"}",
|
||||
"~",
|
||||
"Ġ",
|
||||
"-=",
|
||||
"<<",
|
||||
">>",
|
||||
":=",
|
||||
">=",
|
||||
"<=",
|
||||
"==",
|
||||
"!=",
|
||||
"+=",
|
||||
"//=",
|
||||
"**=",
|
||||
"/=",
|
||||
"//",
|
||||
"%=",
|
||||
"@=",
|
||||
"&=",
|
||||
"|=",
|
||||
"^=",
|
||||
">>=",
|
||||
"<<=",
|
||||
"*=",
|
||||
"()",
|
||||
"):",
|
||||
"~>>",
|
||||
"**",
|
||||
"<codeobj:",
|
||||
"<KWARG_PAD>",
|
||||
"E->",
|
||||
"<TAP_0>",
|
||||
"defaults",
|
||||
"args:",
|
||||
"vararg:",
|
||||
"<TAP_1>",
|
||||
"<START_LINE>",
|
||||
"<SUB>",
|
||||
"<TAP_UP>",
|
||||
"E-END",
|
||||
"<TAP_2>",
|
||||
"~~>",
|
||||
"<TAP_ST>",
|
||||
"<SEP>",
|
||||
"</SUB>",
|
||||
"<TAP_3>",
|
||||
"<TAP_4>",
|
||||
"<TAP_5>",
|
||||
"<TAP_6>",
|
||||
"<TAP_7>",
|
||||
"False",
|
||||
"None",
|
||||
"True",
|
||||
"and",
|
||||
"assert",
|
||||
"async",
|
||||
"await",
|
||||
"break",
|
||||
"class",
|
||||
"continue",
|
||||
"def",
|
||||
"del",
|
||||
"elif",
|
||||
"else",
|
||||
"else:",
|
||||
"except",
|
||||
"except:",
|
||||
"finally",
|
||||
"finally:",
|
||||
"for",
|
||||
"from",
|
||||
"global",
|
||||
"if",
|
||||
"import",
|
||||
"in",
|
||||
"is",
|
||||
"lambda",
|
||||
"nonlocal",
|
||||
"not",
|
||||
"or",
|
||||
"pass",
|
||||
"raise",
|
||||
"return",
|
||||
"try",
|
||||
"try:",
|
||||
"while",
|
||||
"with",
|
||||
"yield",
|
||||
"case",
|
||||
"as",
|
||||
"ASYNC_GEN_WRAP",
|
||||
"BEFORE_ASYNC_WITH",
|
||||
"BEFORE_WITH",
|
||||
"BEGIN_FINALLY",
|
||||
"BINARY_ADD",
|
||||
"BINARY_AND",
|
||||
"BINARY_FLOOR_DIVIDE",
|
||||
"BINARY_LSHIFT",
|
||||
"BINARY_MATRIX_MULTIPLY",
|
||||
"BINARY_MODULO",
|
||||
"BINARY_MULTIPLY",
|
||||
"BINARY_OP",
|
||||
"BINARY_OR",
|
||||
"BINARY_POWER",
|
||||
"BINARY_RSHIFT",
|
||||
"BINARY_SLICE",
|
||||
"BINARY_SUBSCR",
|
||||
"BINARY_SUBTRACT",
|
||||
"BINARY_TRUE_DIVIDE",
|
||||
"BINARY_XOR",
|
||||
"BREAK_LOOP",
|
||||
"BUILD_CONST_KEY_MAP",
|
||||
"BUILD_LIST",
|
||||
"BUILD_LIST_UNPACK",
|
||||
"BUILD_MAP",
|
||||
"BUILD_MAP_UNPACK",
|
||||
"BUILD_MAP_UNPACK_WITH_CALL",
|
||||
"BUILD_SET",
|
||||
"BUILD_SET_UNPACK",
|
||||
"BUILD_SLICE",
|
||||
"BUILD_STRING",
|
||||
"BUILD_TUPLE",
|
||||
"BUILD_TUPLE_UNPACK",
|
||||
"BUILD_TUPLE_UNPACK_WITH_CALL",
|
||||
"CACHE",
|
||||
"CALL_FINALLY",
|
||||
"CALL_FUNCTION",
|
||||
"CALL_FUNCTION_EX",
|
||||
"CALL_FUNCTION_KW",
|
||||
"CALL_METHOD",
|
||||
"CHECK_EG_MATCH",
|
||||
"CHECK_EXC_MATCH",
|
||||
"CLEANUP_THROW",
|
||||
"COMPARE_OP",
|
||||
"CONTAINS_OP",
|
||||
"CONTINUE_LOOP",
|
||||
"COPY",
|
||||
"COPY_DICT_WITHOUT_KEYS",
|
||||
"COPY_FREE_VARS",
|
||||
"DELETE_ATTR",
|
||||
"DELETE_DEREF",
|
||||
"DELETE_FAST",
|
||||
"DELETE_GLOBAL",
|
||||
"DELETE_NAME",
|
||||
"DELETE_SUBSCR",
|
||||
"DICT_MERGE",
|
||||
"DICT_UPDATE",
|
||||
"DUP_TOP",
|
||||
"DUP_TOP_TWO",
|
||||
"END_ASYNC_FOR",
|
||||
"END_FINALLY",
|
||||
"END_FOR",
|
||||
"END_SEND",
|
||||
"EXTENDED_ARG",
|
||||
"FOR_ITER",
|
||||
"FORMAT_VALUE",
|
||||
"GEN_START",
|
||||
"GET_AITER",
|
||||
"GET_ANEXT",
|
||||
"GET_AWAITABLE",
|
||||
"GET_ITER",
|
||||
"GET_LEN",
|
||||
"GET_YIELD_FROM_ITER",
|
||||
"IMPORT_FROM",
|
||||
"IMPORT_NAME",
|
||||
"IMPORT_STAR",
|
||||
"INPLACE_ADD",
|
||||
"INPLACE_AND",
|
||||
"INPLACE_FLOOR_DIVIDE",
|
||||
"INPLACE_LSHIFT",
|
||||
"INPLACE_MATRIX_MULTIPLY",
|
||||
"INPLACE_MODULO",
|
||||
"INPLACE_MULTIPLY",
|
||||
"INPLACE_OR",
|
||||
"INPLACE_POWER",
|
||||
"INPLACE_RSHIFT",
|
||||
"INPLACE_SUBTRACT",
|
||||
"INPLACE_TRUE_DIVIDE",
|
||||
"INPLACE_XOR",
|
||||
"INTERPRETER_EXIT",
|
||||
"IS_OP",
|
||||
"JUMP_ABSOLUTE",
|
||||
"JUMP_BACKWARD",
|
||||
"JUMP_BACKWARD_NO_INTERRUPT",
|
||||
"JUMP_FORWARD",
|
||||
"JUMP_IF_FALSE_OR_POP",
|
||||
"JUMP_IF_NOT_EXC_MATCH",
|
||||
"JUMP_IF_TRUE_OR_POP",
|
||||
"LIST_APPEND",
|
||||
"LIST_EXTEND",
|
||||
"LIST_TO_TUPLE",
|
||||
"LOAD_ASSERTION_ERROR",
|
||||
"LOAD_ATTR",
|
||||
"LOAD_BUILD_CLASS",
|
||||
"LOAD_CLASSDEREF",
|
||||
"LOAD_CLOSURE",
|
||||
"LOAD_CONST",
|
||||
"LOAD_DEREF",
|
||||
"LOAD_FAST",
|
||||
"LOAD_FAST_AND_CLEAR",
|
||||
"LOAD_FAST_CHECK",
|
||||
"LOAD_GLOBAL",
|
||||
"LOAD_LOCALS",
|
||||
"LOAD_METHOD",
|
||||
"LOAD_NAME",
|
||||
"LOAD_SUPER_ATTR",
|
||||
"MAKE_CELL",
|
||||
"MAKE_FUNCTION",
|
||||
"MAP_ADD",
|
||||
"MATCH_CLASS",
|
||||
"MATCH_KEYS",
|
||||
"MATCH_MAPPING",
|
||||
"MATCH_SEQUENCE",
|
||||
"NOP",
|
||||
"POP_BLOCK",
|
||||
"POP_EXCEPT",
|
||||
"POP_FINALLY",
|
||||
"POP_JUMP_FORWARD_IF_FALSE",
|
||||
"POP_JUMP_FORWARD_IF_NONE",
|
||||
"POP_JUMP_FORWARD_IF_NOT_NONE",
|
||||
"POP_JUMP_FORWARD_IF_TRUE",
|
||||
"POP_JUMP_IF_FALSE",
|
||||
"POP_JUMP_IF_NONE",
|
||||
"POP_JUMP_IF_NOT_NONE",
|
||||
"POP_JUMP_IF_TRUE",
|
||||
"POP_TOP",
|
||||
"PRECALL",
|
||||
"PREP_RERAISE_STAR",
|
||||
"PRINT_EXPR",
|
||||
"PUSH_EXC_INFO",
|
||||
"PUSH_NULL",
|
||||
"RAISE_VARARGS",
|
||||
"RERAISE",
|
||||
"RESERVED",
|
||||
"RESUME",
|
||||
"RETURN_CONST",
|
||||
"RETURN_GENERATOR",
|
||||
"RETURN_VALUE",
|
||||
"ROT_FOUR",
|
||||
"ROT_N",
|
||||
"ROT_THREE",
|
||||
"ROT_TWO",
|
||||
"SEND",
|
||||
"SET_ADD",
|
||||
"SET_UPDATE",
|
||||
"SETUP_ANNOTATIONS",
|
||||
"SETUP_ASYNC_WITH",
|
||||
"SETUP_EXCEPT",
|
||||
"SETUP_FINALLY",
|
||||
"SETUP_LOOP",
|
||||
"SETUP_WITH",
|
||||
"STORE_ATTR",
|
||||
"STORE_DEREF",
|
||||
"STORE_FAST",
|
||||
"STORE_GLOBAL",
|
||||
"STORE_NAME",
|
||||
"STORE_SLICE",
|
||||
"STORE_SUBSCR",
|
||||
"SWAP",
|
||||
"UNARY_INVERT",
|
||||
"UNARY_NEGATIVE",
|
||||
"UNARY_NOT",
|
||||
"UNARY_POSITIVE",
|
||||
"UNPACK_EX",
|
||||
"UNPACK_SEQUENCE",
|
||||
"WITH_CLEANUP_FINISH",
|
||||
"WITH_CLEANUP_START",
|
||||
"WITH_EXCEPT_START",
|
||||
"YIELD_FROM",
|
||||
"YIELD_VALUE",
|
||||
"<mask_0>",
|
||||
"<mask_1>",
|
||||
"<mask_2>",
|
||||
"<mask_3>",
|
||||
"<mask_4>",
|
||||
"<mask_5>",
|
||||
"<mask_6>",
|
||||
"<mask_7>",
|
||||
"<mask_8>",
|
||||
"<mask_9>",
|
||||
"<mask_10>",
|
||||
"<mask_11>",
|
||||
"<mask_12>",
|
||||
"<mask_13>",
|
||||
"<mask_14>",
|
||||
"<mask_15>",
|
||||
"<mask_16>",
|
||||
"<mask_17>",
|
||||
"<mask_18>",
|
||||
"<mask_19>",
|
||||
"<mask_20>",
|
||||
"<mask_21>",
|
||||
"<mask_22>",
|
||||
"<mask_23>",
|
||||
"<mask_24>",
|
||||
"<mask_25>",
|
||||
"<mask_26>",
|
||||
"<mask_27>",
|
||||
"<mask_28>",
|
||||
"<mask_29>",
|
||||
"<mask_30>",
|
||||
"<mask_31>",
|
||||
"<mask_32>",
|
||||
"<mask_33>",
|
||||
"<mask_34>",
|
||||
"<mask_35>",
|
||||
"<mask_36>",
|
||||
"<mask_37>",
|
||||
"<mask_38>",
|
||||
"<mask_39>",
|
||||
"<mask_40>",
|
||||
"<mask_41>",
|
||||
"<mask_42>",
|
||||
"<mask_43>",
|
||||
"<mask_44>",
|
||||
"<mask_45>",
|
||||
"<mask_46>",
|
||||
"<mask_47>",
|
||||
"<mask_48>",
|
||||
"<mask_49>",
|
||||
"<mask_50>",
|
||||
"<mask_51>",
|
||||
"<mask_52>",
|
||||
"<mask_53>",
|
||||
"<mask_54>",
|
||||
"<mask_55>",
|
||||
"<mask_56>",
|
||||
"<mask_57>",
|
||||
"<mask_58>",
|
||||
"<mask_59>",
|
||||
"<mask_60>",
|
||||
"<mask_61>",
|
||||
"<mask_62>",
|
||||
"<mask_63>",
|
||||
"<mask_64>",
|
||||
"<mask_65>",
|
||||
"<mask_66>",
|
||||
"<mask_67>",
|
||||
"<mask_68>",
|
||||
"<mask_69>",
|
||||
"<mask_70>",
|
||||
"<mask_71>",
|
||||
"<mask_72>",
|
||||
"<mask_73>",
|
||||
"<mask_74>",
|
||||
"<mask_75>",
|
||||
"<mask_76>",
|
||||
"<mask_77>",
|
||||
"<mask_78>",
|
||||
"<mask_79>",
|
||||
"<mask_80>",
|
||||
"<mask_81>",
|
||||
"<mask_82>",
|
||||
"<mask_83>",
|
||||
"<mask_84>",
|
||||
"<mask_85>",
|
||||
"<mask_86>",
|
||||
"<mask_87>",
|
||||
"<mask_88>",
|
||||
"<mask_89>",
|
||||
"<mask_90>",
|
||||
"<mask_91>",
|
||||
"<mask_92>",
|
||||
"<mask_93>",
|
||||
"<mask_94>",
|
||||
"<mask_95>",
|
||||
"<mask_96>",
|
||||
"<mask_97>",
|
||||
"<mask_98>",
|
||||
"<mask_99>",
|
||||
"<mask_100>",
|
||||
"<mask_101>",
|
||||
"<mask_102>",
|
||||
"<mask_103>",
|
||||
"<mask_104>",
|
||||
"<mask_105>",
|
||||
"<mask_106>",
|
||||
"<mask_107>",
|
||||
"<mask_108>",
|
||||
"<mask_109>",
|
||||
"<mask_110>",
|
||||
"<mask_111>",
|
||||
"<mask_112>",
|
||||
"<mask_113>",
|
||||
"<mask_114>",
|
||||
"<mask_115>",
|
||||
"<mask_116>",
|
||||
"<mask_117>",
|
||||
"<mask_118>",
|
||||
"<mask_119>",
|
||||
"<mask_120>",
|
||||
"<mask_121>",
|
||||
"<mask_122>",
|
||||
"<mask_123>",
|
||||
"<mask_124>",
|
||||
"<mask_125>",
|
||||
"<mask_126>",
|
||||
"<mask_127>",
|
||||
"<mask_128>",
|
||||
"<mask_129>",
|
||||
"<mask_130>",
|
||||
"<mask_131>",
|
||||
"<mask_132>",
|
||||
"<mask_133>",
|
||||
"<mask_134>",
|
||||
"<mask_135>",
|
||||
"<mask_136>",
|
||||
"<mask_137>",
|
||||
"<mask_138>",
|
||||
"<mask_139>",
|
||||
"<mask_140>",
|
||||
"<mask_141>",
|
||||
"<mask_142>",
|
||||
"<mask_143>",
|
||||
"<mask_144>",
|
||||
"<mask_145>",
|
||||
"<mask_146>",
|
||||
"<mask_147>",
|
||||
"<mask_148>",
|
||||
"<mask_149>",
|
||||
"<mask_150>",
|
||||
"<mask_151>",
|
||||
"<mask_152>",
|
||||
"<mask_153>",
|
||||
"<mask_154>",
|
||||
"<mask_155>",
|
||||
"<mask_156>",
|
||||
"<mask_157>",
|
||||
"<mask_158>",
|
||||
"<mask_159>",
|
||||
"<mask_160>",
|
||||
"<mask_161>",
|
||||
"<mask_162>",
|
||||
"<mask_163>",
|
||||
"<mask_164>",
|
||||
"<mask_165>",
|
||||
"<mask_166>",
|
||||
"<mask_167>",
|
||||
"<mask_168>",
|
||||
"<mask_169>",
|
||||
"<mask_170>",
|
||||
"<mask_171>",
|
||||
"<mask_172>",
|
||||
"<mask_173>",
|
||||
"<mask_174>",
|
||||
"<mask_175>",
|
||||
"<mask_176>",
|
||||
"<mask_177>",
|
||||
"<mask_178>",
|
||||
"<mask_179>",
|
||||
"<mask_180>",
|
||||
"<mask_181>",
|
||||
"<mask_182>",
|
||||
"<mask_183>",
|
||||
"<mask_184>",
|
||||
"<mask_185>",
|
||||
"<mask_186>",
|
||||
"<mask_187>",
|
||||
"<mask_188>",
|
||||
"<mask_189>",
|
||||
"<mask_190>",
|
||||
"<mask_191>",
|
||||
"<mask_192>",
|
||||
"<mask_193>",
|
||||
"<mask_194>",
|
||||
"<mask_195>",
|
||||
"<mask_196>",
|
||||
"<mask_197>",
|
||||
"<mask_198>",
|
||||
"<mask_199>",
|
||||
"<mask_200>",
|
||||
"<mask_201>",
|
||||
"<mask_202>",
|
||||
"<mask_203>",
|
||||
"<mask_204>",
|
||||
"<mask_205>",
|
||||
"<mask_206>",
|
||||
"<mask_207>",
|
||||
"<mask_208>",
|
||||
"<mask_209>",
|
||||
"<mask_210>",
|
||||
"<mask_211>",
|
||||
"<mask_212>",
|
||||
"<mask_213>",
|
||||
"<mask_214>",
|
||||
"<mask_215>",
|
||||
"<mask_216>",
|
||||
"<mask_217>",
|
||||
"<mask_218>",
|
||||
"<mask_219>",
|
||||
"<mask_220>",
|
||||
"<mask_221>",
|
||||
"<mask_222>",
|
||||
"<mask_223>",
|
||||
"<mask_224>",
|
||||
"<mask_225>",
|
||||
"<mask_226>",
|
||||
"<mask_227>",
|
||||
"<mask_228>",
|
||||
"<mask_229>",
|
||||
"<mask_230>",
|
||||
"<mask_231>",
|
||||
"<mask_232>",
|
||||
"<mask_233>",
|
||||
"<mask_234>",
|
||||
"<mask_235>",
|
||||
"<mask_236>",
|
||||
"<mask_237>",
|
||||
"<mask_238>",
|
||||
"<mask_239>",
|
||||
"<mask_240>",
|
||||
"<mask_241>",
|
||||
"<mask_242>",
|
||||
"<mask_243>",
|
||||
"<mask_244>",
|
||||
"<mask_245>",
|
||||
"<mask_246>",
|
||||
"<mask_247>",
|
||||
"<mask_248>",
|
||||
"<mask_249>",
|
||||
"<mask_250>",
|
||||
"<mask_251>",
|
||||
"<mask_252>",
|
||||
"<mask_253>",
|
||||
"<mask_254>",
|
||||
"<mask_255>",
|
||||
"<mask_256>",
|
||||
"<mask_257>",
|
||||
"<mask_258>",
|
||||
"<mask_259>",
|
||||
"<mask_260>",
|
||||
"<mask_261>",
|
||||
"<mask_262>",
|
||||
"<mask_263>",
|
||||
"<mask_264>",
|
||||
"<mask_265>",
|
||||
"<mask_266>",
|
||||
"<mask_267>",
|
||||
"<mask_268>",
|
||||
"<mask_269>",
|
||||
"<mask_270>",
|
||||
"<mask_271>",
|
||||
"<mask_272>",
|
||||
"<mask_273>",
|
||||
"<mask_274>",
|
||||
"<mask_275>",
|
||||
"<mask_276>",
|
||||
"<mask_277>",
|
||||
"<mask_278>",
|
||||
"<mask_279>",
|
||||
"<mask_280>",
|
||||
"<mask_281>",
|
||||
"<mask_282>",
|
||||
"<mask_283>",
|
||||
"<mask_284>",
|
||||
"<mask_285>",
|
||||
"<mask_286>",
|
||||
"<mask_287>",
|
||||
"<mask_288>",
|
||||
"<mask_289>",
|
||||
"<mask_290>",
|
||||
"<mask_291>",
|
||||
"<mask_292>",
|
||||
"<mask_293>",
|
||||
"<mask_294>",
|
||||
"<mask_295>",
|
||||
"<mask_296>",
|
||||
"<mask_297>",
|
||||
"<mask_298>",
|
||||
"<mask_299>",
|
||||
"<mask_300>",
|
||||
"<mask_301>",
|
||||
"<mask_302>",
|
||||
"<mask_303>",
|
||||
"<mask_304>",
|
||||
"<mask_305>",
|
||||
"<mask_306>",
|
||||
"<mask_307>",
|
||||
"<mask_308>",
|
||||
"<mask_309>",
|
||||
"<mask_310>",
|
||||
"<mask_311>",
|
||||
"<mask_312>",
|
||||
"<mask_313>",
|
||||
"<mask_314>",
|
||||
"<mask_315>",
|
||||
"<mask_316>",
|
||||
"<mask_317>",
|
||||
"<mask_318>",
|
||||
"<mask_319>",
|
||||
"<mask_320>",
|
||||
"<mask_321>",
|
||||
"<mask_322>",
|
||||
"<mask_323>",
|
||||
"<mask_324>",
|
||||
"<mask_325>",
|
||||
"<mask_326>",
|
||||
"<mask_327>",
|
||||
"<mask_328>",
|
||||
"<mask_329>",
|
||||
"<mask_330>",
|
||||
"<mask_331>",
|
||||
"<mask_332>",
|
||||
"<mask_333>",
|
||||
"<mask_334>",
|
||||
"<mask_335>",
|
||||
"<mask_336>",
|
||||
"<mask_337>",
|
||||
"<mask_338>",
|
||||
"<mask_339>",
|
||||
"<mask_340>",
|
||||
"<mask_341>",
|
||||
"<mask_342>",
|
||||
"<mask_343>",
|
||||
"<mask_344>",
|
||||
"<mask_345>",
|
||||
"<mask_346>",
|
||||
"<mask_347>",
|
||||
"<mask_348>",
|
||||
"<mask_349>",
|
||||
"<mask_350>",
|
||||
"<mask_351>",
|
||||
"<mask_352>",
|
||||
"<mask_353>",
|
||||
"<mask_354>",
|
||||
"<mask_355>",
|
||||
"<mask_356>",
|
||||
"<mask_357>",
|
||||
"<mask_358>",
|
||||
"<mask_359>",
|
||||
"<mask_360>",
|
||||
"<mask_361>",
|
||||
"<mask_362>",
|
||||
"<mask_363>",
|
||||
"<mask_364>",
|
||||
"<mask_365>",
|
||||
"<mask_366>",
|
||||
"<mask_367>",
|
||||
"<mask_368>",
|
||||
"<mask_369>",
|
||||
"<mask_370>",
|
||||
"<mask_371>",
|
||||
"<mask_372>",
|
||||
"<mask_373>",
|
||||
"<mask_374>",
|
||||
"<mask_375>",
|
||||
"<mask_376>",
|
||||
"<mask_377>",
|
||||
"<mask_378>",
|
||||
"<mask_379>",
|
||||
"<mask_380>",
|
||||
"<mask_381>",
|
||||
"<mask_382>",
|
||||
"<mask_383>",
|
||||
"<extra_id_99>",
|
||||
"<extra_id_98>",
|
||||
"<extra_id_97>",
|
||||
"<extra_id_96>",
|
||||
"<extra_id_95>",
|
||||
"<extra_id_94>",
|
||||
"<extra_id_93>",
|
||||
"<extra_id_92>",
|
||||
"<extra_id_91>",
|
||||
"<extra_id_90>",
|
||||
"<extra_id_89>",
|
||||
"<extra_id_88>",
|
||||
"<extra_id_87>",
|
||||
"<extra_id_86>",
|
||||
"<extra_id_85>",
|
||||
"<extra_id_84>",
|
||||
"<extra_id_83>",
|
||||
"<extra_id_82>",
|
||||
"<extra_id_81>",
|
||||
"<extra_id_80>",
|
||||
"<extra_id_79>",
|
||||
"<extra_id_78>",
|
||||
"<extra_id_77>",
|
||||
"<extra_id_76>",
|
||||
"<extra_id_75>",
|
||||
"<extra_id_74>",
|
||||
"<extra_id_73>",
|
||||
"<extra_id_72>",
|
||||
"<extra_id_71>",
|
||||
"<extra_id_70>",
|
||||
"<extra_id_69>",
|
||||
"<extra_id_68>",
|
||||
"<extra_id_67>",
|
||||
"<extra_id_66>",
|
||||
"<extra_id_65>",
|
||||
"<extra_id_64>",
|
||||
"<extra_id_63>",
|
||||
"<extra_id_62>",
|
||||
"<extra_id_61>",
|
||||
"<extra_id_60>",
|
||||
"<extra_id_59>",
|
||||
"<extra_id_58>",
|
||||
"<extra_id_57>",
|
||||
"<extra_id_56>",
|
||||
"<extra_id_55>",
|
||||
"<extra_id_54>",
|
||||
"<extra_id_53>",
|
||||
"<extra_id_52>",
|
||||
"<extra_id_51>",
|
||||
"<extra_id_50>",
|
||||
"<extra_id_49>",
|
||||
"<extra_id_48>",
|
||||
"<extra_id_47>",
|
||||
"<extra_id_46>",
|
||||
"<extra_id_45>",
|
||||
"<extra_id_44>",
|
||||
"<extra_id_43>",
|
||||
"<extra_id_42>",
|
||||
"<extra_id_41>",
|
||||
"<extra_id_40>",
|
||||
"<extra_id_39>",
|
||||
"<extra_id_38>",
|
||||
"<extra_id_37>",
|
||||
"<extra_id_36>",
|
||||
"<extra_id_35>",
|
||||
"<extra_id_34>",
|
||||
"<extra_id_33>",
|
||||
"<extra_id_32>",
|
||||
"<extra_id_31>",
|
||||
"<extra_id_30>",
|
||||
"<extra_id_29>",
|
||||
"<extra_id_28>",
|
||||
"<extra_id_27>",
|
||||
"<extra_id_26>",
|
||||
"<extra_id_25>",
|
||||
"<extra_id_24>",
|
||||
"<extra_id_23>",
|
||||
"<extra_id_22>",
|
||||
"<extra_id_21>",
|
||||
"<extra_id_20>",
|
||||
"<extra_id_19>",
|
||||
"<extra_id_18>",
|
||||
"<extra_id_17>",
|
||||
"<extra_id_16>",
|
||||
"<extra_id_15>",
|
||||
"<extra_id_14>",
|
||||
"<extra_id_13>",
|
||||
"<extra_id_12>",
|
||||
"<extra_id_11>",
|
||||
"<extra_id_10>",
|
||||
"<extra_id_9>",
|
||||
"<extra_id_8>",
|
||||
"<extra_id_7>",
|
||||
"<extra_id_6>",
|
||||
"<extra_id_5>",
|
||||
"<extra_id_4>",
|
||||
"<extra_id_3>",
|
||||
"<extra_id_2>",
|
||||
"<extra_id_1>",
|
||||
"<extra_id_0>",
|
||||
"match",
|
||||
"type",
|
||||
"HAVE_ARGUMENT",
|
||||
"CALL_INTRINSIC_1",
|
||||
"CALL_INTRINSIC_2",
|
||||
"JUMP_NO_INTERRUPT",
|
||||
"nargs",
|
||||
"vargs",
|
||||
"compare",
|
||||
"name",
|
||||
"const",
|
||||
"local"
|
||||
],
|
||||
"bos_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<s>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"cls_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<s>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"eos_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "</s>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"errors": "replace",
|
||||
"mask_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<mask>",
|
||||
"lstrip": true,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"model_max_length": 512,
|
||||
"pad_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<pad>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"sep_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "</s>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"tokenizer_class": "RobertaTokenizer",
|
||||
"trim_offsets": true,
|
||||
"unk_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<unk>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
import os
|
||||
import pathlib
|
||||
import time
|
||||
from datetime import timedelta
|
||||
import click
|
||||
|
||||
from datasets import ReadInstruction, load_dataset
|
||||
from StatementConfiguration import StatementConfiguration, parse_statement_config_json
|
||||
from transformers import (
|
||||
DataCollatorForSeq2Seq,
|
||||
RobertaTokenizer,
|
||||
Seq2SeqTrainer,
|
||||
Seq2SeqTrainingArguments,
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
|
||||
|
||||
def load_tokenized_train_dataset(dataset_repo_name: str, dataset_percentage: int):
|
||||
# Load the tokenized dataset
|
||||
tokenized_train_dataset = load_dataset(
|
||||
dataset_repo_name,
|
||||
token=True,
|
||||
split=ReadInstruction("train", to=dataset_percentage, unit="%"),
|
||||
)
|
||||
return tokenized_train_dataset
|
||||
|
||||
|
||||
def train_statement_model(config: StatementConfiguration):
|
||||
# load model, Salesforce/codet5-base is a pretrained model solving the code generation task.
|
||||
tokenizer = RobertaTokenizer.from_pretrained(config.tokenizer_repo_name)
|
||||
model = T5ForConditionalGeneration.from_pretrained(config.pretrained_seq2seq_repo_name)
|
||||
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
|
||||
|
||||
model_dir = str(config.statement_model_dir)
|
||||
model_repo_name = config.statement_model_repo_name
|
||||
|
||||
train_args = Seq2SeqTrainingArguments(
|
||||
output_dir=model_dir,
|
||||
learning_rate=config.statement_training_parameters.learning_rate,
|
||||
per_device_train_batch_size=config.statement_training_parameters.batch_size,
|
||||
per_device_eval_batch_size=config.statement_training_parameters.batch_size,
|
||||
weight_decay=0.01,
|
||||
fp16=config.fp16,
|
||||
logging_dir=str(config.log_dir),
|
||||
report_to="tensorboard",
|
||||
logging_strategy="steps",
|
||||
logging_steps=1000,
|
||||
save_strategy="steps",
|
||||
save_steps=10000,
|
||||
save_total_limit=2,
|
||||
num_train_epochs=config.statement_training_parameters.epochs,
|
||||
predict_with_generate=True,
|
||||
push_to_hub=True,
|
||||
hub_model_id=model_repo_name,
|
||||
hub_private_repo=True,
|
||||
ddp_backend="nccl",
|
||||
ddp_find_unused_parameters=False,
|
||||
)
|
||||
|
||||
tokenized_train_dataset = load_tokenized_train_dataset(config.tokenized_dataset_repo_name, config.dataset_percentage)
|
||||
trainer = Seq2SeqTrainer(
|
||||
model=model,
|
||||
args=train_args,
|
||||
data_collator=data_collator,
|
||||
train_dataset=tokenized_train_dataset,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
start = time.time()
|
||||
trainer.train()
|
||||
duration = str(timedelta(seconds=time.time() - start))
|
||||
|
||||
if int(os.environ["LOCAL_RANK"]) == 0:
|
||||
# upload the latest version of the model to the Model Hub on Huggingface
|
||||
trainer.save_model(str(config.statement_model_dir))
|
||||
# this command returns the URL of the commit it just did
|
||||
trainer.push_to_hub(
|
||||
commit_message=duration,
|
||||
finetuned_from=config.pretrained_seq2seq_repo_name,
|
||||
dataset=config.tokenized_dataset_repo_name,
|
||||
)
|
||||
|
||||
|
||||
@click.command(help="Training script for the statement translation model given a statement json.")
|
||||
@click.argument("json_path", type=str)
|
||||
def main(json_path: str):
|
||||
json_file_path = pathlib.Path(json_path)
|
||||
statement_config = parse_statement_config_json(json_file_path)
|
||||
train_statement_model(statement_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,93 @@
|
||||
import logging
|
||||
import pathlib
|
||||
import click
|
||||
|
||||
from datasets import ReadInstruction, load_dataset
|
||||
from huggingface_hub import HfApi, repo_exists
|
||||
from StatementConfiguration import StatementConfiguration, parse_statement_config_json
|
||||
from tokenizers import Tokenizer
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
def get_untrained_tokenizer(tokenizer_repo_name: str) -> AutoTokenizer:
|
||||
tokenizer_dir = pathlib.Path(__file__).parent / tokenizer_repo_name
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
|
||||
return tokenizer
|
||||
|
||||
|
||||
def save_and_upload_tokenizer(
|
||||
tokenizer: Tokenizer,
|
||||
tokenizer_json_path: pathlib.Path,
|
||||
tokenizer_repo_name: str,
|
||||
dataset_name: str,
|
||||
):
|
||||
# Save the tokenizer locally
|
||||
tokenizer.save_pretrained(str(tokenizer_json_path.parent.resolve()))
|
||||
|
||||
# Upload files to Hugging Face Hub
|
||||
api = HfApi()
|
||||
api.create_repo(tokenizer_repo_name, exist_ok=True, private=True)
|
||||
api.upload_file(
|
||||
path_in_repo="tokenizer_config.json",
|
||||
path_or_fileobj=str(tokenizer_json_path.parent / "tokenizer_config.json"),
|
||||
repo_id=tokenizer_repo_name,
|
||||
commit_message=f"Trained tokenizer using {dataset_name}",
|
||||
)
|
||||
api.upload_file(
|
||||
path_in_repo="vocab.json",
|
||||
path_or_fileobj=str(tokenizer_json_path.parent / "vocab.json"),
|
||||
repo_id=tokenizer_repo_name,
|
||||
commit_message="Extracted vocabulary from tokenizer",
|
||||
)
|
||||
api.upload_file(
|
||||
path_in_repo="merges.txt",
|
||||
path_or_fileobj=str(tokenizer_json_path.parent / "merges.txt"),
|
||||
repo_id=tokenizer_repo_name,
|
||||
commit_message="Extracted merges from tokenizer",
|
||||
)
|
||||
api.upload_file(
|
||||
path_in_repo="tokenizer.json",
|
||||
path_or_fileobj=str(tokenizer_json_path.parent / "tokenizer.json"),
|
||||
repo_id=tokenizer_repo_name,
|
||||
commit_message="Extracted tokenizer",
|
||||
)
|
||||
api.upload_file(
|
||||
path_in_repo="special_tokens_map.json",
|
||||
path_or_fileobj=str(tokenizer_json_path.parent / "special_tokens_map.json"),
|
||||
repo_id=tokenizer_repo_name,
|
||||
commit_message="Extracted special tokens map",
|
||||
)
|
||||
|
||||
|
||||
def train_tokenizer(config: StatementConfiguration, tokenizer_json_path: pathlib.Path):
|
||||
if repo_exists(config.base_repo_name):
|
||||
logging.error(f"{config.base_repo_name} has already exists")
|
||||
exit(1)
|
||||
|
||||
tokenizer = get_untrained_tokenizer("tokenizer")
|
||||
|
||||
train_dataset = load_dataset(
|
||||
config.dataset_repo_name,
|
||||
token=True,
|
||||
split=ReadInstruction("train", to=config.dataset_percentage, unit="%"),
|
||||
)["bytecode"]
|
||||
|
||||
tokenizer = tokenizer.train_new_from_iterator(train_dataset, vocab_size=30000)
|
||||
save_and_upload_tokenizer(
|
||||
tokenizer,
|
||||
tokenizer_json_path,
|
||||
config.tokenizer_repo_name,
|
||||
config.dataset_repo_name,
|
||||
)
|
||||
|
||||
|
||||
@click.command(help="Training script for the bytecode tokenizer for the statement model given a statement json.")
|
||||
@click.argument("json_path", type=str)
|
||||
def main(json_path: str):
|
||||
json_file_path = pathlib.Path(json_path)
|
||||
statement_config = parse_statement_config_json(json_file_path)
|
||||
train_tokenizer(statement_config, json_file_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user