mirror of
https://github.com/syssec-utd/pylingual.git
synced 2026-05-10 18:39:03 -07:00
rename
This commit is contained in:
@@ -0,0 +1,50 @@
|
||||
from dataclasses import dataclass
|
||||
import pathlib
|
||||
|
||||
from typing import Tuple, List
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataRequest:
|
||||
name: str
|
||||
source_path: pathlib.Path
|
||||
num_train: int
|
||||
num_test: int
|
||||
num_valid: int
|
||||
|
||||
@property
|
||||
def total_files(self):
|
||||
return self.num_train + self.num_test + self.num_valid
|
||||
|
||||
def __post_init__(self):
|
||||
self.source_path = pathlib.Path(self.source_path)
|
||||
if not self.source_path.exists():
|
||||
raise FileNotFoundError(f"{self.source_path} for DataRequest {self.name} does not exist")
|
||||
|
||||
if self.num_train < 0:
|
||||
raise ValueError(f"Training sample count for DataRequest {self.name} must be non-negative")
|
||||
if self.num_test < 0:
|
||||
raise ValueError(f"Testing sample count for DataRequest {self.name} must be non-negative")
|
||||
if self.num_valid < 0:
|
||||
raise ValueError(f"Validation sample count for DataRequest {self.name} must be non-negative")
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetDescription:
|
||||
name: str
|
||||
version: Tuple[int, int]
|
||||
save_to_dir: pathlib.Path
|
||||
huggingface_user: str
|
||||
data_requests: List[DataRequest]
|
||||
|
||||
@property
|
||||
def code_dir(self):
|
||||
return self.save_to_dir / self.name / "code"
|
||||
|
||||
@property
|
||||
def csv_dir(self):
|
||||
return self.save_to_dir / self.name / "csv"
|
||||
|
||||
def __post_init__(self):
|
||||
self.save_to_dir = pathlib.Path(self.save_to_dir)
|
||||
self.version = tuple(self.version)
|
||||
@@ -0,0 +1,3 @@
|
||||
from .create_code_dataset import transfer_and_compile_file
|
||||
|
||||
__all__ = ["transfer_and_compile_file"]
|
||||
@@ -0,0 +1,216 @@
|
||||
import csv
|
||||
import itertools
|
||||
import logging
|
||||
import multiprocessing
|
||||
import pathlib
|
||||
import re
|
||||
import signal
|
||||
from typing import Callable, Tuple
|
||||
|
||||
import tqdm
|
||||
from pylingual.editable_bytecode import PYCFile
|
||||
|
||||
from pylingual.masking.ast_masker import DUMMY_DECORATOR
|
||||
from pylingual.masking.model_disasm import fix_jump_targets
|
||||
from .DatasetDescription import DataRequest
|
||||
from pylingual.masking.model_disasm import create_global_masker, mask_source
|
||||
|
||||
bytecode_separator = " <SEP> "
|
||||
source_seperator = " <SEP> "
|
||||
CSV_SGMT_HEADER = ["source", "bytecode", "boundary", "file"]
|
||||
CSV_STMT_HEADER = ["source", "bytecode", "file"]
|
||||
|
||||
|
||||
def create_csv_dataset(code_dataset_path: pathlib.Path, csv_dataset_path: pathlib.Path, data_requests: list[DataRequest], logger: logging.Logger = None):
|
||||
progress_bar = tqdm.tqdm(total=sum([request.total_files for request in data_requests]))
|
||||
for split in ("train", "test", "valid"):
|
||||
if logger:
|
||||
logger.info(f"Converting the {split} split to CSV...")
|
||||
write_csvs(code_dataset_path / split, csv_dataset_path / split, logger, progress_bar=progress_bar)
|
||||
|
||||
|
||||
def write_csvs(source_path: pathlib.Path, csv_output_path: pathlib.Path, logger: logging.Logger = None, max_csv_rows: int = 30000, progress_bar: tqdm.tqdm = None):
|
||||
# validate output directory
|
||||
if csv_output_path.exists():
|
||||
if not csv_output_path.is_dir():
|
||||
raise OSError("CSV output path is not a directory")
|
||||
else:
|
||||
csv_output_path.mkdir(parents=True)
|
||||
|
||||
##### csv write wrappers to preserve csv row limit
|
||||
|
||||
def csv_writer(file_prefix: str, csv_header: list) -> Callable:
|
||||
out_dir = csv_output_path.joinpath(file_prefix)
|
||||
out_dir.mkdir(exist_ok=True)
|
||||
|
||||
for csv_idx in itertools.count():
|
||||
new_path = out_dir.joinpath(f"{file_prefix}_{csv_idx}.csv")
|
||||
new_path.touch()
|
||||
if logger:
|
||||
logger.info(f"Creating new csv {new_path.resolve()}...")
|
||||
with new_path.open(mode="w") as csv_file:
|
||||
writer = csv.writer(csv_file)
|
||||
writer.writerow(csv_header)
|
||||
for writer in itertools.repeat(writer, max_csv_rows):
|
||||
yield writer.writerow
|
||||
|
||||
segmentation_writer = csv_writer("segmentation", CSV_SGMT_HEADER)
|
||||
statement_writer = csv_writer("statement", CSV_STMT_HEADER)
|
||||
|
||||
# create dirs
|
||||
code_dirs = (child for child in source_path.iterdir() if child.is_dir())
|
||||
|
||||
def bytecode2csv_args():
|
||||
for dir in code_dirs:
|
||||
py_path = next(dir.glob("*.py"), None)
|
||||
pyc_path = next(dir.glob("*.pyc"), None)
|
||||
if None in (py_path, pyc_path):
|
||||
logging.debug(f"PY or PYC file not found in {dir}")
|
||||
continue
|
||||
else:
|
||||
yield (py_path, pyc_path)
|
||||
|
||||
num_fails = 0
|
||||
with multiprocessing.Pool() as pool:
|
||||
for result in pool.imap_unordered(bytecode2csv_exception_wrapper, bytecode2csv_args()):
|
||||
if isinstance(result, Exception):
|
||||
num_fails += 1
|
||||
logger.debug(f"DIR: {dir}\nERR: {result}\nTYPE ERR: {type(result)}\n")
|
||||
continue
|
||||
|
||||
(segmentation_rows, statement_rows) = result
|
||||
for row, writerow in zip(segmentation_rows, segmentation_writer):
|
||||
writerow(row)
|
||||
for row, writerow in zip(statement_rows, statement_writer):
|
||||
writerow(row)
|
||||
|
||||
if progress_bar:
|
||||
progress_bar.update()
|
||||
progress_bar.set_postfix({"num_fails": num_fails})
|
||||
|
||||
logger.info(f"NUMBER OF FAILS !!! {num_fails}")
|
||||
|
||||
|
||||
def timeout_handler(signum, frame):
|
||||
raise TimeoutError()
|
||||
|
||||
|
||||
def bytecode2csv_exception_wrapper(paths=Tuple[pathlib.Path, pathlib.Path]) -> Tuple[list, list] | Exception:
|
||||
signal.signal(signal.SIGALRM, timeout_handler)
|
||||
try:
|
||||
signal.alarm(30) # set 30 second timeout
|
||||
results = bytecode2csv(*paths)
|
||||
signal.alarm(0) # success; disable timer
|
||||
return results
|
||||
except Exception as error:
|
||||
signal.alarm(0) # disable timer in case another exception triggered the fail
|
||||
return Exception(f"{type(error)}: {error} in file {paths}")
|
||||
|
||||
|
||||
def bytecode2csv(py_path: pathlib.Path, pyc_path: pathlib.Path) -> tuple[list, list]:
|
||||
"""Creates segmentation and statement csv rows for given bytecode and source file"""
|
||||
segmentation_rows = []
|
||||
statement_rows = []
|
||||
|
||||
pyc = PYCFile(str(pyc_path.resolve()))
|
||||
if pyc.version == (3, 10):
|
||||
pyc.replace_duplicated_returns10(py_path.read_text().split("\n"))
|
||||
elif pyc.version == (3, 12):
|
||||
pyc.replace_duplicated_returns12(py_path.read_text().split("\n"))
|
||||
global_masker = create_global_masker(pyc)
|
||||
|
||||
masked_source_text = mask_source(py_path, global_masker, pyc.version)
|
||||
masked_source_lines = masked_source_text.split("\n")
|
||||
|
||||
# filter out dummy decorators added in <= 3.7
|
||||
dummy_lnos = []
|
||||
if pyc.version <= (3, 7):
|
||||
# remove dummy decorators from bytecode'
|
||||
pyc._patch_dummy_decorator(dummy_decorator_name=DUMMY_DECORATOR)
|
||||
try: # if no functions are in source, then dummy will not exist
|
||||
dummy_decorator_line = f"@{global_masker.mask(DUMMY_DECORATOR)}"
|
||||
except KeyError:
|
||||
dummy_decorator_line = None
|
||||
dummy_lnos = [lno + 1 for lno, source in enumerate(masked_source_lines) if source.strip() == dummy_decorator_line]
|
||||
|
||||
seen_lines = set()
|
||||
|
||||
# create rows for each bytecode
|
||||
for bc in pyc.iter_bytecodes():
|
||||
# we ignore comprehensions, hoisted later
|
||||
if bc.is_comprehension:
|
||||
continue
|
||||
|
||||
# attempt to filter lines
|
||||
lno_insts = bc.get_lno_insts(previously_seen_lines=seen_lines)
|
||||
|
||||
# create line num : model disasm view of insts
|
||||
lno_model_view_insts = {lno: [global_masker.get_model_view(inst) for inst in line_insts] for lno, line_insts in lno_insts.items()}
|
||||
seen_lines.update(lno_model_view_insts.keys())
|
||||
|
||||
# segment source
|
||||
if pyc.version <= (3, 7):
|
||||
segmented_source_lines = []
|
||||
for line_num in lno_model_view_insts:
|
||||
if not line_num:
|
||||
segmented_source_lines.append("")
|
||||
elif line_num in dummy_lnos:
|
||||
segmented_source_lines.append(masked_source_lines[line_num].strip())
|
||||
else:
|
||||
segmented_source_lines.append(masked_source_lines[line_num - 1].strip())
|
||||
else:
|
||||
segmented_source_lines = [masked_source_lines[line_num - 1].strip() if line_num else "" for line_num in lno_model_view_insts.keys()] # -1 to convert from line num to index in array
|
||||
|
||||
model_disasm_text = bytecode_separator.join(val for val in itertools.chain(*lno_model_view_insts.values()))
|
||||
|
||||
if len(segmented_source_lines) != len(lno_model_view_insts):
|
||||
raise ValueError("Length mismatch between segmented source and segmented bytecodes")
|
||||
|
||||
# create bytecode segmentation
|
||||
boundaries = []
|
||||
for bc_line in lno_model_view_insts.values():
|
||||
if len(bc_line) == 1:
|
||||
bounds = "B"
|
||||
elif len(bc_line) >= 2:
|
||||
bounds = "B" + "I" * (len(bc_line) - 2) + "E"
|
||||
else:
|
||||
raise ValueError("Unexpected amount of bytecodes segmented into a line")
|
||||
boundaries.extend(list(bounds))
|
||||
|
||||
# append rows
|
||||
segmentation_rows.append([source_seperator.join(segmented_source_lines), model_disasm_text, boundaries, str(py_path)])
|
||||
for segmented_source, bytecodes in zip(segmented_source_lines, lno_model_view_insts.values()):
|
||||
# skip empty lines
|
||||
if not segmented_source or segmented_source == "None":
|
||||
continue
|
||||
# skip fillers
|
||||
if segmented_source in ("pass", "...") and ("RETURN_VALUE" in bytecodes or "RETURN_CONST , None" in bytecodes):
|
||||
continue
|
||||
# skip string-only lines that aren't docstrings
|
||||
if (segmented_source.startswith("'") or segmented_source.startswith('"')) and not any("__doc__" in b for b in bytecodes):
|
||||
continue
|
||||
if segmented_source.startswith("elif "):
|
||||
segmented_source = segmented_source[2:]
|
||||
|
||||
joined_bytecode = bytecode_separator.join(bytecodes)
|
||||
|
||||
# DUCT-TAPE; skip samples where model has to guess masks
|
||||
source_masks = set(re.findall(r"<mask_\d+>", segmented_source))
|
||||
bytecode_masks = set(re.findall(r"<mask_\d+>", joined_bytecode))
|
||||
if not source_masks <= bytecode_masks:
|
||||
continue
|
||||
|
||||
# normalize source mask order for statements
|
||||
# replace mask values to start at 0 and count up
|
||||
mask_regex = re.compile(r"(?<=<mask_)\d+(?=>)")
|
||||
masks = mask_regex.findall(joined_bytecode)
|
||||
mask_order = [x for i, x in enumerate(masks) if masks.index(x) == i]
|
||||
normalized_mask_bytecode = mask_regex.sub(lambda x: str(mask_order.index(x.group(0))), joined_bytecode)
|
||||
normalized_mask_source = mask_regex.sub(lambda x: str(mask_order.index(x.group(0))), segmented_source)
|
||||
|
||||
# normalize jump targets
|
||||
normalized_mask_bytecode = fix_jump_targets(normalized_mask_bytecode)
|
||||
|
||||
statement_rows.append([normalized_mask_source, normalized_mask_bytecode, str(py_path)])
|
||||
|
||||
return (segmentation_rows, statement_rows)
|
||||
@@ -0,0 +1,114 @@
|
||||
import itertools
|
||||
import logging
|
||||
import multiprocessing
|
||||
import pathlib
|
||||
import random
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
import tqdm
|
||||
|
||||
from .DatasetDescription import DataRequest
|
||||
from pylingual.utils.generate_bytecode import compile_version
|
||||
from .normalize_source import normalize_source
|
||||
from pylingual.masking.ast_masker import add_dummy_decorators
|
||||
|
||||
|
||||
def transfer_and_compile_file(
|
||||
original_file: pathlib.Path,
|
||||
destination_file: pathlib.Path,
|
||||
version: Tuple[int, int],
|
||||
) -> Optional[Exception]:
|
||||
# copy over normalized source file
|
||||
try:
|
||||
normalized_source = normalize_source(original_file.read_text(), version=version, replace_docstrings=True)
|
||||
|
||||
if version[:2] <= (3, 7):
|
||||
normalized_source = add_dummy_decorators(normalized_source)
|
||||
except Exception as err:
|
||||
return err
|
||||
|
||||
destination_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
destination_file.write_text(normalized_source)
|
||||
|
||||
# compile the copied file with the given version
|
||||
try:
|
||||
compile_version(
|
||||
destination_file.resolve(),
|
||||
destination_file.with_suffix(".pyc").resolve(),
|
||||
version,
|
||||
)
|
||||
except Exception as err:
|
||||
return err
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def star_transfer_and_compile_file(args) -> Optional[Exception]:
|
||||
return transfer_and_compile_file(*args)
|
||||
|
||||
|
||||
# samples num_files files from the given directory
|
||||
# expects the directory to have the structure
|
||||
# source_dir -> identifier -> file.py
|
||||
def sample_directory_splits(
|
||||
data_request: DataRequest,
|
||||
) -> Tuple[List[pathlib.Path], List[pathlib.Path], List[pathlib.Path]]:
|
||||
all_files: Set[pathlib.Path] = set()
|
||||
for identifier in data_request.source_path.iterdir():
|
||||
source_file = next(identifier.glob("*.py"), None) # get the first python file from the identifier
|
||||
if source_file is not None:
|
||||
all_files.add(source_file)
|
||||
|
||||
# sample batches until we have enough files to satisfy the data requests
|
||||
# this avoids running expensive tests on unsampled files
|
||||
clean_sample: Set[pathlib.Path] = set()
|
||||
while len(clean_sample) < data_request.total_files:
|
||||
remaining_files = data_request.total_files - len(clean_sample)
|
||||
sample_batch = random.sample(list(all_files), k=remaining_files)
|
||||
# add the acceptable files to the sample and remove them from the population
|
||||
to_add = set(candidate for candidate in sample_batch if candidate is not None)
|
||||
clean_sample.update(to_add)
|
||||
all_files -= to_add
|
||||
|
||||
full_sample = iter(clean_sample)
|
||||
|
||||
train = list(itertools.islice(full_sample, data_request.num_train))
|
||||
test = list(itertools.islice(full_sample, data_request.num_test))
|
||||
valid = list(itertools.islice(full_sample, data_request.num_valid))
|
||||
|
||||
return train, test, valid
|
||||
|
||||
|
||||
def prepare_single_directory_transfer_args(data_request: DataRequest, target_dir: pathlib.Path) -> List[Tuple[pathlib.Path, pathlib.Path]]:
|
||||
train, test, valid = sample_directory_splits(data_request)
|
||||
|
||||
transfer_args = []
|
||||
for split_name, split_files in zip(("train", "test", "valid"), (train, test, valid)):
|
||||
for source_file in split_files:
|
||||
target_file = target_dir / split_name / f"{data_request.name}-{source_file.parent.name}" / source_file.name
|
||||
transfer_args.append((source_file, target_file))
|
||||
|
||||
return transfer_args
|
||||
|
||||
|
||||
# takes a dict of {<source directory>: (num_train, num_test, num_valid)} and a target directory
|
||||
# makes train, test, and split directories in the target directory with the normalized source files
|
||||
def create_code_dataset(
|
||||
data_requests: List[DataRequest],
|
||||
target_dir: pathlib.Path,
|
||||
version: Tuple[int, int],
|
||||
logger: logging.Logger,
|
||||
):
|
||||
with multiprocessing.Pool() as pool:
|
||||
# prepare a list of file transfers to execute
|
||||
logger.info(f"Sampling {', '.join(str(req.source_path.resolve()) for req in data_requests)}...")
|
||||
transfer_arg_lists = pool.starmap(
|
||||
prepare_single_directory_transfer_args,
|
||||
zip(data_requests, itertools.repeat(target_dir)),
|
||||
)
|
||||
# execute the file transfers
|
||||
versioned_transfer_arg_lists = [(source_file, target_file, version) for (source_file, target_file) in itertools.chain(*transfer_arg_lists)]
|
||||
logger.info(f"Normalizing and Compiling {len(versioned_transfer_arg_lists)} files...")
|
||||
for error in tqdm.tqdm(pool.imap_unordered(star_transfer_and_compile_file, versioned_transfer_arg_lists), total=len(versioned_transfer_arg_lists)):
|
||||
if error is not None:
|
||||
logger.debug(error)
|
||||
@@ -0,0 +1,67 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import ast
|
||||
import pathlib
|
||||
import sys
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def version_str_to_tuple(version_str: str) -> tuple[int, int]:
|
||||
# a version string is a string like 3.9.2
|
||||
versions = [int(version) for version in version_str.split(".")]
|
||||
return tuple(versions[:2])
|
||||
|
||||
|
||||
# must be run in python 3.9 or later for ast.unparse() support
|
||||
# version defaults to whatever version this script is running in; needs to be set explicitly for backwards compatibility
|
||||
# ast only supports versions 3.4 and later
|
||||
def normalize_source(
|
||||
source: str,
|
||||
version: Tuple[int, int] = sys.version_info[0:2],
|
||||
replace_docstrings=False,
|
||||
) -> str:
|
||||
"""
|
||||
Parse the source code into an AST, then convert back to source.
|
||||
This has the following normalizing effects:
|
||||
1. whitespace is set according to the PEP standard
|
||||
2. each statement is on exactly one line
|
||||
3. # comments are removed (note: docstrings are not removed)
|
||||
|
||||
:param str source: The source code to normalize
|
||||
:param tuple version: The (Major, Minor) version of python to parse with; must be at least (3, 4); defaults to
|
||||
same version as this script
|
||||
:param bool replace_docstrings: Replace all docstrings with 'pass'
|
||||
"""
|
||||
tree = ast.parse(source, feature_version=version)
|
||||
if replace_docstrings:
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Expr) and isinstance(node.value, ast.Str):
|
||||
node.value.s = "pass"
|
||||
return ast.unparse(tree)
|
||||
|
||||
|
||||
def normalize_source_file(
|
||||
source_file_path: str,
|
||||
cleaned_suffix: str = "-cleaned",
|
||||
version: tuple[int, int] = sys.version_info[0:2],
|
||||
):
|
||||
"""
|
||||
Normalizes the source code in a given file, then saves it to a '-cleaned' version in the same directory
|
||||
|
||||
:param str source_file_path: The absolute or relative path to the source .py file
|
||||
:param str cleaned_suffix: The suffix to add to the cleaned file, typically left as default
|
||||
:param tuple version: The (Major, Minor) version of python to parse with; must be at least (3, 4); defaults to
|
||||
same version as this script
|
||||
"""
|
||||
|
||||
# add the cleaned_suffix to the output_path
|
||||
input_path = pathlib.Path(source_file_path).resolve()
|
||||
output_path = input_path.with_stem(f"{input_path.stem}{cleaned_suffix}")
|
||||
|
||||
with open(input_path, "r") as source_file:
|
||||
normalized_source = normalize_source(source_file.read(), version=version)
|
||||
|
||||
with open(output_path, "w") as cleaned_file:
|
||||
cleaned_file.write(normalized_source)
|
||||
|
||||
return output_path
|
||||
@@ -0,0 +1,60 @@
|
||||
from io import BytesIO
|
||||
from typing import Dict, List, Literal
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from .DatasetDescription import DatasetDescription
|
||||
|
||||
LOCAL_DATASET = Dict[Literal["train", "test", "valid"], List[str]]
|
||||
|
||||
|
||||
def upload_single_dataset(data_files: LOCAL_DATASET, dataset_name: str, dataset_card: str):
|
||||
local_datasets = load_dataset("csv", data_files=data_files)
|
||||
local_datasets.push_to_hub(dataset_name, private=True)
|
||||
|
||||
dataset_card_with_stats = dataset_card + f"\n\nDataset Statistics:\n\n```\n{local_datasets}\n```"
|
||||
|
||||
api = HfApi()
|
||||
api.upload_file(
|
||||
path_or_fileobj=BytesIO(bytes(dataset_card_with_stats, "utf-8")),
|
||||
path_in_repo="README.md",
|
||||
repo_id=dataset_name,
|
||||
repo_type="dataset",
|
||||
)
|
||||
|
||||
|
||||
def upload_dataset_to_huggingface(dataset_description: DatasetDescription):
|
||||
formatted_data_requests = "\n".join(f"{str(req.source_path.resolve())}: (train: {req.num_train}, test: {req.num_test}, valid: {req.num_valid})" for req in dataset_description.data_requests)
|
||||
dataset_card = f"""
|
||||
# {dataset_description.name}
|
||||
|
||||
Created by the Syssec team @ UTD
|
||||
|
||||
Dataset Composition:
|
||||
|
||||
```
|
||||
{formatted_data_requests}
|
||||
```
|
||||
|
||||
Python version: `{".".join(map(str, dataset_description.version))}`
|
||||
"""
|
||||
|
||||
splits: List[Literal["train", "test", "valid"]] = [
|
||||
"train",
|
||||
"test",
|
||||
"valid",
|
||||
]
|
||||
|
||||
# collect data files
|
||||
segmentation_data_files: LOCAL_DATASET = {}
|
||||
statement_data_files: LOCAL_DATASET = {}
|
||||
for split in splits:
|
||||
segmentation_data_files[split] = [str(path.resolve()) for path in (dataset_description.csv_dir / split / "segmentation").glob("*.csv")]
|
||||
statement_data_files[split] = [str(path.resolve()) for path in (dataset_description.csv_dir / split / "statement").glob("*.csv")]
|
||||
|
||||
# upload datasets
|
||||
segmentation_dataset_name = f"{dataset_description.huggingface_user}/segmentation-{dataset_description.name}"
|
||||
upload_single_dataset(segmentation_data_files, segmentation_dataset_name, dataset_card)
|
||||
statement_dataset_name = f"{dataset_description.huggingface_user}/statement-{dataset_description.name}"
|
||||
upload_single_dataset(statement_data_files, statement_dataset_name, dataset_card)
|
||||
Reference in New Issue
Block a user