init
This commit is contained in:
@@ -0,0 +1,551 @@
|
||||
import ast
|
||||
from typing import Dict, Set, List, Optional, Union, Tuple
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("SymbolTree")
|
||||
|
||||
|
||||
class SymbolType(Enum):
|
||||
"""Defines the different types of symbols that can be tracked."""
|
||||
VARIABLE = "variable"
|
||||
FUNCTION = "function"
|
||||
CLASS = "class"
|
||||
METHOD = "method"
|
||||
ARGUMENT = "argument"
|
||||
ATTRIBUTE = "attribute"
|
||||
MODULE = "module"
|
||||
IMPORT = "import"
|
||||
|
||||
|
||||
class Symbol:
|
||||
"""Represents a symbol in the code with its name, type and other metadata."""
|
||||
|
||||
def __init__(self, name: str, symbol_type: SymbolType, node: ast.AST = None):
|
||||
self.name = name # Original name
|
||||
self.obfuscated_name: Optional[str] = None # Obfuscated name (if assigned)
|
||||
self.symbol_type = symbol_type
|
||||
self.node = node # AST node where this symbol is defined
|
||||
self.references: List[ast.AST] = [] # AST nodes where this symbol is referenced
|
||||
self.parent: Optional['Scope'] = None # Parent scope
|
||||
self.is_obfuscatable = True # Whether this symbol should be obfuscated
|
||||
|
||||
# Additional attributes for specific symbol types
|
||||
self.is_imported = False # Whether this symbol was imported
|
||||
self.original_module: Optional[str] = None # If imported, the module it was imported from
|
||||
|
||||
def add_reference(self, node: ast.AST):
|
||||
"""Add a reference to this symbol."""
|
||||
self.references.append(node)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Symbol {self.name} [{self.symbol_type.value}] {'→ ' + self.obfuscated_name if self.obfuscated_name else ''} refs:{len(self.references)}>"
|
||||
|
||||
|
||||
class Scope:
|
||||
"""
|
||||
Represents a scope in the code, such as a module, function, class, or comprehension.
|
||||
"""
|
||||
def __init__(self, name: str, scope_type: str, node: ast.AST = None):
|
||||
self.name = name
|
||||
self.scope_type = scope_type
|
||||
self.node = node
|
||||
|
||||
# Maps: symbol name -> Symbol object
|
||||
self.symbols: Dict[str, Symbol] = {}
|
||||
|
||||
# Child scopes within this scope
|
||||
self.children: List['Scope'] = []
|
||||
|
||||
# Parent scope (None for the global/module scope)
|
||||
self.parent: Optional['Scope'] = None
|
||||
|
||||
def add_symbol(self, symbol: Symbol) -> Symbol:
|
||||
"""Add a symbol to this scope and return it."""
|
||||
self.symbols[symbol.name] = symbol
|
||||
symbol.parent = self
|
||||
return symbol
|
||||
|
||||
def add_child_scope(self, scope: 'Scope') -> 'Scope':
|
||||
"""Add a child scope to this scope and return it."""
|
||||
self.children.append(scope)
|
||||
scope.parent = self
|
||||
return scope
|
||||
|
||||
def lookup(self, name: str) -> Optional[Symbol]:
|
||||
"""Look up a symbol in this scope, or in parent scopes."""
|
||||
if name in self.symbols:
|
||||
return self.symbols[name]
|
||||
elif self.parent:
|
||||
return self.parent.lookup(name)
|
||||
return None
|
||||
|
||||
def get_qualified_name(self) -> str:
|
||||
"""Get the fully qualified name of this scope."""
|
||||
if self.parent and self.parent.name:
|
||||
return f"{self.parent.get_qualified_name()}.{self.name}"
|
||||
return self.name
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Scope {self.get_qualified_name()} [{self.scope_type}] symbols:{len(self.symbols)} children:{len(self.children)}>"
|
||||
|
||||
|
||||
class ClassScope(Scope):
|
||||
"""A specialized scope for classes with additional tracking for inheritance."""
|
||||
|
||||
def __init__(self, name: str, node: ast.ClassDef):
|
||||
super().__init__(name, "class", node)
|
||||
self.base_classes: List[str] = [] # Names of base classes
|
||||
self.methods: Dict[str, Symbol] = {} # Methods defined in this class
|
||||
self.attributes: Dict[str, Symbol] = {} # Attributes defined in this class
|
||||
|
||||
def add_base_class(self, base_name: str):
|
||||
"""Add a base class to this class's inheritance list."""
|
||||
if base_name not in self.base_classes:
|
||||
self.base_classes.append(base_name)
|
||||
|
||||
def add_method(self, method: Symbol) -> Symbol:
|
||||
"""Add a method to this class."""
|
||||
self.methods[method.name] = method
|
||||
return self.add_symbol(method)
|
||||
|
||||
def add_attribute(self, attr: Symbol) -> Symbol:
|
||||
"""Add an attribute to this class."""
|
||||
self.attributes[attr.name] = attr
|
||||
return self.add_symbol(attr)
|
||||
|
||||
|
||||
class ModuleScope(Scope):
|
||||
"""A specialized scope for modules with additional tracking for imports."""
|
||||
|
||||
def __init__(self, name: str, node: ast.Module):
|
||||
super().__init__(name, "module", node)
|
||||
self.imports: Dict[str, str] = {} # Import alias -> original name
|
||||
self.from_imports: Dict[str, Dict[str, str]] = {} # Module -> {alias -> original name}
|
||||
|
||||
def add_import(self, alias: str, original: str):
|
||||
"""Add an import to this module."""
|
||||
self.imports[alias] = original
|
||||
|
||||
def add_from_import(self, module: str, alias: str, original: str):
|
||||
"""Add a from-import to this module."""
|
||||
if module not in self.from_imports:
|
||||
self.from_imports[module] = {}
|
||||
self.from_imports[module][alias] = original
|
||||
|
||||
|
||||
class SymbolTree:
|
||||
"""
|
||||
Global symbol tree that maintains a hierarchy of scopes and symbols
|
||||
across the entire codebase.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# The root scope is a special module scope named "__root__"
|
||||
self.root_scope = ModuleScope("__root__", None)
|
||||
# Current scope being processed
|
||||
self.current_scope = self.root_scope
|
||||
# Track classes for inheritance resolution
|
||||
self.classes: Dict[str, ClassScope] = {}
|
||||
# Track all symbols by their fully qualified name
|
||||
self.all_symbols: Dict[str, Symbol] = {}
|
||||
# Track imports for proper resolution
|
||||
self.imports: Dict[str, str] = {} # alias -> module
|
||||
|
||||
def push_scope(self, name: str, scope_type: str, node: ast.AST) -> Scope:
|
||||
"""Create a new scope and make it the current scope."""
|
||||
if scope_type == "class":
|
||||
new_scope = ClassScope(name, node)
|
||||
elif scope_type == "module":
|
||||
new_scope = ModuleScope(name, node)
|
||||
else:
|
||||
new_scope = Scope(name, scope_type, node)
|
||||
|
||||
self.current_scope.add_child_scope(new_scope)
|
||||
self.current_scope = new_scope
|
||||
|
||||
# If this is a class, track it
|
||||
if scope_type == "class":
|
||||
fully_qualified = new_scope.get_qualified_name()
|
||||
self.classes[fully_qualified] = new_scope
|
||||
# Also track with just the class name for simpler lookups
|
||||
self.classes[name] = new_scope
|
||||
|
||||
return new_scope
|
||||
|
||||
def pop_scope(self) -> Scope:
|
||||
"""Exit the current scope and return to its parent."""
|
||||
old_scope = self.current_scope
|
||||
if self.current_scope.parent:
|
||||
self.current_scope = self.current_scope.parent
|
||||
return old_scope
|
||||
|
||||
def add_symbol(self, name: str, symbol_type: SymbolType, node: ast.AST = None) -> Symbol:
|
||||
"""Add a symbol to the current scope."""
|
||||
symbol = Symbol(name, symbol_type, node)
|
||||
self.current_scope.add_symbol(symbol)
|
||||
|
||||
# Track in the global map
|
||||
qualified_name = f"{self.current_scope.get_qualified_name()}.{name}"
|
||||
self.all_symbols[qualified_name] = symbol
|
||||
|
||||
# If this is a method in a class scope
|
||||
if symbol_type == SymbolType.METHOD and isinstance(self.current_scope, ClassScope):
|
||||
self.current_scope.add_method(symbol)
|
||||
# If this is an attribute in a class scope
|
||||
elif symbol_type == SymbolType.ATTRIBUTE and isinstance(self.current_scope, ClassScope):
|
||||
self.current_scope.add_attribute(symbol)
|
||||
|
||||
return symbol
|
||||
|
||||
def add_reference(self, name: str, node: ast.AST):
|
||||
"""Add a reference to a symbol."""
|
||||
symbol = self.current_scope.lookup(name)
|
||||
if symbol:
|
||||
symbol.add_reference(node)
|
||||
|
||||
def resolve_inheritance(self):
|
||||
"""
|
||||
Resolve inheritance relationships between classes to ensure
|
||||
consistent method and attribute renaming.
|
||||
"""
|
||||
def resolve_class(class_scope: ClassScope, visited=None):
|
||||
if visited is None:
|
||||
visited = set()
|
||||
|
||||
# Skip if already visited to prevent infinite recursion
|
||||
if class_scope.name in visited:
|
||||
return
|
||||
visited.add(class_scope.name)
|
||||
|
||||
# Process each base class
|
||||
for base_name in class_scope.base_classes:
|
||||
# Skip if the base class is not in our tree (e.g., external library)
|
||||
if base_name not in self.classes:
|
||||
continue
|
||||
|
||||
base_scope = self.classes[base_name]
|
||||
# Resolve the base class first
|
||||
resolve_class(base_scope, visited)
|
||||
|
||||
# Copy method symbols from base to derived if not overridden
|
||||
for method_name, method_symbol in base_scope.methods.items():
|
||||
if method_name not in class_scope.methods:
|
||||
# Create a new symbol in the derived class that references the base class method
|
||||
derived_method = Symbol(method_name, SymbolType.METHOD)
|
||||
class_scope.add_method(derived_method)
|
||||
|
||||
# Use the same obfuscated name as the base class method
|
||||
# (even if the base class method hasn't been obfuscated yet)
|
||||
derived_method.obfuscated_name = method_symbol.obfuscated_name
|
||||
|
||||
# Process all classes
|
||||
for class_scope in self.classes.values():
|
||||
resolve_class(class_scope)
|
||||
|
||||
def check_for_issues(self) -> List[Dict]:
|
||||
"""Check for potential issues in the symbol tree."""
|
||||
issues = []
|
||||
|
||||
# Check for duplicated obfuscated names
|
||||
obfuscated_names = {}
|
||||
for qualified_name, symbol in self.all_symbols.items():
|
||||
if not symbol.obfuscated_name:
|
||||
continue
|
||||
|
||||
if symbol.obfuscated_name in obfuscated_names:
|
||||
issues.append({
|
||||
"type": "duplicate_obfuscated_name",
|
||||
"obfuscated_name": symbol.obfuscated_name,
|
||||
"symbols": [qualified_name, obfuscated_names[symbol.obfuscated_name]]
|
||||
})
|
||||
else:
|
||||
obfuscated_names[symbol.obfuscated_name] = qualified_name
|
||||
|
||||
# Check for inconsistent method obfuscation in inheritance hierarchies
|
||||
for class_name, class_scope in self.classes.items():
|
||||
for base_name in class_scope.base_classes:
|
||||
if base_name not in self.classes:
|
||||
continue
|
||||
|
||||
base_scope = self.classes[base_name]
|
||||
for method_name, method_symbol in base_scope.methods.items():
|
||||
if method_name in class_scope.methods:
|
||||
derived_method = class_scope.methods[method_name]
|
||||
if (method_symbol.obfuscated_name and derived_method.obfuscated_name and
|
||||
method_symbol.obfuscated_name != derived_method.obfuscated_name):
|
||||
issues.append({
|
||||
"type": "inconsistent_method_obfuscation",
|
||||
"method_name": method_name,
|
||||
"base_class": base_name,
|
||||
"derived_class": class_name,
|
||||
"base_obfuscated": method_symbol.obfuscated_name,
|
||||
"derived_obfuscated": derived_method.obfuscated_name
|
||||
})
|
||||
|
||||
return issues
|
||||
|
||||
def apply_name_generator(self, name_generator):
|
||||
"""
|
||||
Apply a name generator to all symbols that need obfuscation.
|
||||
Ensures consistent renaming across the entire codebase.
|
||||
"""
|
||||
# First, handle classes
|
||||
for class_scope in self.classes.values():
|
||||
class_symbol = self.current_scope.lookup(class_scope.name)
|
||||
if class_symbol and class_symbol.is_obfuscatable:
|
||||
class_symbol.obfuscated_name = name_generator.generate_name()
|
||||
|
||||
# Then handle methods to ensure consistency across inheritance
|
||||
self.resolve_inheritance()
|
||||
|
||||
# Apply to all other symbols
|
||||
for symbol in self.all_symbols.values():
|
||||
# Skip if already obfuscated or not obfuscatable
|
||||
if symbol.obfuscated_name or not symbol.is_obfuscatable:
|
||||
continue
|
||||
|
||||
# Skip special names
|
||||
if symbol.name.startswith("__") and symbol.name.endswith("__"):
|
||||
continue
|
||||
|
||||
symbol.obfuscated_name = name_generator.generate_name()
|
||||
|
||||
def get_rename_mapping(self) -> Dict[str, Dict[str, str]]:
|
||||
"""
|
||||
Get a mapping for all symbols to their obfuscated names,
|
||||
organized by symbol type for use in transformers.
|
||||
"""
|
||||
mapping = {
|
||||
"variables": {},
|
||||
"functions": {},
|
||||
"classes": {},
|
||||
"methods": {},
|
||||
"attributes": {}
|
||||
}
|
||||
|
||||
for symbol in self.all_symbols.values():
|
||||
if not symbol.obfuscated_name:
|
||||
continue
|
||||
|
||||
if symbol.symbol_type == SymbolType.VARIABLE:
|
||||
mapping["variables"][symbol.name] = symbol.obfuscated_name
|
||||
elif symbol.symbol_type == SymbolType.FUNCTION:
|
||||
mapping["functions"][symbol.name] = symbol.obfuscated_name
|
||||
elif symbol.symbol_type == SymbolType.CLASS:
|
||||
mapping["classes"][symbol.name] = symbol.obfuscated_name
|
||||
elif symbol.symbol_type == SymbolType.METHOD:
|
||||
# For methods, we need the class name
|
||||
if isinstance(symbol.parent, ClassScope):
|
||||
class_name = symbol.parent.name
|
||||
if class_name not in mapping["methods"]:
|
||||
mapping["methods"][class_name] = {}
|
||||
mapping["methods"][class_name][symbol.name] = symbol.obfuscated_name
|
||||
elif symbol.symbol_type == SymbolType.ATTRIBUTE:
|
||||
# For attributes, we need the class name
|
||||
if isinstance(symbol.parent, ClassScope):
|
||||
class_name = symbol.parent.name
|
||||
if class_name not in mapping["attributes"]:
|
||||
mapping["attributes"][class_name] = {}
|
||||
mapping["attributes"][class_name][symbol.name] = symbol.obfuscated_name
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
class SymbolTreeBuilder(ast.NodeVisitor):
|
||||
"""
|
||||
Builds a symbol tree by visiting all nodes in the AST.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.tree = SymbolTree()
|
||||
|
||||
# Track whether we're in a class definition
|
||||
self.in_class_def = False
|
||||
self.current_class = None
|
||||
|
||||
# Track function augments to avoid creating symbols for them twice
|
||||
self.current_function_args = set()
|
||||
|
||||
# Track whether we're in an attribute context
|
||||
self.in_attribute_ctx = False
|
||||
|
||||
def visit_Module(self, node: ast.Module):
|
||||
"""Process a module node."""
|
||||
self.tree.push_scope("__main__", "module", node)
|
||||
# Visit all statements in the module
|
||||
for stmt in node.body:
|
||||
self.visit(stmt)
|
||||
self.tree.pop_scope()
|
||||
|
||||
def visit_ClassDef(self, node: ast.ClassDef):
|
||||
"""Process a class definition."""
|
||||
# Create a new class scope
|
||||
class_scope = self.tree.push_scope(node.name, "class", node)
|
||||
|
||||
# Add class to current scope's symbols
|
||||
self.tree.add_symbol(node.name, SymbolType.CLASS, node)
|
||||
|
||||
# Track base classes
|
||||
for base in node.bases:
|
||||
if isinstance(base, ast.Name):
|
||||
class_scope.add_base_class(base.id)
|
||||
# Track reference to the base class
|
||||
self.tree.add_reference(base.id, base)
|
||||
|
||||
# Save previous state and update current state
|
||||
prev_in_class = self.in_class_def
|
||||
prev_class = self.current_class
|
||||
self.in_class_def = True
|
||||
self.current_class = node.name
|
||||
|
||||
# Visit class body
|
||||
for item in node.body:
|
||||
self.visit(item)
|
||||
|
||||
# Restore previous state
|
||||
self.in_class_def = prev_in_class
|
||||
self.current_class = prev_class
|
||||
|
||||
# Exit class scope
|
||||
self.tree.pop_scope()
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef):
|
||||
"""Process a function definition."""
|
||||
# Determine if this is a method or a regular function
|
||||
symbol_type = SymbolType.METHOD if self.in_class_def else SymbolType.FUNCTION
|
||||
|
||||
# Add function/method to current scope's symbols
|
||||
self.tree.add_symbol(node.name, symbol_type, node)
|
||||
|
||||
# Create a new function scope
|
||||
self.tree.push_scope(node.name, "function", node)
|
||||
|
||||
# Clear current function arguments set
|
||||
self.current_function_args = set()
|
||||
|
||||
# Process arguments
|
||||
self.visit(node.args)
|
||||
|
||||
# Visit function body
|
||||
for item in node.body:
|
||||
self.visit(item)
|
||||
|
||||
# Exit function scope
|
||||
self.tree.pop_scope()
|
||||
|
||||
def visit_arguments(self, node: ast.arguments):
|
||||
"""Process function arguments."""
|
||||
# Process positional arguments
|
||||
for arg in node.args:
|
||||
self.current_function_args.add(arg.arg)
|
||||
self.tree.add_symbol(arg.arg, SymbolType.ARGUMENT, arg)
|
||||
|
||||
# Process vararg (e.g., *args)
|
||||
if node.vararg:
|
||||
self.current_function_args.add(node.vararg.arg)
|
||||
self.tree.add_symbol(node.vararg.arg, SymbolType.ARGUMENT, node.vararg)
|
||||
|
||||
# Process keyword arguments
|
||||
for kwarg in node.kwonlyargs:
|
||||
self.current_function_args.add(kwarg.arg)
|
||||
self.tree.add_symbol(kwarg.arg, SymbolType.ARGUMENT, kwarg)
|
||||
|
||||
# Process kwarg (e.g., **kwargs)
|
||||
if node.kwarg:
|
||||
self.current_function_args.add(node.kwarg.arg)
|
||||
self.tree.add_symbol(node.kwarg.arg, SymbolType.ARGUMENT, node.kwarg)
|
||||
|
||||
def visit_Assign(self, node: ast.Assign):
|
||||
"""Process an assignment statement."""
|
||||
# Visit the right side first to capture any variable references
|
||||
self.visit(node.value)
|
||||
|
||||
# Now visit the targets (left-hand side)
|
||||
for target in node.targets:
|
||||
# Handle attribute assignment (e.g., self.x = value)
|
||||
if isinstance(target, ast.Attribute):
|
||||
self.visit_attribute_assignment(target)
|
||||
# Handle simple name assignment
|
||||
elif isinstance(target, ast.Name):
|
||||
# Only add symbol if it's not already a function argument
|
||||
if target.id not in self.current_function_args:
|
||||
self.tree.add_symbol(target.id, SymbolType.VARIABLE, target)
|
||||
# Handle other target types (e.g., subscripts, tuples)
|
||||
else:
|
||||
self.visit(target)
|
||||
|
||||
def visit_attribute_assignment(self, node: ast.Attribute):
|
||||
"""Process attribute assignment (e.g., self.x = value)."""
|
||||
# Check if this is a self.attr assignment in a class
|
||||
if (isinstance(node.value, ast.Name) and node.value.id == 'self'
|
||||
and self.in_class_def and self.current_class):
|
||||
# Add attribute to the current class
|
||||
self.tree.add_symbol(node.attr, SymbolType.ATTRIBUTE, node)
|
||||
else:
|
||||
# Visit the left side to capture any variable references
|
||||
self.visit(node.value)
|
||||
|
||||
def visit_Name(self, node: ast.Name):
|
||||
"""Process a name (variable reference)."""
|
||||
# This is a variable/function/class reference, not a definition
|
||||
if isinstance(node.ctx, ast.Load):
|
||||
self.tree.add_reference(node.id, node)
|
||||
|
||||
def visit_Attribute(self, node: ast.Attribute):
|
||||
"""Process attribute access (e.g., obj.attr)."""
|
||||
# Track that we're in an attribute context
|
||||
prev_in_attribute_ctx = self.in_attribute_ctx
|
||||
self.in_attribute_ctx = True
|
||||
|
||||
# Visit the left side
|
||||
self.visit(node.value)
|
||||
|
||||
# Handle self.attr access in a class
|
||||
if (isinstance(node.value, ast.Name) and node.value.id == 'self'
|
||||
and self.in_class_def and self.current_class):
|
||||
# The attribute might be accessed before it's assigned, so we need to ensure it's in the symbol table
|
||||
class_scope = self.tree.classes.get(self.current_class)
|
||||
if class_scope and node.attr not in class_scope.attributes:
|
||||
self.tree.add_symbol(node.attr, SymbolType.ATTRIBUTE, node)
|
||||
|
||||
# Restore previous state
|
||||
self.in_attribute_ctx = prev_in_attribute_ctx
|
||||
|
||||
def visit_Import(self, node: ast.Import):
|
||||
"""Process an import statement."""
|
||||
for item in node.names:
|
||||
# The imported name should not be obfuscated
|
||||
symbol = self.tree.add_symbol(item.asname or item.name, SymbolType.IMPORT, node)
|
||||
symbol.is_obfuscatable = False
|
||||
symbol.is_imported = True
|
||||
|
||||
# Track the import
|
||||
if isinstance(self.tree.current_scope, ModuleScope):
|
||||
module_scope = self.tree.current_scope
|
||||
module_scope.add_import(item.asname or item.name, item.name)
|
||||
|
||||
def visit_ImportFrom(self, node: ast.ImportFrom):
|
||||
"""Process a from-import statement."""
|
||||
for item in node.names:
|
||||
# The imported name should not be obfuscated
|
||||
symbol = self.tree.add_symbol(item.asname or item.name, SymbolType.IMPORT, node)
|
||||
symbol.is_obfuscatable = False
|
||||
symbol.is_imported = True
|
||||
symbol.original_module = node.module
|
||||
|
||||
# Track the import
|
||||
if isinstance(self.tree.current_scope, ModuleScope):
|
||||
module_scope = self.tree.current_scope
|
||||
module_scope.add_from_import(node.module, item.asname or item.name, item.name)
|
||||
|
||||
# Add more visit methods for other AST node types as needed
|
||||
|
||||
def build_tree(self, tree: ast.AST) -> SymbolTree:
|
||||
"""Build the symbol tree from the AST."""
|
||||
self.visit(tree)
|
||||
# Perform final processing
|
||||
self.tree.resolve_inheritance()
|
||||
return self.tree
|
||||
Reference in New Issue
Block a user