While Loops + Breaks/Continues

This commit is contained in:
Xinlong Hu
2025-08-01 19:46:51 -05:00
parent d80b35ec3d
commit 6a3d869868
2 changed files with 110 additions and 21 deletions
@@ -1,17 +1,27 @@
from __future__ import annotations
from itertools import chain
from typing import TYPE_CHECKING
from pylingual.control_flow_reconstruction.source import SourceContext, SourceLine
from ..cft import ControlFlowTemplate, EdgeKind, register_template
from ..utils import (
T,
N,
no_back_edges,
versions_below,
versions_from,
with_instructions,
exact_instructions,
has_no_lines,
has_some_lines,
condense_mapping,
defer_source_to,
starting_instructions,
to_indented_source,
make_try_match,
with_top_level_instructions,
without_top_level_instructions,
)
if TYPE_CHECKING:
@@ -36,9 +46,11 @@ class ForLoop(ControlFlowTemplate):
"""
@register_template(0, 2)
class SelfLoop(ControlFlowTemplate):
template = T(loop_body=~N("loop_body", None))
@register_template(0, 2, *versions_below(3, 10))
class SelfLoop3_6(ControlFlowTemplate):
template = T(
loop_body=~N("loop_body", None)
)
try_match = make_try_match({}, "loop_body")
@@ -50,6 +62,26 @@ class SelfLoop(ControlFlowTemplate):
"""
@register_template(0, 2, *versions_from(3, 10))
class SelfLoop3_10(ControlFlowTemplate):
template = T(
loop_header=~N("loop_body", "RET_CONST?").with_cond(no_back_edges),
loop_body=~N("loop_body", None),
RET_CONST=N.tail(),
)
try_match = make_try_match({}, "loop_header", "loop_body", "RET_CONST")
def to_indented_source(self, source: SourceContext) -> list[SourceLine]:
header = source[self.loop_header]
body = source[self.loop_body, 1]
RET_CONST = source[self.RET_CONST]
if not any(source.lines[i.starts_line - 1].strip().startswith("while ") for i in self.loop_header.get_instructions() if i.starts_line is not None):
return list(chain(header, self.line("while True:"), body))
else:
return list(chain(header, body))
@register_template(0, 2)
class TrueSelfLoop(ControlFlowTemplate):
template = T(loop_body=~N("tail.", "loop_body"), tail=N.tail())
@@ -68,6 +100,28 @@ class TrueSelfLoop(ControlFlowTemplate):
"""
@register_template(1, 39)
class WhileIfElseLoop(ControlFlowTemplate):
template = T(
if_header=~N("if_body", "else_body").with_cond(without_top_level_instructions("WITH_EXCEPT_START", "CHECK_EXC_MATCH", "FOR_ITER")),
else_body=~N("if_header").with_in_deg(1),
if_body=~N("tail.").with_cond(without_top_level_instructions("RERAISE", "END_FINALLY")).with_in_deg(1),
tail=N.tail(),
)
try_match = make_try_match({EdgeKind.Fall: "tail"}, "if_header", "if_body", "else_body")
@to_indented_source
def to_indented_source():
"""
while True:
{if_header}
{if_body}
{else_body?else:}
{else_body}
"""
@register_template(0, 3)
class InlinedComprehensionTemplate(ControlFlowTemplate):
template = T(
@@ -90,7 +144,7 @@ class InlinedComprehensionTemplate(ControlFlowTemplate):
class BreakTemplate(ControlFlowTemplate):
@classmethod
def try_match(cls, cfg, node):
if isinstance(node, BreakTemplate) or has_no_lines(cfg, node):
if not with_top_level_instructions("POP_TOP", "LOAD_CONST", "RETURN_VALUE", "RETURN_CONST", "JUMP_ABSOLUTE", "JUMP_FORWARD", "JUMP_BACKWARD", "BREAK_LOOP")(cfg, node) or has_no_lines(cfg, node):
return None
i = len(node.get_instructions()) - 1
@@ -113,7 +167,7 @@ class BreakTemplate(ControlFlowTemplate):
class ContinueTemplate(ControlFlowTemplate):
@classmethod
def try_match(cls, cfg, node):
if isinstance(node, ContinueTemplate) or has_no_lines(cfg, node):
if not with_top_level_instructions("JUMP_ABSOLUTE", "JUMP_BACKWARD", "CONTINUE_LOOP", "POP_EXCEPT")(cfg, node) or has_no_lines(cfg, node):
return None
i = len(node.get_instructions()) - 1
@@ -153,8 +207,8 @@ class FixLoop(ControlFlowTemplate):
# A back edge exists if the predecessor is reachable from the node (node dominates predecessor)
if cfg.dominates(node, predecessor):
back_edges.append(predecessor)
if not back_edges:
if not back_edges or with_top_level_instructions("SEND")(cfg, node):
return None
# Get all nodes encompassed by the loop excluding source node and initial false jump
@@ -172,7 +226,7 @@ class FixLoop(ControlFlowTemplate):
# Find the candidate end that break connects to
candidate_end = None
for succ in cfg.successors(node):
if cfg.get_edge_data(node, succ).get("kind") == EdgeKind.FalseJump:
if cfg.get_edge_data(node, succ).get("kind") == EdgeKind.FalseJump and not any(n == node for n in cfg.successors(succ)):
candidate_end = succ
# Candidate end is a buffer node
@@ -181,21 +235,43 @@ class FixLoop(ControlFlowTemplate):
if cfg.get_edge_data(candidate_end, ss).get("kind") != EdgeKind.Exception:
candidate_end = ss
break
if encompassed_nodes is not None:
for succ in encompassed_nodes:
if cfg.get_edge_data(succ, candidate_end) != None:
edges_to_remove.append((succ, candidate_end))
for pred, succ in edges_to_remove:
break_node = BreakTemplate.try_match(cfg, pred)
if break_node is not None:
cfg.remove_edge(break_node, succ)
if candidate_end == None:
# While loops
for candidate in back_edges:
cont_node = ContinueTemplate.try_match(cfg, candidate)
if cont_node is not None and not cfg.has_edge(node, cont_node):
cfg.remove_edge(cont_node, node)
dfs_edges = cfg.dfs_labeled_edges_no_loop(source=node)
candidates = [v for u, v, d in dfs_edges if d == "forward"][1:]
for candidate in back_edges:
cont_node = ContinueTemplate.try_match(cfg, candidate)
if cont_node is not None and cfg.in_degree(node) > 2:
cfg.remove_edge(cont_node, node)
for n in candidates:
for s in cfg.successors(n):
if cfg.get_edge_data(n, s).get("kind") != EdgeKind.Exception and not all(cfg.get_edge_data(p, n).get("kind") == EdgeKind.Exception for p in cfg.predecessors(n)):
edges_to_remove.append((n, s))
for pred, succ in edges_to_remove:
break_node = BreakTemplate.try_match(cfg, pred)
if break_node is not None and cfg.in_degree(succ) > 2:
cfg.remove_edge(break_node, succ)
else:
# For loops
if encompassed_nodes is not None:
for succ in encompassed_nodes:
if cfg.get_edge_data(succ, candidate_end) != None:
edges_to_remove.append((succ, candidate_end))
for candidate in back_edges:
cont_node = ContinueTemplate.try_match(cfg, candidate)
if cont_node is not None and cfg.in_degree(node) > 2:
cfg.remove_edge(cont_node, node)
for pred, succ in edges_to_remove:
break_node = BreakTemplate.try_match(cfg, pred)
if break_node is not None:
cfg.remove_edge(break_node, succ)
cfg.iterate()
return
@@ -96,6 +96,19 @@ def without_top_level_instructions(*opnames: str):
return check_instructions
def with_top_level_instructions(*opnames: str):
from .templates.Block import BlockTemplate
def check_instructions(cfg: CFG, node: ControlFlowTemplate | None) -> bool:
if isinstance(node, BlockTemplate):
return any(x.inst.opname in opnames for x in node.members if isinstance(x, InstTemplate))
if isinstance(node, InstTemplate):
return node.inst.opname in opnames
return False
return check_instructions
def has_type(*template_type: type[ControlFlowTemplate]):
def check_type(cfg: CFG, node: ControlFlowTemplate | None) -> bool:
return isinstance(node, template_type)