Skip to content

Instantly share code, notes, and snippets.

@Nicarim
Created June 28, 2019 11:59
Show Gist options
  • Select an option

  • Save Nicarim/c2235b6ba58bf2ca9bd898cc541146af to your computer and use it in GitHub Desktop.

Select an option

Save Nicarim/c2235b6ba58bf2ca9bd898cc541146af to your computer and use it in GitHub Desktop.
Generator of swagger schematics based on common sense.
from inspect import isclass
from os.path import dirname, join
import django
import oyaml as yaml
from django.conf import settings
from django.core.exceptions import ViewDoesNotExist
from django.core.management import BaseCommand
from django.urls import URLPattern, URLResolver
from rest_framework import serializers
from rest_framework.serializers import ModelSerializer
def describe_pattern(p):
return str(p.pattern)
class Command(BaseCommand):
help = 'Generates schema based on serializers used in viewsets and class based views. ' \
'Relies on using serializer_class variables'
map_serializer_field = {
serializers.IntegerField: {"type": "integer", "format": "int32"},
serializers.CharField: {"type": "string"},
serializers.URLField: {"type": "string", "format": "url"},
serializers.DateTimeField: {"type": "string", "format": "date-time"},
serializers.DateField: {"type": "string", "format": "date"},
serializers.FloatField: {"type": "number", "format": "double"},
serializers.BooleanField: {"type": "boolean"},
serializers.EmailField: {"type": "string", "format": "email"},
serializers.UUIDField: {"type": "string", "format": "uuid"},
serializers.ChoiceField: {"type": "integer"},
serializers.FileField: {"type": 'string', 'format': 'binary'},
serializers.ImageField: {"type": 'string', 'format': 'binary'},
serializers.SerializerMethodField: {"type": 'string'},
serializers.SlugField: {"type": 'string', 'format': 'slug'}
}
@staticmethod
def additional_universal_properties(field_cls):
_dict = {}
if field_cls.read_only:
_dict['readOnly'] = True
if field_cls.write_only:
_dict['writeOnly'] = True
if field_cls.allow_null:
_dict['nullable'] = True
if field_cls.help_text:
_dict['description'] = str(field_cls.help_text)
if type(field_cls) not in (serializers.ManyRelatedField, serializers.PrimaryKeyRelatedField):
if hasattr(field_cls, 'choices') and field_cls.choices:
choices_enum = []
choices_description = 'Choices are as follows:'
for key, val in field_cls.choices.items():
choices_enum.append(str(key))
choices_description += f"\n * `{key}` - {val}"
_dict['enum'] = choices_enum
if 'description' in _dict:
_dict['description'] += f'\n{choices_description}'
else:
_dict['description'] = choices_description
# print(f"{str(field_cls)} {str(field_cls.choices)}")
return _dict
def fields_mapper(self, serializer_class):
field_map = {}
required_fields = []
serializer_obj = serializer_class() if isclass(serializer_class) else serializer_class
serializer_class = serializer_class if isclass(serializer_class) else serializer_class.__class__
serializer_fields = serializer_obj.get_fields()
for name, field_cls in serializer_fields.items():
# Top level serializers
if isinstance(serializer_obj, ModelSerializer):
if name not in serializer_class.Meta.fields:
continue
if hasattr(field_cls, 'required') and field_cls.required:
required_fields.append(name)
if type(field_cls) in self.map_serializer_field.keys():
field_map[name] = {**self.additional_universal_properties(field_cls),
**self.map_serializer_field[type(field_cls)]}
elif type(field_cls) is serializers.PrimaryKeyRelatedField:
field_map[name] = {**self.additional_universal_properties(field_cls), "type": "string"}
if not field_cls.pk_field:
# noinspection PyProtectedMember
field_map[name]['description'] = f"Primary key field" \
f" `{field_cls.queryset.model._meta.pk.name}` of model `{field_cls.queryset.model.__name__}`"
elif type(field_cls) is serializers.ListSerializer:
field_map[name] = {**self.additional_universal_properties(field_cls), "type": "array"}
recursive_map, recursive_required_fields = self.fields_mapper(field_cls.child)
field_map[name]['items'] = {'type': 'object', 'properties': recursive_map}
elif isinstance(field_cls, ModelSerializer):
field_map[name] = {**self.additional_universal_properties(field_cls), "type": "object"}
recursive_map, recursive_required_fields = self.fields_mapper(field_cls)
field_map[name]["properties"] = recursive_map
elif type(field_cls) is serializers.ManyRelatedField:
if type(field_cls.child_relation) is serializers.PrimaryKeyRelatedField:
this_name = serializer_class.__name__
other_name = field_cls.child_relation.queryset.model.__name__
field_map[name] = {
**self.additional_universal_properties(field_cls),
"type": "array",
"description": f"Many to many relation between {this_name} "
f"and {other_name} by Primary Key (ID)",
"items": {
"type": "string"
}}
else:
print(f"Unsupported child of many2many relation {name} "
f"of class {str(type(field_cls.child_relation))}")
else:
print(f"Unsupported field {name} of class {str(type(field_cls))}")
return field_map, required_fields
def handle(self, *args, **options):
urlconf = __import__(getattr(settings, 'ROOT_URLCONF'), {}, {}, [''])
view_functions = self.extract_views_from_urlpatterns(urlconf.urlpatterns)
custom_views = list(filter(lambda x: bool(x[2]) and x[2][:5] != "admin", view_functions))
serializers_to_map = {}
for view in custom_views:
view_cls = view[0].cls
if not hasattr(view_cls, 'serializer_class'):
continue
field_map, required_fields = self.fields_mapper(view_cls.serializer_class)
serializers_to_map[view_cls.serializer_class.__name__] = {
'type': 'object',
'required': required_fields,
'properties': field_map,
}
return self.write_to_yaml(serializers_to_map)
@staticmethod
def write_to_yaml(data):
path = join(dirname(dirname(dirname(__file__))), 'spec/definitions_gen.yml')
yaml.Dumper.ignore_aliases = lambda *args: True
data = {'components': {'schemas': data}}
with open(path, 'w') as f:
yaml.dump(data, f, default_flow_style=False)
def extract_views_from_urlpatterns(self, urlpatterns, base='', namespace=None):
"""
Return a list of views from a list of urlpatterns.
Each object in the returned list is a three-tuple: (view_func, regex, name)
"""
views = []
for p in urlpatterns:
if isinstance(p, URLPattern):
try:
if not p.name:
name = p.name
elif namespace:
name = '{0}:{1}'.format(namespace, p.name)
else:
name = p.name
pattern = describe_pattern(p)
views.append((p.callback, base + pattern, name))
except ViewDoesNotExist:
continue
elif isinstance(p, URLResolver):
try:
patterns = p.url_patterns
except ImportError:
continue
if namespace and p.namespace:
_namespace = '{0}:{1}'.format(namespace, p.namespace)
else:
_namespace = (p.namespace or namespace)
pattern = describe_pattern(p)
views.extend(self.extract_views_from_urlpatterns(patterns, base + pattern, namespace=_namespace))
elif hasattr(p, '_get_callback'):
try:
views.append((p._get_callback(), base + describe_pattern(p), p.name))
except ViewDoesNotExist:
continue
elif hasattr(p, 'url_patterns') or hasattr(p, '_get_url_patterns'):
try:
patterns = p.url_patterns
except ImportError:
continue
views.extend(
self.extract_views_from_urlpatterns(patterns, base + describe_pattern(p), namespace=namespace))
else:
raise TypeError("%s does not appear to be a urlpattern object" % p)
return views
if __name__ == "__main__":
# For testing purposes only.
django.setup(set_prefix=False)
Command().handle()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment