301 lines
12 KiB
Python
301 lines
12 KiB
Python
import ast
|
|
from typing import Dict, Set, List, Tuple, Optional
|
|
import logging
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger("ClassMapper")
|
|
|
|
class ClassMapping:
|
|
"""Stores all mappings related to classes in a centralized way."""
|
|
|
|
def __init__(self):
|
|
# Original class name -> obfuscated class name
|
|
self.class_names: Dict[str, str] = {}
|
|
|
|
# Original class name -> {original method name -> obfuscated method name}
|
|
self.class_methods: Dict[str, Dict[str, str]] = {}
|
|
|
|
# Original class name -> {original attr name -> obfuscated attr name}
|
|
self.class_attributes: Dict[str, Dict[str, str]] = {}
|
|
|
|
# Child class -> list of parent classes (original names)
|
|
self.inheritance: Dict[str, List[str]] = {}
|
|
|
|
# Track all seen method calls to ensure complete coverage
|
|
self.seen_method_calls: Dict[str, Set[str]] = {}
|
|
|
|
def debug_info(self) -> str:
|
|
"""Return debug information about mappings."""
|
|
info = []
|
|
info.append(f"Class mappings: {len(self.class_names)} classes")
|
|
|
|
for cls_name, obf_name in self.class_names.items():
|
|
info.append(f" {cls_name} -> {obf_name}")
|
|
|
|
if cls_name in self.class_methods:
|
|
methods = self.class_methods[cls_name]
|
|
info.append(f" Methods: {len(methods)}")
|
|
for method, obf_method in methods.items():
|
|
info.append(f" {method} -> {obf_method}")
|
|
|
|
if cls_name in self.class_attributes:
|
|
attrs = self.class_attributes[cls_name]
|
|
info.append(f" Attributes: {len(attrs)}")
|
|
for attr, obf_attr in attrs.items():
|
|
info.append(f" {attr} -> {obf_attr}")
|
|
|
|
return "\n".join(info)
|
|
|
|
|
|
class ClassMapAnalyzer(ast.NodeVisitor):
|
|
"""Analyzes the AST to create a complete class mapping."""
|
|
|
|
def __init__(self, name_generator):
|
|
self.name_generator = name_generator
|
|
self.mapping = ClassMapping()
|
|
self.current_class: Optional[str] = None
|
|
self.current_method: Optional[str] = None
|
|
self.processed_classes: Set[str] = set()
|
|
|
|
def analyze(self, tree: ast.AST) -> ClassMapping:
|
|
"""Perform a complete analysis of the AST."""
|
|
# First pass: collect all class definitions, methods, and inheritance
|
|
self.visit(tree)
|
|
|
|
# Second pass: resolve inheritance and method mappings
|
|
self._resolve_inheritance()
|
|
self._ensure_complete_method_mapping()
|
|
|
|
logger.info(f"Class analysis complete: {len(self.mapping.class_names)} classes processed")
|
|
logger.debug(self.mapping.debug_info())
|
|
|
|
return self.mapping
|
|
|
|
def visit_ClassDef(self, node: ast.ClassDef):
|
|
"""Process a class definition."""
|
|
prev_class = self.current_class
|
|
self.current_class = node.name
|
|
|
|
# Skip if already processed
|
|
if node.name in self.processed_classes:
|
|
self.current_class = prev_class
|
|
return
|
|
|
|
# Add class name mapping
|
|
if node.name not in self.mapping.class_names:
|
|
self.mapping.class_names[node.name] = self.name_generator.generate_name()
|
|
|
|
# Initialize dictionaries
|
|
if node.name not in self.mapping.class_methods:
|
|
self.mapping.class_methods[node.name] = {}
|
|
|
|
if node.name not in self.mapping.class_attributes:
|
|
self.mapping.class_attributes[node.name] = {}
|
|
|
|
if node.name not in self.mapping.seen_method_calls:
|
|
self.mapping.seen_method_calls[node.name] = set()
|
|
|
|
# Record inheritance
|
|
parent_classes = []
|
|
for base in node.bases:
|
|
if isinstance(base, ast.Name):
|
|
parent_classes.append(base.id)
|
|
|
|
if parent_classes:
|
|
self.mapping.inheritance[node.name] = parent_classes
|
|
|
|
# Process class body
|
|
for item in node.body:
|
|
if isinstance(item, ast.FunctionDef):
|
|
self.visit_method_def(item)
|
|
elif isinstance(item, ast.Assign):
|
|
self.visit_assign_in_class(item)
|
|
elif isinstance(item, ast.Expr):
|
|
# Could contain calls to self.methods
|
|
self.visit(item)
|
|
elif isinstance(item, ast.ClassDef):
|
|
# Nested class
|
|
self.visit(item)
|
|
else:
|
|
# Other nodes that might contain self.method calls
|
|
self.visit(item)
|
|
|
|
self.processed_classes.add(node.name)
|
|
self.current_class = prev_class
|
|
|
|
def visit_method_def(self, node: ast.FunctionDef):
|
|
"""Process a method definition in a class."""
|
|
if not self.current_class:
|
|
return
|
|
|
|
prev_method = self.current_method
|
|
self.current_method = node.name
|
|
|
|
# Skip dunder methods from obfuscation
|
|
if not (node.name.startswith('__') and node.name.endswith('__')):
|
|
# Map method name if not already mapped
|
|
if node.name not in self.mapping.class_methods[self.current_class]:
|
|
obf_name = self.name_generator.generate_name()
|
|
self.mapping.class_methods[self.current_class][node.name] = obf_name
|
|
logger.debug(f"Mapped method {self.current_class}.{node.name} to {obf_name}")
|
|
|
|
# Visit method body to find self.method calls and self.attr assignments
|
|
for item in node.body:
|
|
self.visit(item)
|
|
|
|
self.current_method = prev_method
|
|
|
|
def visit_assign_in_class(self, node: ast.Assign):
|
|
"""Process assignments in class body or methods."""
|
|
if not self.current_class:
|
|
return
|
|
|
|
# Check for self.attribute assignments
|
|
for target in node.targets:
|
|
if isinstance(target, ast.Attribute) and isinstance(target.value, ast.Name) and target.value.id == 'self':
|
|
attr_name = target.attr
|
|
|
|
# Map attribute name if not already mapped
|
|
if attr_name not in self.mapping.class_attributes[self.current_class]:
|
|
obf_name = self.name_generator.generate_name()
|
|
self.mapping.class_attributes[self.current_class][attr_name] = obf_name
|
|
logger.debug(f"Mapped attribute {self.current_class}.{attr_name} to {obf_name}")
|
|
|
|
# Visit the value to find nested self.method calls
|
|
self.visit(node.value)
|
|
|
|
def visit_Attribute(self, node: ast.Attribute):
|
|
"""Process attribute access like self.method or self.attr."""
|
|
if self.current_class and isinstance(node.value, ast.Name) and node.value.id == 'self':
|
|
# Record this access for later processing
|
|
method_name = node.attr
|
|
self.mapping.seen_method_calls[self.current_class].add(method_name)
|
|
logger.debug(f"Recorded method call: {self.current_class}.{method_name}")
|
|
|
|
# Continue traversal
|
|
self.generic_visit(node)
|
|
|
|
def visit_Assign(self, node: ast.Assign):
|
|
"""Process assignments that might contain self.attr references."""
|
|
# Visit both sides of the assignment
|
|
for target in node.targets:
|
|
self.visit(target)
|
|
self.visit(node.value)
|
|
|
|
def _resolve_inheritance(self):
|
|
"""
|
|
Ensure child classes inherit method mappings from parent classes.
|
|
"""
|
|
def process_inheritance(class_name):
|
|
if class_name not in self.mapping.inheritance:
|
|
return
|
|
|
|
for parent in self.mapping.inheritance[class_name]:
|
|
# Process parent's inheritance first
|
|
process_inheritance(parent)
|
|
|
|
# Skip if parent isn't in our mappings
|
|
if parent not in self.mapping.class_methods:
|
|
continue
|
|
|
|
# Copy parent's method mappings to child if not overridden
|
|
for method_name, obf_name in self.mapping.class_methods[parent].items():
|
|
if method_name not in self.mapping.class_methods[class_name]:
|
|
self.mapping.class_methods[class_name][method_name] = obf_name
|
|
logger.debug(f"Inherited method {class_name}.{method_name} from {parent}")
|
|
|
|
# Process all classes
|
|
for class_name in list(self.mapping.class_methods.keys()):
|
|
process_inheritance(class_name)
|
|
|
|
def _ensure_complete_method_mapping(self):
|
|
"""
|
|
Make sure all method calls have corresponding mappings.
|
|
This handles methods called but not defined in the class.
|
|
"""
|
|
for class_name, method_calls in self.mapping.seen_method_calls.items():
|
|
if class_name not in self.mapping.class_methods:
|
|
continue
|
|
|
|
for method_name in method_calls:
|
|
# Skip dunder methods
|
|
if method_name.startswith('__') and method_name.endswith('__'):
|
|
continue
|
|
|
|
# Add mapping if method was called but not defined
|
|
if method_name not in self.mapping.class_methods[class_name]:
|
|
obf_name = self.name_generator.generate_name()
|
|
self.mapping.class_methods[class_name][method_name] = obf_name
|
|
logger.debug(f"Added mapping for called method {class_name}.{method_name} -> {obf_name}")
|
|
|
|
|
|
class ClassTransformer(ast.NodeTransformer):
|
|
"""Transforms class-related nodes using the mapping."""
|
|
|
|
def __init__(self, mapping: ClassMapping):
|
|
self.mapping = mapping
|
|
self.current_class: Optional[str] = None
|
|
|
|
def visit_ClassDef(self, node: ast.ClassDef):
|
|
"""Transform class name and process its body."""
|
|
prev_class = self.current_class
|
|
orig_name = node.name
|
|
self.current_class = orig_name
|
|
|
|
# Rename class if it's in our mapping
|
|
if node.name in self.mapping.class_names:
|
|
node.name = self.mapping.class_names[node.name]
|
|
logger.debug(f"Transformed class {orig_name} -> {node.name}")
|
|
|
|
# Process class body
|
|
node.body = [self.visit(item) for item in node.body]
|
|
|
|
self.current_class = prev_class
|
|
return node
|
|
|
|
def visit_FunctionDef(self, node: ast.FunctionDef):
|
|
"""Transform method name."""
|
|
if self.current_class and node.name in self.mapping.class_methods.get(self.current_class, {}):
|
|
orig_name = node.name
|
|
node.name = self.mapping.class_methods[self.current_class][node.name]
|
|
logger.debug(f"Transformed method {self.current_class}.{orig_name} -> {node.name}")
|
|
|
|
# Visit the method body
|
|
node.body = [self.visit(item) for item in node.body]
|
|
return node
|
|
|
|
def visit_Attribute(self, node: ast.Attribute):
|
|
"""Transform self.method and self.attr references."""
|
|
# Process any child nodes first (for nested attributes)
|
|
node.value = self.visit(node.value)
|
|
|
|
# Check if this is a self.attr or self.method reference
|
|
if self.current_class and isinstance(node.value, ast.Name) and node.value.id == 'self':
|
|
orig_name = node.attr
|
|
|
|
# Check in method mappings first
|
|
if self.current_class in self.mapping.class_methods and node.attr in self.mapping.class_methods[self.current_class]:
|
|
node.attr = self.mapping.class_methods[self.current_class][node.attr]
|
|
logger.debug(f"Transformed self.method {self.current_class}.{orig_name} -> {node.attr}")
|
|
|
|
# Then check attribute mappings
|
|
elif self.current_class in self.mapping.class_attributes and node.attr in self.mapping.class_attributes[self.current_class]:
|
|
node.attr = self.mapping.class_attributes[self.current_class][node.attr]
|
|
logger.debug(f"Transformed self.attr {self.current_class}.{orig_name} -> {node.attr}")
|
|
|
|
return node
|
|
|
|
# Helper function to apply the class mapping transformation
|
|
def apply_class_mapping(tree: ast.AST, name_generator) -> ast.AST:
|
|
"""Analyze and transform classes consistently."""
|
|
# First pass: analyze all classes
|
|
analyzer = ClassMapAnalyzer(name_generator)
|
|
mapping = analyzer.analyze(tree)
|
|
|
|
# Second pass: transform using the mapping
|
|
transformer = ClassTransformer(mapping)
|
|
transformed = transformer.visit(tree)
|
|
|
|
return transformed, mapping
|