Merge branch 'Loops/Breaks/Continues' into cflow-refactor

This commit is contained in:
Xinlong Hu
2025-07-20 17:38:27 -05:00
4 changed files with 226 additions and 5 deletions
@@ -75,6 +75,59 @@ class CFG(DiGraph_CFT):
self._create_dominator_tree()
return nx.dfs_postorder_nodes(self, source=self.start, sort_neighbors=lambda nodes: sorted(nodes, key=lambda x: x.offset, reverse=True))
def is_loop_header(self, node):
# Check all predecessors
for predecessor in self.predecessors(node):
# A back edge exists if the predecessor is dominated by this node
if self.dominates(node, predecessor):
return True
return False
def dfs_labeled_edges_no_loop(self, source=None, depth_limit=None, *, sort_neighbors=None):
if source is None:
# edges for all components
nodes = self
else:
# edges for components with source
nodes = [source]
if depth_limit is None:
depth_limit = len(self)
get_children = (
self.neighbors
if sort_neighbors is None
else lambda n: iter(sort_neighbors(self.neighbors(n)))
)
visited = set()
for start in nodes:
if start in visited:
continue
yield start, start, "forward"
visited.add(start)
stack = [(start, get_children(start))]
depth_now = 1
while stack:
parent, children = stack[-1]
for child in children:
if child in visited or self.is_loop_header(child) or not all(p in visited for p in self.predecessors(child)):
yield parent, child, "nontree"
else:
yield parent, child, "forward"
visited.add(child)
if depth_now < depth_limit:
stack.append((child, iter(get_children(child))))
depth_now += 1
break
else:
yield parent, child, "reverse-depth_limit"
else:
stack.pop()
depth_now -= 1
if stack:
yield stack[-1][0], parent, "reverse"
yield start, start, "reverse"
def apply_graphs(self):
graphs = self.iteration_graphs.pop()
if self.iteration_graphs:
@@ -1,13 +1,14 @@
from ..cft import ControlFlowTemplate, EdgeKind, register_template
from ..utils import T, N, defer_source_to, run_is, has_no_lines, with_instructions, has_instval, starting_instructions, to_indented_source, make_try_match, without_top_level_instructions
from .Loop import BreakTemplate
@register_template(1, 40)
class IfElse(ControlFlowTemplate):
template = T(
if_header=~N("if_body", "else_body").with_cond(without_top_level_instructions("WITH_EXCEPT_START", "CHECK_EXC_MATCH", "FOR_ITER")),
if_body=~N("tail.").with_in_deg(1),
else_body=~N("tail.").with_cond(without_top_level_instructions("RERAISE", "END_FINALLY")).with_in_deg(1),
if_body=N.tail().with_in_deg(1).of_type(BreakTemplate) | ~N("tail.").with_in_deg(1),
else_body=N.tail().with_in_deg(1).of_type(BreakTemplate) | ~N("tail.").with_cond(without_top_level_instructions("RERAISE", "END_FINALLY")).with_in_deg(1),
tail=N.tail(),
)
@@ -1,13 +1,21 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from ..cft import ControlFlowTemplate, EdgeKind, register_template
from ..utils import (
T,
N,
with_instructions,
exact_instructions,
has_no_lines,
condense_mapping,
defer_source_to,
starting_instructions,
to_indented_source,
make_try_match,
make_try_match,
)
if TYPE_CHECKING:
from pylingual.control_flow_reconstruction.cfg import CFG
@register_template(0, 1)
class ForLoop(ControlFlowTemplate):
@@ -40,6 +48,24 @@ class SelfLoop(ControlFlowTemplate):
{loop_body}
"""
@register_template(0, 2)
class TrueSelfLoop(ControlFlowTemplate):
template = T(
loop_body=~N("tail.", "loop_body"),
tail=N.tail())
try_match = make_try_match(
{
EdgeKind.Fall: "tail",
},
"loop_body"
)
@to_indented_source
def to_indented_source():
"""
{loop_body}
"""
@register_template(0, 3)
class InlinedComprehensionTemplate(ControlFlowTemplate):
@@ -58,3 +84,95 @@ class InlinedComprehensionTemplate(ControlFlowTemplate):
)
to_indented_source = defer_source_to("comp")
class BreakTemplate(ControlFlowTemplate):
@classmethod
def try_match(cls, cfg, node):
if isinstance(node, BreakTemplate) or has_no_lines(cfg, node) or with_instructions("RAISE_VARARGS")(cfg, node):
return None
return condense_mapping(cls, cfg, {'child': node}, 'child')
def to_indented_source(self, source):
return self.child.to_indented_source(source) + self.line('break')
class ContinueTemplate(ControlFlowTemplate):
@classmethod
def try_match(cls, cfg, node):
if isinstance(node, ContinueTemplate) or has_no_lines(cfg, node):
return None
instruction = node.get_instructions()[-1].opname
if instruction in {"JUMP_ABSOLUTE", "JUMP_BACKWARD", "CONTINUE_LOOP"} and (node.get_instructions()[-1].starts_line is not None or node.get_instructions()[-2].starts_line is not None):
return condense_mapping(cls, cfg, {'child': node}, 'child')
return None
def to_indented_source(self, source):
return self.child.to_indented_source(source) + self.line('continue')
@register_template(0, 0)
class FixLoop(ControlFlowTemplate):
@classmethod
def try_match(cls, cfg: CFG, node: ControlFlowTemplate) -> ControlFlowTemplate | None:
# check that its a loop that we need to fix
# find the end of the loop
# find all nodes that belong to the loop
# find nodes in loop that go to end
# replace those edges with meta edges to the end
# find nodes in loop that go to header
# replace all but last of those edges with meta edge to end
# a node is a loop header if there are back-edges to it
# a latching node is a node with a back-edge to the loop header
# a back-edge is an edge from any node that is dominated by this node
back_edges = []
for predecessor in cfg.predecessors(node):
# A back edge exists if the predecessor is reachable from the node (node dominates predecessor)
if cfg.dominates(node, predecessor):
back_edges.append(predecessor)
if not back_edges:
return None
# Get all nodes encompassed by the loop excluding source node and initial false jump
loopnode = None
for succ in cfg.successors(node):
if cfg.get_edge_data(node, succ).get("kind") == EdgeKind.Fall:
loopnode = succ
break
dfs_edges = cfg.dfs_labeled_edges_no_loop(source=loopnode)
encompassed_nodes = [v for u, v, d in dfs_edges if d == "forward"]
edges_to_remove = []
# Find the candidate end that break connects to
candidate_end = None
for succ in cfg.successors(node):
if cfg.get_edge_data(node, succ).get("kind") == EdgeKind.FalseJump and cfg.out_degree(succ) <= 1:
candidate_end = succ
# Candidate end is a buffer node
if cfg.in_degree(candidate_end) == 1 and any(exact_instructions(*op)(cfg, candidate_end) for op in [
("POP_BLOCK",), ("END_FOR",), ("END_FOR", "POP_TOP"), ("LOAD_CONST", "RETURN_VALUE")]):
for ss in cfg.successors(candidate_end):
if cfg.out_degree(ss) <= 1:
candidate_end = ss
break
if encompassed_nodes is not None:
for succ in encompassed_nodes:
if cfg.get_edge_data(succ, candidate_end) != None:
edges_to_remove.append((succ, candidate_end))
for pred, succ in edges_to_remove:
break_node = BreakTemplate.try_match(cfg, pred)
if break_node is not None:
cfg.remove_edge(break_node, succ)
for candidate in back_edges:
cont_node = ContinueTemplate.try_match(cfg, candidate)
if cont_node is not None and cfg.in_degree(node) > 2:
cfg.remove_edge(cont_node, node)
cfg.iterate()
return
+51 -2
View File
@@ -16,12 +16,14 @@ def b1_for_over_tuples_nofallthru():
print("tuples")
print("end")
# 3.6/3.7 No else template
def c0_for_else():
for i in range(3):
print("for body")
else:
print("for else")
# 3.6/3.7 No else template
def c1_for_else_nofallthru():
for i in range(3):
print("for body")
@@ -29,14 +31,15 @@ def c1_for_else_nofallthru():
print("for else")
print("end")
# Fails due to no break
# 3.6/3.7 Naive break detection, an unexpected buffer POP_BLOCK to end
# 3.9 Naive break detection, an unexpected buffer block to end
def d0_for_with_break():
for x in range(10):
if x == 5:
print("breaking")
break
# Fails due to no break
# 3.6/3.7 Naive break detection, an unexpected buffer POP_BLOCK to end
def d1_for_with_break_nofallthru():
for x in range(10):
if x == 5:
@@ -135,6 +138,8 @@ def j1_for_with_empty_body_ellipsis_nofallthru():
...
print("end")
# 3.6/3.7 Naive break detection, no back edge
# 3.9/3.11 No while loop detection, self false_jump edge & naive break detection
def k0_while_true_with_break():
x = 0
while True:
@@ -143,6 +148,8 @@ def k0_while_true_with_break():
if x >= 1:
break
# 3.6/3.7 Naive break detection, no back edge
# 3.9/3.11 No while loop detection, self false_jump edge & naive break detection
def k1_while_true_with_break_nofallthru():
x = 0
while True:
@@ -152,6 +159,8 @@ def k1_while_true_with_break_nofallthru():
break
print("end")
# 3.6/3.7 No else template
# 3.11 No while loop detection, self false_jump edge
def l0_while_with_else():
i = 0
while i < 3:
@@ -160,6 +169,8 @@ def l0_while_with_else():
else:
print("while else")
# 3.6/3.7 No else template
# 3.11 No while loop detection, self false_jump edge
def l1_while_with_else_nofallthru():
i = 0
while i < 3:
@@ -169,6 +180,7 @@ def l1_while_with_else_nofallthru():
print("while else")
print("end")
# 3.11 No continue
def m0_while_with_continue():
i = 0
while i < 5:
@@ -178,6 +190,7 @@ def m0_while_with_continue():
continue
print("after continue")
# 3.11 No continue
def m1_while_with_continue_nofallthru():
i = 0
while i < 5:
@@ -188,12 +201,14 @@ def m1_while_with_continue_nofallthru():
print("after continue")
print("end")
# 3.6/3.7 Naive break detection, no back edge
def n0_while_with_break():
i = 0
while True:
print("break in while")
break
# 3.6/3.7 Naive break detection, no back edge
def n1_while_with_break_nofallthru():
i = 0
while True:
@@ -201,6 +216,7 @@ def n1_while_with_break_nofallthru():
break
print("end")
# 3.11 While template broke
def o0_nested_while_loops():
i = 0
while i < 2:
@@ -210,6 +226,7 @@ def o0_nested_while_loops():
j += 1
i += 1
# 3.11 While template broke
def o1_nested_while_loops_nofallthru():
i = 0
while i < 2:
@@ -220,6 +237,8 @@ def o1_nested_while_loops_nofallthru():
i += 1
print("end")
# 3.6/3.7 While template broke (?)
# 3.9 Disconnected with MetaTemplate[end] (?)
def p0_while_with_try_except():
while True:
try:
@@ -227,6 +246,8 @@ def p0_while_with_try_except():
except:
print("except in while")
# 3.6/3.7 While template broke (?)
# 3.9 Disconnected with MetaTemplate[end] (?)
def p1_while_with_try_except_nofallthru():
while True:
try:
@@ -235,34 +256,40 @@ def p1_while_with_try_except_nofallthru():
print("except in while")
print("end")
# 3.6/3.7 While template broke (?) abandoning nodes
def q0_while_with_with_statement():
while True:
with a:
print("inside while with")
# 3.6/3.7 While template broke (?) abandoning nodes
def q1_while_with_with_statement_nofallthru():
while True:
with a:
print("inside while with")
print("end")
# 3.6/3.7 While template broke
def r0_for_inside_while():
while True:
for x in [1, 2]:
print("for in while")
# 3.6/3.7 While template broke
def r1_for_inside_while_nofallthru():
while True:
for x in [1, 2]:
print("for in while")
print("end")
# 3.6/3.7 While template broke
def s0_while_inside_for():
for _ in range(1):
while True:
print("while in for")
break
# 3.6/3.7 While template broke
def s1_while_inside_for_nofallthru():
for _ in range(1):
while True:
@@ -270,10 +297,12 @@ def s1_while_inside_for_nofallthru():
break
print("end")
# 3.6/3.7 While template broke
def t0_while_with_empty_body_ellipsis():
while True:
...
# 3.6/3.7 While template broke
def t1_while_with_empty_body_ellipsis_nofallthru():
while True:
...
@@ -311,6 +340,7 @@ def v1_continue_in_nested_for_nofallthru():
print(f"Processing i={i}, j={j}")
print("end")
# 3.13 if statement putting code in the else block
def w0_break_with_else():
for i in range(5):
if i == 3:
@@ -319,6 +349,7 @@ def w0_break_with_else():
else:
print("This won't execute due to break")
# 3.13 if statement putting code in the else block
def w1_break_with_else_nofallthru():
for i in range(5):
if i == 3:
@@ -328,6 +359,7 @@ def w1_break_with_else_nofallthru():
print("This won't execute due to break")
print("end")
# 3.6/3.7 No continue detection
def x0_continue_with_else():
for i in range(3):
if i == 1:
@@ -336,6 +368,7 @@ def x0_continue_with_else():
else:
print("Else clause still executes after continue")
# 3.6/3.7 No continue detection
def x1_continue_with_else_nofallthru():
for i in range(3):
if i == 1:
@@ -345,6 +378,7 @@ def x1_continue_with_else_nofallthru():
print("Else clause still executes after continue")
print("end")
# 3.9/3.11 Naive break detection, break statement is further up
def y0_break_in_try_except():
for i in range(5):
try:
@@ -354,6 +388,7 @@ def y0_break_in_try_except():
except:
print("Exception occurred")
# 3.9 Naive break detection, break statement is further up
def y1_break_in_try_except_nofallthru():
for i in range(5):
try:
@@ -364,6 +399,19 @@ def y1_break_in_try_except_nofallthru():
print("Exception occurred")
print("end")
# 3.9 Naive break detection, break statement is further up
def y2_return_in_try_except_nofallthru():
for i in range(5):
try:
if i == 3:
print(f"Value: {i}")
else:
break
except:
print("Exception occurred")
print("end")
# 3.6/3.9 No continue detection
def z0_continue_in_try_except():
for i in range(5):
try:
@@ -373,6 +421,7 @@ def z0_continue_in_try_except():
except:
print("Exception occurred")
# 3.6/3.9/3.11 No continue detection
def z1_continue_in_try_except_nofallthru():
for i in range(5):
try: