""" @file transformers/control_flow.py @brief Control-flow flattening transformer. @details Converts structured control flow into a state-machine with a while/dispatch loop to hinder static analysis, and can emit debug telemetry in debug mode. """ import ast import random from typing import Dict, List, Set, Tuple, Any, Optional import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger("ControlFlowFlattener") class ControlFlowFlattener(ast.NodeTransformer): """ A transformer that flattens control flow by converting branching structures into a dispatch table with a while loop. """ def __init__(self, debug_mode=False): self.state_var_name = "_state" self.states = {} self.current_block_id = 0 self.debug_mode = debug_mode # For tracking transformations when debug mode is enabled if debug_mode: self.debug_data = { "flattened_functions": [], "block_counts": {}, "transformations": [] } def log_debug(self, category: str, data: Any): """Log debugging information if debug mode is enabled.""" if self.debug_mode: self.debug_data["transformations"].append({ "type": category, "data": data }) def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef: """ Transform a function body into a flattened control flow. Keeps function signature the same but replaces the body with state-machine style execution. """ # Skip small functions as flattening them might cause more overhead than obfuscation if len(node.body) <= 2: return self.generic_visit(node) # Reset state for this function self.states = {} self.current_block_id = 0 # Process decorators and args normally node.decorator_list = [self.visit(d) for d in node.decorator_list] if hasattr(node, 'args'): node.args = self.visit(node.args) # Create entry block (state 0) entry_block_id = self.current_block_id self.current_block_id += 1 self.states[entry_block_id] = node.body # Create the flattened control flow flattened_body = self.flatten_blocks() node.body = flattened_body if self.debug_mode: self.debug_data["flattened_functions"].append(node.name) self.debug_data["block_counts"][node.name] = len(self.states) self.log_debug("function_flattened", { "name": node.name, "original_statements": len(self.states[entry_block_id]), "flattened_blocks": len(self.states) }) return node def flatten_blocks(self) -> List[ast.stmt]: """ Create a flattened control flow structure using a while loop and switch-like dispatch based on a state variable. """ # Create the state variable and initialize it state_var = ast.Name(id=self.state_var_name, ctx=ast.Store()) init_state = ast.Assign( targets=[state_var], value=ast.Constant(value=0, kind=None) ) # Create the while loop condition (state != -1) loop_condition = ast.Compare( left=ast.Name(id=self.state_var_name, ctx=ast.Load()), ops=[ast.NotEq()], comparators=[ast.Constant(value=-1, kind=None)] ) # Create the dispatch table as a series of if/elif statements dispatch_cases = [] for state_id, block in self.states.items(): # Create the condition (state == state_id) condition = ast.Compare( left=ast.Name(id=self.state_var_name, ctx=ast.Load()), ops=[ast.Eq()], comparators=[ast.Constant(value=state_id, kind=None)] ) # Process the block statements processed_block = [] for stmt in block: processed_stmt = self.visit(stmt) if isinstance(processed_stmt, list): processed_block.extend(processed_stmt) else: processed_block.append(processed_stmt) # If this block doesn't modify the state, add a transition to the next block last_stmt = processed_block[-1] if processed_block else None if not (isinstance(last_stmt, ast.Assign) and isinstance(last_stmt.targets[0], ast.Name) and last_stmt.targets[0].id == self.state_var_name): # Add transition to next block next_state = state_id + 1 if next_state not in self.states: # If no next state, exit the loop next_state = -1 processed_block.append( ast.Assign( targets=[ast.Name(id=self.state_var_name, ctx=ast.Store())], value=ast.Constant(value=next_state, kind=None) ) ) # Create the if/elif body if_body = processed_block # Add to dispatch cases dispatch_cases.append((condition, if_body)) # Convert dispatch cases to if/elif/else structure if_node = None for i, (condition, body) in enumerate(reversed(dispatch_cases)): if i == 0: # Last case (will be the 'else' clause) if_node = ast.If( test=condition, body=body, orelse=[] ) else: if_node = ast.If( test=condition, body=body, orelse=[if_node] if if_node else [] ) # Create the while loop with the dispatch logic while_body = [if_node] if if_node else [] while_loop = ast.While( test=loop_condition, body=while_body, orelse=[] ) # Add some junk code to obscure the control flow junk = self.generate_junk_code() # Assemble the flattened function body flattened_body = [init_state] + junk + [while_loop] return flattened_body def generate_junk_code(self) -> List[ast.stmt]: """Generate meaningless code to obscure the control flow.""" junk = [] # Add dummy variables that look like they're used for control flow dummy_vars = [f"_cflow_{i}" for i in range(random.randint(2, 5))] for var in dummy_vars: # Initialize with random value junk.append( ast.Assign( targets=[ast.Name(id=var, ctx=ast.Store())], value=ast.Constant(value=random.randint(0, 100), kind=None) ) ) # Add some conditional statements that don't do anything important if dummy_vars: cond = ast.Compare( left=ast.Name(id=random.choice(dummy_vars), ctx=ast.Load()), ops=[ast.Gt()], comparators=[ast.Constant(value=50, kind=None)] ) junk.append( ast.If( test=cond, body=[ ast.Assign( targets=[ast.Name(id=random.choice(dummy_vars), ctx=ast.Store())], value=ast.Constant(value=random.randint(0, 100), kind=None) ) ], orelse=[] ) ) if self.debug_mode: self.log_debug("junk_code", { "statements": len(junk), "variables": dummy_vars }) return junk def visit_If(self, node: ast.If) -> ast.stmt: """ Transform if statements into state transitions. Conditional branches become separate states in the state machine. """ # Create new states for the true and false branches true_branch_id = self.current_block_id self.current_block_id += 1 self.states[true_branch_id] = node.body if node.orelse: false_branch_id = self.current_block_id self.current_block_id += 1 self.states[false_branch_id] = node.orelse else: # If no else branch, use the next block in sequence false_branch_id = self.current_block_id # Create a conditional assignment to the state variable result = ast.Assign( targets=[ast.Name(id=self.state_var_name, ctx=ast.Store())], value=ast.IfExp( test=self.visit(node.test), body=ast.Constant(value=true_branch_id, kind=None), orelse=ast.Constant(value=false_branch_id, kind=None) ) ) if self.debug_mode: self.log_debug("if_transformation", { "true_branch": true_branch_id, "false_branch": false_branch_id, "has_else": bool(node.orelse) }) return result def visit_For(self, node: ast.For) -> List[ast.stmt]: """ Transform for loops into state transitions with loop body and else clause as separate states. """ # Create unique variable names for this loop iter_var = f"_iter_{self.current_block_id}" index_var = f"_idx_{self.current_block_id}" # Setup the iteration setup_stmts = [ # Create iterator: _iter_X = iter(iterable) ast.Assign( targets=[ast.Name(id=iter_var, ctx=ast.Store())], value=ast.Call( func=ast.Name(id="iter", ctx=ast.Load()), args=[self.visit(node.iter)], keywords=[] ) ), # Initialize index: _idx_X = 0 ast.Assign( targets=[ast.Name(id=index_var, ctx=ast.Store())], value=ast.Constant(value=0, kind=None) ) ] # Create states for the loop body and else clause loop_body_id = self.current_block_id self.current_block_id += 1 # Loop body needs to get the next item and assign it to the target loop_body = [ # try: target = next(_iter_X) ast.Try( body=[ ast.Assign( targets=[self.visit(node.target)], value=ast.Call( func=ast.Name(id="next", ctx=ast.Load()), args=[ast.Name(id=iter_var, ctx=ast.Load())], keywords=[] ) ) ], # except StopIteration: goto else_clause or skip if no else handlers=[ ast.ExceptHandler( type=ast.Name(id="StopIteration", ctx=ast.Load()), name=None, body=[ ast.Assign( targets=[ast.Name(id=self.state_var_name, ctx=ast.Store())], value=ast.Constant( value=(self.current_block_id if node.orelse else -1), kind=None ) ) ] ) ], # no finally orelse=[], finalbody=[] ) ] + node.body + [ # Increment index ast.AugAssign( target=ast.Name(id=index_var, ctx=ast.Store()), op=ast.Add(), value=ast.Constant(value=1, kind=None) ), # Loop back to body ast.Assign( targets=[ast.Name(id=self.state_var_name, ctx=ast.Store())], value=ast.Constant(value=loop_body_id, kind=None) ) ] self.states[loop_body_id] = loop_body # Handle else clause if present if node.orelse: else_id = self.current_block_id self.current_block_id += 1 self.states[else_id] = node.orelse if self.debug_mode: self.log_debug("for_loop_transformation", { "iterator_var": iter_var, "index_var": index_var, "body_block": loop_body_id, "has_else": bool(node.orelse), "else_block": else_id if node.orelse else None }) # Transition to the loop body after setup setup_stmts.append( ast.Assign( targets=[ast.Name(id=self.state_var_name, ctx=ast.Store())], value=ast.Constant(value=loop_body_id, kind=None) ) ) return setup_stmts def visit_While(self, node: ast.While) -> List[ast.stmt]: """ Transform while loops into state transitions with conditional jumps. """ # Create states for the condition check, body, and else clause cond_check_id = self.current_block_id self.current_block_id += 1 # Body state ID body_id = self.current_block_id self.current_block_id += 1 # Else state ID (if present) else_id = self.current_block_id if node.orelse else -1 if node.orelse: self.current_block_id += 1 # Condition check state: if test: goto body else: goto else/exit cond_check = [ ast.Assign( targets=[ast.Name(id=self.state_var_name, ctx=ast.Store())], value=ast.IfExp( test=self.visit(node.test), body=ast.Constant(value=body_id, kind=None), orelse=ast.Constant(value=else_id, kind=None) ) ) ] self.states[cond_check_id] = cond_check # Body state: execute body then goto condition check body = list(node.body) + [ ast.Assign( targets=[ast.Name(id=self.state_var_name, ctx=ast.Store())], value=ast.Constant(value=cond_check_id, kind=None) ) ] self.states[body_id] = body # Else state (if present) if node.orelse: self.states[else_id] = node.orelse if self.debug_mode: self.log_debug("while_loop_transformation", { "condition_check_block": cond_check_id, "body_block": body_id, "has_else": bool(node.orelse), "else_block": else_id if else_id != -1 else None }) # Initial transition to the condition check return [ ast.Assign( targets=[ast.Name(id=self.state_var_name, ctx=ast.Store())], value=ast.Constant(value=cond_check_id, kind=None) ) ] def visit_Break(self, node: ast.Break) -> ast.Assign: """ Transform break statements into a state change to exit the current loop. In a flattened control flow, this means finding the next state after the loop. For simplicity, we'll just set to -1 (terminate), but a more sophisticated approach would track enclosing loops and their exit states. """ if self.debug_mode: self.log_debug("break_transformation", { "exit_state": -1 }) return ast.Assign( targets=[ast.Name(id=self.state_var_name, ctx=ast.Store())], value=ast.Constant(value=-1, kind=None) ) def visit_Continue(self, node: ast.Continue) -> ast.Assign: """ Transform continue statements to go back to the loop condition. In a flattened control flow, we need to know the loop condition state. For simplicity, we'll implement a jump to the current state, effectively rerunning the current block, but a more sophisticated approach would track enclosing loops and their condition states. """ if self.debug_mode: self.log_debug("continue_transformation", { "target_state": "current loop condition (simplified)" }) # In this simplified model, we just loop back to the current state # A more complete implementation would track the loop stack and jump to the loop start return ast.Assign( targets=[ast.Name(id=self.state_var_name, ctx=ast.Store())], value=ast.Constant(value=0, kind=None) ) def visit_Return(self, node: ast.Return) -> List[ast.stmt]: """ Transform return statements into state transitions that terminate the function. """ # Evaluate return value if present, then exit the state machine if node.value: return_stmt = ast.Return(value=self.visit(node.value)) exit_stmt = ast.Assign( targets=[ast.Name(id=self.state_var_name, ctx=ast.Store())], value=ast.Constant(value=-1, kind=None) ) if self.debug_mode: self.log_debug("return_transformation", { "has_value": True }) return [return_stmt, exit_stmt] else: # No return value if self.debug_mode: self.log_debug("return_transformation", { "has_value": False }) return [ ast.Return(value=None), ast.Assign( targets=[ast.Name(id=self.state_var_name, ctx=ast.Store())], value=ast.Constant(value=-1, kind=None) ) ]