diff --git a/ast_scope/__init__.py b/ast_scope/__init__.py index b8b1ca0..fa8e391 100644 --- a/ast_scope/__init__.py +++ b/ast_scope/__init__.py @@ -1 +1,2 @@ -from .annotate import annotate +from .annotate import ScopeInfo, annotate +from .scope import Scope diff --git a/ast_scope/annotate.py b/ast_scope/annotate.py index 9ab705a..33a394b 100644 --- a/ast_scope/annotate.py +++ b/ast_scope/annotate.py @@ -1,13 +1,21 @@ -from ast_scope.scope import FunctionScope +import ast -from .annotator import AnnotateScope, IntermediateGlobalScope +from ast_scope.scope import ErrorScope, FunctionScope, GlobalScope, Scope + +from .annotator import AnnotateScope, IntermediateGlobalScope, IntermediateScope from .graph import DiGraph from .pull_scope import PullScopes from .utils import get_all_nodes, get_name class ScopeInfo: - def __init__(self, tree, global_scope, error_scope, node_to_containing_scope): + def __init__( + self, + tree: ast.AST, + global_scope: GlobalScope, + error_scope: ErrorScope, + node_to_containing_scope: dict[ast.AST, Scope], + ): self._tree = tree self._global_scope = global_scope self._error_scope = error_scope @@ -43,13 +51,13 @@ def static_dependency_graph(self): def __iter__(self): return iter(self._node_to_containing_scope) - def __contains__(self, node): + def __contains__(self, node: ast.AST): return node in self._node_to_containing_scope - def __getitem__(self, node): + def __getitem__(self, node: ast.AST): return self._node_to_containing_scope[node] - def function_scope_for(self, node): + def function_scope_for(self, node: ast.AST): """ Returns the function scope for the given FunctionDef node. """ @@ -62,8 +70,8 @@ def function_scope_for(self, node): return None -def annotate(tree, class_binds_near=False): - annotation_dict = {} +def annotate(tree: ast.AST, class_binds_near: bool = False): + annotation_dict: dict[ast.AST, tuple[str, IntermediateScope, bool]] = {} annotator = AnnotateScope( IntermediateGlobalScope(), annotation_dict, class_binds_near=class_binds_near ) diff --git a/ast_scope/annotator.py b/ast_scope/annotator.py index d94b0bb..1c1ec66 100644 --- a/ast_scope/annotator.py +++ b/ast_scope/annotator.py @@ -1,5 +1,10 @@ +from __future__ import annotations + import abc import ast +from typing import Iterable + +from typing_extensions import Self from .group_similar_constructs import GroupSimilarConstructsVisitor from .utils import compute_class_fields, name_of_alias @@ -12,36 +17,38 @@ class IntermediateScope(abc.ABC): """ def __init__(self): - self.referenced_variables = set() - self.written_variables = set() - self.nonlocal_variables = set() - self.global_variables = set() + self.referenced_variables: set[str] = set() + self.written_variables: set[str] = set() + self.nonlocal_variables: set[str] = set() + self.global_variables: set[str] = set() - def load(self, variable): + def load(self, variable: str): self.referenced_variables.add(variable) - def modify(self, variable): + def modify(self, variable: str): self.written_variables.add(variable) - def globalize(self, variable): + def globalize(self, variable: str): self.global_variables.add(variable) - def nonlocalize(self, variable): + def nonlocalize(self, variable: str): self.nonlocal_variables.add(variable) @abc.abstractmethod - def global_frame(self): + def global_frame(self) -> "IntermediateGlobalScope": pass @abc.abstractmethod - def find(self, name, is_assignment, global_acceptable=True): + def find( + self, name: str, is_assignment: bool, global_acceptable: bool = True + ) -> Self | None: """ Finds the actual frame containing the variable name, or None if no frame exists """ class IntermediateGlobalScope(IntermediateScope): - def find(self, name, is_assignment, global_acceptable=True): + def find(self, name: str, is_assignment: bool, global_acceptable: bool = True): if not global_acceptable: return None return self @@ -51,11 +58,11 @@ def global_frame(self): class IntermediateScopeWithParent(IntermediateScope): - def __init__(self, parent): - super().__init__() + def __init__(self, parent: IntermediateScope): self.parent = parent + super().__init__() - def true_parent(self): + def true_parent(self) -> IntermediateScope: parent = self.parent while isinstance(parent, IntermediateClassScope): parent = parent.parent @@ -63,14 +70,18 @@ def true_parent(self): class IntermediateFunctionScope(IntermediateScopeWithParent): - def __init__(self, node, parent): + def __init__( + self, + node: ast.FunctionDef | ast.AsyncFunctionDef | ast.comprehension | ast.Lambda, + parent: IntermediateScope, + ): super().__init__(parent) self.node = node - def global_frame(self): + def global_frame(self) -> IntermediateGlobalScope: return self.true_parent().global_frame() - def find(self, name, is_assignment, global_acceptable=True): + def find(self, name: str, is_assignment: bool, global_acceptable: bool = True): if name in self.global_variables: return self.global_frame() if name in self.nonlocal_variables: @@ -81,15 +92,17 @@ def find(self, name, is_assignment, global_acceptable=True): class IntermediateClassScope(IntermediateScopeWithParent): - def __init__(self, node, parent, class_binds_near): + def __init__( + self, node: ast.ClassDef, parent: IntermediateScope, class_binds_near: bool + ): super().__init__(parent) self.node = node self.class_binds_near = class_binds_near - def global_frame(self): - return self.true_parent().find(self) + def global_frame(self) -> IntermediateGlobalScope: + return self.true_parent().global_frame() - def find(self, name, is_assignment, global_acceptable=True): + def find(self, name: str, is_assignment: bool, global_acceptable: bool = True): if self.class_binds_near: # anything can be in a class frame return self @@ -103,12 +116,17 @@ class GrabVariable(ast.NodeVisitor): Dumps variables from a given name object into the given scope. """ - def __init__(self, scope, variable, annotation_dict): + def __init__( + self, + scope: IntermediateScope, + variable: ast.Name, + annotation_dict: dict[ast.AST, tuple[str, IntermediateScope, bool]], + ): self.scope = scope self.variable = variable self.annotation_dict = annotation_dict - def visit_generic(self, node): + def visit_generic(self, node: ast.AST): raise RuntimeError(f"Unsupported node type: {node}") def load(self): @@ -119,70 +137,79 @@ def modify(self): self.annotation_dict[self.variable] = self.variable.id, self.scope, True self.scope.modify(self.variable.id) - def visit_Load(self, _): + def visit_Load(self, node: ast.Load): + del node self.load() - def visit_Store(self, _): + def visit_Store(self, node: ast.Store): + del node self.modify() - def visit_Del(self, _): + def visit_Del(self, node: ast.Del): + del node self.modify() - def visit_AugLoad(self, _): + def visit_AugLoad(self, node: ast.AugLoad): raise RuntimeError("Unsupported: AugStore") - def visit_AugStore(self, _): + def visit_AugStore(self, node: ast.AugStore): raise RuntimeError("Unsupported: AugStore") class ProcessArguments(ast.NodeVisitor): - def __init__(self, expr_scope, arg_scope): + def __init__(self, expr_scope: "AnnotateScope", arg_scope: "AnnotateScope"): self.expr_scope = expr_scope self.arg_scope = arg_scope - def visit_arg(self, node): + def visit_arg(self, node: ast.arg): self.arg_scope.visit(node) visit_all(self.expr_scope, node.annotation, getattr(node, "type_comment", None)) - def visit_arguments(self, node): + def visit_arguments(self, node: ast.AST): super().generic_visit(node) - def generic_visit(self, node): + def generic_visit(self, node: ast.AST): self.expr_scope.visit(node) class AnnotateScope(GroupSimilarConstructsVisitor): - def __init__(self, scope, annotation_dict, class_binds_near): + def __init__( + self, + scope: IntermediateScope, + annotation_dict: dict[ast.AST, tuple[str, IntermediateScope, bool]], + class_binds_near: bool, + ): self.scope = scope self.annotation_dict = annotation_dict self.class_binds_near = class_binds_near - def annotate_intermediate_scope(self, node, name, is_assign): + def annotate_intermediate_scope(self, node: ast.AST, name: str, is_assign: bool): self.annotation_dict[node] = name, self.scope, is_assign - def visit_Name(self, name_node): - GrabVariable(self.scope, name_node, self.annotation_dict).generic_visit( - name_node - ) + def visit_Name(self, node: ast.Name): + GrabVariable(self.scope, node, self.annotation_dict).generic_visit(node) - def visit_ExceptHandler(self, handler_node): - self.annotate_intermediate_scope(handler_node, handler_node.name, True) - self.scope.modify(handler_node.name) - visit_all(self, handler_node.type, handler_node.body) + def visit_ExceptHandler(self, node: ast.ExceptHandler): + assert node.name + self.annotate_intermediate_scope(node, node.name, True) + self.scope.modify(node.name) + visit_all(self, node.type, node.body) - def visit_alias(self, alias_node): - variable = name_of_alias(alias_node) - self.annotate_intermediate_scope(alias_node, variable, True) + def visit_alias(self, node: ast.alias): + variable = name_of_alias(node) + self.annotate_intermediate_scope(node, variable, True) self.scope.modify(variable) - def visit_arg(self, arg): - self.annotate_intermediate_scope(arg, arg.arg, True) - self.scope.modify(arg.arg) + def visit_arg(self, node: ast.arg): + self.annotate_intermediate_scope(node, node.arg, True) + self.scope.modify(node.arg) - def create_subannotator(self, scope): + def create_subannotator(self, scope: IntermediateScope): return AnnotateScope(scope, self.annotation_dict, self.class_binds_near) - def visit_function_def(self, func_node, is_async): + def visit_function_def( + self, func_node: ast.FunctionDef | ast.AsyncFunctionDef, is_async: bool + ): del is_async self.annotate_intermediate_scope(func_node, func_node.name, True) self.scope.modify(func_node.name) @@ -195,19 +222,22 @@ def visit_function_def(self, func_node, is_async): ProcessArguments(self, subscope).visit(func_node.args) visit_all(subscope, func_node.body, func_node.returns) - def visit_Lambda(self, func_node): - self.annotate_intermediate_scope(func_node, "", None) - subscope = self.create_subannotator( - IntermediateFunctionScope(func_node, self.scope) - ) - ProcessArguments(self, subscope).visit(func_node.args) - visit_all(subscope, func_node.body) - - def visit_comprehension_generic(self, targets, comprehensions, node): + def visit_Lambda(self, node: ast.Lambda): + self.annotate_intermediate_scope(node, "", False) + subscope = self.create_subannotator(IntermediateFunctionScope(node, self.scope)) + ProcessArguments(self, subscope).visit(node.args) + visit_all(subscope, node.body) + + def visit_comprehension_generic( + self, + targets: list[ast.expr], + comprehensions: list[ast.comprehension], + node: ast.AST, + ): del node current_scope = self for comprehension in comprehensions: - self.annotate_intermediate_scope(comprehension, "", None) + self.annotate_intermediate_scope(comprehension, "", False) subscope = self.create_subannotator( IntermediateFunctionScope(comprehension, current_scope.scope) ) @@ -216,30 +246,30 @@ def visit_comprehension_generic(self, targets, comprehensions, node): current_scope = subscope visit_all(current_scope, targets) - def visit_ClassDef(self, class_node): - self.annotate_intermediate_scope(class_node, class_node.name, True) - self.scope.modify(class_node.name) + def visit_ClassDef(self, node: ast.ClassDef): + self.annotate_intermediate_scope(node, node.name, True) + self.scope.modify(node.name) subscope = self.create_subannotator( - IntermediateClassScope(class_node, self.scope, self.class_binds_near) + IntermediateClassScope(node, self.scope, self.class_binds_near) ) - class_scope_fields, parent_scope_fields = compute_class_fields(class_node) + class_scope_fields, parent_scope_fields = compute_class_fields(node) visit_all(subscope, *class_scope_fields) visit_all(self, *parent_scope_fields) - def visit_Global(self, global_node): - for name in global_node.names: + def visit_Global(self, node: ast.Global): + for name in node.names: self.scope.globalize(name) - def visit_Nonlocal(self, nonlocal_node): - for name in nonlocal_node.names: + def visit_Nonlocal(self, node: ast.Nonlocal): + for name in node.names: self.scope.nonlocalize(name) -def visit_all(visitor, *nodes): +def visit_all(visitor: ast.NodeVisitor, *nodes: Iterable[ast.AST] | ast.AST | None): for node in nodes: if node is None: pass - elif isinstance(node, list): + elif isinstance(node, Iterable): visit_all(visitor, *node) else: visitor.visit(node) diff --git a/ast_scope/graph.py b/ast_scope/graph.py index 823d4b8..de3a37a 100644 --- a/ast_scope/graph.py +++ b/ast_scope/graph.py @@ -1,16 +1,19 @@ +from typing import Iterable + + class DiGraph: def __init__(self): - self.__adjacency_list = {} + self.__adjacency_list: dict[str, set[str]] = {} - def add_nodes_from(self, iterable): + def add_nodes_from(self, iterable: Iterable[str]): for node in iterable: self.add_node(node) - def add_node(self, node): + def add_node(self, node: str): if node not in self.__adjacency_list: self.__adjacency_list[node] = set() - def add_edge(self, source, target): + def add_edge(self, source: str, target: str): self.__adjacency_list[source].add(target) def nodes(self): @@ -23,5 +26,5 @@ def edges(self): for target in targets ) - def neighbors(self, node): + def neighbors(self, node: str): return list(self.__adjacency_list[node]) diff --git a/ast_scope/group_similar_constructs.py b/ast_scope/group_similar_constructs.py index a8dddf5..2a1930f 100644 --- a/ast_scope/group_similar_constructs.py +++ b/ast_scope/group_similar_constructs.py @@ -1,37 +1,40 @@ +from __future__ import annotations + import ast class GroupSimilarConstructsVisitor(ast.NodeVisitor): - def visit_function_def(self, func_node, is_async): + def visit_function_def( + self, func_node: ast.FunctionDef | ast.AsyncFunctionDef, is_async: bool + ): del is_async return self.generic_visit(func_node) - def visit_FunctionDef(self, func_node): - return self.visit_function_def(func_node, is_async=False) + def visit_FunctionDef(self, node: ast.FunctionDef): + return self.visit_function_def(node, is_async=False) - def visit_AsyncFunctionDef(self, func_node): - return self.visit_function_def(func_node, is_async=True) + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef): + return self.visit_function_def(node, is_async=True) - def visit_comprehension_generic(self, targets, comprehensions, node): + def visit_comprehension_generic( + self, + targets: list[ast.expr], + comprehensions: list[ast.comprehension], + node: ast.AST, + ): del targets, comprehensions return self.generic_visit(node) - def visit_DictComp(self, comp_node): + def visit_DictComp(self, node: ast.DictComp): return self.visit_comprehension_generic( - [comp_node.key, comp_node.value], comp_node.generators, comp_node + [node.key, node.value], node.generators, node ) - def visit_ListComp(self, comp_node): - return self.visit_comprehension_generic( - [comp_node.elt], comp_node.generators, comp_node - ) + def visit_ListComp(self, node: ast.ListComp): + return self.visit_comprehension_generic([node.elt], node.generators, node) - def visit_SetComp(self, comp_node): - return self.visit_comprehension_generic( - [comp_node.elt], comp_node.generators, comp_node - ) + def visit_SetComp(self, node: ast.SetComp): + return self.visit_comprehension_generic([node.elt], node.generators, node) - def visit_GeneratorExp(self, comp_node): - return self.visit_comprehension_generic( - [comp_node.elt], comp_node.generators, comp_node - ) + def visit_GeneratorExp(self, node: ast.GeneratorExp): + return self.visit_comprehension_generic([node.elt], node.generators, node) diff --git a/ast_scope/pull_scope.py b/ast_scope/pull_scope.py index 979b2e5..4f3c060 100644 --- a/ast_scope/pull_scope.py +++ b/ast_scope/pull_scope.py @@ -1,26 +1,42 @@ +from __future__ import annotations + +import ast +from typing import Union, cast + from ast_scope.utils import compute_class_fields -from .annotator import IntermediateGlobalScope, visit_all +from .annotator import ( + IntermediateClassScope, + IntermediateFunctionScope, + IntermediateGlobalScope, + IntermediateScope, + visit_all, +) from .group_similar_constructs import GroupSimilarConstructsVisitor -from .scope import ClassScope, ErrorScope, FunctionScope, GlobalScope +from .scope import ClassScope, ErrorScope, FunctionScope, GlobalScope, Scope class PullScopes(GroupSimilarConstructsVisitor): - def __init__(self, annotation_dict): + def __init__( + self, annotation_dict: dict[ast.AST, tuple[str, IntermediateScope, bool]] + ): self.annotation_dict = annotation_dict - self.node_to_corresponding_scope = {} - self.node_to_containing_scope = {} + self.node_to_corresponding_scope: dict[ast.AST, Scope] = {} + self.node_to_containing_scope: dict[ast.AST, Scope] = {} self.global_scope = GlobalScope() self.error_scope = ErrorScope() - def convert(self, int_scope): + def convert(self, int_scope: IntermediateScope | None): if int_scope is None: return self.error_scope if isinstance(int_scope, IntermediateGlobalScope): return self.global_scope + int_scope = cast( + Union[IntermediateClassScope, IntermediateFunctionScope], int_scope + ) return self.node_to_corresponding_scope[int_scope.node] - def pull_scope(self, node, include_as_variable=True): + def pull_scope(self, node: ast.AST, include_as_variable: bool = True) -> Scope: name, intermediate_scope, is_assign = self.annotation_dict[node] true_intermediate_scope = intermediate_scope.find(name, is_assign) scope = self.convert(true_intermediate_scope) @@ -28,22 +44,24 @@ def pull_scope(self, node, include_as_variable=True): self.node_to_containing_scope[node] = scope return scope - def visit_Name(self, node): + def visit_Name(self, node: ast.Name): scope = self.pull_scope(node) scope.add_variable(node) super().generic_visit(node) - def visit_arg(self, node): + def visit_arg(self, node: ast.arg): scope = self.pull_scope(node) scope.add_argument(node) super().generic_visit(node) - def visit_alias(self, node): + def visit_alias(self, node: ast.alias): scope = self.pull_scope(node) scope.add_import(node) super().generic_visit(node) - def visit_function_def(self, func_node, is_async): + def visit_function_def( + self, func_node: ast.FunctionDef | ast.AsyncFunctionDef, is_async: bool + ): del is_async scope = self.pull_scope(func_node) if func_node not in self.node_to_corresponding_scope: @@ -57,7 +75,7 @@ def visit_function_def(self, func_node, is_async): ) super().generic_visit(func_node) - def visit_Lambda(self, node): + def visit_Lambda(self, node: ast.Lambda): scope = self.pull_scope(node, include_as_variable=False) if node not in self.node_to_corresponding_scope: self.node_to_corresponding_scope[node] = FunctionScope(node, scope) @@ -66,17 +84,22 @@ def visit_Lambda(self, node): ) super().generic_visit(node) - def visit_ExceptHandler(self, node): + def visit_ExceptHandler(self, node: ast.ExceptHandler): scope = self.pull_scope(node) scope.add_exception(node) super().generic_visit(node) - def visit_comprehension_generic(self, targets, comprehensions, node): + def visit_comprehension_generic( + self, + targets: list[ast.expr], + comprehensions: list[ast.comprehension], + node: ast.AST, + ): # mate sure to visit the comprehensions first visit_all(self, comprehensions) visit_all(self, targets) - def visit_comprehension(self, node): + def visit_comprehension(self, node: ast.comprehension): scope = self.pull_scope(node, include_as_variable=False) if node not in self.node_to_corresponding_scope: self.node_to_corresponding_scope[node] = FunctionScope(node, scope) @@ -85,7 +108,7 @@ def visit_comprehension(self, node): ) super().generic_visit(node) - def visit_ClassDef(self, node): + def visit_ClassDef(self, node: ast.ClassDef): scope = self.pull_scope(node) if node not in self.node_to_corresponding_scope: self.node_to_corresponding_scope[node] = ClassScope(node, scope) diff --git a/ast_scope/scope.py b/ast_scope/scope.py index 1923cae..ff46515 100644 --- a/ast_scope/scope.py +++ b/ast_scope/scope.py @@ -1,27 +1,40 @@ +from __future__ import annotations + import abc +import ast import attr +from typing_extensions import Self from .annotator import name_of_alias @attr.s class Variables: - arguments = attr.ib(attr.Factory(set)) - variables = attr.ib(attr.Factory(set)) - functions = attr.ib(attr.Factory(set)) - classes = attr.ib(attr.Factory(set)) - import_statements = attr.ib(attr.Factory(set)) - exceptions = attr.ib(attr.Factory(set)) + arguments: set[ast.arg] = attr.ib(factory=set) + variables: set[ast.Name] = attr.ib(factory=set) + functions: set[ + ast.FunctionDef | ast.Lambda | ast.comprehension | ast.AsyncFunctionDef + ] = attr.ib(factory=set) + classes: set[ast.ClassDef] = attr.ib(factory=set) + import_statements: set[ast.alias] = attr.ib(factory=set) + exceptions: set[ast.ExceptHandler] = attr.ib(factory=set) @property def node_to_symbol(self): - result = {} + result: dict[ast.AST, str] = {} result.update({var: var.arg for var in self.arguments}) result.update({var: var.id for var in self.variables}) - result.update({var: var.name for var in self.functions | self.classes}) + result.update( + { + var: var.name + for var in self.functions + if isinstance(var, (ast.FunctionDef, ast.AsyncFunctionDef)) + } + ) + result.update({var: var.name for var in self.classes}) result.update({var: name_of_alias(var) for var in self.import_statements}) - result.update({var: var.name for var in self.exceptions}) + result.update({var: var.name for var in self.exceptions if var.name}) return result @property @@ -33,28 +46,33 @@ class Scope(abc.ABC): def __init__(self): self.variables = Variables() - def add_argument(self, node): + def add_argument(self, node: ast.arg): self.variables.arguments.add(node) - def add_variable(self, node): + def add_variable(self, node: ast.Name): self.variables.variables.add(node) - def add_import(self, node): + def add_import(self, node: ast.alias): self.variables.import_statements.add(node) - def add_exception(self, node): + def add_exception(self, node: ast.ExceptHandler): self.variables.exceptions.add(node) @abc.abstractmethod - def add_child(self, scope): + def add_child(self, scope: Self): pass - def add_function(self, node, function_scope, include_as_variable): + def add_function( + self, + node: ast.FunctionDef | ast.Lambda | ast.comprehension | ast.AsyncFunctionDef, + function_scope: Self, + include_as_variable: bool, + ): if include_as_variable: self.variables.functions.add(node) self.add_child(function_scope) - def add_class(self, node, class_scope): + def add_class(self, node: ast.ClassDef, class_scope: Self): self.variables.classes.add(node) self.add_child(class_scope) @@ -66,14 +84,14 @@ def symbols_in_frame(self): class ScopeWithChildren(Scope): def __init__(self): Scope.__init__(self) - self.children = [] + self.children: list[Scope] = [] - def add_child(self, scope): + def add_child(self, scope: Scope): self.children.append(scope) class ScopeWithParent(Scope, abc.ABC): - def __init__(self, parent): + def __init__(self, parent: Scope): super().__init__() self.parent = parent @@ -88,14 +106,20 @@ class GlobalScope(ScopeWithChildren): class FunctionScope(ScopeWithChildren, ScopeWithParent): - def __init__(self, function_node, parent): + def __init__( + self, + function_node: ( + ast.FunctionDef | ast.Lambda | ast.comprehension | ast.AsyncFunctionDef + ), + parent: Scope, + ): ScopeWithChildren.__init__(self) ScopeWithParent.__init__(self, parent) self.function_node = function_node class ClassScope(ScopeWithParent): - def __init__(self, class_node, parent): + def __init__(self, class_node: ast.ClassDef, parent: Scope): super().__init__(parent) self.class_node = class_node diff --git a/ast_scope/utils.py b/ast_scope/utils.py index 2eae1c8..a55e6dd 100644 --- a/ast_scope/utils.py +++ b/ast_scope/utils.py @@ -1,18 +1,21 @@ +from __future__ import annotations + import ast +from typing import List from .group_similar_constructs import GroupSimilarConstructsVisitor class GetAllNodes(ast.NodeVisitor): def __init__(self): - self.nodes = [] + self.nodes: list[ast.AST] = [] - def generic_visit(self, node): + def generic_visit(self, node: ast.AST): self.nodes.append(node) super().generic_visit(node) -def get_all_nodes(*nodes): +def get_all_nodes(*nodes: List[ast.AST]): getter = GetAllNodes() for node in nodes: getter.visit(node) @@ -24,27 +27,29 @@ class GetName(GroupSimilarConstructsVisitor): def __init__(self): self.name = None - def visit_Name(self, node): + def visit_Name(self, node: ast.Name): self.name = node.id - def visit_function_def(self, func_node, is_async): + def visit_function_def( + self, func_node: ast.FunctionDef | ast.AsyncFunctionDef, is_async: bool + ): self.name = func_node.name - def visit_ClassDef(self, class_node): - self.name = class_node.name + def visit_ClassDef(self, node: ast.ClassDef): + self.name = node.name - def visit_alias(self, alias_node): - self.name = name_of_alias(alias_node) + def visit_alias(self, node: ast.alias): + self.name = name_of_alias(node) -def get_name(node): +def get_name(node: ast.AST): getter = GetName() getter.visit(node) assert getter.name is not None return getter.name -def name_of_alias(alias_node): +def name_of_alias(alias_node: ast.alias): if alias_node.asname is not None: return alias_node.asname return alias_node.name diff --git a/setup.py b/setup.py index 8afa49e..6ad2df9 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/kavigupta/ast_scope", - packages=setuptools.find_packages('ast_scope'), + packages=setuptools.find_packages("ast_scope"), classifiers=[ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", @@ -20,5 +20,5 @@ "Operating System :: OS Independent", ], python_requires=">=3.9", - install_requires=["attrs>=19.3.0"], + install_requires=["attrs>=19.3.0", "typing-extensions>=4.13.2"], )