diff --git a/src/graphql/language/ast.py b/src/graphql/language/ast.py index ddbf6520..0189723f 100644 --- a/src/graphql/language/ast.py +++ b/src/graphql/language/ast.py @@ -2,17 +2,17 @@ from __future__ import annotations -from copy import copy, deepcopy +from dataclasses import dataclass, field, fields from enum import Enum -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any, ClassVar, Union + +from ..pyutils import camel_to_snake try: from typing import TypeAlias except ImportError: # Python < 3.10 from typing_extensions import TypeAlias -from ..pyutils import camel_to_snake - if TYPE_CHECKING: from .source import Source from .token_kind import TokenKind @@ -168,7 +168,7 @@ def __copy__(self) -> Token: def __deepcopy__(self, memo: dict) -> Token: """Allow only shallow copies to avoid recursion.""" - return copy(self) + return self.__copy__() def __getstate__(self) -> dict[str, Any]: """Remove the links when pickling. @@ -341,24 +341,25 @@ class OperationType(Enum): # Base AST Node -class Node: - """AST nodes""" +class _KeysProperty: + """Descriptor providing .keys at both class and instance level. - # allow custom attributes and weak references (not used internally) - __slots__ = "__dict__", "__weakref__", "_hash", "loc" + For backwards compatibility only. Prefer using dataclasses.fields() instead. + """ - loc: Location | None + def __get__(self, obj: object, cls: type) -> tuple[str, ...]: + if not hasattr(cls, "__dataclass_fields__"): + return () # During class construction + return tuple(f.name for f in fields(cls)) - kind: str = "ast" # the kind of the node as a snake_case string - keys: tuple[str, ...] = ("loc",) # the names of the attributes of this node - def __init__(self, **kwargs: Any) -> None: - """Initialize the node with the given keyword arguments.""" - for key in self.keys: - value = kwargs.get(key) - if isinstance(value, list): - value = tuple(value) - setattr(self, key, value) +@dataclass(frozen=True, repr=False, kw_only=True) +class Node: + """AST nodes""" + + kind: ClassVar[str] = "ast" + keys: ClassVar[tuple[str, ...]] = _KeysProperty() # type: ignore[assignment] + loc: Location | None = None def __repr__(self) -> str: """Get a simple representation of the node.""" @@ -369,64 +370,17 @@ def __repr__(self) -> str: name = getattr(self, "name", None) if name: rep += f"(name={name.value!r})" - loc = getattr(self, "loc", None) - if loc: - rep += f" at {loc}" + if self.loc: + rep += f" at {self.loc}" return rep - def __eq__(self, other: object) -> bool: - """Test whether two nodes are equal (recursively).""" - return ( - isinstance(other, Node) - and self.__class__ == other.__class__ - and all(getattr(self, key) == getattr(other, key) for key in self.keys) - ) - - def __hash__(self) -> int: - """Get a cached hash value for the node.""" - # Caching the hash values improves the performance of AST validators - hashed = getattr(self, "_hash", None) - if hashed is None: - self._hash = id(self) # avoid recursion - hashed = hash(tuple(getattr(self, key) for key in self.keys)) - self._hash = hashed - return hashed - - def __setattr__(self, key: str, value: Any) -> None: - # reset cashed hash value if attributes are changed - if hasattr(self, "_hash") and key in self.keys: - del self._hash - super().__setattr__(key, value) - - def __copy__(self) -> Node: - """Create a shallow copy of the node.""" - return self.__class__(**{key: getattr(self, key) for key in self.keys}) - - def __deepcopy__(self, memo: dict) -> Node: - """Create a deep copy of the node""" - return self.__class__( - **{key: deepcopy(getattr(self, key), memo) for key in self.keys} - ) - def __init_subclass__(cls) -> None: super().__init_subclass__() - name = cls.__name__ - try: - name = name.removeprefix("Const").removesuffix("Node") - except AttributeError: # pragma: no cover (Python < 3.9) - if name.startswith("Const"): - name = name[5:] - if name.endswith("Node"): - name = name[:-4] + name = cls.__name__.removeprefix("Const").removesuffix("Node") cls.kind = camel_to_snake(name) - keys: list[str] = [] - for base in cls.__bases__: - keys.extend(base.keys) # type: ignore - keys.extend(cls.__slots__) - cls.keys = tuple(keys) def to_dict(self, locations: bool = False) -> dict: - """Concert node to a dictionary.""" + """Convert node to a dictionary.""" from ..utilities import ast_to_dict return ast_to_dict(self, locations) @@ -435,203 +389,161 @@ def to_dict(self, locations: bool = False) -> dict: # Name +@dataclass(frozen=True, repr=False, kw_only=True) class NameNode(Node): - __slots__ = ("value",) - value: str -# Document - - -class DocumentNode(Node): - __slots__ = ("definitions",) - - definitions: tuple[DefinitionNode, ...] +# Base classes for node categories +@dataclass(frozen=True, repr=False, kw_only=True) class DefinitionNode(Node): - __slots__ = () + """Base class for all definition nodes.""" +@dataclass(frozen=True, repr=False, kw_only=True) class ExecutableDefinitionNode(DefinitionNode): - __slots__ = "directives", "name", "selection_set", "variable_definitions" + """Base class for executable definition nodes.""" - name: NameNode | None - directives: tuple[DirectiveNode, ...] - variable_definitions: tuple[VariableDefinitionNode, ...] selection_set: SelectionSetNode + name: NameNode | None = None + variable_definitions: tuple[VariableDefinitionNode, ...] = () + directives: tuple[DirectiveNode, ...] = () -class OperationDefinitionNode(ExecutableDefinitionNode): - __slots__ = ("operation",) - - operation: OperationType - - -class VariableDefinitionNode(Node): - __slots__ = "default_value", "directives", "type", "variable" - - variable: VariableNode - type: TypeNode - default_value: ConstValueNode | None - directives: tuple[ConstDirectiveNode, ...] - - -class SelectionSetNode(Node): - __slots__ = ("selections",) - - selections: tuple[SelectionNode, ...] - - +@dataclass(frozen=True, repr=False, kw_only=True) class SelectionNode(Node): - __slots__ = ("directives",) + """Base class for selection nodes.""" - directives: tuple[DirectiveNode, ...] - - -class FieldNode(SelectionNode): - __slots__ = "alias", "arguments", "name", "nullability_assertion", "selection_set" - - alias: NameNode | None - name: NameNode - arguments: tuple[ArgumentNode, ...] - # Note: Client Controlled Nullability is experimental - # and may be changed or removed in the future. - nullability_assertion: NullabilityAssertionNode - selection_set: SelectionSetNode | None + directives: tuple[DirectiveNode, ...] = () +@dataclass(frozen=True, repr=False, kw_only=True) class NullabilityAssertionNode(Node): - __slots__ = ("nullability_assertion",) - nullability_assertion: NullabilityAssertionNode | None + """Base class for nullability assertion nodes.""" -class ListNullabilityOperatorNode(NullabilityAssertionNode): - pass +@dataclass(frozen=True, repr=False, kw_only=True) +class ValueNode(Node): + """Base class for value nodes.""" -class NonNullAssertionNode(NullabilityAssertionNode): - nullability_assertion: ListNullabilityOperatorNode +@dataclass(frozen=True, repr=False, kw_only=True) +class TypeNode(Node): + """Base class for type nodes.""" -class ErrorBoundaryNode(NullabilityAssertionNode): - nullability_assertion: ListNullabilityOperatorNode +@dataclass(frozen=True, repr=False, kw_only=True) +class TypeSystemDefinitionNode(DefinitionNode): + """Base class for type system definition nodes.""" -class ArgumentNode(Node): - __slots__ = "name", "value" +@dataclass(frozen=True, repr=False, kw_only=True) +class TypeDefinitionNode(TypeSystemDefinitionNode): + """Base class for type definition nodes.""" name: NameNode - value: ValueNode - - -class ConstArgumentNode(ArgumentNode): - value: ConstValueNode - - -# Fragments + description: StringValueNode | None = None + directives: tuple[ConstDirectiveNode, ...] = () -class FragmentSpreadNode(SelectionNode): - __slots__ = ("name",) +@dataclass(frozen=True, repr=False, kw_only=True) +class TypeExtensionNode(TypeSystemDefinitionNode): + """Base class for type extension nodes.""" name: NameNode + directives: tuple[ConstDirectiveNode, ...] = () -class InlineFragmentNode(SelectionNode): - __slots__ = "selection_set", "type_condition" +# Type Reference nodes - type_condition: NamedTypeNode - selection_set: SelectionSetNode +@dataclass(frozen=True, repr=False, kw_only=True) +class NamedTypeNode(TypeNode): + name: NameNode -class FragmentDefinitionNode(ExecutableDefinitionNode): - __slots__ = ("type_condition",) - name: NameNode - type_condition: NamedTypeNode +@dataclass(frozen=True, repr=False, kw_only=True) +class ListTypeNode(TypeNode): + type: TypeNode -# Values +@dataclass(frozen=True, repr=False, kw_only=True) +class NonNullTypeNode(TypeNode): + type: NamedTypeNode | ListTypeNode -class ValueNode(Node): - __slots__ = () +# Value nodes +@dataclass(frozen=True, repr=False, kw_only=True) class VariableNode(ValueNode): - __slots__ = ("name",) - name: NameNode +@dataclass(frozen=True, repr=False, kw_only=True) class IntValueNode(ValueNode): - __slots__ = ("value",) - value: str +@dataclass(frozen=True, repr=False, kw_only=True) class FloatValueNode(ValueNode): - __slots__ = ("value",) - value: str +@dataclass(frozen=True, repr=False, kw_only=True) class StringValueNode(ValueNode): - __slots__ = "block", "value" - value: str - block: bool | None + block: bool | None = None +@dataclass(frozen=True, repr=False, kw_only=True) class BooleanValueNode(ValueNode): - __slots__ = ("value",) - value: bool +@dataclass(frozen=True, repr=False, kw_only=True) class NullValueNode(ValueNode): - __slots__ = () + """A null value node has no fields.""" +@dataclass(frozen=True, repr=False, kw_only=True) class EnumValueNode(ValueNode): - __slots__ = ("value",) - value: str +@dataclass(frozen=True, repr=False, kw_only=True) class ListValueNode(ValueNode): - __slots__ = ("values",) - - values: tuple[ValueNode, ...] + values: tuple[ValueNode, ...] = () +@dataclass(frozen=True, repr=False, kw_only=True) class ConstListValueNode(ListValueNode): - values: tuple[ConstValueNode, ...] - - -class ObjectValueNode(ValueNode): - __slots__ = ("fields",) - - fields: tuple[ObjectFieldNode, ...] - - -class ConstObjectValueNode(ObjectValueNode): - fields: tuple[ConstObjectFieldNode, ...] + values: tuple[ConstValueNode, ...] = () +@dataclass(frozen=True, repr=False, kw_only=True) class ObjectFieldNode(Node): - __slots__ = "name", "value" - name: NameNode value: ValueNode +@dataclass(frozen=True, repr=False, kw_only=True) class ConstObjectFieldNode(ObjectFieldNode): value: ConstValueNode +@dataclass(frozen=True, repr=False, kw_only=True) +class ObjectValueNode(ValueNode): + fields: tuple[ObjectFieldNode, ...] = () + + +@dataclass(frozen=True, repr=False, kw_only=True) +class ConstObjectValueNode(ObjectValueNode): + fields: tuple[ConstObjectFieldNode, ...] = () + + ConstValueNode: TypeAlias = Union[ IntValueNode, FloatValueNode, @@ -644,216 +556,249 @@ class ConstObjectFieldNode(ObjectFieldNode): ] -# Directives +# Directive nodes +@dataclass(frozen=True, repr=False, kw_only=True) class DirectiveNode(Node): - __slots__ = "arguments", "name" - name: NameNode - arguments: tuple[ArgumentNode, ...] + arguments: tuple[ArgumentNode, ...] = () +@dataclass(frozen=True, repr=False, kw_only=True) class ConstDirectiveNode(DirectiveNode): - arguments: tuple[ConstArgumentNode, ...] + arguments: tuple[ConstArgumentNode, ...] = () -# Type Reference +# Nullability Assertion nodes -class TypeNode(Node): - __slots__ = () +@dataclass(frozen=True, repr=False, kw_only=True) +class ListNullabilityOperatorNode(NullabilityAssertionNode): + nullability_assertion: NullabilityAssertionNode | None = None -class NamedTypeNode(TypeNode): - __slots__ = ("name",) +@dataclass(frozen=True, repr=False, kw_only=True) +class NonNullAssertionNode(NullabilityAssertionNode): + nullability_assertion: ListNullabilityOperatorNode | None = None + +@dataclass(frozen=True, repr=False, kw_only=True) +class ErrorBoundaryNode(NullabilityAssertionNode): + nullability_assertion: ListNullabilityOperatorNode | None = None + + +# Selection nodes + + +@dataclass(frozen=True, repr=False, kw_only=True) +class FieldNode(SelectionNode): name: NameNode + alias: NameNode | None = None + arguments: tuple[ArgumentNode, ...] = () + directives: tuple[DirectiveNode, ...] = () + nullability_assertion: NullabilityAssertionNode | None = None + selection_set: SelectionSetNode | None = None -class ListTypeNode(TypeNode): - __slots__ = ("type",) +@dataclass(frozen=True, repr=False, kw_only=True) +class FragmentSpreadNode(SelectionNode): + name: NameNode + directives: tuple[DirectiveNode, ...] = () - type: TypeNode +@dataclass(frozen=True, repr=False, kw_only=True) +class InlineFragmentNode(SelectionNode): + selection_set: SelectionSetNode + type_condition: NamedTypeNode | None = None + directives: tuple[DirectiveNode, ...] = () -class NonNullTypeNode(TypeNode): - __slots__ = ("type",) - type: NamedTypeNode | ListTypeNode +# Argument nodes -# Type System Definition +@dataclass(frozen=True, repr=False, kw_only=True) +class ArgumentNode(Node): + name: NameNode + value: ValueNode -class TypeSystemDefinitionNode(DefinitionNode): - __slots__ = () +@dataclass(frozen=True, repr=False, kw_only=True) +class ConstArgumentNode(ArgumentNode): + value: ConstValueNode -class SchemaDefinitionNode(TypeSystemDefinitionNode): - __slots__ = "description", "directives", "operation_types" +# Selection Set - description: StringValueNode | None - directives: tuple[ConstDirectiveNode, ...] - operation_types: tuple[OperationTypeDefinitionNode, ...] +@dataclass(frozen=True, repr=False, kw_only=True) +class SelectionSetNode(Node): + selections: tuple[SelectionNode, ...] = () + + +# Variable Definition + + +@dataclass(frozen=True, repr=False, kw_only=True) +class VariableDefinitionNode(Node): + variable: VariableNode + type: TypeNode + default_value: ConstValueNode | None = None + directives: tuple[ConstDirectiveNode, ...] = () + + +# Executable Definition nodes -class OperationTypeDefinitionNode(Node): - __slots__ = "operation", "type" +@dataclass(frozen=True, repr=False, kw_only=True) +class OperationDefinitionNode(ExecutableDefinitionNode): operation: OperationType - type: NamedTypeNode -# Type Definition +@dataclass(frozen=True, repr=False, kw_only=True) +class FragmentDefinitionNode(ExecutableDefinitionNode): + name: NameNode # Required (overrides optional in parent) + type_condition: NamedTypeNode -class TypeDefinitionNode(TypeSystemDefinitionNode): - __slots__ = "description", "directives", "name" +# Document - description: StringValueNode | None - name: NameNode - directives: tuple[DirectiveNode, ...] +@dataclass(frozen=True, repr=False, kw_only=True) +class DocumentNode(Node): + definitions: tuple[DefinitionNode, ...] = () -class ScalarTypeDefinitionNode(TypeDefinitionNode): - __slots__ = () - directives: tuple[ConstDirectiveNode, ...] +# Type System Definition nodes -class ObjectTypeDefinitionNode(TypeDefinitionNode): - __slots__ = "fields", "interfaces" +@dataclass(frozen=True, repr=False, kw_only=True) +class SchemaDefinitionNode(TypeSystemDefinitionNode): + description: StringValueNode | None = None + directives: tuple[ConstDirectiveNode, ...] = () + operation_types: tuple[OperationTypeDefinitionNode, ...] = () - interfaces: tuple[NamedTypeNode, ...] - directives: tuple[ConstDirectiveNode, ...] - fields: tuple[FieldDefinitionNode, ...] +@dataclass(frozen=True, repr=False, kw_only=True) +class OperationTypeDefinitionNode(Node): + operation: OperationType + type: NamedTypeNode -class FieldDefinitionNode(DefinitionNode): - __slots__ = "arguments", "description", "directives", "name", "type" - description: StringValueNode | None +# Type Definition nodes + + +@dataclass(frozen=True, repr=False, kw_only=True) +class ScalarTypeDefinitionNode(TypeDefinitionNode): + """Scalar type definition node - inherits name, description, directives.""" + + +@dataclass(frozen=True, repr=False, kw_only=True) +class ObjectTypeDefinitionNode(TypeDefinitionNode): + interfaces: tuple[NamedTypeNode, ...] = () + fields: tuple[FieldDefinitionNode, ...] = () + + +@dataclass(frozen=True, repr=False, kw_only=True) +class FieldDefinitionNode(DefinitionNode): name: NameNode - directives: tuple[ConstDirectiveNode, ...] - arguments: tuple[InputValueDefinitionNode, ...] type: TypeNode + description: StringValueNode | None = None + arguments: tuple[InputValueDefinitionNode, ...] = () + directives: tuple[ConstDirectiveNode, ...] = () +@dataclass(frozen=True, repr=False, kw_only=True) class InputValueDefinitionNode(DefinitionNode): - __slots__ = "default_value", "description", "directives", "name", "type" - - description: StringValueNode | None name: NameNode - directives: tuple[ConstDirectiveNode, ...] type: TypeNode - default_value: ConstValueNode | None + description: StringValueNode | None = None + default_value: ConstValueNode | None = None + directives: tuple[ConstDirectiveNode, ...] = () +@dataclass(frozen=True, repr=False, kw_only=True) class InterfaceTypeDefinitionNode(TypeDefinitionNode): - __slots__ = "fields", "interfaces" - - fields: tuple[FieldDefinitionNode, ...] - directives: tuple[ConstDirectiveNode, ...] - interfaces: tuple[NamedTypeNode, ...] + interfaces: tuple[NamedTypeNode, ...] = () + fields: tuple[FieldDefinitionNode, ...] = () +@dataclass(frozen=True, repr=False, kw_only=True) class UnionTypeDefinitionNode(TypeDefinitionNode): - __slots__ = ("types",) - - directives: tuple[ConstDirectiveNode, ...] - types: tuple[NamedTypeNode, ...] + types: tuple[NamedTypeNode, ...] = () +@dataclass(frozen=True, repr=False, kw_only=True) class EnumTypeDefinitionNode(TypeDefinitionNode): - __slots__ = ("values",) - - directives: tuple[ConstDirectiveNode, ...] - values: tuple[EnumValueDefinitionNode, ...] + values: tuple[EnumValueDefinitionNode, ...] = () +@dataclass(frozen=True, repr=False, kw_only=True) class EnumValueDefinitionNode(DefinitionNode): - __slots__ = "description", "directives", "name" - - description: StringValueNode | None name: NameNode - directives: tuple[ConstDirectiveNode, ...] + description: StringValueNode | None = None + directives: tuple[ConstDirectiveNode, ...] = () +@dataclass(frozen=True, repr=False, kw_only=True) class InputObjectTypeDefinitionNode(TypeDefinitionNode): - __slots__ = ("fields",) - - directives: tuple[ConstDirectiveNode, ...] - fields: tuple[InputValueDefinitionNode, ...] + fields: tuple[InputValueDefinitionNode, ...] = () -# Directive Definitions +# Directive Definition +@dataclass(frozen=True, repr=False, kw_only=True) class DirectiveDefinitionNode(TypeSystemDefinitionNode): - __slots__ = "arguments", "description", "locations", "name", "repeatable" - - description: StringValueNode | None name: NameNode - arguments: tuple[InputValueDefinitionNode, ...] - repeatable: bool locations: tuple[NameNode, ...] + description: StringValueNode | None = None + arguments: tuple[InputValueDefinitionNode, ...] = () + repeatable: bool = False -# Type System Extensions +# Type System Extension nodes +@dataclass(frozen=True, repr=False, kw_only=True) class SchemaExtensionNode(Node): - __slots__ = "directives", "operation_types" - - directives: tuple[ConstDirectiveNode, ...] - operation_types: tuple[OperationTypeDefinitionNode, ...] - + directives: tuple[ConstDirectiveNode, ...] = () + operation_types: tuple[OperationTypeDefinitionNode, ...] = () -# Type Extensions - -class TypeExtensionNode(TypeSystemDefinitionNode): - __slots__ = "directives", "name" - - name: NameNode - directives: tuple[ConstDirectiveNode, ...] +TypeSystemExtensionNode: TypeAlias = Union[SchemaExtensionNode, TypeExtensionNode] -TypeSystemExtensionNode: TypeAlias = Union[SchemaExtensionNode, TypeExtensionNode] +# Type Extension nodes +@dataclass(frozen=True, repr=False, kw_only=True) class ScalarTypeExtensionNode(TypeExtensionNode): - __slots__ = () + """Scalar type extension node - inherits name, directives.""" +@dataclass(frozen=True, repr=False, kw_only=True) class ObjectTypeExtensionNode(TypeExtensionNode): - __slots__ = "fields", "interfaces" - - interfaces: tuple[NamedTypeNode, ...] - fields: tuple[FieldDefinitionNode, ...] + interfaces: tuple[NamedTypeNode, ...] = () + fields: tuple[FieldDefinitionNode, ...] = () +@dataclass(frozen=True, repr=False, kw_only=True) class InterfaceTypeExtensionNode(TypeExtensionNode): - __slots__ = "fields", "interfaces" - - interfaces: tuple[NamedTypeNode, ...] - fields: tuple[FieldDefinitionNode, ...] + interfaces: tuple[NamedTypeNode, ...] = () + fields: tuple[FieldDefinitionNode, ...] = () +@dataclass(frozen=True, repr=False, kw_only=True) class UnionTypeExtensionNode(TypeExtensionNode): - __slots__ = ("types",) - - types: tuple[NamedTypeNode, ...] + types: tuple[NamedTypeNode, ...] = () +@dataclass(frozen=True, repr=False, kw_only=True) class EnumTypeExtensionNode(TypeExtensionNode): - __slots__ = ("values",) - - values: tuple[EnumValueDefinitionNode, ...] + values: tuple[EnumValueDefinitionNode, ...] = () +@dataclass(frozen=True, repr=False, kw_only=True) class InputObjectTypeExtensionNode(TypeExtensionNode): - __slots__ = ("fields",) - - fields: tuple[InputValueDefinitionNode, ...] + fields: tuple[InputValueDefinitionNode, ...] = () diff --git a/tests/language/test_ast.py b/tests/language/test_ast.py index 9c1f5c84..f9734c56 100644 --- a/tests/language/test_ast.py +++ b/tests/language/test_ast.py @@ -2,23 +2,25 @@ import weakref from copy import copy, deepcopy +from dataclasses import dataclass +from typing import ClassVar from graphql.language import Location, NameNode, Node, Source, Token, TokenKind from graphql.pyutils import inspect +@dataclass(frozen=True, repr=False, kw_only=True) class SampleTestNode(Node): - __slots__ = "alpha", "beta" - - alpha: int | Node # Union with Node to support copy tests with nested nodes - beta: int | Node | None + kind: ClassVar[str] = "sample_test" + alpha: int | Node = 0 # Union with Node to support copy tests with nested nodes + beta: int | Node | None = None +@dataclass(frozen=True, repr=False, kw_only=True) class SampleNamedNode(Node): - __slots__ = "foo", "name" - - foo: str - name: NameNode | None + kind: ClassVar[str] = "sample_named" + foo: str = "" + name: NameNode | None = None def make_loc(start: int = 1, end: int = 3) -> Location: @@ -176,13 +178,17 @@ def initializes_with_none_location(): assert node.beta == 2 assert not hasattr(node, "gamma") - def converts_list_to_tuple_on_init(): + def does_not_convert_list_to_tuple(): + """Lists are not auto-converted to tuples - pass tuples directly.""" from graphql.language import FieldNode, SelectionSetNode field = FieldNode(name=NameNode(value="foo")) - node = SelectionSetNode(selections=[field]) # Pass list, not tuple + # Passing a list stores it as-is (no conversion) + node = SelectionSetNode(selections=[field]) # type: ignore[arg-type] + assert isinstance(node.selections, list) + # Users should pass tuples directly for proper typing + node = SelectionSetNode(selections=(field,)) assert isinstance(node.selections, tuple) - assert node.selections == (field,) def has_representation_with_loc(): node = SampleTestNode(alpha=1, beta=2) @@ -220,8 +226,6 @@ def can_check_equality(): assert node2 == node node2 = SampleTestNode(alpha=1, beta=1) assert node2 != node - node3 = Node(alpha=1, beta=2) - assert node3 != node def can_hash(): node = SampleTestNode(alpha=1, beta=2) @@ -233,29 +237,18 @@ def can_hash(): assert node3 != node assert hash(node3) != hash(node) - def caches_are_hashed(): - node = SampleTestNode(alpha=1) - assert not hasattr(node, "_hash") + def is_hashable(): + node = SampleTestNode(alpha=1, beta=2) hash1 = hash(node) - assert hasattr(node, "_hash") - assert hash1 == node._hash # noqa: SLF001 - node.alpha = 2 - assert not hasattr(node, "_hash") + # Hash should be stable hash2 = hash(node) - assert hash2 != hash1 - assert hasattr(node, "_hash") - assert hash2 == node._hash # noqa: SLF001 + assert hash1 == hash2 def can_create_weak_reference(): node = SampleTestNode(alpha=1, beta=2) ref = weakref.ref(node) assert ref() is node - def can_create_custom_attribute(): - node = SampleTestNode(alpha=1, beta=2) - node.gamma = 3 - assert node.gamma == 3 # type: ignore - def can_create_shallow_copy(): node = SampleTestNode(alpha=1, beta=2) node2 = copy(node) diff --git a/tests/language/test_predicates.py b/tests/language/test_predicates.py index f87148e4..46dcca21 100644 --- a/tests/language/test_predicates.py +++ b/tests/language/test_predicates.py @@ -18,13 +18,101 @@ parse_value, ) + +def _make_name() -> ast.NameNode: + """Create a dummy NameNode.""" + return ast.NameNode(value="x") + + +def _make_named_type() -> ast.NamedTypeNode: + """Create a dummy NamedTypeNode.""" + return ast.NamedTypeNode(name=_make_name()) + + +def _make_selection_set() -> ast.SelectionSetNode: + """Create a dummy SelectionSetNode.""" + return ast.SelectionSetNode() + + +def _create_node(node_class: type) -> Node: + """Create a minimal valid instance of a node class.""" + name = _make_name() + named_type = _make_named_type() + selection_set = _make_selection_set() + + # Map node classes to their required constructor arguments + constructors: dict[type, dict] = { + # Nodes with required fields + ast.NameNode: {"value": "x"}, + ast.FieldNode: {"name": name}, + ast.FragmentSpreadNode: {"name": name}, + ast.InlineFragmentNode: {"selection_set": selection_set}, + ast.ArgumentNode: {"name": name, "value": ast.NullValueNode()}, + ast.VariableNode: {"name": name}, + ast.IntValueNode: {"value": "0"}, + ast.FloatValueNode: {"value": "0.0"}, + ast.StringValueNode: {"value": ""}, + ast.BooleanValueNode: {"value": True}, + ast.EnumValueNode: {"value": "X"}, + ast.ObjectFieldNode: {"name": name, "value": ast.NullValueNode()}, + ast.ListTypeNode: {"type": named_type}, + ast.NonNullTypeNode: {"type": named_type}, + ast.NamedTypeNode: {"name": name}, + ast.OperationDefinitionNode: { + "operation": ast.OperationType.QUERY, + "selection_set": selection_set, + }, + ast.VariableDefinitionNode: { + "variable": ast.VariableNode(name=name), + "type": named_type, + }, + ast.FragmentDefinitionNode: { + "name": name, + "type_condition": named_type, + "selection_set": selection_set, + }, + ast.DirectiveNode: {"name": name}, + # Base classes with required fields + ast.ExecutableDefinitionNode: {"selection_set": selection_set}, + ast.TypeDefinitionNode: {"name": name}, + ast.OperationTypeDefinitionNode: { + "operation": ast.OperationType.QUERY, + "type": named_type, + }, + ast.ScalarTypeDefinitionNode: {"name": name}, + ast.ObjectTypeDefinitionNode: {"name": name}, + ast.FieldDefinitionNode: {"name": name, "type": named_type}, + ast.InputValueDefinitionNode: {"name": name, "type": named_type}, + ast.InterfaceTypeDefinitionNode: {"name": name}, + ast.UnionTypeDefinitionNode: {"name": name}, + ast.EnumTypeDefinitionNode: {"name": name}, + ast.EnumValueDefinitionNode: {"name": name}, + ast.InputObjectTypeDefinitionNode: {"name": name}, + ast.DirectiveDefinitionNode: {"name": name, "locations": ()}, + ast.TypeExtensionNode: {"name": name}, + ast.ScalarTypeExtensionNode: {"name": name}, + ast.ObjectTypeExtensionNode: {"name": name}, + ast.InterfaceTypeExtensionNode: {"name": name}, + ast.UnionTypeExtensionNode: {"name": name}, + ast.EnumTypeExtensionNode: {"name": name}, + ast.InputObjectTypeExtensionNode: {"name": name}, + } + + if node_class in constructors: + return node_class(**constructors[node_class]) + # Node types with no required fields (base classes and simple nodes) + return node_class() + + +# Build list of all concrete AST node types (excluding Const* variants) all_ast_nodes = sorted( [ - node_type() - for node_type in vars(ast).values() - if type(node_type) is type - and issubclass(node_type, Node) - and not node_type.__name__.startswith("Const") + _create_node(node_class) + for node_class in vars(ast).values() + if isinstance(node_class, type) + and issubclass(node_class, Node) + and node_class is not Node + and not node_class.__name__.startswith("Const") ], key=attrgetter("kind"), ) diff --git a/tests/language/test_printer.py b/tests/language/test_printer.py index 48ba150f..16e40072 100644 --- a/tests/language/test_printer.py +++ b/tests/language/test_printer.py @@ -18,7 +18,7 @@ def produces_helpful_error_messages(): with pytest.raises(TypeError) as exc_info: print_ast(bad_ast) # type: ignore assert str(exc_info.value) == "Not an AST Node: {'random': 'Data'}." - corrupt_ast = FieldNode(name="random data") + corrupt_ast = FieldNode(name="random data") # type: ignore[arg-type] with pytest.raises(TypeError) as exc_info: print_ast(corrupt_ast) assert str(exc_info.value) == "Invalid AST Node: 'random data'." diff --git a/tests/language/test_visitor.py b/tests/language/test_visitor.py index b373dbfd..39a9b562 100644 --- a/tests/language/test_visitor.py +++ b/tests/language/test_visitor.py @@ -1,7 +1,8 @@ from __future__ import annotations +from dataclasses import dataclass from functools import partial -from typing import Any, cast +from typing import Any, ClassVar, cast import pytest @@ -595,11 +596,11 @@ def visit_nodes_with_custom_kinds_but_does_not_traverse_deeper(): # so we keep allowing this and test this feature here. parsed_ast = parse("{ a }") + @dataclass(frozen=True, repr=False, kw_only=True) class CustomFieldNode(SelectionNode): - __slots__ = "name", "selection_set" - + kind: ClassVar[str] = "custom_field" name: NameNode - selection_set: SelectionSetNode | None + selection_set: SelectionSetNode | None = None # Build custom AST immutably op_def = cast("OperationDefinitionNode", parsed_ast.definitions[0])