init new cflow

The old system was very messy and hard to write templates for. This new
system still follows the same approach to control-flow reconstruction,
but is much more simple.

Only 3.12 has been worked on so far, but other versions will come soon.
This commit is contained in:
caandt
2025-03-10 20:23:19 -05:00
parent b689c4b317
commit 7b986e6fda
21 changed files with 2162 additions and 243 deletions
+1
View File
@@ -16,3 +16,4 @@ __pycache__/
mise.toml
dist/
decompiled_*/
decompiled_*.py
@@ -0,0 +1,123 @@
from __future__ import annotations
import os
from typing import TYPE_CHECKING
from pathlib import Path
import networkx as nx
import pydot
from pylingual.editable_bytecode import EditableBytecode
from pylingual.utils.lists import flatten
from .cft import ControlFlowTemplate, EdgeKind, InstTemplate, MetaTemplate
from .templates.Block import BlockTemplate
if TYPE_CHECKING:
DiGraph_CFT = nx.DiGraph[ControlFlowTemplate]
else:
DiGraph_CFT = nx.DiGraph
class CFG(DiGraph_CFT):
bytecode: EditableBytecode
i: int
start: ControlFlowTemplate
end: ControlFlowTemplate
iteration_graphs: list[list[str | list]]
run: int
@staticmethod
def from_graph(cfg: nx.DiGraph, bytecode: EditableBytecode) -> CFG:
self = CFG(cfg)
self.bytecode = bytecode
self.i = 0
self.start = MetaTemplate("start", bytecode.codeobj)
self.end = MetaTemplate("end", bytecode.codeobj)
self.iteration_graphs = []
self.run = 0
if "DEBUG_CFLOW" not in os.environ:
self.visualize = lambda dir="": None
self.layout_nodes = lambda: None
InstTemplate.match_all(self)
for _a, _b, _p in self.edges(data=True):
self[_a][_b]["kind"] = EdgeKind(_p["type"])
root_node = min([x for x in self.nodes], key=lambda x: x.get_instructions()[0].offset)
self.add_nodes_from([self.start, self.end])
self.add_edge(self.start, root_node, kind=EdgeKind.Meta)
self.add_edges_from((node, self.end, EdgeKind.Meta.prop()) for node in self.nodes if isinstance(node, InstTemplate) and self.out_degree(node) == 0)
BlockTemplate.match_all(self)
return self
def iterate(self):
if not self.iteration_graphs:
self.i += 1
self.visualize()
def speculate(self):
self.iteration_graphs.append([])
def drop_graphs(self):
self.iteration_graphs.pop()
def ordered_iter(self):
self._create_dominator_tree()
return nx.dfs_postorder_nodes(self, source=self.start, sort_neighbors=lambda nodes: sorted(nodes, key=lambda x: x.offset, reverse=True))
def apply_graphs(self):
graphs = self.iteration_graphs.pop()
if self.iteration_graphs:
self.iteration_graphs[-1].append(graphs)
else:
for x in flatten(graphs):
g = pydot.graph_from_dot_data(x)[0]
g.set_prog(["neato", "-n"])
g.write_png("/tmp/graph/" + g.get_name().replace('"', ""))
def layout_nodes(self):
relabeled = nx.convert_node_labels_to_integers(self, label_attribute="template") # type: ignore
root = next(i for i in relabeled.nodes if relabeled.nodes[i]["template"] == self.start)
for i, pos in nx.nx_pydot.pydot_layout(relabeled, prog="dot", root=root).items():
relabeled.nodes[i]["template"]._pos = [pos]
def node_by_offset(self, offset: int):
return next(x for x in self.nodes if x.offset == offset)
def _create_dominator_tree(self):
self._dt = nx.create_empty_copy(self)
self._dt.add_edges_from(nx.immediate_dominators(self, self.start).items())
self._dt.remove_edge(self.start, self.start)
self._dr = nx.transitive_closure_dag(self._dt.reverse())
def dominates(self, node_a, node_b):
return self._dr.has_edge(node_a, node_b) or node_a == node_b
def visualize(self, dir="/tmp/graph"):
for n in self.nodes:
self.nodes[n]["label"] = repr(n)
if not self.start._pos:
self.layout_nodes()
i = "-".join([str(self.i)] + [str(len(x)) for x in self.iteration_graphs])
out = Path(f"{dir}/{self.bytecode.name}_{self.bytecode.version[1]}_{i}.png")
dot = pydot.Dot(out.name, splines=True)
dot.set_prog(["neato", "-n"])
nodes = {}
for node, data in self.nodes.data():
nodes[node] = pydot.Node(str(hash(node)), label=repr(node).replace("\n", "\\l").replace("\t", "| ") + "\\l", fontname="Noto Sans", labeljust="l", shape="box", pos=node.pos())
dot.add_node(nodes[node])
for a, b, data in self.edges.data():
dot.add_edge(pydot.Edge(nodes[a], nodes[b], **data, label=data["kind"].value, color=data["kind"].color(), fontname="Noto Sans", labeljust="l"))
if not self.iteration_graphs:
dot.write_png(out)
else:
self.iteration_graphs[-1].append(dot.to_string())
@@ -0,0 +1,472 @@
from __future__ import annotations
from pylingual.control_flow_reconstruction.source import SourceLine, SourceContext
from pylingual.editable_bytecode import EditableBytecode, Inst
from pylingual.editable_bytecode.utils import comprehension_names
import networkx as nx
from abc import ABC, abstractmethod
from types import NoneType
from typing import TYPE_CHECKING, Callable, TypeAlias, TypeVar, override
from collections import defaultdict
from enum import Enum
from xdis import Code3, iscode
if TYPE_CHECKING:
from pylingual.control_flow_reconstruction.cfg import CFG
CFT: TypeAlias = "ControlFlowTemplate"
C = TypeVar("C", bound=ControlFlowTemplate)
def indent_str(string: str, tabs: int = 1) -> str:
return "\n".join("\t" * tabs + line.rstrip() for line in string.split("\n") if line)
class EdgeKind(Enum):
Fall = "natural"
Jump = "jump"
TrueJump = "true_jump"
FalseJump = "false_jump"
Exception = "exception"
Meta = "meta"
def prop(self):
return {"kind": self}
def __str__(self):
return self.value
def color(self):
return {
EdgeKind.Fall: "black",
EdgeKind.Jump: "black",
EdgeKind.TrueJump: "green",
EdgeKind.FalseJump: "green",
EdgeKind.Exception: "red",
EdgeKind.Meta: "blue",
}[self]
class EdgeCategory(Enum):
Natural = "natural"
Conditional = "conditional"
Exception = "exception"
Meta = "meta"
@staticmethod
def from_kind(kind: EdgeKind):
kind = EdgeKind(kind)
if kind in [EdgeKind.Fall, EdgeKind.Jump]:
return EdgeCategory.Natural
if kind in [EdgeKind.TrueJump, EdgeKind.FalseJump]:
return EdgeCategory.Conditional
return EdgeCategory(kind.value)
class NodeMatcher(ABC):
name: str
@abstractmethod
def try_match(self, cfg: CFG, node: CFT | None) -> tuple[CFT | None, list[tuple[str, CFT | None]] | None]:
"""
Checks if the node `node` is valid for this matcher.
If successful, returns `node` (possible modified) and a list of `(name, node)` pairs to check, otherwise `None`.
"""
...
class EdgeMatcher(ABC):
name: str
@abstractmethod
def try_match(self, cfg: CFG, node_a: CFT, node_b: CFT | None) -> tuple[CFT | None, str] | None:
"""
Checks if the edge `(node_a, node_b)` is valid for this matcher.
If successful, returns `node_b` (could be different) and the name of the node that should be checked with it or `''` if no node should be matched, otherwise `None`.
"""
...
def out_edge_dict(cfg: CFG, node: CFT) -> dict[EdgeCategory, CFT | None]:
d: dict[EdgeCategory, CFT | None] = defaultdict(NoneType)
for _, dst, prop in cfg.out_edges(node, data=True):
d[EdgeCategory.from_kind(prop["kind"])] = dst
return d
class Template:
def __init__(self, root: str, nodes: dict[str, NodeMatcher]):
self.root = root
self.nodes = nodes
def try_match(self, cfg: CFG, node: CFT) -> dict[str, CFT | None] | None:
"""
Checks if a subgraph rooted at `node` is valid for this matcher.
If successful, returns a mapping from node names to nodes, otherwise `None`.
"""
mapping: dict[str, CFT | None] = {}
stack: list[tuple[str, CFT | None]] = [(self.root, node)]
while stack:
template_node, cfg_node = stack.pop()
if template_node in mapping:
if mapping[template_node] != cfg_node:
return None
else:
continue
cfg_node, x = self.nodes[template_node].try_match(cfg, cfg_node)
if x is None:
return None
mapping[template_node] = cfg_node
stack.extend(x)
return mapping
class ConditionalNodeMatcher(NodeMatcher):
"""
Matches the inner `NodeMatcher` only if the condition is true
"""
def __init__(self, inner: NodeMatcher, cond: Callable[[CFG, CFT | None], bool]):
self.inner = inner
self.cond = cond
@override
def try_match(self, cfg: CFG, node: CFT | None) -> tuple[CFT | None, list[tuple[str, CFT | None]] | None]:
if not self.cond(cfg, node):
return node, None
return self.inner.try_match(cfg, node)
class OptionalNodeMatcher(NodeMatcher):
"""
Matches None or the inner `NodeMatcher`
"""
def __init__(self, inner: NodeMatcher):
self.inner = inner
@override
def try_match(self, cfg: CFG, node: CFT | None) -> tuple[CFT | None, list[tuple[str, CFT | None]] | None]:
if node is None:
return node, []
return self.inner.try_match(cfg, node)
class AnyNodeMatcher(NodeMatcher):
"""
Matches the first applicable NodeMatcher, if any
"""
def __init__(self, *inner: NodeMatcher):
self.inner = inner
@override
def try_match(self, cfg: CFG, node: CFT | None) -> tuple[CFT | None, list[tuple[str, CFT | None]] | None]:
for inner in self.inner:
new_node, x = inner.try_match(cfg, node)
if x is not None:
return new_node, x
return node, None
class SubtemplateNodeMatcher(NodeMatcher):
"""
Only tries to match the inner `NodeMatcher` if the template successfully matches.
`revert_on_fail` should be used for the corresponding CFTs try_match
"""
def __init__(self, inner: NodeMatcher, template: type[CFT]):
self.inner = inner
self.template = template
@override
def try_match(self, cfg: CFG, node: CFT | None) -> tuple[CFT | None, list[tuple[str, CFT | None]] | None]:
if node is None:
return node, None
# copy = cfg.copy()
copy = cfg
cfg.speculate()
if (new_node := self.template.try_match(copy, node)) is not None:
new_node, x = self.inner.try_match(copy, new_node)
if x is not None:
cfg.apply_graphs()
return new_node, x
cfg.drop_graphs()
return node, None
class NodeTemplate(NodeMatcher):
"""
Matches a node if all of its edges match the matcher's corresponding `EdgeMatcher`
"""
def __init__(self, edges: dict[EdgeCategory, EdgeMatcher]):
self.edges = edges
@override
def try_match(self, cfg: CFG, node: CFT | None) -> tuple[CFT | None, list[tuple[str, CFT | None]] | None]:
if node is None or node not in cfg.nodes:
return node, None
out_edges = out_edge_dict(cfg, node)
next_nodes: list[tuple[str, CFT | None]] = []
for edge_type, edge_matcher in self.edges.items():
next_node = out_edges[edge_type]
next = edge_matcher.try_match(cfg, node, next_node)
if next is None:
return node, None
if next[1]:
next_nodes.append((next[1], next[0]))
return node, next_nodes
class EdgeTemplate(EdgeMatcher):
"""
Matches an edge `(a, b)` if `b` is not None
Assigns `b` to the node with name `name`
"""
def __init__(self, name: str):
self.name = name
@override
def try_match(self, cfg: CFG, node_a: CFT, node_b: CFT | None) -> tuple[CFT | None, str] | None:
if node_b is not None:
return (node_b, self.name)
class OptionalEdge(EdgeMatcher):
"""
Matches any edge `(a, b)`, even if `b` is None
Assigns `b` to the node with name `name` if `b` is not None
"""
def __init__(self, name: str):
self.name = name
@override
def try_match(self, cfg: CFG, node_a: CFT, node_b: CFT | None) -> tuple[CFT | None, str] | None:
if node_b is not None:
return (node_b, self.name)
return (node_b, "")
class OptExcEdge(EdgeMatcher):
def __init__(self, name: str):
self.name = name
@override
def try_match(self, cfg: CFG, node_a: CFT, node_b: CFT | None) -> tuple[CFT | None, str] | None:
if node_b is None and all(x.opname == "JUMP_BACKWARD" for x in node_a.get_instructions()):
return (node_b, "")
if node_b is not None and cfg.get_edge_data(node_a, node_b, {}).get("kind") is EdgeKind.Meta:
return (node_b, "")
return (node_b, self.name)
class NoEdge(EdgeMatcher):
"""
Matches an edge `(a, b)` if `b` is None (i.e. there is no edge)
"""
edge = ""
@override
def try_match(self, cfg: CFG, node_a: CFT, node_b: CFT | None) -> tuple[CFT | None, str] | None:
if node_b is None:
return (node_b, "")
class ExitableEdge(EdgeMatcher):
def __init__(self, name: str):
self.name = name
@override
def try_match(self, cfg: CFG, node_a: CFT, node_b: CFT | None) -> tuple[CFT | None, str] | None:
if node_b is None:
d = out_edge_dict(cfg, node_a)
if d[EdgeCategory.Meta] is not None:
return (d[EdgeCategory.Meta], "")
if d[EdgeCategory.Natural] is None and d[EdgeCategory.Conditional] is None:
return (cfg.end, "")
return (node_b, self.name)
class RaiseOutEdge(EdgeMatcher):
def __init__(self, name: str):
self.name = name
@override
def try_match(self, cfg: CFG, node_a: CFT, node_b: CFT | None) -> tuple[CFT | None, str] | None:
if node_a.get_instructions()[-1].opname not in ["RERAISE", "RAISE_VARARGS"]:
return None
if node_b is None:
d = out_edge_dict(cfg, node_a)
if d[EdgeCategory.Meta] is not None:
return (d[EdgeCategory.Meta], "")
return (node_b, self.name)
class ControlFlowTemplate(ABC):
members: dict[str, CFT | None]
template: Template
offset: int
header_lines: list[SourceLine]
blame: Code3
_pos: list[tuple[float, float]]
def __init__(self, members: dict[str, CFT | None]):
self.members = members
first = next(x for x in members.values() if x is not None)
self.offset = first.offset
self.header_lines = []
self.blame = first.blame
self._pos = sum((x._pos for x in members.values() if x is not None), start=[])
def pos(self):
avg_x = sum(x for x, _ in self._pos) / len(self._pos)
avg_y = sum(y for _, y in self._pos) / len(self._pos)
return f"{avg_x},{avg_y}!"
def __getattr__(self, name: str) -> CFT:
x = self.members[name]
if x is not None:
return x
return MetaTemplate(f"{name} (empty)", self.blame)
@classmethod
@abstractmethod
def try_match(cls, cfg: CFG, node: CFT) -> CFT | None:
"""
Trys to match this template starting at `node`. Returns the new node if the match was successful.
Modifies `cfg` on success.
"""
...
@abstractmethod
def to_indented_source(self, source: SourceContext) -> list[SourceLine]:
"""
Returns the source code for this template, recursively calling into its children to create the full source code.
"""
...
@override
def __repr__(self) -> str:
name = type(self).__name__
components = indent_str(",\n".join(f"{k}={repr(v)}" for k, v in self.members.items()))
return f"{name}[\n{components}]"
def get_instructions(self) -> list[Inst]:
return [i for m in self.members.values() if m is not None for i in m.get_instructions()]
def line(self, s: str, i: int = 0, child: Code3 | None = None, meta: bool = False):
assert s
return [SourceLine(s, i, self.blame, child, meta)]
def add_header(self, s: str, meta: bool = False):
self.header_lines.extend(self.line(s, meta=meta))
class InstTemplate(ControlFlowTemplate):
def __init__(self, inst: Inst):
self.inst = inst
self.offset = self.inst.offset
self.blame = inst.bytecode.codeobj
self.header_lines = []
self._pos = []
@staticmethod
def match_all(cfg):
mapping = {node: InstTemplate(node) for node in cfg.nodes if isinstance(node, Inst)}
nx.relabel_nodes(cfg, mapping, copy=False)
@override
@classmethod
def try_match(cls, cfg, node):
raise NotImplementedError
@override
def to_indented_source(self, source: SourceContext) -> list[SourceLine]:
lines = [] if self.inst.starts_line is None or not source.lines[self.inst.starts_line - 1] else self.line(source.lines[self.inst.starts_line - 1])
if self.inst.opname == "LOAD_CONST" and iscode(self.inst.argval):
if self.inst.argval in source.cfts and self.inst.argval.co_name not in comprehension_names: # type: ignore
lines.append(SourceLine("", 1, self.inst.argval, self.inst.argval))
return lines
@override
def get_instructions(self):
return [self.inst]
@override
def __repr__(self):
x = None
if self.inst.arg is None:
x = f"<{self.inst.offset}: {self.inst.opname}>"
elif not self.inst.argrepr:
x = f"<{self.inst.offset}: {self.inst.opname} {self.inst.arg}>"
elif self.inst.opname == "LOAD_CONST":
arg = self.inst.bytecode.co_consts[self.inst.arg] # type: ignore
if isinstance(arg, EditableBytecode):
x = f"<{self.inst.offset}: {self.inst.opname} {self.inst.arg} ({arg.name})>"
if x is None:
x = f"<{self.inst.offset}: {self.inst.opname} {self.inst.arg} ({self.inst.argrepr})>"
if self.inst.starts_line is not None:
return f"[{self.inst.starts_line}] {x}"
return x
class MetaTemplate(ControlFlowTemplate):
def __init__(self, name: str, blame: Code3):
self.name = name
self.offset = -1
self.header_lines = []
self._pos = []
self.blame = blame
@override
@classmethod
def try_match(cls, cfg: CFG, node: ControlFlowTemplate) -> ControlFlowTemplate | None:
raise NotImplementedError
@override
def to_indented_source(self, source: SourceContext) -> list[SourceLine]:
return self.line(f"# meta: {self.name}", meta=True)
@override
def get_instructions(self):
return []
@override
def __repr__(self):
return f"MetaTemplate[{self.name}]"
template_dict: dict[int, list[tuple[type[ControlFlowTemplate], int]]] = defaultdict(list)
version_specific_template_dict: dict[tuple[int, int], dict[int, list[tuple[type[ControlFlowTemplate], int]]]] = defaultdict(lambda: defaultdict(list))
def register_template(run: int, priority: int, *versions: tuple[int, int]):
"""
Register a control flow template to be used in run `run` with priority `priority`.
If no versions are given, the template is used for all versions.
"""
def deco(template: type[C]) -> type[C]:
if not versions:
template_dict[run].append((template, priority))
else:
for version in versions:
version_specific_template_dict[version][run].append((template, priority))
return template
return deco
def get_template_runs(version: tuple[int, int]) -> list[list[type[ControlFlowTemplate]]]:
runs: dict[int, list[tuple[type[ControlFlowTemplate], int]]] = defaultdict(list)
for run in (template_dict | version_specific_template_dict[version]).keys():
runs[run].extend(template_dict[run])
runs[run].extend(version_specific_template_dict[version][run])
return [[x[0] for x in sorted(runs[run], key=lambda x: x[1])] for run in sorted(runs)]
@@ -0,0 +1,161 @@
from __future__ import annotations
import itertools
import keyword
import inspect
import ast
from typing import TYPE_CHECKING, Generator, NamedTuple
from xdis import Code3
from pylingual.editable_bytecode import PYCFile
from pylingual.editable_bytecode.EditableBytecode import EditableBytecode
from pylingual.utils.use_escape_sequences import use_escape_sequences
from pylingual.utils.version import PythonVersion
if TYPE_CHECKING:
from .cft import ControlFlowTemplate
def indent_lines(lines: list[SourceLine], i: int = 1) -> list[SourceLine]:
return [SourceLine(x.line, x.indent + i, x.blame, x.child, x.meta) for x in lines]
class SourceLine(NamedTuple):
line: str
indent: int
blame: Code3
child: Code3 | None = None
meta: bool = False
def with_line(self, line: str):
return SourceLine(line, self.indent, self.blame, self.child, self.meta)
def sanitize_lines(lines: list[str]) -> list[str]:
return ["" if x in ("break", "continue", "except:", "while True:") else x for x in (x[2:] if x.startswith("elif ") else x for x in (x.strip() for x in lines))]
def fake_header(co: Code3):
name = co.co_name if co.co_name.isidentifier() and not keyword.iskeyword(co.co_name) else "_"
if co.co_flags & inspect.CO_ASYNC_GENERATOR:
return f"async def {name}():"
if co.co_flags & inspect.CO_NEWLOCALS:
return f"def {name}():"
return f"class {name}:"
def valid_header(line: SourceLine, version: PythonVersion):
try:
ast.parse(line.line + "pass", feature_version=version.as_tuple())
return True
except Exception:
return False
class SourceContext:
def __init__(self, pyc: PYCFile, lines: list[str], cfts: dict[Code3, ControlFlowTemplate]):
self.pyc = pyc
self.lines = sanitize_lines(lines)
self.cfts = cfts
self.cache: dict[ControlFlowTemplate, list[SourceLine]] = {}
self.header_lines: list[SourceLine] = []
self.purged_cfts: list[ControlFlowTemplate] = []
self.init_header()
def init_header(self):
for bc in self.pyc.iter_bytecodes():
cft = self.cfts[bc.codeobj]
if bc.codeobj.co_flags & inspect.CO_NEWLOCALS:
if bc.codeobj.co_consts and isinstance(bc.codeobj.co_consts[0], str):
doc = use_escape_sequences(bc.codeobj.co_consts[0])
cft.add_header(f'"""{doc}"""')
if bc.codeobj.co_flags & (inspect.CO_GENERATOR | inspect.CO_ASYNC_GENERATOR):
if not any(self.lines[i.starts_line - 1].strip().startswith("yield ") or self.lines[i.starts_line - 1].strip() == "yield" for i in cft.get_instructions() if i.starts_line is not None):
cft.add_header("if False: yield")
for global_var in bc.globals:
cft.add_header(f"global {global_var}")
parent_nonlocal = set()
parent = bc.parent
while parent:
parent_nonlocal |= parent.nonlocals
parent = parent.parent
for nonlocal_var in bc.nonlocals:
if nonlocal_var in parent_nonlocal:
cft.add_header(f"nonlocal {nonlocal_var}")
def __getitem__(self, template: ControlFlowTemplate | tuple[ControlFlowTemplate, int]):
if isinstance(template, tuple):
template, indent = template
else:
indent = 0
if template not in self.cache:
self.cache[template] = template.to_indented_source(self)
if indent:
return indent_lines(template.header_lines + self.cache[template], indent)
return template.header_lines + self.cache[template]
def source_lines_of(self, cft: ControlFlowTemplate, i=0) -> Generator[SourceLine]:
lines = self[cft, i]
purged = cft in self.purged_cfts
prev = None
for line in lines:
if line.child:
if purged:
if prev and valid_header(prev, self.pyc.version):
yield prev
else:
yield SourceLine(fake_header(line.child), line.indent - 1, line.child)
yield from self.source_lines_of(self.cfts[line.child], line.indent)
elif not purged:
yield line
prev = line
def purge(self, co: Code3):
self.purged_cfts.append(self.cfts[co])
def source_lines(self):
def is_prefix(x: SourceLine):
return x.line.startswith(("from __future__ import ", "__doc__ = ", "global ", "nonlocal ", '"""'))
def priority(x: SourceLine):
if x.line.startswith(("__doc__ = ", '"""')):
return 0
if x.line.startswith("from __future__ import "):
return 1
return 2
lines = self.header_lines + list(self.source_lines_of(self.cfts[self.pyc.codeobj]))
prefix = [x.with_line(x.line[10:]) if x.line.startswith("__doc__ = ") else x for x in sorted(itertools.takewhile(is_prefix, lines), key=priority)]
lines[: len(prefix)] = prefix
# insert pass in empty blocks
colon_line = None
new_lines = []
for x in lines:
if colon_line is not None:
if x.indent <= colon_line.indent:
new_lines.append(SourceLine("pass", colon_line.indent + 1, colon_line.blame))
if not x.meta:
colon_line = None
if x.line.endswith(":"):
colon_line = x
new_lines.append(x)
if colon_line is not None:
new_lines.append(SourceLine("pass", colon_line.indent + 1, colon_line.blame))
return new_lines
def __str__(self):
return "\n".join(" " * x.indent + x.line for x in self.source_lines())
def update_cft(self, bc: EditableBytecode, template: ControlFlowTemplate):
x = bc
while x.parent is not None:
del self.cache[self.cfts[x.codeobj]]
x = x.parent
self.cfts[bc.codeobj] = template
def update_lines(self, lines: list[str]):
self.lines = sanitize_lines(lines)
self.cache.clear()
@@ -0,0 +1,38 @@
import pdb
from pylingual.editable_bytecode import EditableBytecode
from pylingual.editable_bytecode.control_flow_graph import bytecode_to_control_flow_graph
import networkx as nx
from .cfg import CFG
from .cft import ControlFlowTemplate, get_template_runs, MetaTemplate
def iteration(cfg: CFG, runs: list[list[type[ControlFlowTemplate]]]):
for cfg.run, run in enumerate(runs):
for node in cfg.ordered_iter():
for template in run:
if template.try_match(cfg, node):
return True
return False
def bc_to_cft(bc: EditableBytecode):
return structure_control_flow(bytecode_to_control_flow_graph(bc), bc)
def structure_control_flow(cfg: nx.DiGraph, bytecode: EditableBytecode) -> ControlFlowTemplate:
cfg = CFG.from_graph(cfg, bytecode)
runs = get_template_runs(bytecode.version[:2])
try:
while len(cfg) > 1:
if not iteration(cfg, runs):
return MetaTemplate("\x1b[31mirreducible cflow\x1b[0m", bytecode.codeobj)
except Exception:
pdb.xpm() # type: ignore
raise
return next(iter(cfg.nodes))
@@ -0,0 +1,111 @@
from __future__ import annotations
from typing import TYPE_CHECKING, override
from itertools import chain
from pylingual.editable_bytecode import Inst
from ..cft import ControlFlowTemplate, EdgeKind, SourceContext, SourceLine, register_template, EdgeCategory, out_edge_dict, MetaTemplate, indent_str
from ..utils import E, N, T, defer_source_to, remove_nodes
if TYPE_CHECKING:
from pylingual.control_flow_reconstruction.cfg import CFG
@register_template(100, 0)
class EndTemplate(ControlFlowTemplate):
template = T(
start=N(E.meta("body")).of_type(MetaTemplate),
body=N(E.meta("end")),
end=N.tail().of_type(MetaTemplate).with_in_deg(1),
)
@override
@classmethod
def try_match(cls, cfg, node) -> ControlFlowTemplate | None:
if node is not cfg.start:
return None
mapping = cls.template.try_match(cfg, node)
if mapping is None:
return None
template = cls(mapping)
remove_nodes(cfg, mapping, "start", "body", "end")
cfg.add_node(template)
cfg.start = template
cfg.end = template
return template
to_indented_source = defer_source_to("body")
@register_template(0, 20)
@register_template(2, 20)
class BlockTemplate(ControlFlowTemplate):
members: list[ControlFlowTemplate]
def __init__(self, members: list[ControlFlowTemplate]):
self.members = members # type: ignore
self.offset = members[0].offset if members else -1
self._pos = sum((x._pos for x in members), start=[])
self.header_lines = []
self.blame = members[0].blame
@staticmethod
def match_all(cfg: CFG):
it, cfg.iterate = cfg.iterate, lambda: None
for node in list(cfg.nodes):
if isinstance(node, MetaTemplate) or node not in cfg.nodes:
continue
BlockTemplate.try_match(cfg, node)
cfg.iterate = it
cfg.iterate()
@override
@classmethod
def try_match(cls, cfg, node) -> ControlFlowTemplate | None:
members: list[ControlFlowTemplate] = []
out = out_edge_dict(cfg, node)
exc = out[EdgeCategory.Exception]
current = node
while True:
if out[EdgeCategory.Exception] != exc:
break
if current != node and cfg.in_degree(current) > 1: # type: ignore
break
if current in members:
break
members.append(current)
next = out[EdgeCategory.Natural]
if next is None:
break
if cfg.get_edge_data(current, next).get("kind") != EdgeKind.Fall and cfg.run != 2:
break
if out[EdgeCategory.Conditional] is not None:
break
out = out_edge_dict(cfg, next)
current = next
if len(members) < 2:
return None
template = BlockTemplate([x for m in members for x in (m.members if isinstance(m, BlockTemplate) else [m])])
in_edges = [(src, template, prop) for src, _, prop in cfg.in_edges(node, data=True) if src not in members]
out_edges = [(template, template, prop) if dst in members else (template, dst, prop) for _, dst, prop in cfg.out_edges(members[-1], data=True)]
cfg.remove_nodes_from(members)
cfg.add_node(template)
cfg.add_edges_from(chain(in_edges, out_edges))
cfg.iterate()
return template
@override
def to_indented_source(self, source: SourceContext) -> list[SourceLine]:
return list(chain.from_iterable(source[m] for m in self.members))
@override
def get_instructions(self) -> list[Inst]:
insts: list[Inst] = []
for member in self.members:
insts.extend(member.get_instructions())
return insts
@override
def __repr__(self) -> str:
components = indent_str("\n".join(repr(member) for member in self.members))
return f"BlockTemplate[\n{components}]"
@@ -0,0 +1,102 @@
from ..cft import ControlFlowTemplate, EdgeKind, register_template
from ..utils import T, N, defer_source_to, run_is, starting_instructions, to_indented_source, make_try_match, without_top_level_instructions
@register_template(1, 40)
class IfElse(ControlFlowTemplate):
template = T(
if_header=~N("if_body", "else_body").with_cond(without_top_level_instructions("WITH_EXCEPT_START", "CHECK_EXC_MATCH", "FOR_ITER")),
if_body=~N("tail.").with_in_deg(1),
else_body=~N("tail.").with_in_deg(1),
tail=N.tail(),
)
try_match = make_try_match({EdgeKind.Fall: "tail"}, "if_header", "if_body", "else_body")
@to_indented_source
def to_indented_source():
"""
{if_header}
{if_body}
{else_body?else:}
{else_body}
"""
@register_template(1, 41)
@register_template(2, 41)
class IfThen(ControlFlowTemplate):
template = T(
if_header=~N("if_body", "tail").with_cond(without_top_level_instructions("WITH_EXCEPT_START", "CHECK_EXC_MATCH", "FOR_ITER")),
if_body=~N("tail").with_in_deg(1) | ~N("tail.").with_in_deg(1).with_cond(run_is(2)),
tail=N.tail(),
)
try_match = make_try_match({EdgeKind.Fall: "tail"}, "if_header", "if_body")
@to_indented_source
def to_indented_source():
"""
{if_header}
{if_body}
"""
@register_template(0, 39)
class Assertion(ControlFlowTemplate):
template = T(
assertion=~N("fail", "tail"),
fail=+N().with_cond(starting_instructions("LOAD_ASSERTION_ERROR")),
tail=N.tail(),
)
try_match = make_try_match({EdgeKind.Fall: "tail"}, "assertion", "fail")
to_indented_source = defer_source_to("assertion")
@register_template(1, 46)
class ShortCircuitAnd(ControlFlowTemplate):
template = T(
A=~N("B", "tail"),
B=~N("body", "tail").with_in_deg(1),
body=~N.tail(),
tail=N.tail(),
)
try_match = make_try_match(
{
EdgeKind.Fall: "body",
EdgeKind.FalseJump: "tail",
},
"A",
"B",
)
@to_indented_source
def to_indented_source():
"""
{A}
{B}
"""
@register_template(1, 45)
class ShortCircuitOr(ControlFlowTemplate):
template = T(
A=~N("B", "body"),
B=~N("body", "tail").with_in_deg(1),
body=~N.tail(),
tail=N.tail(),
)
try_match = make_try_match(
{
EdgeKind.Fall: "body",
EdgeKind.FalseJump: "tail",
},
"A",
"B",
)
to_indented_source = defer_source_to("A")
@@ -0,0 +1,272 @@
from itertools import chain
from typing import override
from .Block import BlockTemplate
from .Conditional import IfElse, IfThen
from ..cft import ControlFlowTemplate, EdgeCategory, EdgeKind, InstTemplate, SourceLine, SourceContext, register_template
from ..utils import E, N, T, condense_mapping, defer_source_to, ending_instructions, exact_instructions, no_back_edges, revert_on_fail, starting_instructions, to_indented_source, make_try_match, versions_from
reraise = +N().with_cond(exact_instructions("COPY", "POP_EXCEPT", "RERAISE"))
class Except3_11(ControlFlowTemplate):
@classmethod
@override
def try_match(cls, cfg, node) -> ControlFlowTemplate | None:
if [x.opname for x in node.get_instructions()] == ["RERAISE"]:
return node
if x := ExceptExc3_11.try_match(cfg, node):
return x
if x := BareExcept3_11.try_match(cfg, node):
return x
@register_template(0, 0, *versions_from(3, 12))
class Try3_12(ControlFlowTemplate):
template = T(
try_header=N("try_body"),
try_body=N("tail.", None, "except_body"),
except_body=N("tail.", None, "reraise").with_in_deg(1).of_subtemplate(Except3_11),
reraise=reraise,
tail=N.tail(),
)
try_match = revert_on_fail(
make_try_match(
{
EdgeKind.Fall: "tail",
},
"try_header",
"try_body",
"except_body",
"reraise",
)
)
@to_indented_source
def to_indented_source():
"""
{try_header}
{try_body}
{except_body}
"""
@register_template(0, 0, *versions_from(3, 12))
class TryElse3_12(ControlFlowTemplate):
template = T(
try_header=N("try_body"),
try_body=N("try_else.", None, "except_body"),
except_body=N("tail.", None, "reraise").with_in_deg(1).of_subtemplate(Except3_11),
try_else=~N("tail.").with_in_deg(1),
reraise=reraise,
tail=N.tail(),
)
try_match = revert_on_fail(
make_try_match(
{
EdgeKind.Fall: "tail",
},
"try_header",
"try_body",
"except_body",
"try_else",
"reraise",
)
)
@to_indented_source
def to_indented_source():
"""
{try_header}
{try_body}
{except_body}
else:
{try_else}
"""
class BareExcept3_11(Except3_11):
template = T(
except_body=N("except_footer", None, "reraise"),
except_footer=~N("tail.").with_in_deg(1).with_cond(starting_instructions("POP_EXCEPT")),
reraise=reraise,
tail=N.tail(),
)
try_match = make_try_match(
{
EdgeKind.Fall: "tail",
EdgeKind.Exception: "reraise",
},
"except_body",
"except_footer",
)
@to_indented_source
def to_indented_source():
"""
except:
{except_body}
{except_footer}
"""
class ExcBody3_11(ControlFlowTemplate):
@classmethod
@override
def try_match(cls, cfg, node) -> ControlFlowTemplate | None:
if x := NamedExc3_11.try_match(cfg, node):
return x
return node
class NamedExcTail3_11(ControlFlowTemplate):
template = T(
SWAP=N("tail", None, "reraise").with_cond(exact_instructions("SWAP")),
reraise=reraise,
tail=N.tail(),
)
@classmethod
def _try_match(cls, cfg, node):
mapping = cls.template.try_match(cfg, node)
if mapping is None:
return None
return condense_mapping(cls, cfg, mapping, "SWAP", "tail", out_filter=[EdgeCategory.Exception])
@classmethod
@override
def try_match(cls, cfg, node) -> ControlFlowTemplate | None:
if x := cls._try_match(cfg, node):
return x
return node
to_indented_source = defer_source_to("tail")
class NamedExc3_11(ExcBody3_11):
template = T(
STORE=N("body", None, "reraise").with_cond(exact_instructions("STORE_FAST"), exact_instructions("STORE_NAME")),
body=N("tail.", None, "cleanup"),
cleanup=N(E.exc("reraise")).with_cond(exact_instructions("LOAD_CONST", "STORE_FAST", "DELETE_FAST", "RERAISE"), exact_instructions("LOAD_CONST", "STORE_NAME", "DELETE_NAME", "RERAISE")),
reraise=reraise,
tail=N.tail().of_subtemplate(NamedExcTail3_11),
)
try_match = make_try_match({EdgeKind.Fall: "tail", EdgeKind.Exception: "reraise"}, "STORE", "body", "cleanup")
to_indented_source = defer_source_to("body")
class ExceptExc3_11(Except3_11):
template = T(
except_header=N("except_body", "no_match", "reraise").with_cond(ending_instructions("CHECK_EXC_MATCH", "POP_JUMP_FORWARD_IF_FALSE"), ending_instructions("CHECK_EXC_MATCH", "POP_JUMP_IF_FALSE")),
except_body=N("except_footer.", None, "reraise").of_subtemplate(ExcBody3_11).with_in_deg(1),
no_match=N("tail?", None, "reraise").with_in_deg(1).of_subtemplate(Except3_11),
except_footer=~N("tail.").with_in_deg(1).with_cond(starting_instructions("POP_EXCEPT")),
reraise=reraise,
tail=N.tail(),
)
try_match = revert_on_fail(
make_try_match(
{
EdgeKind.Fall: "tail",
EdgeKind.Exception: "reraise",
},
"except_header",
"except_body",
"except_footer",
"no_match",
)
)
@to_indented_source
def to_indented_source():
"""
{except_header}
{except_body}
{except_footer}
{no_match}
"""
@register_template(0, 50)
@register_template(2, 50)
class TryFinally3_12(ControlFlowTemplate):
template = T(
try_header=N("try_body"),
try_body=N("finally_body", None, "fail_body"),
finally_body=~N("tail.").with_in_deg(1).with_cond(no_back_edges),
fail_body=N(E.exc("reraise")),
reraise=reraise,
tail=N.tail(),
)
template2 = T(
try_except=N("finally_body", None, "fail_body").of_type(Try3_12, TryElse3_12),
finally_body=~N("tail.").with_in_deg(1).with_cond(no_back_edges),
fail_body=N(E.exc("reraise")),
reraise=reraise,
tail=N.tail(),
)
@staticmethod
def find_finally_cutoff(mapping):
f = mapping["finally_body"]
g = mapping["fail_body"]
if any(x.starts_line is not None for x in g.get_instructions()):
return None
if not isinstance(f, BlockTemplate):
f = BlockTemplate([f])
if not isinstance(g, BlockTemplate):
g = BlockTemplate([g])
if isinstance(g.members[0], InstTemplate) and g.members[0].inst.opname == "PUSH_EXC_INFO":
g.members.pop(0)
if isinstance(g.members[-1], InstTemplate) and g.members[-1].inst.opname == "RERAISE":
g.members.pop()
x = None
for x, y in zip(f.members, g.members):
if all(type(a) in [IfThen, IfElse] for a in (x, y)):
continue
if type(x) is not type(y):
return None
return x and f.members.index(x)
cutoff: int
@classmethod
@override
def try_match(cls, cfg, node) -> ControlFlowTemplate | None:
mapping = cls.template.try_match(cfg, node)
if mapping is None:
mapping = cls.template2.try_match(cfg, node)
if mapping is None:
return None
mapping["try_header"] = mapping.pop("try_except")
cutoff = cls.find_finally_cutoff(mapping)
if cutoff is None:
if cfg.run == 2:
cutoff = 9999
else:
return None
template = condense_mapping(cls, cfg, mapping, "try_header", "try_body", "finally_body", "fail_body", "reraise")
template.cutoff = cutoff
return template
def to_indented_source(self, source: SourceContext) -> list[SourceLine]:
header = source[self.try_header]
body = source[self.try_body, 1]
if isinstance(self.finally_body, BlockTemplate):
i = self.cutoff + 1
in_finally = source[BlockTemplate(self.finally_body.members[:i]), 1] if i > 0 else []
after = source[BlockTemplate(self.finally_body.members[i:])] if i < len(self.finally_body.members) else []
else:
in_finally = source[self.finally_body, 1]
after = []
return list(chain(header, body, self.line("finally:"), in_finally, after))
@@ -0,0 +1,39 @@
from ..cft import ControlFlowTemplate, EdgeKind, MetaTemplate, register_template
from ..utils import E, T, N, defer_source_to, exact_instructions, no_back_edges, to_indented_source, make_try_match
@register_template(0, 0)
class Await3_12(ControlFlowTemplate):
template = T(
awaited=N("SEND", None, "gen_cleanup").with_cond(no_back_edges),
SEND=N("YIELD_VALUE", "JUMP_BACK_NO_INT", "gen_cleanup").with_in_deg(2).with_cond(exact_instructions("SEND")),
YIELD_VALUE=N("JUMP_BACK_NO_INT", None, "CLEANUP_THROW").with_in_deg(1).with_cond(exact_instructions("YIELD_VALUE")),
JUMP_BACK_NO_INT=N("SEND", None, "gen_cleanup").with_in_deg(2).with_cond(exact_instructions("JUMP_BACKWARD_NO_INTERRUPT")),
CLEANUP_THROW=N("JUMP_BACK", None, "gen_cleanup").with_in_deg(1).with_cond(exact_instructions("CLEANUP_THROW")),
JUMP_BACK=N("tail").with_in_deg(1).with_cond(exact_instructions("JUMP_BACKWARD"), exact_instructions("JUMP_BACKWARD_NO_INTERRUPT")),
gen_cleanup=~N.tail(),
tail=N.tail(),
)
try_match = make_try_match({EdgeKind.Fall: "tail", EdgeKind.Exception: "gen_cleanup"}, "awaited", "SEND", "YIELD_VALUE", "JUMP_BACK_NO_INT", "CLEANUP_THROW", "JUMP_BACK")
to_indented_source = defer_source_to("awaited")
@register_template(0, 0)
class Generator3_12(ControlFlowTemplate):
template = T(
entry=N("body").with_cond(exact_instructions("RETURN_GENERATOR", "POP_TOP")),
body=N(E.exc("gen_cleanup"), E.meta("end?")),
gen_cleanup=N(E.meta("end")).with_cond(exact_instructions("CALL_INTRINSIC_1", "RERAISE")),
end=N().of_type(MetaTemplate),
)
try_match = make_try_match({EdgeKind.Fall: "end"}, "entry", "body", "gen_cleanup")
@to_indented_source
def to_indented_source():
"""
{entry}
{body}
"""
@@ -0,0 +1,60 @@
from ..cft import ControlFlowTemplate, EdgeKind, register_template
from ..utils import (
T,
N,
defer_source_to,
starting_instructions,
to_indented_source,
make_try_match,
)
@register_template(0, 1)
class ForLoop(ControlFlowTemplate):
template = T(
for_iter=~N("for_body", "tail"),
for_body=~N("for_iter").with_in_deg(1),
tail=N.tail(),
)
try_match = make_try_match({EdgeKind.Fall: "tail"}, "for_iter", "for_body")
@to_indented_source
def to_indented_source():
"""
{for_iter}
{for_body}
"""
@register_template(0, 2)
class SelfLoop(ControlFlowTemplate):
template = T(loop_body=~N("loop_body", None))
try_match = make_try_match({}, "loop_body")
@to_indented_source
def to_indented_source():
"""
while True:
{loop_body}
"""
@register_template(0, 3)
class InlinedComprehensionTemplate(ControlFlowTemplate):
template = T(
comp=N("tail", None, "cleanup"),
cleanup=+N().with_in_deg(1).with_cond(starting_instructions("SWAP", "POP_TOP", "SWAP")),
tail=~N.tail(),
)
try_match = make_try_match(
{
EdgeKind.Fall: "tail",
},
"comp",
"cleanup",
)
to_indented_source = defer_source_to("comp")
@@ -0,0 +1,41 @@
from typing import override
from ..cft import ControlFlowTemplate, EdgeKind, register_template
from ..utils import T, N, exact_instructions, starting_instructions, to_indented_source, make_try_match, versions_from
class WithCleanup3_12(ControlFlowTemplate):
template = T(
start=N("reraise", "poptop", "exc").with_cond(
exact_instructions("PUSH_EXC_INFO", "WITH_EXCEPT_START", "POP_JUMP_IF_TRUE"), # 3.12
exact_instructions("PUSH_EXC_INFO", "WITH_EXCEPT_START", "TO_BOOL", "POP_JUMP_IF_TRUE"), # 3.13
),
reraise=N(None, None, "exc").with_cond(exact_instructions("RERAISE")).with_in_deg(1),
poptop=N("tail", None, "exc").with_cond(exact_instructions("POP_TOP")).with_in_deg(1),
exc=+N().with_cond(exact_instructions("COPY", "POP_EXCEPT", "RERAISE")).with_in_deg(3),
tail=~N.tail().with_cond(starting_instructions("POP_EXCEPT", "POP_TOP", "POP_TOP")).with_in_deg(1),
)
try_match = make_try_match({}, "start", "reraise", "poptop", "exc", "tail")
@override
def to_indented_source(self, source):
return []
@register_template(0, 10, *versions_from(3, 12))
class With3_12(ControlFlowTemplate):
template = T(
setup_with=~N("with_body", None),
with_body=N("normal_cleanup", None, "exc_cleanup").with_in_deg(1),
exc_cleanup=N.tail().of_subtemplate(WithCleanup3_12).with_in_deg(1),
normal_cleanup=~N.tail(),
)
try_match = make_try_match({EdgeKind.Fall: "normal_cleanup"}, "setup_with", "with_body", "exc_cleanup")
@to_indented_source
def to_indented_source():
"""
{setup_with}
{with_body}
"""
@@ -0,0 +1,4 @@
from pathlib import Path
__all__ = [x.stem for x in Path(__file__).parent.glob("*.py") if x.stem != "__init__"]
from . import *
@@ -0,0 +1,496 @@
from __future__ import annotations
from functools import partial
from itertools import chain
import textwrap
import pdb
import sys
from typing import TYPE_CHECKING, Callable, TypeVar, override
from pylingual.utils.version import supported_versions
from .cft import (
AnyNodeMatcher,
ConditionalNodeMatcher,
ControlFlowTemplate,
EdgeCategory,
EdgeKind,
EdgeMatcher,
EdgeTemplate,
ExitableEdge,
InstTemplate,
OptExcEdge,
RaiseOutEdge,
SourceContext,
SourceLine,
SubtemplateNodeMatcher,
Template,
NoEdge,
NodeMatcher,
NodeTemplate,
OptionalEdge,
OptionalNodeMatcher,
)
if TYPE_CHECKING:
from pylingual.control_flow_reconstruction.cfg import CFG
C = TypeVar("C", bound=ControlFlowTemplate)
no_edge = NoEdge()
def has_in_degree(n: int):
def check_in_degree(cfg: CFG, node: ControlFlowTemplate | None) -> bool:
return node is not None and cfg.in_degree(node) == n
return check_in_degree
def exact_instructions(*opnames: str):
def check_instructions(cfg: CFG, node: ControlFlowTemplate | None) -> bool:
return node is not None and tuple(x.opname for x in node.get_instructions()) == opnames
return check_instructions
def starting_instructions(*opnames: str):
def check_instructions(cfg: CFG, node: ControlFlowTemplate | None) -> bool:
return node is not None and tuple(x.opname for x in node.get_instructions()[: len(opnames)]) == opnames
return check_instructions
def ending_instructions(*opnames: str):
def check_instructions(cfg: CFG, node: ControlFlowTemplate | None) -> bool:
return node is not None and tuple(x.opname for x in node.get_instructions()[-len(opnames) :]) == opnames
return check_instructions
def without_instructions(*opnames: str):
def check_instructions(cfg: CFG, node: ControlFlowTemplate | None) -> bool:
return node is not None and all(x.opname not in opnames for x in node.get_instructions())
return check_instructions
def without_top_level_instructions(*opnames: str):
from .templates.Block import BlockTemplate
def check_instructions(cfg: CFG, node: ControlFlowTemplate | None) -> bool:
if isinstance(node, BlockTemplate):
return all(x.inst.opname not in opnames for x in node.members if isinstance(x, InstTemplate))
if isinstance(node, InstTemplate):
return node.inst.opname not in opnames
return True
return check_instructions
def has_type(*template_type: type[ControlFlowTemplate]):
def check_type(cfg: CFG, node: ControlFlowTemplate | None) -> bool:
return isinstance(node, template_type)
return check_type
def no_back_edges(cfg: CFG, node: ControlFlowTemplate | None) -> bool:
return node is None or not any(cfg.dominates(succ, node) for succ in cfg.successors(node))
def run_is(n: int):
def check_run(cfg: CFG, node: ControlFlowTemplate | None) -> bool:
return cfg.run == n
return check_run
_AUTO_EXC = "_EXC"
def T(root: str | None = None, **nodes: N | NodeTemplate) -> Template:
"""
Convenience function for creating `Template`s
If `root` is None, the first node in `nodes` is used
"""
assert _AUTO_EXC not in nodes
if any(x._auto_exc for x in nodes.values() if isinstance(x, N)):
nodes[_AUTO_EXC] = N.tail().optional()
if root is None:
root = next(iter(nodes))
return Template(root, {k: v._build(k) if isinstance(v, N) else v for k, v in nodes.items()})
if TYPE_CHECKING:
NodeCondition = Callable[[CFG, ControlFlowTemplate | None], bool]
_ec = [EdgeCategory.Natural, EdgeCategory.Conditional, EdgeCategory.Exception, EdgeCategory.Meta]
_no_edges = {k: no_edge for k in _ec}
def _to_edge_dict(*edges: tuple[EdgeCategory, EdgeMatcher] | str | None) -> dict[EdgeCategory, EdgeMatcher]:
return dict(x if isinstance(x, tuple) else (_ec[i], E._(x)) for i, x in enumerate(edges))
class N:
"""
`NodeTemplate` builder class
"""
_edges: dict[EdgeCategory, EdgeMatcher]
_conds: list[NodeCondition]
_is_optional: bool
_auto_exc: bool
def __init__(self, *edges: tuple[EdgeCategory, EdgeMatcher] | str | None):
self._edges = _no_edges | _to_edge_dict(*edges)
if any(x.endswith(".") for x in edges if isinstance(x, str)):
del self._edges[EdgeCategory.Meta]
self._conds = []
self._is_optional = False
self._auto_exc = False
self._subtemplate = None
def __invert__(self) -> N:
"""
This node is connected to outer exception handler if one exists. An outer exception handler node will be automatically added to the template.
"""
self._edges[EdgeCategory.Exception] = OptExcEdge(_AUTO_EXC)
self._auto_exc = True
return self
def __pos__(self) -> N:
"""
This node raises an exception to either an outer exception handler, or out of the codeobject. An outer exception handler node will be automatically added to the template.
"""
self._edges[EdgeCategory.Exception] = RaiseOutEdge(_AUTO_EXC)
if EdgeCategory.Meta in self._edges:
del self._edges[EdgeCategory.Meta]
self._auto_exc = True
return self
@staticmethod
def tail() -> N:
"""
Create a node and do not check any out edges from it.
"""
x = N()
x._edges = {}
return x
def optional(self) -> N:
"""
The node is optional.
"""
self._is_optional = True
return self
def with_in_deg(self, n: int, *n2: int) -> N:
"""
The node must have in-degree `n` or any of the in-degrees in `n2`.
"""
if not n2:
self._conds.append(has_in_degree(n))
else:
self._conds.append(lambda cfg, node: node is not None and cfg.in_degree(node) in (n, *n2))
return self
def of_type(self, *template_type: type[ControlFlowTemplate]) -> N:
"""
The node must be any template in `template_type`.
"""
self._conds.append(has_type(*template_type))
return self
def of_subtemplate(self, template_type: type[ControlFlowTemplate]) -> N:
"""
When matching a node, first try to match `template_type` rooted at the node, and only accept if the template successfully matched.
"""
self._subtemplate = template_type
return self
def with_cond(self, cond: NodeCondition, *or_conds: NodeCondition) -> N:
"""
The node must match `cond` or any of the conditions in `or_conds`.
"""
if not or_conds:
self._conds.append(cond)
else:
self._conds.append(lambda cfg, node: any(f(cfg, node) for f in (cond, *or_conds)))
return self
def __or__(self, o: N) -> N:
"""
Match either this node or the other node.
"""
return _Ns(self, o)
def _all_conds(self, cfg: CFG, node: ControlFlowTemplate | None) -> bool:
return all(c(cfg, node) for c in self._conds)
def _build(self, name: str) -> NodeMatcher:
x = NodeTemplate(self._edges)
x.name = name
if self._subtemplate:
x = SubtemplateNodeMatcher(x, self._subtemplate)
name += ".subtemplate"
x.name = name
if len(self._conds) == 1:
x = ConditionalNodeMatcher(x, self._conds[0])
name += ".condition"
x.name = name
elif self._conds:
x = ConditionalNodeMatcher(x, self._all_conds)
name += ".condition"
x.name = name
if self._is_optional:
x = OptionalNodeMatcher(x)
name += ".optional"
x.name = name
return x
class _Ns(N):
def __init__(self, a: N, b: N):
self.nodes = [a, b]
@override
def _build(self, name) -> NodeMatcher:
return AnyNodeMatcher(*(x._build(name + ".any") for x in self.nodes))
@override
def optional(self) -> N:
for node in self.nodes:
node._is_optional = True
return self
@override
def with_in_deg(self, n: int, *n2: int) -> N:
self.nodes = [node.with_in_deg(n, *n2) for node in self.nodes]
return self
@override
def of_type(self, *template_type: type[ControlFlowTemplate]) -> N:
self.nodes = [n.of_type(*template_type) for n in self.nodes]
return self
@override
def with_cond(self, cond: NodeCondition, *or_conds: NodeCondition) -> N:
self.nodes = [n.with_cond(cond, *or_conds) for n in self.nodes]
return self
@override
def __or__(self, o: N) -> N:
if isinstance(o, _Ns):
self.nodes.extend(o.nodes)
else:
self.nodes.append(o)
return self
class E:
"""
Namespace for edge convenience functions.
"""
@staticmethod
def _(x: str | None) -> EdgeMatcher:
if x is None:
return no_edge
if x[-1] == "?":
return OptionalEdge(x[:-1])
if x[-1] == ".":
return ExitableEdge(x[:-1])
if x[-1] == "^":
return RaiseOutEdge(x[:-1])
return EdgeTemplate(x)
@staticmethod
def nat(n: str | None):
return (EdgeCategory.Natural, E._(n))
@staticmethod
def cond(n: str | None):
return (EdgeCategory.Conditional, E._(n))
@staticmethod
def exc(n: str | None):
return (EdgeCategory.Exception, E._(n))
@staticmethod
def meta(n: str | None):
return (EdgeCategory.Meta, E._(n))
def remove_nodes(cfg: CFG, mapping: dict[str, ControlFlowTemplate | None], *nodes: str):
cfg.remove_nodes_from(filter(None, (mapping.get(n) for n in nodes if mapping.get(n))))
def _line(line: str) -> Callable[[ControlFlowTemplate, SourceContext], list[SourceLine]]:
x = line.lstrip(" ")
indent = (len(line) - len(x)) // 4
if x[0] == "{":
end = x.index("}")
t = x[1:end]
if "?" in t:
s = t[t.index("?") + 1 :]
t = t[: t.index("?")]
return partial(lambda self, source, indent, t, s: self.line(s, indent) if self.members[t] is not None and source[self.members[t], indent] else [], indent=indent, t=t, s=s)
return partial(lambda self, source, indent, t: source[self.members[t], indent] if self.members[t] is not None else [], indent=indent, t=t)
return lambda self, source: self.line(x, indent)
def to_indented_source(f: Callable[[], None]):
"""
"Compile" a function's docstring into an indented source function
Indentation must be 4 spaces
"""
assert f.__doc__ is not None and "\t" not in f.__doc__
src = [_line(x) for x in textwrap.dedent(f.__doc__).strip().split("\n")]
def to_indented_source(self: ControlFlowTemplate, source: SourceContext) -> list[SourceLine]:
return list(chain.from_iterable(x(self, source) for x in src))
return to_indented_source
def defer_source_to(n: str):
def to_indented_source(self: ControlFlowTemplate, source: SourceContext) -> list[SourceLine]:
node = self.members[n]
if node is None:
return []
return source[node]
return to_indented_source
def condense_mapping(
cls: type[C], cfg: CFG, mapping: dict[str, ControlFlowTemplate | None], *nodes: str, in_edges: dict[ControlFlowTemplate, dict] | None = None, out_edges: dict[ControlFlowTemplate, dict] | None = None, out_filter: list[EdgeCategory] = []
) -> C:
in_template = {x: mapping.get(x) for x in nodes}
template = cls(in_template)
if in_edges is None:
in_edges = {src: prop for n in reversed(in_template.values()) for src, _, prop in cfg.in_edges(n, data=True) if src not in in_template.values() and n is not None}
if out_edges is None:
out_edges = {dst: prop for n in reversed(in_template.values()) for _, dst, prop in cfg.out_edges(n, data=True) if dst not in in_template.values() and n is not None}
if cfg.end in out_edges:
out_edges[cfg.end] = EdgeKind.Meta.prop()
if not out_edges:
out_edges[cfg.end] = EdgeKind.Meta.prop()
if out_filter:
out_edges = {k: v for k, v in out_edges.items() if EdgeCategory.from_kind(v["kind"]) not in out_filter}
remove_nodes(cfg, mapping, *in_template)
cfg.add_node(template)
cfg.add_edges_from((src, template, prop) for src, prop in in_edges.items())
cfg.add_edges_from((template, dst, prop) for dst, prop in out_edges.items())
cfg.iterate()
return template
def make_try_match(out_edges: dict[EdgeKind, str], *nodes: str):
"""
Make a `try_match` method for a `ControlFlowTemplate`.
Matches `cls.template`, condenses all nodes in `nodes`, and creates a new node.
"""
@classmethod
@override
def try_match(cls: type[ControlFlowTemplate], cfg: CFG, node: ControlFlowTemplate) -> ControlFlowTemplate | None:
mapping = cls.template.try_match(cfg, node)
if mapping is None:
return None
edges: dict[ControlFlowTemplate, dict] = {mapping[name]: kind.prop() for kind, name in out_edges.items() if mapping.get(name) is not None} # type: ignore
if mapping.get(_AUTO_EXC) is not None and all(e["kind"] != EdgeKind.Exception for e in edges.values()):
edges[mapping[_AUTO_EXC]] = EdgeKind.Exception.prop() # type: ignore
return condense_mapping(cls, cfg, mapping, *nodes, out_edges=edges)
return try_match
def revert_on_fail(f: Callable[[type[ControlFlowTemplate], CFG, ControlFlowTemplate], ControlFlowTemplate | None] | classmethod):
"""
Make a `ControlFlowTemplate`'s `try_match` method restore the CFG to before the method call if the match fails.
"""
if isinstance(f, classmethod):
f = f.__func__
@classmethod
@override
def try_match(cls: type[ControlFlowTemplate], cfg: CFG, node: ControlFlowTemplate) -> ControlFlowTemplate | None:
copy = cfg.copy()
if (ret := f(cls, cfg, node)) is not None:
return ret
cfg.clear()
cfg.update(copy)
return try_match
def _check_break_condition(cfg: CFG, node: ControlFlowTemplate | None, offset: int | None, i: int | None, name: str | None):
if offset is not None and (not node or node.offset != offset):
return False
if i is not None and cfg.i != i:
return False
if name is not None and cfg.bytecode.name != name:
return False
return True
def _hook(f, offset, i, name):
def hooked(cfg: CFG, node: ControlFlowTemplate | None):
if _check_break_condition(cfg, node, offset, i, name):
p = pdb.Pdb()
p.quitting = False
p.botframe = None
p.stopframe = None
print(f"{cfg.i = }\n{cfg.bytecode.name = }\nnode.offset = {node and node.offset}\n{node = }")
sys.settrace(p.trace_dispatch)
return f(cfg, node)
return hooked
def hook_template(offset: int | None = None, i: int | None = None, name: str | None = None):
"""
Hook a `ControlFlowTemplate`'s `try_match` method to set a breakpoint before running when certain conditions are met.
"""
def deco(template: type[C]):
template.try_match = _hook(template.try_match, offset, i, name)
return template
return deco
def hook_node(node: str, offset: int | None = None, i: int | None = None, name: str | None = None):
"""
In this `ControlFlowTemplate`, hook the node named `node`'s `try_match` method to set a breakpoint before running when certain conditions are met.
"""
def deco(template: type[C]):
template.template.nodes[node].try_match = _hook(template.template.nodes[node].try_match, offset, i, name)
return template
return deco
def versions_above(major: int, minor: int):
return (x.as_tuple() for x in supported_versions if x > (major, minor))
def versions_from(major: int, minor: int):
return (x.as_tuple() for x in supported_versions if x >= (major, minor))
def versions_below(major: int, minor: int):
return (x.as_tuple() for x in supported_versions if x < (major, minor))
def versions_until(major: int, minor: int):
return (x.as_tuple() for x in supported_versions if x <= (major, minor))
def versions_except(*versions: tuple[int, int]):
return (x.as_tuple() for x in supported_versions if x not in versions)
+152 -168
View File
@@ -1,14 +1,33 @@
from __future__ import annotations
from xdis import Code3
Code3.__eq__ = (
lambda self, o: isinstance(o, Code3)
and self.co_argcount == o.co_argcount
and self.co_nlocals == o.co_nlocals
and self.co_flags == o.co_flags
and self.co_code == o.co_code
and self.co_consts == o.co_consts
and self.co_names == o.co_names
and self.co_varnames == o.co_varnames
and self.co_filename == o.co_filename
and self.co_name == o.co_name
and self.co_stacksize == o.co_stacksize
and self.co_firstlineno == o.co_firstlineno
and self.co_freevars == o.co_freevars
and self.co_cellvars == o.co_cellvars
and self.co_kwonlyargcount == o.co_kwonlyargcount
)
Code3.__hash__ = lambda self: hash(self.co_code)
import datetime
import functools
import importlib.resources
import itertools
import keyword
import logging
import re
import tempfile
import shutil
import sys
from dataclasses import dataclass
from pathlib import Path
@@ -16,11 +35,12 @@ from typing import TYPE_CHECKING
from xdis.magics import magicint2version
from pylingual.control_flow_reconstruction.cflow import bytecode_to_indented_source
from pylingual.control_flow_reconstruction.reconstruct_control_indentation import reconstruct_source
from pylingual.control_flow_reconstruction.source import SourceContext, SourceLine
from pylingual.control_flow_reconstruction.structure import bc_to_cft
from pylingual.control_flow_reconstruction.cft import MetaTemplate
from pylingual.equivalence_check import TestResult, compare_pyc
from pylingual.models import CacheTranslator, load_models
from pylingual.utils.generate_bytecode import CompileError, compile_version
from pylingual.utils.generate_bytecode import CompileError, compile_version, has_pyenv
from pylingual.masking.model_disasm import create_global_masker, restore_masked_source_text
from pylingual.editable_bytecode import PYCFile
from pylingual.segmentation.segmentation_search_strategies import get_top_k_predictions, m_deep_top_k, naive_confidence_priority, filter_subwords
@@ -37,8 +57,6 @@ logger = logging.getLogger(__name__)
bytecode_separator = " <SEP> "
lno_regex = re.compile(r"(?<=line )\d+")
def_regex = re.compile(r"(?<=def ).+?(?=\()")
class_regex = re.compile(r"(?<=class ).+?(?=:|\()")
def has_comp_error(results: list[TestResult]) -> bool:
@@ -50,34 +68,29 @@ class DecompilerResult:
"""
Dataclass containing relevant results from decompiling a pyc
:param decompiled_source: str containing the decompiler output
:param equivalence_results: list of internal bytecode comparison results
:param original_pyc: path to original pyc
:param decompiled_source: path to decompiled source
:param out_dir: directory where decompiler output and internal steps are written
:param original_pyc: original pyc
:param version: python version of pyc
"""
decompiled_source: str
equivalence_results: list[TestResult]
original_pyc: Path
decompiled_source: Path
out_dir: Path
original_pyc: PYCFile
version: PythonVersion
def calculate_success_rate(self) -> float:
if not self.equivalence_results:
return 0
return sum(1 for x in self.equivalence_results if x.success) / len(self.equivalence_results) * 100
return sum(1 for x in self.equivalence_results if x.success) / len(self.equivalence_results)
class Decompiler:
"""
You probably want to use decompile() instead.
Decompiles a PYC file after masking bytecode, segmenting bytecode, and translating bytecode back into source statements, then reconstructs the control flow.
Additionally saves the decompiled file into the specified output directory.
:param pyc: The PYCFile loaded into memory
:param out_dir: The output directory where decompilation results will be stored
:param segmenter: The loaded segmentation model
:param translator: The loaded translation model
:param version: The python version
@@ -85,151 +98,109 @@ class Decompiler:
:param trust_lnotab: Decides whether or not to use line number information
"""
def __init__(self, pyc: PYCFile, out_dir: Path, segmenter: transformers.Pipeline, translator: CacheTranslator, version: PythonVersion, top_k=10, trust_lnotab=False):
def __init__(self, pyc: PYCFile, segmenter: transformers.Pipeline, translator: CacheTranslator, version: PythonVersion, top_k=10, trust_lnotab=False):
self.pyc = pyc
self.file = pyc.pyc_path
self.out_dir = out_dir
self.pyc.copy()
self.name = pyc.pyc_path.name if pyc.pyc_path is not None else repr(pyc)
self.segmenter = segmenter
self.translator = translator
self.version = version
self.out_dir.mkdir(parents=True, exist_ok=True)
self.top_k = top_k
self.highest_k_used = 0
self.tmpn = 0
self.trust_lnotab = trust_lnotab
self.header = "# Decompiled with PyLingual (https://pylingual.io)\n"
try:
self.header += (
f"# Internal filename: {self.pyc.codeobj.co_filename}\n"
f"# Bytecode version: {magicint2version[self.pyc.magic]} ({self.pyc.magic})\n"
f"# Source timestamp: {datetime.datetime.fromtimestamp(self.pyc.timestamp, datetime.UTC).strftime('%Y-%m-%d %H:%M:%S UTC')} ({self.pyc.timestamp})\n\n"
)
except:
pass
def __call__(self):
with tempfile.TemporaryDirectory() as tmp:
self.tmp = Path(tmp)
self.decompile()
self.log_results()
logger.info(f"Checking decompilation for {self.file.name}...")
if shutil.which("pyenv") is None and self.version != sys.version_info:
logger.warning(f"pyenv is not installed so equivalence check cannot be performed. Please install pyenv manually along with the required Python version ({self.version}) or run PyLingual again with the --init-pyenv flag")
self.result = DecompilerResult([TestResult(False, "Cannot compare equivalence without pyenv installed", bc.name, bc.name) for bc in self.pyc.iter_bytecodes()], self.file, self.candidate_source_path, self.out_dir, self.version)
return
self.equivalence_results = self.check_reconstruction()
self.correct_failures()
if has_comp_error(self.equivalence_results):
self.equivalence_results += self.purge_comp_errors()
equivalence_report = self.out_dir / "equivalence_report.txt"
equivalence_report.write_text("\n".join(str(r) for r in self.equivalence_results))
self.result = DecompilerResult(self.equivalence_results, self.file, self.candidate_source_path, self.out_dir, self.version)
def decompile(self):
self.mask_bytecode()
if self.trust_lnotab:
self.update_segmentation_from_lnotab()
else:
self.mask_bytecode()
self.run_segmentation()
self.run_translation()
self.run_cflow_reconstruction()
self.reconstruct_source()
self.run_translation()
self.unmask_lines()
self.run_cflow_reconstruction()
self.reconstruct_source()
if not has_pyenv() and self.version != sys.version_info:
logger.warning(f"pyenv is not installed so equivalence check cannot be performed. Please install pyenv manually along with the required Python version ({self.version}) or run PyLingual again with the --init-pyenv flag")
return DecompilerResult(self.indented_source, [TestResult(False, "Cannot compare equivalence without pyenv installed", bc, bc) for bc in self.pyc.iter_bytecodes()], self.pyc, self.version)
self.equivalence_results = self.check_reconstruction(self.indented_source)
self.correct_failures()
if has_comp_error(self.equivalence_results):
self.equivalence_results += self.purge_comp_errors()
for tr in self.equivalence_results:
if tr.bc_a is not None and not tr.success:
self.source_context.cfts[tr.bc_a.codeobj].add_header(f"# {tr}", meta=True)
return DecompilerResult(str(self.source_context), self.equivalence_results, self.pyc, self.version)
def find_comp_error_cause(self, results: list[TestResult]):
# parse lno from exception
lno = int(lno_regex.search(str(results[0])).group(0)) - 1
# adjust for lines added in postprocessing
lno -= sum(1 for x in (self.header + self.indented_source).split("\n")[: lno + 1] if x.endswith("# postinserted") or not x.strip() or x.strip().startswith("#"))
lno = lno_regex.search(str(results[0]))
if lno is None:
return None
lno = int(lno.group(0)) - 1
# get offending codeobj
bad_codeobj = self.blame[lno]
bad_codeobj = self.source_context.source_lines()[lno].blame
bad_idx = next(i for i, e in enumerate(self.ordered_bytecodes) if e.codeobj == bad_codeobj)
return bad_idx
def correct_failures(self):
changed = False
try:
# fix compile errors
corrected_comp_errors = set()
while has_comp_error(self.equivalence_results):
bad_idx = self.find_comp_error_cause(self.equivalence_results)
# i don't think this will ever happen but better safe than sorry
if bad_idx in corrected_comp_errors:
if bad_idx is None or bad_idx in corrected_comp_errors:
return
if not self.correct_segmentation(bad_idx, from_comp_error=True):
return
changed = True
corrected_comp_errors.add(bad_idx)
failed = TrackedList(CORRECTION_STEP, [i for i, result in enumerate(self.equivalence_results) if not result.success])
for i in failed:
if self.correct_segmentation(i):
changed = True
continue
# other fixes...
except Exception as e:
e.add_note("From error correction")
raise
finally:
if changed:
self.log_results()
# get eq results after replacing all codeobjs with comp errors with pass, preserving nested codeobjs
def purge_comp_errors(self):
logger.info("Removing compile errors")
try:
equivalence_results = self.equivalence_results
def replace_line(line):
line = line.strip()
x = def_regex.search(line)
if x is not None:
try:
x = x.group(0)
x = self.global_masker.unmask(x) if x.startswith("<mask_") else x
if not x.isidentifier() or keyword.iskeyword(x):
x = "_"
except:
x = "_"
return f"def {x}():"
x = class_regex.search(line)
if x is not None:
try:
x = x.group(0)
x = self.global_masker.unmask(x) if x.startswith("<mask_") else x
if not x.isidentifier() or keyword.iskeyword(x):
x = "_"
except:
x = "_"
return f"class {x}:"
if line.endswith("# inserted"):
return "pass # inserted"
return "pass"
purged = []
while has_comp_error(equivalence_results):
bad_idx = self.find_comp_error_cause(equivalence_results)
bad_co = self.ordered_bytecodes[bad_idx].codeobj
if bad_idx in purged:
if bad_idx is None:
logger.info("Could not find line number of error, unable to fix compile errors")
self.source_context.purged_cfts = []
return []
bad_bc = self.ordered_bytecodes[bad_idx]
if bad_idx in purged:
logger.info(f"{bad_bc.name} was already purged, unable to fix compile errors")
self.source_context.purged_cfts = []
return []
logger.info(f"Purging {bad_bc.name}")
purged.append(bad_idx)
self.cflow_results[bad_co] = [replace_line(x) for x in "\n".join(self.cflow_results[bad_co]).split("\n")]
self.reconstruct_source()
equivalence_results = self.check_reconstruction(write_source=True)
self.source_context.purge(bad_bc.codeobj)
equivalence_results = self.check_reconstruction(str(self.source_context))
for i in purged:
r = equivalence_results[i]
equivalence_results[i] = TestResult(False, "Compilation Error", r.name_a, r.name_b)
equivalence_results[i] = TestResult(False, "Compilation Error", r.bc_a, r.bc_b)
self.source_context.purged_cfts = []
return equivalence_results
except:
self.source_context.purged_cfts = []
return []
def mask_bytecode(self):
logger.info(f"Masking bytecode for {self.file.name}...")
logger.info(f"Masking bytecode for {self.name}...")
try:
self.global_masker = create_global_masker(self.pyc)
# create a dict of line num : [bytecodes composing line]
@@ -241,8 +212,18 @@ class Decompiler:
e.add_note("From masking bytecode")
raise
def unmask_lines(self):
logger.info(f"Unmasking lines for {self.name}...")
try:
self.source_lines = restore_masked_source_text(self.source_lines, self.global_masker)
except Exception as e:
e.add_note("From masking bytecode")
raise
def run_segmentation(self):
logger.info(f"Segmenting bytecode for {self.file.name}...")
if self.trust_lnotab:
return self.update_segmentation_from_lnotab()
logger.info(f"Segmenting bytecode for {self.name}...")
try:
MAX_WINDOW_LENGTH = 512
STEP_SIZE = 128
@@ -308,7 +289,7 @@ class Decompiler:
self.update_starts_line()
def run_translation(self):
logger.info(f"Translating statements for {self.file.name}...")
logger.info(f"Translating statements for {self.name}...")
try:
translation_requests = []
for instructions, boundary_predictions in zip(self.ordered_instructions, self.segmentation_results):
@@ -322,41 +303,34 @@ class Decompiler:
raise
def run_cflow_reconstruction(self):
logger.info(f"Reconstructing control flow for {self.file.name}...")
logger.info(f"Reconstructing control flow for {self.name}...")
try:
self.cflow_results = {bc.codeobj: bytecode_to_indented_source(bc, self.source_lines) for bc in TrackedDataset(CFLOW_STEP, self.ordered_bytecodes)}
cfts = {bc.codeobj: bc_to_cft(bc) for bc in TrackedList(CFLOW_STEP, self.ordered_bytecodes)}
self.source_context = SourceContext(self.pyc, self.source_lines, cfts)
version = magicint2version.get(self.pyc.magic, "?")
time = datetime.datetime.fromtimestamp(self.pyc.timestamp, datetime.UTC).strftime("%Y-%m-%d %H:%M:%S UTC")
self.source_context.header_lines = [
SourceLine("# Decompiled with PyLingual (https://pylingual.io)", 0, self.pyc.codeobj, meta=True),
SourceLine(f"# Internal filename: {self.pyc.codeobj.co_filename!r}", 0, self.pyc.codeobj, meta=True),
SourceLine(f"# Bytecode version: {version} ({self.pyc.magic})", 0, self.pyc.codeobj, meta=True),
SourceLine(f"# Source timestamp: {time} ({self.pyc.timestamp})", 0, self.pyc.codeobj, meta=True),
SourceLine("", 0, self.pyc.codeobj, meta=True),
]
except Exception as e:
e.add_note("From control flow reconstruction")
raise
# merge sources and unmask source
def reconstruct_source(self):
logger.info(f"Reconstructing source for {self.file.name}...")
# merge sources and postprocess results
logger.info(f"Reconstructing source for {self.name}...")
try:
self.indented_masked_source, self.blame = reconstruct_source(self.pyc, {a: b for a, b in self.cflow_results.items()})
self.indented_source = str(self.source_context)
except Exception as e:
e.add_note("From control flow reconstruction")
e.add_note("From source reconstruction")
raise
# undo the masking
try:
self.indented_source = restore_masked_source_text(self.indented_masked_source, self.global_masker, python_version=self.version)
except Exception as e:
e.add_note("From unmasking source")
raise
# write indented source and pyc
def log_results(self):
self.candidate_source_path = self.out_dir / self.file.with_suffix(".py").name
self.candidate_pyc_path = self.candidate_source_path.with_suffix(".pyc")
self.candidate_source_path.write_text(self.header + self.indented_source)
try:
compile_version(self.candidate_source_path, self.candidate_pyc_path, self.version)
except Exception:
pass # it's ok if the python doesn't compile
# make a translation request from a segmentation result
def make_translation_request(self, instructions: list[Inst], boundary_predictions: list[dict]) -> list[str]:
def make_translation_request(self, instructions: list[list["Inst"]], boundary_predictions: list[dict]) -> list[str]:
translation_requests = []
for inst, boundary_prediction in zip(instructions, boundary_predictions):
if boundary_prediction["entity"] == "B":
@@ -372,28 +346,30 @@ class Decompiler:
elif self.version >= (3, 10):
self.pyc.fix_while(self.source_lines)
# compiles and compares result to original pyc
def check_reconstruction(self, write_source=False) -> list:
candidate_source_path = self.candidate_source_path
candidate_pyc_path = self.candidate_pyc_path
if write_source:
tmp = tempfile.NamedTemporaryFile(mode="w", suffix=".py")
tmp.write(self.header + self.indented_source)
candidate_source_path = Path(tmp.name)
candidate_pyc_path = Path(tmp.name).with_suffix(".pyc")
def tmpfile(self):
self.tmpn += 1
return self.tmp / str(self.tmpn)
# compile source
# compiles and compares result to original pyc
def check_reconstruction(self, source: str) -> list[TestResult]:
logger.info(f"Checking decompilation for {self.name}...")
src = self.tmpfile()
pyc = self.tmpfile()
src.write_text(source)
try:
compile_version(candidate_source_path, candidate_pyc_path, self.version)
compile_version(src, pyc, self.version)
except CompileError as e:
return [e]
else:
return compare_pyc(self.file, candidate_pyc_path)
return compare_pyc(self.pyc, pyc)
# try to correct the segmentation of the ith code object
def correct_segmentation(self, i: int, from_comp_error=False) -> bool:
if not self.segmentation_results[i]:
return False
if isinstance(self.source_context.cfts[self.ordered_bytecodes[i].codeobj], MetaTemplate):
return False
logger.info(f"Trying to fix segmentation for {self.ordered_bytecodes[i].name}")
original_prediction = [r["entity"] for r in self.segmentation_results[i]]
strategy = functools.partial(m_deep_top_k, priority_function=naive_confidence_priority, m=2, k=self.top_k + 1)
# skip first prediction since it is the same as original
@@ -406,37 +382,38 @@ class Decompiler:
self.update_starts_line()
# retranslate affected bytecode
translation_request = self.make_translation_request(self.ordered_instructions[i], self.segmentation_results[i])
previous_lines, previous_indented_source = self.source_lines, self.indented_source
try:
self.translation_results[i] = self.translator(translation_request)
self.update_source_lines()
self.unmask_lines()
except Exception as e:
e.add_note("From translation")
raise
# redo cflow of affected bytecode
try:
bc = self.ordered_bytecodes[i]
self.cflow_results[bc.codeobj] = bytecode_to_indented_source(bc, self.source_lines)
except Exception as e:
e.add_note("From control flow reconstruction")
raise
self.source_context.update_lines(self.source_lines)
# check if new reconstruction is correct
previous_indented_masked_source, previous_blame, previous_indented_source = self.indented_masked_source, self.blame, self.indented_source
self.reconstruct_source()
equivalence_results = self.check_reconstruction(write_source=True)
equivalence_results = self.check_reconstruction(self.indented_source)
if from_comp_error:
if not has_comp_error(equivalence_results) or self.find_comp_error_cause(equivalence_results) != i:
if not has_comp_error(equivalence_results) or self.find_comp_error_cause(equivalence_results) not in [None, i]:
self.equivalence_results = equivalence_results
self.highest_k_used = max(self.highest_k_used, k)
logger.info(f"Updated segmentation for {self.ordered_bytecodes[i].name}")
return True
elif not has_comp_error(equivalence_results) and equivalence_results[i].success:
self.equivalence_results[i] = equivalence_results[i]
self.highest_k_used = max(self.highest_k_used, k)
logger.info(f"Updated segmentation for {self.ordered_bytecodes[i].name}")
return True
# correction failed, roll back changes to internal source code storage
self.indented_masked_source, self.blame, self.indented_source = previous_indented_masked_source, previous_blame, previous_indented_source
self.indented_source = previous_indented_source
self.source_lines = previous_lines
self.source_context.update_lines(previous_lines)
# revert to original segmentation
for r, p in zip(self.segmentation_results[i], original_prediction):
r["entity"] = p
self.update_starts_line()
logger.info(f"Could not fix segmentation for {self.ordered_bytecodes[i].name}")
return False
# update starts_line of all instructions based on segmentation results
@@ -451,20 +428,21 @@ class Decompiler:
inst.starts_line = None
def decompile(file: Path, out_dir: Path, config_file: Path | None = None, version: PythonVersion | tuple[int, int] | str | None = None, top_k: int = 10, trust_lnotab: bool = False) -> DecompilerResult:
def decompile(pyc: PYCFile | Path, save_to: Path | None = None, config_file: Path | None = None, version: str | None = None, top_k: int = 10, trust_lnotab: bool = False) -> DecompilerResult:
"""
Decompile a PYC file.
:param file: path to pyc to decompile
:param out_dir: Path to save decompilation results and steps to. Defaults to ./decompiled_<pyc_name>/
:param config_file: Path to decompiler_config.yaml to load. recommended None, which loads the default pylingual config.
:param pyc: PYCFile or Path to decompile.
:param save_to: Path to save decompilation results to or None.
:param config_file: Path to decompiler_config.yaml to load. Use None to load the default PyLingual config (recommended).
:param version: Loads the models corresponding to this python version. if None, automatically detects version based on input PYC file.
:param top_k: Max number of pyc segmentations to consider.
:param trust_lnotab: Trust the lnotab in the input PYC for segmentation, recommended False.
:param trust_lnotab: Trust the lnotab in the input PYC for segmentation (False recommended).
:return: DecompilerResult class including important information about decompilation
"""
logger.info(f"Loading {file}...")
pyc = PYCFile(file)
logger.info(f"Loading {pyc}...")
if isinstance(pyc, Path):
pyc = PYCFile(pyc)
# try to auto resolve version
if version is None:
@@ -490,10 +468,16 @@ def decompile(file: Path, out_dir: Path, config_file: Path | None = None, versio
segmenter, translator = load_models(config_file, pversion)
logger.info(f"Decompiling pyc {file.resolve()} to {out_dir.resolve()}")
result = Decompiler(pyc, out_dir, segmenter, translator, pversion, top_k, trust_lnotab).result
if save_to:
logger.info(f"Decompiling pyc {pyc.pyc_path.resolve() if pyc.pyc_path else repr(pyc)} to {save_to.resolve()}")
else:
logger.info(f"Decompiling pyc {pyc.pyc_path.resolve() if pyc.pyc_path else repr(pyc)}")
decompiler = Decompiler(pyc, segmenter, translator, pversion, top_k, trust_lnotab)
result = decompiler()
logger.info("Decompilation complete")
logger.info(f"{round(result.calculate_success_rate(), 2)}% code object success rate")
logger.info(f"Result saved to {result.decompiled_source.resolve()}")
logger.info(f"{result.calculate_success_rate():.2%} code object success rate")
if save_to:
save_to.write_text(result.decompiled_source)
logger.info(f"Result saved to {save_to}")
return result
+2 -18
View File
@@ -14,6 +14,7 @@ class PYCFile(EditableBytecode):
def __init__(self, source, name_prefix=None):
self.pyc_path = None
self.source = source
source_tuple = (None, None, None, None, None, None, None)
if isinstance(source, bytes):
source = BytesIO(source)
@@ -46,24 +47,7 @@ class PYCFile(EditableBytecode):
)
def copy(self):
try:
copy = PYCFile(None)
EditableBytecode.__init__(copy, self.to_code(), self.opcode, self.version, self.name_prefix, False)
except IndexError:
copy = EditableBytecode.copy(self)
for attr in (
"version",
"timestamp",
"magic",
"code",
"ispypy",
"source_size",
"sip_hash",
):
setattr(copy, attr, getattr(self, attr))
return copy
return PYCFile(self.source)
def save(self, file, should_close=True, no_lnotab=False):
"""Saves the current recursive bytecode to the specified file."""
+21 -13
View File
@@ -5,7 +5,7 @@ from dataclasses import dataclass
from pathlib import Path
import networkx as nx
from pylingual.control_flow_reconstruction.structure_control_flow import condense_basic_blocks
from pylingual.control_flow_reconstruction.cfg import CFG
from pylingual.editable_bytecode import EditableBytecode, Inst, PYCFile
from pylingual.editable_bytecode.bytecode_patches import fix_indirect_jump, fix_unreachable, remove_extended_arg, remove_nop
from pylingual.editable_bytecode.control_flow_graph import bytecode_to_control_flow_graph
@@ -115,11 +115,19 @@ class TestResult:
success: bool
message: str
name_a: str
name_b: str
bc_a: EditableBytecode | None
bc_b: EditableBytecode | None
failed_line_number: int | None = None
failed_offset: int | None = None
@property
def name_a(self) -> str:
return self.bc_a.name if self.bc_a is not None else "None"
@property
def name_b(self) -> str:
return self.bc_a.name if self.bc_a is not None else "None"
def names(self):
if self.name_a == self.name_b:
return self.name_a
@@ -169,7 +177,7 @@ def matching_iter(pyc_a, pyc_b):
i_b += 1
def compare_pyc(pyc_path_a: Path, pyc_path_b: Path) -> list[TestResult]:
def compare_pyc(pyc_a: PYCFile | Path, pyc_b: PYCFile | Path) -> list[TestResult]:
"""
Tests the control flow of the two pyc files
Should not be imported as it relies on TestResult class.
@@ -180,8 +188,8 @@ def compare_pyc(pyc_path_a: Path, pyc_path_b: Path) -> list[TestResult]:
:param pyc_path_b: Second pyc to compare
"""
pyc_a = PYCFile(pyc_path_a)
pyc_b = PYCFile(pyc_path_b)
pyc_a = pyc_a.copy() if isinstance(pyc_a, PYCFile) else PYCFile(pyc_a)
pyc_b = pyc_b.copy() if isinstance(pyc_b, PYCFile) else PYCFile(pyc_b)
pyc_a.apply_patches([remove_extended_arg, remove_nop, fix_indirect_jump, fix_unreachable, remove_extended_arg])
pyc_b.apply_patches([remove_extended_arg, remove_nop, fix_indirect_jump, fix_unreachable, remove_extended_arg])
@@ -190,29 +198,29 @@ def compare_pyc(pyc_path_a: Path, pyc_path_b: Path) -> list[TestResult]:
for bytecode_a, bytecode_b in matching_iter(pyc_a, pyc_b):
if bytecode_a is None:
test_result = TestResult(False, "Extra bytecode", "None", bytecode_b.name)
test_result = TestResult(False, "Extra bytecode", None, bytecode_b)
results.append(test_result)
continue
if bytecode_b is None:
test_result = TestResult(False, "Missing bytecode", bytecode_a.name, "None")
test_result = TestResult(False, "Missing bytecode", bytecode_a, None)
results.append(test_result)
continue
cfg_a = bytecode_to_control_flow_graph(bytecode_a)
cfg_b = bytecode_to_control_flow_graph(bytecode_b)
block_graph_a = condense_basic_blocks(cfg_a)
block_graph_b = condense_basic_blocks(cfg_b)
block_graph_a = CFG.from_graph(cfg_a, bytecode_a)
block_graph_b = CFG.from_graph(cfg_b, bytecode_b)
if not is_control_flow_equivalent(block_graph_a, block_graph_b):
test_result = TestResult(False, "Different control flow", bytecode_a.name, bytecode_b.name)
test_result = TestResult(False, "Different control flow", bytecode_a, bytecode_b)
results.append(test_result)
continue
bytecode_result = compare_bytecode(bytecode_a, bytecode_b)
if not bytecode_result.result:
test_result = TestResult(False, "Different bytecode", bytecode_a.name, bytecode_b.name, bytecode_result.failed_line, bytecode_result.failed_offset)
test_result = TestResult(False, "Different bytecode", bytecode_a, bytecode_b, bytecode_result.failed_line, bytecode_result.failed_offset)
results.append(test_result)
continue
test_result = TestResult(True, "Equal", bytecode_a.name, bytecode_b.name)
test_result = TestResult(True, "Equal", bytecode_a, bytecode_b)
results.append(test_result)
return results
+19 -14
View File
@@ -1,14 +1,13 @@
from typing import TYPE_CHECKING
import click
import logging
import shutil
import platform
import subprocess
import os
from pathlib import Path
import pylingual.utils.ascii_art as ascii_art
from pylingual.utils.generate_bytecode import CompileError
from pylingual.utils.generate_bytecode import CompileError, has_pyenv
from pylingual.utils.version import PythonVersion, supported_versions
from pylingual.utils.tracked_list import TrackedList, SEGMENTATION_STEP, TRANSLATION_STEP, CFLOW_STEP, CORRECTION_STEP
from pylingual.utils.lazy import lazy_import
@@ -41,8 +40,9 @@ def print_header():
console.rule()
def print_result(file: str, result: DecompilerResult):
table = Table(title=f"Equivalence Results for {file}")
def print_result(result: DecompilerResult):
pyc = result.original_pyc
table = Table(title=f"Equivalence Results for {pyc.pyc_path.name if pyc.pyc_path else repr(pyc)}")
table.add_column("Code Object")
table.add_column("Success")
table.add_column("Message")
@@ -78,8 +78,8 @@ def main(files: list[str], out_dir: Path | None, config_file: Path | None, versi
if init_pyenv and (not install_pyenv() or not files):
return
if out_dir is not None:
out_dir.mkdir(parents=True, exist_ok=True)
if out_dir:
Path(out_dir).mkdir(parents=True, exist_ok=True)
progress = Progress(
TextColumn("[progress.description]{task.description}"),
@@ -98,10 +98,10 @@ def main(files: list[str], out_dir: Path | None, config_file: Path | None, versi
TrackedList.init = init
TrackedList.progress = lambda self, i: progress.advance(self.task.id, i)
# the step is not done until the TrackedList is deleted
TrackedList.__del__ = lambda self: progress.advance(self.task.id, float("inf"))
TrackedList.__del__ = lambda self: progress.advance(self.task.id, 9e999)
n = len(files)
with Live(Group(Rule(), status, progress), transient=True, console=console, refresh_per_second=12.5):
with Live(Group(Rule(), status, progress), transient=True, console=console, refresh_per_second=12.5) as live:
transformers.logging.disable_default_handler()
transformers.logging.add_handler(log_handler)
progress.add_task(SEGMENTATION_STEP, start=False)
@@ -112,28 +112,32 @@ def main(files: list[str], out_dir: Path | None, config_file: Path | None, versi
for task in progress.tasks:
progress.reset(task.id, start=False)
pyc_path = Path(file)
log_handler.keywords = [file, pyc_path.name, pyc_path.with_suffix(".py").name]
log_handler.keywords = [file, pyc_path.name, pyc_path.with_suffix(".py").name, "decompiled_" + pyc_path.with_suffix(".py").name]
status.update(f"Decompiling {pyc_path} ({i + 1} / {n})")
if not pyc_path.exists():
raise FileNotFoundError(f"pyc file {pyc_path} does not exist")
try:
result = decompile(
file=pyc_path,
out_dir=out_dir / f"decompiled_{pyc_path.stem}" if out_dir is not None else Path(f"decompiled_{pyc_path.stem}"),
pyc=pyc_path,
save_to=Path(f"{out_dir}/decompiled_{pyc_path.with_suffix('.py').name}" if out_dir else f"decompiled_{pyc_path.with_suffix('.py').name}"),
config_file=Path(config_file) if config_file else None,
version=version,
top_k=top_k,
trust_lnotab=trust_lnotab,
)
print_result(pyc_path.name, result)
print_result(result)
except Exception:
import pdb
live.stop()
pdb.xpm()
logger.exception(f"Failed to decompile {pyc_path}")
console.rule()
def install_pyenv():
if shutil.which("pyenv") is not None:
if has_pyenv():
logger.warning("pyenv seems to already be installed, ignoring --init-pyenv...")
return True
if platform.system() not in ["Linux", "Darwin"] and not click.confirm("pyenv is probably not supported on your operating system. Continue?", default=False):
@@ -144,8 +148,9 @@ def install_pyenv():
if subprocess.run(cmd, shell=True).returncode != 0:
logger.error("pyenv install failed, exiting...")
return False
has_pyenv.cache_clear()
os.environ["PATH"] = f"{os.environ.get('PYENV_ROOT', os.path.expanduser('~/.pyenv'))}/bin:{os.environ['PATH']}"
if shutil.which("pyenv") is None:
if not has_pyenv():
logger.error("Could not find pyenv, exiting...")
return False
versions = click.prompt(
+21 -16
View File
@@ -4,11 +4,12 @@ import ast
import pathlib
import re
from copy import deepcopy
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Pattern
from pylingual.utils.use_escape_sequences import use_escape_sequences
from pylingual.utils.version import PythonVersion
if TYPE_CHECKING:
from pylingual.editable_bytecode import EditableBytecode
@@ -103,14 +104,28 @@ def restore_masked_source(file_path: pathlib.Path, masker: Masker, python_versio
def format_source_replacement(mask_value: str) -> str:
if mask_value is ...:
return "..."
if mask_value == 9e999: # infinity
return "9e999"
if type(mask_value) in (int, float) and mask_value < 0:
return f"({mask_value})"
if type(mask_value) != str:
return str(mask_value)
return mask_value
formatted_mask_value = use_escape_sequences(mask_value)
return formatted_mask_value
re_rel_pattern = re.compile(r"^(\s*)(import|from)\s*(\d+)(.*)", re.MULTILINE)
def unmask(source_line: str, replacements: dict, re_pattern: Pattern[str]):
def m(match):
s = match.span()
r = replacements[match.group()]
if s[0] == 0 or s[1] >= len(match.string) or match.string[s[0] - 1] not in "\"'{}" and match.string[s[1]] not in "\"'{}":
return r
return use_escape_sequences(r)
text = re_pattern.sub(m, source_line)
return re_rel_pattern.sub(lambda match: f"{match.group(1)}{match.group(2)} {'.' * int(match.group(3))}{match.group(4)}", text)
def fix_jump_targets(disasm: str) -> str:
@@ -126,22 +141,12 @@ def fix_jump_targets(disasm: str) -> str:
return result
def restore_masked_source_text(text: str, masker: Masker, python_version: PythonVersion) -> str:
def restore_masked_source_text(lines: list[str], masker: Masker) -> list[str]:
"""Creates a large regex of all the tokens and their respective values
Replaces everything in file text in one pass."""
replacements = {re.escape(v): format_source_replacement(k) for k, v in masker.global_tab.items()} # we use encode + decode so multiline strings get replaced correctly
replacements = {re.escape(v): format_source_replacement(k) for k, v in masker.global_tab.items()}
re_pattern = re.compile("|".join(replacements.keys()))
result = re_pattern.sub(lambda match: replacements[match.group()], text)
# replace imports with a module starting with a number, with that number amount of dots for relative imports
re_rel_pattern = r"^(\s*)(import|from)\s*(\d+)(.*)"
result_rel_imports = re.sub(re_rel_pattern, lambda match: f"{match.group(1)}{match.group(2)} {'.' * int(match.group(3))}{match.group(4)}", result, 0, re.MULTILINE)
# normalize with parse+unparse to catch replacement errors and simplify whitespace
try:
return ast.unparse(ast.parse(result_rel_imports, feature_version=python_version.as_tuple()))
except (SyntaxError, IndentationError):
return result_rel_imports
return [unmask(x, replacements, re_pattern) for x in lines]
# replace mask values to start at 0 and count up
+8
View File
@@ -3,13 +3,21 @@
import subprocess
import sys
import shlex
import shutil
import py_compile
import functools
from pylingual.utils.version import PythonVersion
class CompileError(Exception):
success = False
bc_a = None
@functools.cache
def has_pyenv():
return shutil.which("pyenv") is not None
def compile_version(py_file, out_file, version):
+15 -13
View File
@@ -1,17 +1,19 @@
escapes = {
"\\": "\\\\",
"'": "\\'",
'"': '\\"',
"\a": "\\a",
"\b": "\\b",
"\f": "\\f",
"\n": "\\n",
"\r": "\\r",
"\t": "\\t",
"\v": "\\v",
"\x00": "\\x00",
}
def use_escape_sequences(s):
escapes = {
"\\": "\\\\",
"'": "\\'",
'"': '\\"',
"\a": "\\a",
"\b": "\\b",
"\f": "\\f",
"\n": "\\n",
"\r": "\\r",
"\t": "\\t",
"\v": "\\v",
"\x00": "\\x00",
}
for a, b in escapes.items():
s = s.replace(a, b)
return s
+4 -1
View File
@@ -5,7 +5,7 @@ version_str = {f"{x[0]}{x[1]}": x for x in supported_tuples} | {f"{x[0]}.{x[1]}"
class PythonVersion:
major: int
minor: int
_t: tuple
_t: tuple[int, int]
@staticmethod
def normalize(x) -> tuple[int, int] | None:
@@ -59,6 +59,9 @@ class PythonVersion:
norm = PythonVersion.normalize(o)
return norm is not None and self._t < norm
def __hash__(self):
return hash(self._t)
def __getitem__(self, i):
return self._t[i]