3.10+Match Statements

This commit is contained in:
Xinlong Hu
2025-08-09 21:03:56 -05:00
committed by Joel-Flores123
parent f7ed7a3539
commit 9bf67312f9
3 changed files with 334 additions and 0 deletions
@@ -0,0 +1,118 @@
from itertools import chain
from typing import override
from .Block import BlockTemplate
from ..cft import SourceContext, SourceLine, ControlFlowTemplate, EdgeKind, register_template
from ..utils import T, N, ending_instructions, has_start_end_source, versions_from, make_try_match
class CaseOne(ControlFlowTemplate):
template = T(
case_header=~N("case_body", None).with_cond(has_start_end_source("case", ":")),
case_body=~N("tail."),
tail=N.tail(),
)
try_match = make_try_match({EdgeKind.Fall: "tail"}, "case_header", "case_body")
def to_indented_source(self, source: SourceContext) -> list[SourceLine]:
case_body = source[self.case_body]
cutoff = next((i for i, x in enumerate(self.case_header.get_instructions()) if x.source_line.strip().startswith("case")), 0)
if isinstance(self.case_header, BlockTemplate):
i = cutoff + 1
case_header = source[BlockTemplate(self.case_header.members[:i]), 1] if i > 0 else []
case_lines = source[BlockTemplate(self.case_header.members[i:]), 2] if i < len(self.case_header.members) else []
else:
case_header = source[self.case_header, 1]
case_lines = []
return list(chain(case_header, case_lines, case_body))
class CaseWrapper(ControlFlowTemplate):
@classmethod
@override
def try_match(cls, cfg, node) -> ControlFlowTemplate | None:
if x := CaseTwo.try_match(cfg, node):
return x
if x := CaseOne.try_match(cfg, node):
return x
@register_template(1, 0, *versions_from(3, 10))
class CaseTwo(ControlFlowTemplate):
template = T(
case_header=~N("case_body", "other.").with_cond(has_start_end_source("case", ":")),
case_body=~N("other."),
other=~N("tail.").of_subtemplate(CaseWrapper) | N.tail(),
tail=N.tail(),
)
try_match = make_try_match({EdgeKind.Fall: "tail"}, "case_header", "case_body", "other")
def to_indented_source(self, source: SourceContext) -> list[SourceLine]:
case_header = source[self.case_header, 1]
case_body = source[self.case_body, 2]
other = source[self.other]
return list(chain(case_header, case_body, other))
@register_template(0, 0, *versions_from(3, 10))
class Match(ControlFlowTemplate):
template = T(
match_header=~N("case_body", "tail").with_cond(has_start_end_source("match", ":")),
case_body=~N("tail.").with_in_deg(1) | ~N("tail").with_in_deg(1).with_cond(ending_instructions("POP_TOP")),
tail=~N.tail().of_subtemplate(CaseWrapper) | N.tail(),
)
try_match = make_try_match({EdgeKind.Fall: "tail"}, "match_header", "case_body", "POP_TOP")
def to_indented_source(self, source: SourceContext) -> list[SourceLine]:
match_line = None
case_line = None
case_body = source[self.case_body, 2]
cutoff = next((i for i, x in enumerate(self.match_header.get_instructions()) if x.source_line.strip().startswith("match")), 0)
if isinstance(self.match_header, BlockTemplate):
i = cutoff + 1
match_line = source[BlockTemplate(self.match_header.members[:i])] if i > 0 else []
case_line = source[BlockTemplate(self.match_header.members[i:]), 1] if i < len(self.match_header.members) else []
else:
match_line = source[self.match_header, 1]
case_line = []
return list(chain(match_line, case_line, case_body))
@register_template(0, 0, *versions_from(3, 10))
class MultiMatch(ControlFlowTemplate):
template = T(
match_header=~N("multi_header", "POP_TOP").with_cond(has_start_end_source("match", ":")),
multi_header=~N("case_body", "POP_TOP"),
case_body=~N("tail.").with_in_deg(1) | ~N("tail").with_in_deg(1).with_cond(ending_instructions("POP_TOP")),
POP_TOP=~N("tail."),
tail=~N.tail().of_subtemplate(CaseWrapper) | ~N.tail(),
)
try_match = make_try_match({EdgeKind.Fall: "tail"}, "multi_header", "match_header", "case_body", "POP_TOP")
def to_indented_source(self, source: SourceContext) -> list[SourceLine]:
match_line = None
case_line = None
case_body = source[self.case_body, 2]
cutoff = next((i for i, x in enumerate(self.match_header.get_instructions()) if x.source_line.strip().startswith("match")), 0)
if isinstance(self.match_header, BlockTemplate):
i = cutoff + 1
match_line = source[BlockTemplate(self.match_header.members[:i])] if i > 0 else []
case_line = source[BlockTemplate(self.match_header.members[i:]), 1] if i < len(self.match_header.members) else []
else:
match_line = source[self.match_header, 1]
case_line = []
return list(chain(match_line, case_line, case_body))
@@ -145,6 +145,16 @@ def has_instval(opname: str, argval: Any):
return check_instructions return check_instructions
def has_start_end_source(argval: Any, endval: Any):
def check_instructions(cfg: CFG, node: ControlFlowTemplate | None) -> bool:
for x in node.get_instructions():
if x.source_line.startswith(argval) and x.source_line.endswith(endval):
return True
return False
return check_instructions
def has_no_lines(cfg: CFG, node: ControlFlowTemplate | None) -> bool: def has_no_lines(cfg: CFG, node: ControlFlowTemplate | None) -> bool:
return node is None or all(i.starts_line is None for i in node.get_instructions()) return node is None or all(i.starts_line is None for i in node.get_instructions())
+206
View File
@@ -0,0 +1,206 @@
def a0_bare_match():
match x:
case 1:
print(1)
print(2)
def a1_bare_match(x):
match x:
case 1:
print(1)
case _:
print(2)
print(3)
def a2_bare_match(x):
match x:
case 1:
print(1)
case 2:
print(2)
case _:
print(3)
print(4)
def a3_bare_match(x):
match x:
case 1:
print(1)
case 2:
print(2)
case 3:
print(3)
case _:
print(4)
print(5)
def a4_bare_match(x):
match x:
case 1:
print(1)
case 2:
print(2)
case 3:
print(3)
case 4:
print(4)
print(5)
def b0_multi_case():
match x:
case 1 | 2:
print(1)
def b1_multi_case_fallthrough():
match x:
case 1 | 2:
print(1)
print(2)
def c0_match_with_as():
match x:
case [1, 2] as y:
print(1)
def c1_match_with_as():
match x:
case [1, 2] as y:
print(1)
case [3, 4] as z:
print(2)
def c1_match_with_as_fallthrough():
match x:
case [1, 2] as y:
print(1)
print(2)
def d0_match_sequence():
match x:
case [a, b, c]:
print(1)
def d1_match_sequence_fallthrough():
match x:
case [a, b, c]:
print(1)
print(2)
def e0_match_mapping():
match x:
case {'key': value}:
print(1)
def e1_match_mapping_fallthrough():
match x:
case {'key': value}:
print(1)
print(2)
def f0_match_class():
match x:
case Point(x=0, y=0):
print(1)
def f1_match_class_fallthrough():
match x:
case Point(x=0, y=0):
print(1)
print(2)
def g0_match_complex():
match x:
case [Point(x1, y1), Point(x2, y2) as p2]:
print(1)
def g1_match_complex_fallthrough():
match x:
case [Point(x1, y1), Point(x2, y2) as p2]:
print(1)
print(2)
def h0_try_match_except():
try:
match x:
case 1:
print(1)
except:
print(2)
print(3)
def i0_match_return():
match x:
case 1:
return 1
print(1)
def j0_match_raise():
match x:
case 1:
raise Exc
print(1)
async def k0_bare_match():
match x:
case 1:
print(1)
async def k1_bare_match_fallthrough():
match x:
case 1:
print(1)
print(2)
def n0_match_guard():
match x:
case [a, b] if a > b:
print(1)
def n1_match_guard_fallthrough():
match x:
case [a, b] if a > b:
print(1)
print(2)
def m0_nested_match():
match x:
case [a, b]:
match b:
case 1:
print(1)
print(2)
def m1_nested_match_fallthrough():
match x:
case [a, b]:
match b:
case 1:
print(1)
print(2)
print(3)