From 34fd57294928cf132527a2288410d5f874726274 Mon Sep 17 00:00:00 2001 From: Samir Gupta Date: Mon, 4 Sep 2023 04:06:03 +1000 Subject: [PATCH 1/5] added static types to all functions and classes --- ast_scope/annotate.py | 17 +-- ast_scope/annotator.py | 167 +++++++++++++------------- ast_scope/graph.py | 13 +- ast_scope/group_similar_constructs.py | 28 ++--- ast_scope/pull_scope.py | 44 ++++--- ast_scope/scope.py | 47 ++++---- ast_scope/utils.py | 22 ++-- 7 files changed, 177 insertions(+), 161 deletions(-) diff --git a/ast_scope/annotate.py b/ast_scope/annotate.py index 361df8e..b33b36d 100644 --- a/ast_scope/annotate.py +++ b/ast_scope/annotate.py @@ -1,12 +1,13 @@ -from ast_scope.scope import FunctionScope -from .annotator import AnnotateScope, IntermediateGlobalScope +import ast +from ast_scope.scope import ErrorScope, FunctionScope, GlobalScope, Scope +from .annotator import AnnotateScope, IntermediateGlobalScope, IntermediateScope from .pull_scope import PullScopes from .utils import get_all_nodes, get_name from .graph import DiGraph class ScopeInfo: - def __init__(self, tree, global_scope, error_scope, node_to_containing_scope): + def __init__(self, tree: ast.AST, global_scope: GlobalScope, error_scope: ErrorScope, node_to_containing_scope: dict[ast.AST, Scope]): self._tree = tree self._global_scope = global_scope self._error_scope = error_scope @@ -42,13 +43,13 @@ def static_dependency_graph(self): def __iter__(self): return iter(self._node_to_containing_scope) - def __contains__(self, node): + def __contains__(self, node: ast.AST): return node in self._node_to_containing_scope - def __getitem__(self, node): + def __getitem__(self, node: ast.AST): return self._node_to_containing_scope[node] - def function_scope_for(self, node): + def function_scope_for(self, node: ast.AST): """ Returns the function scope for the given FunctionDef node. """ @@ -61,8 +62,8 @@ def function_scope_for(self, node): return None -def annotate(tree, class_binds_near=False): - annotation_dict = {} +def annotate(tree: ast.AST, class_binds_near: bool=False): + annotation_dict: dict[ast.AST, tuple[str, IntermediateScope, bool]] = {} annotator = AnnotateScope( IntermediateGlobalScope(), annotation_dict, class_binds_near=class_binds_near ) diff --git a/ast_scope/annotator.py b/ast_scope/annotator.py index c0a6e97..3ec01f9 100644 --- a/ast_scope/annotator.py +++ b/ast_scope/annotator.py @@ -1,5 +1,6 @@ -import ast import abc +import ast +from typing import Iterable, Self from .group_similar_constructs import GroupSimilarConstructsVisitor from .utils import name_of_alias @@ -12,61 +13,66 @@ class IntermediateScope(abc.ABC): """ def __init__(self): - self.referenced_variables = set() - self.written_variables = set() - self.nonlocal_variables = set() - self.global_variables = set() + self.referenced_variables: set[str] = set() + self.written_variables: set[str] = set() + self.nonlocal_variables: set[str] = set() + self.global_variables: set[str] = set() - def load(self, variable): + def load(self, variable: str): self.referenced_variables.add(variable) - def modify(self, variable): + def modify(self, variable: str): self.written_variables.add(variable) - def globalize(self, variable): + def globalize(self, variable: str): self.global_variables.add(variable) - def nonlocalize(self, variable): + def nonlocalize(self, variable: str): self.nonlocal_variables.add(variable) @abc.abstractmethod - def global_frame(self): + def global_frame(self) -> 'IntermediateGlobalScope': pass @abc.abstractmethod - def find(self, name, global_acceptable=True): + def find(self, name: str, is_assignment: bool, global_acceptable: bool=True) -> Self | None: """ Finds the actual frame containing the variable name, or None if no frame exists """ pass - def true_parent(self): - parent = self.parent - while isinstance(parent, IntermediateClassScope): - parent = parent.parent - return parent - class IntermediateGlobalScope(IntermediateScope): - def find(self, name, is_assignment, global_acceptable=True): + def find(self, name: str, is_assignment: bool, global_acceptable: bool=True): if not global_acceptable: return None return self def global_frame(self): return self + - -class IntermediateFunctionScope(IntermediateScope): - def __init__(self, node, parent): +class IntermediateScopeWithParent(IntermediateScope): + def __init__(self, parent: IntermediateScope): + self.parent = parent super().__init__() + + def true_parent(self) -> IntermediateScope: + parent = self.parent + while isinstance(parent, IntermediateClassScope): + parent = parent.parent + return parent + + +class IntermediateFunctionScope(IntermediateScopeWithParent): + def __init__(self, node: ast.FunctionDef | ast.AsyncFunctionDef | ast.comprehension | ast.Lambda, parent: IntermediateScope): + super().__init__(parent) self.node = node - self.parent = parent - def global_frame(self): + def global_frame(self) -> IntermediateGlobalScope: return self.true_parent().global_frame() - def find(self, name, is_assignment, global_acceptable=True): + def find(self, name: str, is_assignment: bool, global_acceptable: bool=True): if name in self.global_variables: return self.global_frame() if name in self.nonlocal_variables: @@ -76,17 +82,16 @@ def find(self, name, is_assignment, global_acceptable=True): return self.true_parent().find(name, is_assignment, global_acceptable) -class IntermediateClassScope(IntermediateScope): - def __init__(self, node, parent, class_binds_near): - super().__init__() +class IntermediateClassScope(IntermediateScopeWithParent): + def __init__(self, node: ast.ClassDef, parent: IntermediateScope, class_binds_near: bool): + super().__init__(parent) self.node = node - self.parent = parent self.class_binds_near = class_binds_near - def global_frame(self): - return self.true_parent().find(self) + def global_frame(self) -> IntermediateGlobalScope: + return self.true_parent().global_frame() - def find(self, name, is_assignment, global_acceptable=True): + def find(self, name: str, is_assignment: bool, global_acceptable: bool=True): if self.class_binds_near: # anything can be in a class frame return self @@ -94,22 +99,21 @@ def find(self, name, is_assignment, global_acceptable=True): return self return self.parent.find(name, is_assignment, global_acceptable) - class GrabVariable(ast.NodeVisitor): """ Dumps variables from a given name object into the given scope. """ - def __init__(self, scope, variable, annotation_dict): + def __init__(self, scope: IntermediateScope, variable: ast.Name, annotation_dict: dict[ast.AST, tuple[str, IntermediateScope, bool]]): self.scope = scope self.variable = variable self.annotation_dict = annotation_dict - def visit_generic(self, node): + def visit_generic(self, node: ast.AST): raise RuntimeError("Unsupported node type: {node}".format(node=node)) - def visit_Name(self, node): - super().visit_generic(node) + def visit_Name(self, node: ast.Name): + super().generic_visit(node) def load(self): self.annotation_dict[self.variable] = self.variable.id, self.scope, False @@ -119,70 +123,71 @@ def modify(self): self.annotation_dict[self.variable] = self.variable.id, self.scope, True self.scope.modify(self.variable.id) - def visit_Load(self, _): + def visit_Load(self, node: ast.Load): self.load() - def visit_Store(self, _): + def visit_Store(self, node: ast.Store): self.modify() - def visit_Del(self, _): + def visit_Del(self, node: ast.Del): self.modify() - def visit_AugLoad(self, _): + def visit_AugLoad(self, node: ast.AugLoad): raise RuntimeError("Unsupported: AugStore") - def visit_AugStore(self, _): + def visit_AugStore(self, node: ast.AugStore): raise RuntimeError("Unsupported: AugStore") class ProcessArguments(ast.NodeVisitor): - def __init__(self, expr_scope, arg_scope): + def __init__(self, expr_scope: 'AnnotateScope', arg_scope: 'AnnotateScope'): self.expr_scope = expr_scope self.arg_scope = arg_scope - def visit_arg(self, node): + def visit_arg(self, node: ast.arg): self.arg_scope.visit(node) visit_all(self.expr_scope, node.annotation, getattr(node, "type_comment", None)) - def visit_arguments(self, node): + def visit_arguments(self, node: ast.AST): super().generic_visit(node) - def generic_visit(self, node): + def generic_visit(self, node: ast.AST): self.expr_scope.visit(node) class AnnotateScope(GroupSimilarConstructsVisitor): - def __init__(self, scope, annotation_dict, class_binds_near): + def __init__(self, scope: IntermediateScope, annotation_dict: dict[ast.AST, tuple[str, IntermediateScope, bool]], class_binds_near: bool): self.scope = scope self.annotation_dict = annotation_dict self.class_binds_near = class_binds_near - def annotate_intermediate_scope(self, node, name, is_assign): + def annotate_intermediate_scope(self, node: ast.AST, name: str, is_assign: bool): self.annotation_dict[node] = name, self.scope, is_assign - def visit_Name(self, name_node): - GrabVariable(self.scope, name_node, self.annotation_dict).generic_visit( - name_node + def visit_Name(self, node: ast.Name): + GrabVariable(self.scope, node, self.annotation_dict).generic_visit( + node ) - def visit_ExceptHandler(self, handler_node): - self.annotate_intermediate_scope(handler_node, handler_node.name, True) - self.scope.modify(handler_node.name) - visit_all(self, handler_node.type, handler_node.body) + def visit_ExceptHandler(self, node: ast.ExceptHandler): + assert node.name + self.annotate_intermediate_scope(node, node.name, True) + self.scope.modify(node.name) + visit_all(self, node.type, node.body) - def visit_alias(self, alias_node): - variable = name_of_alias(alias_node) - self.annotate_intermediate_scope(alias_node, variable, True) + def visit_alias(self, node: ast.alias): + variable = name_of_alias(node) + self.annotate_intermediate_scope(node, variable, True) self.scope.modify(variable) - def visit_arg(self, arg): - self.annotate_intermediate_scope(arg, arg.arg, True) - self.scope.modify(arg.arg) + def visit_arg(self, node: ast.arg): + self.annotate_intermediate_scope(node, node.arg, True) + self.scope.modify(node.arg) - def create_subannotator(self, scope): + def create_subannotator(self, scope: IntermediateScope): return AnnotateScope(scope, self.annotation_dict, self.class_binds_near) - def visit_function_def(self, func_node, is_async): + def visit_function_def(self, func_node: ast.FunctionDef | ast.AsyncFunctionDef, is_async: bool): del is_async self.annotate_intermediate_scope(func_node, func_node.name, True) self.scope.modify(func_node.name) @@ -195,19 +200,19 @@ def visit_function_def(self, func_node, is_async): ProcessArguments(self, subscope).visit(func_node.args) visit_all(subscope, func_node.body, func_node.returns) - def visit_Lambda(self, func_node): - self.annotate_intermediate_scope(func_node, "", None) + def visit_Lambda(self, node: ast.Lambda): + self.annotate_intermediate_scope(node, "", False) subscope = self.create_subannotator( - IntermediateFunctionScope(func_node, self.scope) + IntermediateFunctionScope(node, self.scope) ) - ProcessArguments(self, subscope).visit(func_node.args) - visit_all(subscope, func_node.body) + ProcessArguments(self, subscope).visit(node.args) + visit_all(subscope, node.body) - def visit_comprehension_generic(self, targets, comprehensions, typ): - del typ + def visit_comprehension_generic(self, targets: list[ast.expr], comprehensions: list[ast.comprehension], node: ast.AST): + del node current_scope = self for comprehension in comprehensions: - self.annotate_intermediate_scope(comprehension, "", None) + self.annotate_intermediate_scope(comprehension, "", False) subscope = self.create_subannotator( IntermediateFunctionScope(comprehension, current_scope.scope) ) @@ -216,28 +221,28 @@ def visit_comprehension_generic(self, targets, comprehensions, typ): current_scope = subscope visit_all(current_scope, targets) - def visit_ClassDef(self, class_node): - self.annotate_intermediate_scope(class_node, class_node.name, True) - self.scope.modify(class_node.name) + def visit_ClassDef(self, node: ast.ClassDef): + self.annotate_intermediate_scope(node, node.name, True) + self.scope.modify(node.name) subscope = self.create_subannotator( - IntermediateClassScope(class_node, self.scope, self.class_binds_near) + IntermediateClassScope(node, self.scope, self.class_binds_near) ) - ast.NodeVisitor.generic_visit(subscope, class_node) + ast.NodeVisitor.generic_visit(subscope, node) - def visit_Global(self, global_node): - for name in global_node.names: + def visit_Global(self, node: ast.Global): + for name in node.names: self.scope.globalize(name) - def visit_Nonlocal(self, nonlocal_node): - for name in nonlocal_node.names: + def visit_Nonlocal(self, node: ast.Nonlocal): + for name in node.names: self.scope.nonlocalize(name) -def visit_all(visitor, *nodes): +def visit_all(visitor: ast.NodeVisitor, *nodes: Iterable[ast.AST] | ast.AST | None): for node in nodes: if node is None: pass - elif isinstance(node, list): + elif isinstance(node, Iterable): visit_all(visitor, *node) else: visitor.visit(node) diff --git a/ast_scope/graph.py b/ast_scope/graph.py index 823d4b8..de3a37a 100644 --- a/ast_scope/graph.py +++ b/ast_scope/graph.py @@ -1,16 +1,19 @@ +from typing import Iterable + + class DiGraph: def __init__(self): - self.__adjacency_list = {} + self.__adjacency_list: dict[str, set[str]] = {} - def add_nodes_from(self, iterable): + def add_nodes_from(self, iterable: Iterable[str]): for node in iterable: self.add_node(node) - def add_node(self, node): + def add_node(self, node: str): if node not in self.__adjacency_list: self.__adjacency_list[node] = set() - def add_edge(self, source, target): + def add_edge(self, source: str, target: str): self.__adjacency_list[source].add(target) def nodes(self): @@ -23,5 +26,5 @@ def edges(self): for target in targets ) - def neighbors(self, node): + def neighbors(self, node: str): return list(self.__adjacency_list[node]) diff --git a/ast_scope/group_similar_constructs.py b/ast_scope/group_similar_constructs.py index 3239afd..c517bc9 100644 --- a/ast_scope/group_similar_constructs.py +++ b/ast_scope/group_similar_constructs.py @@ -2,34 +2,34 @@ class GroupSimilarConstructsVisitor(ast.NodeVisitor): - def visit_function_def(self, func_node, is_async): + def visit_function_def(self, func_node: ast.FunctionDef | ast.AsyncFunctionDef, is_async: bool): return self.generic_visit(func_node) - def visit_FunctionDef(self, func_node): - return self.visit_function_def(func_node, is_async=False) + def visit_FunctionDef(self, node: ast.FunctionDef): + return self.visit_function_def(node, is_async=False) - def visit_AsyncFunctionDef(self, func_node): - return self.visit_function_def(func_node, is_async=True) + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef): + return self.visit_function_def(node, is_async=True) - def visit_comprehension_generic(self, targets, comprehensions, node): + def visit_comprehension_generic(self, targets: list[ast.expr], comprehensions: list[ast.comprehension], node: ast.AST): return self.generic_visit(node) - def visit_DictComp(self, comp_node): + def visit_DictComp(self, node: ast.DictComp): return self.visit_comprehension_generic( - [comp_node.key, comp_node.value], comp_node.generators, comp_node + [node.key, node.value], node.generators, node ) - def visit_ListComp(self, comp_node): + def visit_ListComp(self, node: ast.ListComp): return self.visit_comprehension_generic( - [comp_node.elt], comp_node.generators, comp_node + [node.elt], node.generators, node ) - def visit_SetComp(self, comp_node): + def visit_SetComp(self, node: ast.SetComp): return self.visit_comprehension_generic( - [comp_node.elt], comp_node.generators, comp_node + [node.elt], node.generators, node ) - def visit_GeneratorExp(self, comp_node): + def visit_GeneratorExp(self, node: ast.GeneratorExp): return self.visit_comprehension_generic( - [comp_node.elt], comp_node.generators, comp_node + [node.elt], node.generators, node ) diff --git a/ast_scope/pull_scope.py b/ast_scope/pull_scope.py index d7d1908..1cd788e 100644 --- a/ast_scope/pull_scope.py +++ b/ast_scope/pull_scope.py @@ -1,31 +1,35 @@ import ast +from typing import cast -from .scope import GlobalScope, ErrorScope, FunctionScope, ClassScope +from .scope import GlobalScope, ErrorScope, FunctionScope, ClassScope, Scope from .annotator import ( IntermediateGlobalScope, IntermediateFunctionScope, IntermediateClassScope, + IntermediateScope, visit_all, ) from .group_similar_constructs import GroupSimilarConstructsVisitor +from .scope import ClassScope, ErrorScope, FunctionScope, GlobalScope class PullScopes(GroupSimilarConstructsVisitor): - def __init__(self, annotation_dict): + def __init__(self, annotation_dict: dict[ast.AST, tuple[str, IntermediateScope, bool]]): self.annotation_dict = annotation_dict - self.node_to_corresponding_scope = {} - self.node_to_containing_scope = {} + self.node_to_corresponding_scope: dict[ast.AST, Scope] = {} + self.node_to_containing_scope: dict[ast.AST, Scope] = {} self.global_scope = GlobalScope() self.error_scope = ErrorScope() - def convert(self, int_scope): + def convert(self, int_scope: IntermediateScope | None): if int_scope is None: return self.error_scope if isinstance(int_scope, IntermediateGlobalScope): return self.global_scope + int_scope = cast(IntermediateClassScope | IntermediateFunctionScope, int_scope) return self.node_to_corresponding_scope[int_scope.node] - def pull_scope(self, node, include_as_variable=True): + def pull_scope(self, node: ast.AST, include_as_variable: bool=True) -> Scope: name, intermediate_scope, is_assign = self.annotation_dict[node] true_intermediate_scope = intermediate_scope.find(name, is_assign) scope = self.convert(true_intermediate_scope) @@ -33,32 +37,32 @@ def pull_scope(self, node, include_as_variable=True): self.node_to_containing_scope[node] = scope return scope - def visit_Name(self, node): + def visit_Name(self, node: ast.Name): scope = self.pull_scope(node) scope.add_variable(node) super().generic_visit(node) - def visit_arg(self, node): + def visit_arg(self, node: ast.arg): scope = self.pull_scope(node) scope.add_argument(node) super().generic_visit(node) - def visit_alias(self, node): + def visit_alias(self, node: ast.alias): scope = self.pull_scope(node) scope.add_import(node) super().generic_visit(node) - def visit_function_def(self, node, is_async): + def visit_function_def(self, func_node: ast.FunctionDef | ast.AsyncFunctionDef, is_async: bool): del is_async - scope = self.pull_scope(node) - if node not in self.node_to_corresponding_scope: - self.node_to_corresponding_scope[node] = FunctionScope(node, scope) + scope = self.pull_scope(func_node) + if func_node not in self.node_to_corresponding_scope: + self.node_to_corresponding_scope[func_node] = FunctionScope(func_node, scope) scope.add_function( - node, self.node_to_corresponding_scope[node], include_as_variable=True + func_node, self.node_to_corresponding_scope[func_node], include_as_variable=True ) - super().generic_visit(node) + super().generic_visit(func_node) - def visit_Lambda(self, node): + def visit_Lambda(self, node: ast.Lambda): scope = self.pull_scope(node, include_as_variable=False) if node not in self.node_to_corresponding_scope: self.node_to_corresponding_scope[node] = FunctionScope(node, scope) @@ -67,17 +71,17 @@ def visit_Lambda(self, node): ) super().generic_visit(node) - def visit_ExceptHandler(self, node): + def visit_ExceptHandler(self, node: ast.ExceptHandler): scope = self.pull_scope(node) scope.add_exception(node) super().generic_visit(node) - def visit_comprehension_generic(self, targets, comprehensions, node): + def visit_comprehension_generic(self, targets: list[ast.expr], comprehensions: list[ast.comprehension], node: ast.AST): # mate sure to visit the comprehensions first visit_all(self, comprehensions) visit_all(self, targets) - def visit_comprehension(self, node): + def visit_comprehension(self, node: ast.comprehension): scope = self.pull_scope(node, include_as_variable=False) if node not in self.node_to_corresponding_scope: self.node_to_corresponding_scope[node] = FunctionScope(node, scope) @@ -86,7 +90,7 @@ def visit_comprehension(self, node): ) super().generic_visit(node) - def visit_ClassDef(self, node): + def visit_ClassDef(self, node: ast.ClassDef): scope = self.pull_scope(node) if node not in self.node_to_corresponding_scope: self.node_to_corresponding_scope[node] = ClassScope(node, scope) diff --git a/ast_scope/scope.py b/ast_scope/scope.py index b9d4c0e..bb13416 100644 --- a/ast_scope/scope.py +++ b/ast_scope/scope.py @@ -1,27 +1,30 @@ +import abc import ast +from typing import Self + import attr -import abc from .annotator import name_of_alias @attr.s class Variables: - arguments = attr.ib(attr.Factory(set)) - variables = attr.ib(attr.Factory(set)) - functions = attr.ib(attr.Factory(set)) - classes = attr.ib(attr.Factory(set)) - import_statements = attr.ib(attr.Factory(set)) - exceptions = attr.ib(attr.Factory(set)) + arguments: set[ast.arg] = attr.ib(factory=set) + variables: set[ast.Name] = attr.ib(factory=set) + functions: set[ast.FunctionDef | ast.Lambda | ast.comprehension | ast.AsyncFunctionDef] = attr.ib(factory=set) + classes: set[ast.ClassDef] = attr.ib(factory=set) + import_statements: set[ast.alias] = attr.ib(factory=set) + exceptions: set[ast.ExceptHandler] = attr.ib(factory=set) @property def node_to_symbol(self): - result = {} + result: dict[ast.AST, str] = {} result.update({var: var.arg for var in self.arguments}) result.update({var: var.id for var in self.variables}) - result.update({var: var.name for var in self.functions | self.classes}) + result.update({var: var.name for var in self.functions if isinstance(var, (ast.FunctionDef, ast.AsyncFunctionDef))}) + result.update({var: var.name for var in self.classes}) result.update({var: name_of_alias(var) for var in self.import_statements}) - result.update({var: var.name for var in self.exceptions}) + result.update({var: var.name for var in self.exceptions if var.name}) return result @property @@ -33,28 +36,28 @@ class Scope(abc.ABC): def __init__(self): self.variables = Variables() - def add_argument(self, node): + def add_argument(self, node: ast.arg): self.variables.arguments.add(node) - def add_variable(self, node): + def add_variable(self, node: ast.Name): self.variables.variables.add(node) - def add_import(self, node): + def add_import(self, node: ast.alias): self.variables.import_statements.add(node) - def add_exception(self, node): + def add_exception(self, node: ast.ExceptHandler): self.variables.exceptions.add(node) @abc.abstractmethod - def add_child(self, scope): + def add_child(self, scope: Self): pass - def add_function(self, node, function_scope, include_as_variable): + def add_function(self, node: ast.FunctionDef | ast.Lambda | ast.comprehension | ast.AsyncFunctionDef, function_scope: Self, include_as_variable: bool): if include_as_variable: self.variables.functions.add(node) self.add_child(function_scope) - def add_class(self, node, class_scope): + def add_class(self, node: ast.ClassDef, class_scope: Self): self.variables.classes.add(node) self.add_child(class_scope) @@ -66,14 +69,14 @@ def symbols_in_frame(self): class ScopeWithChildren(Scope): def __init__(self): Scope.__init__(self) - self.children = [] + self.children: list[Scope] = [] - def add_child(self, scope): + def add_child(self, scope: Scope): self.children.append(scope) class ScopeWithParent(Scope, abc.ABC): - def __init__(self, parent): + def __init__(self, parent: Scope): super().__init__() self.parent = parent @@ -88,14 +91,14 @@ class GlobalScope(ScopeWithChildren): class FunctionScope(ScopeWithChildren, ScopeWithParent): - def __init__(self, function_node, parent): + def __init__(self, function_node: ast.FunctionDef | ast.Lambda | ast.comprehension | ast.AsyncFunctionDef, parent: Scope): ScopeWithChildren.__init__(self) ScopeWithParent.__init__(self, parent) self.function_node = function_node class ClassScope(ScopeWithParent): - def __init__(self, class_node, parent): + def __init__(self, class_node: ast.ClassDef, parent: Scope): super().__init__(parent) self.class_node = class_node diff --git a/ast_scope/utils.py b/ast_scope/utils.py index 75a1272..ffb68c6 100644 --- a/ast_scope/utils.py +++ b/ast_scope/utils.py @@ -5,14 +5,14 @@ class GetAllNodes(ast.NodeVisitor): def __init__(self): - self.nodes = [] + self.nodes: list[ast.AST] = [] - def generic_visit(self, node): + def generic_visit(self, node: ast.AST): self.nodes.append(node) super().generic_visit(node) -def get_all_nodes(node): +def get_all_nodes(node: ast.AST): getter = GetAllNodes() getter.visit(node) return [subnode for subnode in getter.nodes if subnode is not node] @@ -22,27 +22,27 @@ class GetName(GroupSimilarConstructsVisitor): def __init__(self): self.name = None - def visit_Name(self, node): + def visit_Name(self, node: ast.Name): self.name = node.id - def visit_function_def(self, func_node, is_async): + def visit_function_def(self, func_node: ast.FunctionDef | ast.AsyncFunctionDef, is_async: bool): self.name = func_node.name - def visit_ClassDef(self, class_node): - self.name = class_node.name + def visit_ClassDef(self, node: ast.ClassDef): + self.name = node.name - def visit_alias(self, alias_node): - self.name = name_of_alias(alias_node) + def visit_alias(self, node: ast.alias): + self.name = name_of_alias(node) -def get_name(node): +def get_name(node: ast.AST): getter = GetName() getter.visit(node) assert getter.name is not None return getter.name -def name_of_alias(alias_node): +def name_of_alias(alias_node: ast.alias): if alias_node.asname is not None: return alias_node.asname return alias_node.name From 165c0490deb23c313490a8fdd5a40faaee5e18aa Mon Sep 17 00:00:00 2001 From: Samir Gupta Date: Mon, 4 Sep 2023 04:13:42 +1000 Subject: [PATCH 2/5] add some more exports to enhance typing --- ast_scope/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ast_scope/__init__.py b/ast_scope/__init__.py index b8b1ca0..648a4b5 100644 --- a/ast_scope/__init__.py +++ b/ast_scope/__init__.py @@ -1 +1,2 @@ -from .annotate import annotate +from .annotate import annotate, ScopeInfo +from .scope import Scope \ No newline at end of file From d2915002f63551dd495ae636d0ed098c48cc691e Mon Sep 17 00:00:00 2001 From: Kavi Gupta Date: Mon, 28 Apr 2025 15:58:53 -0400 Subject: [PATCH 3/5] make compatible with 3.9 --- ast_scope/annotator.py | 7 +++++-- ast_scope/scope.py | 2 +- setup.py | 4 ++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/ast_scope/annotator.py b/ast_scope/annotator.py index 41c4352..f063d7a 100644 --- a/ast_scope/annotator.py +++ b/ast_scope/annotator.py @@ -1,6 +1,7 @@ import abc import ast -from typing import Iterable, Self +from typing import Iterable +from typing_extensions import Self from .group_similar_constructs import GroupSimilarConstructsVisitor from .utils import compute_class_fields, name_of_alias @@ -41,7 +42,6 @@ def find( """ Finds the actual frame containing the variable name, or None if no frame exists """ - pass class IntermediateGlobalScope(IntermediateScope): @@ -135,12 +135,15 @@ def modify(self): self.scope.modify(self.variable.id) def visit_Load(self, node: ast.Load): + del node self.load() def visit_Store(self, node: ast.Store): + del node self.modify() def visit_Del(self, node: ast.Del): + del node self.modify() def visit_AugLoad(self, node: ast.AugLoad): diff --git a/ast_scope/scope.py b/ast_scope/scope.py index af84537..21e9e2c 100644 --- a/ast_scope/scope.py +++ b/ast_scope/scope.py @@ -1,6 +1,6 @@ import abc import ast -from typing import Self +from typing_extensions import Self import attr diff --git a/setup.py b/setup.py index 8afa49e..6ad2df9 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/kavigupta/ast_scope", - packages=setuptools.find_packages('ast_scope'), + packages=setuptools.find_packages("ast_scope"), classifiers=[ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", @@ -20,5 +20,5 @@ "Operating System :: OS Independent", ], python_requires=">=3.9", - install_requires=["attrs>=19.3.0"], + install_requires=["attrs>=19.3.0", "typing-extensions>=4.13.2"], ) From e5a8a9e38946a2c927bcd92eeae8b33866e26ae0 Mon Sep 17 00:00:00 2001 From: Kavi Gupta Date: Mon, 28 Apr 2025 16:02:37 -0400 Subject: [PATCH 4/5] add from __future__ to make the unions work --- ast_scope/annotator.py | 2 ++ ast_scope/group_similar_constructs.py | 2 ++ ast_scope/pull_scope.py | 6 ++++-- ast_scope/scope.py | 2 ++ ast_scope/utils.py | 2 ++ 5 files changed, 12 insertions(+), 2 deletions(-) diff --git a/ast_scope/annotator.py b/ast_scope/annotator.py index f063d7a..4b96eb1 100644 --- a/ast_scope/annotator.py +++ b/ast_scope/annotator.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import abc import ast from typing import Iterable diff --git a/ast_scope/group_similar_constructs.py b/ast_scope/group_similar_constructs.py index 15ed0c1..2a1930f 100644 --- a/ast_scope/group_similar_constructs.py +++ b/ast_scope/group_similar_constructs.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import ast diff --git a/ast_scope/pull_scope.py b/ast_scope/pull_scope.py index 4904ea9..a1d2bed 100644 --- a/ast_scope/pull_scope.py +++ b/ast_scope/pull_scope.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import ast -from typing import cast +from typing import cast, Union from ast_scope.utils import compute_class_fields @@ -29,7 +31,7 @@ def convert(self, int_scope: IntermediateScope | None): return self.error_scope if isinstance(int_scope, IntermediateGlobalScope): return self.global_scope - int_scope = cast(IntermediateClassScope | IntermediateFunctionScope, int_scope) + int_scope = cast(Union[IntermediateClassScope, IntermediateFunctionScope], int_scope) return self.node_to_corresponding_scope[int_scope.node] def pull_scope(self, node: ast.AST, include_as_variable: bool = True) -> Scope: diff --git a/ast_scope/scope.py b/ast_scope/scope.py index 21e9e2c..38e310f 100644 --- a/ast_scope/scope.py +++ b/ast_scope/scope.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import abc import ast from typing_extensions import Self diff --git a/ast_scope/utils.py b/ast_scope/utils.py index 4768db7..a55e6dd 100644 --- a/ast_scope/utils.py +++ b/ast_scope/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import ast from typing import List From 15acef4915e54f977f1f6db586062905293bba11 Mon Sep 17 00:00:00 2001 From: Kavi Gupta Date: Mon, 28 Apr 2025 16:03:35 -0400 Subject: [PATCH 5/5] update --- ast_scope/annotator.py | 1 + ast_scope/pull_scope.py | 6 ++++-- ast_scope/scope.py | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/ast_scope/annotator.py b/ast_scope/annotator.py index 4b96eb1..1c1ec66 100644 --- a/ast_scope/annotator.py +++ b/ast_scope/annotator.py @@ -3,6 +3,7 @@ import abc import ast from typing import Iterable + from typing_extensions import Self from .group_similar_constructs import GroupSimilarConstructsVisitor diff --git a/ast_scope/pull_scope.py b/ast_scope/pull_scope.py index a1d2bed..4f3c060 100644 --- a/ast_scope/pull_scope.py +++ b/ast_scope/pull_scope.py @@ -1,7 +1,7 @@ from __future__ import annotations import ast -from typing import cast, Union +from typing import Union, cast from ast_scope.utils import compute_class_fields @@ -31,7 +31,9 @@ def convert(self, int_scope: IntermediateScope | None): return self.error_scope if isinstance(int_scope, IntermediateGlobalScope): return self.global_scope - int_scope = cast(Union[IntermediateClassScope, IntermediateFunctionScope], int_scope) + int_scope = cast( + Union[IntermediateClassScope, IntermediateFunctionScope], int_scope + ) return self.node_to_corresponding_scope[int_scope.node] def pull_scope(self, node: ast.AST, include_as_variable: bool = True) -> Scope: diff --git a/ast_scope/scope.py b/ast_scope/scope.py index 38e310f..ff46515 100644 --- a/ast_scope/scope.py +++ b/ast_scope/scope.py @@ -2,9 +2,9 @@ import abc import ast -from typing_extensions import Self import attr +from typing_extensions import Self from .annotator import name_of_alias