This commit is contained in:
caandt
2025-03-13 16:56:36 -05:00
parent b2439eee3e
commit 046e80cdd1
27 changed files with 0 additions and 0 deletions
+48
View File
@@ -0,0 +1,48 @@
# Model Training
PyLingual's accuracy is dependent on having accurate segmentation and statement models [^1]. The segmentation model divides a list of bytecode instructions into groups for each source instruction. The statement model transforms each group of instructions into source code. The instructions for training these models is as follows:
## Dataset generation
First install [pyenv](https://github.com/pyenv/pyenv) and the required Python versions for the dataset. Create a dataset JSON file based off the sample (`sample_jsons/py36-sample-data.json`).
The dataset directory should be structured like so, with only one `.py` file per directory:
```
dataset
├── 0
│   └── file.py
├── 1
│   └── file.py
...
├── 999
│   └── file.py
└── 1000
└── file.py
```
The names of the inner directories and files do not matter. Then create the dataset:
```
python prepare_dataset.py <path to JSON>
```
## Segmentation model
Create a segmentation model JSON file based off the sample (`sample_jsons/py36-sample-segmentation.json`). Then train the model:
```
python train_models.py --segmentation <path to JSON>
```
## Statement model
Create a statement model JSON file based off the sample (`sample_jsons/py36-sample-statement.json`). Then train the model:
```
python train_models.py --statement <path to JSON>
```
Once models are trained, update `../pylingual/decompiler_config.yaml` or create a separate config file by replacing the old models with the newly trained ones.
[^1]: [pylingual models](https://huggingface.co/syssec-utd).
+173
View File
@@ -0,0 +1,173 @@
import contextlib
import difflib
import json
import multiprocessing
import os
import warnings
from collections import defaultdict
from enum import Enum
from functools import partial
from pathlib import Path
from tempfile import TemporaryDirectory
import click
import rich
from rich.progress import track
from pylingual.control_flow_reconstruction.cfg import CFG
from pylingual.control_flow_reconstruction.structure import bc_to_cft
from pylingual.main import print_result
from pylingual.control_flow_reconstruction.source import SourceContext
from pylingual.editable_bytecode import PYCFile
from pylingual.equivalence_check import compare_pyc
from pylingual.utils.version import PythonVersion
from pylingual.utils.generate_bytecode import compile_version, CompileError
from dataset_generation.normalize_source import normalize_source
class Result(Enum):
Success = "success"
Failure = "failure"
Error = "error"
CompileError = "compile_error"
def edit_pyc_lines(pyc: PYCFile, src_lines: list[str]):
if pyc.version == (3, 10):
pyc.replace_duplicated_returns10(src_lines)
elif pyc.version == (3, 12):
pyc.replace_duplicated_returns12(src_lines)
seen_lines = set()
# multiple instructions can start the same lno, but the segmentation model will only assign the lno to the first one
for bc in pyc.iter_bytecodes():
if bc.is_comprehension:
continue
# create a dict of line num : [bytecodes composing line]
lno_bytecodes = bc.get_lno_insts(previously_seen_lines=seen_lines)
seen_lines.update(lno_bytecodes.keys())
for lno, line_insts in lno_bytecodes.items():
line_insts[0].starts_line = lno
for inst in line_insts[1:]:
inst.starts_line = None
def run(file: Path, out_dir: Path, version: PythonVersion, print=False):
try:
out_dir = get_unused(out_dir / file.stem, False)
out_dir.mkdir(parents=True)
if file.is_dir():
file = next(file.iterdir())
in_src = normalize_source(file.read_text(), replace_docstrings=True)
src_lines = in_src.split("\n")
in_path = out_dir / "a.py"
in_path.write_text(in_src, encoding="utf-8")
in_pyc = out_dir / "a.pyc"
compile_version(in_path, in_pyc, version)
pyc = PYCFile(in_pyc)
edit_pyc_lines(pyc, src_lines)
cfts = {bc.codeobj: bc_to_cft(bc) for bc in pyc.iter_bytecodes()}
out_src = normalize_source(str(SourceContext(pyc, src_lines, cfts)))
out_path = out_dir / "b.py"
out_path.write_text(out_src, encoding="utf-8")
out_pyc = out_dir / "b.pyc"
compile_version(out_path, out_pyc, version)
result = compare_pyc(in_pyc, out_pyc)
if print:
print_result(f"Equivalance results for {file}", result)
return Result.Success if all(x.success for x in result) else Result.Failure, file
except (CompileError, SyntaxError):
return Result.CompileError, file
except Exception:
rich.get_console().print_exception()
return Result.Error, file
class NoPool:
imap_unordered = map
def print_diff(a: Path, b: Path):
a_lines = a.read_text().split("\n")
b_lines = b.read_text().split("\n")
console = rich.console.Console(highlight=False)
for line in difflib.unified_diff(a_lines, b_lines, str(a), str(b)):
style = "red" if line[0] == "-" else "green" if line[0] == "+" else "blue" if line[0] == "@" else ""
console.print(line, style=style)
def get_unused(a: Path, _=True):
if not _ and not a.exists():
return a
stem = a.stem
i = 0
while True:
a = a.with_stem(f"{stem}_{i}")
i += 1
if not a.exists():
return a
@click.command(help="Run the control-flow reconstructor")
@click.argument("input", type=Path)
@click.argument("output", type=str, default="")
@click.option("-v", "--version", type=PythonVersion, default=PythonVersion((3, 12)), help="Python version to compile as")
@click.option("-p", "--processes", type=int, default=os.cpu_count(), help="Number of processes")
@click.option("-d", "--prefix", type=Path, default=Path("/tmp/cflow_test"), help="Base dir for all output")
@click.option("-g", "--graph", is_flag=False, flag_value="graph", help="Enable CFG visualization")
@click.option("-f", "--graph-format", default="jpg", help="Output format supported by pydot")
def main(input: Path, output: str, version: PythonVersion, graph: str | None, prefix: Path, processes: int, graph_format: str):
warnings.filterwarnings("ignore")
print = rich.get_console().print
if graph:
CFG.enable_graphing(prefix / graph, graph_format)
if input.is_file() and input.suffix == ".py":
if output:
out = contextlib.nullcontext(output)
else:
out = TemporaryDirectory()
with out as o:
o = Path(o)
results = run(input, o, version)[0]
if results in [Result.CompileError, Result.Error]:
print(results)
else:
print_diff(o / input.name / "a.py", o / input.name / "b.py")
else:
if not output:
out_dir = get_unused(prefix / input.stem)
else:
out_dir = prefix / output
print(f"Saving results to {out_dir}")
results = defaultdict(list)
f = partial(run, out_dir=out_dir, version=version)
if input.is_dir():
files = list(input.iterdir())
else:
files = list(map(Path, input.read_text().strip().split("\n")))
if processes > 1:
pool = multiprocessing.Pool(processes=processes)
else:
pool = contextlib.nullcontext(NoPool)
with pool as p:
for result, input in track(p.imap_unordered(f, files), total=len(files)):
results[result].append(input)
for res in Result:
print(f"{res}: {len(results[res])}")
total = sum(len(x) for x in results.values())
if total:
print(f"{len(results[Result.Success])} / {total} succeeded ({len(results[Result.Success]) / total:.3%})")
res = json.dumps({k.value: list(map(str, v)) for k, v in results.items()})
(out_dir / "results.json").write_text(res)
if __name__ == "__main__":
main()
@@ -0,0 +1,50 @@
from dataclasses import dataclass
import pathlib
from typing import Tuple, List
@dataclass
class DataRequest:
name: str
source_path: pathlib.Path
num_train: int
num_test: int
num_valid: int
@property
def total_files(self):
return self.num_train + self.num_test + self.num_valid
def __post_init__(self):
self.source_path = pathlib.Path(self.source_path)
if not self.source_path.exists():
raise FileNotFoundError(f"{self.source_path} for DataRequest {self.name} does not exist")
if self.num_train < 0:
raise ValueError(f"Training sample count for DataRequest {self.name} must be non-negative")
if self.num_test < 0:
raise ValueError(f"Testing sample count for DataRequest {self.name} must be non-negative")
if self.num_valid < 0:
raise ValueError(f"Validation sample count for DataRequest {self.name} must be non-negative")
@dataclass
class DatasetDescription:
name: str
version: Tuple[int, int]
save_to_dir: pathlib.Path
huggingface_user: str
data_requests: List[DataRequest]
@property
def code_dir(self):
return self.save_to_dir / self.name / "code"
@property
def csv_dir(self):
return self.save_to_dir / self.name / "csv"
def __post_init__(self):
self.save_to_dir = pathlib.Path(self.save_to_dir)
self.version = tuple(self.version)
@@ -0,0 +1,3 @@
from .create_code_dataset import transfer_and_compile_file
__all__ = ["transfer_and_compile_file"]
@@ -0,0 +1,216 @@
import csv
import itertools
import logging
import multiprocessing
import pathlib
import re
import signal
from typing import Callable, Tuple
import tqdm
from pylingual.editable_bytecode import PYCFile
from pylingual.masking.ast_masker import DUMMY_DECORATOR
from pylingual.masking.model_disasm import fix_jump_targets
from .DatasetDescription import DataRequest
from pylingual.masking.model_disasm import create_global_masker, mask_source
bytecode_separator = " <SEP> "
source_seperator = " <SEP> "
CSV_SGMT_HEADER = ["source", "bytecode", "boundary", "file"]
CSV_STMT_HEADER = ["source", "bytecode", "file"]
def create_csv_dataset(code_dataset_path: pathlib.Path, csv_dataset_path: pathlib.Path, data_requests: list[DataRequest], logger: logging.Logger = None):
progress_bar = tqdm.tqdm(total=sum([request.total_files for request in data_requests]))
for split in ("train", "test", "valid"):
if logger:
logger.info(f"Converting the {split} split to CSV...")
write_csvs(code_dataset_path / split, csv_dataset_path / split, logger, progress_bar=progress_bar)
def write_csvs(source_path: pathlib.Path, csv_output_path: pathlib.Path, logger: logging.Logger = None, max_csv_rows: int = 30000, progress_bar: tqdm.tqdm = None):
# validate output directory
if csv_output_path.exists():
if not csv_output_path.is_dir():
raise OSError("CSV output path is not a directory")
else:
csv_output_path.mkdir(parents=True)
##### csv write wrappers to preserve csv row limit
def csv_writer(file_prefix: str, csv_header: list) -> Callable:
out_dir = csv_output_path.joinpath(file_prefix)
out_dir.mkdir(exist_ok=True)
for csv_idx in itertools.count():
new_path = out_dir.joinpath(f"{file_prefix}_{csv_idx}.csv")
new_path.touch()
if logger:
logger.info(f"Creating new csv {new_path.resolve()}...")
with new_path.open(mode="w") as csv_file:
writer = csv.writer(csv_file)
writer.writerow(csv_header)
for writer in itertools.repeat(writer, max_csv_rows):
yield writer.writerow
segmentation_writer = csv_writer("segmentation", CSV_SGMT_HEADER)
statement_writer = csv_writer("statement", CSV_STMT_HEADER)
# create dirs
code_dirs = (child for child in source_path.iterdir() if child.is_dir())
def bytecode2csv_args():
for dir in code_dirs:
py_path = next(dir.glob("*.py"), None)
pyc_path = next(dir.glob("*.pyc"), None)
if None in (py_path, pyc_path):
logging.debug(f"PY or PYC file not found in {dir}")
continue
else:
yield (py_path, pyc_path)
num_fails = 0
with multiprocessing.Pool() as pool:
for result in pool.imap_unordered(bytecode2csv_exception_wrapper, bytecode2csv_args()):
if isinstance(result, Exception):
num_fails += 1
logger.debug(f"DIR: {dir}\nERR: {result}\nTYPE ERR: {type(result)}\n")
continue
(segmentation_rows, statement_rows) = result
for row, writerow in zip(segmentation_rows, segmentation_writer):
writerow(row)
for row, writerow in zip(statement_rows, statement_writer):
writerow(row)
if progress_bar:
progress_bar.update()
progress_bar.set_postfix({"num_fails": num_fails})
logger.info(f"NUMBER OF FAILS !!! {num_fails}")
def timeout_handler(signum, frame):
raise TimeoutError()
def bytecode2csv_exception_wrapper(paths=Tuple[pathlib.Path, pathlib.Path]) -> Tuple[list, list] | Exception:
signal.signal(signal.SIGALRM, timeout_handler)
try:
signal.alarm(30) # set 30 second timeout
results = bytecode2csv(*paths)
signal.alarm(0) # success; disable timer
return results
except Exception as error:
signal.alarm(0) # disable timer in case another exception triggered the fail
return Exception(f"{type(error)}: {error} in file {paths}")
def bytecode2csv(py_path: pathlib.Path, pyc_path: pathlib.Path) -> tuple[list, list]:
"""Creates segmentation and statement csv rows for given bytecode and source file"""
segmentation_rows = []
statement_rows = []
pyc = PYCFile(str(pyc_path.resolve()))
if pyc.version == (3, 10):
pyc.replace_duplicated_returns10(py_path.read_text().split("\n"))
elif pyc.version == (3, 12):
pyc.replace_duplicated_returns12(py_path.read_text().split("\n"))
global_masker = create_global_masker(pyc)
masked_source_text = mask_source(py_path, global_masker, pyc.version)
masked_source_lines = masked_source_text.split("\n")
# filter out dummy decorators added in <= 3.7
dummy_lnos = []
if pyc.version <= (3, 7):
# remove dummy decorators from bytecode'
pyc._patch_dummy_decorator(dummy_decorator_name=DUMMY_DECORATOR)
try: # if no functions are in source, then dummy will not exist
dummy_decorator_line = f"@{global_masker.mask(DUMMY_DECORATOR)}"
except KeyError:
dummy_decorator_line = None
dummy_lnos = [lno + 1 for lno, source in enumerate(masked_source_lines) if source.strip() == dummy_decorator_line]
seen_lines = set()
# create rows for each bytecode
for bc in pyc.iter_bytecodes():
# we ignore comprehensions, hoisted later
if bc.is_comprehension:
continue
# attempt to filter lines
lno_insts = bc.get_lno_insts(previously_seen_lines=seen_lines)
# create line num : model disasm view of insts
lno_model_view_insts = {lno: [global_masker.get_model_view(inst) for inst in line_insts] for lno, line_insts in lno_insts.items()}
seen_lines.update(lno_model_view_insts.keys())
# segment source
if pyc.version <= (3, 7):
segmented_source_lines = []
for line_num in lno_model_view_insts:
if not line_num:
segmented_source_lines.append("")
elif line_num in dummy_lnos:
segmented_source_lines.append(masked_source_lines[line_num].strip())
else:
segmented_source_lines.append(masked_source_lines[line_num - 1].strip())
else:
segmented_source_lines = [masked_source_lines[line_num - 1].strip() if line_num else "" for line_num in lno_model_view_insts.keys()] # -1 to convert from line num to index in array
model_disasm_text = bytecode_separator.join(val for val in itertools.chain(*lno_model_view_insts.values()))
if len(segmented_source_lines) != len(lno_model_view_insts):
raise ValueError("Length mismatch between segmented source and segmented bytecodes")
# create bytecode segmentation
boundaries = []
for bc_line in lno_model_view_insts.values():
if len(bc_line) == 1:
bounds = "B"
elif len(bc_line) >= 2:
bounds = "B" + "I" * (len(bc_line) - 2) + "E"
else:
raise ValueError("Unexpected amount of bytecodes segmented into a line")
boundaries.extend(list(bounds))
# append rows
segmentation_rows.append([source_seperator.join(segmented_source_lines), model_disasm_text, boundaries, str(py_path)])
for segmented_source, bytecodes in zip(segmented_source_lines, lno_model_view_insts.values()):
# skip empty lines
if not segmented_source or segmented_source == "None":
continue
# skip fillers
if segmented_source in ("pass", "...") and ("RETURN_VALUE" in bytecodes or "RETURN_CONST , None" in bytecodes):
continue
# skip string-only lines that aren't docstrings
if (segmented_source.startswith("'") or segmented_source.startswith('"')) and not any("__doc__" in b for b in bytecodes):
continue
if segmented_source.startswith("elif "):
segmented_source = segmented_source[2:]
joined_bytecode = bytecode_separator.join(bytecodes)
# DUCT-TAPE; skip samples where model has to guess masks
source_masks = set(re.findall(r"<mask_\d+>", segmented_source))
bytecode_masks = set(re.findall(r"<mask_\d+>", joined_bytecode))
if not source_masks <= bytecode_masks:
continue
# normalize source mask order for statements
# replace mask values to start at 0 and count up
mask_regex = re.compile(r"(?<=<mask_)\d+(?=>)")
masks = mask_regex.findall(joined_bytecode)
mask_order = [x for i, x in enumerate(masks) if masks.index(x) == i]
normalized_mask_bytecode = mask_regex.sub(lambda x: str(mask_order.index(x.group(0))), joined_bytecode)
normalized_mask_source = mask_regex.sub(lambda x: str(mask_order.index(x.group(0))), segmented_source)
# normalize jump targets
normalized_mask_bytecode = fix_jump_targets(normalized_mask_bytecode)
statement_rows.append([normalized_mask_source, normalized_mask_bytecode, str(py_path)])
return (segmentation_rows, statement_rows)
@@ -0,0 +1,114 @@
import itertools
import logging
import multiprocessing
import pathlib
import random
from typing import List, Optional, Set, Tuple
import tqdm
from .DatasetDescription import DataRequest
from pylingual.utils.generate_bytecode import compile_version
from .normalize_source import normalize_source
from pylingual.masking.ast_masker import add_dummy_decorators
def transfer_and_compile_file(
original_file: pathlib.Path,
destination_file: pathlib.Path,
version: Tuple[int, int],
) -> Optional[Exception]:
# copy over normalized source file
try:
normalized_source = normalize_source(original_file.read_text(), version=version, replace_docstrings=True)
if version[:2] <= (3, 7):
normalized_source = add_dummy_decorators(normalized_source)
except Exception as err:
return err
destination_file.parent.mkdir(parents=True, exist_ok=True)
destination_file.write_text(normalized_source)
# compile the copied file with the given version
try:
compile_version(
destination_file.resolve(),
destination_file.with_suffix(".pyc").resolve(),
version,
)
except Exception as err:
return err
return None
def star_transfer_and_compile_file(args) -> Optional[Exception]:
return transfer_and_compile_file(*args)
# samples num_files files from the given directory
# expects the directory to have the structure
# source_dir -> identifier -> file.py
def sample_directory_splits(
data_request: DataRequest,
) -> Tuple[List[pathlib.Path], List[pathlib.Path], List[pathlib.Path]]:
all_files: Set[pathlib.Path] = set()
for identifier in data_request.source_path.iterdir():
source_file = next(identifier.glob("*.py"), None) # get the first python file from the identifier
if source_file is not None:
all_files.add(source_file)
# sample batches until we have enough files to satisfy the data requests
# this avoids running expensive tests on unsampled files
clean_sample: Set[pathlib.Path] = set()
while len(clean_sample) < data_request.total_files:
remaining_files = data_request.total_files - len(clean_sample)
sample_batch = random.sample(list(all_files), k=remaining_files)
# add the acceptable files to the sample and remove them from the population
to_add = set(candidate for candidate in sample_batch if candidate is not None)
clean_sample.update(to_add)
all_files -= to_add
full_sample = iter(clean_sample)
train = list(itertools.islice(full_sample, data_request.num_train))
test = list(itertools.islice(full_sample, data_request.num_test))
valid = list(itertools.islice(full_sample, data_request.num_valid))
return train, test, valid
def prepare_single_directory_transfer_args(data_request: DataRequest, target_dir: pathlib.Path) -> List[Tuple[pathlib.Path, pathlib.Path]]:
train, test, valid = sample_directory_splits(data_request)
transfer_args = []
for split_name, split_files in zip(("train", "test", "valid"), (train, test, valid)):
for source_file in split_files:
target_file = target_dir / split_name / f"{data_request.name}-{source_file.parent.name}" / source_file.name
transfer_args.append((source_file, target_file))
return transfer_args
# takes a dict of {<source directory>: (num_train, num_test, num_valid)} and a target directory
# makes train, test, and split directories in the target directory with the normalized source files
def create_code_dataset(
data_requests: List[DataRequest],
target_dir: pathlib.Path,
version: Tuple[int, int],
logger: logging.Logger,
):
with multiprocessing.Pool() as pool:
# prepare a list of file transfers to execute
logger.info(f"Sampling {', '.join(str(req.source_path.resolve()) for req in data_requests)}...")
transfer_arg_lists = pool.starmap(
prepare_single_directory_transfer_args,
zip(data_requests, itertools.repeat(target_dir)),
)
# execute the file transfers
versioned_transfer_arg_lists = [(source_file, target_file, version) for (source_file, target_file) in itertools.chain(*transfer_arg_lists)]
logger.info(f"Normalizing and Compiling {len(versioned_transfer_arg_lists)} files...")
for error in tqdm.tqdm(pool.imap_unordered(star_transfer_and_compile_file, versioned_transfer_arg_lists), total=len(versioned_transfer_arg_lists)):
if error is not None:
logger.debug(error)
@@ -0,0 +1,67 @@
#!/usr/bin/env python3
import ast
import pathlib
import sys
from typing import Tuple
def version_str_to_tuple(version_str: str) -> tuple[int, int]:
# a version string is a string like 3.9.2
versions = [int(version) for version in version_str.split(".")]
return tuple(versions[:2])
# must be run in python 3.9 or later for ast.unparse() support
# version defaults to whatever version this script is running in; needs to be set explicitly for backwards compatibility
# ast only supports versions 3.4 and later
def normalize_source(
source: str,
version: Tuple[int, int] = sys.version_info[0:2],
replace_docstrings=False,
) -> str:
"""
Parse the source code into an AST, then convert back to source.
This has the following normalizing effects:
1. whitespace is set according to the PEP standard
2. each statement is on exactly one line
3. # comments are removed (note: docstrings are not removed)
:param str source: The source code to normalize
:param tuple version: The (Major, Minor) version of python to parse with; must be at least (3, 4); defaults to
same version as this script
:param bool replace_docstrings: Replace all docstrings with 'pass'
"""
tree = ast.parse(source, feature_version=version)
if replace_docstrings:
for node in ast.walk(tree):
if isinstance(node, ast.Expr) and isinstance(node.value, ast.Str):
node.value.s = "pass"
return ast.unparse(tree)
def normalize_source_file(
source_file_path: str,
cleaned_suffix: str = "-cleaned",
version: tuple[int, int] = sys.version_info[0:2],
):
"""
Normalizes the source code in a given file, then saves it to a '-cleaned' version in the same directory
:param str source_file_path: The absolute or relative path to the source .py file
:param str cleaned_suffix: The suffix to add to the cleaned file, typically left as default
:param tuple version: The (Major, Minor) version of python to parse with; must be at least (3, 4); defaults to
same version as this script
"""
# add the cleaned_suffix to the output_path
input_path = pathlib.Path(source_file_path).resolve()
output_path = input_path.with_stem(f"{input_path.stem}{cleaned_suffix}")
with open(input_path, "r") as source_file:
normalized_source = normalize_source(source_file.read(), version=version)
with open(output_path, "w") as cleaned_file:
cleaned_file.write(normalized_source)
return output_path
@@ -0,0 +1,60 @@
from io import BytesIO
from typing import Dict, List, Literal
from datasets import load_dataset
from huggingface_hub import HfApi
from .DatasetDescription import DatasetDescription
LOCAL_DATASET = Dict[Literal["train", "test", "valid"], List[str]]
def upload_single_dataset(data_files: LOCAL_DATASET, dataset_name: str, dataset_card: str):
local_datasets = load_dataset("csv", data_files=data_files)
local_datasets.push_to_hub(dataset_name, private=True)
dataset_card_with_stats = dataset_card + f"\n\nDataset Statistics:\n\n```\n{local_datasets}\n```"
api = HfApi()
api.upload_file(
path_or_fileobj=BytesIO(bytes(dataset_card_with_stats, "utf-8")),
path_in_repo="README.md",
repo_id=dataset_name,
repo_type="dataset",
)
def upload_dataset_to_huggingface(dataset_description: DatasetDescription):
formatted_data_requests = "\n".join(f"{str(req.source_path.resolve())}: (train: {req.num_train}, test: {req.num_test}, valid: {req.num_valid})" for req in dataset_description.data_requests)
dataset_card = f"""
# {dataset_description.name}
Created by the Syssec team @ UTD
Dataset Composition:
```
{formatted_data_requests}
```
Python version: `{".".join(map(str, dataset_description.version))}`
"""
splits: List[Literal["train", "test", "valid"]] = [
"train",
"test",
"valid",
]
# collect data files
segmentation_data_files: LOCAL_DATASET = {}
statement_data_files: LOCAL_DATASET = {}
for split in splits:
segmentation_data_files[split] = [str(path.resolve()) for path in (dataset_description.csv_dir / split / "segmentation").glob("*.csv")]
statement_data_files[split] = [str(path.resolve()) for path in (dataset_description.csv_dir / split / "statement").glob("*.csv")]
# upload datasets
segmentation_dataset_name = f"{dataset_description.huggingface_user}/segmentation-{dataset_description.name}"
upload_single_dataset(segmentation_data_files, segmentation_dataset_name, dataset_card)
statement_dataset_name = f"{dataset_description.huggingface_user}/statement-{dataset_description.name}"
upload_single_dataset(statement_data_files, statement_dataset_name, dataset_card)
+66
View File
@@ -0,0 +1,66 @@
import json
import logging
import pathlib
from typing import Union
import click
from dataset_generation.bytecode2csv import create_csv_dataset
from dataset_generation.create_code_dataset import create_code_dataset
from dataset_generation.DatasetDescription import DataRequest, DatasetDescription
from dataset_generation.upload_raw_dataset import upload_dataset_to_huggingface
from pylingual.utils.get_logger import get_logger
def get_dataset_description_from_arg_json(json_path: str, logger: Union[logging.Logger, None] = None) -> DatasetDescription:
json_file_path = pathlib.Path(json_path)
if not json_file_path.exists():
raise FileNotFoundError(f"{json_file_path} does not exist")
if logger:
logger.info(f"Loading dataset description from {json_file_path}...")
with json_file_path.open() as json_file:
dataset_description_dict = json.load(json_file)
dataset_description_dict["data_requests"] = [DataRequest(**d) for d in dataset_description_dict["data_requests"]]
return DatasetDescription(**dataset_description_dict)
@click.command(help="Samples, splits, processes, and uploads a given dataset described by JSON.")
@click.argument("json_path", type=str)
def main(json_path: str):
logger = get_logger("prepare-dataset")
dataset_description = get_dataset_description_from_arg_json(json_path, logger)
logger.debug(dataset_description)
if dataset_description.code_dir.exists():
raise FileExistsError(f"{dataset_description.code_dir} already exists! The dataset name is probably already taken.")
logger.info("Creating code dataset...")
if not (dataset_description.data_requests and dataset_description.code_dir and dataset_description.version):
logger.error("Dataset description is missing required fields")
exit(1)
create_code_dataset(
dataset_description.data_requests,
dataset_description.code_dir,
dataset_description.version,
logger,
)
# create csv dataset
logger.info("Converting code dataset to csv...")
create_csv_dataset(
dataset_description.code_dir,
dataset_description.csv_dir,
dataset_description.data_requests,
logger,
)
logger.info(f"Uploading {dataset_description.name} to HuggingFace...")
upload_dataset_to_huggingface(dataset_description)
if __name__ == "__main__":
main()
@@ -0,0 +1,24 @@
{
"name": "sample_dataset_name",
"version": [3, 6],
"save_to_dir": "./save_dir/",
"huggingface_user": "sample_user",
"data_requests":
[
{
"name": "dataset",
"source_path": "./dataset",
"num_train": 200,
"num_test": 200,
"num_valid": 200
},
{
"name": "dataset2",
"source_path": "./dataset2",
"num_train": 200,
"num_test": 200,
"num_valid": 200
}
]
}
@@ -0,0 +1,18 @@
{
"base_repo_name": "sample_user/sample_segmenter_name",
"dataset_repo_name": "sample_user/segmentation-sample_dataset_name",
"pretrained_mlm_repo_name": "",
"cache_dir": "./cache-dir/",
"max_token_length": 512,
"dataset_percentage": 100,
"mlm_training_parameters": {
"batch_size": 48,
"epochs": 2,
"learning_rate": 5e-5
},
"segmentation_training_parameters": {
"batch_size": 48,
"epochs": 2,
"learning_rate": 2e-5
}
}
@@ -0,0 +1,16 @@
{
"base_repo_name": "sample_user/sample_name",
"dataset_repo_name": "sample_user/statement-sample_dataset_name",
"tokenizer_repo_name": "sample_user/sample_name-tok",
"pretrained_seq2seq_repo_name": "Salesforce/codet5-base",
"cache_dir": "./cache-dir/",
"max_token_length": 256,
"dataset_percentage": 100,
"do_eval": true,
"fp16": true,
"statement_training_parameters": {
"batch_size": 24,
"epochs": 2,
"learning_rate": 2e-5
}
}
@@ -0,0 +1,74 @@
import json
import logging
import pathlib
from dataclasses import dataclass
from typing import Optional
@dataclass
class TrainingParameters:
batch_size: int
epochs: int
learning_rate: float
@dataclass
class SegmentationConfiguration:
base_repo_name: str
dataset_repo_name: str
pretrained_mlm_repo_name: str
cache_dir: pathlib.Path
max_token_length: int
dataset_percentage: int
mlm_training_parameters: TrainingParameters
segmentation_training_parameters: TrainingParameters
@property
def tokenizer_repo_name(self):
return self.base_repo_name + "-tokenizer"
@property
def tokenizer_json_path(self):
return self.cache_dir / "tokenizers" / self.tokenizer_repo_name / "tokenizer.json"
@property
def tokenized_dataset_repo_name(self):
return self.dataset_repo_name + "-tokenized"
@property
def mlm_repo_name(self):
return self.base_repo_name + "-mlm"
@property
def mlm_dir(self):
return self.cache_dir / "models" / self.mlm_repo_name
@property
def segmenter_repo_name(self):
return self.base_repo_name + "-segmenter"
@property
def segmenter_dir(self):
return self.cache_dir / "models" / self.segmenter_repo_name
@property
def dataset_dir(self):
return self.cache_dir / "datasets" / self.dataset_repo_name
def __post_init__(self):
self.cache_dir = pathlib.Path(self.cache_dir)
def parse_segmentation_config_json(json_file_path: pathlib.Path, logger: Optional[logging.Logger] = None) -> SegmentationConfiguration:
if not json_file_path.exists():
raise FileNotFoundError(f"{json_file_path} does not exist")
if logger:
logger.info(f"Loading model description from {json_file_path}...")
with json_file_path.open() as json_file:
segmentation_config_dict = json.load(json_file)
segmentation_config_dict["mlm_training_parameters"] = TrainingParameters(**segmentation_config_dict["mlm_training_parameters"])
segmentation_config_dict["segmentation_training_parameters"] = TrainingParameters(**segmentation_config_dict["segmentation_training_parameters"])
return SegmentationConfiguration(**segmentation_config_dict)
+152
View File
@@ -0,0 +1,152 @@
import ast
import functools
import os
import pathlib
import click
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from SegmentationConfiguration import SegmentationConfiguration, parse_segmentation_config_json
from pylingual.segmentation.sliding_window import sliding_window
from transformers import PreTrainedTokenizerFast
bytecode_separator = " <SEP> "
def load_tokenizer(tokenizer_repo_name: str, cache_dir: pathlib.Path) -> PreTrainedTokenizerFast:
tokenizer_dir = cache_dir / "tokenizers" / tokenizer_repo_name
tokenizer_file = hf_hub_download(repo_id=tokenizer_repo_name, filename="tokenizer.json", token=True, cache_dir=str(tokenizer_dir))
tokenizer = PreTrainedTokenizerFast(
tokenizer_file=tokenizer_file,
unk_token="[UNK]",
pad_token="[PAD]",
cls_token="[CLS]",
sep_token="[SEP]",
mask_token="[MASK]",
)
return tokenizer
# we need to make sure we align all the labels with the proper words.
def align_labels_with_tokens(labels, word_ids):
label_names = ["B", "I", "E"]
id2label = {str(i): label for i, label in enumerate(label_names)}
label2id = {v: k for k, v in id2label.items()}
new_labels = []
current_word = None
for word_id in word_ids:
if word_id != current_word:
# Start of a new word!
current_word = word_id
label = -100 if word_id is None else int(label2id[labels[word_id]])
new_labels.append(label)
elif word_id is None:
# Special token
new_labels.append(-100)
else:
# Same word as previous token
label = int(label2id[labels[word_id]])
new_labels.append(label)
return new_labels
# the process function used for tokenize the dataset
def tokenize_and_align_labels(tokenizer: PreTrainedTokenizerFast, max_length: int, examples):
MAX_WINDOW_LENGTH = 512
STEP_SIZE = 128
# parse the strings into lists to better work with the bytecode and boundaries
parsed_bc = [(codeobj.split(" <SEP> "), ast.literal_eval(bounds)) for codeobj, bounds in zip(examples["bytecode"], examples["boundary"])]
codeobj_tokens = []
# count the tokens for each bytecode instruction in a codeobj
for codeobj, bounds in parsed_bc:
token_list = []
for bc, bounds in zip(codeobj, bounds):
token_list.append(((bc, bounds), len(tokenizer(bc)[0])))
codeobj_tokens.append(token_list)
windows = [sliding_window(codeobj, MAX_WINDOW_LENGTH, STEP_SIZE) for codeobj in codeobj_tokens]
# remake examples using our windows
examples["boundary"] = []
examples["bytecode"] = []
# go through each window
for window in windows:
for item in window:
# where we will temporarily store our bytecode and bounds
bytecode = []
bounds = []
for bc in item[0]:
bytecode.append(bc[0])
bounds.append(bc[1])
# append it into examples
examples["bytecode"].append(bytecode_separator.join(bytecode))
examples["boundary"].append(str(bounds))
tokenized_inputs = tokenizer(
examples["bytecode"],
truncation=True,
max_length=max_length,
)
all_labels = examples["boundary"]
new_labels = []
for i, labels in enumerate(all_labels):
labels = labels.replace("'", "").strip("][").split(", ")
word_ids = tokenized_inputs.word_ids(i)
labels_len = len(labels)
max_word_id = word_ids[-2]
# for those data might cause error due to the incorrect tokenization, we fix the data exceed-length issue and
# leave them here as some noisy data.
if max_word_id >= labels_len:
new_labels.append([-100] * max_word_id)
else:
new_labels.append(align_labels_with_tokens(labels, word_ids))
tokenized_inputs["labels"] = new_labels
return tokenized_inputs
def tokenize_segmentation_dataset(config: SegmentationConfiguration):
raw_dataset = load_dataset(config.dataset_repo_name, token=True, cache_dir=str(config.dataset_dir))
tokenizer = load_tokenizer(config.tokenizer_repo_name, config.cache_dir)
prepped_tokenize_and_align_labels = functools.partial(tokenize_and_align_labels, tokenizer, config.max_token_length)
# tokenize input dataset
column_names = raw_dataset["train"].column_names
tokenized_datasets = raw_dataset.map(
prepped_tokenize_and_align_labels,
batched=True,
remove_columns=column_names,
num_proc=os.cpu_count(),
desc="Tokenizing datasets",
)
tokenized_datasets.push_to_hub(
config.tokenized_dataset_repo_name,
private=True,
)
@click.command(help="Script to tokenize the segmentation dataset given a segmentation json.")
@click.argument("json_path", type=str)
def main(json_path: str):
json_file_path = pathlib.Path(json_path)
segmentation_config = parse_segmentation_config_json(json_file_path)
tokenize_segmentation_dataset(segmentation_config)
if __name__ == "__main__":
main()
+195
View File
@@ -0,0 +1,195 @@
import logging
import os
import pathlib
import click
from datasets import load_dataset
from huggingface_hub import hf_hub_download, repo_exists
from SegmentationConfiguration import SegmentationConfiguration, parse_segmentation_config_json
from transformers import AutoModelForMaskedLM, DataCollatorForLanguageModeling, PreTrainedTokenizerFast, RobertaConfig, RobertaForMaskedLM, Trainer, TrainingArguments
from pylingual.segmentation.sliding_window import sliding_window
bytecode_separator = " <SEP> "
def load_tokenizer(tokenizer_repo_name: str, cache_dir: pathlib.Path) -> PreTrainedTokenizerFast:
tokenizer_dir = cache_dir / "tokenizers" / tokenizer_repo_name
tokenizer_file = hf_hub_download(
repo_id=tokenizer_repo_name,
filename="tokenizer.json",
token=True,
cache_dir=str(tokenizer_dir),
)
tokenizer = PreTrainedTokenizerFast(
tokenizer_file=tokenizer_file,
unk_token="[UNK]",
pad_token="[PAD]",
cls_token="[CLS]",
sep_token="[SEP]",
mask_token="[MASK]",
)
return tokenizer
def load_tokenized_train_dataset(
dataset_repo_name: str,
tokenizer: PreTrainedTokenizerFast,
max_length: int,
cache_dir: pathlib.Path,
):
dataset_dir = cache_dir / "datasets" / dataset_repo_name
raw_dataset = load_dataset(dataset_repo_name, token=True, cache_dir=dataset_dir, split="train")
# tokenize the input data
column_names = raw_dataset.column_names
def tokenize(examples):
# sliding window compatibility
MAX_WINDOW_LENGTH = 512
STEP_SIZE = 128
# parse the strings into lists to better work with the bytecode and boundaries
parsed_bc = [codeobj.split(" <SEP> ") for codeobj in examples["bytecode"]]
codeobj_tokens = []
# count the tokens for each bytecode instruction in a codeobj
for codeobj in parsed_bc:
token_list = []
for bytecode in codeobj:
token_list.append((bytecode, len(tokenizer(bytecode)[0])))
codeobj_tokens.append(token_list)
windows = [sliding_window(codeobj, MAX_WINDOW_LENGTH, STEP_SIZE) for codeobj in codeobj_tokens]
# remake examples using our windows
examples["bytecode"] = []
# go through each window
for window in windows:
for item in window:
# where we will temporarily store our bytecode and bounds
bytecode = []
for bc in item[0]:
bytecode.append(bc)
# append to examples
examples["bytecode"].append(bytecode_separator.join(bytecode))
return tokenizer(examples["bytecode"], max_length=max_length, truncation=True)
tokenized_dataset = raw_dataset.map(
tokenize,
batched=True,
remove_columns=column_names,
num_proc=os.cpu_count(),
desc="Tokenizing datasets",
)
return tokenized_dataset
def load_pretrained_mlm(
pretrained_mlm_repo_name: str,
tokenizer_embedding_length: int,
cache_dir: pathlib.Path,
) -> AutoModelForMaskedLM:
# load a basic pretrained BERT model
pretrained_mlm_dir = cache_dir / "models" / pretrained_mlm_repo_name
model = AutoModelForMaskedLM.from_pretrained(pretrained_mlm_repo_name, cache_dir=str(pretrained_mlm_dir))
# resize token embeddings to fit the model
model.resize_token_embeddings(tokenizer_embedding_length)
return model
def initialize_untrained_mlm(
tokenizer_embedding_length: int,
max_token_length: int,
) -> RobertaForMaskedLM:
# initialize untrained RoBERTa model
# most configuration options set to match https://huggingface.co/microsoft/codebert-base/blob/main/config.json for direct comparison
model_config = RobertaConfig(
max_position_embeddings=max_token_length, # INPUT LENGTH LIMIT
vocab_size=tokenizer_embedding_length,
layer_norm_eps=1e-05,
type_vocab_size=1,
)
model = RobertaForMaskedLM(model_config)
return model
def train_mlm(config: SegmentationConfiguration):
if repo_exists(config.base_repo_name):
logging.error(f"{config.base_repo_name} has already exists")
exit(1)
using_pretrained_model = bool(config.pretrained_mlm_repo_name)
# train model, for now the configuration comes from a regular T5 translation model.
training_args = TrainingArguments(
output_dir=str(config.mlm_dir),
num_train_epochs=config.mlm_training_parameters.epochs,
per_device_train_batch_size=config.mlm_training_parameters.batch_size,
save_steps=1000,
save_total_limit=5,
prediction_loss_only=True,
push_to_hub=True,
hub_model_id=config.mlm_repo_name,
hub_private_repo=True,
ddp_backend="nccl",
ddp_find_unused_parameters=using_pretrained_model, # only look for unused parameters in pretrained models
remove_unused_columns=False,
)
tokenizer = load_tokenizer(config.tokenizer_repo_name, config.cache_dir)
# Set DataCollator for MLM task, set the probability of masking.
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
if using_pretrained_model:
pretrained_mlm = load_pretrained_mlm(config.pretrained_mlm_repo_name, len(tokenizer), config.cache_dir)
else:
pretrained_mlm = initialize_untrained_mlm(len(tokenizer), config.max_token_length + 2)
tokenized_training_data = load_tokenized_train_dataset(config.dataset_repo_name, tokenizer, config.max_token_length, config.cache_dir)
# Hugging face trainer: a Trainer class to fine-tune pretrained models
trainer = Trainer(
model=pretrained_mlm,
args=training_args,
data_collator=data_collator,
train_dataset=tokenized_training_data,
)
# Training
trainer.train()
if int(os.environ["LOCAL_RANK"]) == 0:
# Save the model
trainer.save_model(config.mlm_dir)
trainer.push_to_hub(
finetuned_from=config.pretrained_mlm_repo_name,
dataset=config.dataset_repo_name,
commit_message=f"Trained on {config.dataset_repo_name} using {config.tokenizer_repo_name}",
)
@click.command(help="Training script for the masked language model pretraining for the segmentation model given a segmentation json.")
@click.argument("json_path", type=str)
def main(json_path: str):
json_file_path = pathlib.Path(json_path)
segmentation_config = parse_segmentation_config_json(json_file_path)
train_mlm(segmentation_config)
if __name__ == "__main__":
main()
+155
View File
@@ -0,0 +1,155 @@
import logging
import os
import pathlib
import click
import evaluate
import numpy as np
from datasets import ReadInstruction, load_dataset
from huggingface_hub import hf_hub_download, repo_exists
from SegmentationConfiguration import SegmentationConfiguration, parse_segmentation_config_json
from transformers import AutoModelForTokenClassification, DataCollatorForTokenClassification, PreTrainedTokenizerFast, Trainer, TrainingArguments
# two dictionaries, id2label and label2id, which contain the mappings from ID to label and vice versa.
label_names = ["B", "I", "E"]
id2label = {str(i): label for i, label in enumerate(label_names)}
label2id = {v: k for k, v in id2label.items()}
# compute_metrics: evaluate metric for training and evaluation.
def compute_metrics(eval_preds):
metric = evaluate.load("seqeval")
logits, labels = eval_preds
predictions = np.argmax(logits, axis=-1)
# Remove ignored index (special tokens) and convert to labels
# noqa: E741
true_labels = [[label_names[l] for l in label if l != -100] for label in labels]
true_predictions = [[label_names[p] for (p, l) in zip(prediction, label) if l != -100] for prediction, label in zip(predictions, labels)]
all_metrics = metric.compute(predictions=true_predictions, references=true_labels)
return {
"precision": all_metrics["overall_precision"],
"recall": all_metrics["overall_recall"],
"f1": all_metrics["overall_f1"],
"accuracy": all_metrics["overall_accuracy"],
}
def load_tokenizer(tokenizer_repo_name: str, cache_dir: pathlib.Path) -> PreTrainedTokenizerFast:
tokenizer_dir = cache_dir / "tokenizers" / tokenizer_repo_name
tokenizer_file = hf_hub_download(
repo_id=tokenizer_repo_name,
filename="tokenizer.json",
token=True,
cache_dir=str(tokenizer_dir),
)
tokenizer = PreTrainedTokenizerFast(
tokenizer_file=tokenizer_file,
unk_token="[UNK]",
pad_token="[PAD]",
cls_token="[CLS]",
sep_token="[SEP]",
mask_token="[MASK]",
)
return tokenizer
def load_tokenized_train_and_valid_dataset(dataset_repo_name: str, cache_dir: pathlib.Path, dataset_percentage: int = 100):
dataset_dir = cache_dir / "datasets" / dataset_repo_name
# Load the tokenized dataset
tokenized_train_dataset = load_dataset(
dataset_repo_name,
token=True,
cache_dir=str(dataset_dir),
split=ReadInstruction("train", to=dataset_percentage, unit="%"),
)
tokenized_validation_dataset = load_dataset(
dataset_repo_name,
token=True,
cache_dir=str(dataset_dir),
split="valid",
)
return tokenized_train_dataset, tokenized_validation_dataset
def train_segmentation_model(config: SegmentationConfiguration):
if repo_exists(config.base_repo_name):
logging.error(f"{config.base_repo_name} has already exists")
exit(1)
# training arguments.
training_args = TrainingArguments(
output_dir=str(config.segmenter_dir),
overwrite_output_dir=True,
eval_strategy="epoch",
logging_strategy="epoch",
save_strategy="epoch",
learning_rate=config.segmentation_training_parameters.learning_rate,
num_train_epochs=config.segmentation_training_parameters.epochs,
per_device_train_batch_size=config.segmentation_training_parameters.batch_size,
save_steps=1000,
weight_decay=0.01,
fp16=True,
push_to_hub=True,
hub_model_id=config.segmenter_repo_name,
hub_private_repo=True,
ddp_backend="nccl",
ddp_find_unused_parameters=True,
save_total_limit=5,
)
# load a basic pretrained BERT model
model = AutoModelForTokenClassification.from_pretrained(
pretrained_model_name_or_path=config.mlm_repo_name,
id2label=id2label,
label2id=label2id,
token=True,
)
# Set DataCollator for DataCollatorForTokenClassification
tokenizer = load_tokenizer(config.tokenizer_repo_name, config.cache_dir)
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, max_length=config.max_token_length)
(
tokenized_train_dataset,
tokenized_validation_dataset,
) = load_tokenized_train_and_valid_dataset(config.tokenized_dataset_repo_name, config.cache_dir, config.dataset_percentage)
# Hugging face trainer: a Trainer class to fine-tune pretrained models
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=tokenized_train_dataset,
eval_dataset=tokenized_validation_dataset,
compute_metrics=compute_metrics,
tokenizer=tokenizer,
)
# Training
trainer.train()
if int(os.environ["LOCAL_RANK"]) == 0:
# Save the model
trainer.save_model(str(config.segmenter_dir))
trainer.push_to_hub(
finetuned_from=config.mlm_repo_name,
dataset=config.tokenized_dataset_repo_name,
commit_message=f"Trained on {config.tokenized_dataset_repo_name} using {config.mlm_repo_name}",
)
@click.command(help="Training script for the segmentation model given a segmentation json.")
@click.argument("json_path", type=str)
def main(json_path: str):
json_file_path = pathlib.Path(json_path)
segmentation_config = parse_segmentation_config_json(json_file_path)
train_segmentation_model(segmentation_config)
if __name__ == "__main__":
main()
@@ -0,0 +1,96 @@
import logging
import pathlib
import click
from datasets import ReadInstruction, load_dataset
from huggingface_hub import HfApi, create_repo, repo_exists
from SegmentationConfiguration import SegmentationConfiguration, parse_segmentation_config_json
from tokenizers import Tokenizer, decoders, models, normalizers, pre_tokenizers, processors, trainers
special_tokens = ["[UNK]", "[PAD]", "[CLS]", "[SEP]", "[MASK]"]
def get_untrained_tokenizer() -> Tokenizer:
# WordPiece tokenization for BERT.
tokenizer = Tokenizer(models.WordPiece(unk_token="[UNK]"))
# The normalizer recognizes the accented characters and strip them out.
tokenizer.normalizer = normalizers.Sequence([normalizers.NFD(), normalizers.StripAccents()])
# The pre-tokenizer splits on <SEP> tokens.
tokenizer.pre_tokenizer = pre_tokenizers.Split("<SEP>", "removed")
return tokenizer
def post_training_configuration(tokenizer: Tokenizer):
cls_token_id = tokenizer.token_to_id("[CLS]")
sep_token_id = tokenizer.token_to_id("[SEP]")
# Set decoder for the tokenizer
tokenizer.decoder = decoders.WordPiece(prefix="##")
# For the TemplateProcessor, we have to specify how to treat a single sentence and a pair of sentences.
tokenizer.post_processor = processors.TemplateProcessing(
single="[CLS]:0 $A:0 [SEP]:0",
pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
special_tokens=[("[CLS]", cls_token_id), ("[SEP]", sep_token_id)],
)
def save_and_upload_tokenizer(
tokenizer: Tokenizer,
tokenizer_json_path: pathlib.Path,
tokenizer_repo_name: str,
dataset_name: str,
):
# save the tokenizer locally
tokenizer_json_path.parent.mkdir(parents=True, exist_ok=True)
tokenizer.save(str(tokenizer_json_path.resolve()))
# upload tokenizer to huggingface
api = HfApi()
create_repo(tokenizer_repo_name, exist_ok=True, private=True)
api.upload_file(
path_in_repo="tokenizer.json",
path_or_fileobj=str(tokenizer_json_path.resolve()),
repo_id=tokenizer_repo_name,
commit_message=f"Trained tokenizer using {dataset_name}",
)
def train_tokenizer(config: SegmentationConfiguration):
if repo_exists(config.base_repo_name):
logging.error(f"{config.base_repo_name} has already exists")
exit(1)
tokenizer = get_untrained_tokenizer()
train_dataset = load_dataset(
config.dataset_repo_name,
token=True,
split=ReadInstruction("train", to=config.dataset_percentage, unit="%"),
)["bytecode"]
trainer = trainers.WordPieceTrainer(vocab_size=30000, special_tokens=special_tokens)
tokenizer.train_from_iterator(train_dataset, trainer=trainer)
post_training_configuration(tokenizer)
save_and_upload_tokenizer(
tokenizer,
config.tokenizer_json_path,
config.tokenizer_repo_name,
config.dataset_repo_name,
)
@click.command(help="Training script for the bytecode tokenizer for the segmentation model given a segmentation json.")
@click.argument("json_path", type=str)
def main(json_path: str):
json_file_path = pathlib.Path(json_path)
segmentation_config = parse_segmentation_config_json(json_file_path)
train_tokenizer(segmentation_config)
if __name__ == "__main__":
main()
+18
View File
@@ -0,0 +1,18 @@
# seq2seq
- train_tokenizer_auto.py:
- trains the manual tokenizer
- tokenize_seq2seq.py:
- tokenize the dataset for the seq2seq model
- train_seq2seq.py:
- finetuning the pretrained model
- will create a sequence-to-sequence translation model
- StatementConfiguration.py
- defines the JSON format for statement translation training
# manual1
Contains JSONs mapping bytecode instructions and their configurations to use in training.
@@ -0,0 +1,59 @@
from dataclasses import dataclass
import pathlib
import json
import logging
@dataclass
class TrainingParameters:
batch_size: int
epochs: int
learning_rate: float
@dataclass
class StatementConfiguration:
base_repo_name: str
dataset_repo_name: str
tokenizer_repo_name: str
pretrained_seq2seq_repo_name: str
cache_dir: pathlib.Path
max_token_length: int
dataset_percentage: int
do_eval: bool
fp16: bool
statement_training_parameters: TrainingParameters
@property
def tokenized_dataset_repo_name(self):
return self.dataset_repo_name + "-tokenized"
@property
def statement_model_repo_name(self):
return self.base_repo_name + "-statement"
@property
def statement_model_dir(self):
return self.cache_dir / "models" / self.statement_model_repo_name
@property
def log_dir(self):
return self.statement_model_dir / "logs"
def __post_init__(self):
self.cache_dir = pathlib.Path(self.cache_dir)
def parse_statement_config_json(json_file_path: pathlib.Path, logger: logging.Logger = None) -> StatementConfiguration:
if not json_file_path.exists():
raise FileNotFoundError(f"{json_file_path} does not exist")
if logger:
logger.info(f"Loading model description from {json_file_path}...")
with json_file_path.open() as json_file:
statement_config_dict = json.load(json_file)
statement_config_dict["statement_training_parameters"] = TrainingParameters(**statement_config_dict["statement_training_parameters"])
return StatementConfiguration(**statement_config_dict)
View File
+51
View File
@@ -0,0 +1,51 @@
import os
from typing import Any
import click
import pathlib
from datasets import load_dataset
from transformers import RobertaTokenizer
from StatementConfiguration import StatementConfiguration, parse_statement_config_json
import functools
def preprocess_function(tokenizer: RobertaTokenizer, max_token_length: int, input_key: str, examples: dict[str, Any]) -> dict[str, Any]:
"""Set up Huggingface tokenizers for both inputs and targets"""
inputs = [ex if ex else "" for ex in examples[input_key]]
targets = [ex if ex else "" for ex in examples["source"]]
return tokenizer(text=inputs, text_target=targets, max_length=max_token_length, truncation=True)
def tokenize_seq2seq_dataset(config: StatementConfiguration):
# ref: https://huggingface.co/Salesforce/codet5-base
tokenizer = RobertaTokenizer.from_pretrained(config.tokenizer_repo_name)
raw_datasets = load_dataset(config.dataset_repo_name, token=True)
column_names = raw_datasets["train"].column_names
input_key = "bytecode"
prepped_preprocess_function = functools.partial(preprocess_function, tokenizer, config.max_token_length, input_key)
tokenized_datasets = raw_datasets.map(
prepped_preprocess_function,
batched=True,
remove_columns=column_names,
num_proc=os.cpu_count(),
desc="Tokenizing datasets",
)
tokenized_datasets.push_to_hub(config.tokenized_dataset_repo_name, private=True)
@click.command(help="Tokenization script for Statement Translation model given a statement json.")
@click.argument("json_path", type=str)
def main(json_path: str):
json_file_path = pathlib.Path(json_path)
statement_config = parse_statement_config_json(json_file_path)
tokenize_seq2seq_dataset(statement_config)
if __name__ == "__main__":
main()
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,929 @@
{
"add_prefix_space": false,
"additional_special_tokens": [
"<pad>",
"<s>",
"</s>",
"<unk>",
"<mask>",
"!",
"\"",
"#",
"$",
"%",
"&",
"'",
"(",
")",
"*",
"+",
",",
"-",
".",
"/",
"0",
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
":",
";",
"<",
"=",
">",
"?",
"@",
"A",
"B",
"C",
"D",
"E",
"F",
"G",
"H",
"I",
"J",
"K",
"L",
"M",
"N",
"O",
"P",
"Q",
"R",
"S",
"T",
"U",
"V",
"W",
"X",
"Y",
"Z",
"[",
"\\",
"]",
"^",
"_",
"`",
"a",
"b",
"c",
"d",
"e",
"f",
"g",
"h",
"i",
"j",
"k",
"l",
"m",
"n",
"o",
"p",
"q",
"r",
"s",
"t",
"u",
"v",
"w",
"x",
"y",
"z",
"{",
"|",
"}",
"~",
"Ġ",
"-=",
"<<",
">>",
":=",
">=",
"<=",
"==",
"!=",
"+=",
"//=",
"**=",
"/=",
"//",
"%=",
"@=",
"&=",
"|=",
"^=",
">>=",
"<<=",
"*=",
"()",
"):",
"~>>",
"**",
"<codeobj:",
"<KWARG_PAD>",
"E->",
"<TAP_0>",
"defaults",
"args:",
"vararg:",
"<TAP_1>",
"<START_LINE>",
"<SUB>",
"<TAP_UP>",
"E-END",
"<TAP_2>",
"~~>",
"<TAP_ST>",
"<SEP>",
"</SUB>",
"<TAP_3>",
"<TAP_4>",
"<TAP_5>",
"<TAP_6>",
"<TAP_7>",
"False",
"None",
"True",
"and",
"assert",
"async",
"await",
"break",
"class",
"continue",
"def",
"del",
"elif",
"else",
"else:",
"except",
"except:",
"finally",
"finally:",
"for",
"from",
"global",
"if",
"import",
"in",
"is",
"lambda",
"nonlocal",
"not",
"or",
"pass",
"raise",
"return",
"try",
"try:",
"while",
"with",
"yield",
"case",
"as",
"ASYNC_GEN_WRAP",
"BEFORE_ASYNC_WITH",
"BEFORE_WITH",
"BEGIN_FINALLY",
"BINARY_ADD",
"BINARY_AND",
"BINARY_FLOOR_DIVIDE",
"BINARY_LSHIFT",
"BINARY_MATRIX_MULTIPLY",
"BINARY_MODULO",
"BINARY_MULTIPLY",
"BINARY_OP",
"BINARY_OR",
"BINARY_POWER",
"BINARY_RSHIFT",
"BINARY_SLICE",
"BINARY_SUBSCR",
"BINARY_SUBTRACT",
"BINARY_TRUE_DIVIDE",
"BINARY_XOR",
"BREAK_LOOP",
"BUILD_CONST_KEY_MAP",
"BUILD_LIST",
"BUILD_LIST_UNPACK",
"BUILD_MAP",
"BUILD_MAP_UNPACK",
"BUILD_MAP_UNPACK_WITH_CALL",
"BUILD_SET",
"BUILD_SET_UNPACK",
"BUILD_SLICE",
"BUILD_STRING",
"BUILD_TUPLE",
"BUILD_TUPLE_UNPACK",
"BUILD_TUPLE_UNPACK_WITH_CALL",
"CACHE",
"CALL_FINALLY",
"CALL_FUNCTION",
"CALL_FUNCTION_EX",
"CALL_FUNCTION_KW",
"CALL_METHOD",
"CHECK_EG_MATCH",
"CHECK_EXC_MATCH",
"CLEANUP_THROW",
"COMPARE_OP",
"CONTAINS_OP",
"CONTINUE_LOOP",
"COPY",
"COPY_DICT_WITHOUT_KEYS",
"COPY_FREE_VARS",
"DELETE_ATTR",
"DELETE_DEREF",
"DELETE_FAST",
"DELETE_GLOBAL",
"DELETE_NAME",
"DELETE_SUBSCR",
"DICT_MERGE",
"DICT_UPDATE",
"DUP_TOP",
"DUP_TOP_TWO",
"END_ASYNC_FOR",
"END_FINALLY",
"END_FOR",
"END_SEND",
"EXTENDED_ARG",
"FOR_ITER",
"FORMAT_VALUE",
"GEN_START",
"GET_AITER",
"GET_ANEXT",
"GET_AWAITABLE",
"GET_ITER",
"GET_LEN",
"GET_YIELD_FROM_ITER",
"IMPORT_FROM",
"IMPORT_NAME",
"IMPORT_STAR",
"INPLACE_ADD",
"INPLACE_AND",
"INPLACE_FLOOR_DIVIDE",
"INPLACE_LSHIFT",
"INPLACE_MATRIX_MULTIPLY",
"INPLACE_MODULO",
"INPLACE_MULTIPLY",
"INPLACE_OR",
"INPLACE_POWER",
"INPLACE_RSHIFT",
"INPLACE_SUBTRACT",
"INPLACE_TRUE_DIVIDE",
"INPLACE_XOR",
"INTERPRETER_EXIT",
"IS_OP",
"JUMP_ABSOLUTE",
"JUMP_BACKWARD",
"JUMP_BACKWARD_NO_INTERRUPT",
"JUMP_FORWARD",
"JUMP_IF_FALSE_OR_POP",
"JUMP_IF_NOT_EXC_MATCH",
"JUMP_IF_TRUE_OR_POP",
"LIST_APPEND",
"LIST_EXTEND",
"LIST_TO_TUPLE",
"LOAD_ASSERTION_ERROR",
"LOAD_ATTR",
"LOAD_BUILD_CLASS",
"LOAD_CLASSDEREF",
"LOAD_CLOSURE",
"LOAD_CONST",
"LOAD_DEREF",
"LOAD_FAST",
"LOAD_FAST_AND_CLEAR",
"LOAD_FAST_CHECK",
"LOAD_GLOBAL",
"LOAD_LOCALS",
"LOAD_METHOD",
"LOAD_NAME",
"LOAD_SUPER_ATTR",
"MAKE_CELL",
"MAKE_FUNCTION",
"MAP_ADD",
"MATCH_CLASS",
"MATCH_KEYS",
"MATCH_MAPPING",
"MATCH_SEQUENCE",
"NOP",
"POP_BLOCK",
"POP_EXCEPT",
"POP_FINALLY",
"POP_JUMP_FORWARD_IF_FALSE",
"POP_JUMP_FORWARD_IF_NONE",
"POP_JUMP_FORWARD_IF_NOT_NONE",
"POP_JUMP_FORWARD_IF_TRUE",
"POP_JUMP_IF_FALSE",
"POP_JUMP_IF_NONE",
"POP_JUMP_IF_NOT_NONE",
"POP_JUMP_IF_TRUE",
"POP_TOP",
"PRECALL",
"PREP_RERAISE_STAR",
"PRINT_EXPR",
"PUSH_EXC_INFO",
"PUSH_NULL",
"RAISE_VARARGS",
"RERAISE",
"RESERVED",
"RESUME",
"RETURN_CONST",
"RETURN_GENERATOR",
"RETURN_VALUE",
"ROT_FOUR",
"ROT_N",
"ROT_THREE",
"ROT_TWO",
"SEND",
"SET_ADD",
"SET_UPDATE",
"SETUP_ANNOTATIONS",
"SETUP_ASYNC_WITH",
"SETUP_EXCEPT",
"SETUP_FINALLY",
"SETUP_LOOP",
"SETUP_WITH",
"STORE_ATTR",
"STORE_DEREF",
"STORE_FAST",
"STORE_GLOBAL",
"STORE_NAME",
"STORE_SLICE",
"STORE_SUBSCR",
"SWAP",
"UNARY_INVERT",
"UNARY_NEGATIVE",
"UNARY_NOT",
"UNARY_POSITIVE",
"UNPACK_EX",
"UNPACK_SEQUENCE",
"WITH_CLEANUP_FINISH",
"WITH_CLEANUP_START",
"WITH_EXCEPT_START",
"YIELD_FROM",
"YIELD_VALUE",
"<mask_0>",
"<mask_1>",
"<mask_2>",
"<mask_3>",
"<mask_4>",
"<mask_5>",
"<mask_6>",
"<mask_7>",
"<mask_8>",
"<mask_9>",
"<mask_10>",
"<mask_11>",
"<mask_12>",
"<mask_13>",
"<mask_14>",
"<mask_15>",
"<mask_16>",
"<mask_17>",
"<mask_18>",
"<mask_19>",
"<mask_20>",
"<mask_21>",
"<mask_22>",
"<mask_23>",
"<mask_24>",
"<mask_25>",
"<mask_26>",
"<mask_27>",
"<mask_28>",
"<mask_29>",
"<mask_30>",
"<mask_31>",
"<mask_32>",
"<mask_33>",
"<mask_34>",
"<mask_35>",
"<mask_36>",
"<mask_37>",
"<mask_38>",
"<mask_39>",
"<mask_40>",
"<mask_41>",
"<mask_42>",
"<mask_43>",
"<mask_44>",
"<mask_45>",
"<mask_46>",
"<mask_47>",
"<mask_48>",
"<mask_49>",
"<mask_50>",
"<mask_51>",
"<mask_52>",
"<mask_53>",
"<mask_54>",
"<mask_55>",
"<mask_56>",
"<mask_57>",
"<mask_58>",
"<mask_59>",
"<mask_60>",
"<mask_61>",
"<mask_62>",
"<mask_63>",
"<mask_64>",
"<mask_65>",
"<mask_66>",
"<mask_67>",
"<mask_68>",
"<mask_69>",
"<mask_70>",
"<mask_71>",
"<mask_72>",
"<mask_73>",
"<mask_74>",
"<mask_75>",
"<mask_76>",
"<mask_77>",
"<mask_78>",
"<mask_79>",
"<mask_80>",
"<mask_81>",
"<mask_82>",
"<mask_83>",
"<mask_84>",
"<mask_85>",
"<mask_86>",
"<mask_87>",
"<mask_88>",
"<mask_89>",
"<mask_90>",
"<mask_91>",
"<mask_92>",
"<mask_93>",
"<mask_94>",
"<mask_95>",
"<mask_96>",
"<mask_97>",
"<mask_98>",
"<mask_99>",
"<mask_100>",
"<mask_101>",
"<mask_102>",
"<mask_103>",
"<mask_104>",
"<mask_105>",
"<mask_106>",
"<mask_107>",
"<mask_108>",
"<mask_109>",
"<mask_110>",
"<mask_111>",
"<mask_112>",
"<mask_113>",
"<mask_114>",
"<mask_115>",
"<mask_116>",
"<mask_117>",
"<mask_118>",
"<mask_119>",
"<mask_120>",
"<mask_121>",
"<mask_122>",
"<mask_123>",
"<mask_124>",
"<mask_125>",
"<mask_126>",
"<mask_127>",
"<mask_128>",
"<mask_129>",
"<mask_130>",
"<mask_131>",
"<mask_132>",
"<mask_133>",
"<mask_134>",
"<mask_135>",
"<mask_136>",
"<mask_137>",
"<mask_138>",
"<mask_139>",
"<mask_140>",
"<mask_141>",
"<mask_142>",
"<mask_143>",
"<mask_144>",
"<mask_145>",
"<mask_146>",
"<mask_147>",
"<mask_148>",
"<mask_149>",
"<mask_150>",
"<mask_151>",
"<mask_152>",
"<mask_153>",
"<mask_154>",
"<mask_155>",
"<mask_156>",
"<mask_157>",
"<mask_158>",
"<mask_159>",
"<mask_160>",
"<mask_161>",
"<mask_162>",
"<mask_163>",
"<mask_164>",
"<mask_165>",
"<mask_166>",
"<mask_167>",
"<mask_168>",
"<mask_169>",
"<mask_170>",
"<mask_171>",
"<mask_172>",
"<mask_173>",
"<mask_174>",
"<mask_175>",
"<mask_176>",
"<mask_177>",
"<mask_178>",
"<mask_179>",
"<mask_180>",
"<mask_181>",
"<mask_182>",
"<mask_183>",
"<mask_184>",
"<mask_185>",
"<mask_186>",
"<mask_187>",
"<mask_188>",
"<mask_189>",
"<mask_190>",
"<mask_191>",
"<mask_192>",
"<mask_193>",
"<mask_194>",
"<mask_195>",
"<mask_196>",
"<mask_197>",
"<mask_198>",
"<mask_199>",
"<mask_200>",
"<mask_201>",
"<mask_202>",
"<mask_203>",
"<mask_204>",
"<mask_205>",
"<mask_206>",
"<mask_207>",
"<mask_208>",
"<mask_209>",
"<mask_210>",
"<mask_211>",
"<mask_212>",
"<mask_213>",
"<mask_214>",
"<mask_215>",
"<mask_216>",
"<mask_217>",
"<mask_218>",
"<mask_219>",
"<mask_220>",
"<mask_221>",
"<mask_222>",
"<mask_223>",
"<mask_224>",
"<mask_225>",
"<mask_226>",
"<mask_227>",
"<mask_228>",
"<mask_229>",
"<mask_230>",
"<mask_231>",
"<mask_232>",
"<mask_233>",
"<mask_234>",
"<mask_235>",
"<mask_236>",
"<mask_237>",
"<mask_238>",
"<mask_239>",
"<mask_240>",
"<mask_241>",
"<mask_242>",
"<mask_243>",
"<mask_244>",
"<mask_245>",
"<mask_246>",
"<mask_247>",
"<mask_248>",
"<mask_249>",
"<mask_250>",
"<mask_251>",
"<mask_252>",
"<mask_253>",
"<mask_254>",
"<mask_255>",
"<mask_256>",
"<mask_257>",
"<mask_258>",
"<mask_259>",
"<mask_260>",
"<mask_261>",
"<mask_262>",
"<mask_263>",
"<mask_264>",
"<mask_265>",
"<mask_266>",
"<mask_267>",
"<mask_268>",
"<mask_269>",
"<mask_270>",
"<mask_271>",
"<mask_272>",
"<mask_273>",
"<mask_274>",
"<mask_275>",
"<mask_276>",
"<mask_277>",
"<mask_278>",
"<mask_279>",
"<mask_280>",
"<mask_281>",
"<mask_282>",
"<mask_283>",
"<mask_284>",
"<mask_285>",
"<mask_286>",
"<mask_287>",
"<mask_288>",
"<mask_289>",
"<mask_290>",
"<mask_291>",
"<mask_292>",
"<mask_293>",
"<mask_294>",
"<mask_295>",
"<mask_296>",
"<mask_297>",
"<mask_298>",
"<mask_299>",
"<mask_300>",
"<mask_301>",
"<mask_302>",
"<mask_303>",
"<mask_304>",
"<mask_305>",
"<mask_306>",
"<mask_307>",
"<mask_308>",
"<mask_309>",
"<mask_310>",
"<mask_311>",
"<mask_312>",
"<mask_313>",
"<mask_314>",
"<mask_315>",
"<mask_316>",
"<mask_317>",
"<mask_318>",
"<mask_319>",
"<mask_320>",
"<mask_321>",
"<mask_322>",
"<mask_323>",
"<mask_324>",
"<mask_325>",
"<mask_326>",
"<mask_327>",
"<mask_328>",
"<mask_329>",
"<mask_330>",
"<mask_331>",
"<mask_332>",
"<mask_333>",
"<mask_334>",
"<mask_335>",
"<mask_336>",
"<mask_337>",
"<mask_338>",
"<mask_339>",
"<mask_340>",
"<mask_341>",
"<mask_342>",
"<mask_343>",
"<mask_344>",
"<mask_345>",
"<mask_346>",
"<mask_347>",
"<mask_348>",
"<mask_349>",
"<mask_350>",
"<mask_351>",
"<mask_352>",
"<mask_353>",
"<mask_354>",
"<mask_355>",
"<mask_356>",
"<mask_357>",
"<mask_358>",
"<mask_359>",
"<mask_360>",
"<mask_361>",
"<mask_362>",
"<mask_363>",
"<mask_364>",
"<mask_365>",
"<mask_366>",
"<mask_367>",
"<mask_368>",
"<mask_369>",
"<mask_370>",
"<mask_371>",
"<mask_372>",
"<mask_373>",
"<mask_374>",
"<mask_375>",
"<mask_376>",
"<mask_377>",
"<mask_378>",
"<mask_379>",
"<mask_380>",
"<mask_381>",
"<mask_382>",
"<mask_383>",
"<extra_id_99>",
"<extra_id_98>",
"<extra_id_97>",
"<extra_id_96>",
"<extra_id_95>",
"<extra_id_94>",
"<extra_id_93>",
"<extra_id_92>",
"<extra_id_91>",
"<extra_id_90>",
"<extra_id_89>",
"<extra_id_88>",
"<extra_id_87>",
"<extra_id_86>",
"<extra_id_85>",
"<extra_id_84>",
"<extra_id_83>",
"<extra_id_82>",
"<extra_id_81>",
"<extra_id_80>",
"<extra_id_79>",
"<extra_id_78>",
"<extra_id_77>",
"<extra_id_76>",
"<extra_id_75>",
"<extra_id_74>",
"<extra_id_73>",
"<extra_id_72>",
"<extra_id_71>",
"<extra_id_70>",
"<extra_id_69>",
"<extra_id_68>",
"<extra_id_67>",
"<extra_id_66>",
"<extra_id_65>",
"<extra_id_64>",
"<extra_id_63>",
"<extra_id_62>",
"<extra_id_61>",
"<extra_id_60>",
"<extra_id_59>",
"<extra_id_58>",
"<extra_id_57>",
"<extra_id_56>",
"<extra_id_55>",
"<extra_id_54>",
"<extra_id_53>",
"<extra_id_52>",
"<extra_id_51>",
"<extra_id_50>",
"<extra_id_49>",
"<extra_id_48>",
"<extra_id_47>",
"<extra_id_46>",
"<extra_id_45>",
"<extra_id_44>",
"<extra_id_43>",
"<extra_id_42>",
"<extra_id_41>",
"<extra_id_40>",
"<extra_id_39>",
"<extra_id_38>",
"<extra_id_37>",
"<extra_id_36>",
"<extra_id_35>",
"<extra_id_34>",
"<extra_id_33>",
"<extra_id_32>",
"<extra_id_31>",
"<extra_id_30>",
"<extra_id_29>",
"<extra_id_28>",
"<extra_id_27>",
"<extra_id_26>",
"<extra_id_25>",
"<extra_id_24>",
"<extra_id_23>",
"<extra_id_22>",
"<extra_id_21>",
"<extra_id_20>",
"<extra_id_19>",
"<extra_id_18>",
"<extra_id_17>",
"<extra_id_16>",
"<extra_id_15>",
"<extra_id_14>",
"<extra_id_13>",
"<extra_id_12>",
"<extra_id_11>",
"<extra_id_10>",
"<extra_id_9>",
"<extra_id_8>",
"<extra_id_7>",
"<extra_id_6>",
"<extra_id_5>",
"<extra_id_4>",
"<extra_id_3>",
"<extra_id_2>",
"<extra_id_1>",
"<extra_id_0>",
"match",
"type",
"HAVE_ARGUMENT",
"CALL_INTRINSIC_1",
"CALL_INTRINSIC_2",
"JUMP_NO_INTERRUPT",
"nargs",
"vargs",
"compare",
"name",
"const",
"local"
],
"bos_token": {
"__type": "AddedToken",
"content": "<s>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"clean_up_tokenization_spaces": true,
"cls_token": {
"__type": "AddedToken",
"content": "<s>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"eos_token": {
"__type": "AddedToken",
"content": "</s>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"errors": "replace",
"mask_token": {
"__type": "AddedToken",
"content": "<mask>",
"lstrip": true,
"normalized": true,
"rstrip": false,
"single_word": false
},
"model_max_length": 512,
"pad_token": {
"__type": "AddedToken",
"content": "<pad>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"sep_token": {
"__type": "AddedToken",
"content": "</s>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"tokenizer_class": "RobertaTokenizer",
"trim_offsets": true,
"unk_token": {
"__type": "AddedToken",
"content": "<unk>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
}
}
+93
View File
@@ -0,0 +1,93 @@
import os
import pathlib
import time
from datetime import timedelta
import click
from datasets import ReadInstruction, load_dataset
from StatementConfiguration import StatementConfiguration, parse_statement_config_json
from transformers import (
DataCollatorForSeq2Seq,
RobertaTokenizer,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
T5ForConditionalGeneration,
)
def load_tokenized_train_dataset(dataset_repo_name: str, dataset_percentage: int):
# Load the tokenized dataset
tokenized_train_dataset = load_dataset(
dataset_repo_name,
token=True,
split=ReadInstruction("train", to=dataset_percentage, unit="%"),
)
return tokenized_train_dataset
def train_statement_model(config: StatementConfiguration):
# load model, Salesforce/codet5-base is a pretrained model solving the code generation task.
tokenizer = RobertaTokenizer.from_pretrained(config.tokenizer_repo_name)
model = T5ForConditionalGeneration.from_pretrained(config.pretrained_seq2seq_repo_name)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
model_dir = str(config.statement_model_dir)
model_repo_name = config.statement_model_repo_name
train_args = Seq2SeqTrainingArguments(
output_dir=model_dir,
learning_rate=config.statement_training_parameters.learning_rate,
per_device_train_batch_size=config.statement_training_parameters.batch_size,
per_device_eval_batch_size=config.statement_training_parameters.batch_size,
weight_decay=0.01,
fp16=config.fp16,
logging_dir=str(config.log_dir),
report_to="tensorboard",
logging_strategy="steps",
logging_steps=1000,
save_strategy="steps",
save_steps=10000,
save_total_limit=2,
num_train_epochs=config.statement_training_parameters.epochs,
predict_with_generate=True,
push_to_hub=True,
hub_model_id=model_repo_name,
hub_private_repo=True,
ddp_backend="nccl",
ddp_find_unused_parameters=False,
)
tokenized_train_dataset = load_tokenized_train_dataset(config.tokenized_dataset_repo_name, config.dataset_percentage)
trainer = Seq2SeqTrainer(
model=model,
args=train_args,
data_collator=data_collator,
train_dataset=tokenized_train_dataset,
tokenizer=tokenizer,
)
start = time.time()
trainer.train()
duration = str(timedelta(seconds=time.time() - start))
if int(os.environ["LOCAL_RANK"]) == 0:
# upload the latest version of the model to the Model Hub on Huggingface
trainer.save_model(str(config.statement_model_dir))
# this command returns the URL of the commit it just did
trainer.push_to_hub(
commit_message=duration,
finetuned_from=config.pretrained_seq2seq_repo_name,
dataset=config.tokenized_dataset_repo_name,
)
@click.command(help="Training script for the statement translation model given a statement json.")
@click.argument("json_path", type=str)
def main(json_path: str):
json_file_path = pathlib.Path(json_path)
statement_config = parse_statement_config_json(json_file_path)
train_statement_model(statement_config)
if __name__ == "__main__":
main()
@@ -0,0 +1,93 @@
import logging
import pathlib
import click
from datasets import ReadInstruction, load_dataset
from huggingface_hub import HfApi, repo_exists
from StatementConfiguration import StatementConfiguration, parse_statement_config_json
from tokenizers import Tokenizer
from transformers import AutoTokenizer
def get_untrained_tokenizer(tokenizer_repo_name: str) -> AutoTokenizer:
tokenizer_dir = pathlib.Path(__file__).parent / tokenizer_repo_name
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
return tokenizer
def save_and_upload_tokenizer(
tokenizer: Tokenizer,
tokenizer_json_path: pathlib.Path,
tokenizer_repo_name: str,
dataset_name: str,
):
# Save the tokenizer locally
tokenizer.save_pretrained(str(tokenizer_json_path.parent.resolve()))
# Upload files to Hugging Face Hub
api = HfApi()
api.create_repo(tokenizer_repo_name, exist_ok=True, private=True)
api.upload_file(
path_in_repo="tokenizer_config.json",
path_or_fileobj=str(tokenizer_json_path.parent / "tokenizer_config.json"),
repo_id=tokenizer_repo_name,
commit_message=f"Trained tokenizer using {dataset_name}",
)
api.upload_file(
path_in_repo="vocab.json",
path_or_fileobj=str(tokenizer_json_path.parent / "vocab.json"),
repo_id=tokenizer_repo_name,
commit_message="Extracted vocabulary from tokenizer",
)
api.upload_file(
path_in_repo="merges.txt",
path_or_fileobj=str(tokenizer_json_path.parent / "merges.txt"),
repo_id=tokenizer_repo_name,
commit_message="Extracted merges from tokenizer",
)
api.upload_file(
path_in_repo="tokenizer.json",
path_or_fileobj=str(tokenizer_json_path.parent / "tokenizer.json"),
repo_id=tokenizer_repo_name,
commit_message="Extracted tokenizer",
)
api.upload_file(
path_in_repo="special_tokens_map.json",
path_or_fileobj=str(tokenizer_json_path.parent / "special_tokens_map.json"),
repo_id=tokenizer_repo_name,
commit_message="Extracted special tokens map",
)
def train_tokenizer(config: StatementConfiguration, tokenizer_json_path: pathlib.Path):
if repo_exists(config.base_repo_name):
logging.error(f"{config.base_repo_name} has already exists")
exit(1)
tokenizer = get_untrained_tokenizer("tokenizer")
train_dataset = load_dataset(
config.dataset_repo_name,
token=True,
split=ReadInstruction("train", to=config.dataset_percentage, unit="%"),
)["bytecode"]
tokenizer = tokenizer.train_new_from_iterator(train_dataset, vocab_size=30000)
save_and_upload_tokenizer(
tokenizer,
tokenizer_json_path,
config.tokenizer_repo_name,
config.dataset_repo_name,
)
@click.command(help="Training script for the bytecode tokenizer for the statement model given a statement json.")
@click.argument("json_path", type=str)
def main(json_path: str):
json_file_path = pathlib.Path(json_path)
statement_config = parse_statement_config_json(json_file_path)
train_tokenizer(statement_config, json_file_path)
if __name__ == "__main__":
main()
+117
View File
@@ -0,0 +1,117 @@
import logging
import os
import pathlib
import subprocess
import click
from pylingual.utils.get_logger import get_logger
def train_segmentation(segmentation_config_path: pathlib.Path, logger: logging.Logger, nnodes: int = 1, nproc_per_node: int = 1, rdzv_port: int = 29400):
segmentation_root = pathlib.Path(__file__).parent / "segmentation"
# train tokenizer
logger.info("training tokenizer...")
subprocess.run(["python", segmentation_root / "train_tokenizer.py", segmentation_config_path])
# train mlm (single gpu to avoid conflicts with local tokenized data)
logger.info("training masked language model...")
subprocess.run(
[
"torchrun",
f"--nnodes={nnodes}",
f"--nproc-per-node={nproc_per_node}",
"--rdzv-backend=c10d",
f"--rdzv-endpoint=localhost:{rdzv_port}",
segmentation_root / "train_mlm.py",
segmentation_config_path,
],
env=dict(os.environ, NCCL_P2P_DISABLE="1"),
)
# tokenize dataset
logger.info("tokenizing segmentation dataset...")
subprocess.run(["python", segmentation_root / "tokenize_seg.py", segmentation_config_path])
# train segmentation model (4 gpus)
logger.info("training segmentation model...")
subprocess.run(
[
"torchrun",
f"--nnodes={nnodes}",
f"--nproc-per-node={nproc_per_node}",
"--rdzv-backend=c10d",
f"--rdzv-endpoint=localhost:{rdzv_port}",
segmentation_root / "train_seg.py",
segmentation_config_path,
],
env=dict(os.environ, NCCL_P2P_DISABLE="1"),
)
def train_statement(statement_config_path: pathlib.Path, logger: logging.Logger, nnodes: int = 1, nproc_per_node: int = 1, rdzv_port: int = 29400):
statement_root = pathlib.Path(__file__).parent / "statement"
# manual tokenizer
subprocess.run(["python", statement_root / "train_tokenizer_auto.py", statement_config_path])
# tokenize statement dataset with salesforce tokenizer
logger.info("tokenizing statement dataset...")
subprocess.run(["python", statement_root / "tokenize_seq2seq.py", statement_config_path])
# train statement model (4 gpus)
logger.info("training statement model...")
subprocess.run(
[
"torchrun",
f"--nnodes={nnodes}",
f"--nproc-per-node={nproc_per_node}",
"--rdzv-backend=c10d",
f"--rdzv-endpoint=localhost:{rdzv_port}",
statement_root / "train_seq2seq.py",
statement_config_path,
],
env=dict(os.environ, NCCL_P2P_DISABLE="1"),
)
@click.command(help="Full tokenization and training pipeline for the segmentation and statement translation models.")
@click.option("--segmentation", type=str, default=None, help="The path to the segmentation model description JSON file.")
@click.option("--statement", type=str, default=None, help="The path to the statement model description JSON file.")
@click.option("--nnodes", type=int, default=1, help="Torchrun nnodes arg")
@click.option("--nproc_per_node", type=int, default=1, help="Torchrun nproc_per_node arg")
@click.option("--rdzv_port", "-p", type=int, default=29400, help="Port to use for torchrun rendezvous endpoint")
def main(segmentation: str, statement: str, nnodes: int, nproc_per_node: int, rdzv_port: int):
logger = get_logger("train-models")
### LOAD JSON
logger.info("Training pipeline starting...")
logger.info("Loading dataset description JSON files...")
### CONFIG_PATHS
segmentation_config_path = pathlib.Path(segmentation).resolve() if segmentation is not None else None
statement_config_path = pathlib.Path(statement).resolve() if statement is not None else None
logger.info("Dataset description JSON files loaded!")
### TRAIN SEGMENTATION
if segmentation_config_path is not None:
logger.info("Segmentation model training starting...")
train_segmentation(segmentation_config_path, logger, nnodes, nproc_per_node, rdzv_port)
logger.info("Segmentation model training complete!")
else:
logger.warning("Segmentation model configuration json path not provided in --segmentation; skipping segmentation model training...")
### TRAIN STATEMENT
if statement_config_path is not None:
logger.info("Statement model training starting...")
train_statement(statement_config_path, logger, nnodes, nproc_per_node, rdzv_port)
logger.info("Statement model training complete!")
else:
logger.warning("Statement model configuration json path not provided in --statement; skipping statement model training...")
logger.info("Training pipeline complete!")
if __name__ == "__main__":
main()