Created
June 28, 2019 11:59
-
-
Save Nicarim/c2235b6ba58bf2ca9bd898cc541146af to your computer and use it in GitHub Desktop.
Generator of swagger schematics based on common sense.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from inspect import isclass | |
| from os.path import dirname, join | |
| import django | |
| import oyaml as yaml | |
| from django.conf import settings | |
| from django.core.exceptions import ViewDoesNotExist | |
| from django.core.management import BaseCommand | |
| from django.urls import URLPattern, URLResolver | |
| from rest_framework import serializers | |
| from rest_framework.serializers import ModelSerializer | |
| def describe_pattern(p): | |
| return str(p.pattern) | |
| class Command(BaseCommand): | |
| help = 'Generates schema based on serializers used in viewsets and class based views. ' \ | |
| 'Relies on using serializer_class variables' | |
| map_serializer_field = { | |
| serializers.IntegerField: {"type": "integer", "format": "int32"}, | |
| serializers.CharField: {"type": "string"}, | |
| serializers.URLField: {"type": "string", "format": "url"}, | |
| serializers.DateTimeField: {"type": "string", "format": "date-time"}, | |
| serializers.DateField: {"type": "string", "format": "date"}, | |
| serializers.FloatField: {"type": "number", "format": "double"}, | |
| serializers.BooleanField: {"type": "boolean"}, | |
| serializers.EmailField: {"type": "string", "format": "email"}, | |
| serializers.UUIDField: {"type": "string", "format": "uuid"}, | |
| serializers.ChoiceField: {"type": "integer"}, | |
| serializers.FileField: {"type": 'string', 'format': 'binary'}, | |
| serializers.ImageField: {"type": 'string', 'format': 'binary'}, | |
| serializers.SerializerMethodField: {"type": 'string'}, | |
| serializers.SlugField: {"type": 'string', 'format': 'slug'} | |
| } | |
| @staticmethod | |
| def additional_universal_properties(field_cls): | |
| _dict = {} | |
| if field_cls.read_only: | |
| _dict['readOnly'] = True | |
| if field_cls.write_only: | |
| _dict['writeOnly'] = True | |
| if field_cls.allow_null: | |
| _dict['nullable'] = True | |
| if field_cls.help_text: | |
| _dict['description'] = str(field_cls.help_text) | |
| if type(field_cls) not in (serializers.ManyRelatedField, serializers.PrimaryKeyRelatedField): | |
| if hasattr(field_cls, 'choices') and field_cls.choices: | |
| choices_enum = [] | |
| choices_description = 'Choices are as follows:' | |
| for key, val in field_cls.choices.items(): | |
| choices_enum.append(str(key)) | |
| choices_description += f"\n * `{key}` - {val}" | |
| _dict['enum'] = choices_enum | |
| if 'description' in _dict: | |
| _dict['description'] += f'\n{choices_description}' | |
| else: | |
| _dict['description'] = choices_description | |
| # print(f"{str(field_cls)} {str(field_cls.choices)}") | |
| return _dict | |
| def fields_mapper(self, serializer_class): | |
| field_map = {} | |
| required_fields = [] | |
| serializer_obj = serializer_class() if isclass(serializer_class) else serializer_class | |
| serializer_class = serializer_class if isclass(serializer_class) else serializer_class.__class__ | |
| serializer_fields = serializer_obj.get_fields() | |
| for name, field_cls in serializer_fields.items(): | |
| # Top level serializers | |
| if isinstance(serializer_obj, ModelSerializer): | |
| if name not in serializer_class.Meta.fields: | |
| continue | |
| if hasattr(field_cls, 'required') and field_cls.required: | |
| required_fields.append(name) | |
| if type(field_cls) in self.map_serializer_field.keys(): | |
| field_map[name] = {**self.additional_universal_properties(field_cls), | |
| **self.map_serializer_field[type(field_cls)]} | |
| elif type(field_cls) is serializers.PrimaryKeyRelatedField: | |
| field_map[name] = {**self.additional_universal_properties(field_cls), "type": "string"} | |
| if not field_cls.pk_field: | |
| # noinspection PyProtectedMember | |
| field_map[name]['description'] = f"Primary key field" \ | |
| f" `{field_cls.queryset.model._meta.pk.name}` of model `{field_cls.queryset.model.__name__}`" | |
| elif type(field_cls) is serializers.ListSerializer: | |
| field_map[name] = {**self.additional_universal_properties(field_cls), "type": "array"} | |
| recursive_map, recursive_required_fields = self.fields_mapper(field_cls.child) | |
| field_map[name]['items'] = {'type': 'object', 'properties': recursive_map} | |
| elif isinstance(field_cls, ModelSerializer): | |
| field_map[name] = {**self.additional_universal_properties(field_cls), "type": "object"} | |
| recursive_map, recursive_required_fields = self.fields_mapper(field_cls) | |
| field_map[name]["properties"] = recursive_map | |
| elif type(field_cls) is serializers.ManyRelatedField: | |
| if type(field_cls.child_relation) is serializers.PrimaryKeyRelatedField: | |
| this_name = serializer_class.__name__ | |
| other_name = field_cls.child_relation.queryset.model.__name__ | |
| field_map[name] = { | |
| **self.additional_universal_properties(field_cls), | |
| "type": "array", | |
| "description": f"Many to many relation between {this_name} " | |
| f"and {other_name} by Primary Key (ID)", | |
| "items": { | |
| "type": "string" | |
| }} | |
| else: | |
| print(f"Unsupported child of many2many relation {name} " | |
| f"of class {str(type(field_cls.child_relation))}") | |
| else: | |
| print(f"Unsupported field {name} of class {str(type(field_cls))}") | |
| return field_map, required_fields | |
| def handle(self, *args, **options): | |
| urlconf = __import__(getattr(settings, 'ROOT_URLCONF'), {}, {}, ['']) | |
| view_functions = self.extract_views_from_urlpatterns(urlconf.urlpatterns) | |
| custom_views = list(filter(lambda x: bool(x[2]) and x[2][:5] != "admin", view_functions)) | |
| serializers_to_map = {} | |
| for view in custom_views: | |
| view_cls = view[0].cls | |
| if not hasattr(view_cls, 'serializer_class'): | |
| continue | |
| field_map, required_fields = self.fields_mapper(view_cls.serializer_class) | |
| serializers_to_map[view_cls.serializer_class.__name__] = { | |
| 'type': 'object', | |
| 'required': required_fields, | |
| 'properties': field_map, | |
| } | |
| return self.write_to_yaml(serializers_to_map) | |
| @staticmethod | |
| def write_to_yaml(data): | |
| path = join(dirname(dirname(dirname(__file__))), 'spec/definitions_gen.yml') | |
| yaml.Dumper.ignore_aliases = lambda *args: True | |
| data = {'components': {'schemas': data}} | |
| with open(path, 'w') as f: | |
| yaml.dump(data, f, default_flow_style=False) | |
| def extract_views_from_urlpatterns(self, urlpatterns, base='', namespace=None): | |
| """ | |
| Return a list of views from a list of urlpatterns. | |
| Each object in the returned list is a three-tuple: (view_func, regex, name) | |
| """ | |
| views = [] | |
| for p in urlpatterns: | |
| if isinstance(p, URLPattern): | |
| try: | |
| if not p.name: | |
| name = p.name | |
| elif namespace: | |
| name = '{0}:{1}'.format(namespace, p.name) | |
| else: | |
| name = p.name | |
| pattern = describe_pattern(p) | |
| views.append((p.callback, base + pattern, name)) | |
| except ViewDoesNotExist: | |
| continue | |
| elif isinstance(p, URLResolver): | |
| try: | |
| patterns = p.url_patterns | |
| except ImportError: | |
| continue | |
| if namespace and p.namespace: | |
| _namespace = '{0}:{1}'.format(namespace, p.namespace) | |
| else: | |
| _namespace = (p.namespace or namespace) | |
| pattern = describe_pattern(p) | |
| views.extend(self.extract_views_from_urlpatterns(patterns, base + pattern, namespace=_namespace)) | |
| elif hasattr(p, '_get_callback'): | |
| try: | |
| views.append((p._get_callback(), base + describe_pattern(p), p.name)) | |
| except ViewDoesNotExist: | |
| continue | |
| elif hasattr(p, 'url_patterns') or hasattr(p, '_get_url_patterns'): | |
| try: | |
| patterns = p.url_patterns | |
| except ImportError: | |
| continue | |
| views.extend( | |
| self.extract_views_from_urlpatterns(patterns, base + describe_pattern(p), namespace=namespace)) | |
| else: | |
| raise TypeError("%s does not appear to be a urlpattern object" % p) | |
| return views | |
| if __name__ == "__main__": | |
| # For testing purposes only. | |
| django.setup(set_prefix=False) | |
| Command().handle() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment