diff --git a/pylingual/control_flow_reconstruction/cfg.py b/pylingual/control_flow_reconstruction/cfg.py index 1e6c21d..41091f5 100644 --- a/pylingual/control_flow_reconstruction/cfg.py +++ b/pylingual/control_flow_reconstruction/cfg.py @@ -75,6 +75,59 @@ class CFG(DiGraph_CFT): self._create_dominator_tree() return nx.dfs_postorder_nodes(self, source=self.start, sort_neighbors=lambda nodes: sorted(nodes, key=lambda x: x.offset, reverse=True)) + def is_loop_header(self, node): + # Check all predecessors + for predecessor in self.predecessors(node): + # A back edge exists if the predecessor is dominated by this node + if self.dominates(node, predecessor): + return True + return False + + def dfs_labeled_edges_no_loop(self, source=None, depth_limit=None, *, sort_neighbors=None): + if source is None: + # edges for all components + nodes = self + else: + # edges for components with source + nodes = [source] + if depth_limit is None: + depth_limit = len(self) + + get_children = ( + self.neighbors + if sort_neighbors is None + else lambda n: iter(sort_neighbors(self.neighbors(n))) + ) + + visited = set() + for start in nodes: + if start in visited: + continue + yield start, start, "forward" + visited.add(start) + stack = [(start, get_children(start))] + depth_now = 1 + while stack: + parent, children = stack[-1] + for child in children: + if child in visited or self.is_loop_header(child) or not all(p in visited for p in self.predecessors(child)): + yield parent, child, "nontree" + else: + yield parent, child, "forward" + visited.add(child) + if depth_now < depth_limit: + stack.append((child, iter(get_children(child)))) + depth_now += 1 + break + else: + yield parent, child, "reverse-depth_limit" + else: + stack.pop() + depth_now -= 1 + if stack: + yield stack[-1][0], parent, "reverse" + yield start, start, "reverse" + def apply_graphs(self): graphs = self.iteration_graphs.pop() if self.iteration_graphs: diff --git a/pylingual/control_flow_reconstruction/templates/Conditional.py b/pylingual/control_flow_reconstruction/templates/Conditional.py index f163981..f0f06a8 100644 --- a/pylingual/control_flow_reconstruction/templates/Conditional.py +++ b/pylingual/control_flow_reconstruction/templates/Conditional.py @@ -1,13 +1,14 @@ from ..cft import ControlFlowTemplate, EdgeKind, register_template from ..utils import T, N, defer_source_to, run_is, has_no_lines, with_instructions, has_instval, starting_instructions, to_indented_source, make_try_match, without_top_level_instructions +from .Loop import BreakTemplate @register_template(1, 40) class IfElse(ControlFlowTemplate): template = T( if_header=~N("if_body", "else_body").with_cond(without_top_level_instructions("WITH_EXCEPT_START", "CHECK_EXC_MATCH", "FOR_ITER")), - if_body=~N("tail.").with_in_deg(1), - else_body=~N("tail.").with_cond(without_top_level_instructions("RERAISE", "END_FINALLY")).with_in_deg(1), + if_body=N.tail().with_in_deg(1).of_type(BreakTemplate) | ~N("tail.").with_in_deg(1), + else_body=N.tail().with_in_deg(1).of_type(BreakTemplate) | ~N("tail.").with_cond(without_top_level_instructions("RERAISE", "END_FINALLY")).with_in_deg(1), tail=N.tail(), ) diff --git a/pylingual/control_flow_reconstruction/templates/Loop.py b/pylingual/control_flow_reconstruction/templates/Loop.py index 3f56a0c..0745c94 100644 --- a/pylingual/control_flow_reconstruction/templates/Loop.py +++ b/pylingual/control_flow_reconstruction/templates/Loop.py @@ -1,13 +1,21 @@ +from __future__ import annotations +from typing import TYPE_CHECKING from ..cft import ControlFlowTemplate, EdgeKind, register_template from ..utils import ( T, N, + with_instructions, + exact_instructions, + has_no_lines, + condense_mapping, defer_source_to, starting_instructions, to_indented_source, - make_try_match, + make_try_match, ) +if TYPE_CHECKING: + from pylingual.control_flow_reconstruction.cfg import CFG @register_template(0, 1) class ForLoop(ControlFlowTemplate): @@ -40,6 +48,24 @@ class SelfLoop(ControlFlowTemplate): {loop_body} """ +@register_template(0, 2) +class TrueSelfLoop(ControlFlowTemplate): + template = T( + loop_body=~N("tail.", "loop_body"), + tail=N.tail()) + + try_match = make_try_match( + { + EdgeKind.Fall: "tail", + }, + "loop_body" + ) + + @to_indented_source + def to_indented_source(): + """ + {loop_body} + """ @register_template(0, 3) class InlinedComprehensionTemplate(ControlFlowTemplate): @@ -58,3 +84,95 @@ class InlinedComprehensionTemplate(ControlFlowTemplate): ) to_indented_source = defer_source_to("comp") + +class BreakTemplate(ControlFlowTemplate): + @classmethod + def try_match(cls, cfg, node): + if isinstance(node, BreakTemplate) or has_no_lines(cfg, node) or with_instructions("RAISE_VARARGS")(cfg, node): + return None + return condense_mapping(cls, cfg, {'child': node}, 'child') + + def to_indented_source(self, source): + return self.child.to_indented_source(source) + self.line('break') + +class ContinueTemplate(ControlFlowTemplate): + @classmethod + def try_match(cls, cfg, node): + if isinstance(node, ContinueTemplate) or has_no_lines(cfg, node): + return None + instruction = node.get_instructions()[-1].opname + if instruction in {"JUMP_ABSOLUTE", "JUMP_BACKWARD", "CONTINUE_LOOP"} and (node.get_instructions()[-1].starts_line is not None or node.get_instructions()[-2].starts_line is not None): + return condense_mapping(cls, cfg, {'child': node}, 'child') + return None + + def to_indented_source(self, source): + return self.child.to_indented_source(source) + self.line('continue') + +@register_template(0, 0) +class FixLoop(ControlFlowTemplate): + @classmethod + def try_match(cls, cfg: CFG, node: ControlFlowTemplate) -> ControlFlowTemplate | None: + # check that its a loop that we need to fix + # find the end of the loop + # find all nodes that belong to the loop + # find nodes in loop that go to end + # replace those edges with meta edges to the end + # find nodes in loop that go to header + # replace all but last of those edges with meta edge to end + + # a node is a loop header if there are back-edges to it + # a latching node is a node with a back-edge to the loop header + # a back-edge is an edge from any node that is dominated by this node + back_edges = [] + for predecessor in cfg.predecessors(node): + + # A back edge exists if the predecessor is reachable from the node (node dominates predecessor) + if cfg.dominates(node, predecessor): + back_edges.append(predecessor) + + if not back_edges: + return None + + # Get all nodes encompassed by the loop excluding source node and initial false jump + loopnode = None + for succ in cfg.successors(node): + if cfg.get_edge_data(node, succ).get("kind") == EdgeKind.Fall: + loopnode = succ + break + + dfs_edges = cfg.dfs_labeled_edges_no_loop(source=loopnode) + encompassed_nodes = [v for u, v, d in dfs_edges if d == "forward"] + + edges_to_remove = [] + + # Find the candidate end that break connects to + candidate_end = None + for succ in cfg.successors(node): + if cfg.get_edge_data(node, succ).get("kind") == EdgeKind.FalseJump and cfg.out_degree(succ) <= 1: + candidate_end = succ + + # Candidate end is a buffer node + if cfg.in_degree(candidate_end) == 1 and any(exact_instructions(*op)(cfg, candidate_end) for op in [ + ("POP_BLOCK",), ("END_FOR",), ("END_FOR", "POP_TOP"), ("LOAD_CONST", "RETURN_VALUE")]): + for ss in cfg.successors(candidate_end): + if cfg.out_degree(ss) <= 1: + candidate_end = ss + break + + if encompassed_nodes is not None: + for succ in encompassed_nodes: + if cfg.get_edge_data(succ, candidate_end) != None: + edges_to_remove.append((succ, candidate_end)) + + for pred, succ in edges_to_remove: + break_node = BreakTemplate.try_match(cfg, pred) + if break_node is not None: + cfg.remove_edge(break_node, succ) + + for candidate in back_edges: + cont_node = ContinueTemplate.try_match(cfg, candidate) + if cont_node is not None and cfg.in_degree(node) > 2: + cfg.remove_edge(cont_node, node) + + cfg.iterate() + return \ No newline at end of file diff --git a/test/Loop.py b/test/Loop.py index 135e0a0..8c730bf 100644 --- a/test/Loop.py +++ b/test/Loop.py @@ -16,12 +16,14 @@ def b1_for_over_tuples_nofallthru(): print("tuples") print("end") +# 3.6/3.7 No else template def c0_for_else(): for i in range(3): print("for body") else: print("for else") +# 3.6/3.7 No else template def c1_for_else_nofallthru(): for i in range(3): print("for body") @@ -29,14 +31,15 @@ def c1_for_else_nofallthru(): print("for else") print("end") -# Fails due to no break +# 3.6/3.7 Naive break detection, an unexpected buffer POP_BLOCK to end +# 3.9 Naive break detection, an unexpected buffer block to end def d0_for_with_break(): for x in range(10): if x == 5: print("breaking") break -# Fails due to no break +# 3.6/3.7 Naive break detection, an unexpected buffer POP_BLOCK to end def d1_for_with_break_nofallthru(): for x in range(10): if x == 5: @@ -135,6 +138,8 @@ def j1_for_with_empty_body_ellipsis_nofallthru(): ... print("end") +# 3.6/3.7 Naive break detection, no back edge +# 3.9/3.11 No while loop detection, self false_jump edge & naive break detection def k0_while_true_with_break(): x = 0 while True: @@ -143,6 +148,8 @@ def k0_while_true_with_break(): if x >= 1: break +# 3.6/3.7 Naive break detection, no back edge +# 3.9/3.11 No while loop detection, self false_jump edge & naive break detection def k1_while_true_with_break_nofallthru(): x = 0 while True: @@ -152,6 +159,8 @@ def k1_while_true_with_break_nofallthru(): break print("end") +# 3.6/3.7 No else template +# 3.11 No while loop detection, self false_jump edge def l0_while_with_else(): i = 0 while i < 3: @@ -160,6 +169,8 @@ def l0_while_with_else(): else: print("while else") +# 3.6/3.7 No else template +# 3.11 No while loop detection, self false_jump edge def l1_while_with_else_nofallthru(): i = 0 while i < 3: @@ -169,6 +180,7 @@ def l1_while_with_else_nofallthru(): print("while else") print("end") +# 3.11 No continue def m0_while_with_continue(): i = 0 while i < 5: @@ -178,6 +190,7 @@ def m0_while_with_continue(): continue print("after continue") +# 3.11 No continue def m1_while_with_continue_nofallthru(): i = 0 while i < 5: @@ -188,12 +201,14 @@ def m1_while_with_continue_nofallthru(): print("after continue") print("end") +# 3.6/3.7 Naive break detection, no back edge def n0_while_with_break(): i = 0 while True: print("break in while") break +# 3.6/3.7 Naive break detection, no back edge def n1_while_with_break_nofallthru(): i = 0 while True: @@ -201,6 +216,7 @@ def n1_while_with_break_nofallthru(): break print("end") +# 3.11 While template broke def o0_nested_while_loops(): i = 0 while i < 2: @@ -210,6 +226,7 @@ def o0_nested_while_loops(): j += 1 i += 1 +# 3.11 While template broke def o1_nested_while_loops_nofallthru(): i = 0 while i < 2: @@ -220,6 +237,8 @@ def o1_nested_while_loops_nofallthru(): i += 1 print("end") +# 3.6/3.7 While template broke (?) +# 3.9 Disconnected with MetaTemplate[end] (?) def p0_while_with_try_except(): while True: try: @@ -227,6 +246,8 @@ def p0_while_with_try_except(): except: print("except in while") +# 3.6/3.7 While template broke (?) +# 3.9 Disconnected with MetaTemplate[end] (?) def p1_while_with_try_except_nofallthru(): while True: try: @@ -235,34 +256,40 @@ def p1_while_with_try_except_nofallthru(): print("except in while") print("end") +# 3.6/3.7 While template broke (?) abandoning nodes def q0_while_with_with_statement(): while True: with a: print("inside while with") +# 3.6/3.7 While template broke (?) abandoning nodes def q1_while_with_with_statement_nofallthru(): while True: with a: print("inside while with") print("end") +# 3.6/3.7 While template broke def r0_for_inside_while(): while True: for x in [1, 2]: print("for in while") +# 3.6/3.7 While template broke def r1_for_inside_while_nofallthru(): while True: for x in [1, 2]: print("for in while") print("end") +# 3.6/3.7 While template broke def s0_while_inside_for(): for _ in range(1): while True: print("while in for") break +# 3.6/3.7 While template broke def s1_while_inside_for_nofallthru(): for _ in range(1): while True: @@ -270,10 +297,12 @@ def s1_while_inside_for_nofallthru(): break print("end") +# 3.6/3.7 While template broke def t0_while_with_empty_body_ellipsis(): while True: ... +# 3.6/3.7 While template broke def t1_while_with_empty_body_ellipsis_nofallthru(): while True: ... @@ -311,6 +340,7 @@ def v1_continue_in_nested_for_nofallthru(): print(f"Processing i={i}, j={j}") print("end") +# 3.13 if statement putting code in the else block def w0_break_with_else(): for i in range(5): if i == 3: @@ -319,6 +349,7 @@ def w0_break_with_else(): else: print("This won't execute due to break") +# 3.13 if statement putting code in the else block def w1_break_with_else_nofallthru(): for i in range(5): if i == 3: @@ -328,6 +359,7 @@ def w1_break_with_else_nofallthru(): print("This won't execute due to break") print("end") +# 3.6/3.7 No continue detection def x0_continue_with_else(): for i in range(3): if i == 1: @@ -336,6 +368,7 @@ def x0_continue_with_else(): else: print("Else clause still executes after continue") +# 3.6/3.7 No continue detection def x1_continue_with_else_nofallthru(): for i in range(3): if i == 1: @@ -345,6 +378,7 @@ def x1_continue_with_else_nofallthru(): print("Else clause still executes after continue") print("end") +# 3.9/3.11 Naive break detection, break statement is further up def y0_break_in_try_except(): for i in range(5): try: @@ -354,6 +388,7 @@ def y0_break_in_try_except(): except: print("Exception occurred") +# 3.9 Naive break detection, break statement is further up def y1_break_in_try_except_nofallthru(): for i in range(5): try: @@ -364,6 +399,19 @@ def y1_break_in_try_except_nofallthru(): print("Exception occurred") print("end") +# 3.9 Naive break detection, break statement is further up +def y2_return_in_try_except_nofallthru(): + for i in range(5): + try: + if i == 3: + print(f"Value: {i}") + else: + break + except: + print("Exception occurred") + print("end") + +# 3.6/3.9 No continue detection def z0_continue_in_try_except(): for i in range(5): try: @@ -373,6 +421,7 @@ def z0_continue_in_try_except(): except: print("Exception occurred") +# 3.6/3.9/3.11 No continue detection def z1_continue_in_try_except_nofallthru(): for i in range(5): try: