mirror of
https://github.com/syssec-utd/pylingual.git
synced 2026-05-10 18:39:03 -07:00
Merge pull request #115 from syssec-utd/dev-python-3.14
Draft Python 3.14 support
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
dataset/
|
||||
cache-dir/
|
||||
model-jsons/
|
||||
venv/
|
||||
.venv/
|
||||
*.pyc
|
||||
@@ -18,3 +20,4 @@ mise.toml
|
||||
dist/
|
||||
decompiled_*/
|
||||
decompiled_*.py
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
# "pylingual",
|
||||
# ]
|
||||
# [tool.uv.sources]
|
||||
# pylingual = { path = "../" }
|
||||
# pylingual = { path = "../", editable = true }
|
||||
# ///
|
||||
|
||||
import contextlib
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
# "pylingual",
|
||||
# ]
|
||||
# [tool.uv.sources]
|
||||
# pylingual = { path = "../../" }
|
||||
# pylingual = { path = "../../", editable = true }
|
||||
# ///
|
||||
|
||||
import csv
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
# "pylingual",
|
||||
# ]
|
||||
# [tool.uv.sources]
|
||||
# pylingual = { path = "../../" }
|
||||
# pylingual = { path = "../../", editable = true }
|
||||
# ///
|
||||
|
||||
import itertools
|
||||
|
||||
@@ -35,8 +35,8 @@ def normalize_source(
|
||||
tree = ast.parse(source, feature_version=version)
|
||||
if replace_docstrings:
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Expr) and isinstance(node.value, ast.Str):
|
||||
node.value.s = "pass"
|
||||
if isinstance(node, ast.Expr) and isinstance(node.value, ast.Constant) and isinstance(node.value.value, str):
|
||||
node.value.value = "pass"
|
||||
return ast.unparse(tree)
|
||||
|
||||
|
||||
|
||||
+121
-53
@@ -10,6 +10,7 @@ from datetime import datetime
|
||||
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvaluationResult:
|
||||
success: set[Path]
|
||||
@@ -18,34 +19,34 @@ class EvaluationResult:
|
||||
error: set[Path]
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, list[Path]]) -> 'EvaluationResult':
|
||||
def from_dict(cls, data: dict[str, list[Path]]) -> "EvaluationResult":
|
||||
return cls(
|
||||
success = set(data.get('success', [])),
|
||||
failure = set(data.get('failure', [])),
|
||||
compile_error = set(data.get('compile_error', [])),
|
||||
error = set(data.get('error', [])),
|
||||
success=set(data.get("success", [])),
|
||||
failure=set(data.get("failure", [])),
|
||||
compile_error=set(data.get("compile_error", [])),
|
||||
error=set(data.get("error", [])),
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def import_json(cls, json_path: Path) -> 'EvaluationResult':
|
||||
def import_json(cls, json_path: Path) -> "EvaluationResult":
|
||||
with json_path.open("r") as f:
|
||||
return cls.from_dict(json.load(f))
|
||||
|
||||
|
||||
def to_dict(self):
|
||||
return asdict(self)
|
||||
|
||||
def export_json(self, json_path: Path):
|
||||
jsonable_dict = {
|
||||
'success': sorted(self.success),
|
||||
'failure': sorted(self.failure),
|
||||
'compile_error': sorted(self.compile_error),
|
||||
'error': sorted(self.error),
|
||||
"success": sorted(self.success),
|
||||
"failure": sorted(self.failure),
|
||||
"compile_error": sorted(self.compile_error),
|
||||
"error": sorted(self.error),
|
||||
}
|
||||
with json_path.open("w") as f:
|
||||
json.dump(jsonable_dict, f, indent=2)
|
||||
|
||||
def __post_init__(self):
|
||||
assert len(set.intersection(self.success, self.failure, self.compile_error, self.error)) == 0, 'Malformed evaluation result. Paths appear in multiple categories.'
|
||||
assert len(set.intersection(self.success, self.failure, self.compile_error, self.error)) == 0, "Malformed evaluation result. Paths appear in multiple categories."
|
||||
|
||||
|
||||
# --- Constants and Configuration ---
|
||||
@@ -55,16 +56,28 @@ HARNESS_DIR = PROJECT_ROOT / ".eval_harness"
|
||||
CACHE_DIR = HARNESS_DIR / "results_cache"
|
||||
LOCAL_WORKSPACE = HARNESS_DIR / "local"
|
||||
|
||||
SUPPORTED_PYTHON_VERSIONS = ('3.6', '3.7', '3.8', '3.9', '3.10', '3.11', '3.12', '3.13')
|
||||
SUPPORTED_PYTHON_VERSIONS = (
|
||||
"3.6",
|
||||
"3.7",
|
||||
"3.8",
|
||||
"3.9",
|
||||
"3.10",
|
||||
"3.11",
|
||||
"3.12",
|
||||
"3.13",
|
||||
"3.14",
|
||||
)
|
||||
|
||||
# Rich console for pretty printing
|
||||
console = Console()
|
||||
|
||||
|
||||
def _get_cache_path(commit_hash: str, eval_file_list_path: Path, python_version: str) -> Path:
|
||||
cache_path = CACHE_DIR / python_version / commit_hash / eval_file_list_path.with_suffix('.json').name
|
||||
cache_path = CACHE_DIR / python_version / commit_hash / eval_file_list_path.with_suffix(".json").name
|
||||
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
return cache_path
|
||||
|
||||
|
||||
def run_command(command, cwd=None, capture_output=False, text=True):
|
||||
"""A helper to run a shell command and handle errors."""
|
||||
try:
|
||||
@@ -96,7 +109,7 @@ def get_head_commit_hash():
|
||||
return run_command(["git", "rev-parse", "--short", "HEAD"], capture_output=True).stdout.strip()
|
||||
|
||||
|
||||
def setup_workspace(workspace_path: Path, version_name: str, commit_hash: str = ''):
|
||||
def setup_workspace(workspace_path: Path, version_name: str, commit_hash: str = ""):
|
||||
"""Prepares a clean workspace for an evaluation run."""
|
||||
console.print(f"\n[bold cyan]Setting up '{version_name}' workspace...[/bold cyan]")
|
||||
|
||||
@@ -108,7 +121,6 @@ def setup_workspace(workspace_path: Path, version_name: str, commit_hash: str =
|
||||
code_dir = workspace_path / "code"
|
||||
venv_dir = workspace_path / "venv"
|
||||
# Handle OS-specific executable paths
|
||||
pip_executable = venv_dir / "Scripts" / "pip.exe" if sys.platform == "win32" else venv_dir / "bin" / "pip"
|
||||
|
||||
# 1. Get the source code
|
||||
if version_name == "local":
|
||||
@@ -117,7 +129,7 @@ def setup_workspace(workspace_path: Path, version_name: str, commit_hash: str =
|
||||
shutil.copytree(
|
||||
PROJECT_ROOT,
|
||||
code_dir,
|
||||
ignore=shutil.ignore_patterns(".git", ".eval_harness", "__pycache__", "*.pyc", ".idea"),
|
||||
ignore=shutil.ignore_patterns(".git", ".eval_harness", "__pycache__", "*.pyc", ".idea", ".venv"),
|
||||
)
|
||||
else:
|
||||
console.print(f" -> Exporting code from {version_name} ({commit_hash})...")
|
||||
@@ -126,13 +138,28 @@ def setup_workspace(workspace_path: Path, version_name: str, commit_hash: str =
|
||||
git_archive_command = f"git archive {commit_hash} | tar -x -C {code_dir}"
|
||||
run_command(git_archive_command, cwd=PROJECT_ROOT)
|
||||
|
||||
# 2. Create virtual environment
|
||||
console.print(f" -> Creating virtual environment at [italic]{venv_dir}[/italic]...")
|
||||
run_command([sys.executable, "-m", "venv", str(venv_dir)])
|
||||
# 2. Setup environment using uv
|
||||
console.print(" -> Setting up environment with uv...")
|
||||
|
||||
# 3. Install dependencies
|
||||
console.print(" -> Installing project dependencies...")
|
||||
run_command([str(pip_executable), "install", "-e", "."], cwd=code_dir, capture_output=True)
|
||||
# Check if we should use 'uv sync' (new style) or 'uv pip' (legacy)
|
||||
has_uv_lock = (code_dir / "uv.lock").exists()
|
||||
has_pyproject = (code_dir / "pyproject.toml").exists()
|
||||
|
||||
if has_uv_lock and has_pyproject:
|
||||
console.print(" -> Found uv.lock, using [bold]uv sync[/bold]...")
|
||||
# uv sync creates the venv in .venv by default
|
||||
run_command(["uv", "sync"], cwd=code_dir)
|
||||
venv_dir = code_dir / ".venv"
|
||||
else:
|
||||
console.print(" -> Legacy setup, using [bold]uv pip install[/bold]...")
|
||||
# Create venv explicitly
|
||||
run_command(["uv", "venv", str(venv_dir)], cwd=code_dir)
|
||||
|
||||
# Determine python executable in the new venv for pip install
|
||||
venv_python = venv_dir / "Scripts" / "python.exe" if sys.platform == "win32" else venv_dir / "bin" / "python"
|
||||
|
||||
# Install dependencies
|
||||
run_command(["uv", "pip", "install", "-e", ".", "--python", str(venv_python)], cwd=code_dir)
|
||||
|
||||
return code_dir, venv_dir
|
||||
|
||||
@@ -143,10 +170,10 @@ def run_evaluation(workspace_path: Path, venv_dir: Path, input_file: Path, pytho
|
||||
console.print(f"\n[bold green]Running evaluation for '{version_name}' on Python {python_version}...[/bold green]")
|
||||
|
||||
code_dir = workspace_path / "code"
|
||||
output_dir = workspace_path / "output" / python_version # Use a sub-dir for version-specific output
|
||||
output_dir = workspace_path / "output" / python_version # Use a sub-dir for version-specific output
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
results_file = output_dir / python_version / f"{input_file.stem}_0" / "results.json"
|
||||
|
||||
|
||||
# Clean previous results for this version if they exist
|
||||
if results_file.exists():
|
||||
results_file.unlink()
|
||||
@@ -172,7 +199,14 @@ def run_evaluation(workspace_path: Path, venv_dir: Path, input_file: Path, pytho
|
||||
|
||||
return EvaluationResult.import_json(results_file)
|
||||
|
||||
def compare_and_report(commit_results: EvaluationResult, local_results: EvaluationResult, report_path: Path, compare_to_commit: str, python_version: str):
|
||||
|
||||
def compare_and_report(
|
||||
commit_results: EvaluationResult,
|
||||
local_results: EvaluationResult,
|
||||
report_path: Path,
|
||||
compare_to_commit: str,
|
||||
python_version: str,
|
||||
):
|
||||
"""Compares two sets of results and prints a detailed report to console and a file."""
|
||||
with report_path.open("w", encoding="utf-8") as f:
|
||||
report_console = Console(file=f)
|
||||
@@ -197,7 +231,7 @@ def compare_and_report(commit_results: EvaluationResult, local_results: Evaluati
|
||||
commit_map = {path: cat for cat, paths in commit_dict.items() for path in paths}
|
||||
local_map = {path: cat for cat, paths in local_dict.items() for path in paths}
|
||||
all_paths = set(commit_map.keys()) | set(local_map.keys())
|
||||
|
||||
|
||||
movement_matrix = {cat: {cat2: 0 for cat2 in categories} for cat in categories}
|
||||
for path in all_paths:
|
||||
from_cat = commit_map.get(path)
|
||||
@@ -216,11 +250,11 @@ def compare_and_report(commit_results: EvaluationResult, local_results: Evaluati
|
||||
if from_cat == to_cat:
|
||||
style = "blue"
|
||||
elif from_cat == "success":
|
||||
style = "bold red" # Regression from success
|
||||
style = "bold red" # Regression from success
|
||||
elif to_cat == "success":
|
||||
style = "bold green" # Improvement to success
|
||||
style = "bold green" # Improvement to success
|
||||
else:
|
||||
style = "tan" # Side-move
|
||||
style = "tan" # Side-move
|
||||
row.append(f"[{style}]{'+' if from_cat != to_cat else ''}{count}[/{style}]")
|
||||
table.add_row(*row)
|
||||
|
||||
@@ -233,10 +267,7 @@ def compare_and_report(commit_results: EvaluationResult, local_results: Evaluati
|
||||
if from_cat == to_cat:
|
||||
continue
|
||||
|
||||
moved_paths = sorted([
|
||||
p for p in all_paths
|
||||
if commit_map.get(p) == from_cat and local_map.get(p) == to_cat
|
||||
])
|
||||
moved_paths = sorted([p for p in all_paths if commit_map.get(p) == from_cat and local_map.get(p) == to_cat])
|
||||
|
||||
if not moved_paths:
|
||||
continue
|
||||
@@ -255,7 +286,7 @@ def compare_and_report(commit_results: EvaluationResult, local_results: Evaluati
|
||||
for p in moved_paths:
|
||||
console.print(f"- {p}")
|
||||
report_console.print(f"- {p}", soft_wrap=True)
|
||||
|
||||
|
||||
# 3. New and Removed Items
|
||||
new_items = sorted([p for p in all_paths if commit_map.get(p) is None])
|
||||
removed_items = sorted([p for p in all_paths if local_map.get(p) is None])
|
||||
@@ -268,7 +299,7 @@ def compare_and_report(commit_results: EvaluationResult, local_results: Evaluati
|
||||
line = format_func(item)
|
||||
console.print(line)
|
||||
report_console.print(line)
|
||||
|
||||
|
||||
print_list_section(
|
||||
"\n[bold blue]New Items[/bold blue]",
|
||||
new_items,
|
||||
@@ -283,12 +314,41 @@ def compare_and_report(commit_results: EvaluationResult, local_results: Evaluati
|
||||
|
||||
console.print(f"\n-> Comparison report saved to [italic]{report_path}[/italic]")
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option('--input-file', required=True, type=click.Path(exists=True, dir_okay=False, resolve_path=True, path_type=Path), help='Path to the input file listing test cases.')
|
||||
@click.option('--python-version', 'python_versions', multiple=True, type=str, help='Python version to evaluate. Can be specified multiple times. Defaults to all supported versions.', default=SUPPORTED_PYTHON_VERSIONS)
|
||||
@click.option('--compare-to-commit', type=str, help='The git commit hash to compare to. Defaults to HEAD.', default='HEAD')
|
||||
@click.option('--no-cache', is_flag=True, default=False, help='Force re-evaluation of the comparison commit for all specified Python versions.')
|
||||
def main(input_file: Path, python_versions: list[str], compare_to_commit: str, no_cache: bool):
|
||||
@click.option(
|
||||
"--input-file",
|
||||
required=True,
|
||||
type=click.Path(exists=True, dir_okay=False, resolve_path=True, path_type=Path),
|
||||
help="Path to the input file listing test cases.",
|
||||
)
|
||||
@click.option(
|
||||
"--python-version",
|
||||
"python_versions",
|
||||
multiple=True,
|
||||
type=str,
|
||||
help="Python version to evaluate. Can be specified multiple times. Defaults to all supported versions.",
|
||||
default=SUPPORTED_PYTHON_VERSIONS,
|
||||
)
|
||||
@click.option(
|
||||
"--compare-to-commit",
|
||||
type=str,
|
||||
help="The git commit hash to compare to. Defaults to HEAD.",
|
||||
default="HEAD",
|
||||
)
|
||||
@click.option(
|
||||
"--no-cache",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Force re-evaluation of the comparison commit for all specified Python versions.",
|
||||
)
|
||||
@click.option(
|
||||
"--no-cleanup",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Do not clean up workspaces after evaluation.",
|
||||
)
|
||||
def main(input_file: Path, python_versions: list[str], compare_to_commit: str, no_cache: bool, no_cleanup: bool):
|
||||
"""
|
||||
An evaluation framework to compare the performance of the current project
|
||||
state against a previous git commit.
|
||||
@@ -297,11 +357,11 @@ def main(input_file: Path, python_versions: list[str], compare_to_commit: str, n
|
||||
run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
commit_version = compare_to_commit
|
||||
if compare_to_commit.lower() == 'head':
|
||||
if compare_to_commit.lower() == "head":
|
||||
compare_to_commit = get_head_commit_hash()
|
||||
console.print(f"[bold green]Resolved HEAD to commit {compare_to_commit}.[/bold green]")
|
||||
else:
|
||||
compare_to_commit = compare_to_commit[:7].lower() # shorten and lowercase for consistency
|
||||
compare_to_commit = compare_to_commit[:7].lower() # shorten and lowercase for consistency
|
||||
|
||||
COMMIT_WORKSPACE = HARNESS_DIR / compare_to_commit
|
||||
|
||||
@@ -322,7 +382,7 @@ def main(input_file: Path, python_versions: list[str], compare_to_commit: str, n
|
||||
else:
|
||||
if commit_venv_dir is None:
|
||||
_, commit_venv_dir = setup_workspace(COMMIT_WORKSPACE, commit_version, compare_to_commit)
|
||||
|
||||
|
||||
assert commit_venv_dir is not None
|
||||
commit_results = run_evaluation(COMMIT_WORKSPACE, commit_venv_dir, input_file, python_version)
|
||||
commit_results.export_json(cached_result_file)
|
||||
@@ -330,7 +390,7 @@ def main(input_file: Path, python_versions: list[str], compare_to_commit: str, n
|
||||
|
||||
# --- Local Evaluation ---
|
||||
local_results = run_evaluation(LOCAL_WORKSPACE, local_venv_dir, input_file, python_version)
|
||||
|
||||
|
||||
# --- Save Local Results Artifact ---
|
||||
local_artifact_path = CACHE_DIR / python_version / f"local_results_{run_timestamp}.json"
|
||||
local_results.export_json(local_artifact_path)
|
||||
@@ -338,14 +398,22 @@ def main(input_file: Path, python_versions: list[str], compare_to_commit: str, n
|
||||
|
||||
# --- Comparison ---
|
||||
report_artifact_path = CACHE_DIR / python_version / f"comparison_report_{run_timestamp}.txt"
|
||||
compare_and_report(commit_results, local_results, report_artifact_path, compare_to_commit, python_version)
|
||||
compare_and_report(
|
||||
commit_results,
|
||||
local_results,
|
||||
report_artifact_path,
|
||||
compare_to_commit,
|
||||
python_version,
|
||||
)
|
||||
|
||||
# --- Final Cleanup ---
|
||||
console.print("\n[bold]Cleaning up workspaces...[/bold]")
|
||||
if commit_venv_dir is not None:
|
||||
shutil.rmtree(COMMIT_WORKSPACE)
|
||||
shutil.rmtree(LOCAL_WORKSPACE)
|
||||
console.print("Done.")
|
||||
if not no_cleanup:
|
||||
console.print("\n[bold]Cleaning up workspaces...[/bold]")
|
||||
if commit_venv_dir is not None:
|
||||
shutil.rmtree(COMMIT_WORKSPACE)
|
||||
shutil.rmtree(LOCAL_WORKSPACE)
|
||||
console.print("Done.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
import pathlib
|
||||
import tqdm
|
||||
import click
|
||||
import csv
|
||||
from typing import TextIO
|
||||
from datetime import datetime
|
||||
|
||||
import signal
|
||||
|
||||
from pylingual.decompiler import decompile
|
||||
|
||||
|
||||
@click.command(help="Evaluation script for pylingual")
|
||||
@click.argument("pyc_list", type=click.File("r"))
|
||||
@click.argument("out_dir", type=click.Path(file_okay=False, dir_okay=True, writable=True, path_type=pathlib.Path))
|
||||
def main(pyc_list: TextIO, out_dir: pathlib.Path):
|
||||
start_time = datetime.now()
|
||||
|
||||
def timeout_handler(signum, frame):
|
||||
raise TimeoutError()
|
||||
|
||||
signal.signal(signal.SIGALRM, timeout_handler)
|
||||
|
||||
out_dir = out_dir / f"pylingual-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
|
||||
|
||||
pyc_files = [pathlib.Path(pyc_path_line.strip()) for pyc_path_line in pyc_list.readlines()]
|
||||
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
evaluation_results_file = out_dir / "evaluation_results.csv"
|
||||
evaluation_results_stream = evaluation_results_file.open("w", newline="")
|
||||
evaluation_writer = csv.DictWriter(evaluation_results_stream, fieldnames=["pyc_file", "py_file", "identifier", "success", "category", "notes"])
|
||||
evaluation_writer.writeheader()
|
||||
|
||||
# decompile all the pyc files
|
||||
total_files_succeeded = 0
|
||||
total_files_attempted = 0
|
||||
for pyc_file in (evaluation_progress := tqdm.tqdm(pyc_files)):
|
||||
decompiler_results_dir = out_dir / "decompilation_results"
|
||||
decompiler_results_dir.mkdir(parents=True, exist_ok=True)
|
||||
target_out_dir = decompiler_results_dir / pyc_file.parent.name
|
||||
identifier = str(pyc_file).split("/")[-2]
|
||||
# update progress bar
|
||||
if total_files_attempted > 0:
|
||||
evaluation_progress.set_postfix(
|
||||
{
|
||||
"file_success": f"{total_files_succeeded}/{total_files_attempted} ({total_files_succeeded / total_files_attempted:.2%})",
|
||||
}
|
||||
)
|
||||
|
||||
total_files_attempted += 1
|
||||
|
||||
# decompile the file
|
||||
try:
|
||||
signal.alarm(300) # 5-minute timeout for decompiling one file
|
||||
py_file = decompile(pyc_file, target_out_dir)
|
||||
signal.alarm(0) # success; disable timer
|
||||
except Exception as err:
|
||||
signal.alarm(0)
|
||||
evaluation_writer.writerow({"pyc_file": pyc_file, "py_file": "", "identifier": "FILE", "success": False, "category": "DECOMPILER ERROR", "notes": repr(err)})
|
||||
continue
|
||||
|
||||
if all([result.success for result in py_file.equivalence_results]):
|
||||
evaluation_writer.writerow({"pyc_file": pyc_file, "py_file": pyc_file.parent.name, "identifier": "FILE", "success": True, "category": "Equal", "notes": ""})
|
||||
total_files_succeeded += 1
|
||||
else:
|
||||
evaluation_writer.writerow({"pyc_file": pyc_file, "py_file": pyc_file.parent.name, "identifier": "FILE", "success": False, "category": "Different", "notes": ""})
|
||||
|
||||
evaluation_writer.writerows({"pyc_file": pyc_file, "py_file": pyc_file.parent.name, "identifier": identifier, "success": result.success, "notes": ""} for result in py_file.equivalence_results)
|
||||
|
||||
evaluation_results_stream.close()
|
||||
elapsed_time = datetime.now() - start_time
|
||||
|
||||
with open(out_dir / "elapsed_time.txt", "w") as time_file:
|
||||
time_file.write(f"Elapsed Time: {str(elapsed_time)}\n")
|
||||
time_file.write(f"File success: {total_files_succeeded}/{total_files_attempted} {total_files_succeeded / total_files_attempted:.2%}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -4,7 +4,7 @@
|
||||
# "pylingual",
|
||||
# ]
|
||||
# [tool.uv.sources]
|
||||
# pylingual = { path = "../" }
|
||||
# pylingual = { path = "../", editable = true }
|
||||
# ///
|
||||
|
||||
import json
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
# "pylingual",
|
||||
# ]
|
||||
# [tool.uv.sources]
|
||||
# pylingual = { path = "../../" }
|
||||
# pylingual = { path = "../../", editable = true }
|
||||
# ///
|
||||
|
||||
import ast
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
# "pylingual",
|
||||
# ]
|
||||
# [tool.uv.sources]
|
||||
# pylingual = { path = "../../" }
|
||||
# pylingual = { path = "../../", editable = true }
|
||||
# ///
|
||||
|
||||
import logging
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
# "pylingual",
|
||||
# ]
|
||||
# [tool.uv.sources]
|
||||
# pylingual = { path = "../" }
|
||||
# pylingual = { path = "../", editable = true }
|
||||
# ///
|
||||
|
||||
import logging
|
||||
@@ -27,7 +27,8 @@ def train_segmentation(segmentation_config_path: pathlib.Path, logger: logging.L
|
||||
logger.info("training masked language model...")
|
||||
subprocess.run(
|
||||
[
|
||||
"uv", "run",
|
||||
"uv",
|
||||
"run",
|
||||
"torchrun",
|
||||
f"--nnodes={nnodes}",
|
||||
f"--nproc-per-node={nproc_per_node}",
|
||||
@@ -47,7 +48,8 @@ def train_segmentation(segmentation_config_path: pathlib.Path, logger: logging.L
|
||||
logger.info("training segmentation model...")
|
||||
subprocess.run(
|
||||
[
|
||||
"uv", "run",
|
||||
"uv",
|
||||
"run",
|
||||
"torchrun",
|
||||
f"--nnodes={nnodes}",
|
||||
f"--nproc-per-node={nproc_per_node}",
|
||||
@@ -74,7 +76,8 @@ def train_statement(statement_config_path: pathlib.Path, logger: logging.Logger,
|
||||
logger.info("training statement model...")
|
||||
subprocess.run(
|
||||
[
|
||||
"uv", "run",
|
||||
"uv",
|
||||
"run",
|
||||
"torchrun",
|
||||
f"--nnodes={nnodes}",
|
||||
f"--nproc-per-node={nproc_per_node}",
|
||||
|
||||
@@ -50,7 +50,7 @@ class CFG(DiGraph_CFT):
|
||||
if source is not None and n.inst.starts_line is not None:
|
||||
n.inst.source_line = source[n.inst.starts_line - 1]
|
||||
else:
|
||||
n.inst.source_line = ''
|
||||
n.inst.source_line = ""
|
||||
|
||||
for _a, _b, _p in self.edges(data=True):
|
||||
self[_a][_b]["kind"] = EdgeKind(_p["type"])
|
||||
@@ -155,7 +155,9 @@ class CFG(DiGraph_CFT):
|
||||
def _create_dominator_tree(self):
|
||||
self._dt = nx.create_empty_copy(self)
|
||||
self._dt.add_edges_from(nx.immediate_dominators(self, self.start).items())
|
||||
self._dt.remove_edge(self.start, self.start)
|
||||
# In NetworkX 3.5 and below, the start node dominates itself; this was changed in NetworkX 3.6
|
||||
if self._dt.has_edge(self.start, self.start):
|
||||
self._dt.remove_edge(self.start, self.start)
|
||||
self._dr = nx.transitive_closure_dag(self._dt.reverse())
|
||||
|
||||
def dominates(self, node_a, node_b):
|
||||
@@ -192,7 +194,10 @@ class CFG(DiGraph_CFT):
|
||||
self.add_edges_from((n, self.end, EdgeKind.Meta.prop()) for n in self.nodes if all(e["kind"] is EdgeKind.Exception for _, _, e in self.out_edges(n, data=True)))
|
||||
pdt = nx.create_empty_copy(self)
|
||||
pdt.add_edges_from((B, A) for A, B in nx.immediate_dominators(self.reverse(), self.end).items())
|
||||
pdt.remove_edge(self.end, self.end)
|
||||
|
||||
# In NetworkX 3.5 and below, the start node dominates itself; this was changed in NetworkX 3.6
|
||||
if pdt.has_edge(self.end, self.end):
|
||||
pdt.remove_edge(self.end, self.end)
|
||||
pdr = nx.transitive_closure_dag(pdt)
|
||||
postdominates = lambda A, B: pdr.has_edge(A, B) or A == B
|
||||
control_dependent = lambda A, B: 0 < sum(postdominates(A, succ) for succ in self.successors(B)) < self.out_degree(B)
|
||||
|
||||
@@ -67,7 +67,8 @@ class SourceContext:
|
||||
for bc in self.pyc.iter_bytecodes():
|
||||
cft = self.cfts[bc.codeobj]
|
||||
if bc.codeobj.co_flags & inspect.CO_NEWLOCALS:
|
||||
if bc.codeobj.co_consts and isinstance(bc.codeobj.co_consts[0], str):
|
||||
has_docstring_314 = bool(bc.codeobj.co_flags & 2**26) # inspect.CO_HAS_DOCSTRING flag introduced in Python 3.14
|
||||
if (self.pyc.version < (3, 14) or has_docstring_314) and bc.codeobj.co_consts and isinstance(bc.codeobj.co_consts[0], str):
|
||||
doc = use_escape_sequences(bc.codeobj.co_consts[0])
|
||||
cft.add_header(f'"""{doc}"""')
|
||||
if bc.codeobj.co_flags & (inspect.CO_GENERATOR | inspect.CO_ASYNC_GENERATOR):
|
||||
|
||||
@@ -66,7 +66,7 @@ class RemoveUnreachable(ControlFlowTemplate):
|
||||
return node
|
||||
|
||||
|
||||
@register_template(0, 0, (3, 12), (3, 13))
|
||||
@register_template(0, 0, *versions_from(3, 12))
|
||||
class JumpTemplate(ControlFlowTemplate):
|
||||
template = T(
|
||||
body=~N("jump", None).with_cond(without_instructions("CLEANUP_THROW")),
|
||||
@@ -89,7 +89,7 @@ class JumpTemplate(ControlFlowTemplate):
|
||||
|
||||
try_match = make_try_match(
|
||||
{
|
||||
EdgeKind.Fall: "tail",
|
||||
EdgeKind.Jump: "tail",
|
||||
EdgeKind.TrueJump: "block",
|
||||
},
|
||||
"body",
|
||||
|
||||
@@ -1,14 +1,32 @@
|
||||
from ..cft import ControlFlowTemplate, EdgeKind, MetaTemplate, register_template
|
||||
from ..utils import E, T, N, defer_source_to, has_some_lines, run_is, has_no_lines, with_instructions, exact_instructions, has_instval, starting_instructions, to_indented_source, make_try_match, without_top_level_instructions, ending_instructions
|
||||
from ..utils import (
|
||||
E,
|
||||
T,
|
||||
N,
|
||||
defer_source_to,
|
||||
has_some_lines,
|
||||
run_is,
|
||||
has_no_lines,
|
||||
versions_from,
|
||||
with_instructions,
|
||||
exact_instructions,
|
||||
has_instval,
|
||||
starting_instructions,
|
||||
to_indented_source,
|
||||
make_try_match,
|
||||
without_top_level_instructions,
|
||||
ending_instructions,
|
||||
)
|
||||
from .Loop import BreakTemplate, ContinueTemplate
|
||||
|
||||
|
||||
class EarlyRet(ControlFlowTemplate):
|
||||
template = T(
|
||||
pop_block=~N("early_ret", None).with_in_deg(1),
|
||||
early_ret=N(E.meta("end")).with_cond(ending_instructions("RETURN_VALUE")).with_cond(has_no_lines).with_in_deg(1),
|
||||
end=N(None).of_type(MetaTemplate),
|
||||
)
|
||||
|
||||
|
||||
try_match = make_try_match({EdgeKind.Meta: "end"}, "pop_block", "early_ret")
|
||||
|
||||
to_indented_source = defer_source_to("pop_block")
|
||||
@@ -19,7 +37,10 @@ class IfElse(ControlFlowTemplate):
|
||||
template = T(
|
||||
if_header=~N("if_body", "else_body").with_cond(without_top_level_instructions("WITH_EXCEPT_START", "CHECK_EXC_MATCH", "FOR_ITER")),
|
||||
if_body=~N.tail().of_subtemplate(EarlyRet) | ~N(None).with_in_deg(1).of_type(BreakTemplate, ContinueTemplate) | ~N("tail.").with_in_deg(1),
|
||||
else_body=~N.tail().of_subtemplate(EarlyRet) | ~N("tail.").with_in_deg(1).of_type(BreakTemplate, ContinueTemplate) | ~N("tail.").with_cond(without_top_level_instructions("RERAISE", "END_FINALLY")).with_in_deg(1) | ~N("tail").with_cond(has_some_lines).with_in_deg(1),
|
||||
else_body=~N.tail().of_subtemplate(EarlyRet)
|
||||
| ~N("tail.").with_in_deg(1).of_type(BreakTemplate, ContinueTemplate)
|
||||
| ~N("tail.").with_cond(without_top_level_instructions("RERAISE", "END_FINALLY")).with_in_deg(1)
|
||||
| ~N("tail").with_cond(has_some_lines).with_in_deg(1),
|
||||
tail=N.tail(),
|
||||
)
|
||||
|
||||
@@ -57,7 +78,6 @@ class IfJumpElse(ControlFlowTemplate):
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@register_template(1, 40)
|
||||
class IfElseJump(ControlFlowTemplate):
|
||||
template = T(
|
||||
@@ -80,7 +100,7 @@ class IfElseJump(ControlFlowTemplate):
|
||||
"""
|
||||
|
||||
|
||||
@register_template(1, 39, (3, 12), (3, 13))
|
||||
@register_template(1, 39, *versions_from(3, 12))
|
||||
class IfElseLoop(ControlFlowTemplate):
|
||||
template = T(
|
||||
if_header=~N("else_body", "if_body").with_cond(without_top_level_instructions("WITH_EXCEPT_START", "CHECK_EXC_MATCH", "FOR_ITER")),
|
||||
@@ -107,7 +127,11 @@ class IfElseLoop(ControlFlowTemplate):
|
||||
class IfThen(ControlFlowTemplate):
|
||||
template = T(
|
||||
if_header=~N("if_body", "tail").with_cond(without_top_level_instructions("WITH_EXCEPT_START", "CHECK_EXC_MATCH", "FOR_ITER", "JUMP_IF_NOT_EXC_MATCH")),
|
||||
if_body=~N.tail().with_in_deg(1).of_type(BreakTemplate, ContinueTemplate) | ~N("tail").with_in_deg(1) | ~N("tail.").with_in_deg(1).with_cond(run_is(2)) | ~N.tail().with_in_deg(1).with_cond(exact_instructions("LOAD_CONST","RETURN_VALUE"), exact_instructions("POP_TOP", "LOAD_CONST","RETURN_VALUE")) | ~N.tail().with_in_deg(1).with_cond(ending_instructions("POP_TOP", "RERAISE")),
|
||||
if_body=~N.tail().with_in_deg(1).of_type(BreakTemplate, ContinueTemplate)
|
||||
| ~N("tail").with_in_deg(1)
|
||||
| ~N("tail.").with_in_deg(1).with_cond(run_is(2))
|
||||
| ~N.tail().with_in_deg(1).with_cond(exact_instructions("LOAD_CONST", "RETURN_VALUE"), exact_instructions("POP_TOP", "LOAD_CONST", "RETURN_VALUE"))
|
||||
| ~N.tail().with_in_deg(1).with_cond(ending_instructions("POP_TOP", "RERAISE")),
|
||||
tail=N.tail(),
|
||||
)
|
||||
|
||||
@@ -125,7 +149,7 @@ class IfThen(ControlFlowTemplate):
|
||||
class Assertion(ControlFlowTemplate):
|
||||
template = T(
|
||||
assertion=~N("fail", "tail"),
|
||||
fail=+N().with_cond(starting_instructions("LOAD_ASSERTION_ERROR"), has_instval("LOAD_GLOBAL", argval="AssertionError")).with_cond(has_no_lines),
|
||||
fail=+N().with_cond(starting_instructions("LOAD_ASSERTION_ERROR"), has_instval("LOAD_GLOBAL", argval="AssertionError"), has_instval("LOAD_COMMON_CONSTANT", argval=0)).with_cond(has_no_lines),
|
||||
tail=N.tail(),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from ..cft import ControlFlowTemplate, EdgeKind, MetaTemplate, register_template
|
||||
from ..utils import E, T, N, defer_source_to, ending_instructions, exact_instructions, no_back_edges, to_indented_source, make_try_match
|
||||
from ..utils import E, T, N, defer_source_to, ending_instructions, exact_instructions, no_back_edges, starting_instructions, to_indented_source, make_try_match
|
||||
|
||||
|
||||
@register_template(0, 0)
|
||||
@@ -52,3 +52,19 @@ class Generator3_12(ControlFlowTemplate):
|
||||
{entry}
|
||||
{body}
|
||||
"""
|
||||
|
||||
@register_template(0, 0, (3, 14))
|
||||
class Generator3_14(ControlFlowTemplate):
|
||||
template = T(
|
||||
entry=N(E.exc("gen_cleanup")).with_cond(starting_instructions("RETURN_GENERATOR", "POP_TOP")),
|
||||
gen_cleanup=N(E.meta("end")).with_cond(exact_instructions("CALL_INTRINSIC_1", "RERAISE")),
|
||||
end=N().of_type(MetaTemplate),
|
||||
)
|
||||
|
||||
try_match = make_try_match({EdgeKind.Fall: "end"}, "entry","gen_cleanup")
|
||||
|
||||
@to_indented_source
|
||||
def to_indented_source():
|
||||
"""
|
||||
{entry}
|
||||
"""
|
||||
|
||||
@@ -12,6 +12,7 @@ from ..utils import (
|
||||
is_not_type,
|
||||
no_back_edges,
|
||||
versions_below,
|
||||
versions_except,
|
||||
versions_from,
|
||||
ending_instructions,
|
||||
has_no_lines,
|
||||
@@ -29,7 +30,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
|
||||
@register_template(0, 1)
|
||||
@register_template(0, 1, *versions_below(3, 14))
|
||||
class ForLoop(ControlFlowTemplate):
|
||||
template = T(
|
||||
for_iter=~N("for_body", "tail"),
|
||||
@@ -46,6 +47,23 @@ class ForLoop(ControlFlowTemplate):
|
||||
{for_body}
|
||||
"""
|
||||
|
||||
@register_template(0, 1, *versions_from(3, 14))
|
||||
class ForLoop3_14(ControlFlowTemplate):
|
||||
template = T(
|
||||
for_iter=~N("for_body", "tail").with_cond(with_top_level_instructions("FOR_ITER")),
|
||||
for_body=~N("for_iter").with_in_deg(1),
|
||||
tail=N.tail().with_cond(is_not_type(LoopElse)),
|
||||
)
|
||||
|
||||
try_match = make_try_match({EdgeKind.Fall: "tail"}, "for_iter", "for_body")
|
||||
|
||||
@to_indented_source
|
||||
def to_indented_source():
|
||||
"""
|
||||
{for_iter}
|
||||
{for_body}
|
||||
"""
|
||||
|
||||
|
||||
@register_template(0, 1)
|
||||
class ForElseLoop(ControlFlowTemplate):
|
||||
@@ -85,7 +103,7 @@ class LoopedReturn(ControlFlowTemplate):
|
||||
{for_body}
|
||||
"""
|
||||
|
||||
@register_template(0, 2, *versions_below(3, 10))
|
||||
@register_template(0, 2, *versions_except((3, 10), (3, 11), (3, 12), (3, 13)))
|
||||
class SelfLoop3_6(ControlFlowTemplate):
|
||||
template = T(
|
||||
loop_body=~N("loop_body", None)
|
||||
@@ -181,7 +199,7 @@ class WhileElse3_10(ControlFlowTemplate):
|
||||
"""
|
||||
|
||||
|
||||
@register_template(0, 1, (3, 12), (3, 13))
|
||||
@register_template(0, 1, *versions_from(3, 12))
|
||||
class WhileElse3_12(ControlFlowTemplate):
|
||||
template = T(
|
||||
while_header=~N("while_body", "else_body"),
|
||||
@@ -244,6 +262,49 @@ class WhileIfElseLoop(ControlFlowTemplate):
|
||||
"""
|
||||
|
||||
|
||||
@register_template(1, 39, *versions_from(3, 14))
|
||||
class WhileLoop3_14(ControlFlowTemplate):
|
||||
template = T(
|
||||
if_header=~N("if_body", "else_body").with_cond(without_top_level_instructions("WITH_EXCEPT_START", "CHECK_EXC_MATCH", "FOR_ITER", "NOP")),
|
||||
else_body=~N("tail.").with_cond(without_top_level_instructions("RERAISE", "END_FINALLY")).with_in_deg(1),
|
||||
if_body=~N("if_header").with_in_deg(1),
|
||||
tail=N.tail(),
|
||||
)
|
||||
|
||||
try_match = make_try_match({EdgeKind.Fall: "tail"}, "if_header", "if_body", "else_body")
|
||||
|
||||
@to_indented_source
|
||||
def to_indented_source():
|
||||
"""
|
||||
{if_header}
|
||||
{if_body}
|
||||
{else_body?else:}
|
||||
{else_body}
|
||||
"""
|
||||
|
||||
|
||||
@register_template(1, 39, *versions_from(3, 14))
|
||||
class WhileTrueLoop3_14(ControlFlowTemplate):
|
||||
template = T(
|
||||
if_header=~N("if_body", "else_body").with_cond(without_top_level_instructions("WITH_EXCEPT_START", "CHECK_EXC_MATCH", "FOR_ITER")).with_cond(starting_instructions("NOP")),
|
||||
else_body=~N("tail.").with_cond(without_top_level_instructions("RERAISE", "END_FINALLY")).with_in_deg(1),
|
||||
if_body=~N("if_header").with_in_deg(1),
|
||||
tail=N.tail(),
|
||||
)
|
||||
|
||||
try_match = make_try_match({EdgeKind.Fall: "tail"}, "if_header", "if_body", "else_body")
|
||||
|
||||
@to_indented_source
|
||||
def to_indented_source():
|
||||
"""
|
||||
while True:
|
||||
{if_header}
|
||||
{else_body}
|
||||
{if_body?else:}
|
||||
{if_body}
|
||||
"""
|
||||
|
||||
|
||||
@register_template(0, 3)
|
||||
class InlinedComprehensionTemplate(ControlFlowTemplate):
|
||||
template = T(
|
||||
@@ -461,7 +522,7 @@ class FixLoop(ControlFlowTemplate):
|
||||
# Only remove edge if there are more than 2 incoming edges to avoid breaking other control flow structures
|
||||
if cfg.in_degree(succ) > 2:
|
||||
cfg.remove_edge(break_node, succ)
|
||||
elif cfg.in_degree(succ) <= 2:
|
||||
elif cfg.in_degree(succ) <= 2 and succ != cfg.end:
|
||||
# Broken: conflicting cases where removing the edge would strand blocks
|
||||
# but also match correctly to valid Break statement nodes
|
||||
continue
|
||||
|
||||
@@ -352,7 +352,7 @@ class Decompiler:
|
||||
logger.info(f"Checking decompilation for {self.name}...")
|
||||
src = self.tmpfile()
|
||||
pyc = self.tmpfile()
|
||||
src.write_text(source)
|
||||
src.write_text(source, encoding='utf-8')
|
||||
try:
|
||||
compile_version(src, pyc, self.version)
|
||||
except CompileError as e:
|
||||
@@ -478,6 +478,6 @@ def decompile(pyc: PYCFile | Path, save_to: Path | None = None, config_file: Pat
|
||||
logger.info("Decompilation complete")
|
||||
logger.info(f"{result.calculate_success_rate():.2%} code object success rate")
|
||||
if save_to:
|
||||
save_to.write_text(result.decompiled_source)
|
||||
save_to.write_text(result.decompiled_source, encoding='utf-8')
|
||||
logger.info(f"Result saved to {save_to}")
|
||||
return result
|
||||
|
||||
@@ -85,3 +85,14 @@ v3.13:
|
||||
REPO: syssec-utd/py313-pylingual-v1.1-statement
|
||||
REVISION: main
|
||||
TOKENIZER: syssec-utd/py313-pylingual-v1.1-tok
|
||||
|
||||
v3.14:
|
||||
SEGMENTATION_MODEL:
|
||||
REPO: syssec-utd/py314-pylingual-v3-segmenter
|
||||
REVISION: main
|
||||
TOKENIZER: syssec-utd/py314-pylingual-v3-tokenizer
|
||||
|
||||
STATEMENT_MODEL:
|
||||
REPO: syssec-utd/py314-pylingual-v3-statement
|
||||
REVISION: main
|
||||
TOKENIZER: syssec-utd/py314-pylingual-v3-tok
|
||||
|
||||
@@ -88,9 +88,13 @@ class EditableBytecode:
|
||||
if self.version >= (3, 13):
|
||||
self.fix_make_function_argval()
|
||||
|
||||
low_information_instruction_blacklist = ["RESUME", "EXTENDED_ARG", "CACHE", "PRECALL", "MAKE_CELL"]
|
||||
low_information_instruction_blacklist = ["RESUME", "EXTENDED_ARG", "CACHE", "PRECALL", "MAKE_CELL", "NOT_TAKEN", "COPY_FREE_VARS"]
|
||||
self.remove_instructions({inst for inst in self.instructions if inst.opname in low_information_instruction_blacklist})
|
||||
|
||||
# inline __annotate__ functions, which were added in python 3.14
|
||||
if self.version >= (3, 14):
|
||||
self.inline_annotate_functions()
|
||||
|
||||
# updates attribute of instructions that contains information about the exception table
|
||||
self._add_inst_exception_attrs()
|
||||
|
||||
@@ -115,6 +119,187 @@ class EditableBytecode:
|
||||
else:
|
||||
inst.argval = 0
|
||||
|
||||
def inline_annotate_functions(self):
|
||||
"""In Python 3.14, type annotations are stored in implicit __annotate__ functions. This function inlines them."""
|
||||
# (opname, argval) pairs
|
||||
# fmt: off
|
||||
# appears at the start of all __annotate__ functions
|
||||
ANNOTATE_FUNC_PREAMBLE = (
|
||||
("LOAD_FAST_BORROW", "format"),
|
||||
("LOAD_SMALL_INT", 2),
|
||||
("COMPARE_OP", ">"),
|
||||
("POP_JUMP_IF_FALSE", 12),
|
||||
("LOAD_COMMON_CONSTANT", 1), # NotImplementedError
|
||||
("RAISE_VARARGS", 1), # exception instance
|
||||
)
|
||||
|
||||
# appears at the start of a <module> code object that has annotations
|
||||
MODULE_ANNOTATE_FUNC_LOAD_SEQUENCE = (
|
||||
("LOAD_CONST", "__annotate__"), # code object
|
||||
("MAKE_FUNCTION", 0),
|
||||
("STORE_NAME", "__annotate__"),
|
||||
("BUILD_SET", 0),
|
||||
("STORE_NAME", "__conditional_annotations__"),
|
||||
)
|
||||
|
||||
# appears at the end of class object that has annotations, right before storing the static attributes
|
||||
CLASS_ANNOTATE_FUNC_LOAD_SEQUENCE = (
|
||||
("LOAD_FAST_BORROW", "__classdict__"),
|
||||
("LOAD_FAST_BORROW", "__conditional_annotations__"),
|
||||
("BUILD_TUPLE", 2),
|
||||
("LOAD_CONST", "__annotate__"), # code object
|
||||
("MAKE_FUNCTION", 8),
|
||||
("SET_FUNCTION_ATTRIBUTE", 8), # closure
|
||||
("STORE_NAME", "__annotate_func__"),
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
def try_read_annotate_func(codeobj) -> EditableBytecode | None:
|
||||
if not iscode(codeobj):
|
||||
return None
|
||||
if not codeobj.co_name == "__annotate__":
|
||||
return None
|
||||
target_bc = EditableBytecode(codeobj, self.opcode, self.version)
|
||||
target_preamble = tuple((inst.opname, inst.argval) for inst in target_bc.instructions[: len(ANNOTATE_FUNC_PREAMBLE)])
|
||||
if target_preamble != ANNOTATE_FUNC_PREAMBLE:
|
||||
return None
|
||||
return target_bc
|
||||
|
||||
# this applies to __annotate__ functions at the top of the <module> codeobj
|
||||
def is_annotate_func_and_get_conditional_annotation_map(codeobj) -> tuple[bool, dict[int, list[Inst]]]:
|
||||
annotate_bytecode = try_read_annotate_func(codeobj)
|
||||
if annotate_bytecode is None:
|
||||
return (False, dict())
|
||||
|
||||
# searching for
|
||||
# LOAD_SMALL_INT <identifier> (could be load const)
|
||||
# LOAD_GLOBAL __conditional_annotations__
|
||||
# CONTAINS_OP in
|
||||
# POP_JUMP_IF_FALSE (to next conditional annotation)
|
||||
# ... type expression
|
||||
# COPY 2
|
||||
# LOAD_CONST (variable name)
|
||||
# STORE_SUBSCR
|
||||
|
||||
conditional_annotation_map: dict[int, list[Inst]] = dict()
|
||||
|
||||
cursor_idx = annotate_bytecode.instructions.index(annotate_bytecode[3].target.next_instructions[0]) # start of the first annotation definition
|
||||
while (cursor_inst := annotate_bytecode[cursor_idx]).opname != "RETURN_VALUE":
|
||||
if cursor_inst.opname.startswith("LOAD_") and cursor_inst.argval == "__classdict__":
|
||||
annotation_identifier = -1
|
||||
next_cursor_idx = next(idx + 1 for idx in range(cursor_idx, len(annotate_bytecode)) if annotate_bytecode[idx].opname == "STORE_SUBSCR")
|
||||
inlinable_insts = (conditional_annotation_map.get(-1) or list()) + annotate_bytecode[cursor_idx + 1 : next_cursor_idx]
|
||||
elif cursor_inst.opname.startswith("LOAD_") and isinstance(cursor_inst.argval, int):
|
||||
annotation_identifier = cursor_inst.argval
|
||||
|
||||
conditional_annotation_guard_jump = annotate_bytecode[cursor_idx + 3]
|
||||
next_cursor_idx = annotate_bytecode.instructions.index(conditional_annotation_guard_jump.target)
|
||||
|
||||
inlinable_insts = annotate_bytecode[cursor_idx + 4 : next_cursor_idx]
|
||||
else:
|
||||
raise AssertionError("Misaligned 3.14 annotation inlining")
|
||||
|
||||
# to help translation, we replace COPY 2 with LOAD_GLOBAL (__annotations__)
|
||||
copy_annotations_inst = inlinable_insts[-3]
|
||||
assert copy_annotations_inst.opname == "COPY" and copy_annotations_inst.argval == 2, "Misaligned 3.14 annotation inlining"
|
||||
load_global_annotations = self.new_instruction(
|
||||
opname="LOAD_GLOBAL",
|
||||
opcode=self.opcode.LOAD_GLOBAL,
|
||||
optype="name",
|
||||
inst_size=copy_annotations_inst.inst_size,
|
||||
arg=-1,
|
||||
argval="__annotations__",
|
||||
argrepr="__annotations__",
|
||||
has_arg=True,
|
||||
offset=copy_annotations_inst.offset,
|
||||
starts_line=None,
|
||||
is_jump_target=False,
|
||||
has_extended_arg=False,
|
||||
)
|
||||
inlinable_insts[-3] = load_global_annotations
|
||||
|
||||
conditional_annotation_map[annotation_identifier] = inlinable_insts
|
||||
cursor_idx = next_cursor_idx
|
||||
|
||||
return (True, conditional_annotation_map)
|
||||
|
||||
# this applies to __annotate__ functions that are decorating a child function
|
||||
def is_annotate_func_and_get_inlinable_insts(codeobj) -> tuple[bool, list[Inst]]:
|
||||
annotate_bytecode = try_read_annotate_func(codeobj)
|
||||
if annotate_bytecode is None:
|
||||
return (False, [])
|
||||
return (True, annotate_bytecode.instructions[len(ANNOTATE_FUNC_PREAMBLE) : -1]) # skip the RETURN_VALUE at the end
|
||||
|
||||
# iterate over all instructions
|
||||
# replace any load_consts that load __annotate__ functions with the function's instructions, minus the return and the prefix
|
||||
# keys are (index, instructions to remove), values are (list of instructions to insert)
|
||||
inline_dict: dict[tuple[int, tuple[Inst]], list[Inst]] = {}
|
||||
jump_target_mapping = {}
|
||||
conditional_annotation_map = {}
|
||||
|
||||
# eat the <module> level conditional __annotate__ object
|
||||
my_module_preamble = tuple((inst.opname, getattr(inst.argval, "co_name", inst.argval)) for inst in self.instructions[: len(MODULE_ANNOTATE_FUNC_LOAD_SEQUENCE)])
|
||||
if my_module_preamble == MODULE_ANNOTATE_FUNC_LOAD_SEQUENCE:
|
||||
sanity_check, conditional_annotation_map = is_annotate_func_and_get_conditional_annotation_map(self.instructions[0].argval)
|
||||
assert sanity_check, "Improperly matched annotation function load sequence"
|
||||
self.co_consts[self.instructions[0].arg] = None
|
||||
self.remove_instructions(self.instructions[: len(MODULE_ANNOTATE_FUNC_LOAD_SEQUENCE)])
|
||||
|
||||
# handle class object __annotate_func__
|
||||
store_static_attributes = next((inst for inst in self.instructions if inst.opname == "STORE_NAME" and inst.argval == "__static_attributes__"), None)
|
||||
if store_static_attributes:
|
||||
preamble_end_idx = self.instructions.index(store_static_attributes) - 1
|
||||
my_class_preamble = tuple((inst.opname, getattr(inst.argval, "co_name", inst.argval)) for inst in self.instructions[preamble_end_idx - len(CLASS_ANNOTATE_FUNC_LOAD_SEQUENCE) : preamble_end_idx])
|
||||
if my_class_preamble == CLASS_ANNOTATE_FUNC_LOAD_SEQUENCE:
|
||||
sanity_check, class_conditional_annotation_map = is_annotate_func_and_get_conditional_annotation_map(self.instructions[preamble_end_idx - 4].argval)
|
||||
assert sanity_check, "Improperly matched annotation function load sequence"
|
||||
conditional_annotation_map |= class_conditional_annotation_map
|
||||
self.co_consts[self.instructions[preamble_end_idx - 4].arg] = None
|
||||
self.remove_instructions(self.instructions[preamble_end_idx - len(CLASS_ANNOTATE_FUNC_LOAD_SEQUENCE) : preamble_end_idx])
|
||||
|
||||
for idx, inst in enumerate(self.instructions):
|
||||
# handle function definition annotations
|
||||
if inst.opname == "LOAD_CONST":
|
||||
is_annotate_func, inlinable_insts = is_annotate_func_and_get_inlinable_insts(inst.argval)
|
||||
if not is_annotate_func:
|
||||
continue
|
||||
|
||||
# replace
|
||||
# LOAD_CONST __annotate__
|
||||
# MAKE_FUNCTION
|
||||
inline_dict[(idx, tuple(self.instructions[idx : idx + 2]))] = inlinable_insts
|
||||
jump_target_mapping[inst] = inlinable_insts[0]
|
||||
|
||||
# handle inline variable annotations
|
||||
elif inst.opname in ("LOAD_NAME", "LOAD_DEREF") and inst.argval == "__conditional_annotations__":
|
||||
load_annotation_identifier = self.instructions[idx + 1]
|
||||
if not (load_annotation_identifier.opname in ("LOAD_CONST", "LOAD_SMALL_INT") and load_annotation_identifier.argval in conditional_annotation_map):
|
||||
continue
|
||||
inlinable_insts = conditional_annotation_map[load_annotation_identifier.argval]
|
||||
# replace
|
||||
# LOAD_NAME __conditional_annotations__
|
||||
# LOAD_SMALL_INT <identifier> (could be load const)
|
||||
# SET_ADD 1
|
||||
# POP_TOP
|
||||
inline_dict[(idx, tuple(self.instructions[idx : idx + 4]))] = inlinable_insts
|
||||
jump_target_mapping[inst] = inlinable_insts[0]
|
||||
|
||||
# handle __classdict__ annotations
|
||||
elif inst.opname == "STORE_DEREF" and inst.argval == "__conditional_annotations__" and -1 in conditional_annotation_map:
|
||||
inline_dict[(idx, tuple(self.instructions[idx - 1 : idx + 1]))] = conditional_annotation_map[-1]
|
||||
del conditional_annotation_map[-1]
|
||||
|
||||
self.insert_insts({idx + len(insts_to_remove): insts_to_insert for (idx, insts_to_remove), insts_to_insert in inline_dict.items()})
|
||||
self._change_jump_targets(jump_target_mapping)
|
||||
self.remove_instructions(set(itertools.chain.from_iterable(insts_to_remove for (idx, insts_to_remove) in inline_dict.keys())))
|
||||
|
||||
# remove the __annotate__ functions from co_consts, but don't impact co_consts offsets
|
||||
# we don't remove these from child_bytecodes because this runs before child_bytecodes are populated
|
||||
for idx, removed_insts in inline_dict.keys():
|
||||
if (inst := removed_insts[0]).opname == "LOAD_CONST":
|
||||
assert self.co_consts[inst.arg] == inst.argval
|
||||
self.co_consts[inst.arg] = None
|
||||
|
||||
def get_recursive_length(self):
|
||||
"""Returns the recursive length of this bytecode and all its descendents"""
|
||||
return len(self) + sum(bytecode.get_recursive_length() for bytecode in self.child_bytecodes)
|
||||
@@ -404,16 +589,16 @@ class EditableBytecode:
|
||||
for bc in self.iter_bytecodes():
|
||||
patch(bc)
|
||||
|
||||
def _change_jump_targets(self, from_inst: Inst, to_inst: Inst):
|
||||
"""Changes the targets of any instructions jumping to "from_inst" to "to_inst".
|
||||
def _change_jump_targets(self, jump_target_mapping: dict[Inst, Inst]):
|
||||
"""Changes the targets of any instructions jumping to any of the keys in jump_target_mapping to the corresponding values.
|
||||
Before:
|
||||
InstA --> InstB
|
||||
After _change_jump_targets(InstB, InstC):
|
||||
After _change_jump_targets({InstB: InstC}):
|
||||
InstA --> !!InstC!!
|
||||
"""
|
||||
for i, inst in enumerate(self):
|
||||
if inst.is_jump and inst.target == from_inst:
|
||||
self[i]._target = to_inst
|
||||
if inst.is_jump and inst.target in jump_target_mapping:
|
||||
self[i]._target = jump_target_mapping[inst.target]
|
||||
|
||||
def collapse_unconditional_jumps(self):
|
||||
"""Causes unnecessary unconditional jumps to "collapse" into a single jump."""
|
||||
@@ -583,6 +768,75 @@ class EditableBytecode:
|
||||
|
||||
return len(to_remove)
|
||||
|
||||
def insert_insts(self, insert_dict: dict[int, list[Inst]]) -> int:
|
||||
"""Inserts the specified lists of instructions at each specified index."""
|
||||
|
||||
if not any(insert_dict.values()):
|
||||
return 0
|
||||
|
||||
if any(idx < 0 or idx > len(self.instructions) for idx in insert_dict.keys()):
|
||||
raise IndexError("Index out of range")
|
||||
|
||||
self.regenerate()
|
||||
self._edited = True
|
||||
|
||||
# store an instruction-based copy of the exception table to make offset fixing easier at the end
|
||||
temp_exception_table = {self.get_by_offset(start): (self.get_by_offset(end), self.get_by_offset(target)) for start, (end, target) in self.exception_table.items()}
|
||||
temp_named_exception_table = dict()
|
||||
if self.named_exception_table:
|
||||
temp_named_exception_table = [(self.get_by_offset(e.start), self.get_by_offset(e.end), self.get_by_offset(e.target), e.depth, e.lasti) for e in self.named_exception_table]
|
||||
|
||||
# insert instructions
|
||||
# go from small to large indices, and track how the indices change as we insert
|
||||
inserted_count = 0
|
||||
for idx, insts in sorted(insert_dict.items(), key=lambda x: x[0]):
|
||||
real_idx = idx + inserted_count
|
||||
self.instructions = self.instructions[:real_idx] + insts + self.instructions[real_idx:]
|
||||
inserted_count += len(insts)
|
||||
|
||||
to_insert = [inst for insts in insert_dict.values() for inst in insts]
|
||||
for inst in to_insert:
|
||||
# set bytecode backpointer
|
||||
inst.bytecode = self
|
||||
|
||||
# add names and consts
|
||||
if inst.optype == "const":
|
||||
if inst.argval not in self.co_consts:
|
||||
self.co_consts.append(inst.argval)
|
||||
inst.arg = self.co_consts.index(inst.argval)
|
||||
elif inst.optype == "name":
|
||||
if inst.argval not in self.co_names:
|
||||
self.co_names.append(inst.argval)
|
||||
inst.arg = self.co_names.index(inst.argval)
|
||||
elif inst.optype == "free":
|
||||
if inst.argval not in self.co_varnames:
|
||||
self.co_varnames.append(inst.argval)
|
||||
inst.arg = self.co_varnames.index(inst.argval)
|
||||
|
||||
self._edited = True
|
||||
self.regenerate() # recalculate offsets
|
||||
|
||||
# fix jump target argval and argrepr
|
||||
for inst in self.instructions:
|
||||
inst.is_jump_target = False
|
||||
for inst in self.instructions:
|
||||
if not inst.is_jump:
|
||||
continue
|
||||
inst.argval = inst.target.offset
|
||||
inst.argrepr = f"to {inst.argval}"
|
||||
inst.target.is_jump_target = True
|
||||
|
||||
# fix exception table offsets
|
||||
self.exception_table = {start.offset: (end.offset, target.offset) for start, (end, target) in temp_exception_table.items()}
|
||||
if temp_named_exception_table:
|
||||
# also delete entries that will never trigger (end is non-inclusive)
|
||||
self.named_exception_table = [_ExceptionTableEntry(start.offset, end.offset, target.offset, depth, lasti) for (start, end, target, depth, lasti) in temp_named_exception_table if start.offset < end.offset]
|
||||
self._add_inst_exception_attrs()
|
||||
|
||||
self._edited = True
|
||||
|
||||
return len(to_insert)
|
||||
|
||||
def new_instruction(self, *args, **kwargs):
|
||||
"""Creates a new instruction for use with this EditableBytecode object. This function does NOT automatically insert the instruction."""
|
||||
return Inst(self, *args, **kwargs)
|
||||
@@ -830,9 +1084,10 @@ class EditableBytecode:
|
||||
instruction_before = self[i.start - 1] if i.start is not None and i.start > 0 else None
|
||||
instruction_after = self[i.stop] if i.stop is not None and i.stop <= len(self) else None
|
||||
|
||||
jump_target_mapping = {}
|
||||
for j, inst in enumerate(insts):
|
||||
if isinstance(value, (list, tuple)) and len(insts) == len(value):
|
||||
self._change_jump_targets(inst, value[j])
|
||||
jump_target_mapping[inst] = value[j]
|
||||
else:
|
||||
new_target = instruction_before or instruction_after
|
||||
if not new_target and len(value) > 0:
|
||||
@@ -841,8 +1096,9 @@ class EditableBytecode:
|
||||
pass # They should have used __del__
|
||||
|
||||
if new_target:
|
||||
self._change_jump_targets(inst, new_target)
|
||||
jump_target_mapping[inst] = new_target
|
||||
|
||||
self._change_jump_targets(jump_target_mapping)
|
||||
self.instructions[i] = value
|
||||
self._edited = True
|
||||
|
||||
@@ -856,11 +1112,12 @@ class EditableBytecode:
|
||||
instruction_before = self[i.start - 1] if i.start is not None and i.start > 0 else None
|
||||
instruction_after = self[i.stop] if i.stop is not None and i.stop <= len(self) else None
|
||||
|
||||
jump_target_mapping = {}
|
||||
for inst in insts:
|
||||
new_target = instruction_before or instruction_after
|
||||
|
||||
if new_target:
|
||||
self._change_jump_targets(inst, new_target)
|
||||
jump_target_mapping[inst] = new_target
|
||||
self._change_jump_targets(jump_target_mapping)
|
||||
|
||||
del self.instructions[i]
|
||||
self._edited = True
|
||||
|
||||
@@ -15,7 +15,7 @@ class PYCFile(EditableBytecode):
|
||||
def __init__(self, source, name_prefix=None):
|
||||
self.pyc_path = None
|
||||
self.source = source
|
||||
source_tuple = (None, None, None, None, None, None, None)
|
||||
source_tuple = (None, None, None, None, None, None, None, None)
|
||||
if isinstance(source, bytes):
|
||||
source = BytesIO(source)
|
||||
source_tuple = load_module_from_file_object(source)
|
||||
@@ -33,6 +33,7 @@ class PYCFile(EditableBytecode):
|
||||
self.ispypy,
|
||||
self.source_size,
|
||||
self.sip_hash,
|
||||
self.file_offsets,
|
||||
) = source_tuple
|
||||
|
||||
self.version = PythonVersion(version)
|
||||
|
||||
@@ -4,3 +4,4 @@ from .remove_docstrings import remove_docstrings
|
||||
from .remove_nop import remove_nop
|
||||
from .fix_indirect_jump import fix_indirect_jump
|
||||
from .replace_firstlno import replace_firstlno
|
||||
from .replace_borrow import replace_borrow
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
from ..EditableBytecode import EditableBytecode
|
||||
|
||||
def replace_borrow(bytecode: EditableBytecode):
|
||||
if bytecode.version < (3,14):
|
||||
return
|
||||
load_fast_borrows = [inst for inst in bytecode.instructions if inst.opname == "LOAD_FAST_BORROW"]
|
||||
for borrow in load_fast_borrows:
|
||||
borrow.opname = "LOAD_FAST"
|
||||
borrow.opcode = bytecode.opcode.LOAD_FAST
|
||||
|
||||
double_load_fast_borrows = [inst for inst in bytecode.instructions if inst.opname == "LOAD_FAST_BORROW_LOAD_FAST_BORROW"]
|
||||
for double_borrow in double_load_fast_borrows:
|
||||
double_borrow.opname = "LOAD_FAST_LOAD_FAST"
|
||||
double_borrow.opcode = bytecode.opcode.LOAD_FAST_LOAD_FAST
|
||||
@@ -4,7 +4,14 @@ import itertools
|
||||
|
||||
|
||||
def replace_firstlno(bytecode: EditableBytecode):
|
||||
to_replace = next((load_const for load_const, store_name in itertools.pairwise(bytecode.instructions) if load_const.opname == "LOAD_CONST" and store_name.opname == "STORE_NAME" and store_name.argval == "__firstlineno__"), None)
|
||||
to_replace = next(
|
||||
(load_const for load_const, store_name in itertools.pairwise(bytecode.instructions) if load_const.opname in ["LOAD_CONST", "LOAD_SMALL_INT"] and store_name.opname == "STORE_NAME" and store_name.argval == "__firstlineno__"), None
|
||||
)
|
||||
if to_replace is not None:
|
||||
to_replace.argval = 0
|
||||
to_replace.argrepr = "0"
|
||||
|
||||
if bytecode.version >= (3, 14):
|
||||
to_replace.opname = "LOAD_SMALL_INT"
|
||||
to_replace.opcode = bytecode.opcode.LOAD_SMALL_INT
|
||||
to_replace.optype = "??"
|
||||
|
||||
@@ -7,7 +7,7 @@ from pathlib import Path
|
||||
import networkx as nx
|
||||
from pylingual.control_flow_reconstruction.cfg import CFG
|
||||
from pylingual.editable_bytecode import EditableBytecode, Inst, PYCFile
|
||||
from pylingual.editable_bytecode.bytecode_patches import fix_indirect_jump, fix_unreachable, remove_extended_arg, remove_nop, replace_firstlno
|
||||
from pylingual.editable_bytecode.bytecode_patches import fix_indirect_jump, fix_unreachable, remove_extended_arg, remove_nop, replace_firstlno, replace_borrow
|
||||
from pylingual.editable_bytecode.control_flow_graph import bytecode_to_control_flow_graph
|
||||
|
||||
|
||||
@@ -191,8 +191,8 @@ def compare_pyc(pyc_a: PYCFile | Path, pyc_b: PYCFile | Path) -> list[TestResult
|
||||
pyc_a = pyc_a.copy() if isinstance(pyc_a, PYCFile) else PYCFile(pyc_a)
|
||||
pyc_b = pyc_b.copy() if isinstance(pyc_b, PYCFile) else PYCFile(pyc_b)
|
||||
|
||||
pyc_a.apply_patches([remove_extended_arg, remove_nop, fix_indirect_jump, fix_unreachable, remove_extended_arg, replace_firstlno])
|
||||
pyc_b.apply_patches([remove_extended_arg, remove_nop, fix_indirect_jump, fix_unreachable, remove_extended_arg, replace_firstlno])
|
||||
pyc_a.apply_patches([remove_extended_arg, remove_nop, fix_indirect_jump, fix_unreachable, remove_extended_arg, replace_firstlno, replace_borrow])
|
||||
pyc_b.apply_patches([remove_extended_arg, remove_nop, fix_indirect_jump, fix_unreachable, remove_extended_arg, replace_firstlno, replace_borrow])
|
||||
|
||||
results = []
|
||||
|
||||
|
||||
@@ -1,13 +1,21 @@
|
||||
import ast
|
||||
import sys
|
||||
import re
|
||||
import copy
|
||||
from pylingual.masking.global_masker import Masker
|
||||
from pylingual.utils.version import PythonVersion
|
||||
|
||||
_RUNTIME_PYTHON_VERSION = PythonVersion(sys.version_info)
|
||||
|
||||
class customUnparser(ast._Unparser):
|
||||
if _RUNTIME_PYTHON_VERSION >= (3, 14):
|
||||
from _ast_unparse import Unparser
|
||||
else:
|
||||
Unparser = ast._Unparser
|
||||
|
||||
|
||||
class customUnparser(Unparser):
|
||||
def __init__(self, masker: Masker, **kwargs):
|
||||
ast._Unparser.__init__(self, **kwargs)
|
||||
Unparser.__init__(self, **kwargs)
|
||||
self.masker = masker
|
||||
|
||||
def visit_Constant(self, node):
|
||||
@@ -28,7 +36,10 @@ class customUnparser(ast._Unparser):
|
||||
|
||||
def visit_FormattedValue(self, node):
|
||||
def unparse_inner(inner):
|
||||
unparser = type(self)(self.masker, _avoid_backslashes=True)
|
||||
if _RUNTIME_PYTHON_VERSION <= (3, 11):
|
||||
unparser = type(self)(self.masker, _avoid_backslashes=True)
|
||||
else:
|
||||
unparser = type(self)(self.masker)
|
||||
unparser.set_precedence(ast._Precedence.TEST.next(), inner)
|
||||
return unparser.visit(inner)
|
||||
|
||||
|
||||
@@ -56,6 +56,8 @@ class Masker:
|
||||
blacklist = [
|
||||
"__doc__",
|
||||
"__annotations__",
|
||||
"__conditional_annotations__",
|
||||
"__annotate__",
|
||||
"__qualname__",
|
||||
"__class__",
|
||||
"return", # for return annotations
|
||||
@@ -69,7 +71,7 @@ class Masker:
|
||||
|
||||
def mask(self, tok):
|
||||
"""Mask a token, must be in the global_table."""
|
||||
return self.global_tab[tok] if not any(tok == t and type(tok) == type(t) for t in self.blacklist) else tok
|
||||
return self.global_tab[tok] if not any(tok == t and type(tok) is type(t) for t in self.blacklist) else tok
|
||||
|
||||
def unmask(self, value):
|
||||
"""Unmask a token, value must be a metatoken value in the global_table; or this function will fail loudly"""
|
||||
@@ -101,6 +103,8 @@ class Masker:
|
||||
func_info.append("annotations")
|
||||
if bool(flags_make_func & 0b1000): # b_free_vars
|
||||
func_info.append("closures")
|
||||
if bool(flags_make_func & 0b10000): # annotate function
|
||||
func_info.append("annotations-func")
|
||||
|
||||
# flags from the target code object
|
||||
flags_co = int(target_co.co_flags)
|
||||
@@ -219,6 +223,9 @@ class Masker:
|
||||
if inst.is_jump:
|
||||
jump_direction_indicator = "v~>" if inst.target.offset > inst.offset else "^~>"
|
||||
view = f"{inst.opname} {inst.argrepr} {jump_direction_indicator}"
|
||||
elif inst.opname == "LOAD_SMALL_INT":
|
||||
# Treat LOAD_SMALL_INT (X) like LOAD_CONST
|
||||
view = f"{inst.opname} , {self.mask(inst.argval)}"
|
||||
elif inst.optype is None or inst.optype == "??" or inst.optype == "encoded_arg":
|
||||
# don't mask IS_OP args
|
||||
view = f"{inst.opname} , {inst.argrepr if inst.argrepr else inst.argval}"
|
||||
|
||||
@@ -39,11 +39,21 @@ def create_global_masker(bytecode: EditableBytecode) -> Masker:
|
||||
|
||||
# create consts
|
||||
consts = list(deepcopy(bc_co.co_consts))
|
||||
|
||||
# add LOAD_SMALL_INT values to consts (3.14+)
|
||||
if bc.version >= (3, 14):
|
||||
for inst in bc.instructions:
|
||||
if inst.opname == "LOAD_SMALL_INT": # duplicate consts will be filtered out later
|
||||
consts.append(inst.argval)
|
||||
|
||||
while consts:
|
||||
const = consts.pop(0)
|
||||
# Don't mask None
|
||||
if const is None:
|
||||
continue
|
||||
# Don't needlessly increment the global_idx
|
||||
if const in global_tab:
|
||||
continue
|
||||
if type(const) in (list, tuple, frozenset, set):
|
||||
consts.extend(const)
|
||||
else:
|
||||
@@ -63,6 +73,13 @@ def create_global_masker(bytecode: EditableBytecode) -> Masker:
|
||||
global_tab.update({free: f"<mask_{global_idx}>"})
|
||||
global_idx += 1
|
||||
|
||||
if bc.version >= (3, 11):
|
||||
for cell in bc_co.co_cellvars:
|
||||
if cell in global_tab:
|
||||
continue
|
||||
global_tab.update({cell: f"<mask_{global_idx}>"})
|
||||
global_idx += 1
|
||||
|
||||
for local in bc_co.co_varnames:
|
||||
if local in global_tab:
|
||||
continue
|
||||
|
||||
+36
-6
@@ -6,8 +6,13 @@ import logging
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from functools import lru_cache
|
||||
|
||||
from pylingual.masking.model_disasm import fix_jump_targets, normalize_masks, restore_masks
|
||||
from pylingual.masking.model_disasm import (
|
||||
fix_jump_targets,
|
||||
normalize_masks,
|
||||
restore_masks,
|
||||
)
|
||||
from pylingual.utils.lists import flatten
|
||||
from pylingual.utils.tracked_list import TrackedDataset, TRANSLATION_STEP
|
||||
from pylingual.utils.version import PythonVersion
|
||||
@@ -43,7 +48,12 @@ class CacheTranslator:
|
||||
self.cache.move_to_end(item)
|
||||
return self.cache[item]
|
||||
|
||||
def _translate_and_decode(self, translation_requests: TrackedDataset | list[str], batch_size: int = 32, **kwargs) -> list[str]:
|
||||
def _translate_and_decode(
|
||||
self,
|
||||
translation_requests: TrackedDataset | list[str],
|
||||
batch_size: int = 32,
|
||||
**kwargs,
|
||||
) -> list[str]:
|
||||
# return_tensors=True prevents standard postprocessing which skips special tokens
|
||||
translation_result = self.translator(translation_requests, return_tensors=True, batch_size=batch_size, **kwargs)
|
||||
decoded_results = []
|
||||
@@ -79,7 +89,10 @@ class CacheTranslator:
|
||||
normalized_args = [normalize_masks(fix_jump_targets(x)) for x in args]
|
||||
|
||||
# New are those not in the local cache
|
||||
new = TrackedDataset(TRANSLATION_STEP, list({norm for norm, _ in normalized_args if norm not in self.cache}))
|
||||
new = TrackedDataset(
|
||||
TRANSLATION_STEP,
|
||||
list({norm for norm, _ in normalized_args if norm not in self.cache}),
|
||||
)
|
||||
|
||||
# Now, "new" has been updated to those not in local
|
||||
for arg, result in zip(new.x, self._translate_with_backoff(new)):
|
||||
@@ -91,7 +104,12 @@ class CacheTranslator:
|
||||
return results
|
||||
|
||||
|
||||
def load_models(config_file: Path = Path("pylingual/decompiler_config.yaml"), version: PythonVersion = PythonVersion(3.9), token=False) -> tuple[transformers.Pipeline, CacheTranslator]:
|
||||
@lru_cache(maxsize=1)
|
||||
def load_models(
|
||||
config_file: Path = Path("pylingual/decompiler_config.yaml"),
|
||||
version: PythonVersion = PythonVersion(3.9),
|
||||
token=False,
|
||||
) -> tuple[transformers.Pipeline, CacheTranslator]:
|
||||
logger.info(f"Loading models for {version}...")
|
||||
with config_file.open() as f:
|
||||
config = yaml.safe_load(f)
|
||||
@@ -124,12 +142,24 @@ def load_models(config_file: Path = Path("pylingual/decompiler_config.yaml"), ve
|
||||
else:
|
||||
logger.warning("Using CPU for models")
|
||||
device = torch.device("cpu")
|
||||
segmenter = transformers.pipeline("token-classification", model=segmentation_model, tokenizer=segmentation_tokenizer, aggregation_strategy="none", device=device)
|
||||
segmenter = transformers.pipeline(
|
||||
"token-classification",
|
||||
model=segmentation_model,
|
||||
tokenizer=segmentation_tokenizer,
|
||||
aggregation_strategy="none",
|
||||
device=device,
|
||||
)
|
||||
#########################################
|
||||
# Sequence translation model components #
|
||||
#########################################
|
||||
translation_model = transformers.T5ForConditionalGeneration.from_pretrained(stmt_config["REPO"], revision=stmt_config["REVISION"], token=token)
|
||||
translation_tokenizer = transformers.RobertaTokenizer.from_pretrained(stmt_config["TOKENIZER"], token=token)
|
||||
translator = transformers.TranslationPipeline(model=translation_model, tokenizer=translation_tokenizer, max_length=512, truncation=False, device=device)
|
||||
translator = transformers.TranslationPipeline(
|
||||
model=translation_model,
|
||||
tokenizer=translation_tokenizer,
|
||||
max_length=512,
|
||||
truncation=False,
|
||||
device=device,
|
||||
)
|
||||
|
||||
return segmenter, CacheTranslator(translator)
|
||||
|
||||
@@ -12,7 +12,7 @@ import shutil
|
||||
from pylingual.utils.version import PythonVersion
|
||||
|
||||
|
||||
UV_VERSIONS = {PythonVersion((3, x)) for x in range(8, 14)}
|
||||
UV_VERSIONS = {PythonVersion((3, x)) for x in range(8, 15)}
|
||||
|
||||
|
||||
class CompileError(Exception):
|
||||
@@ -35,12 +35,12 @@ def _compile_native(py_file: str, out_file: str):
|
||||
def _compile_uv(py_file: str, out_file: str, version: PythonVersion):
|
||||
compile_cmd = f"import py_compile, sys; assert sys.version_info[:2] == {version.as_tuple()!r}; py_compile.compile({py_file!r}, cfile={out_file!r})"
|
||||
|
||||
cmd = ["uvx", "--python", version.as_str(), "python", "-c", compile_cmd]
|
||||
cmd = ["uvx", "--no-config", "--python", version.as_str(), "python", "-c", compile_cmd]
|
||||
|
||||
output = subprocess.run(cmd, shell=False, capture_output=True, text=True, env={**os.environ, "PYTHONWARNINGS": "ignore"})
|
||||
|
||||
# Ignore stderr messages from uv downloading versions on demand
|
||||
stderr = re.sub(r'Downloading .+\n', '', output.stderr)
|
||||
stderr = re.sub(r"\s*Download(ing|ed)\s.+\n", "", output.stderr)
|
||||
if stderr:
|
||||
raise CompileError(stderr)
|
||||
|
||||
@@ -94,7 +94,3 @@ def compile_version(py_file, out_file, version):
|
||||
_compile_uv(py_file=py_file, out_file=out_file, version=version)
|
||||
else:
|
||||
_compile_pyenv(py_file=py_file, out_file=out_file, version=version)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
supported_tuples = [(3, x) for x in range(6, 14)]
|
||||
version_str = {f"{x[0]}{x[1]}": x for x in supported_tuples} | {f"{x[0]}.{x[1]}": x for x in supported_tuples}
|
||||
supported_tuples = [(3, x) for x in range(6, 15)]
|
||||
version_str = {f"{x[0]}{x[1]}": x for x in supported_tuples} | {
|
||||
f"{x[0]}.{x[1]}": x for x in supported_tuples
|
||||
}
|
||||
|
||||
|
||||
class PythonVersion:
|
||||
|
||||
+8
-3
@@ -1,7 +1,7 @@
|
||||
[project]
|
||||
name = "pylingual"
|
||||
version = "0.1.0"
|
||||
description = "A Python bytecode decompilation tool, supporting versions 3.6 - 3.13"
|
||||
description = "A Python bytecode decompilation tool, supporting versions 3.6 - 3.14"
|
||||
authors = [{ name = "syssec-utd" }]
|
||||
readme = "README.md"
|
||||
requires-python = ">= 3.12"
|
||||
@@ -28,10 +28,12 @@ dependencies = [
|
||||
"tqdm",
|
||||
"rich",
|
||||
"seqeval",
|
||||
"transformers==4.46.1",
|
||||
"transformers",
|
||||
"transformers[torch]",
|
||||
"xdis>=6.1.4",
|
||||
"click",
|
||||
"evaluate",
|
||||
"tensorboardx>=2.6.4",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
@@ -73,7 +75,7 @@ exclude = [
|
||||
"site-packages",
|
||||
"venv",
|
||||
]
|
||||
target-version = "py311"
|
||||
target-version = "py314"
|
||||
line-length = 240
|
||||
indent-width = 4
|
||||
|
||||
@@ -100,3 +102,6 @@ docstring-code-line-length = "dynamic"
|
||||
members = [
|
||||
"pylingual/tools",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
xdis = { git = "https://github.com/jdw170000/python-xdis", rev = "python-3.14" }
|
||||
|
||||
Reference in New Issue
Block a user