This commit is contained in:
2025-08-15 20:12:35 -07:00
parent c686e60ec5
commit 387c694efe
14 changed files with 2724 additions and 0 deletions
+1
View File
@@ -0,0 +1 @@
# This file ensures that the transformers directory is treated as a Python package
+53
View File
@@ -0,0 +1,53 @@
import ast
from typing import Dict, Optional
class AttributeTransformer(ast.NodeTransformer):
"""
Transforms attribute access expressions (obj.attr) consistently,
especially for class method calls and attribute accesses.
"""
def __init__(self, class_attr_mapping: Dict[str, Dict[str, str]], class_renames: Dict[str, str]):
"""
Initialize with mapping dictionaries.
Args:
class_attr_mapping: Maps class_name -> {attr_name -> obfuscated_attr_name}
class_renames: Maps original_class_name -> obfuscated_class_name
"""
self.class_attr_mapping = class_attr_mapping
self.class_renames = class_renames
self.current_class: Optional[str] = None
def visit_ClassDef(self, node):
"""Keep track of the current class being processed"""
old_class = self.current_class
# Get the obfuscated name for this class
class_name = node.name
if class_name in self.class_renames:
self.current_class = self.class_renames[class_name]
else:
self.current_class = class_name
# Process the class body
node = self.generic_visit(node)
# Restore the previous class context
self.current_class = old_class
return node
def visit_Attribute(self, node):
"""Transform attribute access like self.method to use consistent names"""
# First process any nested attributes
node = self.generic_visit(node)
# Handle self.attr references within a class
if isinstance(node.value, ast.Name) and node.value.id == 'self' and self.current_class:
if self.current_class in self.class_attr_mapping:
attr_map = self.class_attr_mapping[self.current_class]
# Only substitute if node.attr is still in its original form (a mapping key)
if node.attr in attr_map and node.attr != attr_map[node.attr]:
node.attr = attr_map[node.attr]
return node
+215
View File
@@ -0,0 +1,215 @@
import ast
from typing import Dict, Set, Tuple, List
class ClassMethodMap:
"""Stores method name mappings for all classes in the code."""
def __init__(self):
# Maps: original_class_name -> {original_method_name -> obfuscated_method_name}
self.class_methods: Dict[str, Dict[str, str]] = {}
# Maps: original_class_name -> {original_attr_name -> obfuscated_attr_name}
self.class_attributes: Dict[str, Dict[str, str]] = {}
# Maps: original_class_name -> obfuscated_class_name
self.class_renames: Dict[str, str] = {}
# Track inheritance relationships: child_class -> [parent_classes]
self.inheritance: Dict[str, List[str]] = {}
class ClassAnalyzer(ast.NodeVisitor):
"""
Pre-analyzes classes to ensure consistent renaming of methods and attributes.
This is crucial for making self.method() calls match def method() definitions.
"""
def __init__(self, name_generator):
self.name_generator = name_generator
self.method_map = ClassMethodMap()
self.current_class = None
# To avoid duplicate scanning
self.scanned_classes: Set[str] = set()
# Track method calls within each class
self.method_calls: Dict[str, Set[str]] = {}
def analyze(self, tree: ast.AST) -> ClassMethodMap:
"""Analyzes the entire AST and returns populated method mappings."""
self.visit(tree)
self._resolve_inheritance()
self._ensure_consistent_method_mapping()
return self.method_map
def visit_ClassDef(self, node: ast.ClassDef):
"""Process a class definition and map its methods."""
prev_class = self.current_class
self.current_class = node.name
# Skip if already processed this class
if node.name in self.scanned_classes:
self.current_class = prev_class
return
# Initialize method calls tracking for this class
self.method_calls[node.name] = set()
# Record class inheritance
parent_classes = []
for base in node.bases:
if isinstance(base, ast.Name):
parent_classes.append(base.id)
if parent_classes:
self.method_map.inheritance[node.name] = parent_classes
# Initialize mappings for this class
if node.name not in self.method_map.class_methods:
self.method_map.class_methods[node.name] = {}
if node.name not in self.method_map.class_attributes:
self.method_map.class_attributes[node.name] = {}
# Create a consistent obfuscated name for this class
if node.name not in self.method_map.class_renames:
new_name = self.name_generator.generate_name()
self.method_map.class_renames[node.name] = new_name
# Process all method definitions in the class
for item in node.body:
# Methods
if isinstance(item, ast.FunctionDef):
# Skip dunder methods
if not (item.name.startswith('__') and item.name.endswith('__')):
# Generate a consistent obfuscated name for this method
new_name = self.name_generator.generate_name()
self.method_map.class_methods[node.name][item.name] = new_name
# Visit the method body to find self.method() calls
self.visit(item)
# Attributes in assignments
elif isinstance(item, ast.Assign):
self.visit_attribute_assign(item)
else:
# Visit other nodes (like if statements that might contain self.method calls)
self.visit(item)
self.scanned_classes.add(node.name)
# Visit any nested classes
for item in node.body:
if isinstance(item, ast.ClassDef):
self.visit(item)
self.current_class = prev_class
def visit_attribute_assign(self, node):
"""Process attribute assignments like self.attr = value"""
if not self.current_class:
return
for target in node.targets:
if isinstance(target, ast.Attribute) and isinstance(target.value, ast.Name):
if target.value.id == 'self':
# This is a self.attribute assignment
attr_name = target.attr
if attr_name not in self.method_map.class_attributes[self.current_class]:
new_name = self.name_generator.generate_name()
self.method_map.class_attributes[self.current_class][attr_name] = new_name
# Visit the value part of the assignment for nested self.method() calls
self.visit(node.value)
def visit_Attribute(self, node):
"""Track self.method references to ensure consistent naming"""
if self.current_class:
is_self_method, method_name = get_method_name(node)
if is_self_method:
# Record this method call for consistency checks later
self.method_calls[self.current_class].add(method_name)
# Continue traversing
self.generic_visit(node)
def _ensure_consistent_method_mapping(self):
"""
Make sure that all methods called via self.method() have a mapping,
even if they're not defined in the class.
"""
for class_name, method_calls in self.method_calls.items():
if class_name not in self.method_map.class_methods:
continue
class_methods = self.method_map.class_methods[class_name]
for method_name in method_calls:
if method_name not in class_methods:
# Skip dunder methods
if method_name.startswith('__') and method_name.endswith('__'):
continue
# Existing check: mapping is generated only once.
new_name = self.name_generator.generate_name()
class_methods[method_name] = new_name
def _resolve_inheritance(self):
"""
Ensure child classes inherit method mappings from parent classes.
This ensures that overridden methods use the same obfuscated name.
"""
# Process inheritance depth-first to handle multi-level inheritance
def process_inheritance(class_name):
if class_name not in self.method_map.inheritance:
return
for parent in self.method_map.inheritance[class_name]:
# Process parent's inheritance first
process_inheritance(parent)
# Skip if parent isn't in our mappings (external class)
if parent not in self.method_map.class_methods:
continue
# Inherit parent's methods if not overridden
for method_name, obf_name in self.method_map.class_methods[parent].items():
if method_name not in self.method_map.class_methods[class_name]:
self.method_map.class_methods[class_name][method_name] = obf_name
# Process inheritance for each class
for class_name in list(self.method_map.class_methods.keys()):
process_inheritance(class_name)
def get_method_name(node: ast.Attribute) -> Tuple[bool, str]:
"""
Helper function to determine if an attribute is a self.method() call.
Returns (is_self_method, method_name)
"""
if isinstance(node.value, ast.Name) and node.value.id == 'self':
return True, node.attr
return False, ""
def update_obfuscator_with_class_mappings(obfuscator, class_map: ClassMethodMap):
"""
Updates the main obfuscator with class method and attribute mappings
to ensure consistent renaming across the codebase.
"""
# Update class name mappings in global_var_renames
for orig_name, obf_name in class_map.class_renames.items():
obfuscator.global_var_renames[orig_name] = obf_name
# Update class attr mapping with our analyzed data
for class_name, class_obf_name in class_map.class_renames.items():
# Initialize if needed
if class_obf_name not in obfuscator.class_attr_mapping:
obfuscator.class_attr_mapping[class_obf_name] = {}
# Copy method mappings
if class_name in class_map.class_methods:
for method, obf_method in class_map.class_methods[class_name].items():
obfuscator.class_attr_mapping[class_obf_name][method] = obf_method
# Copy attribute mappings
if class_name in class_map.class_attributes:
for attr, obf_attr in class_map.class_attributes[class_name].items():
obfuscator.class_attr_mapping[class_obf_name][attr] = obf_attr
+300
View File
@@ -0,0 +1,300 @@
import ast
from typing import Dict, Set, List, Tuple, Optional
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("ClassMapper")
class ClassMapping:
"""Stores all mappings related to classes in a centralized way."""
def __init__(self):
# Original class name -> obfuscated class name
self.class_names: Dict[str, str] = {}
# Original class name -> {original method name -> obfuscated method name}
self.class_methods: Dict[str, Dict[str, str]] = {}
# Original class name -> {original attr name -> obfuscated attr name}
self.class_attributes: Dict[str, Dict[str, str]] = {}
# Child class -> list of parent classes (original names)
self.inheritance: Dict[str, List[str]] = {}
# Track all seen method calls to ensure complete coverage
self.seen_method_calls: Dict[str, Set[str]] = {}
def debug_info(self) -> str:
"""Return debug information about mappings."""
info = []
info.append(f"Class mappings: {len(self.class_names)} classes")
for cls_name, obf_name in self.class_names.items():
info.append(f" {cls_name} -> {obf_name}")
if cls_name in self.class_methods:
methods = self.class_methods[cls_name]
info.append(f" Methods: {len(methods)}")
for method, obf_method in methods.items():
info.append(f" {method} -> {obf_method}")
if cls_name in self.class_attributes:
attrs = self.class_attributes[cls_name]
info.append(f" Attributes: {len(attrs)}")
for attr, obf_attr in attrs.items():
info.append(f" {attr} -> {obf_attr}")
return "\n".join(info)
class ClassMapAnalyzer(ast.NodeVisitor):
"""Analyzes the AST to create a complete class mapping."""
def __init__(self, name_generator):
self.name_generator = name_generator
self.mapping = ClassMapping()
self.current_class: Optional[str] = None
self.current_method: Optional[str] = None
self.processed_classes: Set[str] = set()
def analyze(self, tree: ast.AST) -> ClassMapping:
"""Perform a complete analysis of the AST."""
# First pass: collect all class definitions, methods, and inheritance
self.visit(tree)
# Second pass: resolve inheritance and method mappings
self._resolve_inheritance()
self._ensure_complete_method_mapping()
logger.info(f"Class analysis complete: {len(self.mapping.class_names)} classes processed")
logger.debug(self.mapping.debug_info())
return self.mapping
def visit_ClassDef(self, node: ast.ClassDef):
"""Process a class definition."""
prev_class = self.current_class
self.current_class = node.name
# Skip if already processed
if node.name in self.processed_classes:
self.current_class = prev_class
return
# Add class name mapping
if node.name not in self.mapping.class_names:
self.mapping.class_names[node.name] = self.name_generator.generate_name()
# Initialize dictionaries
if node.name not in self.mapping.class_methods:
self.mapping.class_methods[node.name] = {}
if node.name not in self.mapping.class_attributes:
self.mapping.class_attributes[node.name] = {}
if node.name not in self.mapping.seen_method_calls:
self.mapping.seen_method_calls[node.name] = set()
# Record inheritance
parent_classes = []
for base in node.bases:
if isinstance(base, ast.Name):
parent_classes.append(base.id)
if parent_classes:
self.mapping.inheritance[node.name] = parent_classes
# Process class body
for item in node.body:
if isinstance(item, ast.FunctionDef):
self.visit_method_def(item)
elif isinstance(item, ast.Assign):
self.visit_assign_in_class(item)
elif isinstance(item, ast.Expr):
# Could contain calls to self.methods
self.visit(item)
elif isinstance(item, ast.ClassDef):
# Nested class
self.visit(item)
else:
# Other nodes that might contain self.method calls
self.visit(item)
self.processed_classes.add(node.name)
self.current_class = prev_class
def visit_method_def(self, node: ast.FunctionDef):
"""Process a method definition in a class."""
if not self.current_class:
return
prev_method = self.current_method
self.current_method = node.name
# Skip dunder methods from obfuscation
if not (node.name.startswith('__') and node.name.endswith('__')):
# Map method name if not already mapped
if node.name not in self.mapping.class_methods[self.current_class]:
obf_name = self.name_generator.generate_name()
self.mapping.class_methods[self.current_class][node.name] = obf_name
logger.debug(f"Mapped method {self.current_class}.{node.name} to {obf_name}")
# Visit method body to find self.method calls and self.attr assignments
for item in node.body:
self.visit(item)
self.current_method = prev_method
def visit_assign_in_class(self, node: ast.Assign):
"""Process assignments in class body or methods."""
if not self.current_class:
return
# Check for self.attribute assignments
for target in node.targets:
if isinstance(target, ast.Attribute) and isinstance(target.value, ast.Name) and target.value.id == 'self':
attr_name = target.attr
# Map attribute name if not already mapped
if attr_name not in self.mapping.class_attributes[self.current_class]:
obf_name = self.name_generator.generate_name()
self.mapping.class_attributes[self.current_class][attr_name] = obf_name
logger.debug(f"Mapped attribute {self.current_class}.{attr_name} to {obf_name}")
# Visit the value to find nested self.method calls
self.visit(node.value)
def visit_Attribute(self, node: ast.Attribute):
"""Process attribute access like self.method or self.attr."""
if self.current_class and isinstance(node.value, ast.Name) and node.value.id == 'self':
# Record this access for later processing
method_name = node.attr
self.mapping.seen_method_calls[self.current_class].add(method_name)
logger.debug(f"Recorded method call: {self.current_class}.{method_name}")
# Continue traversal
self.generic_visit(node)
def visit_Assign(self, node: ast.Assign):
"""Process assignments that might contain self.attr references."""
# Visit both sides of the assignment
for target in node.targets:
self.visit(target)
self.visit(node.value)
def _resolve_inheritance(self):
"""
Ensure child classes inherit method mappings from parent classes.
"""
def process_inheritance(class_name):
if class_name not in self.mapping.inheritance:
return
for parent in self.mapping.inheritance[class_name]:
# Process parent's inheritance first
process_inheritance(parent)
# Skip if parent isn't in our mappings
if parent not in self.mapping.class_methods:
continue
# Copy parent's method mappings to child if not overridden
for method_name, obf_name in self.mapping.class_methods[parent].items():
if method_name not in self.mapping.class_methods[class_name]:
self.mapping.class_methods[class_name][method_name] = obf_name
logger.debug(f"Inherited method {class_name}.{method_name} from {parent}")
# Process all classes
for class_name in list(self.mapping.class_methods.keys()):
process_inheritance(class_name)
def _ensure_complete_method_mapping(self):
"""
Make sure all method calls have corresponding mappings.
This handles methods called but not defined in the class.
"""
for class_name, method_calls in self.mapping.seen_method_calls.items():
if class_name not in self.mapping.class_methods:
continue
for method_name in method_calls:
# Skip dunder methods
if method_name.startswith('__') and method_name.endswith('__'):
continue
# Add mapping if method was called but not defined
if method_name not in self.mapping.class_methods[class_name]:
obf_name = self.name_generator.generate_name()
self.mapping.class_methods[class_name][method_name] = obf_name
logger.debug(f"Added mapping for called method {class_name}.{method_name} -> {obf_name}")
class ClassTransformer(ast.NodeTransformer):
"""Transforms class-related nodes using the mapping."""
def __init__(self, mapping: ClassMapping):
self.mapping = mapping
self.current_class: Optional[str] = None
def visit_ClassDef(self, node: ast.ClassDef):
"""Transform class name and process its body."""
prev_class = self.current_class
orig_name = node.name
self.current_class = orig_name
# Rename class if it's in our mapping
if node.name in self.mapping.class_names:
node.name = self.mapping.class_names[node.name]
logger.debug(f"Transformed class {orig_name} -> {node.name}")
# Process class body
node.body = [self.visit(item) for item in node.body]
self.current_class = prev_class
return node
def visit_FunctionDef(self, node: ast.FunctionDef):
"""Transform method name."""
if self.current_class and node.name in self.mapping.class_methods.get(self.current_class, {}):
orig_name = node.name
node.name = self.mapping.class_methods[self.current_class][node.name]
logger.debug(f"Transformed method {self.current_class}.{orig_name} -> {node.name}")
# Visit the method body
node.body = [self.visit(item) for item in node.body]
return node
def visit_Attribute(self, node: ast.Attribute):
"""Transform self.method and self.attr references."""
# Process any child nodes first (for nested attributes)
node.value = self.visit(node.value)
# Check if this is a self.attr or self.method reference
if self.current_class and isinstance(node.value, ast.Name) and node.value.id == 'self':
orig_name = node.attr
# Check in method mappings first
if self.current_class in self.mapping.class_methods and node.attr in self.mapping.class_methods[self.current_class]:
node.attr = self.mapping.class_methods[self.current_class][node.attr]
logger.debug(f"Transformed self.method {self.current_class}.{orig_name} -> {node.attr}")
# Then check attribute mappings
elif self.current_class in self.mapping.class_attributes and node.attr in self.mapping.class_attributes[self.current_class]:
node.attr = self.mapping.class_attributes[self.current_class][node.attr]
logger.debug(f"Transformed self.attr {self.current_class}.{orig_name} -> {node.attr}")
return node
# Helper function to apply the class mapping transformation
def apply_class_mapping(tree: ast.AST, name_generator) -> ast.AST:
"""Analyze and transform classes consistently."""
# First pass: analyze all classes
analyzer = ClassMapAnalyzer(name_generator)
mapping = analyzer.analyze(tree)
# Second pass: transform using the mapping
transformer = ClassTransformer(mapping)
transformed = transformer.visit(tree)
return transformed, mapping
+488
View File
@@ -0,0 +1,488 @@
import ast
import random
from typing import Dict, List, Set, Tuple, Any, Optional
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("ControlFlowFlattener")
class ControlFlowFlattener(ast.NodeTransformer):
"""
A transformer that flattens control flow by converting branching structures
into a dispatch table with a while loop.
"""
def __init__(self, debug_mode=False):
self.state_var_name = "_state"
self.states = {}
self.current_block_id = 0
self.debug_mode = debug_mode
# For tracking transformations when debug mode is enabled
if debug_mode:
self.debug_data = {
"flattened_functions": [],
"block_counts": {},
"transformations": []
}
def log_debug(self, category: str, data: Any):
"""Log debugging information if debug mode is enabled."""
if self.debug_mode:
self.debug_data["transformations"].append({
"type": category,
"data": data
})
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
"""
Transform a function body into a flattened control flow.
Keeps function signature the same but replaces the body with
state-machine style execution.
"""
# Skip small functions as flattening them might cause more overhead than obfuscation
if len(node.body) <= 2:
return self.generic_visit(node)
# Reset state for this function
self.states = {}
self.current_block_id = 0
# Process decorators and args normally
node.decorator_list = [self.visit(d) for d in node.decorator_list]
if hasattr(node, 'args'):
node.args = self.visit(node.args)
# Create entry block (state 0)
entry_block_id = self.current_block_id
self.current_block_id += 1
self.states[entry_block_id] = node.body
# Create the flattened control flow
flattened_body = self.flatten_blocks()
node.body = flattened_body
if self.debug_mode:
self.debug_data["flattened_functions"].append(node.name)
self.debug_data["block_counts"][node.name] = len(self.states)
self.log_debug("function_flattened", {
"name": node.name,
"original_statements": len(self.states[entry_block_id]),
"flattened_blocks": len(self.states)
})
return node
def flatten_blocks(self) -> List[ast.stmt]:
"""
Create a flattened control flow structure using a while loop and switch-like
dispatch based on a state variable.
"""
# Create the state variable and initialize it
state_var = ast.Name(id=self.state_var_name, ctx=ast.Store())
init_state = ast.Assign(
targets=[state_var],
value=ast.Constant(value=0, kind=None)
)
# Create the while loop condition (state != -1)
loop_condition = ast.Compare(
left=ast.Name(id=self.state_var_name, ctx=ast.Load()),
ops=[ast.NotEq()],
comparators=[ast.Constant(value=-1, kind=None)]
)
# Create the dispatch table as a series of if/elif statements
dispatch_cases = []
for state_id, block in self.states.items():
# Create the condition (state == state_id)
condition = ast.Compare(
left=ast.Name(id=self.state_var_name, ctx=ast.Load()),
ops=[ast.Eq()],
comparators=[ast.Constant(value=state_id, kind=None)]
)
# Process the block statements
processed_block = []
for stmt in block:
processed_stmt = self.visit(stmt)
if isinstance(processed_stmt, list):
processed_block.extend(processed_stmt)
else:
processed_block.append(processed_stmt)
# If this block doesn't modify the state, add a transition to the next block
last_stmt = processed_block[-1] if processed_block else None
if not (isinstance(last_stmt, ast.Assign) and
isinstance(last_stmt.targets[0], ast.Name) and
last_stmt.targets[0].id == self.state_var_name):
# Add transition to next block
next_state = state_id + 1
if next_state not in self.states:
# If no next state, exit the loop
next_state = -1
processed_block.append(
ast.Assign(
targets=[ast.Name(id=self.state_var_name, ctx=ast.Store())],
value=ast.Constant(value=next_state, kind=None)
)
)
# Create the if/elif body
if_body = processed_block
# Add to dispatch cases
dispatch_cases.append((condition, if_body))
# Convert dispatch cases to if/elif/else structure
if_node = None
for i, (condition, body) in enumerate(reversed(dispatch_cases)):
if i == 0: # Last case (will be the 'else' clause)
if_node = ast.If(
test=condition,
body=body,
orelse=[]
)
else:
if_node = ast.If(
test=condition,
body=body,
orelse=[if_node] if if_node else []
)
# Create the while loop with the dispatch logic
while_body = [if_node] if if_node else []
while_loop = ast.While(
test=loop_condition,
body=while_body,
orelse=[]
)
# Add some junk code to obscure the control flow
junk = self.generate_junk_code()
# Assemble the flattened function body
flattened_body = [init_state] + junk + [while_loop]
return flattened_body
def generate_junk_code(self) -> List[ast.stmt]:
"""Generate meaningless code to obscure the control flow."""
junk = []
# Add dummy variables that look like they're used for control flow
dummy_vars = [f"_cflow_{i}" for i in range(random.randint(2, 5))]
for var in dummy_vars:
# Initialize with random value
junk.append(
ast.Assign(
targets=[ast.Name(id=var, ctx=ast.Store())],
value=ast.Constant(value=random.randint(0, 100), kind=None)
)
)
# Add some conditional statements that don't do anything important
if dummy_vars:
cond = ast.Compare(
left=ast.Name(id=random.choice(dummy_vars), ctx=ast.Load()),
ops=[ast.Gt()],
comparators=[ast.Constant(value=50, kind=None)]
)
junk.append(
ast.If(
test=cond,
body=[
ast.Assign(
targets=[ast.Name(id=random.choice(dummy_vars), ctx=ast.Store())],
value=ast.Constant(value=random.randint(0, 100), kind=None)
)
],
orelse=[]
)
)
if self.debug_mode:
self.log_debug("junk_code", {
"statements": len(junk),
"variables": dummy_vars
})
return junk
def visit_If(self, node: ast.If) -> ast.stmt:
"""
Transform if statements into state transitions.
Conditional branches become separate states in the state machine.
"""
# Create new states for the true and false branches
true_branch_id = self.current_block_id
self.current_block_id += 1
self.states[true_branch_id] = node.body
if node.orelse:
false_branch_id = self.current_block_id
self.current_block_id += 1
self.states[false_branch_id] = node.orelse
else:
# If no else branch, use the next block in sequence
false_branch_id = self.current_block_id
# Create a conditional assignment to the state variable
result = ast.Assign(
targets=[ast.Name(id=self.state_var_name, ctx=ast.Store())],
value=ast.IfExp(
test=self.visit(node.test),
body=ast.Constant(value=true_branch_id, kind=None),
orelse=ast.Constant(value=false_branch_id, kind=None)
)
)
if self.debug_mode:
self.log_debug("if_transformation", {
"true_branch": true_branch_id,
"false_branch": false_branch_id,
"has_else": bool(node.orelse)
})
return result
def visit_For(self, node: ast.For) -> List[ast.stmt]:
"""
Transform for loops into state transitions with loop body and else clause
as separate states.
"""
# Create unique variable names for this loop
iter_var = f"_iter_{self.current_block_id}"
index_var = f"_idx_{self.current_block_id}"
# Setup the iteration
setup_stmts = [
# Create iterator: _iter_X = iter(iterable)
ast.Assign(
targets=[ast.Name(id=iter_var, ctx=ast.Store())],
value=ast.Call(
func=ast.Name(id="iter", ctx=ast.Load()),
args=[self.visit(node.iter)],
keywords=[]
)
),
# Initialize index: _idx_X = 0
ast.Assign(
targets=[ast.Name(id=index_var, ctx=ast.Store())],
value=ast.Constant(value=0, kind=None)
)
]
# Create states for the loop body and else clause
loop_body_id = self.current_block_id
self.current_block_id += 1
# Loop body needs to get the next item and assign it to the target
loop_body = [
# try: target = next(_iter_X)
ast.Try(
body=[
ast.Assign(
targets=[self.visit(node.target)],
value=ast.Call(
func=ast.Name(id="next", ctx=ast.Load()),
args=[ast.Name(id=iter_var, ctx=ast.Load())],
keywords=[]
)
)
],
# except StopIteration: goto else_clause or skip if no else
handlers=[
ast.ExceptHandler(
type=ast.Name(id="StopIteration", ctx=ast.Load()),
name=None,
body=[
ast.Assign(
targets=[ast.Name(id=self.state_var_name, ctx=ast.Store())],
value=ast.Constant(
value=(self.current_block_id if node.orelse else -1),
kind=None
)
)
]
)
],
# no finally
orelse=[],
finalbody=[]
)
] + node.body + [
# Increment index
ast.AugAssign(
target=ast.Name(id=index_var, ctx=ast.Store()),
op=ast.Add(),
value=ast.Constant(value=1, kind=None)
),
# Loop back to body
ast.Assign(
targets=[ast.Name(id=self.state_var_name, ctx=ast.Store())],
value=ast.Constant(value=loop_body_id, kind=None)
)
]
self.states[loop_body_id] = loop_body
# Handle else clause if present
if node.orelse:
else_id = self.current_block_id
self.current_block_id += 1
self.states[else_id] = node.orelse
if self.debug_mode:
self.log_debug("for_loop_transformation", {
"iterator_var": iter_var,
"index_var": index_var,
"body_block": loop_body_id,
"has_else": bool(node.orelse),
"else_block": else_id if node.orelse else None
})
# Transition to the loop body after setup
setup_stmts.append(
ast.Assign(
targets=[ast.Name(id=self.state_var_name, ctx=ast.Store())],
value=ast.Constant(value=loop_body_id, kind=None)
)
)
return setup_stmts
def visit_While(self, node: ast.While) -> List[ast.stmt]:
"""
Transform while loops into state transitions with conditional jumps.
"""
# Create states for the condition check, body, and else clause
cond_check_id = self.current_block_id
self.current_block_id += 1
# Body state ID
body_id = self.current_block_id
self.current_block_id += 1
# Else state ID (if present)
else_id = self.current_block_id if node.orelse else -1
if node.orelse:
self.current_block_id += 1
# Condition check state: if test: goto body else: goto else/exit
cond_check = [
ast.Assign(
targets=[ast.Name(id=self.state_var_name, ctx=ast.Store())],
value=ast.IfExp(
test=self.visit(node.test),
body=ast.Constant(value=body_id, kind=None),
orelse=ast.Constant(value=else_id, kind=None)
)
)
]
self.states[cond_check_id] = cond_check
# Body state: execute body then goto condition check
body = list(node.body) + [
ast.Assign(
targets=[ast.Name(id=self.state_var_name, ctx=ast.Store())],
value=ast.Constant(value=cond_check_id, kind=None)
)
]
self.states[body_id] = body
# Else state (if present)
if node.orelse:
self.states[else_id] = node.orelse
if self.debug_mode:
self.log_debug("while_loop_transformation", {
"condition_check_block": cond_check_id,
"body_block": body_id,
"has_else": bool(node.orelse),
"else_block": else_id if else_id != -1 else None
})
# Initial transition to the condition check
return [
ast.Assign(
targets=[ast.Name(id=self.state_var_name, ctx=ast.Store())],
value=ast.Constant(value=cond_check_id, kind=None)
)
]
def visit_Break(self, node: ast.Break) -> ast.Assign:
"""
Transform break statements into a state change to exit the current loop.
In a flattened control flow, this means finding the next state after the loop.
For simplicity, we'll just set to -1 (terminate), but a more sophisticated
approach would track enclosing loops and their exit states.
"""
if self.debug_mode:
self.log_debug("break_transformation", {
"exit_state": -1
})
return ast.Assign(
targets=[ast.Name(id=self.state_var_name, ctx=ast.Store())],
value=ast.Constant(value=-1, kind=None)
)
def visit_Continue(self, node: ast.Continue) -> ast.Assign:
"""
Transform continue statements to go back to the loop condition.
In a flattened control flow, we need to know the loop condition state.
For simplicity, we'll implement a jump to the current state, effectively
rerunning the current block, but a more sophisticated approach would track
enclosing loops and their condition states.
"""
if self.debug_mode:
self.log_debug("continue_transformation", {
"target_state": "current loop condition (simplified)"
})
# In this simplified model, we just loop back to the current state
# A more complete implementation would track the loop stack and jump to the loop start
return ast.Assign(
targets=[ast.Name(id=self.state_var_name, ctx=ast.Store())],
value=ast.Constant(value=0, kind=None)
)
def visit_Return(self, node: ast.Return) -> List[ast.stmt]:
"""
Transform return statements into state transitions that terminate the function.
"""
# Evaluate return value if present, then exit the state machine
if node.value:
return_stmt = ast.Return(value=self.visit(node.value))
exit_stmt = ast.Assign(
targets=[ast.Name(id=self.state_var_name, ctx=ast.Store())],
value=ast.Constant(value=-1, kind=None)
)
if self.debug_mode:
self.log_debug("return_transformation", {
"has_value": True
})
return [return_stmt, exit_stmt]
else:
# No return value
if self.debug_mode:
self.log_debug("return_transformation", {
"has_value": False
})
return [
ast.Return(value=None),
ast.Assign(
targets=[ast.Name(id=self.state_var_name, ctx=ast.Store())],
value=ast.Constant(value=-1, kind=None)
)
]
+423
View File
@@ -0,0 +1,423 @@
import ast
from utils.encryption import StringEncryptor
from utils.name_gen import NameGenerator
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("RenameTransformer")
class RenameTransformer(ast.NodeTransformer):
def __init__(self, name_generator, global_var_renames, class_attr_mapping,
primary_key, secondary_key, salt, debug_mode=False):
self.name_generator = name_generator
self.global_var_renames = global_var_renames
self.class_attr_mapping = class_attr_mapping
# String encryption
self.encryptor = StringEncryptor(primary_key, secondary_key, salt)
# We'll collect any code needed for key setup (from encrypt_string calls).
self.key_setup_code = []
# Each element is a dict that maps old_name -> new_name for the local scope
self.scope_stack = [{}]
# For class/attribute rename support
self.in_class = False
self.current_class_name = None
# Add this to track methods in each class
self.class_method_mapping = {}
# Add a first-pass flag to help with method detection
self.first_pass = True
# Debug data if debug mode is enabled
self.debug_mode = debug_mode
if self.debug_mode:
self.debug_data = {
"variable_mappings": {},
"string_encryption": [],
"renamed_nodes": 0,
"issues": []
}
def log_debug(self, category: str, data: any):
"""Log debugging information if debug mode is on."""
if self.debug_mode:
if category not in self.debug_data:
self.debug_data[category] = []
self.debug_data[category].append(data)
def visit_Module(self, node: ast.Module) -> ast.Module:
"""
Enhanced module visitor to properly handle global variables
AND top-level function definitions for consistent renaming.
Now includes a two-pass strategy for classes to ensure method consistency.
"""
# First pass: detect methods in classes
if self.first_pass:
self.first_pass = False
# First scan all classes and methods recursively
for stmt in ast.walk(node):
if isinstance(stmt, ast.ClassDef):
self.scan_class_methods(stmt)
# Rest of the code remains the same...
# First pass: detect top-level assignments, globals, AND function definitions
for stmt in node.body:
# 1. If a global statement
if isinstance(stmt, ast.Global):
for name in stmt.names:
if name not in self.global_var_renames:
self.global_var_renames[name] = self.name_generator.generate_name()
# 2. If a top-level assignment
elif isinstance(stmt, ast.Assign):
for target in stmt.targets:
if isinstance(target, ast.Name):
if target.id not in self.global_var_renames:
self.global_var_renames[target.id] = self.name_generator.generate_name()
# 3. If a top-level function definition
elif isinstance(stmt, ast.FunctionDef):
# Skip special (dunder) methods to avoid messing up e.g. __init__
if not (stmt.name.startswith('__') and stmt.name.endswith('__')):
if stmt.name not in self.global_var_renames:
self.global_var_renames[stmt.name] = self.name_generator.generate_name()
# Create a new body that starts with any needed imports
new_body = [
ast.parse("import base64").body[0],
ast.parse("import random").body[0]
]
# Transform the module body
transformed_body = []
for item in node.body:
visited = self.visit(item)
if isinstance(visited, list):
transformed_body.extend(visited)
else:
transformed_body.append(visited)
# If there's any key setup code, parse & insert that
if self.key_setup_code:
setup_nodes = ast.parse('\n'.join(self.key_setup_code)).body
new_body.extend(setup_nodes)
new_body.extend(transformed_body)
node.body = new_body
return node
def scan_class_methods(self, node):
"""
Pre-scan a class to identify and map all its methods before actual renaming.
"""
class_name = node.name
# Create a mapping entry for this class if not exists
if class_name not in self.class_method_mapping:
self.class_method_mapping[class_name] = {}
# Scan all method definitions in the class
for item in node.body:
if isinstance(item, ast.FunctionDef):
method_name = item.name
# Skip dunder methods
if not (method_name.startswith('__') and method_name.endswith('__')):
# Generate a consistent obfuscated name for this method
new_name = self.name_generator.generate_name()
self.class_method_mapping[class_name][method_name] = new_name
def _push_scope(self):
self.scope_stack.append({})
def _pop_scope(self):
self.scope_stack.pop()
def visit_Global(self, node: ast.Global) -> ast.Global:
"""
Handle global statement declarations by:
1. Adding the variable names to global_var_renames if not already there
2. Adding them to the current scope to mark them as global
"""
for name in node.names:
# If this global name hasn't been seen before, generate a new obfuscated name
if name not in self.global_var_renames:
self.global_var_renames[name] = self.name_generator.generate_name()
# Mark this name as global in the current scope
self.scope_stack[-1][name] = self.global_var_renames[name]
# Update the global statement with obfuscated names
node.names = [self.global_var_renames[name] for name in node.names]
return node
def visit_ListComp(self, node: ast.ListComp) -> ast.ListComp:
self._push_scope()
for gen in node.generators:
gen = self.visit(gen)
node.elt = self.visit(node.elt)
self._pop_scope()
return node
def visit_SetComp(self, node: ast.SetComp) -> ast.SetComp:
self._push_scope()
for gen in node.generators:
gen = self.visit(gen)
node.elt = self.visit(node.elt)
self._pop_scope()
return node
def visit_DictComp(self, node: ast.DictComp) -> ast.DictComp:
self._push_scope()
for gen in node.generators:
gen = self.visit(gen)
node.key = self.visit(node.key)
node.value = self.visit(node.value)
self._pop_scope()
return node
def visit_GeneratorExp(self, node: ast.GeneratorExp) -> ast.GeneratorExp:
self._push_scope()
for gen in node.generators:
gen = self.visit(gen)
node.elt = self.visit(node.elt)
self._pop_scope()
return node
def visit_comprehension(self, node: ast.comprehension) -> ast.comprehension:
if isinstance(node.target, ast.Name):
if node.target.id not in self.global_var_renames:
new_name = self.name_generator.generate_name()
self.scope_stack[-1][node.target.id] = new_name
node.target.id = new_name
else:
node.target = self.visit(node.target)
node.iter = self.visit(node.iter)
node.ifs = [self.visit(i) for i in node.ifs]
return node
def visit_Name(self, node: ast.Name) -> ast.AST:
"""
Handle variable names and function names in calls.
1. If it's a known global/ top-level function, use global_var_renames.
2. Otherwise, handle locally within scope_stack.
"""
if node.id in self.global_var_renames:
node.id = self.global_var_renames[node.id]
return node
# Check if this name is marked as global in any scope
for scope in self.scope_stack:
if node.id in scope and node.id in self.global_var_renames:
node.id = self.global_var_renames[node.id]
return node
# Otherwise, handle local variables
if isinstance(node.ctx, ast.Store):
if self.in_class and not isinstance(node.ctx, ast.Param):
# Class attribute
if node.id not in self.scope_stack[-1]:
new_name = self.name_generator.generate_name()
self.scope_stack[-1][node.id] = new_name
if self.current_class_name:
self.class_attr_mapping[self.current_class_name][node.id] = new_name
node.id = self.scope_stack[-1][node.id]
else:
# Regular variable assignment
if node.id not in self.scope_stack[-1]:
self.scope_stack[-1][node.id] = self.name_generator.generate_name()
node.id = self.scope_stack[-1][node.id]
else:
# Load context
for scope in reversed(self.scope_stack):
if node.id in scope:
node.id = scope[node.id]
break
return node
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
has_bases = len(node.bases) > 0
self.scope_stack[-1]['_has_bases'] = has_bases
prev_in_class = self.in_class
prev_class_name = self.current_class_name
self.in_class = True
new_class_name = self.name_generator.generate_name()
self.scope_stack[-1][node.name] = new_class_name
self.current_class_name = new_class_name
self.class_attr_mapping[new_class_name] = {}
# Transfer method mappings from the original class name to the renamed one
if node.name in self.class_method_mapping:
self.class_method_mapping[new_class_name] = self.class_method_mapping[node.name]
del self.class_method_mapping[node.name]
self.scope_stack.append({})
node.bases = [self.visit(base) for base in node.bases]
node.body = [self.visit(b) for b in node.body]
self.scope_stack[-2][node.name] = new_class_name
node.name = new_class_name
self.in_class = prev_in_class
self.current_class_name = prev_class_name
self.scope_stack.pop()
return node
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
"""
Make sure top-level function definitions also get renamed
if they're in global_var_renames.
"""
is_special = node.name.startswith('__') and node.name.endswith('__')
has_bases = any(('_has_bases' in scope and scope['_has_bases']) for scope in self.scope_stack)
# Save original name for later
orig_name = node.name
self._push_scope()
# Handle arguments (skip renaming for self if method):
if self.in_class and len(node.args.args) > 0:
self.scope_stack[-1][node.args.args[0].arg] = node.args.args[0].arg
for arg in node.args.args[1:]:
if arg.arg not in self.global_var_renames:
new_arg_name = self.name_generator.generate_name()
self.scope_stack[-1][arg.arg] = new_arg_name
arg.arg = new_arg_name
else:
for arg in node.args.args:
if arg.arg not in self.global_var_renames:
new_arg_name = self.name_generator.generate_name()
self.scope_stack[-1][arg.arg] = new_arg_name
arg.arg = new_arg_name
# Visit body
node.body = [self.visit(n) for n in node.body]
# --------------------------------
# NEW LOGIC: Actually rename the func if it's top-level
# or if we want to rename it anyway (and not dunder).
# --------------------------------
if not is_special:
# If the function was declared top-level and recognized in global_var_renames,
# then rename using that. Otherwise generate a brand new obfuscated name.
if node.name in self.global_var_renames:
node.name = self.global_var_renames[node.name]
elif self.in_class and not has_bases:
# For class methods, use the name we saved earlier
if self.current_class_name and orig_name in self.class_method_mapping.get(self.current_class_name, {}):
node.name = self.class_method_mapping[self.current_class_name][orig_name]
else:
new_fn_name = self.name_generator.generate_name()
self.scope_stack[-2][node.name] = new_fn_name
node.name = new_fn_name
self._pop_scope()
return node
def visit_Attribute(self, node: ast.Attribute) -> ast.AST:
"""
Handle attribute access for renaming:
1. Map method calls on self to the corresponding renamed methods
2. Preserve external/inherited method names
3. Rename self.attribute for class-defined attributes
"""
# First visit any nested expressions
node = self.generic_visit(node)
# Check if this is a self.something attribute access
if isinstance(node.value, ast.Name) and node.value.id == 'self':
# Check all class methods across all classes (more robust)
for class_name, methods in self.class_method_mapping.items():
if node.attr in methods:
# We found a match in our method mapping
node.attr = methods[node.attr]
return node
# If we're in a class context, apply class-specific logic
if self.current_class_name:
# Case 1: Is it a call to one of our renamed methods?
class_methods = self.class_method_mapping.get(self.current_class_name, {})
if node.attr in class_methods:
node.attr = class_methods[node.attr]
return node
# Case 2: Is it an external method call with Qt-style naming?
is_external_method = (
node.attr[0].islower() and
any(c.isupper() for c in node.attr) and
not node.attr.startswith('__')
)
if is_external_method:
return node
# Case 3: Handle normal class attributes
attr_map = self.class_attr_mapping.get(self.current_class_name, {})
if node.attr not in attr_map:
attr_map[node.attr] = self.name_generator.generate_name()
node.attr = attr_map[node.attr]
return node
def visit_Subscript(self, node: ast.Subscript) -> ast.AST:
node.value = self.visit(node.value)
if isinstance(node.slice, ast.AST):
self.visit(node.slice)
return node
def visit_Constant(self, node: ast.Constant) -> ast.AST:
"""
Encrypt string literals into a multi-step XOR, then base85 decode at runtime.
"""
if isinstance(node.value, str):
encoded, key_setup, modifier = self.encryptor.encrypt_string(node.value)
if key_setup not in self.key_setup_code:
self.key_setup_code.append(key_setup)
decrypt_str = (
f"bytes(("
f"k2^k1^m for k1,k2,m in zip("
f"bytes(c^k for c,k in zip(base64.b85decode('{encoded}'),"
f"_sk*((len(base64.b85decode('{encoded}'))//8)+1))),"
f"_pk*((len(base64.b85decode('{encoded}'))//8)+1),"
f"bytes.fromhex('{modifier}')*((len(base64.b85decode('{encoded}'))//8)+1)))"
f").decode()"
)
return ast.parse(decrypt_str).body[0].value
return node
def visit_Import(self, node: ast.Import) -> ast.AST:
return self.generic_visit(node)
def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.AST:
return self.generic_visit(node)
def visit_Call(self, node: ast.Call) -> ast.Call:
"""
Handle super() calls specifically to ensure class names are properly updated.
"""
# First visit all arguments and the function itself
node = self.generic_visit(node)
# Check if this is a super() call
if isinstance(node.func, ast.Name) and node.func.id == 'super':
# For super() with no args in Python 3
if not node.args:
return node
# For super(Class, self) style calls
if len(node.args) >= 1 and isinstance(node.args[0], ast.Name):
class_name = node.args[0].id
# Look for the renamed class in all scopes
for scope in reversed(self.scope_stack):
if class_name in scope:
node.args[0].id = scope[class_name]
break
return node
+551
View File
@@ -0,0 +1,551 @@
import ast
from typing import Dict, Set, List, Optional, Union, Tuple
import logging
from enum import Enum
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("SymbolTree")
class SymbolType(Enum):
"""Defines the different types of symbols that can be tracked."""
VARIABLE = "variable"
FUNCTION = "function"
CLASS = "class"
METHOD = "method"
ARGUMENT = "argument"
ATTRIBUTE = "attribute"
MODULE = "module"
IMPORT = "import"
class Symbol:
"""Represents a symbol in the code with its name, type and other metadata."""
def __init__(self, name: str, symbol_type: SymbolType, node: ast.AST = None):
self.name = name # Original name
self.obfuscated_name: Optional[str] = None # Obfuscated name (if assigned)
self.symbol_type = symbol_type
self.node = node # AST node where this symbol is defined
self.references: List[ast.AST] = [] # AST nodes where this symbol is referenced
self.parent: Optional['Scope'] = None # Parent scope
self.is_obfuscatable = True # Whether this symbol should be obfuscated
# Additional attributes for specific symbol types
self.is_imported = False # Whether this symbol was imported
self.original_module: Optional[str] = None # If imported, the module it was imported from
def add_reference(self, node: ast.AST):
"""Add a reference to this symbol."""
self.references.append(node)
def __repr__(self):
return f"<Symbol {self.name} [{self.symbol_type.value}] {'' + self.obfuscated_name if self.obfuscated_name else ''} refs:{len(self.references)}>"
class Scope:
"""
Represents a scope in the code, such as a module, function, class, or comprehension.
"""
def __init__(self, name: str, scope_type: str, node: ast.AST = None):
self.name = name
self.scope_type = scope_type
self.node = node
# Maps: symbol name -> Symbol object
self.symbols: Dict[str, Symbol] = {}
# Child scopes within this scope
self.children: List['Scope'] = []
# Parent scope (None for the global/module scope)
self.parent: Optional['Scope'] = None
def add_symbol(self, symbol: Symbol) -> Symbol:
"""Add a symbol to this scope and return it."""
self.symbols[symbol.name] = symbol
symbol.parent = self
return symbol
def add_child_scope(self, scope: 'Scope') -> 'Scope':
"""Add a child scope to this scope and return it."""
self.children.append(scope)
scope.parent = self
return scope
def lookup(self, name: str) -> Optional[Symbol]:
"""Look up a symbol in this scope, or in parent scopes."""
if name in self.symbols:
return self.symbols[name]
elif self.parent:
return self.parent.lookup(name)
return None
def get_qualified_name(self) -> str:
"""Get the fully qualified name of this scope."""
if self.parent and self.parent.name:
return f"{self.parent.get_qualified_name()}.{self.name}"
return self.name
def __repr__(self):
return f"<Scope {self.get_qualified_name()} [{self.scope_type}] symbols:{len(self.symbols)} children:{len(self.children)}>"
class ClassScope(Scope):
"""A specialized scope for classes with additional tracking for inheritance."""
def __init__(self, name: str, node: ast.ClassDef):
super().__init__(name, "class", node)
self.base_classes: List[str] = [] # Names of base classes
self.methods: Dict[str, Symbol] = {} # Methods defined in this class
self.attributes: Dict[str, Symbol] = {} # Attributes defined in this class
def add_base_class(self, base_name: str):
"""Add a base class to this class's inheritance list."""
if base_name not in self.base_classes:
self.base_classes.append(base_name)
def add_method(self, method: Symbol) -> Symbol:
"""Add a method to this class."""
self.methods[method.name] = method
return self.add_symbol(method)
def add_attribute(self, attr: Symbol) -> Symbol:
"""Add an attribute to this class."""
self.attributes[attr.name] = attr
return self.add_symbol(attr)
class ModuleScope(Scope):
"""A specialized scope for modules with additional tracking for imports."""
def __init__(self, name: str, node: ast.Module):
super().__init__(name, "module", node)
self.imports: Dict[str, str] = {} # Import alias -> original name
self.from_imports: Dict[str, Dict[str, str]] = {} # Module -> {alias -> original name}
def add_import(self, alias: str, original: str):
"""Add an import to this module."""
self.imports[alias] = original
def add_from_import(self, module: str, alias: str, original: str):
"""Add a from-import to this module."""
if module not in self.from_imports:
self.from_imports[module] = {}
self.from_imports[module][alias] = original
class SymbolTree:
"""
Global symbol tree that maintains a hierarchy of scopes and symbols
across the entire codebase.
"""
def __init__(self):
# The root scope is a special module scope named "__root__"
self.root_scope = ModuleScope("__root__", None)
# Current scope being processed
self.current_scope = self.root_scope
# Track classes for inheritance resolution
self.classes: Dict[str, ClassScope] = {}
# Track all symbols by their fully qualified name
self.all_symbols: Dict[str, Symbol] = {}
# Track imports for proper resolution
self.imports: Dict[str, str] = {} # alias -> module
def push_scope(self, name: str, scope_type: str, node: ast.AST) -> Scope:
"""Create a new scope and make it the current scope."""
if scope_type == "class":
new_scope = ClassScope(name, node)
elif scope_type == "module":
new_scope = ModuleScope(name, node)
else:
new_scope = Scope(name, scope_type, node)
self.current_scope.add_child_scope(new_scope)
self.current_scope = new_scope
# If this is a class, track it
if scope_type == "class":
fully_qualified = new_scope.get_qualified_name()
self.classes[fully_qualified] = new_scope
# Also track with just the class name for simpler lookups
self.classes[name] = new_scope
return new_scope
def pop_scope(self) -> Scope:
"""Exit the current scope and return to its parent."""
old_scope = self.current_scope
if self.current_scope.parent:
self.current_scope = self.current_scope.parent
return old_scope
def add_symbol(self, name: str, symbol_type: SymbolType, node: ast.AST = None) -> Symbol:
"""Add a symbol to the current scope."""
symbol = Symbol(name, symbol_type, node)
self.current_scope.add_symbol(symbol)
# Track in the global map
qualified_name = f"{self.current_scope.get_qualified_name()}.{name}"
self.all_symbols[qualified_name] = symbol
# If this is a method in a class scope
if symbol_type == SymbolType.METHOD and isinstance(self.current_scope, ClassScope):
self.current_scope.add_method(symbol)
# If this is an attribute in a class scope
elif symbol_type == SymbolType.ATTRIBUTE and isinstance(self.current_scope, ClassScope):
self.current_scope.add_attribute(symbol)
return symbol
def add_reference(self, name: str, node: ast.AST):
"""Add a reference to a symbol."""
symbol = self.current_scope.lookup(name)
if symbol:
symbol.add_reference(node)
def resolve_inheritance(self):
"""
Resolve inheritance relationships between classes to ensure
consistent method and attribute renaming.
"""
def resolve_class(class_scope: ClassScope, visited=None):
if visited is None:
visited = set()
# Skip if already visited to prevent infinite recursion
if class_scope.name in visited:
return
visited.add(class_scope.name)
# Process each base class
for base_name in class_scope.base_classes:
# Skip if the base class is not in our tree (e.g., external library)
if base_name not in self.classes:
continue
base_scope = self.classes[base_name]
# Resolve the base class first
resolve_class(base_scope, visited)
# Copy method symbols from base to derived if not overridden
for method_name, method_symbol in base_scope.methods.items():
if method_name not in class_scope.methods:
# Create a new symbol in the derived class that references the base class method
derived_method = Symbol(method_name, SymbolType.METHOD)
class_scope.add_method(derived_method)
# Use the same obfuscated name as the base class method
# (even if the base class method hasn't been obfuscated yet)
derived_method.obfuscated_name = method_symbol.obfuscated_name
# Process all classes
for class_scope in self.classes.values():
resolve_class(class_scope)
def check_for_issues(self) -> List[Dict]:
"""Check for potential issues in the symbol tree."""
issues = []
# Check for duplicated obfuscated names
obfuscated_names = {}
for qualified_name, symbol in self.all_symbols.items():
if not symbol.obfuscated_name:
continue
if symbol.obfuscated_name in obfuscated_names:
issues.append({
"type": "duplicate_obfuscated_name",
"obfuscated_name": symbol.obfuscated_name,
"symbols": [qualified_name, obfuscated_names[symbol.obfuscated_name]]
})
else:
obfuscated_names[symbol.obfuscated_name] = qualified_name
# Check for inconsistent method obfuscation in inheritance hierarchies
for class_name, class_scope in self.classes.items():
for base_name in class_scope.base_classes:
if base_name not in self.classes:
continue
base_scope = self.classes[base_name]
for method_name, method_symbol in base_scope.methods.items():
if method_name in class_scope.methods:
derived_method = class_scope.methods[method_name]
if (method_symbol.obfuscated_name and derived_method.obfuscated_name and
method_symbol.obfuscated_name != derived_method.obfuscated_name):
issues.append({
"type": "inconsistent_method_obfuscation",
"method_name": method_name,
"base_class": base_name,
"derived_class": class_name,
"base_obfuscated": method_symbol.obfuscated_name,
"derived_obfuscated": derived_method.obfuscated_name
})
return issues
def apply_name_generator(self, name_generator):
"""
Apply a name generator to all symbols that need obfuscation.
Ensures consistent renaming across the entire codebase.
"""
# First, handle classes
for class_scope in self.classes.values():
class_symbol = self.current_scope.lookup(class_scope.name)
if class_symbol and class_symbol.is_obfuscatable:
class_symbol.obfuscated_name = name_generator.generate_name()
# Then handle methods to ensure consistency across inheritance
self.resolve_inheritance()
# Apply to all other symbols
for symbol in self.all_symbols.values():
# Skip if already obfuscated or not obfuscatable
if symbol.obfuscated_name or not symbol.is_obfuscatable:
continue
# Skip special names
if symbol.name.startswith("__") and symbol.name.endswith("__"):
continue
symbol.obfuscated_name = name_generator.generate_name()
def get_rename_mapping(self) -> Dict[str, Dict[str, str]]:
"""
Get a mapping for all symbols to their obfuscated names,
organized by symbol type for use in transformers.
"""
mapping = {
"variables": {},
"functions": {},
"classes": {},
"methods": {},
"attributes": {}
}
for symbol in self.all_symbols.values():
if not symbol.obfuscated_name:
continue
if symbol.symbol_type == SymbolType.VARIABLE:
mapping["variables"][symbol.name] = symbol.obfuscated_name
elif symbol.symbol_type == SymbolType.FUNCTION:
mapping["functions"][symbol.name] = symbol.obfuscated_name
elif symbol.symbol_type == SymbolType.CLASS:
mapping["classes"][symbol.name] = symbol.obfuscated_name
elif symbol.symbol_type == SymbolType.METHOD:
# For methods, we need the class name
if isinstance(symbol.parent, ClassScope):
class_name = symbol.parent.name
if class_name not in mapping["methods"]:
mapping["methods"][class_name] = {}
mapping["methods"][class_name][symbol.name] = symbol.obfuscated_name
elif symbol.symbol_type == SymbolType.ATTRIBUTE:
# For attributes, we need the class name
if isinstance(symbol.parent, ClassScope):
class_name = symbol.parent.name
if class_name not in mapping["attributes"]:
mapping["attributes"][class_name] = {}
mapping["attributes"][class_name][symbol.name] = symbol.obfuscated_name
return mapping
class SymbolTreeBuilder(ast.NodeVisitor):
"""
Builds a symbol tree by visiting all nodes in the AST.
"""
def __init__(self):
self.tree = SymbolTree()
# Track whether we're in a class definition
self.in_class_def = False
self.current_class = None
# Track function augments to avoid creating symbols for them twice
self.current_function_args = set()
# Track whether we're in an attribute context
self.in_attribute_ctx = False
def visit_Module(self, node: ast.Module):
"""Process a module node."""
self.tree.push_scope("__main__", "module", node)
# Visit all statements in the module
for stmt in node.body:
self.visit(stmt)
self.tree.pop_scope()
def visit_ClassDef(self, node: ast.ClassDef):
"""Process a class definition."""
# Create a new class scope
class_scope = self.tree.push_scope(node.name, "class", node)
# Add class to current scope's symbols
self.tree.add_symbol(node.name, SymbolType.CLASS, node)
# Track base classes
for base in node.bases:
if isinstance(base, ast.Name):
class_scope.add_base_class(base.id)
# Track reference to the base class
self.tree.add_reference(base.id, base)
# Save previous state and update current state
prev_in_class = self.in_class_def
prev_class = self.current_class
self.in_class_def = True
self.current_class = node.name
# Visit class body
for item in node.body:
self.visit(item)
# Restore previous state
self.in_class_def = prev_in_class
self.current_class = prev_class
# Exit class scope
self.tree.pop_scope()
def visit_FunctionDef(self, node: ast.FunctionDef):
"""Process a function definition."""
# Determine if this is a method or a regular function
symbol_type = SymbolType.METHOD if self.in_class_def else SymbolType.FUNCTION
# Add function/method to current scope's symbols
self.tree.add_symbol(node.name, symbol_type, node)
# Create a new function scope
self.tree.push_scope(node.name, "function", node)
# Clear current function arguments set
self.current_function_args = set()
# Process arguments
self.visit(node.args)
# Visit function body
for item in node.body:
self.visit(item)
# Exit function scope
self.tree.pop_scope()
def visit_arguments(self, node: ast.arguments):
"""Process function arguments."""
# Process positional arguments
for arg in node.args:
self.current_function_args.add(arg.arg)
self.tree.add_symbol(arg.arg, SymbolType.ARGUMENT, arg)
# Process vararg (e.g., *args)
if node.vararg:
self.current_function_args.add(node.vararg.arg)
self.tree.add_symbol(node.vararg.arg, SymbolType.ARGUMENT, node.vararg)
# Process keyword arguments
for kwarg in node.kwonlyargs:
self.current_function_args.add(kwarg.arg)
self.tree.add_symbol(kwarg.arg, SymbolType.ARGUMENT, kwarg)
# Process kwarg (e.g., **kwargs)
if node.kwarg:
self.current_function_args.add(node.kwarg.arg)
self.tree.add_symbol(node.kwarg.arg, SymbolType.ARGUMENT, node.kwarg)
def visit_Assign(self, node: ast.Assign):
"""Process an assignment statement."""
# Visit the right side first to capture any variable references
self.visit(node.value)
# Now visit the targets (left-hand side)
for target in node.targets:
# Handle attribute assignment (e.g., self.x = value)
if isinstance(target, ast.Attribute):
self.visit_attribute_assignment(target)
# Handle simple name assignment
elif isinstance(target, ast.Name):
# Only add symbol if it's not already a function argument
if target.id not in self.current_function_args:
self.tree.add_symbol(target.id, SymbolType.VARIABLE, target)
# Handle other target types (e.g., subscripts, tuples)
else:
self.visit(target)
def visit_attribute_assignment(self, node: ast.Attribute):
"""Process attribute assignment (e.g., self.x = value)."""
# Check if this is a self.attr assignment in a class
if (isinstance(node.value, ast.Name) and node.value.id == 'self'
and self.in_class_def and self.current_class):
# Add attribute to the current class
self.tree.add_symbol(node.attr, SymbolType.ATTRIBUTE, node)
else:
# Visit the left side to capture any variable references
self.visit(node.value)
def visit_Name(self, node: ast.Name):
"""Process a name (variable reference)."""
# This is a variable/function/class reference, not a definition
if isinstance(node.ctx, ast.Load):
self.tree.add_reference(node.id, node)
def visit_Attribute(self, node: ast.Attribute):
"""Process attribute access (e.g., obj.attr)."""
# Track that we're in an attribute context
prev_in_attribute_ctx = self.in_attribute_ctx
self.in_attribute_ctx = True
# Visit the left side
self.visit(node.value)
# Handle self.attr access in a class
if (isinstance(node.value, ast.Name) and node.value.id == 'self'
and self.in_class_def and self.current_class):
# The attribute might be accessed before it's assigned, so we need to ensure it's in the symbol table
class_scope = self.tree.classes.get(self.current_class)
if class_scope and node.attr not in class_scope.attributes:
self.tree.add_symbol(node.attr, SymbolType.ATTRIBUTE, node)
# Restore previous state
self.in_attribute_ctx = prev_in_attribute_ctx
def visit_Import(self, node: ast.Import):
"""Process an import statement."""
for item in node.names:
# The imported name should not be obfuscated
symbol = self.tree.add_symbol(item.asname or item.name, SymbolType.IMPORT, node)
symbol.is_obfuscatable = False
symbol.is_imported = True
# Track the import
if isinstance(self.tree.current_scope, ModuleScope):
module_scope = self.tree.current_scope
module_scope.add_import(item.asname or item.name, item.name)
def visit_ImportFrom(self, node: ast.ImportFrom):
"""Process a from-import statement."""
for item in node.names:
# The imported name should not be obfuscated
symbol = self.tree.add_symbol(item.asname or item.name, SymbolType.IMPORT, node)
symbol.is_obfuscatable = False
symbol.is_imported = True
symbol.original_module = node.module
# Track the import
if isinstance(self.tree.current_scope, ModuleScope):
module_scope = self.tree.current_scope
module_scope.add_from_import(node.module, item.asname or item.name, item.name)
# Add more visit methods for other AST node types as needed
def build_tree(self, tree: ast.AST) -> SymbolTree:
"""Build the symbol tree from the AST."""
self.visit(tree)
# Perform final processing
self.tree.resolve_inheritance()
return self.tree