mirror of
https://github.com/syssec-utd/pylingual.git
synced 2026-05-10 18:39:03 -07:00
rename
This commit is contained in:
@@ -0,0 +1,48 @@
|
||||
# Model Training
|
||||
|
||||
PyLingual's accuracy is dependent on having accurate segmentation and statement models [^1]. The segmentation model divides a list of bytecode instructions into groups for each source instruction. The statement model transforms each group of instructions into source code. The instructions for training these models is as follows:
|
||||
|
||||
## Dataset generation
|
||||
|
||||
First install [pyenv](https://github.com/pyenv/pyenv) and the required Python versions for the dataset. Create a dataset JSON file based off the sample (`sample_jsons/py36-sample-data.json`).
|
||||
|
||||
The dataset directory should be structured like so, with only one `.py` file per directory:
|
||||
|
||||
```
|
||||
dataset
|
||||
├── 0
|
||||
│ └── file.py
|
||||
├── 1
|
||||
│ └── file.py
|
||||
...
|
||||
├── 999
|
||||
│ └── file.py
|
||||
└── 1000
|
||||
└── file.py
|
||||
```
|
||||
|
||||
The names of the inner directories and files do not matter. Then create the dataset:
|
||||
|
||||
```
|
||||
python prepare_dataset.py <path to JSON>
|
||||
```
|
||||
|
||||
## Segmentation model
|
||||
|
||||
Create a segmentation model JSON file based off the sample (`sample_jsons/py36-sample-segmentation.json`). Then train the model:
|
||||
|
||||
```
|
||||
python train_models.py --segmentation <path to JSON>
|
||||
```
|
||||
|
||||
## Statement model
|
||||
|
||||
Create a statement model JSON file based off the sample (`sample_jsons/py36-sample-statement.json`). Then train the model:
|
||||
|
||||
```
|
||||
python train_models.py --statement <path to JSON>
|
||||
```
|
||||
|
||||
Once models are trained, update `../pylingual/decompiler_config.yaml` or create a separate config file by replacing the old models with the newly trained ones.
|
||||
|
||||
[^1]: [pylingual models](https://huggingface.co/syssec-utd).
|
||||
@@ -0,0 +1,173 @@
|
||||
import contextlib
|
||||
import difflib
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import warnings
|
||||
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import click
|
||||
import rich
|
||||
from rich.progress import track
|
||||
|
||||
from pylingual.control_flow_reconstruction.cfg import CFG
|
||||
from pylingual.control_flow_reconstruction.structure import bc_to_cft
|
||||
from pylingual.main import print_result
|
||||
from pylingual.control_flow_reconstruction.source import SourceContext
|
||||
from pylingual.editable_bytecode import PYCFile
|
||||
from pylingual.equivalence_check import compare_pyc
|
||||
from pylingual.utils.version import PythonVersion
|
||||
from pylingual.utils.generate_bytecode import compile_version, CompileError
|
||||
from dataset_generation.normalize_source import normalize_source
|
||||
|
||||
|
||||
class Result(Enum):
|
||||
Success = "success"
|
||||
Failure = "failure"
|
||||
Error = "error"
|
||||
CompileError = "compile_error"
|
||||
|
||||
|
||||
def edit_pyc_lines(pyc: PYCFile, src_lines: list[str]):
|
||||
if pyc.version == (3, 10):
|
||||
pyc.replace_duplicated_returns10(src_lines)
|
||||
elif pyc.version == (3, 12):
|
||||
pyc.replace_duplicated_returns12(src_lines)
|
||||
seen_lines = set()
|
||||
# multiple instructions can start the same lno, but the segmentation model will only assign the lno to the first one
|
||||
for bc in pyc.iter_bytecodes():
|
||||
if bc.is_comprehension:
|
||||
continue
|
||||
|
||||
# create a dict of line num : [bytecodes composing line]
|
||||
lno_bytecodes = bc.get_lno_insts(previously_seen_lines=seen_lines)
|
||||
seen_lines.update(lno_bytecodes.keys())
|
||||
|
||||
for lno, line_insts in lno_bytecodes.items():
|
||||
line_insts[0].starts_line = lno
|
||||
for inst in line_insts[1:]:
|
||||
inst.starts_line = None
|
||||
|
||||
|
||||
def run(file: Path, out_dir: Path, version: PythonVersion, print=False):
|
||||
try:
|
||||
out_dir = get_unused(out_dir / file.stem, False)
|
||||
out_dir.mkdir(parents=True)
|
||||
if file.is_dir():
|
||||
file = next(file.iterdir())
|
||||
|
||||
in_src = normalize_source(file.read_text(), replace_docstrings=True)
|
||||
src_lines = in_src.split("\n")
|
||||
in_path = out_dir / "a.py"
|
||||
in_path.write_text(in_src, encoding="utf-8")
|
||||
in_pyc = out_dir / "a.pyc"
|
||||
|
||||
compile_version(in_path, in_pyc, version)
|
||||
pyc = PYCFile(in_pyc)
|
||||
edit_pyc_lines(pyc, src_lines)
|
||||
|
||||
cfts = {bc.codeobj: bc_to_cft(bc) for bc in pyc.iter_bytecodes()}
|
||||
out_src = normalize_source(str(SourceContext(pyc, src_lines, cfts)))
|
||||
|
||||
out_path = out_dir / "b.py"
|
||||
out_path.write_text(out_src, encoding="utf-8")
|
||||
out_pyc = out_dir / "b.pyc"
|
||||
compile_version(out_path, out_pyc, version)
|
||||
result = compare_pyc(in_pyc, out_pyc)
|
||||
if print:
|
||||
print_result(f"Equivalance results for {file}", result)
|
||||
return Result.Success if all(x.success for x in result) else Result.Failure, file
|
||||
except (CompileError, SyntaxError):
|
||||
return Result.CompileError, file
|
||||
except Exception:
|
||||
rich.get_console().print_exception()
|
||||
return Result.Error, file
|
||||
|
||||
|
||||
class NoPool:
|
||||
imap_unordered = map
|
||||
|
||||
|
||||
def print_diff(a: Path, b: Path):
|
||||
a_lines = a.read_text().split("\n")
|
||||
b_lines = b.read_text().split("\n")
|
||||
console = rich.console.Console(highlight=False)
|
||||
for line in difflib.unified_diff(a_lines, b_lines, str(a), str(b)):
|
||||
style = "red" if line[0] == "-" else "green" if line[0] == "+" else "blue" if line[0] == "@" else ""
|
||||
console.print(line, style=style)
|
||||
|
||||
|
||||
def get_unused(a: Path, _=True):
|
||||
if not _ and not a.exists():
|
||||
return a
|
||||
stem = a.stem
|
||||
i = 0
|
||||
while True:
|
||||
a = a.with_stem(f"{stem}_{i}")
|
||||
i += 1
|
||||
if not a.exists():
|
||||
return a
|
||||
|
||||
|
||||
@click.command(help="Run the control-flow reconstructor")
|
||||
@click.argument("input", type=Path)
|
||||
@click.argument("output", type=str, default="")
|
||||
@click.option("-v", "--version", type=PythonVersion, default=PythonVersion((3, 12)), help="Python version to compile as")
|
||||
@click.option("-p", "--processes", type=int, default=os.cpu_count(), help="Number of processes")
|
||||
@click.option("-d", "--prefix", type=Path, default=Path("/tmp/cflow_test"), help="Base dir for all output")
|
||||
@click.option("-g", "--graph", is_flag=False, flag_value="graph", help="Enable CFG visualization")
|
||||
@click.option("-f", "--graph-format", default="jpg", help="Output format supported by pydot")
|
||||
def main(input: Path, output: str, version: PythonVersion, graph: str | None, prefix: Path, processes: int, graph_format: str):
|
||||
warnings.filterwarnings("ignore")
|
||||
print = rich.get_console().print
|
||||
if graph:
|
||||
CFG.enable_graphing(prefix / graph, graph_format)
|
||||
if input.is_file() and input.suffix == ".py":
|
||||
if output:
|
||||
out = contextlib.nullcontext(output)
|
||||
else:
|
||||
out = TemporaryDirectory()
|
||||
with out as o:
|
||||
o = Path(o)
|
||||
results = run(input, o, version)[0]
|
||||
if results in [Result.CompileError, Result.Error]:
|
||||
print(results)
|
||||
else:
|
||||
print_diff(o / input.name / "a.py", o / input.name / "b.py")
|
||||
else:
|
||||
if not output:
|
||||
out_dir = get_unused(prefix / input.stem)
|
||||
else:
|
||||
out_dir = prefix / output
|
||||
print(f"Saving results to {out_dir}")
|
||||
results = defaultdict(list)
|
||||
f = partial(run, out_dir=out_dir, version=version)
|
||||
if input.is_dir():
|
||||
files = list(input.iterdir())
|
||||
else:
|
||||
files = list(map(Path, input.read_text().strip().split("\n")))
|
||||
|
||||
if processes > 1:
|
||||
pool = multiprocessing.Pool(processes=processes)
|
||||
else:
|
||||
pool = contextlib.nullcontext(NoPool)
|
||||
with pool as p:
|
||||
for result, input in track(p.imap_unordered(f, files), total=len(files)):
|
||||
results[result].append(input)
|
||||
|
||||
for res in Result:
|
||||
print(f"{res}: {len(results[res])}")
|
||||
total = sum(len(x) for x in results.values())
|
||||
if total:
|
||||
print(f"{len(results[Result.Success])} / {total} succeeded ({len(results[Result.Success]) / total:.3%})")
|
||||
res = json.dumps({k.value: list(map(str, v)) for k, v in results.items()})
|
||||
(out_dir / "results.json").write_text(res)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,50 @@
|
||||
from dataclasses import dataclass
|
||||
import pathlib
|
||||
|
||||
from typing import Tuple, List
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataRequest:
|
||||
name: str
|
||||
source_path: pathlib.Path
|
||||
num_train: int
|
||||
num_test: int
|
||||
num_valid: int
|
||||
|
||||
@property
|
||||
def total_files(self):
|
||||
return self.num_train + self.num_test + self.num_valid
|
||||
|
||||
def __post_init__(self):
|
||||
self.source_path = pathlib.Path(self.source_path)
|
||||
if not self.source_path.exists():
|
||||
raise FileNotFoundError(f"{self.source_path} for DataRequest {self.name} does not exist")
|
||||
|
||||
if self.num_train < 0:
|
||||
raise ValueError(f"Training sample count for DataRequest {self.name} must be non-negative")
|
||||
if self.num_test < 0:
|
||||
raise ValueError(f"Testing sample count for DataRequest {self.name} must be non-negative")
|
||||
if self.num_valid < 0:
|
||||
raise ValueError(f"Validation sample count for DataRequest {self.name} must be non-negative")
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetDescription:
|
||||
name: str
|
||||
version: Tuple[int, int]
|
||||
save_to_dir: pathlib.Path
|
||||
huggingface_user: str
|
||||
data_requests: List[DataRequest]
|
||||
|
||||
@property
|
||||
def code_dir(self):
|
||||
return self.save_to_dir / self.name / "code"
|
||||
|
||||
@property
|
||||
def csv_dir(self):
|
||||
return self.save_to_dir / self.name / "csv"
|
||||
|
||||
def __post_init__(self):
|
||||
self.save_to_dir = pathlib.Path(self.save_to_dir)
|
||||
self.version = tuple(self.version)
|
||||
@@ -0,0 +1,3 @@
|
||||
from .create_code_dataset import transfer_and_compile_file
|
||||
|
||||
__all__ = ["transfer_and_compile_file"]
|
||||
@@ -0,0 +1,216 @@
|
||||
import csv
|
||||
import itertools
|
||||
import logging
|
||||
import multiprocessing
|
||||
import pathlib
|
||||
import re
|
||||
import signal
|
||||
from typing import Callable, Tuple
|
||||
|
||||
import tqdm
|
||||
from pylingual.editable_bytecode import PYCFile
|
||||
|
||||
from pylingual.masking.ast_masker import DUMMY_DECORATOR
|
||||
from pylingual.masking.model_disasm import fix_jump_targets
|
||||
from .DatasetDescription import DataRequest
|
||||
from pylingual.masking.model_disasm import create_global_masker, mask_source
|
||||
|
||||
bytecode_separator = " <SEP> "
|
||||
source_seperator = " <SEP> "
|
||||
CSV_SGMT_HEADER = ["source", "bytecode", "boundary", "file"]
|
||||
CSV_STMT_HEADER = ["source", "bytecode", "file"]
|
||||
|
||||
|
||||
def create_csv_dataset(code_dataset_path: pathlib.Path, csv_dataset_path: pathlib.Path, data_requests: list[DataRequest], logger: logging.Logger = None):
|
||||
progress_bar = tqdm.tqdm(total=sum([request.total_files for request in data_requests]))
|
||||
for split in ("train", "test", "valid"):
|
||||
if logger:
|
||||
logger.info(f"Converting the {split} split to CSV...")
|
||||
write_csvs(code_dataset_path / split, csv_dataset_path / split, logger, progress_bar=progress_bar)
|
||||
|
||||
|
||||
def write_csvs(source_path: pathlib.Path, csv_output_path: pathlib.Path, logger: logging.Logger = None, max_csv_rows: int = 30000, progress_bar: tqdm.tqdm = None):
|
||||
# validate output directory
|
||||
if csv_output_path.exists():
|
||||
if not csv_output_path.is_dir():
|
||||
raise OSError("CSV output path is not a directory")
|
||||
else:
|
||||
csv_output_path.mkdir(parents=True)
|
||||
|
||||
##### csv write wrappers to preserve csv row limit
|
||||
|
||||
def csv_writer(file_prefix: str, csv_header: list) -> Callable:
|
||||
out_dir = csv_output_path.joinpath(file_prefix)
|
||||
out_dir.mkdir(exist_ok=True)
|
||||
|
||||
for csv_idx in itertools.count():
|
||||
new_path = out_dir.joinpath(f"{file_prefix}_{csv_idx}.csv")
|
||||
new_path.touch()
|
||||
if logger:
|
||||
logger.info(f"Creating new csv {new_path.resolve()}...")
|
||||
with new_path.open(mode="w") as csv_file:
|
||||
writer = csv.writer(csv_file)
|
||||
writer.writerow(csv_header)
|
||||
for writer in itertools.repeat(writer, max_csv_rows):
|
||||
yield writer.writerow
|
||||
|
||||
segmentation_writer = csv_writer("segmentation", CSV_SGMT_HEADER)
|
||||
statement_writer = csv_writer("statement", CSV_STMT_HEADER)
|
||||
|
||||
# create dirs
|
||||
code_dirs = (child for child in source_path.iterdir() if child.is_dir())
|
||||
|
||||
def bytecode2csv_args():
|
||||
for dir in code_dirs:
|
||||
py_path = next(dir.glob("*.py"), None)
|
||||
pyc_path = next(dir.glob("*.pyc"), None)
|
||||
if None in (py_path, pyc_path):
|
||||
logging.debug(f"PY or PYC file not found in {dir}")
|
||||
continue
|
||||
else:
|
||||
yield (py_path, pyc_path)
|
||||
|
||||
num_fails = 0
|
||||
with multiprocessing.Pool() as pool:
|
||||
for result in pool.imap_unordered(bytecode2csv_exception_wrapper, bytecode2csv_args()):
|
||||
if isinstance(result, Exception):
|
||||
num_fails += 1
|
||||
logger.debug(f"DIR: {dir}\nERR: {result}\nTYPE ERR: {type(result)}\n")
|
||||
continue
|
||||
|
||||
(segmentation_rows, statement_rows) = result
|
||||
for row, writerow in zip(segmentation_rows, segmentation_writer):
|
||||
writerow(row)
|
||||
for row, writerow in zip(statement_rows, statement_writer):
|
||||
writerow(row)
|
||||
|
||||
if progress_bar:
|
||||
progress_bar.update()
|
||||
progress_bar.set_postfix({"num_fails": num_fails})
|
||||
|
||||
logger.info(f"NUMBER OF FAILS !!! {num_fails}")
|
||||
|
||||
|
||||
def timeout_handler(signum, frame):
|
||||
raise TimeoutError()
|
||||
|
||||
|
||||
def bytecode2csv_exception_wrapper(paths=Tuple[pathlib.Path, pathlib.Path]) -> Tuple[list, list] | Exception:
|
||||
signal.signal(signal.SIGALRM, timeout_handler)
|
||||
try:
|
||||
signal.alarm(30) # set 30 second timeout
|
||||
results = bytecode2csv(*paths)
|
||||
signal.alarm(0) # success; disable timer
|
||||
return results
|
||||
except Exception as error:
|
||||
signal.alarm(0) # disable timer in case another exception triggered the fail
|
||||
return Exception(f"{type(error)}: {error} in file {paths}")
|
||||
|
||||
|
||||
def bytecode2csv(py_path: pathlib.Path, pyc_path: pathlib.Path) -> tuple[list, list]:
|
||||
"""Creates segmentation and statement csv rows for given bytecode and source file"""
|
||||
segmentation_rows = []
|
||||
statement_rows = []
|
||||
|
||||
pyc = PYCFile(str(pyc_path.resolve()))
|
||||
if pyc.version == (3, 10):
|
||||
pyc.replace_duplicated_returns10(py_path.read_text().split("\n"))
|
||||
elif pyc.version == (3, 12):
|
||||
pyc.replace_duplicated_returns12(py_path.read_text().split("\n"))
|
||||
global_masker = create_global_masker(pyc)
|
||||
|
||||
masked_source_text = mask_source(py_path, global_masker, pyc.version)
|
||||
masked_source_lines = masked_source_text.split("\n")
|
||||
|
||||
# filter out dummy decorators added in <= 3.7
|
||||
dummy_lnos = []
|
||||
if pyc.version <= (3, 7):
|
||||
# remove dummy decorators from bytecode'
|
||||
pyc._patch_dummy_decorator(dummy_decorator_name=DUMMY_DECORATOR)
|
||||
try: # if no functions are in source, then dummy will not exist
|
||||
dummy_decorator_line = f"@{global_masker.mask(DUMMY_DECORATOR)}"
|
||||
except KeyError:
|
||||
dummy_decorator_line = None
|
||||
dummy_lnos = [lno + 1 for lno, source in enumerate(masked_source_lines) if source.strip() == dummy_decorator_line]
|
||||
|
||||
seen_lines = set()
|
||||
|
||||
# create rows for each bytecode
|
||||
for bc in pyc.iter_bytecodes():
|
||||
# we ignore comprehensions, hoisted later
|
||||
if bc.is_comprehension:
|
||||
continue
|
||||
|
||||
# attempt to filter lines
|
||||
lno_insts = bc.get_lno_insts(previously_seen_lines=seen_lines)
|
||||
|
||||
# create line num : model disasm view of insts
|
||||
lno_model_view_insts = {lno: [global_masker.get_model_view(inst) for inst in line_insts] for lno, line_insts in lno_insts.items()}
|
||||
seen_lines.update(lno_model_view_insts.keys())
|
||||
|
||||
# segment source
|
||||
if pyc.version <= (3, 7):
|
||||
segmented_source_lines = []
|
||||
for line_num in lno_model_view_insts:
|
||||
if not line_num:
|
||||
segmented_source_lines.append("")
|
||||
elif line_num in dummy_lnos:
|
||||
segmented_source_lines.append(masked_source_lines[line_num].strip())
|
||||
else:
|
||||
segmented_source_lines.append(masked_source_lines[line_num - 1].strip())
|
||||
else:
|
||||
segmented_source_lines = [masked_source_lines[line_num - 1].strip() if line_num else "" for line_num in lno_model_view_insts.keys()] # -1 to convert from line num to index in array
|
||||
|
||||
model_disasm_text = bytecode_separator.join(val for val in itertools.chain(*lno_model_view_insts.values()))
|
||||
|
||||
if len(segmented_source_lines) != len(lno_model_view_insts):
|
||||
raise ValueError("Length mismatch between segmented source and segmented bytecodes")
|
||||
|
||||
# create bytecode segmentation
|
||||
boundaries = []
|
||||
for bc_line in lno_model_view_insts.values():
|
||||
if len(bc_line) == 1:
|
||||
bounds = "B"
|
||||
elif len(bc_line) >= 2:
|
||||
bounds = "B" + "I" * (len(bc_line) - 2) + "E"
|
||||
else:
|
||||
raise ValueError("Unexpected amount of bytecodes segmented into a line")
|
||||
boundaries.extend(list(bounds))
|
||||
|
||||
# append rows
|
||||
segmentation_rows.append([source_seperator.join(segmented_source_lines), model_disasm_text, boundaries, str(py_path)])
|
||||
for segmented_source, bytecodes in zip(segmented_source_lines, lno_model_view_insts.values()):
|
||||
# skip empty lines
|
||||
if not segmented_source or segmented_source == "None":
|
||||
continue
|
||||
# skip fillers
|
||||
if segmented_source in ("pass", "...") and ("RETURN_VALUE" in bytecodes or "RETURN_CONST , None" in bytecodes):
|
||||
continue
|
||||
# skip string-only lines that aren't docstrings
|
||||
if (segmented_source.startswith("'") or segmented_source.startswith('"')) and not any("__doc__" in b for b in bytecodes):
|
||||
continue
|
||||
if segmented_source.startswith("elif "):
|
||||
segmented_source = segmented_source[2:]
|
||||
|
||||
joined_bytecode = bytecode_separator.join(bytecodes)
|
||||
|
||||
# DUCT-TAPE; skip samples where model has to guess masks
|
||||
source_masks = set(re.findall(r"<mask_\d+>", segmented_source))
|
||||
bytecode_masks = set(re.findall(r"<mask_\d+>", joined_bytecode))
|
||||
if not source_masks <= bytecode_masks:
|
||||
continue
|
||||
|
||||
# normalize source mask order for statements
|
||||
# replace mask values to start at 0 and count up
|
||||
mask_regex = re.compile(r"(?<=<mask_)\d+(?=>)")
|
||||
masks = mask_regex.findall(joined_bytecode)
|
||||
mask_order = [x for i, x in enumerate(masks) if masks.index(x) == i]
|
||||
normalized_mask_bytecode = mask_regex.sub(lambda x: str(mask_order.index(x.group(0))), joined_bytecode)
|
||||
normalized_mask_source = mask_regex.sub(lambda x: str(mask_order.index(x.group(0))), segmented_source)
|
||||
|
||||
# normalize jump targets
|
||||
normalized_mask_bytecode = fix_jump_targets(normalized_mask_bytecode)
|
||||
|
||||
statement_rows.append([normalized_mask_source, normalized_mask_bytecode, str(py_path)])
|
||||
|
||||
return (segmentation_rows, statement_rows)
|
||||
@@ -0,0 +1,114 @@
|
||||
import itertools
|
||||
import logging
|
||||
import multiprocessing
|
||||
import pathlib
|
||||
import random
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
import tqdm
|
||||
|
||||
from .DatasetDescription import DataRequest
|
||||
from pylingual.utils.generate_bytecode import compile_version
|
||||
from .normalize_source import normalize_source
|
||||
from pylingual.masking.ast_masker import add_dummy_decorators
|
||||
|
||||
|
||||
def transfer_and_compile_file(
|
||||
original_file: pathlib.Path,
|
||||
destination_file: pathlib.Path,
|
||||
version: Tuple[int, int],
|
||||
) -> Optional[Exception]:
|
||||
# copy over normalized source file
|
||||
try:
|
||||
normalized_source = normalize_source(original_file.read_text(), version=version, replace_docstrings=True)
|
||||
|
||||
if version[:2] <= (3, 7):
|
||||
normalized_source = add_dummy_decorators(normalized_source)
|
||||
except Exception as err:
|
||||
return err
|
||||
|
||||
destination_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
destination_file.write_text(normalized_source)
|
||||
|
||||
# compile the copied file with the given version
|
||||
try:
|
||||
compile_version(
|
||||
destination_file.resolve(),
|
||||
destination_file.with_suffix(".pyc").resolve(),
|
||||
version,
|
||||
)
|
||||
except Exception as err:
|
||||
return err
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def star_transfer_and_compile_file(args) -> Optional[Exception]:
|
||||
return transfer_and_compile_file(*args)
|
||||
|
||||
|
||||
# samples num_files files from the given directory
|
||||
# expects the directory to have the structure
|
||||
# source_dir -> identifier -> file.py
|
||||
def sample_directory_splits(
|
||||
data_request: DataRequest,
|
||||
) -> Tuple[List[pathlib.Path], List[pathlib.Path], List[pathlib.Path]]:
|
||||
all_files: Set[pathlib.Path] = set()
|
||||
for identifier in data_request.source_path.iterdir():
|
||||
source_file = next(identifier.glob("*.py"), None) # get the first python file from the identifier
|
||||
if source_file is not None:
|
||||
all_files.add(source_file)
|
||||
|
||||
# sample batches until we have enough files to satisfy the data requests
|
||||
# this avoids running expensive tests on unsampled files
|
||||
clean_sample: Set[pathlib.Path] = set()
|
||||
while len(clean_sample) < data_request.total_files:
|
||||
remaining_files = data_request.total_files - len(clean_sample)
|
||||
sample_batch = random.sample(list(all_files), k=remaining_files)
|
||||
# add the acceptable files to the sample and remove them from the population
|
||||
to_add = set(candidate for candidate in sample_batch if candidate is not None)
|
||||
clean_sample.update(to_add)
|
||||
all_files -= to_add
|
||||
|
||||
full_sample = iter(clean_sample)
|
||||
|
||||
train = list(itertools.islice(full_sample, data_request.num_train))
|
||||
test = list(itertools.islice(full_sample, data_request.num_test))
|
||||
valid = list(itertools.islice(full_sample, data_request.num_valid))
|
||||
|
||||
return train, test, valid
|
||||
|
||||
|
||||
def prepare_single_directory_transfer_args(data_request: DataRequest, target_dir: pathlib.Path) -> List[Tuple[pathlib.Path, pathlib.Path]]:
|
||||
train, test, valid = sample_directory_splits(data_request)
|
||||
|
||||
transfer_args = []
|
||||
for split_name, split_files in zip(("train", "test", "valid"), (train, test, valid)):
|
||||
for source_file in split_files:
|
||||
target_file = target_dir / split_name / f"{data_request.name}-{source_file.parent.name}" / source_file.name
|
||||
transfer_args.append((source_file, target_file))
|
||||
|
||||
return transfer_args
|
||||
|
||||
|
||||
# takes a dict of {<source directory>: (num_train, num_test, num_valid)} and a target directory
|
||||
# makes train, test, and split directories in the target directory with the normalized source files
|
||||
def create_code_dataset(
|
||||
data_requests: List[DataRequest],
|
||||
target_dir: pathlib.Path,
|
||||
version: Tuple[int, int],
|
||||
logger: logging.Logger,
|
||||
):
|
||||
with multiprocessing.Pool() as pool:
|
||||
# prepare a list of file transfers to execute
|
||||
logger.info(f"Sampling {', '.join(str(req.source_path.resolve()) for req in data_requests)}...")
|
||||
transfer_arg_lists = pool.starmap(
|
||||
prepare_single_directory_transfer_args,
|
||||
zip(data_requests, itertools.repeat(target_dir)),
|
||||
)
|
||||
# execute the file transfers
|
||||
versioned_transfer_arg_lists = [(source_file, target_file, version) for (source_file, target_file) in itertools.chain(*transfer_arg_lists)]
|
||||
logger.info(f"Normalizing and Compiling {len(versioned_transfer_arg_lists)} files...")
|
||||
for error in tqdm.tqdm(pool.imap_unordered(star_transfer_and_compile_file, versioned_transfer_arg_lists), total=len(versioned_transfer_arg_lists)):
|
||||
if error is not None:
|
||||
logger.debug(error)
|
||||
@@ -0,0 +1,67 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import ast
|
||||
import pathlib
|
||||
import sys
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def version_str_to_tuple(version_str: str) -> tuple[int, int]:
|
||||
# a version string is a string like 3.9.2
|
||||
versions = [int(version) for version in version_str.split(".")]
|
||||
return tuple(versions[:2])
|
||||
|
||||
|
||||
# must be run in python 3.9 or later for ast.unparse() support
|
||||
# version defaults to whatever version this script is running in; needs to be set explicitly for backwards compatibility
|
||||
# ast only supports versions 3.4 and later
|
||||
def normalize_source(
|
||||
source: str,
|
||||
version: Tuple[int, int] = sys.version_info[0:2],
|
||||
replace_docstrings=False,
|
||||
) -> str:
|
||||
"""
|
||||
Parse the source code into an AST, then convert back to source.
|
||||
This has the following normalizing effects:
|
||||
1. whitespace is set according to the PEP standard
|
||||
2. each statement is on exactly one line
|
||||
3. # comments are removed (note: docstrings are not removed)
|
||||
|
||||
:param str source: The source code to normalize
|
||||
:param tuple version: The (Major, Minor) version of python to parse with; must be at least (3, 4); defaults to
|
||||
same version as this script
|
||||
:param bool replace_docstrings: Replace all docstrings with 'pass'
|
||||
"""
|
||||
tree = ast.parse(source, feature_version=version)
|
||||
if replace_docstrings:
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Expr) and isinstance(node.value, ast.Str):
|
||||
node.value.s = "pass"
|
||||
return ast.unparse(tree)
|
||||
|
||||
|
||||
def normalize_source_file(
|
||||
source_file_path: str,
|
||||
cleaned_suffix: str = "-cleaned",
|
||||
version: tuple[int, int] = sys.version_info[0:2],
|
||||
):
|
||||
"""
|
||||
Normalizes the source code in a given file, then saves it to a '-cleaned' version in the same directory
|
||||
|
||||
:param str source_file_path: The absolute or relative path to the source .py file
|
||||
:param str cleaned_suffix: The suffix to add to the cleaned file, typically left as default
|
||||
:param tuple version: The (Major, Minor) version of python to parse with; must be at least (3, 4); defaults to
|
||||
same version as this script
|
||||
"""
|
||||
|
||||
# add the cleaned_suffix to the output_path
|
||||
input_path = pathlib.Path(source_file_path).resolve()
|
||||
output_path = input_path.with_stem(f"{input_path.stem}{cleaned_suffix}")
|
||||
|
||||
with open(input_path, "r") as source_file:
|
||||
normalized_source = normalize_source(source_file.read(), version=version)
|
||||
|
||||
with open(output_path, "w") as cleaned_file:
|
||||
cleaned_file.write(normalized_source)
|
||||
|
||||
return output_path
|
||||
@@ -0,0 +1,60 @@
|
||||
from io import BytesIO
|
||||
from typing import Dict, List, Literal
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from .DatasetDescription import DatasetDescription
|
||||
|
||||
LOCAL_DATASET = Dict[Literal["train", "test", "valid"], List[str]]
|
||||
|
||||
|
||||
def upload_single_dataset(data_files: LOCAL_DATASET, dataset_name: str, dataset_card: str):
|
||||
local_datasets = load_dataset("csv", data_files=data_files)
|
||||
local_datasets.push_to_hub(dataset_name, private=True)
|
||||
|
||||
dataset_card_with_stats = dataset_card + f"\n\nDataset Statistics:\n\n```\n{local_datasets}\n```"
|
||||
|
||||
api = HfApi()
|
||||
api.upload_file(
|
||||
path_or_fileobj=BytesIO(bytes(dataset_card_with_stats, "utf-8")),
|
||||
path_in_repo="README.md",
|
||||
repo_id=dataset_name,
|
||||
repo_type="dataset",
|
||||
)
|
||||
|
||||
|
||||
def upload_dataset_to_huggingface(dataset_description: DatasetDescription):
|
||||
formatted_data_requests = "\n".join(f"{str(req.source_path.resolve())}: (train: {req.num_train}, test: {req.num_test}, valid: {req.num_valid})" for req in dataset_description.data_requests)
|
||||
dataset_card = f"""
|
||||
# {dataset_description.name}
|
||||
|
||||
Created by the Syssec team @ UTD
|
||||
|
||||
Dataset Composition:
|
||||
|
||||
```
|
||||
{formatted_data_requests}
|
||||
```
|
||||
|
||||
Python version: `{".".join(map(str, dataset_description.version))}`
|
||||
"""
|
||||
|
||||
splits: List[Literal["train", "test", "valid"]] = [
|
||||
"train",
|
||||
"test",
|
||||
"valid",
|
||||
]
|
||||
|
||||
# collect data files
|
||||
segmentation_data_files: LOCAL_DATASET = {}
|
||||
statement_data_files: LOCAL_DATASET = {}
|
||||
for split in splits:
|
||||
segmentation_data_files[split] = [str(path.resolve()) for path in (dataset_description.csv_dir / split / "segmentation").glob("*.csv")]
|
||||
statement_data_files[split] = [str(path.resolve()) for path in (dataset_description.csv_dir / split / "statement").glob("*.csv")]
|
||||
|
||||
# upload datasets
|
||||
segmentation_dataset_name = f"{dataset_description.huggingface_user}/segmentation-{dataset_description.name}"
|
||||
upload_single_dataset(segmentation_data_files, segmentation_dataset_name, dataset_card)
|
||||
statement_dataset_name = f"{dataset_description.huggingface_user}/statement-{dataset_description.name}"
|
||||
upload_single_dataset(statement_data_files, statement_dataset_name, dataset_card)
|
||||
@@ -0,0 +1,66 @@
|
||||
import json
|
||||
import logging
|
||||
import pathlib
|
||||
from typing import Union
|
||||
import click
|
||||
|
||||
from dataset_generation.bytecode2csv import create_csv_dataset
|
||||
from dataset_generation.create_code_dataset import create_code_dataset
|
||||
from dataset_generation.DatasetDescription import DataRequest, DatasetDescription
|
||||
from dataset_generation.upload_raw_dataset import upload_dataset_to_huggingface
|
||||
from pylingual.utils.get_logger import get_logger
|
||||
|
||||
|
||||
def get_dataset_description_from_arg_json(json_path: str, logger: Union[logging.Logger, None] = None) -> DatasetDescription:
|
||||
json_file_path = pathlib.Path(json_path)
|
||||
|
||||
if not json_file_path.exists():
|
||||
raise FileNotFoundError(f"{json_file_path} does not exist")
|
||||
|
||||
if logger:
|
||||
logger.info(f"Loading dataset description from {json_file_path}...")
|
||||
|
||||
with json_file_path.open() as json_file:
|
||||
dataset_description_dict = json.load(json_file)
|
||||
|
||||
dataset_description_dict["data_requests"] = [DataRequest(**d) for d in dataset_description_dict["data_requests"]]
|
||||
return DatasetDescription(**dataset_description_dict)
|
||||
|
||||
|
||||
@click.command(help="Samples, splits, processes, and uploads a given dataset described by JSON.")
|
||||
@click.argument("json_path", type=str)
|
||||
def main(json_path: str):
|
||||
logger = get_logger("prepare-dataset")
|
||||
|
||||
dataset_description = get_dataset_description_from_arg_json(json_path, logger)
|
||||
logger.debug(dataset_description)
|
||||
|
||||
if dataset_description.code_dir.exists():
|
||||
raise FileExistsError(f"{dataset_description.code_dir} already exists! The dataset name is probably already taken.")
|
||||
|
||||
logger.info("Creating code dataset...")
|
||||
if not (dataset_description.data_requests and dataset_description.code_dir and dataset_description.version):
|
||||
logger.error("Dataset description is missing required fields")
|
||||
exit(1)
|
||||
create_code_dataset(
|
||||
dataset_description.data_requests,
|
||||
dataset_description.code_dir,
|
||||
dataset_description.version,
|
||||
logger,
|
||||
)
|
||||
|
||||
# create csv dataset
|
||||
logger.info("Converting code dataset to csv...")
|
||||
create_csv_dataset(
|
||||
dataset_description.code_dir,
|
||||
dataset_description.csv_dir,
|
||||
dataset_description.data_requests,
|
||||
logger,
|
||||
)
|
||||
|
||||
logger.info(f"Uploading {dataset_description.name} to HuggingFace...")
|
||||
upload_dataset_to_huggingface(dataset_description)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,24 @@
|
||||
{
|
||||
"name": "sample_dataset_name",
|
||||
"version": [3, 6],
|
||||
"save_to_dir": "./save_dir/",
|
||||
"huggingface_user": "sample_user",
|
||||
|
||||
"data_requests":
|
||||
[
|
||||
{
|
||||
"name": "dataset",
|
||||
"source_path": "./dataset",
|
||||
"num_train": 200,
|
||||
"num_test": 200,
|
||||
"num_valid": 200
|
||||
},
|
||||
{
|
||||
"name": "dataset2",
|
||||
"source_path": "./dataset2",
|
||||
"num_train": 200,
|
||||
"num_test": 200,
|
||||
"num_valid": 200
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"base_repo_name": "sample_user/sample_segmenter_name",
|
||||
"dataset_repo_name": "sample_user/segmentation-sample_dataset_name",
|
||||
"pretrained_mlm_repo_name": "",
|
||||
"cache_dir": "./cache-dir/",
|
||||
"max_token_length": 512,
|
||||
"dataset_percentage": 100,
|
||||
"mlm_training_parameters": {
|
||||
"batch_size": 48,
|
||||
"epochs": 2,
|
||||
"learning_rate": 5e-5
|
||||
},
|
||||
"segmentation_training_parameters": {
|
||||
"batch_size": 48,
|
||||
"epochs": 2,
|
||||
"learning_rate": 2e-5
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"base_repo_name": "sample_user/sample_name",
|
||||
"dataset_repo_name": "sample_user/statement-sample_dataset_name",
|
||||
"tokenizer_repo_name": "sample_user/sample_name-tok",
|
||||
"pretrained_seq2seq_repo_name": "Salesforce/codet5-base",
|
||||
"cache_dir": "./cache-dir/",
|
||||
"max_token_length": 256,
|
||||
"dataset_percentage": 100,
|
||||
"do_eval": true,
|
||||
"fp16": true,
|
||||
"statement_training_parameters": {
|
||||
"batch_size": 24,
|
||||
"epochs": 2,
|
||||
"learning_rate": 2e-5
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
import json
|
||||
import logging
|
||||
import pathlib
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingParameters:
|
||||
batch_size: int
|
||||
epochs: int
|
||||
learning_rate: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class SegmentationConfiguration:
|
||||
base_repo_name: str
|
||||
dataset_repo_name: str
|
||||
pretrained_mlm_repo_name: str
|
||||
cache_dir: pathlib.Path
|
||||
max_token_length: int
|
||||
dataset_percentage: int
|
||||
mlm_training_parameters: TrainingParameters
|
||||
segmentation_training_parameters: TrainingParameters
|
||||
|
||||
@property
|
||||
def tokenizer_repo_name(self):
|
||||
return self.base_repo_name + "-tokenizer"
|
||||
|
||||
@property
|
||||
def tokenizer_json_path(self):
|
||||
return self.cache_dir / "tokenizers" / self.tokenizer_repo_name / "tokenizer.json"
|
||||
|
||||
@property
|
||||
def tokenized_dataset_repo_name(self):
|
||||
return self.dataset_repo_name + "-tokenized"
|
||||
|
||||
@property
|
||||
def mlm_repo_name(self):
|
||||
return self.base_repo_name + "-mlm"
|
||||
|
||||
@property
|
||||
def mlm_dir(self):
|
||||
return self.cache_dir / "models" / self.mlm_repo_name
|
||||
|
||||
@property
|
||||
def segmenter_repo_name(self):
|
||||
return self.base_repo_name + "-segmenter"
|
||||
|
||||
@property
|
||||
def segmenter_dir(self):
|
||||
return self.cache_dir / "models" / self.segmenter_repo_name
|
||||
|
||||
@property
|
||||
def dataset_dir(self):
|
||||
return self.cache_dir / "datasets" / self.dataset_repo_name
|
||||
|
||||
def __post_init__(self):
|
||||
self.cache_dir = pathlib.Path(self.cache_dir)
|
||||
|
||||
|
||||
def parse_segmentation_config_json(json_file_path: pathlib.Path, logger: Optional[logging.Logger] = None) -> SegmentationConfiguration:
|
||||
if not json_file_path.exists():
|
||||
raise FileNotFoundError(f"{json_file_path} does not exist")
|
||||
|
||||
if logger:
|
||||
logger.info(f"Loading model description from {json_file_path}...")
|
||||
|
||||
with json_file_path.open() as json_file:
|
||||
segmentation_config_dict = json.load(json_file)
|
||||
|
||||
segmentation_config_dict["mlm_training_parameters"] = TrainingParameters(**segmentation_config_dict["mlm_training_parameters"])
|
||||
segmentation_config_dict["segmentation_training_parameters"] = TrainingParameters(**segmentation_config_dict["segmentation_training_parameters"])
|
||||
return SegmentationConfiguration(**segmentation_config_dict)
|
||||
@@ -0,0 +1,152 @@
|
||||
import ast
|
||||
import functools
|
||||
import os
|
||||
import pathlib
|
||||
import click
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import hf_hub_download
|
||||
from SegmentationConfiguration import SegmentationConfiguration, parse_segmentation_config_json
|
||||
from pylingual.segmentation.sliding_window import sliding_window
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
bytecode_separator = " <SEP> "
|
||||
|
||||
|
||||
def load_tokenizer(tokenizer_repo_name: str, cache_dir: pathlib.Path) -> PreTrainedTokenizerFast:
|
||||
tokenizer_dir = cache_dir / "tokenizers" / tokenizer_repo_name
|
||||
|
||||
tokenizer_file = hf_hub_download(repo_id=tokenizer_repo_name, filename="tokenizer.json", token=True, cache_dir=str(tokenizer_dir))
|
||||
tokenizer = PreTrainedTokenizerFast(
|
||||
tokenizer_file=tokenizer_file,
|
||||
unk_token="[UNK]",
|
||||
pad_token="[PAD]",
|
||||
cls_token="[CLS]",
|
||||
sep_token="[SEP]",
|
||||
mask_token="[MASK]",
|
||||
)
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
# we need to make sure we align all the labels with the proper words.
|
||||
def align_labels_with_tokens(labels, word_ids):
|
||||
label_names = ["B", "I", "E"]
|
||||
id2label = {str(i): label for i, label in enumerate(label_names)}
|
||||
label2id = {v: k for k, v in id2label.items()}
|
||||
|
||||
new_labels = []
|
||||
current_word = None
|
||||
for word_id in word_ids:
|
||||
if word_id != current_word:
|
||||
# Start of a new word!
|
||||
current_word = word_id
|
||||
label = -100 if word_id is None else int(label2id[labels[word_id]])
|
||||
new_labels.append(label)
|
||||
elif word_id is None:
|
||||
# Special token
|
||||
new_labels.append(-100)
|
||||
else:
|
||||
# Same word as previous token
|
||||
label = int(label2id[labels[word_id]])
|
||||
new_labels.append(label)
|
||||
return new_labels
|
||||
|
||||
|
||||
# the process function used for tokenize the dataset
|
||||
def tokenize_and_align_labels(tokenizer: PreTrainedTokenizerFast, max_length: int, examples):
|
||||
MAX_WINDOW_LENGTH = 512
|
||||
STEP_SIZE = 128
|
||||
|
||||
# parse the strings into lists to better work with the bytecode and boundaries
|
||||
parsed_bc = [(codeobj.split(" <SEP> "), ast.literal_eval(bounds)) for codeobj, bounds in zip(examples["bytecode"], examples["boundary"])]
|
||||
|
||||
codeobj_tokens = []
|
||||
|
||||
# count the tokens for each bytecode instruction in a codeobj
|
||||
for codeobj, bounds in parsed_bc:
|
||||
token_list = []
|
||||
|
||||
for bc, bounds in zip(codeobj, bounds):
|
||||
token_list.append(((bc, bounds), len(tokenizer(bc)[0])))
|
||||
|
||||
codeobj_tokens.append(token_list)
|
||||
|
||||
windows = [sliding_window(codeobj, MAX_WINDOW_LENGTH, STEP_SIZE) for codeobj in codeobj_tokens]
|
||||
|
||||
# remake examples using our windows
|
||||
examples["boundary"] = []
|
||||
examples["bytecode"] = []
|
||||
|
||||
# go through each window
|
||||
for window in windows:
|
||||
for item in window:
|
||||
# where we will temporarily store our bytecode and bounds
|
||||
bytecode = []
|
||||
bounds = []
|
||||
|
||||
for bc in item[0]:
|
||||
bytecode.append(bc[0])
|
||||
bounds.append(bc[1])
|
||||
|
||||
# append it into examples
|
||||
examples["bytecode"].append(bytecode_separator.join(bytecode))
|
||||
examples["boundary"].append(str(bounds))
|
||||
|
||||
tokenized_inputs = tokenizer(
|
||||
examples["bytecode"],
|
||||
truncation=True,
|
||||
max_length=max_length,
|
||||
)
|
||||
|
||||
all_labels = examples["boundary"]
|
||||
new_labels = []
|
||||
for i, labels in enumerate(all_labels):
|
||||
labels = labels.replace("'", "").strip("][").split(", ")
|
||||
word_ids = tokenized_inputs.word_ids(i)
|
||||
labels_len = len(labels)
|
||||
max_word_id = word_ids[-2]
|
||||
# for those data might cause error due to the incorrect tokenization, we fix the data exceed-length issue and
|
||||
# leave them here as some noisy data.
|
||||
if max_word_id >= labels_len:
|
||||
new_labels.append([-100] * max_word_id)
|
||||
else:
|
||||
new_labels.append(align_labels_with_tokens(labels, word_ids))
|
||||
|
||||
tokenized_inputs["labels"] = new_labels
|
||||
|
||||
return tokenized_inputs
|
||||
|
||||
|
||||
def tokenize_segmentation_dataset(config: SegmentationConfiguration):
|
||||
raw_dataset = load_dataset(config.dataset_repo_name, token=True, cache_dir=str(config.dataset_dir))
|
||||
|
||||
tokenizer = load_tokenizer(config.tokenizer_repo_name, config.cache_dir)
|
||||
prepped_tokenize_and_align_labels = functools.partial(tokenize_and_align_labels, tokenizer, config.max_token_length)
|
||||
|
||||
# tokenize input dataset
|
||||
column_names = raw_dataset["train"].column_names
|
||||
tokenized_datasets = raw_dataset.map(
|
||||
prepped_tokenize_and_align_labels,
|
||||
batched=True,
|
||||
remove_columns=column_names,
|
||||
num_proc=os.cpu_count(),
|
||||
desc="Tokenizing datasets",
|
||||
)
|
||||
|
||||
tokenized_datasets.push_to_hub(
|
||||
config.tokenized_dataset_repo_name,
|
||||
private=True,
|
||||
)
|
||||
|
||||
|
||||
@click.command(help="Script to tokenize the segmentation dataset given a segmentation json.")
|
||||
@click.argument("json_path", type=str)
|
||||
def main(json_path: str):
|
||||
json_file_path = pathlib.Path(json_path)
|
||||
segmentation_config = parse_segmentation_config_json(json_file_path)
|
||||
tokenize_segmentation_dataset(segmentation_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,195 @@
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import click
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import hf_hub_download, repo_exists
|
||||
from SegmentationConfiguration import SegmentationConfiguration, parse_segmentation_config_json
|
||||
from transformers import AutoModelForMaskedLM, DataCollatorForLanguageModeling, PreTrainedTokenizerFast, RobertaConfig, RobertaForMaskedLM, Trainer, TrainingArguments
|
||||
|
||||
from pylingual.segmentation.sliding_window import sliding_window
|
||||
|
||||
bytecode_separator = " <SEP> "
|
||||
|
||||
|
||||
def load_tokenizer(tokenizer_repo_name: str, cache_dir: pathlib.Path) -> PreTrainedTokenizerFast:
|
||||
tokenizer_dir = cache_dir / "tokenizers" / tokenizer_repo_name
|
||||
|
||||
tokenizer_file = hf_hub_download(
|
||||
repo_id=tokenizer_repo_name,
|
||||
filename="tokenizer.json",
|
||||
token=True,
|
||||
cache_dir=str(tokenizer_dir),
|
||||
)
|
||||
tokenizer = PreTrainedTokenizerFast(
|
||||
tokenizer_file=tokenizer_file,
|
||||
unk_token="[UNK]",
|
||||
pad_token="[PAD]",
|
||||
cls_token="[CLS]",
|
||||
sep_token="[SEP]",
|
||||
mask_token="[MASK]",
|
||||
)
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def load_tokenized_train_dataset(
|
||||
dataset_repo_name: str,
|
||||
tokenizer: PreTrainedTokenizerFast,
|
||||
max_length: int,
|
||||
cache_dir: pathlib.Path,
|
||||
):
|
||||
dataset_dir = cache_dir / "datasets" / dataset_repo_name
|
||||
raw_dataset = load_dataset(dataset_repo_name, token=True, cache_dir=dataset_dir, split="train")
|
||||
|
||||
# tokenize the input data
|
||||
column_names = raw_dataset.column_names
|
||||
|
||||
def tokenize(examples):
|
||||
# sliding window compatibility
|
||||
MAX_WINDOW_LENGTH = 512
|
||||
STEP_SIZE = 128
|
||||
|
||||
# parse the strings into lists to better work with the bytecode and boundaries
|
||||
parsed_bc = [codeobj.split(" <SEP> ") for codeobj in examples["bytecode"]]
|
||||
|
||||
codeobj_tokens = []
|
||||
|
||||
# count the tokens for each bytecode instruction in a codeobj
|
||||
for codeobj in parsed_bc:
|
||||
token_list = []
|
||||
|
||||
for bytecode in codeobj:
|
||||
token_list.append((bytecode, len(tokenizer(bytecode)[0])))
|
||||
|
||||
codeobj_tokens.append(token_list)
|
||||
|
||||
windows = [sliding_window(codeobj, MAX_WINDOW_LENGTH, STEP_SIZE) for codeobj in codeobj_tokens]
|
||||
|
||||
# remake examples using our windows
|
||||
examples["bytecode"] = []
|
||||
|
||||
# go through each window
|
||||
for window in windows:
|
||||
for item in window:
|
||||
# where we will temporarily store our bytecode and bounds
|
||||
bytecode = []
|
||||
|
||||
for bc in item[0]:
|
||||
bytecode.append(bc)
|
||||
|
||||
# append to examples
|
||||
examples["bytecode"].append(bytecode_separator.join(bytecode))
|
||||
|
||||
return tokenizer(examples["bytecode"], max_length=max_length, truncation=True)
|
||||
|
||||
tokenized_dataset = raw_dataset.map(
|
||||
tokenize,
|
||||
batched=True,
|
||||
remove_columns=column_names,
|
||||
num_proc=os.cpu_count(),
|
||||
desc="Tokenizing datasets",
|
||||
)
|
||||
|
||||
return tokenized_dataset
|
||||
|
||||
|
||||
def load_pretrained_mlm(
|
||||
pretrained_mlm_repo_name: str,
|
||||
tokenizer_embedding_length: int,
|
||||
cache_dir: pathlib.Path,
|
||||
) -> AutoModelForMaskedLM:
|
||||
# load a basic pretrained BERT model
|
||||
pretrained_mlm_dir = cache_dir / "models" / pretrained_mlm_repo_name
|
||||
model = AutoModelForMaskedLM.from_pretrained(pretrained_mlm_repo_name, cache_dir=str(pretrained_mlm_dir))
|
||||
|
||||
# resize token embeddings to fit the model
|
||||
model.resize_token_embeddings(tokenizer_embedding_length)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def initialize_untrained_mlm(
|
||||
tokenizer_embedding_length: int,
|
||||
max_token_length: int,
|
||||
) -> RobertaForMaskedLM:
|
||||
# initialize untrained RoBERTa model
|
||||
# most configuration options set to match https://huggingface.co/microsoft/codebert-base/blob/main/config.json for direct comparison
|
||||
model_config = RobertaConfig(
|
||||
max_position_embeddings=max_token_length, # INPUT LENGTH LIMIT
|
||||
vocab_size=tokenizer_embedding_length,
|
||||
layer_norm_eps=1e-05,
|
||||
type_vocab_size=1,
|
||||
)
|
||||
model = RobertaForMaskedLM(model_config)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def train_mlm(config: SegmentationConfiguration):
|
||||
if repo_exists(config.base_repo_name):
|
||||
logging.error(f"{config.base_repo_name} has already exists")
|
||||
exit(1)
|
||||
|
||||
using_pretrained_model = bool(config.pretrained_mlm_repo_name)
|
||||
# train model, for now the configuration comes from a regular T5 translation model.
|
||||
training_args = TrainingArguments(
|
||||
output_dir=str(config.mlm_dir),
|
||||
num_train_epochs=config.mlm_training_parameters.epochs,
|
||||
per_device_train_batch_size=config.mlm_training_parameters.batch_size,
|
||||
save_steps=1000,
|
||||
save_total_limit=5,
|
||||
prediction_loss_only=True,
|
||||
push_to_hub=True,
|
||||
hub_model_id=config.mlm_repo_name,
|
||||
hub_private_repo=True,
|
||||
ddp_backend="nccl",
|
||||
ddp_find_unused_parameters=using_pretrained_model, # only look for unused parameters in pretrained models
|
||||
remove_unused_columns=False,
|
||||
)
|
||||
|
||||
tokenizer = load_tokenizer(config.tokenizer_repo_name, config.cache_dir)
|
||||
|
||||
# Set DataCollator for MLM task, set the probability of masking.
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
|
||||
|
||||
if using_pretrained_model:
|
||||
pretrained_mlm = load_pretrained_mlm(config.pretrained_mlm_repo_name, len(tokenizer), config.cache_dir)
|
||||
else:
|
||||
pretrained_mlm = initialize_untrained_mlm(len(tokenizer), config.max_token_length + 2)
|
||||
|
||||
tokenized_training_data = load_tokenized_train_dataset(config.dataset_repo_name, tokenizer, config.max_token_length, config.cache_dir)
|
||||
|
||||
# Hugging face trainer: a Trainer class to fine-tune pretrained models
|
||||
trainer = Trainer(
|
||||
model=pretrained_mlm,
|
||||
args=training_args,
|
||||
data_collator=data_collator,
|
||||
train_dataset=tokenized_training_data,
|
||||
)
|
||||
|
||||
# Training
|
||||
trainer.train()
|
||||
|
||||
if int(os.environ["LOCAL_RANK"]) == 0:
|
||||
# Save the model
|
||||
trainer.save_model(config.mlm_dir)
|
||||
|
||||
trainer.push_to_hub(
|
||||
finetuned_from=config.pretrained_mlm_repo_name,
|
||||
dataset=config.dataset_repo_name,
|
||||
commit_message=f"Trained on {config.dataset_repo_name} using {config.tokenizer_repo_name}",
|
||||
)
|
||||
|
||||
|
||||
@click.command(help="Training script for the masked language model pretraining for the segmentation model given a segmentation json.")
|
||||
@click.argument("json_path", type=str)
|
||||
def main(json_path: str):
|
||||
json_file_path = pathlib.Path(json_path)
|
||||
segmentation_config = parse_segmentation_config_json(json_file_path)
|
||||
train_mlm(segmentation_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,155 @@
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import click
|
||||
|
||||
import evaluate
|
||||
import numpy as np
|
||||
from datasets import ReadInstruction, load_dataset
|
||||
from huggingface_hub import hf_hub_download, repo_exists
|
||||
from SegmentationConfiguration import SegmentationConfiguration, parse_segmentation_config_json
|
||||
from transformers import AutoModelForTokenClassification, DataCollatorForTokenClassification, PreTrainedTokenizerFast, Trainer, TrainingArguments
|
||||
|
||||
# two dictionaries, id2label and label2id, which contain the mappings from ID to label and vice versa.
|
||||
label_names = ["B", "I", "E"]
|
||||
id2label = {str(i): label for i, label in enumerate(label_names)}
|
||||
label2id = {v: k for k, v in id2label.items()}
|
||||
|
||||
|
||||
# compute_metrics: evaluate metric for training and evaluation.
|
||||
def compute_metrics(eval_preds):
|
||||
metric = evaluate.load("seqeval")
|
||||
logits, labels = eval_preds
|
||||
predictions = np.argmax(logits, axis=-1)
|
||||
|
||||
# Remove ignored index (special tokens) and convert to labels
|
||||
# noqa: E741
|
||||
true_labels = [[label_names[l] for l in label if l != -100] for label in labels]
|
||||
true_predictions = [[label_names[p] for (p, l) in zip(prediction, label) if l != -100] for prediction, label in zip(predictions, labels)]
|
||||
all_metrics = metric.compute(predictions=true_predictions, references=true_labels)
|
||||
return {
|
||||
"precision": all_metrics["overall_precision"],
|
||||
"recall": all_metrics["overall_recall"],
|
||||
"f1": all_metrics["overall_f1"],
|
||||
"accuracy": all_metrics["overall_accuracy"],
|
||||
}
|
||||
|
||||
|
||||
def load_tokenizer(tokenizer_repo_name: str, cache_dir: pathlib.Path) -> PreTrainedTokenizerFast:
|
||||
tokenizer_dir = cache_dir / "tokenizers" / tokenizer_repo_name
|
||||
|
||||
tokenizer_file = hf_hub_download(
|
||||
repo_id=tokenizer_repo_name,
|
||||
filename="tokenizer.json",
|
||||
token=True,
|
||||
cache_dir=str(tokenizer_dir),
|
||||
)
|
||||
tokenizer = PreTrainedTokenizerFast(
|
||||
tokenizer_file=tokenizer_file,
|
||||
unk_token="[UNK]",
|
||||
pad_token="[PAD]",
|
||||
cls_token="[CLS]",
|
||||
sep_token="[SEP]",
|
||||
mask_token="[MASK]",
|
||||
)
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def load_tokenized_train_and_valid_dataset(dataset_repo_name: str, cache_dir: pathlib.Path, dataset_percentage: int = 100):
|
||||
dataset_dir = cache_dir / "datasets" / dataset_repo_name
|
||||
# Load the tokenized dataset
|
||||
tokenized_train_dataset = load_dataset(
|
||||
dataset_repo_name,
|
||||
token=True,
|
||||
cache_dir=str(dataset_dir),
|
||||
split=ReadInstruction("train", to=dataset_percentage, unit="%"),
|
||||
)
|
||||
|
||||
tokenized_validation_dataset = load_dataset(
|
||||
dataset_repo_name,
|
||||
token=True,
|
||||
cache_dir=str(dataset_dir),
|
||||
split="valid",
|
||||
)
|
||||
|
||||
return tokenized_train_dataset, tokenized_validation_dataset
|
||||
|
||||
|
||||
def train_segmentation_model(config: SegmentationConfiguration):
|
||||
if repo_exists(config.base_repo_name):
|
||||
logging.error(f"{config.base_repo_name} has already exists")
|
||||
exit(1)
|
||||
# training arguments.
|
||||
training_args = TrainingArguments(
|
||||
output_dir=str(config.segmenter_dir),
|
||||
overwrite_output_dir=True,
|
||||
eval_strategy="epoch",
|
||||
logging_strategy="epoch",
|
||||
save_strategy="epoch",
|
||||
learning_rate=config.segmentation_training_parameters.learning_rate,
|
||||
num_train_epochs=config.segmentation_training_parameters.epochs,
|
||||
per_device_train_batch_size=config.segmentation_training_parameters.batch_size,
|
||||
save_steps=1000,
|
||||
weight_decay=0.01,
|
||||
fp16=True,
|
||||
push_to_hub=True,
|
||||
hub_model_id=config.segmenter_repo_name,
|
||||
hub_private_repo=True,
|
||||
ddp_backend="nccl",
|
||||
ddp_find_unused_parameters=True,
|
||||
save_total_limit=5,
|
||||
)
|
||||
|
||||
# load a basic pretrained BERT model
|
||||
model = AutoModelForTokenClassification.from_pretrained(
|
||||
pretrained_model_name_or_path=config.mlm_repo_name,
|
||||
id2label=id2label,
|
||||
label2id=label2id,
|
||||
token=True,
|
||||
)
|
||||
|
||||
# Set DataCollator for DataCollatorForTokenClassification
|
||||
tokenizer = load_tokenizer(config.tokenizer_repo_name, config.cache_dir)
|
||||
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, max_length=config.max_token_length)
|
||||
|
||||
(
|
||||
tokenized_train_dataset,
|
||||
tokenized_validation_dataset,
|
||||
) = load_tokenized_train_and_valid_dataset(config.tokenized_dataset_repo_name, config.cache_dir, config.dataset_percentage)
|
||||
|
||||
# Hugging face trainer: a Trainer class to fine-tune pretrained models
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
data_collator=data_collator,
|
||||
train_dataset=tokenized_train_dataset,
|
||||
eval_dataset=tokenized_validation_dataset,
|
||||
compute_metrics=compute_metrics,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
# Training
|
||||
trainer.train()
|
||||
|
||||
if int(os.environ["LOCAL_RANK"]) == 0:
|
||||
# Save the model
|
||||
trainer.save_model(str(config.segmenter_dir))
|
||||
|
||||
trainer.push_to_hub(
|
||||
finetuned_from=config.mlm_repo_name,
|
||||
dataset=config.tokenized_dataset_repo_name,
|
||||
commit_message=f"Trained on {config.tokenized_dataset_repo_name} using {config.mlm_repo_name}",
|
||||
)
|
||||
|
||||
|
||||
@click.command(help="Training script for the segmentation model given a segmentation json.")
|
||||
@click.argument("json_path", type=str)
|
||||
def main(json_path: str):
|
||||
json_file_path = pathlib.Path(json_path)
|
||||
segmentation_config = parse_segmentation_config_json(json_file_path)
|
||||
train_segmentation_model(segmentation_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,96 @@
|
||||
import logging
|
||||
import pathlib
|
||||
import click
|
||||
|
||||
from datasets import ReadInstruction, load_dataset
|
||||
from huggingface_hub import HfApi, create_repo, repo_exists
|
||||
from SegmentationConfiguration import SegmentationConfiguration, parse_segmentation_config_json
|
||||
from tokenizers import Tokenizer, decoders, models, normalizers, pre_tokenizers, processors, trainers
|
||||
|
||||
special_tokens = ["[UNK]", "[PAD]", "[CLS]", "[SEP]", "[MASK]"]
|
||||
|
||||
|
||||
def get_untrained_tokenizer() -> Tokenizer:
|
||||
# WordPiece tokenization for BERT.
|
||||
tokenizer = Tokenizer(models.WordPiece(unk_token="[UNK]"))
|
||||
|
||||
# The normalizer recognizes the accented characters and strip them out.
|
||||
tokenizer.normalizer = normalizers.Sequence([normalizers.NFD(), normalizers.StripAccents()])
|
||||
|
||||
# The pre-tokenizer splits on <SEP> tokens.
|
||||
tokenizer.pre_tokenizer = pre_tokenizers.Split("<SEP>", "removed")
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def post_training_configuration(tokenizer: Tokenizer):
|
||||
cls_token_id = tokenizer.token_to_id("[CLS]")
|
||||
sep_token_id = tokenizer.token_to_id("[SEP]")
|
||||
|
||||
# Set decoder for the tokenizer
|
||||
tokenizer.decoder = decoders.WordPiece(prefix="##")
|
||||
|
||||
# For the TemplateProcessor, we have to specify how to treat a single sentence and a pair of sentences.
|
||||
tokenizer.post_processor = processors.TemplateProcessing(
|
||||
single="[CLS]:0 $A:0 [SEP]:0",
|
||||
pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
|
||||
special_tokens=[("[CLS]", cls_token_id), ("[SEP]", sep_token_id)],
|
||||
)
|
||||
|
||||
|
||||
def save_and_upload_tokenizer(
|
||||
tokenizer: Tokenizer,
|
||||
tokenizer_json_path: pathlib.Path,
|
||||
tokenizer_repo_name: str,
|
||||
dataset_name: str,
|
||||
):
|
||||
# save the tokenizer locally
|
||||
tokenizer_json_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tokenizer.save(str(tokenizer_json_path.resolve()))
|
||||
|
||||
# upload tokenizer to huggingface
|
||||
api = HfApi()
|
||||
create_repo(tokenizer_repo_name, exist_ok=True, private=True)
|
||||
api.upload_file(
|
||||
path_in_repo="tokenizer.json",
|
||||
path_or_fileobj=str(tokenizer_json_path.resolve()),
|
||||
repo_id=tokenizer_repo_name,
|
||||
commit_message=f"Trained tokenizer using {dataset_name}",
|
||||
)
|
||||
|
||||
|
||||
def train_tokenizer(config: SegmentationConfiguration):
|
||||
if repo_exists(config.base_repo_name):
|
||||
logging.error(f"{config.base_repo_name} has already exists")
|
||||
exit(1)
|
||||
|
||||
tokenizer = get_untrained_tokenizer()
|
||||
|
||||
train_dataset = load_dataset(
|
||||
config.dataset_repo_name,
|
||||
token=True,
|
||||
split=ReadInstruction("train", to=config.dataset_percentage, unit="%"),
|
||||
)["bytecode"]
|
||||
trainer = trainers.WordPieceTrainer(vocab_size=30000, special_tokens=special_tokens)
|
||||
tokenizer.train_from_iterator(train_dataset, trainer=trainer)
|
||||
|
||||
post_training_configuration(tokenizer)
|
||||
|
||||
save_and_upload_tokenizer(
|
||||
tokenizer,
|
||||
config.tokenizer_json_path,
|
||||
config.tokenizer_repo_name,
|
||||
config.dataset_repo_name,
|
||||
)
|
||||
|
||||
|
||||
@click.command(help="Training script for the bytecode tokenizer for the segmentation model given a segmentation json.")
|
||||
@click.argument("json_path", type=str)
|
||||
def main(json_path: str):
|
||||
json_file_path = pathlib.Path(json_path)
|
||||
segmentation_config = parse_segmentation_config_json(json_file_path)
|
||||
train_tokenizer(segmentation_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,18 @@
|
||||
# seq2seq
|
||||
|
||||
- train_tokenizer_auto.py:
|
||||
- trains the manual tokenizer
|
||||
|
||||
- tokenize_seq2seq.py:
|
||||
- tokenize the dataset for the seq2seq model
|
||||
|
||||
- train_seq2seq.py:
|
||||
- finetuning the pretrained model
|
||||
- will create a sequence-to-sequence translation model
|
||||
|
||||
- StatementConfiguration.py
|
||||
- defines the JSON format for statement translation training
|
||||
|
||||
# manual1
|
||||
|
||||
Contains JSONs mapping bytecode instructions and their configurations to use in training.
|
||||
@@ -0,0 +1,59 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pathlib
|
||||
import json
|
||||
import logging
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingParameters:
|
||||
batch_size: int
|
||||
epochs: int
|
||||
learning_rate: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class StatementConfiguration:
|
||||
base_repo_name: str
|
||||
dataset_repo_name: str
|
||||
tokenizer_repo_name: str
|
||||
pretrained_seq2seq_repo_name: str
|
||||
cache_dir: pathlib.Path
|
||||
max_token_length: int
|
||||
dataset_percentage: int
|
||||
do_eval: bool
|
||||
fp16: bool
|
||||
statement_training_parameters: TrainingParameters
|
||||
|
||||
@property
|
||||
def tokenized_dataset_repo_name(self):
|
||||
return self.dataset_repo_name + "-tokenized"
|
||||
|
||||
@property
|
||||
def statement_model_repo_name(self):
|
||||
return self.base_repo_name + "-statement"
|
||||
|
||||
@property
|
||||
def statement_model_dir(self):
|
||||
return self.cache_dir / "models" / self.statement_model_repo_name
|
||||
|
||||
@property
|
||||
def log_dir(self):
|
||||
return self.statement_model_dir / "logs"
|
||||
|
||||
def __post_init__(self):
|
||||
self.cache_dir = pathlib.Path(self.cache_dir)
|
||||
|
||||
|
||||
def parse_statement_config_json(json_file_path: pathlib.Path, logger: logging.Logger = None) -> StatementConfiguration:
|
||||
if not json_file_path.exists():
|
||||
raise FileNotFoundError(f"{json_file_path} does not exist")
|
||||
|
||||
if logger:
|
||||
logger.info(f"Loading model description from {json_file_path}...")
|
||||
|
||||
with json_file_path.open() as json_file:
|
||||
statement_config_dict = json.load(json_file)
|
||||
|
||||
statement_config_dict["statement_training_parameters"] = TrainingParameters(**statement_config_dict["statement_training_parameters"])
|
||||
return StatementConfiguration(**statement_config_dict)
|
||||
@@ -0,0 +1,51 @@
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
import pathlib
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import RobertaTokenizer
|
||||
|
||||
from StatementConfiguration import StatementConfiguration, parse_statement_config_json
|
||||
|
||||
import functools
|
||||
|
||||
|
||||
def preprocess_function(tokenizer: RobertaTokenizer, max_token_length: int, input_key: str, examples: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Set up Huggingface tokenizers for both inputs and targets"""
|
||||
inputs = [ex if ex else "" for ex in examples[input_key]]
|
||||
targets = [ex if ex else "" for ex in examples["source"]]
|
||||
|
||||
return tokenizer(text=inputs, text_target=targets, max_length=max_token_length, truncation=True)
|
||||
|
||||
|
||||
def tokenize_seq2seq_dataset(config: StatementConfiguration):
|
||||
# ref: https://huggingface.co/Salesforce/codet5-base
|
||||
tokenizer = RobertaTokenizer.from_pretrained(config.tokenizer_repo_name)
|
||||
raw_datasets = load_dataset(config.dataset_repo_name, token=True)
|
||||
|
||||
column_names = raw_datasets["train"].column_names
|
||||
input_key = "bytecode"
|
||||
prepped_preprocess_function = functools.partial(preprocess_function, tokenizer, config.max_token_length, input_key)
|
||||
tokenized_datasets = raw_datasets.map(
|
||||
prepped_preprocess_function,
|
||||
batched=True,
|
||||
remove_columns=column_names,
|
||||
num_proc=os.cpu_count(),
|
||||
desc="Tokenizing datasets",
|
||||
)
|
||||
|
||||
tokenized_datasets.push_to_hub(config.tokenized_dataset_repo_name, private=True)
|
||||
|
||||
|
||||
@click.command(help="Tokenization script for Statement Translation model given a statement json.")
|
||||
@click.argument("json_path", type=str)
|
||||
def main(json_path: str):
|
||||
json_file_path = pathlib.Path(json_path)
|
||||
statement_config = parse_statement_config_json(json_file_path)
|
||||
tokenize_seq2seq_dataset(statement_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,929 @@
|
||||
{
|
||||
"add_prefix_space": false,
|
||||
"additional_special_tokens": [
|
||||
"<pad>",
|
||||
"<s>",
|
||||
"</s>",
|
||||
"<unk>",
|
||||
"<mask>",
|
||||
"!",
|
||||
"\"",
|
||||
"#",
|
||||
"$",
|
||||
"%",
|
||||
"&",
|
||||
"'",
|
||||
"(",
|
||||
")",
|
||||
"*",
|
||||
"+",
|
||||
",",
|
||||
"-",
|
||||
".",
|
||||
"/",
|
||||
"0",
|
||||
"1",
|
||||
"2",
|
||||
"3",
|
||||
"4",
|
||||
"5",
|
||||
"6",
|
||||
"7",
|
||||
"8",
|
||||
"9",
|
||||
":",
|
||||
";",
|
||||
"<",
|
||||
"=",
|
||||
">",
|
||||
"?",
|
||||
"@",
|
||||
"A",
|
||||
"B",
|
||||
"C",
|
||||
"D",
|
||||
"E",
|
||||
"F",
|
||||
"G",
|
||||
"H",
|
||||
"I",
|
||||
"J",
|
||||
"K",
|
||||
"L",
|
||||
"M",
|
||||
"N",
|
||||
"O",
|
||||
"P",
|
||||
"Q",
|
||||
"R",
|
||||
"S",
|
||||
"T",
|
||||
"U",
|
||||
"V",
|
||||
"W",
|
||||
"X",
|
||||
"Y",
|
||||
"Z",
|
||||
"[",
|
||||
"\\",
|
||||
"]",
|
||||
"^",
|
||||
"_",
|
||||
"`",
|
||||
"a",
|
||||
"b",
|
||||
"c",
|
||||
"d",
|
||||
"e",
|
||||
"f",
|
||||
"g",
|
||||
"h",
|
||||
"i",
|
||||
"j",
|
||||
"k",
|
||||
"l",
|
||||
"m",
|
||||
"n",
|
||||
"o",
|
||||
"p",
|
||||
"q",
|
||||
"r",
|
||||
"s",
|
||||
"t",
|
||||
"u",
|
||||
"v",
|
||||
"w",
|
||||
"x",
|
||||
"y",
|
||||
"z",
|
||||
"{",
|
||||
"|",
|
||||
"}",
|
||||
"~",
|
||||
"Ġ",
|
||||
"-=",
|
||||
"<<",
|
||||
">>",
|
||||
":=",
|
||||
">=",
|
||||
"<=",
|
||||
"==",
|
||||
"!=",
|
||||
"+=",
|
||||
"//=",
|
||||
"**=",
|
||||
"/=",
|
||||
"//",
|
||||
"%=",
|
||||
"@=",
|
||||
"&=",
|
||||
"|=",
|
||||
"^=",
|
||||
">>=",
|
||||
"<<=",
|
||||
"*=",
|
||||
"()",
|
||||
"):",
|
||||
"~>>",
|
||||
"**",
|
||||
"<codeobj:",
|
||||
"<KWARG_PAD>",
|
||||
"E->",
|
||||
"<TAP_0>",
|
||||
"defaults",
|
||||
"args:",
|
||||
"vararg:",
|
||||
"<TAP_1>",
|
||||
"<START_LINE>",
|
||||
"<SUB>",
|
||||
"<TAP_UP>",
|
||||
"E-END",
|
||||
"<TAP_2>",
|
||||
"~~>",
|
||||
"<TAP_ST>",
|
||||
"<SEP>",
|
||||
"</SUB>",
|
||||
"<TAP_3>",
|
||||
"<TAP_4>",
|
||||
"<TAP_5>",
|
||||
"<TAP_6>",
|
||||
"<TAP_7>",
|
||||
"False",
|
||||
"None",
|
||||
"True",
|
||||
"and",
|
||||
"assert",
|
||||
"async",
|
||||
"await",
|
||||
"break",
|
||||
"class",
|
||||
"continue",
|
||||
"def",
|
||||
"del",
|
||||
"elif",
|
||||
"else",
|
||||
"else:",
|
||||
"except",
|
||||
"except:",
|
||||
"finally",
|
||||
"finally:",
|
||||
"for",
|
||||
"from",
|
||||
"global",
|
||||
"if",
|
||||
"import",
|
||||
"in",
|
||||
"is",
|
||||
"lambda",
|
||||
"nonlocal",
|
||||
"not",
|
||||
"or",
|
||||
"pass",
|
||||
"raise",
|
||||
"return",
|
||||
"try",
|
||||
"try:",
|
||||
"while",
|
||||
"with",
|
||||
"yield",
|
||||
"case",
|
||||
"as",
|
||||
"ASYNC_GEN_WRAP",
|
||||
"BEFORE_ASYNC_WITH",
|
||||
"BEFORE_WITH",
|
||||
"BEGIN_FINALLY",
|
||||
"BINARY_ADD",
|
||||
"BINARY_AND",
|
||||
"BINARY_FLOOR_DIVIDE",
|
||||
"BINARY_LSHIFT",
|
||||
"BINARY_MATRIX_MULTIPLY",
|
||||
"BINARY_MODULO",
|
||||
"BINARY_MULTIPLY",
|
||||
"BINARY_OP",
|
||||
"BINARY_OR",
|
||||
"BINARY_POWER",
|
||||
"BINARY_RSHIFT",
|
||||
"BINARY_SLICE",
|
||||
"BINARY_SUBSCR",
|
||||
"BINARY_SUBTRACT",
|
||||
"BINARY_TRUE_DIVIDE",
|
||||
"BINARY_XOR",
|
||||
"BREAK_LOOP",
|
||||
"BUILD_CONST_KEY_MAP",
|
||||
"BUILD_LIST",
|
||||
"BUILD_LIST_UNPACK",
|
||||
"BUILD_MAP",
|
||||
"BUILD_MAP_UNPACK",
|
||||
"BUILD_MAP_UNPACK_WITH_CALL",
|
||||
"BUILD_SET",
|
||||
"BUILD_SET_UNPACK",
|
||||
"BUILD_SLICE",
|
||||
"BUILD_STRING",
|
||||
"BUILD_TUPLE",
|
||||
"BUILD_TUPLE_UNPACK",
|
||||
"BUILD_TUPLE_UNPACK_WITH_CALL",
|
||||
"CACHE",
|
||||
"CALL_FINALLY",
|
||||
"CALL_FUNCTION",
|
||||
"CALL_FUNCTION_EX",
|
||||
"CALL_FUNCTION_KW",
|
||||
"CALL_METHOD",
|
||||
"CHECK_EG_MATCH",
|
||||
"CHECK_EXC_MATCH",
|
||||
"CLEANUP_THROW",
|
||||
"COMPARE_OP",
|
||||
"CONTAINS_OP",
|
||||
"CONTINUE_LOOP",
|
||||
"COPY",
|
||||
"COPY_DICT_WITHOUT_KEYS",
|
||||
"COPY_FREE_VARS",
|
||||
"DELETE_ATTR",
|
||||
"DELETE_DEREF",
|
||||
"DELETE_FAST",
|
||||
"DELETE_GLOBAL",
|
||||
"DELETE_NAME",
|
||||
"DELETE_SUBSCR",
|
||||
"DICT_MERGE",
|
||||
"DICT_UPDATE",
|
||||
"DUP_TOP",
|
||||
"DUP_TOP_TWO",
|
||||
"END_ASYNC_FOR",
|
||||
"END_FINALLY",
|
||||
"END_FOR",
|
||||
"END_SEND",
|
||||
"EXTENDED_ARG",
|
||||
"FOR_ITER",
|
||||
"FORMAT_VALUE",
|
||||
"GEN_START",
|
||||
"GET_AITER",
|
||||
"GET_ANEXT",
|
||||
"GET_AWAITABLE",
|
||||
"GET_ITER",
|
||||
"GET_LEN",
|
||||
"GET_YIELD_FROM_ITER",
|
||||
"IMPORT_FROM",
|
||||
"IMPORT_NAME",
|
||||
"IMPORT_STAR",
|
||||
"INPLACE_ADD",
|
||||
"INPLACE_AND",
|
||||
"INPLACE_FLOOR_DIVIDE",
|
||||
"INPLACE_LSHIFT",
|
||||
"INPLACE_MATRIX_MULTIPLY",
|
||||
"INPLACE_MODULO",
|
||||
"INPLACE_MULTIPLY",
|
||||
"INPLACE_OR",
|
||||
"INPLACE_POWER",
|
||||
"INPLACE_RSHIFT",
|
||||
"INPLACE_SUBTRACT",
|
||||
"INPLACE_TRUE_DIVIDE",
|
||||
"INPLACE_XOR",
|
||||
"INTERPRETER_EXIT",
|
||||
"IS_OP",
|
||||
"JUMP_ABSOLUTE",
|
||||
"JUMP_BACKWARD",
|
||||
"JUMP_BACKWARD_NO_INTERRUPT",
|
||||
"JUMP_FORWARD",
|
||||
"JUMP_IF_FALSE_OR_POP",
|
||||
"JUMP_IF_NOT_EXC_MATCH",
|
||||
"JUMP_IF_TRUE_OR_POP",
|
||||
"LIST_APPEND",
|
||||
"LIST_EXTEND",
|
||||
"LIST_TO_TUPLE",
|
||||
"LOAD_ASSERTION_ERROR",
|
||||
"LOAD_ATTR",
|
||||
"LOAD_BUILD_CLASS",
|
||||
"LOAD_CLASSDEREF",
|
||||
"LOAD_CLOSURE",
|
||||
"LOAD_CONST",
|
||||
"LOAD_DEREF",
|
||||
"LOAD_FAST",
|
||||
"LOAD_FAST_AND_CLEAR",
|
||||
"LOAD_FAST_CHECK",
|
||||
"LOAD_GLOBAL",
|
||||
"LOAD_LOCALS",
|
||||
"LOAD_METHOD",
|
||||
"LOAD_NAME",
|
||||
"LOAD_SUPER_ATTR",
|
||||
"MAKE_CELL",
|
||||
"MAKE_FUNCTION",
|
||||
"MAP_ADD",
|
||||
"MATCH_CLASS",
|
||||
"MATCH_KEYS",
|
||||
"MATCH_MAPPING",
|
||||
"MATCH_SEQUENCE",
|
||||
"NOP",
|
||||
"POP_BLOCK",
|
||||
"POP_EXCEPT",
|
||||
"POP_FINALLY",
|
||||
"POP_JUMP_FORWARD_IF_FALSE",
|
||||
"POP_JUMP_FORWARD_IF_NONE",
|
||||
"POP_JUMP_FORWARD_IF_NOT_NONE",
|
||||
"POP_JUMP_FORWARD_IF_TRUE",
|
||||
"POP_JUMP_IF_FALSE",
|
||||
"POP_JUMP_IF_NONE",
|
||||
"POP_JUMP_IF_NOT_NONE",
|
||||
"POP_JUMP_IF_TRUE",
|
||||
"POP_TOP",
|
||||
"PRECALL",
|
||||
"PREP_RERAISE_STAR",
|
||||
"PRINT_EXPR",
|
||||
"PUSH_EXC_INFO",
|
||||
"PUSH_NULL",
|
||||
"RAISE_VARARGS",
|
||||
"RERAISE",
|
||||
"RESERVED",
|
||||
"RESUME",
|
||||
"RETURN_CONST",
|
||||
"RETURN_GENERATOR",
|
||||
"RETURN_VALUE",
|
||||
"ROT_FOUR",
|
||||
"ROT_N",
|
||||
"ROT_THREE",
|
||||
"ROT_TWO",
|
||||
"SEND",
|
||||
"SET_ADD",
|
||||
"SET_UPDATE",
|
||||
"SETUP_ANNOTATIONS",
|
||||
"SETUP_ASYNC_WITH",
|
||||
"SETUP_EXCEPT",
|
||||
"SETUP_FINALLY",
|
||||
"SETUP_LOOP",
|
||||
"SETUP_WITH",
|
||||
"STORE_ATTR",
|
||||
"STORE_DEREF",
|
||||
"STORE_FAST",
|
||||
"STORE_GLOBAL",
|
||||
"STORE_NAME",
|
||||
"STORE_SLICE",
|
||||
"STORE_SUBSCR",
|
||||
"SWAP",
|
||||
"UNARY_INVERT",
|
||||
"UNARY_NEGATIVE",
|
||||
"UNARY_NOT",
|
||||
"UNARY_POSITIVE",
|
||||
"UNPACK_EX",
|
||||
"UNPACK_SEQUENCE",
|
||||
"WITH_CLEANUP_FINISH",
|
||||
"WITH_CLEANUP_START",
|
||||
"WITH_EXCEPT_START",
|
||||
"YIELD_FROM",
|
||||
"YIELD_VALUE",
|
||||
"<mask_0>",
|
||||
"<mask_1>",
|
||||
"<mask_2>",
|
||||
"<mask_3>",
|
||||
"<mask_4>",
|
||||
"<mask_5>",
|
||||
"<mask_6>",
|
||||
"<mask_7>",
|
||||
"<mask_8>",
|
||||
"<mask_9>",
|
||||
"<mask_10>",
|
||||
"<mask_11>",
|
||||
"<mask_12>",
|
||||
"<mask_13>",
|
||||
"<mask_14>",
|
||||
"<mask_15>",
|
||||
"<mask_16>",
|
||||
"<mask_17>",
|
||||
"<mask_18>",
|
||||
"<mask_19>",
|
||||
"<mask_20>",
|
||||
"<mask_21>",
|
||||
"<mask_22>",
|
||||
"<mask_23>",
|
||||
"<mask_24>",
|
||||
"<mask_25>",
|
||||
"<mask_26>",
|
||||
"<mask_27>",
|
||||
"<mask_28>",
|
||||
"<mask_29>",
|
||||
"<mask_30>",
|
||||
"<mask_31>",
|
||||
"<mask_32>",
|
||||
"<mask_33>",
|
||||
"<mask_34>",
|
||||
"<mask_35>",
|
||||
"<mask_36>",
|
||||
"<mask_37>",
|
||||
"<mask_38>",
|
||||
"<mask_39>",
|
||||
"<mask_40>",
|
||||
"<mask_41>",
|
||||
"<mask_42>",
|
||||
"<mask_43>",
|
||||
"<mask_44>",
|
||||
"<mask_45>",
|
||||
"<mask_46>",
|
||||
"<mask_47>",
|
||||
"<mask_48>",
|
||||
"<mask_49>",
|
||||
"<mask_50>",
|
||||
"<mask_51>",
|
||||
"<mask_52>",
|
||||
"<mask_53>",
|
||||
"<mask_54>",
|
||||
"<mask_55>",
|
||||
"<mask_56>",
|
||||
"<mask_57>",
|
||||
"<mask_58>",
|
||||
"<mask_59>",
|
||||
"<mask_60>",
|
||||
"<mask_61>",
|
||||
"<mask_62>",
|
||||
"<mask_63>",
|
||||
"<mask_64>",
|
||||
"<mask_65>",
|
||||
"<mask_66>",
|
||||
"<mask_67>",
|
||||
"<mask_68>",
|
||||
"<mask_69>",
|
||||
"<mask_70>",
|
||||
"<mask_71>",
|
||||
"<mask_72>",
|
||||
"<mask_73>",
|
||||
"<mask_74>",
|
||||
"<mask_75>",
|
||||
"<mask_76>",
|
||||
"<mask_77>",
|
||||
"<mask_78>",
|
||||
"<mask_79>",
|
||||
"<mask_80>",
|
||||
"<mask_81>",
|
||||
"<mask_82>",
|
||||
"<mask_83>",
|
||||
"<mask_84>",
|
||||
"<mask_85>",
|
||||
"<mask_86>",
|
||||
"<mask_87>",
|
||||
"<mask_88>",
|
||||
"<mask_89>",
|
||||
"<mask_90>",
|
||||
"<mask_91>",
|
||||
"<mask_92>",
|
||||
"<mask_93>",
|
||||
"<mask_94>",
|
||||
"<mask_95>",
|
||||
"<mask_96>",
|
||||
"<mask_97>",
|
||||
"<mask_98>",
|
||||
"<mask_99>",
|
||||
"<mask_100>",
|
||||
"<mask_101>",
|
||||
"<mask_102>",
|
||||
"<mask_103>",
|
||||
"<mask_104>",
|
||||
"<mask_105>",
|
||||
"<mask_106>",
|
||||
"<mask_107>",
|
||||
"<mask_108>",
|
||||
"<mask_109>",
|
||||
"<mask_110>",
|
||||
"<mask_111>",
|
||||
"<mask_112>",
|
||||
"<mask_113>",
|
||||
"<mask_114>",
|
||||
"<mask_115>",
|
||||
"<mask_116>",
|
||||
"<mask_117>",
|
||||
"<mask_118>",
|
||||
"<mask_119>",
|
||||
"<mask_120>",
|
||||
"<mask_121>",
|
||||
"<mask_122>",
|
||||
"<mask_123>",
|
||||
"<mask_124>",
|
||||
"<mask_125>",
|
||||
"<mask_126>",
|
||||
"<mask_127>",
|
||||
"<mask_128>",
|
||||
"<mask_129>",
|
||||
"<mask_130>",
|
||||
"<mask_131>",
|
||||
"<mask_132>",
|
||||
"<mask_133>",
|
||||
"<mask_134>",
|
||||
"<mask_135>",
|
||||
"<mask_136>",
|
||||
"<mask_137>",
|
||||
"<mask_138>",
|
||||
"<mask_139>",
|
||||
"<mask_140>",
|
||||
"<mask_141>",
|
||||
"<mask_142>",
|
||||
"<mask_143>",
|
||||
"<mask_144>",
|
||||
"<mask_145>",
|
||||
"<mask_146>",
|
||||
"<mask_147>",
|
||||
"<mask_148>",
|
||||
"<mask_149>",
|
||||
"<mask_150>",
|
||||
"<mask_151>",
|
||||
"<mask_152>",
|
||||
"<mask_153>",
|
||||
"<mask_154>",
|
||||
"<mask_155>",
|
||||
"<mask_156>",
|
||||
"<mask_157>",
|
||||
"<mask_158>",
|
||||
"<mask_159>",
|
||||
"<mask_160>",
|
||||
"<mask_161>",
|
||||
"<mask_162>",
|
||||
"<mask_163>",
|
||||
"<mask_164>",
|
||||
"<mask_165>",
|
||||
"<mask_166>",
|
||||
"<mask_167>",
|
||||
"<mask_168>",
|
||||
"<mask_169>",
|
||||
"<mask_170>",
|
||||
"<mask_171>",
|
||||
"<mask_172>",
|
||||
"<mask_173>",
|
||||
"<mask_174>",
|
||||
"<mask_175>",
|
||||
"<mask_176>",
|
||||
"<mask_177>",
|
||||
"<mask_178>",
|
||||
"<mask_179>",
|
||||
"<mask_180>",
|
||||
"<mask_181>",
|
||||
"<mask_182>",
|
||||
"<mask_183>",
|
||||
"<mask_184>",
|
||||
"<mask_185>",
|
||||
"<mask_186>",
|
||||
"<mask_187>",
|
||||
"<mask_188>",
|
||||
"<mask_189>",
|
||||
"<mask_190>",
|
||||
"<mask_191>",
|
||||
"<mask_192>",
|
||||
"<mask_193>",
|
||||
"<mask_194>",
|
||||
"<mask_195>",
|
||||
"<mask_196>",
|
||||
"<mask_197>",
|
||||
"<mask_198>",
|
||||
"<mask_199>",
|
||||
"<mask_200>",
|
||||
"<mask_201>",
|
||||
"<mask_202>",
|
||||
"<mask_203>",
|
||||
"<mask_204>",
|
||||
"<mask_205>",
|
||||
"<mask_206>",
|
||||
"<mask_207>",
|
||||
"<mask_208>",
|
||||
"<mask_209>",
|
||||
"<mask_210>",
|
||||
"<mask_211>",
|
||||
"<mask_212>",
|
||||
"<mask_213>",
|
||||
"<mask_214>",
|
||||
"<mask_215>",
|
||||
"<mask_216>",
|
||||
"<mask_217>",
|
||||
"<mask_218>",
|
||||
"<mask_219>",
|
||||
"<mask_220>",
|
||||
"<mask_221>",
|
||||
"<mask_222>",
|
||||
"<mask_223>",
|
||||
"<mask_224>",
|
||||
"<mask_225>",
|
||||
"<mask_226>",
|
||||
"<mask_227>",
|
||||
"<mask_228>",
|
||||
"<mask_229>",
|
||||
"<mask_230>",
|
||||
"<mask_231>",
|
||||
"<mask_232>",
|
||||
"<mask_233>",
|
||||
"<mask_234>",
|
||||
"<mask_235>",
|
||||
"<mask_236>",
|
||||
"<mask_237>",
|
||||
"<mask_238>",
|
||||
"<mask_239>",
|
||||
"<mask_240>",
|
||||
"<mask_241>",
|
||||
"<mask_242>",
|
||||
"<mask_243>",
|
||||
"<mask_244>",
|
||||
"<mask_245>",
|
||||
"<mask_246>",
|
||||
"<mask_247>",
|
||||
"<mask_248>",
|
||||
"<mask_249>",
|
||||
"<mask_250>",
|
||||
"<mask_251>",
|
||||
"<mask_252>",
|
||||
"<mask_253>",
|
||||
"<mask_254>",
|
||||
"<mask_255>",
|
||||
"<mask_256>",
|
||||
"<mask_257>",
|
||||
"<mask_258>",
|
||||
"<mask_259>",
|
||||
"<mask_260>",
|
||||
"<mask_261>",
|
||||
"<mask_262>",
|
||||
"<mask_263>",
|
||||
"<mask_264>",
|
||||
"<mask_265>",
|
||||
"<mask_266>",
|
||||
"<mask_267>",
|
||||
"<mask_268>",
|
||||
"<mask_269>",
|
||||
"<mask_270>",
|
||||
"<mask_271>",
|
||||
"<mask_272>",
|
||||
"<mask_273>",
|
||||
"<mask_274>",
|
||||
"<mask_275>",
|
||||
"<mask_276>",
|
||||
"<mask_277>",
|
||||
"<mask_278>",
|
||||
"<mask_279>",
|
||||
"<mask_280>",
|
||||
"<mask_281>",
|
||||
"<mask_282>",
|
||||
"<mask_283>",
|
||||
"<mask_284>",
|
||||
"<mask_285>",
|
||||
"<mask_286>",
|
||||
"<mask_287>",
|
||||
"<mask_288>",
|
||||
"<mask_289>",
|
||||
"<mask_290>",
|
||||
"<mask_291>",
|
||||
"<mask_292>",
|
||||
"<mask_293>",
|
||||
"<mask_294>",
|
||||
"<mask_295>",
|
||||
"<mask_296>",
|
||||
"<mask_297>",
|
||||
"<mask_298>",
|
||||
"<mask_299>",
|
||||
"<mask_300>",
|
||||
"<mask_301>",
|
||||
"<mask_302>",
|
||||
"<mask_303>",
|
||||
"<mask_304>",
|
||||
"<mask_305>",
|
||||
"<mask_306>",
|
||||
"<mask_307>",
|
||||
"<mask_308>",
|
||||
"<mask_309>",
|
||||
"<mask_310>",
|
||||
"<mask_311>",
|
||||
"<mask_312>",
|
||||
"<mask_313>",
|
||||
"<mask_314>",
|
||||
"<mask_315>",
|
||||
"<mask_316>",
|
||||
"<mask_317>",
|
||||
"<mask_318>",
|
||||
"<mask_319>",
|
||||
"<mask_320>",
|
||||
"<mask_321>",
|
||||
"<mask_322>",
|
||||
"<mask_323>",
|
||||
"<mask_324>",
|
||||
"<mask_325>",
|
||||
"<mask_326>",
|
||||
"<mask_327>",
|
||||
"<mask_328>",
|
||||
"<mask_329>",
|
||||
"<mask_330>",
|
||||
"<mask_331>",
|
||||
"<mask_332>",
|
||||
"<mask_333>",
|
||||
"<mask_334>",
|
||||
"<mask_335>",
|
||||
"<mask_336>",
|
||||
"<mask_337>",
|
||||
"<mask_338>",
|
||||
"<mask_339>",
|
||||
"<mask_340>",
|
||||
"<mask_341>",
|
||||
"<mask_342>",
|
||||
"<mask_343>",
|
||||
"<mask_344>",
|
||||
"<mask_345>",
|
||||
"<mask_346>",
|
||||
"<mask_347>",
|
||||
"<mask_348>",
|
||||
"<mask_349>",
|
||||
"<mask_350>",
|
||||
"<mask_351>",
|
||||
"<mask_352>",
|
||||
"<mask_353>",
|
||||
"<mask_354>",
|
||||
"<mask_355>",
|
||||
"<mask_356>",
|
||||
"<mask_357>",
|
||||
"<mask_358>",
|
||||
"<mask_359>",
|
||||
"<mask_360>",
|
||||
"<mask_361>",
|
||||
"<mask_362>",
|
||||
"<mask_363>",
|
||||
"<mask_364>",
|
||||
"<mask_365>",
|
||||
"<mask_366>",
|
||||
"<mask_367>",
|
||||
"<mask_368>",
|
||||
"<mask_369>",
|
||||
"<mask_370>",
|
||||
"<mask_371>",
|
||||
"<mask_372>",
|
||||
"<mask_373>",
|
||||
"<mask_374>",
|
||||
"<mask_375>",
|
||||
"<mask_376>",
|
||||
"<mask_377>",
|
||||
"<mask_378>",
|
||||
"<mask_379>",
|
||||
"<mask_380>",
|
||||
"<mask_381>",
|
||||
"<mask_382>",
|
||||
"<mask_383>",
|
||||
"<extra_id_99>",
|
||||
"<extra_id_98>",
|
||||
"<extra_id_97>",
|
||||
"<extra_id_96>",
|
||||
"<extra_id_95>",
|
||||
"<extra_id_94>",
|
||||
"<extra_id_93>",
|
||||
"<extra_id_92>",
|
||||
"<extra_id_91>",
|
||||
"<extra_id_90>",
|
||||
"<extra_id_89>",
|
||||
"<extra_id_88>",
|
||||
"<extra_id_87>",
|
||||
"<extra_id_86>",
|
||||
"<extra_id_85>",
|
||||
"<extra_id_84>",
|
||||
"<extra_id_83>",
|
||||
"<extra_id_82>",
|
||||
"<extra_id_81>",
|
||||
"<extra_id_80>",
|
||||
"<extra_id_79>",
|
||||
"<extra_id_78>",
|
||||
"<extra_id_77>",
|
||||
"<extra_id_76>",
|
||||
"<extra_id_75>",
|
||||
"<extra_id_74>",
|
||||
"<extra_id_73>",
|
||||
"<extra_id_72>",
|
||||
"<extra_id_71>",
|
||||
"<extra_id_70>",
|
||||
"<extra_id_69>",
|
||||
"<extra_id_68>",
|
||||
"<extra_id_67>",
|
||||
"<extra_id_66>",
|
||||
"<extra_id_65>",
|
||||
"<extra_id_64>",
|
||||
"<extra_id_63>",
|
||||
"<extra_id_62>",
|
||||
"<extra_id_61>",
|
||||
"<extra_id_60>",
|
||||
"<extra_id_59>",
|
||||
"<extra_id_58>",
|
||||
"<extra_id_57>",
|
||||
"<extra_id_56>",
|
||||
"<extra_id_55>",
|
||||
"<extra_id_54>",
|
||||
"<extra_id_53>",
|
||||
"<extra_id_52>",
|
||||
"<extra_id_51>",
|
||||
"<extra_id_50>",
|
||||
"<extra_id_49>",
|
||||
"<extra_id_48>",
|
||||
"<extra_id_47>",
|
||||
"<extra_id_46>",
|
||||
"<extra_id_45>",
|
||||
"<extra_id_44>",
|
||||
"<extra_id_43>",
|
||||
"<extra_id_42>",
|
||||
"<extra_id_41>",
|
||||
"<extra_id_40>",
|
||||
"<extra_id_39>",
|
||||
"<extra_id_38>",
|
||||
"<extra_id_37>",
|
||||
"<extra_id_36>",
|
||||
"<extra_id_35>",
|
||||
"<extra_id_34>",
|
||||
"<extra_id_33>",
|
||||
"<extra_id_32>",
|
||||
"<extra_id_31>",
|
||||
"<extra_id_30>",
|
||||
"<extra_id_29>",
|
||||
"<extra_id_28>",
|
||||
"<extra_id_27>",
|
||||
"<extra_id_26>",
|
||||
"<extra_id_25>",
|
||||
"<extra_id_24>",
|
||||
"<extra_id_23>",
|
||||
"<extra_id_22>",
|
||||
"<extra_id_21>",
|
||||
"<extra_id_20>",
|
||||
"<extra_id_19>",
|
||||
"<extra_id_18>",
|
||||
"<extra_id_17>",
|
||||
"<extra_id_16>",
|
||||
"<extra_id_15>",
|
||||
"<extra_id_14>",
|
||||
"<extra_id_13>",
|
||||
"<extra_id_12>",
|
||||
"<extra_id_11>",
|
||||
"<extra_id_10>",
|
||||
"<extra_id_9>",
|
||||
"<extra_id_8>",
|
||||
"<extra_id_7>",
|
||||
"<extra_id_6>",
|
||||
"<extra_id_5>",
|
||||
"<extra_id_4>",
|
||||
"<extra_id_3>",
|
||||
"<extra_id_2>",
|
||||
"<extra_id_1>",
|
||||
"<extra_id_0>",
|
||||
"match",
|
||||
"type",
|
||||
"HAVE_ARGUMENT",
|
||||
"CALL_INTRINSIC_1",
|
||||
"CALL_INTRINSIC_2",
|
||||
"JUMP_NO_INTERRUPT",
|
||||
"nargs",
|
||||
"vargs",
|
||||
"compare",
|
||||
"name",
|
||||
"const",
|
||||
"local"
|
||||
],
|
||||
"bos_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<s>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"cls_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<s>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"eos_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "</s>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"errors": "replace",
|
||||
"mask_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<mask>",
|
||||
"lstrip": true,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"model_max_length": 512,
|
||||
"pad_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<pad>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"sep_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "</s>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"tokenizer_class": "RobertaTokenizer",
|
||||
"trim_offsets": true,
|
||||
"unk_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<unk>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
import os
|
||||
import pathlib
|
||||
import time
|
||||
from datetime import timedelta
|
||||
import click
|
||||
|
||||
from datasets import ReadInstruction, load_dataset
|
||||
from StatementConfiguration import StatementConfiguration, parse_statement_config_json
|
||||
from transformers import (
|
||||
DataCollatorForSeq2Seq,
|
||||
RobertaTokenizer,
|
||||
Seq2SeqTrainer,
|
||||
Seq2SeqTrainingArguments,
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
|
||||
|
||||
def load_tokenized_train_dataset(dataset_repo_name: str, dataset_percentage: int):
|
||||
# Load the tokenized dataset
|
||||
tokenized_train_dataset = load_dataset(
|
||||
dataset_repo_name,
|
||||
token=True,
|
||||
split=ReadInstruction("train", to=dataset_percentage, unit="%"),
|
||||
)
|
||||
return tokenized_train_dataset
|
||||
|
||||
|
||||
def train_statement_model(config: StatementConfiguration):
|
||||
# load model, Salesforce/codet5-base is a pretrained model solving the code generation task.
|
||||
tokenizer = RobertaTokenizer.from_pretrained(config.tokenizer_repo_name)
|
||||
model = T5ForConditionalGeneration.from_pretrained(config.pretrained_seq2seq_repo_name)
|
||||
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
|
||||
|
||||
model_dir = str(config.statement_model_dir)
|
||||
model_repo_name = config.statement_model_repo_name
|
||||
|
||||
train_args = Seq2SeqTrainingArguments(
|
||||
output_dir=model_dir,
|
||||
learning_rate=config.statement_training_parameters.learning_rate,
|
||||
per_device_train_batch_size=config.statement_training_parameters.batch_size,
|
||||
per_device_eval_batch_size=config.statement_training_parameters.batch_size,
|
||||
weight_decay=0.01,
|
||||
fp16=config.fp16,
|
||||
logging_dir=str(config.log_dir),
|
||||
report_to="tensorboard",
|
||||
logging_strategy="steps",
|
||||
logging_steps=1000,
|
||||
save_strategy="steps",
|
||||
save_steps=10000,
|
||||
save_total_limit=2,
|
||||
num_train_epochs=config.statement_training_parameters.epochs,
|
||||
predict_with_generate=True,
|
||||
push_to_hub=True,
|
||||
hub_model_id=model_repo_name,
|
||||
hub_private_repo=True,
|
||||
ddp_backend="nccl",
|
||||
ddp_find_unused_parameters=False,
|
||||
)
|
||||
|
||||
tokenized_train_dataset = load_tokenized_train_dataset(config.tokenized_dataset_repo_name, config.dataset_percentage)
|
||||
trainer = Seq2SeqTrainer(
|
||||
model=model,
|
||||
args=train_args,
|
||||
data_collator=data_collator,
|
||||
train_dataset=tokenized_train_dataset,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
start = time.time()
|
||||
trainer.train()
|
||||
duration = str(timedelta(seconds=time.time() - start))
|
||||
|
||||
if int(os.environ["LOCAL_RANK"]) == 0:
|
||||
# upload the latest version of the model to the Model Hub on Huggingface
|
||||
trainer.save_model(str(config.statement_model_dir))
|
||||
# this command returns the URL of the commit it just did
|
||||
trainer.push_to_hub(
|
||||
commit_message=duration,
|
||||
finetuned_from=config.pretrained_seq2seq_repo_name,
|
||||
dataset=config.tokenized_dataset_repo_name,
|
||||
)
|
||||
|
||||
|
||||
@click.command(help="Training script for the statement translation model given a statement json.")
|
||||
@click.argument("json_path", type=str)
|
||||
def main(json_path: str):
|
||||
json_file_path = pathlib.Path(json_path)
|
||||
statement_config = parse_statement_config_json(json_file_path)
|
||||
train_statement_model(statement_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,93 @@
|
||||
import logging
|
||||
import pathlib
|
||||
import click
|
||||
|
||||
from datasets import ReadInstruction, load_dataset
|
||||
from huggingface_hub import HfApi, repo_exists
|
||||
from StatementConfiguration import StatementConfiguration, parse_statement_config_json
|
||||
from tokenizers import Tokenizer
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
def get_untrained_tokenizer(tokenizer_repo_name: str) -> AutoTokenizer:
|
||||
tokenizer_dir = pathlib.Path(__file__).parent / tokenizer_repo_name
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
|
||||
return tokenizer
|
||||
|
||||
|
||||
def save_and_upload_tokenizer(
|
||||
tokenizer: Tokenizer,
|
||||
tokenizer_json_path: pathlib.Path,
|
||||
tokenizer_repo_name: str,
|
||||
dataset_name: str,
|
||||
):
|
||||
# Save the tokenizer locally
|
||||
tokenizer.save_pretrained(str(tokenizer_json_path.parent.resolve()))
|
||||
|
||||
# Upload files to Hugging Face Hub
|
||||
api = HfApi()
|
||||
api.create_repo(tokenizer_repo_name, exist_ok=True, private=True)
|
||||
api.upload_file(
|
||||
path_in_repo="tokenizer_config.json",
|
||||
path_or_fileobj=str(tokenizer_json_path.parent / "tokenizer_config.json"),
|
||||
repo_id=tokenizer_repo_name,
|
||||
commit_message=f"Trained tokenizer using {dataset_name}",
|
||||
)
|
||||
api.upload_file(
|
||||
path_in_repo="vocab.json",
|
||||
path_or_fileobj=str(tokenizer_json_path.parent / "vocab.json"),
|
||||
repo_id=tokenizer_repo_name,
|
||||
commit_message="Extracted vocabulary from tokenizer",
|
||||
)
|
||||
api.upload_file(
|
||||
path_in_repo="merges.txt",
|
||||
path_or_fileobj=str(tokenizer_json_path.parent / "merges.txt"),
|
||||
repo_id=tokenizer_repo_name,
|
||||
commit_message="Extracted merges from tokenizer",
|
||||
)
|
||||
api.upload_file(
|
||||
path_in_repo="tokenizer.json",
|
||||
path_or_fileobj=str(tokenizer_json_path.parent / "tokenizer.json"),
|
||||
repo_id=tokenizer_repo_name,
|
||||
commit_message="Extracted tokenizer",
|
||||
)
|
||||
api.upload_file(
|
||||
path_in_repo="special_tokens_map.json",
|
||||
path_or_fileobj=str(tokenizer_json_path.parent / "special_tokens_map.json"),
|
||||
repo_id=tokenizer_repo_name,
|
||||
commit_message="Extracted special tokens map",
|
||||
)
|
||||
|
||||
|
||||
def train_tokenizer(config: StatementConfiguration, tokenizer_json_path: pathlib.Path):
|
||||
if repo_exists(config.base_repo_name):
|
||||
logging.error(f"{config.base_repo_name} has already exists")
|
||||
exit(1)
|
||||
|
||||
tokenizer = get_untrained_tokenizer("tokenizer")
|
||||
|
||||
train_dataset = load_dataset(
|
||||
config.dataset_repo_name,
|
||||
token=True,
|
||||
split=ReadInstruction("train", to=config.dataset_percentage, unit="%"),
|
||||
)["bytecode"]
|
||||
|
||||
tokenizer = tokenizer.train_new_from_iterator(train_dataset, vocab_size=30000)
|
||||
save_and_upload_tokenizer(
|
||||
tokenizer,
|
||||
tokenizer_json_path,
|
||||
config.tokenizer_repo_name,
|
||||
config.dataset_repo_name,
|
||||
)
|
||||
|
||||
|
||||
@click.command(help="Training script for the bytecode tokenizer for the statement model given a statement json.")
|
||||
@click.argument("json_path", type=str)
|
||||
def main(json_path: str):
|
||||
json_file_path = pathlib.Path(json_path)
|
||||
statement_config = parse_statement_config_json(json_file_path)
|
||||
train_tokenizer(statement_config, json_file_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,117 @@
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import subprocess
|
||||
import click
|
||||
|
||||
from pylingual.utils.get_logger import get_logger
|
||||
|
||||
|
||||
def train_segmentation(segmentation_config_path: pathlib.Path, logger: logging.Logger, nnodes: int = 1, nproc_per_node: int = 1, rdzv_port: int = 29400):
|
||||
segmentation_root = pathlib.Path(__file__).parent / "segmentation"
|
||||
|
||||
# train tokenizer
|
||||
logger.info("training tokenizer...")
|
||||
subprocess.run(["python", segmentation_root / "train_tokenizer.py", segmentation_config_path])
|
||||
|
||||
# train mlm (single gpu to avoid conflicts with local tokenized data)
|
||||
logger.info("training masked language model...")
|
||||
subprocess.run(
|
||||
[
|
||||
"torchrun",
|
||||
f"--nnodes={nnodes}",
|
||||
f"--nproc-per-node={nproc_per_node}",
|
||||
"--rdzv-backend=c10d",
|
||||
f"--rdzv-endpoint=localhost:{rdzv_port}",
|
||||
segmentation_root / "train_mlm.py",
|
||||
segmentation_config_path,
|
||||
],
|
||||
env=dict(os.environ, NCCL_P2P_DISABLE="1"),
|
||||
)
|
||||
|
||||
# tokenize dataset
|
||||
logger.info("tokenizing segmentation dataset...")
|
||||
subprocess.run(["python", segmentation_root / "tokenize_seg.py", segmentation_config_path])
|
||||
|
||||
# train segmentation model (4 gpus)
|
||||
logger.info("training segmentation model...")
|
||||
subprocess.run(
|
||||
[
|
||||
"torchrun",
|
||||
f"--nnodes={nnodes}",
|
||||
f"--nproc-per-node={nproc_per_node}",
|
||||
"--rdzv-backend=c10d",
|
||||
f"--rdzv-endpoint=localhost:{rdzv_port}",
|
||||
segmentation_root / "train_seg.py",
|
||||
segmentation_config_path,
|
||||
],
|
||||
env=dict(os.environ, NCCL_P2P_DISABLE="1"),
|
||||
)
|
||||
|
||||
|
||||
def train_statement(statement_config_path: pathlib.Path, logger: logging.Logger, nnodes: int = 1, nproc_per_node: int = 1, rdzv_port: int = 29400):
|
||||
statement_root = pathlib.Path(__file__).parent / "statement"
|
||||
|
||||
# manual tokenizer
|
||||
subprocess.run(["python", statement_root / "train_tokenizer_auto.py", statement_config_path])
|
||||
|
||||
# tokenize statement dataset with salesforce tokenizer
|
||||
logger.info("tokenizing statement dataset...")
|
||||
subprocess.run(["python", statement_root / "tokenize_seq2seq.py", statement_config_path])
|
||||
|
||||
# train statement model (4 gpus)
|
||||
logger.info("training statement model...")
|
||||
subprocess.run(
|
||||
[
|
||||
"torchrun",
|
||||
f"--nnodes={nnodes}",
|
||||
f"--nproc-per-node={nproc_per_node}",
|
||||
"--rdzv-backend=c10d",
|
||||
f"--rdzv-endpoint=localhost:{rdzv_port}",
|
||||
statement_root / "train_seq2seq.py",
|
||||
statement_config_path,
|
||||
],
|
||||
env=dict(os.environ, NCCL_P2P_DISABLE="1"),
|
||||
)
|
||||
|
||||
|
||||
@click.command(help="Full tokenization and training pipeline for the segmentation and statement translation models.")
|
||||
@click.option("--segmentation", type=str, default=None, help="The path to the segmentation model description JSON file.")
|
||||
@click.option("--statement", type=str, default=None, help="The path to the statement model description JSON file.")
|
||||
@click.option("--nnodes", type=int, default=1, help="Torchrun nnodes arg")
|
||||
@click.option("--nproc_per_node", type=int, default=1, help="Torchrun nproc_per_node arg")
|
||||
@click.option("--rdzv_port", "-p", type=int, default=29400, help="Port to use for torchrun rendezvous endpoint")
|
||||
def main(segmentation: str, statement: str, nnodes: int, nproc_per_node: int, rdzv_port: int):
|
||||
logger = get_logger("train-models")
|
||||
|
||||
### LOAD JSON
|
||||
logger.info("Training pipeline starting...")
|
||||
logger.info("Loading dataset description JSON files...")
|
||||
|
||||
### CONFIG_PATHS
|
||||
segmentation_config_path = pathlib.Path(segmentation).resolve() if segmentation is not None else None
|
||||
statement_config_path = pathlib.Path(statement).resolve() if statement is not None else None
|
||||
|
||||
logger.info("Dataset description JSON files loaded!")
|
||||
|
||||
### TRAIN SEGMENTATION
|
||||
if segmentation_config_path is not None:
|
||||
logger.info("Segmentation model training starting...")
|
||||
train_segmentation(segmentation_config_path, logger, nnodes, nproc_per_node, rdzv_port)
|
||||
logger.info("Segmentation model training complete!")
|
||||
else:
|
||||
logger.warning("Segmentation model configuration json path not provided in --segmentation; skipping segmentation model training...")
|
||||
|
||||
### TRAIN STATEMENT
|
||||
if statement_config_path is not None:
|
||||
logger.info("Statement model training starting...")
|
||||
train_statement(statement_config_path, logger, nnodes, nproc_per_node, rdzv_port)
|
||||
logger.info("Statement model training complete!")
|
||||
else:
|
||||
logger.warning("Statement model configuration json path not provided in --statement; skipping statement model training...")
|
||||
|
||||
logger.info("Training pipeline complete!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user