import ast from typing import Dict, Set, List, Tuple, Optional import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger("ClassMapper") class ClassMapping: """Stores all mappings related to classes in a centralized way.""" def __init__(self): # Original class name -> obfuscated class name self.class_names: Dict[str, str] = {} # Original class name -> {original method name -> obfuscated method name} self.class_methods: Dict[str, Dict[str, str]] = {} # Original class name -> {original attr name -> obfuscated attr name} self.class_attributes: Dict[str, Dict[str, str]] = {} # Child class -> list of parent classes (original names) self.inheritance: Dict[str, List[str]] = {} # Track all seen method calls to ensure complete coverage self.seen_method_calls: Dict[str, Set[str]] = {} def debug_info(self) -> str: """Return debug information about mappings.""" info = [] info.append(f"Class mappings: {len(self.class_names)} classes") for cls_name, obf_name in self.class_names.items(): info.append(f" {cls_name} -> {obf_name}") if cls_name in self.class_methods: methods = self.class_methods[cls_name] info.append(f" Methods: {len(methods)}") for method, obf_method in methods.items(): info.append(f" {method} -> {obf_method}") if cls_name in self.class_attributes: attrs = self.class_attributes[cls_name] info.append(f" Attributes: {len(attrs)}") for attr, obf_attr in attrs.items(): info.append(f" {attr} -> {obf_attr}") return "\n".join(info) class ClassMapAnalyzer(ast.NodeVisitor): """Analyzes the AST to create a complete class mapping.""" def __init__(self, name_generator): self.name_generator = name_generator self.mapping = ClassMapping() self.current_class: Optional[str] = None self.current_method: Optional[str] = None self.processed_classes: Set[str] = set() def analyze(self, tree: ast.AST) -> ClassMapping: """Perform a complete analysis of the AST.""" # First pass: collect all class definitions, methods, and inheritance self.visit(tree) # Second pass: resolve inheritance and method mappings self._resolve_inheritance() self._ensure_complete_method_mapping() logger.info(f"Class analysis complete: {len(self.mapping.class_names)} classes processed") logger.debug(self.mapping.debug_info()) return self.mapping def visit_ClassDef(self, node: ast.ClassDef): """Process a class definition.""" prev_class = self.current_class self.current_class = node.name # Skip if already processed if node.name in self.processed_classes: self.current_class = prev_class return # Add class name mapping if node.name not in self.mapping.class_names: self.mapping.class_names[node.name] = self.name_generator.generate_name() # Initialize dictionaries if node.name not in self.mapping.class_methods: self.mapping.class_methods[node.name] = {} if node.name not in self.mapping.class_attributes: self.mapping.class_attributes[node.name] = {} if node.name not in self.mapping.seen_method_calls: self.mapping.seen_method_calls[node.name] = set() # Record inheritance parent_classes = [] for base in node.bases: if isinstance(base, ast.Name): parent_classes.append(base.id) if parent_classes: self.mapping.inheritance[node.name] = parent_classes # Process class body for item in node.body: if isinstance(item, ast.FunctionDef): self.visit_method_def(item) elif isinstance(item, ast.Assign): self.visit_assign_in_class(item) elif isinstance(item, ast.Expr): # Could contain calls to self.methods self.visit(item) elif isinstance(item, ast.ClassDef): # Nested class self.visit(item) else: # Other nodes that might contain self.method calls self.visit(item) self.processed_classes.add(node.name) self.current_class = prev_class def visit_method_def(self, node: ast.FunctionDef): """Process a method definition in a class.""" if not self.current_class: return prev_method = self.current_method self.current_method = node.name # Skip dunder methods from obfuscation if not (node.name.startswith('__') and node.name.endswith('__')): # Map method name if not already mapped if node.name not in self.mapping.class_methods[self.current_class]: obf_name = self.name_generator.generate_name() self.mapping.class_methods[self.current_class][node.name] = obf_name logger.debug(f"Mapped method {self.current_class}.{node.name} to {obf_name}") # Visit method body to find self.method calls and self.attr assignments for item in node.body: self.visit(item) self.current_method = prev_method def visit_assign_in_class(self, node: ast.Assign): """Process assignments in class body or methods.""" if not self.current_class: return # Check for self.attribute assignments for target in node.targets: if isinstance(target, ast.Attribute) and isinstance(target.value, ast.Name) and target.value.id == 'self': attr_name = target.attr # Map attribute name if not already mapped if attr_name not in self.mapping.class_attributes[self.current_class]: obf_name = self.name_generator.generate_name() self.mapping.class_attributes[self.current_class][attr_name] = obf_name logger.debug(f"Mapped attribute {self.current_class}.{attr_name} to {obf_name}") # Visit the value to find nested self.method calls self.visit(node.value) def visit_Attribute(self, node: ast.Attribute): """Process attribute access like self.method or self.attr.""" if self.current_class and isinstance(node.value, ast.Name) and node.value.id == 'self': # Record this access for later processing method_name = node.attr self.mapping.seen_method_calls[self.current_class].add(method_name) logger.debug(f"Recorded method call: {self.current_class}.{method_name}") # Continue traversal self.generic_visit(node) def visit_Assign(self, node: ast.Assign): """Process assignments that might contain self.attr references.""" # Visit both sides of the assignment for target in node.targets: self.visit(target) self.visit(node.value) def _resolve_inheritance(self): """ Ensure child classes inherit method mappings from parent classes. """ def process_inheritance(class_name): if class_name not in self.mapping.inheritance: return for parent in self.mapping.inheritance[class_name]: # Process parent's inheritance first process_inheritance(parent) # Skip if parent isn't in our mappings if parent not in self.mapping.class_methods: continue # Copy parent's method mappings to child if not overridden for method_name, obf_name in self.mapping.class_methods[parent].items(): if method_name not in self.mapping.class_methods[class_name]: self.mapping.class_methods[class_name][method_name] = obf_name logger.debug(f"Inherited method {class_name}.{method_name} from {parent}") # Process all classes for class_name in list(self.mapping.class_methods.keys()): process_inheritance(class_name) def _ensure_complete_method_mapping(self): """ Make sure all method calls have corresponding mappings. This handles methods called but not defined in the class. """ for class_name, method_calls in self.mapping.seen_method_calls.items(): if class_name not in self.mapping.class_methods: continue for method_name in method_calls: # Skip dunder methods if method_name.startswith('__') and method_name.endswith('__'): continue # Add mapping if method was called but not defined if method_name not in self.mapping.class_methods[class_name]: obf_name = self.name_generator.generate_name() self.mapping.class_methods[class_name][method_name] = obf_name logger.debug(f"Added mapping for called method {class_name}.{method_name} -> {obf_name}") class ClassTransformer(ast.NodeTransformer): """Transforms class-related nodes using the mapping.""" def __init__(self, mapping: ClassMapping): self.mapping = mapping self.current_class: Optional[str] = None def visit_ClassDef(self, node: ast.ClassDef): """Transform class name and process its body.""" prev_class = self.current_class orig_name = node.name self.current_class = orig_name # Rename class if it's in our mapping if node.name in self.mapping.class_names: node.name = self.mapping.class_names[node.name] logger.debug(f"Transformed class {orig_name} -> {node.name}") # Process class body node.body = [self.visit(item) for item in node.body] self.current_class = prev_class return node def visit_FunctionDef(self, node: ast.FunctionDef): """Transform method name.""" if self.current_class and node.name in self.mapping.class_methods.get(self.current_class, {}): orig_name = node.name node.name = self.mapping.class_methods[self.current_class][node.name] logger.debug(f"Transformed method {self.current_class}.{orig_name} -> {node.name}") # Visit the method body node.body = [self.visit(item) for item in node.body] return node def visit_Attribute(self, node: ast.Attribute): """Transform self.method and self.attr references.""" # Process any child nodes first (for nested attributes) node.value = self.visit(node.value) # Check if this is a self.attr or self.method reference if self.current_class and isinstance(node.value, ast.Name) and node.value.id == 'self': orig_name = node.attr # Check in method mappings first if self.current_class in self.mapping.class_methods and node.attr in self.mapping.class_methods[self.current_class]: node.attr = self.mapping.class_methods[self.current_class][node.attr] logger.debug(f"Transformed self.method {self.current_class}.{orig_name} -> {node.attr}") # Then check attribute mappings elif self.current_class in self.mapping.class_attributes and node.attr in self.mapping.class_attributes[self.current_class]: node.attr = self.mapping.class_attributes[self.current_class][node.attr] logger.debug(f"Transformed self.attr {self.current_class}.{orig_name} -> {node.attr}") return node # Helper function to apply the class mapping transformation def apply_class_mapping(tree: ast.AST, name_generator) -> ast.AST: """Analyze and transform classes consistently.""" # First pass: analyze all classes analyzer = ClassMapAnalyzer(name_generator) mapping = analyzer.analyze(tree) # Second pass: transform using the mapping transformer = ClassTransformer(mapping) transformed = transformer.visit(tree) return transformed, mapping