Skip to content

Instantly share code, notes, and snippets.

@webel
Last active November 11, 2021 16:44
Show Gist options
  • Select an option

  • Save webel/de230773df9638096af8dbcbfe782ab4 to your computer and use it in GitHub Desktop.

Select an option

Save webel/de230773df9638096af8dbcbfe782ab4 to your computer and use it in GitHub Desktop.
Get select_relatade and prefetch_related fields for queryset given a request

graphene-django-extras is great.

Here's a homegrown solution to getting the fields that can be selected and prefetched at the qs stage in FilterSet classes.

Simply inherit from it to automatically get the qs with any possible data selected/prefetched.

import json
import re
from dataclasses import dataclass
from dataclasses import field as dfield
from typing import List, Tuple
import django_filters
from django.contrib.contenttypes.fields import GenericRelation
from django.core.exceptions import FieldDoesNotExist
from django.db import connection, reset_queries
from django.db.models import Field as ModelField
from django.db.models import fields as model_fields
from graphql.backend import GraphQLCoreBackend
from graphql.utils.ast_to_dict import ast_to_dict
from requests import Response
def database_debug(func):
def inner_func(*args, **kwargs):
reset_queries()
results = func(*args, **kwargs)
query_info = connection.queries
print("function_name: {}".format(func.__name__))
print("query_count: {}".format(len(query_info)))
queries = ["{}\n".format(query["sql"]) for query in query_info]
print("queries: \n{}".format("".join(queries)))
return results
inner_func.__name__ = func.__name__
return inner_func
@dataclass
class GraphQLRequestUtilities:
request: Response
data: dict = dfield(default_factory=dict)
operation_name: str = ""
def __post_init__(self):
self.data = self.__parse_request_body()
self.operation_name = self.data["operationName"]
def __parse_request_body(self) -> dict:
"""Decode request body to dict"""
content_type = self.request.content_type
if content_type == "application/graphql":
return {"query": self.request.body.decode("utf-8")}
if content_type == "application/json":
body = self.request.body.decode("utf-8")
return json.loads(body)
return {}
@staticmethod
def to_snake_case(string) -> str:
string = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", string)
string = re.sub("__([A-Z])", r"_\1", string)
string = re.sub("([a-z0-9])([A-Z])", r"\1_\2", string)
return string.lower()
@staticmethod
def collect_fields(node) -> dict:
"""Recursively collects fields from the AST
Args:
node (dict): A node in the AST
Returns:
A dict mapping each field found, along with their sub fields.
{'name': {},
'sentimentsPerLanguage': {'id': {},
'name': {},
'totalSentiments': {}},
'slug': {}}
"""
field = {}
if node.get("selection_set"):
for leaf in node["selection_set"]["selections"]:
if leaf["kind"] == "Field":
field.update(
{
leaf["name"][
"value"
]: GraphQLRequestUtilities.collect_fields(leaf)
}
)
return field
def get_request_data(self) -> dict:
return self.data
def get_operation_name(self) -> str:
return self.operation_name
def get_graphql_document(self) -> dict:
from graphy.schema import schema
document = GraphQLCoreBackend().document_from_string(schema, self.data["query"])
return document
def get_query_fields(self) -> dict:
document = self.get_graphql_document()
# NOTE: Here we hope that there's a single operation in the request
definition = document.document_ast.definitions[0]
fields = self.collect_fields(ast_to_dict(definition))
return fields
class OptimizedFilterSet(django_filters.FilterSet):
def get_result_fields(self) -> dict:
"""Here we hope the operation is named the same as the query we want to get the fields from
I.E. if we have
query searchNotes {
_debug {
sql {
duration
}
}
searchNotes {
results {
...
}
}
}
so, even though we have two queries, our operation is named after the paginated query that we want to get.
If this is not the case, we will still **try** to find the query with the results.
"""
gl_utils = GraphQLRequestUtilities(self.request)
query_fields = gl_utils.get_query_fields()
try:
operation_name = gl_utils.get_operation_name()
return query_fields[operation_name]["results"]
except KeyError:
pass
for key in query_fields.keys():
if query_fields[key].get("results"):
return query_fields[key]["results"]
return {}
@property
def model_prefetch_fields(self) -> List[ModelField]:
"""Return list of Django model fields that can be prefetched"""
return [
model_fields.related.ManyToManyField,
model_fields.reverse_related.ManyToOneRel,
GenericRelation,
]
@property
def model_select_fields(self) -> List[ModelField]:
"""Return list of Django model fields that can be selected,
like ForeignKey or OneToOne"""
return [
model_fields.related.ForeignKey,
model_fields.related.OneToOneField,
model_fields.reverse_related.OneToOneRel,
]
def get_select_and_prefetch_lists(
self, model, queried_fields
) -> Tuple[List[str], List[str]]:
"""Get the possible select_related and prefetch_related
data given a model and the queried_fields dict.
"""
select_fields = []
prefetch_fields = []
queried_fields_keys = queried_fields.keys()
for field in queried_fields_keys:
child_nodes = queried_fields[field]
snake_case = GraphQLRequestUtilities.to_snake_case(field)
try:
model_field = model._meta.get_field(snake_case)
# Would occur for properties from custom resolvers
except FieldDoesNotExist:
continue
current_list = []
if model_field.__class__ in self.model_select_fields:
select_fields.append(model_field.name)
current_list = select_fields
elif model_field.__class__ in self.model_prefetch_fields:
prefetch_fields.append(model_field.name)
current_list = prefetch_fields
if current_list and child_nodes.keys():
(
sub_select_fields,
sub_prefetch_fields,
) = self.get_select_and_prefetch_lists(
model_field.related_model, child_nodes
)
current_list += [
f"{model_field.name}__{sub}" for sub in sub_prefetch_fields
]
current_list += [
f"{model_field.name}__{sub}" for sub in sub_select_fields
]
return select_fields, prefetch_fields
@property
def qs(self):
print("IN OPTIMIZED QS")
qs = super(OptimizedFilterSet, self).qs
print(qs.model.__name__)
queried_fields = self.get_result_fields()
select_fields, prefetch_fields = self.get_select_and_prefetch_lists(
qs.model, queried_fields
)
print(select_fields, prefetch_fields)
return qs.select_related(*select_fields).prefetch_related(*prefetch_fields)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment