diff --git a/docs/modules/gql.rst b/docs/modules/gql.rst index aac47c86..06a89a96 100644 --- a/docs/modules/gql.rst +++ b/docs/modules/gql.rst @@ -21,3 +21,4 @@ Sub-Packages client transport dsl + utilities diff --git a/docs/modules/utilities.rst b/docs/modules/utilities.rst new file mode 100644 index 00000000..47043b98 --- /dev/null +++ b/docs/modules/utilities.rst @@ -0,0 +1,6 @@ +gql.utilities +============= + +.. currentmodule:: gql.utilities + +.. automodule:: gql.utilities diff --git a/docs/usage/custom_scalars.rst b/docs/usage/custom_scalars.rst new file mode 100644 index 00000000..baee441e --- /dev/null +++ b/docs/usage/custom_scalars.rst @@ -0,0 +1,134 @@ +Custom Scalars +============== + +Scalar types represent primitive values at the leaves of a query. + +GraphQL provides a number of built-in scalars (Int, Float, String, Boolean and ID), but a GraphQL backend +can add additional custom scalars to its schema to better express values in their data model. + +For example, a schema can define the Datetime scalar to represent an ISO-8601 encoded date. + +The schema will then only contain: + +.. code-block:: python + + scalar Datetime + +When custom scalars are sent to the backend (as inputs) or from the backend (as outputs), +their values need to be serialized to be composed +of only built-in scalars, then at the destination the serialized values will be parsed again to +be able to represent the scalar in its local internal representation. + +Because this serialization/unserialization is dependent on the language used at both sides, it is not +described in the schema and needs to be defined independently at both sides (client, backend). + +A custom scalar value can have two different representations during its transport: + + - as a serialized value (usually as json): + + * in the results sent by the backend + * in the variables sent by the client alongside the query + + - as "literal" inside the query itself sent by the client + +To define a custom scalar, you need 3 methods: + + - a :code:`serialize` method used: + + * by the backend to serialize a custom scalar output in the result + * by the client to serialize a custom scalar input in the variables + + - a :code:`parse_value` method used: + + * by the backend to unserialize custom scalars inputs in the variables sent by the client + * by the client to unserialize custom scalars outputs from the results + + - a :code:`parse_literal` method used: + + * by the backend to unserialize custom scalars inputs inside the query itself + +To define a custom scalar object, we define a :code:`GraphQLScalarType` from graphql-core with +its name and the implementation of the above methods. + +Example for Datetime: + +.. code-block:: python + + from datetime import datetime + from typing import Any, Dict, Optional + + from graphql import GraphQLScalarType, ValueNode + from graphql.utilities import value_from_ast_untyped + + + def serialize_datetime(value: Any) -> str: + return value.isoformat() + + + def parse_datetime_value(value: Any) -> datetime: + return datetime.fromisoformat(value) + + + def parse_datetime_literal( + value_node: ValueNode, variables: Optional[Dict[str, Any]] = None + ) -> datetime: + ast_value = value_from_ast_untyped(value_node, variables) + return parse_datetime_value(ast_value) + + + DatetimeScalar = GraphQLScalarType( + name="Datetime", + serialize=serialize_datetime, + parse_value=parse_datetime_value, + parse_literal=parse_datetime_literal, + ) + +Custom Scalars in inputs +------------------------ + +To provide custom scalars in input with gql, you can: + +- serialize the scalar yourself as "literal" in the query: + +.. code-block:: python + + query = gql( + """{ + shiftDays(time: "2021-11-12T11:58:13.461161", days: 5) + }""" + ) + +- serialize the scalar yourself in a variable: + +.. code-block:: python + + query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") + + variable_values = { + "time": "2021-11-12T11:58:13.461161", + } + + result = client.execute(query, variable_values=variable_values) + +- add a custom scalar to the schema with :func:`update_schema_scalars ` + and execute the query with :code:`serialize_variables=True` + and gql will serialize the variable values from a Python object representation. + +For this, you need to provide a schema or set :code:`fetch_schema_from_transport=True` +in the client to request the schema from the backend. + +.. code-block:: python + + from gql.utilities import update_schema_scalars + + async with Client(transport=transport, fetch_schema_from_transport=True) as session: + + update_schema_scalars(session.client.schema, [DatetimeScalar]) + + query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") + + variable_values = {"time": datetime.now()} + + result = await session.execute( + query, variable_values=variable_values, serialize_variables=True + ) diff --git a/docs/usage/index.rst b/docs/usage/index.rst index a7dd4d56..4a38093a 100644 --- a/docs/usage/index.rst +++ b/docs/usage/index.rst @@ -10,3 +10,4 @@ Usage variables headers file_upload + custom_scalars diff --git a/gql/client.py b/gql/client.py index 6017ab69..368193cc 100644 --- a/gql/client.py +++ b/gql/client.py @@ -17,6 +17,7 @@ from .transport.exceptions import TransportQueryError from .transport.local_schema import LocalSchemaTransport from .transport.transport import Transport +from .variable_values import serialize_variable_values class Client: @@ -289,18 +290,79 @@ def __init__(self, client: Client): """:param client: the :class:`client ` used""" self.client = client - def _execute(self, document: DocumentNode, *args, **kwargs) -> ExecutionResult: + def _execute( + self, + document: DocumentNode, + *args, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + serialize_variables: bool = False, + **kwargs, + ) -> ExecutionResult: + """Execute the provided document AST synchronously using + the sync transport, returning an ExecutionResult object. + + :param document: GraphQL query as AST Node object. + :param variable_values: Dictionary of input parameters. + :param operation_name: Name of the operation that shall be executed. + :param serialize_variables: whether the variable values should be + serialized. Used for custom scalars and/or enums. Default: False. + + The extra arguments are passed to the transport execute method.""" # Validate document if self.client.schema: self.client.validate(document) - return self.transport.execute(document, *args, **kwargs) + # Parse variable values for custom scalars if requested + if serialize_variables and variable_values is not None: + variable_values = serialize_variable_values( + self.client.schema, + document, + variable_values, + operation_name=operation_name, + ) + + return self.transport.execute( + document, + *args, + variable_values=variable_values, + operation_name=operation_name, + **kwargs, + ) - def execute(self, document: DocumentNode, *args, **kwargs) -> Dict: + def execute( + self, + document: DocumentNode, + *args, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + serialize_variables: bool = False, + **kwargs, + ) -> Dict: + """Execute the provided document AST synchronously using + the sync transport. + + Raises a TransportQueryError if an error has been returned in + the ExecutionResult. + + :param document: GraphQL query as AST Node object. + :param variable_values: Dictionary of input parameters. + :param operation_name: Name of the operation that shall be executed. + :param serialize_variables: whether the variable values should be + serialized. Used for custom scalars and/or enums. Default: False. + + The extra arguments are passed to the transport execute method.""" # Validate and execute on the transport - result = self._execute(document, *args, **kwargs) + result = self._execute( + document, + *args, + variable_values=variable_values, + operation_name=operation_name, + serialize_variables=serialize_variables, + **kwargs, + ) # Raise an error if an error is returned in the ExecutionResult object if result.errors: @@ -341,17 +403,52 @@ def __init__(self, client: Client): self.client = client async def _subscribe( - self, document: DocumentNode, *args, **kwargs + self, + document: DocumentNode, + *args, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + serialize_variables: bool = False, + **kwargs, ) -> AsyncGenerator[ExecutionResult, None]: + """Coroutine to subscribe asynchronously to the provided document AST + asynchronously using the async transport, + returning an async generator producing ExecutionResult objects. + + * Validate the query with the schema if provided. + * Serialize the variable_values if requested. + + :param document: GraphQL query as AST Node object. + :param variable_values: Dictionary of input parameters. + :param operation_name: Name of the operation that shall be executed. + :param serialize_variables: whether the variable values should be + serialized. Used for custom scalars and/or enums. Default: False. + + The extra arguments are passed to the transport subscribe method.""" # Validate document if self.client.schema: self.client.validate(document) + # Parse variable values for custom scalars if requested + if serialize_variables and variable_values is not None: + variable_values = serialize_variable_values( + self.client.schema, + document, + variable_values, + operation_name=operation_name, + ) + # Subscribe to the transport inner_generator: AsyncGenerator[ ExecutionResult, None - ] = self.transport.subscribe(document, *args, **kwargs) + ] = self.transport.subscribe( + document, + *args, + variable_values=variable_values, + operation_name=operation_name, + **kwargs, + ) # Keep a reference to the inner generator to allow the user to call aclose() # before a break if python version is too old (pypy3 py 3.6.1) @@ -364,15 +461,35 @@ async def _subscribe( await inner_generator.aclose() async def subscribe( - self, document: DocumentNode, *args, **kwargs + self, + document: DocumentNode, + *args, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + serialize_variables: bool = False, + **kwargs, ) -> AsyncGenerator[Dict, None]: """Coroutine to subscribe asynchronously to the provided document AST asynchronously using the async transport. + Raises a TransportQueryError if an error has been returned in + the ExecutionResult. + + :param document: GraphQL query as AST Node object. + :param variable_values: Dictionary of input parameters. + :param operation_name: Name of the operation that shall be executed. + :param serialize_variables: whether the variable values should be + serialized. Used for custom scalars and/or enums. Default: False. + The extra arguments are passed to the transport subscribe method.""" inner_generator: AsyncGenerator[ExecutionResult, None] = self._subscribe( - document, *args, **kwargs + document, + *args, + variable_values=variable_values, + operation_name=operation_name, + serialize_variables=serialize_variables, + **kwargs, ) try: @@ -391,27 +508,85 @@ async def subscribe( await inner_generator.aclose() async def _execute( - self, document: DocumentNode, *args, **kwargs + self, + document: DocumentNode, + *args, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + serialize_variables: bool = False, + **kwargs, ) -> ExecutionResult: + """Coroutine to execute the provided document AST asynchronously using + the async transport, returning an ExecutionResult object. + + * Validate the query with the schema if provided. + * Serialize the variable_values if requested. + + :param document: GraphQL query as AST Node object. + :param variable_values: Dictionary of input parameters. + :param operation_name: Name of the operation that shall be executed. + :param serialize_variables: whether the variable values should be + serialized. Used for custom scalars and/or enums. Default: False. + + The extra arguments are passed to the transport execute method.""" # Validate document if self.client.schema: self.client.validate(document) + # Parse variable values for custom scalars if requested + if serialize_variables and variable_values is not None: + variable_values = serialize_variable_values( + self.client.schema, + document, + variable_values, + operation_name=operation_name, + ) + # Execute the query with the transport with a timeout return await asyncio.wait_for( - self.transport.execute(document, *args, **kwargs), + self.transport.execute( + document, + variable_values=variable_values, + operation_name=operation_name, + *args, + **kwargs, + ), self.client.execute_timeout, ) - async def execute(self, document: DocumentNode, *args, **kwargs) -> Dict: + async def execute( + self, + document: DocumentNode, + *args, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + serialize_variables: bool = False, + **kwargs, + ) -> Dict: """Coroutine to execute the provided document AST asynchronously using the async transport. + Raises a TransportQueryError if an error has been returned in + the ExecutionResult. + + :param document: GraphQL query as AST Node object. + :param variable_values: Dictionary of input parameters. + :param operation_name: Name of the operation that shall be executed. + :param serialize_variables: whether the variable values should be + serialized. Used for custom scalars and/or enums. Default: False. + The extra arguments are passed to the transport execute method.""" # Validate and execute on the transport - result = await self._execute(document, *args, **kwargs) + result = await self._execute( + document, + *args, + variable_values=variable_values, + operation_name=operation_name, + serialize_variables=serialize_variables, + **kwargs, + ) # Raise an error if an error is returned in the ExecutionResult object if result.errors: diff --git a/gql/utilities/__init__.py b/gql/utilities/__init__.py new file mode 100644 index 00000000..68b80156 --- /dev/null +++ b/gql/utilities/__init__.py @@ -0,0 +1,5 @@ +from .update_schema_scalars import update_schema_scalars + +__all__ = [ + "update_schema_scalars", +] diff --git a/gql/utilities/update_schema_scalars.py b/gql/utilities/update_schema_scalars.py new file mode 100644 index 00000000..d5434c6b --- /dev/null +++ b/gql/utilities/update_schema_scalars.py @@ -0,0 +1,32 @@ +from typing import Iterable, List + +from graphql import GraphQLError, GraphQLScalarType, GraphQLSchema + + +def update_schema_scalars(schema: GraphQLSchema, scalars: List[GraphQLScalarType]): + """Update the scalars in a schema with the scalars provided. + + This can be used to update the default Custom Scalar implementation + when the schema has been provided from a text file or from introspection. + """ + + if not isinstance(scalars, Iterable): + raise GraphQLError("Scalars argument should be a list of scalars.") + + for scalar in scalars: + if not isinstance(scalar, GraphQLScalarType): + raise GraphQLError("Scalars should be instances of GraphQLScalarType.") + + try: + schema_scalar = schema.type_map[scalar.name] + except KeyError: + raise GraphQLError(f"Scalar '{scalar.name}' not found in schema.") + + assert isinstance(schema_scalar, GraphQLScalarType) + + # Update the conversion methods + # Using setattr because mypy has a false positive + # https://github.com/python/mypy/issues/2427 + setattr(schema_scalar, "serialize", scalar.serialize) + setattr(schema_scalar, "parse_value", scalar.parse_value) + setattr(schema_scalar, "parse_literal", scalar.parse_literal) diff --git a/gql/variable_values.py b/gql/variable_values.py new file mode 100644 index 00000000..7db7091a --- /dev/null +++ b/gql/variable_values.py @@ -0,0 +1,117 @@ +from typing import Any, Dict, Optional + +from graphql import ( + DocumentNode, + GraphQLEnumType, + GraphQLError, + GraphQLInputObjectType, + GraphQLList, + GraphQLNonNull, + GraphQLScalarType, + GraphQLSchema, + GraphQLType, + GraphQLWrappingType, + OperationDefinitionNode, + type_from_ast, +) +from graphql.pyutils import inspect + + +def get_document_operation( + document: DocumentNode, operation_name: Optional[str] = None +) -> OperationDefinitionNode: + """Returns the operation which should be executed in the document. + + Raises a GraphQLError if a single operation cannot be retrieved. + """ + + operation: Optional[OperationDefinitionNode] = None + + for definition in document.definitions: + if isinstance(definition, OperationDefinitionNode): + if operation_name is None: + if operation: + raise GraphQLError( + "Must provide operation name" + " if query contains multiple operations." + ) + operation = definition + elif definition.name and definition.name.value == operation_name: + operation = definition + + if not operation: + if operation_name is not None: + raise GraphQLError(f"Unknown operation named '{operation_name}'.") + + # The following line should never happen normally as the document is + # already verified before calling this function. + raise GraphQLError("Must provide an operation.") # pragma: no cover + + return operation + + +def serialize_value(type_: GraphQLType, value: Any) -> Any: + """Given a GraphQL type and a Python value, return the serialized value. + + Can be used to serialize Enums and/or Custom Scalars in variable values. + """ + + if value is None: + if isinstance(type_, GraphQLNonNull): + # raise GraphQLError(f"Type {type_.of_type.name} Cannot be None.") + raise GraphQLError(f"Type {inspect(type_)} Cannot be None.") + else: + return None + + if isinstance(type_, GraphQLWrappingType): + inner_type = type_.of_type + + if isinstance(type_, GraphQLNonNull): + return serialize_value(inner_type, value) + + elif isinstance(type_, GraphQLList): + return [serialize_value(inner_type, v) for v in value] + + elif isinstance(type_, (GraphQLScalarType, GraphQLEnumType)): + return type_.serialize(value) + + elif isinstance(type_, GraphQLInputObjectType): + return { + field_name: serialize_value(field.type, value[field_name]) + for field_name, field in type_.fields.items() + } + + raise GraphQLError(f"Impossible to serialize value with type: {inspect(type_)}.") + + +def serialize_variable_values( + schema: GraphQLSchema, + document: DocumentNode, + variable_values: Dict[str, Any], + operation_name: Optional[str] = None, +) -> Dict[str, Any]: + """Given a GraphQL document and a schema, serialize the Dictionary of + variable values. + + Useful to serialize Enums and/or Custom Scalars in variable values + """ + + parsed_variable_values: Dict[str, Any] = {} + + # Find the operation in the document + operation = get_document_operation(document, operation_name=operation_name) + + # Serialize every variable value defined for the operation + for var_def_node in operation.variable_definitions: + var_name = var_def_node.variable.name.value + var_type = type_from_ast(schema, var_def_node.type) + + if var_name in variable_values: + + assert var_type is not None + + var_value = variable_values[var_name] + + parsed_variable_values[var_name] = serialize_value(var_type, var_value) + + return parsed_variable_values diff --git a/tests/custom_scalars/__init__.py b/tests/custom_scalars/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/custom_scalars/test_custom_scalar_datetime.py b/tests/custom_scalars/test_custom_scalar_datetime.py new file mode 100644 index 00000000..25c6bb31 --- /dev/null +++ b/tests/custom_scalars/test_custom_scalar_datetime.py @@ -0,0 +1,220 @@ +from datetime import datetime, timedelta +from typing import Any, Dict, Optional + +import pytest +from graphql.error import GraphQLError +from graphql.language import ValueNode +from graphql.pyutils import inspect +from graphql.type import ( + GraphQLArgument, + GraphQLField, + GraphQLInputField, + GraphQLInputObjectType, + GraphQLInt, + GraphQLList, + GraphQLObjectType, + GraphQLScalarType, + GraphQLSchema, +) +from graphql.utilities import value_from_ast_untyped + +from gql import Client, gql + + +def serialize_datetime(value: Any) -> str: + if not isinstance(value, datetime): + raise GraphQLError("Cannot serialize datetime value: " + inspect(value)) + return value.isoformat() + + +def parse_datetime_value(value: Any) -> datetime: + + if isinstance(value, str): + try: + # Note: a more solid custom scalar should use dateutil.parser.isoparse + # Not using it here in the test to avoid adding another dependency + return datetime.fromisoformat(value) + except Exception: + raise GraphQLError("Cannot parse datetime value : " + inspect(value)) + + else: + raise GraphQLError("Cannot parse datetime value: " + inspect(value)) + + +def parse_datetime_literal( + value_node: ValueNode, variables: Optional[Dict[str, Any]] = None +) -> datetime: + ast_value = value_from_ast_untyped(value_node, variables) + if not isinstance(ast_value, str): + raise GraphQLError("Cannot parse literal datetime value: " + inspect(ast_value)) + + return parse_datetime_value(ast_value) + + +DatetimeScalar = GraphQLScalarType( + name="Datetime", + serialize=serialize_datetime, + parse_value=parse_datetime_value, + parse_literal=parse_datetime_literal, +) + + +def resolve_shift_days(root, _info, time, days): + return time + timedelta(days=days) + + +def resolve_latest(root, _info, times): + return max(times) + + +def resolve_seconds(root, _info, interval): + print(f"interval={interval!r}") + return (interval["end"] - interval["start"]).total_seconds() + + +IntervalInputType = GraphQLInputObjectType( + "IntervalInput", + fields={ + "start": GraphQLInputField(DatetimeScalar), + "end": GraphQLInputField(DatetimeScalar), + }, +) + +queryType = GraphQLObjectType( + name="RootQueryType", + fields={ + "shiftDays": GraphQLField( + DatetimeScalar, + args={ + "time": GraphQLArgument(DatetimeScalar), + "days": GraphQLArgument(GraphQLInt), + }, + resolve=resolve_shift_days, + ), + "latest": GraphQLField( + DatetimeScalar, + args={"times": GraphQLArgument(GraphQLList(DatetimeScalar))}, + resolve=resolve_latest, + ), + "seconds": GraphQLField( + GraphQLInt, + args={"interval": GraphQLArgument(IntervalInputType)}, + resolve=resolve_seconds, + ), + }, +) + +schema = GraphQLSchema(query=queryType) + + +@pytest.mark.skipif( + not hasattr(datetime, "fromisoformat"), reason="fromisoformat is new in Python 3.7+" +) +def test_shift_days(): + + client = Client(schema=schema) + + now = datetime.fromisoformat("2021-11-12T11:58:13.461161") + + query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") + + variable_values = { + "time": now, + } + + result = client.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + print(result) + + assert result["shiftDays"] == "2021-11-17T11:58:13.461161" + + +@pytest.mark.skipif( + not hasattr(datetime, "fromisoformat"), reason="fromisoformat is new in Python 3.7+" +) +def test_shift_days_serialized_manually_in_query(): + + client = Client(schema=schema) + + query = gql( + """{ + shiftDays(time: "2021-11-12T11:58:13.461161", days: 5) + }""" + ) + + result = client.execute(query) + + print(result) + + assert result["shiftDays"] == "2021-11-17T11:58:13.461161" + + +@pytest.mark.skipif( + not hasattr(datetime, "fromisoformat"), reason="fromisoformat is new in Python 3.7+" +) +def test_shift_days_serialized_manually_in_variables(): + + client = Client(schema=schema) + + query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") + + variable_values = { + "time": "2021-11-12T11:58:13.461161", + } + + result = client.execute(query, variable_values=variable_values) + + print(result) + + assert result["shiftDays"] == "2021-11-17T11:58:13.461161" + + +@pytest.mark.skipif( + not hasattr(datetime, "fromisoformat"), reason="fromisoformat is new in Python 3.7+" +) +def test_latest(): + + client = Client(schema=schema) + + now = datetime.fromisoformat("2021-11-12T11:58:13.461161") + in_five_days = datetime.fromisoformat("2021-11-17T11:58:13.461161") + + query = gql("query latest($times: [Datetime!]!) {latest(times: $times)}") + + variable_values = { + "times": [now, in_five_days], + } + + result = client.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + print(result) + + assert result["latest"] == in_five_days.isoformat() + + +@pytest.mark.skipif( + not hasattr(datetime, "fromisoformat"), reason="fromisoformat is new in Python 3.7+" +) +def test_seconds(): + client = Client(schema=schema) + + now = datetime.fromisoformat("2021-11-12T11:58:13.461161") + in_five_days = datetime.fromisoformat("2021-11-17T11:58:13.461161") + + query = gql( + "query seconds($interval: IntervalInput) {seconds(interval: $interval)}" + ) + + variable_values = {"interval": {"start": now, "end": in_five_days}} + + result = client.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + print(result) + + assert result["seconds"] == 432000 diff --git a/tests/custom_scalars/test_custom_scalar_money.py b/tests/custom_scalars/test_custom_scalar_money.py new file mode 100644 index 00000000..238308a9 --- /dev/null +++ b/tests/custom_scalars/test_custom_scalar_money.py @@ -0,0 +1,635 @@ +import asyncio +from typing import Any, Dict, NamedTuple, Optional + +import pytest +from graphql import graphql_sync +from graphql.error import GraphQLError +from graphql.language import ValueNode +from graphql.pyutils import inspect, is_finite +from graphql.type import ( + GraphQLArgument, + GraphQLField, + GraphQLFloat, + GraphQLInt, + GraphQLNonNull, + GraphQLObjectType, + GraphQLScalarType, + GraphQLSchema, +) +from graphql.utilities import value_from_ast_untyped + +from gql import Client, gql +from gql.transport.exceptions import TransportQueryError +from gql.utilities import update_schema_scalars +from gql.variable_values import serialize_value + +from ..conftest import MS + +# Marking all tests in this file with the aiohttp marker +pytestmark = pytest.mark.aiohttp + + +class Money(NamedTuple): + amount: float + currency: str + + +def serialize_money(output_value: Any) -> Dict[str, Any]: + if not isinstance(output_value, Money): + raise GraphQLError("Cannot serialize money value: " + inspect(output_value)) + return output_value._asdict() + + +def parse_money_value(input_value: Any) -> Money: + """Using Money custom scalar from graphql-core tests except here the + input value is supposed to be a dict instead of a Money object.""" + + """ + if isinstance(input_value, Money): + return input_value + """ + + if isinstance(input_value, dict): + amount = input_value.get("amount", None) + currency = input_value.get("currency", None) + + if not is_finite(amount) or not isinstance(currency, str): + raise GraphQLError("Cannot parse money value dict: " + inspect(input_value)) + + return Money(float(amount), currency) + else: + raise GraphQLError("Cannot parse money value: " + inspect(input_value)) + + +def parse_money_literal( + value_node: ValueNode, variables: Optional[Dict[str, Any]] = None +) -> Money: + money = value_from_ast_untyped(value_node, variables) + if variables is not None and ( + # variables are not set when checked with ValuesIOfCorrectTypeRule + not money + or not is_finite(money.get("amount")) + or not isinstance(money.get("currency"), str) + ): + raise GraphQLError("Cannot parse literal money value: " + inspect(money)) + return Money(**money) + + +MoneyScalar = GraphQLScalarType( + name="Money", + serialize=serialize_money, + parse_value=parse_money_value, + parse_literal=parse_money_literal, +) + + +def resolve_balance(root, _info): + return root + + +def resolve_to_euros(_root, _info, money): + amount = money.amount + currency = money.currency + if not amount or currency == "EUR": + return amount + if currency == "DM": + return amount * 0.5 + raise ValueError("Cannot convert to euros: " + inspect(money)) + + +queryType = GraphQLObjectType( + name="RootQueryType", + fields={ + "balance": GraphQLField(MoneyScalar, resolve=resolve_balance), + "toEuros": GraphQLField( + GraphQLFloat, + args={"money": GraphQLArgument(MoneyScalar)}, + resolve=resolve_to_euros, + ), + }, +) + + +def resolve_spent_money(spent_money, _info, **kwargs): + return spent_money + + +async def subscribe_spend_all(_root, _info, money): + while money.amount > 0: + money = Money(money.amount - 1, money.currency) + yield money + await asyncio.sleep(1 * MS) + + +subscriptionType = GraphQLObjectType( + "Subscription", + fields=lambda: { + "spend": GraphQLField( + MoneyScalar, + args={"money": GraphQLArgument(MoneyScalar)}, + subscribe=subscribe_spend_all, + resolve=resolve_spent_money, + ) + }, +) + +root_value = Money(42, "DM") + +schema = GraphQLSchema(query=queryType, subscription=subscriptionType,) + + +def test_custom_scalar_in_output(): + + client = Client(schema=schema) + + query = gql("{balance}") + + result = client.execute(query, root_value=root_value) + + print(result) + + assert result["balance"] == serialize_money(root_value) + + +def test_custom_scalar_in_input_query(): + + client = Client(schema=schema) + + query = gql('{toEuros(money: {amount: 10, currency: "DM"})}') + + result = client.execute(query, root_value=root_value) + + assert result["toEuros"] == 5 + + query = gql('{toEuros(money: {amount: 10, currency: "EUR"})}') + + result = client.execute(query, root_value=root_value) + + assert result["toEuros"] == 10 + + +def test_custom_scalar_in_input_variable_values(): + + client = Client(schema=schema) + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + money_value = {"amount": 10, "currency": "DM"} + + variable_values = {"money": money_value} + + result = client.execute( + query, variable_values=variable_values, root_value=root_value + ) + + assert result["toEuros"] == 5 + + +def test_custom_scalar_in_input_variable_values_serialized(): + + client = Client(schema=schema) + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + money_value = Money(10, "DM") + + variable_values = {"money": money_value} + + result = client.execute( + query, + variable_values=variable_values, + root_value=root_value, + serialize_variables=True, + ) + + assert result["toEuros"] == 5 + + +def test_custom_scalar_in_input_variable_values_serialized_with_operation_name(): + + client = Client(schema=schema) + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + money_value = Money(10, "DM") + + variable_values = {"money": money_value} + + result = client.execute( + query, + variable_values=variable_values, + root_value=root_value, + serialize_variables=True, + operation_name="myquery", + ) + + assert result["toEuros"] == 5 + + +def test_serialize_variable_values_exception_multiple_ops_without_operation_name(): + + client = Client(schema=schema) + + query = gql( + """ + query myconversion($money: Money) { + toEuros(money: $money) + } + + query mybalance { + balance + }""" + ) + + money_value = Money(10, "DM") + + variable_values = {"money": money_value} + + with pytest.raises(GraphQLError) as exc_info: + client.execute( + query, + variable_values=variable_values, + root_value=root_value, + serialize_variables=True, + ) + + exception = exc_info.value + + assert ( + str(exception) + == "Must provide operation name if query contains multiple operations." + ) + + +def test_serialize_variable_values_exception_operation_name_not_found(): + + client = Client(schema=schema) + + query = gql( + """ + query myconversion($money: Money) { + toEuros(money: $money) + } +""" + ) + + money_value = Money(10, "DM") + + variable_values = {"money": money_value} + + with pytest.raises(GraphQLError) as exc_info: + client.execute( + query, + variable_values=variable_values, + root_value=root_value, + serialize_variables=True, + operation_name="invalid_operation_name", + ) + + exception = exc_info.value + + assert str(exception) == "Unknown operation named 'invalid_operation_name'." + + +def test_custom_scalar_subscribe_in_input_variable_values_serialized(): + + client = Client(schema=schema) + + query = gql("subscription spendAll($money: Money) {spend(money: $money)}") + + money_value = Money(10, "DM") + + variable_values = {"money": money_value} + + expected_result = {"spend": {"amount": 10, "currency": "DM"}} + + for result in client.subscribe( + query, + variable_values=variable_values, + root_value=root_value, + serialize_variables=True, + ): + print(f"result = {result!r}") + expected_result["spend"]["amount"] = expected_result["spend"]["amount"] - 1 + assert expected_result == result + + +async def make_money_backend(aiohttp_server): + from aiohttp import web + + async def handler(request): + data = await request.json() + source = data["query"] + + print(f"data keys = {data.keys()}") + try: + variables = data["variables"] + print(f"variables = {variables!r}") + except KeyError: + variables = None + + result = graphql_sync( + schema, source, variable_values=variables, root_value=root_value + ) + + print(f"backend result = {result!r}") + + return web.json_response( + { + "data": result.data, + "errors": [str(e) for e in result.errors] if result.errors else None, + } + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + return server + + +async def make_money_transport(aiohttp_server): + from gql.transport.aiohttp import AIOHTTPTransport + + server = await make_money_backend(aiohttp_server) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url, timeout=10) + + return transport + + +async def make_sync_money_transport(aiohttp_server): + from gql.transport.requests import RequestsHTTPTransport + + server = await make_money_backend(aiohttp_server) + + url = server.make_url("/") + + transport = RequestsHTTPTransport(url=url, timeout=10) + + return (server, transport) + + +@pytest.mark.asyncio +async def test_custom_scalar_in_output_with_transport(event_loop, aiohttp_server): + + transport = await make_money_transport(aiohttp_server) + + async with Client(transport=transport,) as session: + + query = gql("{balance}") + + result = await session.execute(query) + + print(result) + + assert result["balance"] == serialize_money(root_value) + + +@pytest.mark.asyncio +async def test_custom_scalar_in_input_query_with_transport(event_loop, aiohttp_server): + + transport = await make_money_transport(aiohttp_server) + + async with Client(transport=transport,) as session: + + query = gql('{toEuros(money: {amount: 10, currency: "DM"})}') + + result = await session.execute(query) + + assert result["toEuros"] == 5 + + query = gql('{toEuros(money: {amount: 10, currency: "EUR"})}') + + result = await session.execute(query) + + assert result["toEuros"] == 10 + + +@pytest.mark.asyncio +async def test_custom_scalar_in_input_variable_values_with_transport( + event_loop, aiohttp_server +): + + transport = await make_money_transport(aiohttp_server) + + async with Client(transport=transport,) as session: + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + money_value = {"amount": 10, "currency": "DM"} + # money_value = Money(10, "DM") + + variable_values = {"money": money_value} + + result = await session.execute(query, variable_values=variable_values) + + print(f"result = {result!r}") + assert result["toEuros"] == 5 + + +@pytest.mark.asyncio +async def test_custom_scalar_in_input_variable_values_split_with_transport( + event_loop, aiohttp_server +): + + transport = await make_money_transport(aiohttp_server) + + async with Client(transport=transport,) as session: + + query = gql( + """ +query myquery($amount: Float, $currency: String) { + toEuros(money: {amount: $amount, currency: $currency}) +}""" + ) + + variable_values = {"amount": 10, "currency": "DM"} + + result = await session.execute(query, variable_values=variable_values) + + print(f"result = {result!r}") + assert result["toEuros"] == 5 + + +@pytest.mark.asyncio +async def test_custom_scalar_serialize_variables(event_loop, aiohttp_server): + + transport = await make_money_transport(aiohttp_server) + + async with Client(schema=schema, transport=transport,) as session: + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + variable_values = {"money": Money(10, "DM")} + + result = await session.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + print(f"result = {result!r}") + assert result["toEuros"] == 5 + + +@pytest.mark.asyncio +async def test_custom_scalar_serialize_variables_no_schema(event_loop, aiohttp_server): + + transport = await make_money_transport(aiohttp_server) + + async with Client(transport=transport,) as session: + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + variable_values = {"money": Money(10, "DM")} + + with pytest.raises(TransportQueryError): + await session.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + +@pytest.mark.asyncio +async def test_custom_scalar_serialize_variables_schema_from_introspection( + event_loop, aiohttp_server +): + + transport = await make_money_transport(aiohttp_server) + + async with Client(transport=transport, fetch_schema_from_transport=True) as session: + + schema = session.client.schema + + # Updating the Money Scalar in the schema + # We cannot replace it because some other objects keep a reference + # to the existing Scalar + # cannot do: schema.type_map["Money"] = MoneyScalar + + money_scalar = schema.type_map["Money"] + + money_scalar.serialize = MoneyScalar.serialize + money_scalar.parse_value = MoneyScalar.parse_value + money_scalar.parse_literal = MoneyScalar.parse_literal + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + variable_values = {"money": Money(10, "DM")} + + result = await session.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + print(f"result = {result!r}") + assert result["toEuros"] == 5 + + +@pytest.mark.asyncio +async def test_update_schema_scalars(event_loop, aiohttp_server): + + transport = await make_money_transport(aiohttp_server) + + async with Client(transport=transport, fetch_schema_from_transport=True) as session: + + # Update the schema MoneyScalar default implementation from + # introspection with our provided conversion methods + update_schema_scalars(session.client.schema, [MoneyScalar]) + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + variable_values = {"money": Money(10, "DM")} + + result = await session.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + print(f"result = {result!r}") + assert result["toEuros"] == 5 + + +def test_update_schema_scalars_invalid_scalar(): + + with pytest.raises(GraphQLError) as exc_info: + update_schema_scalars(schema, [int]) + + exception = exc_info.value + + assert str(exception) == "Scalars should be instances of GraphQLScalarType." + + +def test_update_schema_scalars_invalid_scalar_argument(): + + with pytest.raises(GraphQLError) as exc_info: + update_schema_scalars(schema, MoneyScalar) + + exception = exc_info.value + + assert str(exception) == "Scalars argument should be a list of scalars." + + +def test_update_schema_scalars_scalar_not_found_in_schema(): + + NotFoundScalar = GraphQLScalarType(name="abcd",) + + with pytest.raises(GraphQLError) as exc_info: + update_schema_scalars(schema, [MoneyScalar, NotFoundScalar]) + + exception = exc_info.value + + assert str(exception) == "Scalar 'abcd' not found in schema." + + +@pytest.mark.asyncio +@pytest.mark.requests +async def test_custom_scalar_serialize_variables_sync_transport( + event_loop, aiohttp_server, run_sync_test +): + + server, transport = await make_sync_money_transport(aiohttp_server) + + def test_code(): + with Client(schema=schema, transport=transport,) as session: + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + variable_values = {"money": Money(10, "DM")} + + result = session.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + print(f"result = {result!r}") + assert result["toEuros"] == 5 + + await run_sync_test(event_loop, server, test_code) + + +def test_serialize_value_with_invalid_type(): + + with pytest.raises(GraphQLError) as exc_info: + serialize_value("Not a valid type", 50) + + exception = exc_info.value + + assert ( + str(exception) == "Impossible to serialize value with type: 'Not a valid type'." + ) + + +def test_serialize_value_with_non_null_type_null(): + + non_null_int = GraphQLNonNull(GraphQLInt) + + with pytest.raises(GraphQLError) as exc_info: + serialize_value(non_null_int, None) + + exception = exc_info.value + + assert str(exception) == "Type Int! Cannot be None." + + +def test_serialize_value_with_nullable_type(): + + nullable_int = GraphQLInt + + assert serialize_value(nullable_int, None) is None