Repository URL to install this package:
|
Version:
0.22.2 ▾
|
"""This module contains the base classes for dealing with extensions."""
from __future__ import annotations
import ast
import enum
from collections import defaultdict
from inspect import isclass
from typing import TYPE_CHECKING, Any, Sequence, Type, Union
from griffe.agents.base import BaseInspector, BaseVisitor
from griffe.agents.nodes import ObjectNode
from griffe.exceptions import ExtensionNotLoadedError
from griffe.importer import dynamic_import
if TYPE_CHECKING:
from griffe.agents.inspector import Inspector
from griffe.agents.visitor import Visitor
class When(enum.Enum):
"""This enumeration contains the different times at which an extension is used.
Attributes:
before_all: For each node, before the visit/inspection.
before_children: For each node, after the visit has started, and before the children visit/inspection.
after_children: For each node, after the children have been visited/inspected, and before finishing the visit/inspection.
after_all: For each node, after the visit/inspection.
"""
before_all: int = 1
before_children: int = 2
after_children: int = 3
after_all: int = 4
class VisitorExtension(BaseVisitor):
"""The node visitor extension base class, to inherit from."""
when: When
def __init__(self) -> None:
"""Initialize the visitor extension."""
super().__init__()
self.visitor: Visitor = None # type: ignore[assignment]
def attach(self, visitor: Visitor) -> None:
"""Attach the parent visitor to this extension.
Parameters:
visitor: The parent visitor.
"""
self.visitor = visitor
def visit(self, node: ast.AST) -> None:
"""Visit a node.
Parameters:
node: The node to visit.
"""
getattr(self, f"visit_{node.kind}", lambda _: None)(node) # type: ignore[attr-defined]
class InspectorExtension(BaseInspector):
"""The object inspector extension base class, to inherit from."""
when: When
def __init__(self) -> None:
"""Initialize the inspector extension."""
super().__init__()
self.inspector: Inspector = None # type: ignore[assignment]
def attach(self, inspector: Inspector) -> None:
"""Attach the parent inspector to this extension.
Parameters:
inspector: The parent inspector.
"""
self.inspector = inspector
def inspect(self, node: ObjectNode) -> None:
"""Inspect a node.
Parameters:
node: The node to inspect.
"""
getattr(self, f"inspect_{node.kind}", lambda _: None)(node)
Extension = Union[VisitorExtension, InspectorExtension]
class Extensions:
"""This class helps iterating on extensions that should run at different times."""
def __init__(self, *extensions: Extension) -> None:
"""Initialize the extensions container.
Parameters:
*extensions: The extensions to add.
"""
self._visitors: dict[When, list[VisitorExtension]] = defaultdict(list)
self._inspectors: dict[When, list[InspectorExtension]] = defaultdict(list)
self.add(*extensions)
def add(self, *extensions: Extension) -> None:
"""Add extensions to this container.
Parameters:
*extensions: The extensions to add.
"""
for extension in extensions:
if isinstance(extension, VisitorExtension):
self._visitors[extension.when].append(extension)
else:
self._inspectors[extension.when].append(extension)
def attach_visitor(self, parent_visitor: Visitor) -> Extensions:
"""Attach a parent visitor to the visitor extensions.
Parameters:
parent_visitor: The parent visitor, leading the visit.
Returns:
Self, conveniently.
"""
for when in self._visitors.keys():
for visitor in self._visitors[when]:
visitor.attach(parent_visitor)
return self
def attach_inspector(self, parent_inspector: Inspector) -> Extensions:
"""Attach a parent inspector to the inspector extensions.
Parameters:
parent_inspector: The parent inspector, leading the inspection.
Returns:
Self, conveniently.
"""
for when in self._inspectors.keys():
for inspector in self._inspectors[when]:
inspector.attach(parent_inspector)
return self
@property
def before_visit(self) -> list[VisitorExtension]:
"""Return the visitors that run before the visit.
Returns:
Visitors.
"""
return self._visitors[When.before_all]
@property
def before_children_visit(self) -> list[VisitorExtension]:
"""Return the visitors that run before the children visit.
Returns:
Visitors.
"""
return self._visitors[When.before_children]
@property
def after_children_visit(self) -> list[VisitorExtension]:
"""Return the visitors that run after the children visit.
Returns:
Visitors.
"""
return self._visitors[When.after_children]
@property
def after_visit(self) -> list[VisitorExtension]:
"""Return the visitors that run after the visit.
Returns:
Visitors.
"""
return self._visitors[When.after_all]
@property
def before_inspection(self) -> list[InspectorExtension]:
"""Return the inspectors that run before the inspection.
Returns:
Inspectors.
"""
return self._inspectors[When.before_all]
@property
def before_children_inspection(self) -> list[InspectorExtension]:
"""Return the inspectors that run before the children inspection.
Returns:
Inspectors.
"""
return self._inspectors[When.before_children]
@property
def after_children_inspection(self) -> list[InspectorExtension]:
"""Return the inspectors that run after the children inspection.
Returns:
Inspectors.
"""
return self._inspectors[When.after_children]
@property
def after_inspection(self) -> list[InspectorExtension]:
"""Return the inspectors that run after the inspection.
Returns:
Inspectors.
"""
return self._inspectors[When.after_all]
builtin_extensions: set[str] = {
"hybrid",
}
def load_extension(extension: str | dict[str, Any] | Extension | Type[Extension]) -> Extension:
"""Load a configured extension.
Parameters:
extension: An extension, with potential configuration options.
Raises:
ExtensionNotLoadedError: When the extension cannot be loaded,
either because the module is not found, or because it does not expose
the Extension attribute. ImportError will bubble up so users can see
the traceback.
Returns:
An extension instance.
"""
if isinstance(extension, (VisitorExtension, InspectorExtension)):
return extension
if isclass(extension) and issubclass(extension, (VisitorExtension, InspectorExtension)): # type: ignore[arg-type]
return extension() # type: ignore[operator]
if isinstance(extension, dict):
import_path, options = next(iter(extension.items()))
else: # we consider it's a string
import_path = str(extension)
options = {}
if import_path in builtin_extensions:
import_path = f"griffe.agents.extensions.{import_path}"
try:
ext_module = dynamic_import(import_path)
except ModuleNotFoundError as error:
raise ExtensionNotLoadedError(f"Extension module '{import_path}' could not be found") from error
except ImportError as error:
raise ExtensionNotLoadedError(f"Error while importing extension module '{import_path}': {error}") from error
try:
return ext_module.Extension(**options)
except AttributeError as error: # noqa: WPS440
raise ExtensionNotLoadedError(f"Extension module '{import_path}' has no 'Extension' attribute") from error
def load_extensions(exts: Sequence[str | dict[str, Any] | Extension | Type[Extension]]) -> Extensions: # noqa: WPS231
"""Load configured extensions.
Parameters:
exts: A sequence of extension, with potential configuration options.
Returns:
An extensions container.
"""
extensions = Extensions()
for extension in exts:
extensions.add(load_extension(extension))
return extensions