diff --git a/pylingual/control_flow_reconstruction/templates/Match.py b/pylingual/control_flow_reconstruction/templates/Match.py new file mode 100644 index 0000000..87a694a --- /dev/null +++ b/pylingual/control_flow_reconstruction/templates/Match.py @@ -0,0 +1,118 @@ +from itertools import chain +from typing import override +from .Block import BlockTemplate +from ..cft import SourceContext, SourceLine, ControlFlowTemplate, EdgeKind, register_template +from ..utils import T, N, ending_instructions, has_start_end_source, versions_from, make_try_match + + +class CaseOne(ControlFlowTemplate): + template = T( + case_header=~N("case_body", None).with_cond(has_start_end_source("case", ":")), + case_body=~N("tail."), + tail=N.tail(), + ) + + try_match = make_try_match({EdgeKind.Fall: "tail"}, "case_header", "case_body") + + def to_indented_source(self, source: SourceContext) -> list[SourceLine]: + case_body = source[self.case_body] + + cutoff = next((i for i, x in enumerate(self.case_header.get_instructions()) if x.source_line.strip().startswith("case")), 0) + + if isinstance(self.case_header, BlockTemplate): + i = cutoff + 1 + case_header = source[BlockTemplate(self.case_header.members[:i]), 1] if i > 0 else [] + case_lines = source[BlockTemplate(self.case_header.members[i:]), 2] if i < len(self.case_header.members) else [] + else: + case_header = source[self.case_header, 1] + case_lines = [] + + return list(chain(case_header, case_lines, case_body)) + + +class CaseWrapper(ControlFlowTemplate): + @classmethod + @override + def try_match(cls, cfg, node) -> ControlFlowTemplate | None: + if x := CaseTwo.try_match(cfg, node): + return x + if x := CaseOne.try_match(cfg, node): + return x + + +@register_template(1, 0, *versions_from(3, 10)) +class CaseTwo(ControlFlowTemplate): + template = T( + case_header=~N("case_body", "other.").with_cond(has_start_end_source("case", ":")), + case_body=~N("other."), + other=~N("tail.").of_subtemplate(CaseWrapper) | N.tail(), + tail=N.tail(), + ) + + try_match = make_try_match({EdgeKind.Fall: "tail"}, "case_header", "case_body", "other") + + def to_indented_source(self, source: SourceContext) -> list[SourceLine]: + case_header = source[self.case_header, 1] + case_body = source[self.case_body, 2] + other = source[self.other] + + return list(chain(case_header, case_body, other)) + + +@register_template(0, 0, *versions_from(3, 10)) +class Match(ControlFlowTemplate): + template = T( + match_header=~N("case_body", "tail").with_cond(has_start_end_source("match", ":")), + case_body=~N("tail.").with_in_deg(1) | ~N("tail").with_in_deg(1).with_cond(ending_instructions("POP_TOP")), + tail=~N.tail().of_subtemplate(CaseWrapper) | N.tail(), + ) + + try_match = make_try_match({EdgeKind.Fall: "tail"}, "match_header", "case_body", "POP_TOP") + + def to_indented_source(self, source: SourceContext) -> list[SourceLine]: + match_line = None + case_line = None + case_body = source[self.case_body, 2] + + cutoff = next((i for i, x in enumerate(self.match_header.get_instructions()) if x.source_line.strip().startswith("match")), 0) + + if isinstance(self.match_header, BlockTemplate): + i = cutoff + 1 + match_line = source[BlockTemplate(self.match_header.members[:i])] if i > 0 else [] + case_line = source[BlockTemplate(self.match_header.members[i:]), 1] if i < len(self.match_header.members) else [] + else: + match_line = source[self.match_header, 1] + case_line = [] + + return list(chain(match_line, case_line, case_body)) + + +@register_template(0, 0, *versions_from(3, 10)) +class MultiMatch(ControlFlowTemplate): + template = T( + match_header=~N("multi_header", "POP_TOP").with_cond(has_start_end_source("match", ":")), + multi_header=~N("case_body", "POP_TOP"), + case_body=~N("tail.").with_in_deg(1) | ~N("tail").with_in_deg(1).with_cond(ending_instructions("POP_TOP")), + POP_TOP=~N("tail."), + tail=~N.tail().of_subtemplate(CaseWrapper) | ~N.tail(), + ) + + try_match = make_try_match({EdgeKind.Fall: "tail"}, "multi_header", "match_header", "case_body", "POP_TOP") + + def to_indented_source(self, source: SourceContext) -> list[SourceLine]: + match_line = None + case_line = None + case_body = source[self.case_body, 2] + + cutoff = next((i for i, x in enumerate(self.match_header.get_instructions()) if x.source_line.strip().startswith("match")), 0) + + if isinstance(self.match_header, BlockTemplate): + i = cutoff + 1 + match_line = source[BlockTemplate(self.match_header.members[:i])] if i > 0 else [] + case_line = source[BlockTemplate(self.match_header.members[i:]), 1] if i < len(self.match_header.members) else [] + else: + match_line = source[self.match_header, 1] + + case_line = [] + + return list(chain(match_line, case_line, case_body)) \ No newline at end of file diff --git a/pylingual/control_flow_reconstruction/utils.py b/pylingual/control_flow_reconstruction/utils.py index 16edadb..dc02c5c 100644 --- a/pylingual/control_flow_reconstruction/utils.py +++ b/pylingual/control_flow_reconstruction/utils.py @@ -145,6 +145,16 @@ def has_instval(opname: str, argval: Any): return check_instructions +def has_start_end_source(argval: Any, endval: Any): + def check_instructions(cfg: CFG, node: ControlFlowTemplate | None) -> bool: + for x in node.get_instructions(): + if x.source_line.startswith(argval) and x.source_line.endswith(endval): + return True + return False + + return check_instructions + + def has_no_lines(cfg: CFG, node: ControlFlowTemplate | None) -> bool: return node is None or all(i.starts_line is None for i in node.get_instructions()) diff --git a/test/Match.py b/test/Match.py new file mode 100644 index 0000000..e1790ad --- /dev/null +++ b/test/Match.py @@ -0,0 +1,206 @@ +def a0_bare_match(): + match x: + case 1: + print(1) + print(2) + + +def a1_bare_match(x): + match x: + case 1: + print(1) + case _: + print(2) + print(3) + + +def a2_bare_match(x): + match x: + case 1: + print(1) + case 2: + print(2) + case _: + print(3) + print(4) + + +def a3_bare_match(x): + match x: + case 1: + print(1) + case 2: + print(2) + case 3: + print(3) + case _: + print(4) + print(5) + + +def a4_bare_match(x): + match x: + case 1: + print(1) + case 2: + print(2) + case 3: + print(3) + case 4: + print(4) + print(5) + + +def b0_multi_case(): + match x: + case 1 | 2: + print(1) + + +def b1_multi_case_fallthrough(): + match x: + case 1 | 2: + print(1) + print(2) + + +def c0_match_with_as(): + match x: + case [1, 2] as y: + print(1) + + +def c1_match_with_as(): + match x: + case [1, 2] as y: + print(1) + case [3, 4] as z: + print(2) + + +def c1_match_with_as_fallthrough(): + match x: + case [1, 2] as y: + print(1) + print(2) + + +def d0_match_sequence(): + match x: + case [a, b, c]: + print(1) + + +def d1_match_sequence_fallthrough(): + match x: + case [a, b, c]: + print(1) + print(2) + + +def e0_match_mapping(): + match x: + case {'key': value}: + print(1) + + +def e1_match_mapping_fallthrough(): + match x: + case {'key': value}: + print(1) + print(2) + + +def f0_match_class(): + match x: + case Point(x=0, y=0): + print(1) + + +def f1_match_class_fallthrough(): + match x: + case Point(x=0, y=0): + print(1) + print(2) + + +def g0_match_complex(): + match x: + case [Point(x1, y1), Point(x2, y2) as p2]: + print(1) + + +def g1_match_complex_fallthrough(): + match x: + case [Point(x1, y1), Point(x2, y2) as p2]: + print(1) + print(2) + + +def h0_try_match_except(): + try: + match x: + case 1: + print(1) + except: + print(2) + print(3) + + +def i0_match_return(): + match x: + case 1: + return 1 + print(1) + + +def j0_match_raise(): + match x: + case 1: + raise Exc + print(1) + + +async def k0_bare_match(): + match x: + case 1: + print(1) + + +async def k1_bare_match_fallthrough(): + match x: + case 1: + print(1) + print(2) + + +def n0_match_guard(): + match x: + case [a, b] if a > b: + print(1) + + +def n1_match_guard_fallthrough(): + match x: + case [a, b] if a > b: + print(1) + print(2) + + +def m0_nested_match(): + match x: + case [a, b]: + match b: + case 1: + print(1) + print(2) + + +def m1_nested_match_fallthrough(): + match x: + case [a, b]: + match b: + case 1: + print(1) + print(2) + print(3) \ No newline at end of file