From 549104355c0fcfe745e380ed3e3e847d17779943 Mon Sep 17 00:00:00 2001 From: LaoSi Date: Tue, 19 Jan 2021 23:25:51 +0800 Subject: [PATCH 01/34] first commit , create class APIView --- .gitignore | 2 + Apps/__init__.py | 26 +++++++ Cofing/__init__.py | 22 ++++++ Cofing/develop.py | 12 ++++ Cofing/formal.py | 12 ++++ Cofing/local.py | 12 ++++ db.py | 22 ++++++ run.py | 33 +++++++++ sanic_rest_framework/__init__.py | 12 ++++ sanic_rest_framework/constant.py | 19 ++++++ sanic_rest_framework/routes.py | 71 +++++++++++++++++++ sanic_rest_framework/serializers.py | 0 sanic_rest_framework/status.py | 101 ++++++++++++++++++++++++++++ sanic_rest_framework/views.py | 94 ++++++++++++++++++++++++++ 14 files changed, 438 insertions(+) create mode 100644 Apps/__init__.py create mode 100644 Cofing/__init__.py create mode 100644 Cofing/develop.py create mode 100644 Cofing/formal.py create mode 100644 Cofing/local.py create mode 100644 db.py create mode 100644 run.py create mode 100644 sanic_rest_framework/__init__.py create mode 100644 sanic_rest_framework/constant.py create mode 100644 sanic_rest_framework/routes.py create mode 100644 sanic_rest_framework/serializers.py create mode 100644 sanic_rest_framework/status.py create mode 100644 sanic_rest_framework/views.py diff --git a/.gitignore b/.gitignore index 11614af..51cb344 100644 --- a/.gitignore +++ b/.gitignore @@ -113,3 +113,5 @@ dmypy.json # Pyre type checker .pyre/ +.idea/ +.vscode/ \ No newline at end of file diff --git a/Apps/__init__.py b/Apps/__init__.py new file mode 100644 index 0000000..e5646ee --- /dev/null +++ b/Apps/__init__.py @@ -0,0 +1,26 @@ +""" +@Author: WangYuXiang +@E-mile: Hill@3io.cc +@CreateTime: 2021/1/15 14:01 +@DependencyLibrary: + TODO: 需要填写依赖安装方法 例:python -m pip install requests +@MainFunction: + 无 +@FileDoc: + __init__.py + 工厂函数 +""" +import os + +from sanic import Sanic + +from Cofing import get_config +from tortoise.contrib.sanic import register_tortoise + +def create_app(): + app = Sanic(__name__) + app.config.from_object(get_config(config_name='develop')) + register_tortoise( + app, config=db_config, modules={"models": ["models"]}, generate_schemas=False + ) + return app diff --git a/Cofing/__init__.py b/Cofing/__init__.py new file mode 100644 index 0000000..ad88d06 --- /dev/null +++ b/Cofing/__init__.py @@ -0,0 +1,22 @@ +""" +@Author: WangYuXiang +@E-mile: Hill@3io.cc +@CreateTime: 2021/1/15 14:07 +@DependencyLibrary: + TODO: 需要填写依赖安装方法 例:python -m pip install requests +@MainFunction: + TODO: 需要填写入口函数 +@FileDoc: + __init__.py + 配置初始化文件 +""" +import importlib + + +def get_config(config_name): + """ + 得到配置文件 + :param config_name: + :return: + """ + return importlib.import_module('Config.{}.Config'.format(config_name)) diff --git a/Cofing/develop.py b/Cofing/develop.py new file mode 100644 index 0000000..aed0f16 --- /dev/null +++ b/Cofing/develop.py @@ -0,0 +1,12 @@ +""" +@Author: WangYuXiang +@E-mile: Hill@3io.cc +@CreateTime: 2021/1/15 14:04 +@DependencyLibrary: + TODO: 需要填写依赖安装方法 例:python -m pip install requests +@MainFunction: + TODO: 需要填写入口函数 +@FileDoc: + develop.py + 开发环境配置文件 +""" diff --git a/Cofing/formal.py b/Cofing/formal.py new file mode 100644 index 0000000..033ff8d --- /dev/null +++ b/Cofing/formal.py @@ -0,0 +1,12 @@ +""" +@Author: WangYuXiang +@E-mile: Hill@3io.cc +@CreateTime: 2021/1/15 14:04 +@DependencyLibrary: + TODO: 需要填写依赖安装方法 例:python -m pip install requests +@MainFunction: + TODO: 需要填写入口函数 +@FileDoc: + formal.py + 正式运行环境配置文件 +""" diff --git a/Cofing/local.py b/Cofing/local.py new file mode 100644 index 0000000..276864e --- /dev/null +++ b/Cofing/local.py @@ -0,0 +1,12 @@ +""" +@Author: WangYuXiang +@E-mile: Hill@3io.cc +@CreateTime: 2021/1/15 14:03 +@DependencyLibrary: + TODO: 需要填写依赖安装方法 例:python -m pip install requests +@MainFunction: + TODO: 需要填写入口函数 +@FileDoc: + local.py + 本地运行配置文件 +""" diff --git a/db.py b/db.py new file mode 100644 index 0000000..3d62573 --- /dev/null +++ b/db.py @@ -0,0 +1,22 @@ +""" +@Author: WangYuXiang +@E-mile: Hill@3io.cc +@CreateTime: 2021/1/15 14:00 +@DependencyLibrary: + TODO: 需要填写依赖安装方法 例:python -m pip install requests +@MainFunction: + TODO: 需要填写入口函数 +@FileDoc: + db.py + 文件说明 +""" +from datetime import date + +from tortoise.fields import CharField, IntField, DateField +from tortoise import Model + + +class TestModel(Model): + name = CharField(max_length=30) + ages = IntField() + birthday = DateField(default=date.today) diff --git a/run.py b/run.py new file mode 100644 index 0000000..d4d15a6 --- /dev/null +++ b/run.py @@ -0,0 +1,33 @@ +from sanic import Sanic +from sanic.blueprints import Blueprint +from tortoise.contrib.sanic import register_tortoise + +from sanic_rest_framework.routes import Route +from sanic_rest_framework.views import BaseAPIView + +app = Sanic(__name__) +admin = Blueprint('admin', '/admin') + + +class TestView(BaseAPIView): + + async def get(self, request): + return self.success_json_response() + + async def post(self, request): + return self.success_json_response() + + async def put(self, request, pk): + return self.success_json_response() + + +route = Route() +route.register_route('test', TestView) +route.initialize(admin) +app.blueprint(admin) + +register_tortoise( + app, db_url="sqlite:///db.sqlite", modules={"models": ["db"]}, generate_schemas=True +) + +app.run(host="127.0.0.1", port=8000, debug=True, auto_reload=True) diff --git a/sanic_rest_framework/__init__.py b/sanic_rest_framework/__init__.py new file mode 100644 index 0000000..a85a032 --- /dev/null +++ b/sanic_rest_framework/__init__.py @@ -0,0 +1,12 @@ +""" +@Author: WangYuXiang +@E-mile: Hill@3io.cc +@CreateTime: 2021/1/19 15:47 +@DependencyLibrary: + TODO: 需要填写依赖安装方法 例:python -m pip install requests +@MainFunction: + TODO: 需要填写入口函数 +@FileDoc: + __init__.py + 文件说明 +""" diff --git a/sanic_rest_framework/constant.py b/sanic_rest_framework/constant.py new file mode 100644 index 0000000..8ea0801 --- /dev/null +++ b/sanic_rest_framework/constant.py @@ -0,0 +1,19 @@ +""" +@Author: WangYuXiang +@E-mile: Hill@3io.cc +@CreateTime: 2021/1/19 15:45 +@DependencyLibrary: +@MainFunction: +@FileDoc: + constant.py + 全局常量 +""" +ALL_METHOD = {'GET', 'POST', 'PUT', 'PATCH', 'DELETE', 'HEAD', 'OPTIONS'} +DETAIL_METHOD_GROUP = { + 'dynamic_method': ['GET', 'PUT', 'DELETE', 'PATCH'], + 'static_method': ['POST', 'OPTION'] +} +LIST_METHOD_GROUP = { + 'dynamic_method': ['PUT', 'DELETE', 'PATCH'], + 'static_method': ['GET', 'POST', 'OPTION'] +} diff --git a/sanic_rest_framework/routes.py b/sanic_rest_framework/routes.py new file mode 100644 index 0000000..fd5e046 --- /dev/null +++ b/sanic_rest_framework/routes.py @@ -0,0 +1,71 @@ +""" +@Author: WangYuXiang +@E-mile: Hill@3io.cc +@CreateTime: 2021/1/19 16:08 +@DependencyLibrary: +@MainFunction: +@FileDoc: + routes.py + 便捷路由文件 +""" +from typing import List, Type, Union + +from sanic import Sanic, Blueprint + +from .constant import ALL_METHOD, DETAIL_METHOD_GROUP, LIST_METHOD_GROUP + + +# 默认分组 + + +class Route: + def __init__(self): + self.routes = [] + + def register_route(self, prefix, viewset, name=None): + """ + 注册路由 + :param prefix: url 前缀 + :param viewset: 视图类 + :param name: 供 url_for 使用的名称 + :return: + """ + if name is None: + name = prefix + + dynamic_uri = '/{prefix}/' + static_uri = '/{prefix}' + base_method_group = LIST_METHOD_GROUP + if viewset.detail: + base_method_group = DETAIL_METHOD_GROUP + + viewset_methods = self.get_viewset_methods(viewset) + viewset_dynamic_method = [i for i in viewset_methods if i in base_method_group['dynamic_method']] + viewset_static_method = [i for i in viewset_methods if i in base_method_group['static_method']] + + if viewset_dynamic_method: + self.routes.append({ + 'handler': viewset.as_view(viewset_dynamic_method), + 'uri': dynamic_uri.format(prefix=prefix), + 'name': '{name}_detail'.format(name=name) + }) + if viewset_static_method: + self.routes.append({ + 'handler': viewset.as_view(viewset_static_method), + 'uri': static_uri.format(prefix=prefix), + 'name': '{name}_list'.format(name=name) + }) + + def get_viewset_methods(self, viewset): + """得到viewSet所有请求方法""" + methods = [] + for method in ALL_METHOD: + if hasattr(viewset, method.lower()): + methods.append(method) + return methods + + def initialize(self, destination: Union[Sanic, Blueprint]): + """注册路由""" + for route in self.routes: + route['methods'] = ALL_METHOD + destination.add_route(**route) diff --git a/sanic_rest_framework/serializers.py b/sanic_rest_framework/serializers.py new file mode 100644 index 0000000..e69de29 diff --git a/sanic_rest_framework/status.py b/sanic_rest_framework/status.py new file mode 100644 index 0000000..e482042 --- /dev/null +++ b/sanic_rest_framework/status.py @@ -0,0 +1,101 @@ +""" +@Author: WangYuXiang +@E-mile: Hill@3io.cc +@CreateTime: 2021/1/19 9:34 +@DependencyLibrary: +@MainFunction: +@FileDoc: + status.py + Http status describe file +""" +from enum import Enum + + +def is_informational(code): + return 100 <= code <= 199 + + +def is_success(code): + return 200 <= code <= 299 + + +def is_redirect(code): + return 300 <= code <= 399 + + +def is_client_error(code): + return 400 <= code <= 499 + + +def is_server_error(code): + return 500 <= code <= 599 + + +# 约定规则状态 +class RuleStatus: + STATUS_0_FAIL = 0 + STATUS_1_SUCCESS = 1 + + +# 协议状态 +class HttpStatus: + HTTP_100_CONTINUE = 100 + HTTP_101_SWITCHING_PROTOCOLS = 101 + HTTP_200_OK = 200 + HTTP_201_CREATED = 201 + HTTP_202_ACCEPTED = 202 + HTTP_203_NON_AUTHORITATIVE_INFORMATION = 203 + HTTP_204_NO_CONTENT = 204 + HTTP_205_RESET_CONTENT = 205 + HTTP_206_PARTIAL_CONTENT = 206 + HTTP_207_MULTI_STATUS = 207 + HTTP_208_ALREADY_REPORTED = 208 + HTTP_226_IM_USED = 226 + HTTP_300_MULTIPLE_CHOICES = 300 + HTTP_301_MOVED_PERMANENTLY = 301 + HTTP_302_FOUND = 302 + HTTP_303_SEE_OTHER = 303 + HTTP_304_NOT_MODIFIED = 304 + HTTP_305_USE_PROXY = 305 + HTTP_306_RESERVED = 306 + HTTP_307_TEMPORARY_REDIRECT = 307 + HTTP_308_PERMANENT_REDIRECT = 308 + HTTP_400_BAD_REQUEST = 400 + HTTP_401_UNAUTHORIZED = 401 + HTTP_402_PAYMENT_REQUIRED = 402 + HTTP_403_FORBIDDEN = 403 + HTTP_404_NOT_FOUND = 404 + HTTP_405_METHOD_NOT_ALLOWED = 405 + HTTP_406_NOT_ACCEPTABLE = 406 + HTTP_407_PROXY_AUTHENTICATION_REQUIRED = 407 + HTTP_408_REQUEST_TIMEOUT = 408 + HTTP_409_CONFLICT = 409 + HTTP_410_GONE = 410 + HTTP_411_LENGTH_REQUIRED = 411 + HTTP_412_PRECONDITION_FAILED = 412 + HTTP_413_REQUEST_ENTITY_TOO_LARGE = 413 + HTTP_414_REQUEST_URI_TOO_LONG = 414 + HTTP_415_UNSUPPORTED_MEDIA_TYPE = 415 + HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE = 416 + HTTP_417_EXPECTATION_FAILED = 417 + HTTP_418_IM_A_TEAPOT = 418 + HTTP_422_UNPROCESSABLE_ENTITY = 422 + HTTP_423_LOCKED = 423 + HTTP_424_FAILED_DEPENDENCY = 424 + HTTP_426_UPGRADE_REQUIRED = 426 + HTTP_428_PRECONDITION_REQUIRED = 428 + HTTP_429_TOO_MANY_REQUESTS = 429 + HTTP_431_REQUEST_HEADER_FIELDS_TOO_LARGE = 431 + HTTP_451_UNAVAILABLE_FOR_LEGAL_REASONS = 451 + HTTP_500_INTERNAL_SERVER_ERROR = 500 + HTTP_501_NOT_IMPLEMENTED = 501 + HTTP_502_BAD_GATEWAY = 502 + HTTP_503_SERVICE_UNAVAILABLE = 503 + HTTP_504_GATEWAY_TIMEOUT = 504 + HTTP_505_HTTP_VERSION_NOT_SUPPORTED = 505 + HTTP_506_VARIANT_ALSO_NEGOTIATES = 506 + HTTP_507_INSUFFICIENT_STORAGE = 507 + HTTP_508_LOOP_DETECTED = 508 + HTTP_509_BANDWIDTH_LIMIT_EXCEEDED = 509 + HTTP_510_NOT_EXTENDED = 510 + HTTP_511_NETWORK_AUTHENTICATION_REQUIRED = 511 diff --git a/sanic_rest_framework/views.py b/sanic_rest_framework/views.py new file mode 100644 index 0000000..a502360 --- /dev/null +++ b/sanic_rest_framework/views.py @@ -0,0 +1,94 @@ +""" +@Author: WangYuXiang +@E-mile: Hill@3io.cc +@CreateTime: 2021/1/19 15:44 +@DependencyLibrary: +@MainFunction: +@FileDoc: + views.py + 基础视图文件 +""" +from sanic.response import json +from sanic_rest_framework.constant import ALL_METHOD +from sanic_rest_framework.status import RuleStatus, HttpStatus + + +class BaseAPIView: + """基础API视图""" + detail = False + + def dispatch(self, request, *args, **kwargs): + """分发路由""" + method = request.method + if method not in self.licensed_methods: + return self.json_response(msg='发生错误:未找到%s方法' % method, status=RuleStatus.STATUS_0_FAIL, + response_status=HttpStatus.HTTP_405_METHOD_NOT_ALLOWED) + handler = getattr(self, method.lower(), None) + return handler(request, *args, **kwargs) + + @classmethod + def as_view(cls, methods=None, *class_args, **class_kwargs): + + # 许可的方法 + if methods is None: + methods = ALL_METHOD + + # 返回的响应方法闭包 + def view(request, *args, **kwargs): + self = view.base_class(*class_args, **class_kwargs) + self.licensed_methods = methods + self.request = request + self.args = args + self.kwargs = kwargs + return self.dispatch(request, *args, **kwargs) + + view.base_class = cls + view.API_DOC_CONFIG = class_kwargs.get('API_DOC_CONFIG') # 未来的API文档配置属性 + view.__doc__ = cls.__doc__ + view.__module__ = cls.__module__ + view.__name__ = cls.__name__ + return view + + def json_response(self, data=None, msg="OK", status=RuleStatus.STATUS_1_SUCCESS, + response_status=HttpStatus.HTTP_200_OK): + """ + Json 相应体 + :param data: 返回的数据主题 + :param msg: 前台提示字符串 + :param status: 前台约定状态,供前台判断是否成功 + :param response_status: Http响应数据 + :return: + """ + if data is None: + data = {} + response_body = { + 'data': data, + 'message': msg, + 'status': status + } + return json(body=response_body, status=response_status) + + def success_json_response(self, data=None, msg="Success"): + """ + 快捷的成功的json响应体 + :param data: 返回的数据主题 + :param msg: 前台提示字符串 + :return: json + """ + return self.json_response(data=data, msg=msg, status=RuleStatus.STATUS_1_SUCCESS) + + def error_json_response(self, data=None, msg="Fail"): + """ + 快捷的失败的json响应体 + :param data: 返回的数据主题 + :param msg: 前台提示字符串 + :return: json + """ + return self.json_response(data=data, msg=msg, status=RuleStatus.STATUS_0_FAIL, + response_status=HttpStatus.HTTP_400_BAD_REQUEST) + + +# class ViewJsonHelperMixin: + + +# class APIView(ViewJsonHelperMixin, BaseAPIView): -- Gitee From c5a6b27876c5a47634d7ca52c8e87a85acaed9c6 Mon Sep 17 00:00:00 2001 From: LaoSi Date: Wed, 20 Jan 2021 17:59:10 +0800 Subject: [PATCH 02/34] serializers --- requirements.txt | 18 ++ run.py | 3 +- sanic_rest_framework/fields.py | 403 ++++++++++++++++++++++++++++ sanic_rest_framework/serializers.py | 10 + 4 files changed, 433 insertions(+), 1 deletion(-) create mode 100644 requirements.txt create mode 100644 sanic_rest_framework/fields.py diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..01050ec --- /dev/null +++ b/requirements.txt @@ -0,0 +1,18 @@ +aiofiles==0.6.0 +aiosqlite==0.16.0 +certifi==2020.12.5 +h11==0.9.0 +httpcore==0.11.1 +httptools==0.1.1 +httpx==0.15.4 +idna==3.1 +iso8601==0.1.13 +multidict==5.1.0 +PyPika==0.44.1 +pytz==2020.5 +rfc3986==1.4.0 +sanic==20.12.1 +sniffio==1.2.0 +tortoise-orm==0.16.19 +typing-extensions==3.7.4.3 +websockets==8.1 diff --git a/run.py b/run.py index d4d15a6..f203742 100644 --- a/run.py +++ b/run.py @@ -2,6 +2,7 @@ from sanic import Sanic from sanic.blueprints import Blueprint from tortoise.contrib.sanic import register_tortoise +from db import TestModel from sanic_rest_framework.routes import Route from sanic_rest_framework.views import BaseAPIView @@ -27,7 +28,7 @@ route.initialize(admin) app.blueprint(admin) register_tortoise( - app, db_url="sqlite:///db.sqlite", modules={"models": ["db"]}, generate_schemas=True + app, db_url="sqlite://./db.sqlite", modules={"models": ["db"]}, generate_schemas=True ) app.run(host="127.0.0.1", port=8000, debug=True, auto_reload=True) diff --git a/sanic_rest_framework/fields.py b/sanic_rest_framework/fields.py new file mode 100644 index 0000000..774c346 --- /dev/null +++ b/sanic_rest_framework/fields.py @@ -0,0 +1,403 @@ +""" +@Author: WangYuXiang +@E-mile: Hill@3io.cc +@CreateTime: 2021/1/20 13:20 +@DependencyLibrary: +@MainFunction: +@FileDoc: + fields.py + 文件说明 +""" + + +class empty: + """ + 此类代表空,因为有些字段可以为 None + 所以需要一个可以替代 None 代表空变量 + """ + pass + + +class Field: + _creation_counter = 0 + + default_error_messages = { + 'required': '此字段必填.', + 'null': '此字段不能为空' + } + default_validators = [] + default_empty_value = empty + initial = None + + def __init__(self, read_only=False, write_only=False, + required=None, default=empty, initial=empty, source=None, + label=None, help_text=None, style=None, + error_messages=None, validators=None, allow_null=False): + self._creation_counter = Field._creation_counter + Field._creation_counter += 1 + + # If `required` is unset, then use `True` unless a default is provided. + if required is None: + required = default is empty and not read_only + + # Some combinations of keyword arguments do not make sense. + assert not (read_only and write_only), NOT_READ_ONLY_WRITE_ONLY + assert not (read_only and required), NOT_READ_ONLY_REQUIRED + assert not (required and default is not empty), NOT_REQUIRED_DEFAULT + assert not (read_only and self.__class__ == Field), USE_READONLYFIELD + + self.read_only = read_only + self.write_only = write_only + self.required = required + self.default = default + self.source = source + self.initial = self.initial if (initial is empty) else initial + self.label = label + self.help_text = help_text + self.style = {} if style is None else style + self.allow_null = allow_null + + if self.default_empty_value is not empty: + if default is not empty: + self.default_empty_value = default + + if validators is not None: + self.validators = list(validators) + + # These are set up by `.bind()` when the field is added to a serializer. + self.field_name = None + self.parent = None + + # Collect default error message from self and parent classes + messages = {} + for cls in reversed(self.__class__.__mro__): + messages.update(getattr(cls, 'default_error_messages', {})) + messages.update(error_messages or {}) + self.error_messages = messages + + def bind(self, field_name, parent): + """ + Initializes the field name and parent for the field instance. + Called when a field is added to the parent serializer instance. + """ + + # In order to enforce a consistent style, we error if a redundant + # 'source' argument has been used. For example: + # my_field = serializer.CharField(source='my_field') + assert self.source != field_name, ( + "It is redundant to specify `source='%s'` on field '%s' in " + "serializer '%s', because it is the same as the field name. " + "Remove the `source` keyword argument." % + (field_name, self.__class__.__name__, parent.__class__.__name__) + ) + + self.field_name = field_name + self.parent = parent + + # `self.label` should default to being based on the field name. + if self.label is None: + self.label = field_name.replace('_', ' ').capitalize() + + # self.source should default to being the same as the field name. + if self.source is None: + self.source = field_name + + # self.source_attrs is a list of attributes that need to be looked up + # when serializing the instance, or populating the validated data. + if self.source == '*': + self.source_attrs = [] + else: + self.source_attrs = self.source.split('.') + + # .validators is a lazily loaded property, that gets its default + # value from `get_validators`. + @property + def validators(self): + if not hasattr(self, '_validators'): + self._validators = self.get_validators() + return self._validators + + @validators.setter + def validators(self, validators): + self._validators = validators + + def get_validators(self): + return list(self.default_validators) + + def get_initial(self): + """ + Return a value to use when the field is being returned as a primitive + value, without any object instance. + """ + if callable(self.initial): + return self.initial() + return self.initial + + def get_value(self, dictionary): + """ + Given the *incoming* primitive data, return the value for this field + that should be validated and transformed to a native value. + """ + if html.is_html_input(dictionary): + # HTML forms will represent empty fields as '', and cannot + # represent None or False values directly. + if self.field_name not in dictionary: + if getattr(self.root, 'partial', False): + return empty + return self.default_empty_value + ret = dictionary[self.field_name] + if ret == '' and self.allow_null: + # If the field is blank, and null is a valid value then + # determine if we should use null instead. + return '' if getattr(self, 'allow_blank', False) else None + elif ret == '' and not self.required: + # If the field is blank, and emptiness is valid then + # determine if we should use emptiness instead. + return '' if getattr(self, 'allow_blank', False) else empty + return ret + return dictionary.get(self.field_name, empty) + + def get_attribute(self, instance): + """ + Given the *outgoing* object instance, return the primitive value + that should be used for this field. + """ + try: + return get_attribute(instance, self.source_attrs) + except BuiltinSignatureError as exc: + msg = ( + 'Field source for `{serializer}.{field}` maps to a built-in ' + 'function type and is invalid. Define a property or method on ' + 'the `{instance}` instance that wraps the call to the built-in ' + 'function.'.format( + serializer=self.parent.__class__.__name__, + field=self.field_name, + instance=instance.__class__.__name__, + ) + ) + raise type(exc)(msg) + except (KeyError, AttributeError) as exc: + if self.default is not empty: + return self.get_default() + if self.allow_null: + return None + if not self.required: + raise SkipField() + msg = ( + 'Got {exc_type} when attempting to get a value for field ' + '`{field}` on serializer `{serializer}`.\nThe serializer ' + 'field might be named incorrectly and not match ' + 'any attribute or key on the `{instance}` instance.\n' + 'Original exception text was: {exc}.'.format( + exc_type=type(exc).__name__, + field=self.field_name, + serializer=self.parent.__class__.__name__, + instance=instance.__class__.__name__, + exc=exc + ) + ) + raise type(exc)(msg) + + def get_default(self): + """ + Return the default value to use when validating data if no input + is provided for this field. + + If a default has not been set for this field then this will simply + raise `SkipField`, indicating that no value should be set in the + validated data for this field. + """ + if self.default is empty or getattr(self.root, 'partial', False): + # No default, or this is a partial update. + raise SkipField() + if callable(self.default): + if hasattr(self.default, 'set_context'): + warnings.warn( + "Method `set_context` on defaults is deprecated and will " + "no longer be called starting with 3.13. Instead set " + "`requires_context = True` on the class, and accept the " + "context as an additional argument.", + RemovedInDRF313Warning, stacklevel=2 + ) + self.default.set_context(self) + + if getattr(self.default, 'requires_context', False): + return self.default(self) + else: + return self.default() + + return self.default + + def validate_empty_values(self, data): + """ + Validate empty values, and either: + + * Raise `ValidationError`, indicating invalid data. + * Raise `SkipField`, indicating that the field should be ignored. + * Return (True, data), indicating an empty value that should be + returned without any further validation being applied. + * Return (False, data), indicating a non-empty value, that should + have validation applied as normal. + """ + if self.read_only: + return (True, self.get_default()) + + if data is empty: + if getattr(self.root, 'partial', False): + raise SkipField() + if self.required: + self.fail('required') + return (True, self.get_default()) + + if data is None: + if not self.allow_null: + self.fail('null') + # Nullable `source='*'` fields should not be skipped when its named + # field is given a null value. This is because `source='*'` means + # the field is passed the entire object, which is not null. + elif self.source == '*': + return (False, None) + return (True, None) + + return (False, data) + + def run_validation(self, data=empty): + """ + Validate a simple representation and return the internal value. + + The provided data may be `empty` if no representation was included + in the input. + + May raise `SkipField` if the field should not be included in the + validated data. + """ + (is_empty_value, data) = self.validate_empty_values(data) + if is_empty_value: + return data + value = self.to_internal_value(data) + self.run_validators(value) + return value + + def run_validators(self, value): + """ + Test the given value against all the validators on the field, + and either raise a `ValidationError` or simply return. + """ + errors = [] + for validator in self.validators: + if hasattr(validator, 'set_context'): + warnings.warn( + "Method `set_context` on validators is deprecated and will " + "no longer be called starting with 3.13. Instead set " + "`requires_context = True` on the class, and accept the " + "context as an additional argument.", + RemovedInDRF313Warning, stacklevel=2 + ) + validator.set_context(self) + + try: + if getattr(validator, 'requires_context', False): + validator(value, self) + else: + validator(value) + except ValidationError as exc: + # If the validation error contains a mapping of fields to + # errors then simply raise it immediately rather than + # attempting to accumulate a list of errors. + if isinstance(exc.detail, dict): + raise + errors.extend(exc.detail) + except DjangoValidationError as exc: + errors.extend(get_error_detail(exc)) + if errors: + raise ValidationError(errors) + + def to_internal_value(self, data): + """ + Transform the *incoming* primitive data into a native value. + """ + raise NotImplementedError( + '{cls}.to_internal_value() must be implemented for field ' + '{field_name}. If you do not need to support write operations ' + 'you probably want to subclass `ReadOnlyField` instead.'.format( + cls=self.__class__.__name__, + field_name=self.field_name, + ) + ) + + def to_representation(self, value): + """ + Transform the *outgoing* native value into primitive data. + """ + raise NotImplementedError( + '{cls}.to_representation() must be implemented for field {field_name}.'.format( + cls=self.__class__.__name__, + field_name=self.field_name, + ) + ) + + def fail(self, key, **kwargs): + """ + A helper method that simply raises a validation error. + """ + try: + msg = self.error_messages[key] + except KeyError: + class_name = self.__class__.__name__ + msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key) + raise AssertionError(msg) + message_string = msg.format(**kwargs) + raise ValidationError(message_string, code=key) + + @property + def root(self): + """ + Returns the top-level serializer for this field. + """ + root = self + while root.parent is not None: + root = root.parent + return root + + @property + def context(self): + """ + Returns the context as passed to the root serializer on initialization. + """ + return getattr(self.root, '_context', {}) + + def __new__(cls, *args, **kwargs): + """ + When a field is instantiated, we store the arguments that were used, + so that we can present a helpful representation of the object. + """ + instance = super().__new__(cls) + instance._args = args + instance._kwargs = kwargs + return instance + + def __deepcopy__(self, memo): + """ + When cloning fields we instantiate using the arguments it was + originally created with, rather than copying the complete state. + """ + # Treat regexes and validators as immutable. + # See https://github.com/encode/django-rest-framework/issues/1954 + # and https://github.com/encode/django-rest-framework/pull/4489 + args = [ + copy.deepcopy(item) if not isinstance(item, REGEX_TYPE) else item + for item in self._args + ] + kwargs = { + key: (copy.deepcopy(value, memo) if (key not in ('validators', 'regex')) else value) + for key, value in self._kwargs.items() + } + return self.__class__(*args, **kwargs) + + def __repr__(self): + """ + Fields are represented using their initial calling arguments. + This allows us to create descriptive representations for serializer + instances that show all the declared fields on the serializer. + """ + return representation.field_repr(self) diff --git a/sanic_rest_framework/serializers.py b/sanic_rest_framework/serializers.py index e69de29..2907fdd 100644 --- a/sanic_rest_framework/serializers.py +++ b/sanic_rest_framework/serializers.py @@ -0,0 +1,10 @@ +""" +@Author: WangYuXiang +@E-mile: Hill@3io.cc +@CreateTime: 2021/1/20 13:20 +@DependencyLibrary: +@MainFunction: +@FileDoc: + serializers.py + 序列化器文件 +""" -- Gitee From 2c447ed8ca4ddf8f0f6a5d0d15a1be067a0c0b51 Mon Sep 17 00:00:00 2001 From: LaoSi Date: Wed, 20 Jan 2021 23:17:17 +0800 Subject: [PATCH 03/34] fields.py --- sanic_rest_framework/exceptions.py | 10 + sanic_rest_framework/fields.py | 405 +++-------------------------- 2 files changed, 52 insertions(+), 363 deletions(-) create mode 100644 sanic_rest_framework/exceptions.py diff --git a/sanic_rest_framework/exceptions.py b/sanic_rest_framework/exceptions.py new file mode 100644 index 0000000..fc388d6 --- /dev/null +++ b/sanic_rest_framework/exceptions.py @@ -0,0 +1,10 @@ +""" +@Author: WangYuXiang +@E-mile: Hill@3io.cc +@CreateTime: 2021/1/20 20:03 +@DependencyLibrary: 无 +@MainFunction:无 +@FileDoc: + exceptions.py + 序列化器文件 +""" diff --git a/sanic_rest_framework/fields.py b/sanic_rest_framework/fields.py index 774c346..2d65ebc 100644 --- a/sanic_rest_framework/fields.py +++ b/sanic_rest_framework/fields.py @@ -8,6 +8,7 @@ fields.py 文件说明 """ +import copy class empty: @@ -18,386 +19,64 @@ class empty: pass +NOT_RAED_ONLY_AND_WRITE_ONLY = 'read_only 和 write_only 不能同时为True, 只能二选一' +NOT_RAED_ONLY_REQUIRED_ONLY = 'read_only 为 True 时 required 不能为True , 只能二选一' + + class Field: - _creation_counter = 0 + _sort_counter = 0 default_error_messages = { - 'required': '此字段必填.', - 'null': '此字段不能为空' + 'required': '{field_name}是必填项', + 'null': '{field_name}不能为空' } - default_validators = [] - default_empty_value = empty - initial = None - - def __init__(self, read_only=False, write_only=False, - required=None, default=empty, initial=empty, source=None, - label=None, help_text=None, style=None, - error_messages=None, validators=None, allow_null=False): - self._creation_counter = Field._creation_counter - Field._creation_counter += 1 - - # If `required` is unset, then use `True` unless a default is provided. - if required is None: - required = default is empty and not read_only - - # Some combinations of keyword arguments do not make sense. - assert not (read_only and write_only), NOT_READ_ONLY_WRITE_ONLY - assert not (read_only and required), NOT_READ_ONLY_REQUIRED - assert not (required and default is not empty), NOT_REQUIRED_DEFAULT - assert not (read_only and self.__class__ == Field), USE_READONLYFIELD + default_validators = None + + def __init__(self, read_only=False, write_only=False, required=True, allow_null=False, + default=None, initial=empty, source=None, error_messages=None, + label=None, description=None, validators=None, ): + """ + 字段基类 + :param read_only: 是否只反序列化 + :param write_only: 是否只序列化 + :param required: 是否必填 与 read_only 冲突 + :param allow_null: 是否允许为空 + :param default: 默认值 + :param initial: 初始化值 + :param source: 来源 getattr(initial,source) + :param label: 易懂的语言描述 列如:机器编号, 呈现在ApiDoc内 + :param description: 详细描述 呈现在ApiDoc内 + :param validators: 自定义的验证器 + """ + # 优先检查冲突 + assert (read_only and write_only), NOT_RAED_ONLY_AND_WRITE_ONLY + assert (read_only and required), NOT_RAED_ONLY_REQUIRED_ONLY + + self._sort_counter = Field._sort_counter + Field._sort_counter += 1 self.read_only = read_only self.write_only = write_only self.required = required + self.allow_null = allow_null self.default = default self.source = source self.initial = self.initial if (initial is empty) else initial self.label = label - self.help_text = help_text - self.style = {} if style is None else style - self.allow_null = allow_null - - if self.default_empty_value is not empty: - if default is not empty: - self.default_empty_value = default - - if validators is not None: - self.validators = list(validators) - - # These are set up by `.bind()` when the field is added to a serializer. + self.description = description + self.error_messages = copy.deepcopy(self.default_error_messages) + self.validators = [] if self.default_validators is None else copy.deepcopy(self.default_validators) + if error_messages: + self.error_messages.update(error_messages) + if validators: + self.validators.extend(validators) + + # 为绑定做准备 .bind() self.field_name = None self.parent = None - # Collect default error message from self and parent classes - messages = {} - for cls in reversed(self.__class__.__mro__): - messages.update(getattr(cls, 'default_error_messages', {})) - messages.update(error_messages or {}) - self.error_messages = messages - def bind(self, field_name, parent): - """ - Initializes the field name and parent for the field instance. - Called when a field is added to the parent serializer instance. - """ - - # In order to enforce a consistent style, we error if a redundant - # 'source' argument has been used. For example: - # my_field = serializer.CharField(source='my_field') - assert self.source != field_name, ( - "It is redundant to specify `source='%s'` on field '%s' in " - "serializer '%s', because it is the same as the field name. " - "Remove the `source` keyword argument." % - (field_name, self.__class__.__name__, parent.__class__.__name__) - ) - self.field_name = field_name self.parent = parent - - # `self.label` should default to being based on the field name. - if self.label is None: - self.label = field_name.replace('_', ' ').capitalize() - - # self.source should default to being the same as the field name. if self.source is None: self.source = field_name - - # self.source_attrs is a list of attributes that need to be looked up - # when serializing the instance, or populating the validated data. - if self.source == '*': - self.source_attrs = [] - else: - self.source_attrs = self.source.split('.') - - # .validators is a lazily loaded property, that gets its default - # value from `get_validators`. - @property - def validators(self): - if not hasattr(self, '_validators'): - self._validators = self.get_validators() - return self._validators - - @validators.setter - def validators(self, validators): - self._validators = validators - - def get_validators(self): - return list(self.default_validators) - - def get_initial(self): - """ - Return a value to use when the field is being returned as a primitive - value, without any object instance. - """ - if callable(self.initial): - return self.initial() - return self.initial - - def get_value(self, dictionary): - """ - Given the *incoming* primitive data, return the value for this field - that should be validated and transformed to a native value. - """ - if html.is_html_input(dictionary): - # HTML forms will represent empty fields as '', and cannot - # represent None or False values directly. - if self.field_name not in dictionary: - if getattr(self.root, 'partial', False): - return empty - return self.default_empty_value - ret = dictionary[self.field_name] - if ret == '' and self.allow_null: - # If the field is blank, and null is a valid value then - # determine if we should use null instead. - return '' if getattr(self, 'allow_blank', False) else None - elif ret == '' and not self.required: - # If the field is blank, and emptiness is valid then - # determine if we should use emptiness instead. - return '' if getattr(self, 'allow_blank', False) else empty - return ret - return dictionary.get(self.field_name, empty) - - def get_attribute(self, instance): - """ - Given the *outgoing* object instance, return the primitive value - that should be used for this field. - """ - try: - return get_attribute(instance, self.source_attrs) - except BuiltinSignatureError as exc: - msg = ( - 'Field source for `{serializer}.{field}` maps to a built-in ' - 'function type and is invalid. Define a property or method on ' - 'the `{instance}` instance that wraps the call to the built-in ' - 'function.'.format( - serializer=self.parent.__class__.__name__, - field=self.field_name, - instance=instance.__class__.__name__, - ) - ) - raise type(exc)(msg) - except (KeyError, AttributeError) as exc: - if self.default is not empty: - return self.get_default() - if self.allow_null: - return None - if not self.required: - raise SkipField() - msg = ( - 'Got {exc_type} when attempting to get a value for field ' - '`{field}` on serializer `{serializer}`.\nThe serializer ' - 'field might be named incorrectly and not match ' - 'any attribute or key on the `{instance}` instance.\n' - 'Original exception text was: {exc}.'.format( - exc_type=type(exc).__name__, - field=self.field_name, - serializer=self.parent.__class__.__name__, - instance=instance.__class__.__name__, - exc=exc - ) - ) - raise type(exc)(msg) - - def get_default(self): - """ - Return the default value to use when validating data if no input - is provided for this field. - - If a default has not been set for this field then this will simply - raise `SkipField`, indicating that no value should be set in the - validated data for this field. - """ - if self.default is empty or getattr(self.root, 'partial', False): - # No default, or this is a partial update. - raise SkipField() - if callable(self.default): - if hasattr(self.default, 'set_context'): - warnings.warn( - "Method `set_context` on defaults is deprecated and will " - "no longer be called starting with 3.13. Instead set " - "`requires_context = True` on the class, and accept the " - "context as an additional argument.", - RemovedInDRF313Warning, stacklevel=2 - ) - self.default.set_context(self) - - if getattr(self.default, 'requires_context', False): - return self.default(self) - else: - return self.default() - - return self.default - - def validate_empty_values(self, data): - """ - Validate empty values, and either: - - * Raise `ValidationError`, indicating invalid data. - * Raise `SkipField`, indicating that the field should be ignored. - * Return (True, data), indicating an empty value that should be - returned without any further validation being applied. - * Return (False, data), indicating a non-empty value, that should - have validation applied as normal. - """ - if self.read_only: - return (True, self.get_default()) - - if data is empty: - if getattr(self.root, 'partial', False): - raise SkipField() - if self.required: - self.fail('required') - return (True, self.get_default()) - - if data is None: - if not self.allow_null: - self.fail('null') - # Nullable `source='*'` fields should not be skipped when its named - # field is given a null value. This is because `source='*'` means - # the field is passed the entire object, which is not null. - elif self.source == '*': - return (False, None) - return (True, None) - - return (False, data) - - def run_validation(self, data=empty): - """ - Validate a simple representation and return the internal value. - - The provided data may be `empty` if no representation was included - in the input. - - May raise `SkipField` if the field should not be included in the - validated data. - """ - (is_empty_value, data) = self.validate_empty_values(data) - if is_empty_value: - return data - value = self.to_internal_value(data) - self.run_validators(value) - return value - - def run_validators(self, value): - """ - Test the given value against all the validators on the field, - and either raise a `ValidationError` or simply return. - """ - errors = [] - for validator in self.validators: - if hasattr(validator, 'set_context'): - warnings.warn( - "Method `set_context` on validators is deprecated and will " - "no longer be called starting with 3.13. Instead set " - "`requires_context = True` on the class, and accept the " - "context as an additional argument.", - RemovedInDRF313Warning, stacklevel=2 - ) - validator.set_context(self) - - try: - if getattr(validator, 'requires_context', False): - validator(value, self) - else: - validator(value) - except ValidationError as exc: - # If the validation error contains a mapping of fields to - # errors then simply raise it immediately rather than - # attempting to accumulate a list of errors. - if isinstance(exc.detail, dict): - raise - errors.extend(exc.detail) - except DjangoValidationError as exc: - errors.extend(get_error_detail(exc)) - if errors: - raise ValidationError(errors) - - def to_internal_value(self, data): - """ - Transform the *incoming* primitive data into a native value. - """ - raise NotImplementedError( - '{cls}.to_internal_value() must be implemented for field ' - '{field_name}. If you do not need to support write operations ' - 'you probably want to subclass `ReadOnlyField` instead.'.format( - cls=self.__class__.__name__, - field_name=self.field_name, - ) - ) - - def to_representation(self, value): - """ - Transform the *outgoing* native value into primitive data. - """ - raise NotImplementedError( - '{cls}.to_representation() must be implemented for field {field_name}.'.format( - cls=self.__class__.__name__, - field_name=self.field_name, - ) - ) - - def fail(self, key, **kwargs): - """ - A helper method that simply raises a validation error. - """ - try: - msg = self.error_messages[key] - except KeyError: - class_name = self.__class__.__name__ - msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key) - raise AssertionError(msg) - message_string = msg.format(**kwargs) - raise ValidationError(message_string, code=key) - - @property - def root(self): - """ - Returns the top-level serializer for this field. - """ - root = self - while root.parent is not None: - root = root.parent - return root - - @property - def context(self): - """ - Returns the context as passed to the root serializer on initialization. - """ - return getattr(self.root, '_context', {}) - - def __new__(cls, *args, **kwargs): - """ - When a field is instantiated, we store the arguments that were used, - so that we can present a helpful representation of the object. - """ - instance = super().__new__(cls) - instance._args = args - instance._kwargs = kwargs - return instance - - def __deepcopy__(self, memo): - """ - When cloning fields we instantiate using the arguments it was - originally created with, rather than copying the complete state. - """ - # Treat regexes and validators as immutable. - # See https://github.com/encode/django-rest-framework/issues/1954 - # and https://github.com/encode/django-rest-framework/pull/4489 - args = [ - copy.deepcopy(item) if not isinstance(item, REGEX_TYPE) else item - for item in self._args - ] - kwargs = { - key: (copy.deepcopy(value, memo) if (key not in ('validators', 'regex')) else value) - for key, value in self._kwargs.items() - } - return self.__class__(*args, **kwargs) - - def __repr__(self): - """ - Fields are represented using their initial calling arguments. - This allows us to create descriptive representations for serializer - instances that show all the declared fields on the serializer. - """ - return representation.field_repr(self) -- Gitee From 3c01156c2967c9548186d95dc8a72e229160e9d3 Mon Sep 17 00:00:00 2001 From: LaoSi Date: Thu, 21 Jan 2021 13:56:51 +0800 Subject: [PATCH 04/34] serializers comm --- sanic_rest_framework/exceptions.py | 77 ++++++++++++++++++++++++++++++ sanic_rest_framework/fields.py | 75 +++++++++++++++++++++++++++-- 2 files changed, 147 insertions(+), 5 deletions(-) diff --git a/sanic_rest_framework/exceptions.py b/sanic_rest_framework/exceptions.py index fc388d6..907725c 100644 --- a/sanic_rest_framework/exceptions.py +++ b/sanic_rest_framework/exceptions.py @@ -8,3 +8,80 @@ exceptions.py 序列化器文件 """ + + +class ValidationError(Exception): + """验证器通用错误类 发生错误即抛出此类""" + + def __init__(self, message, code=None, params=None): + super().__init__(message, code, params) + if isinstance(message, ValidationError): + if hasattr(message, 'error_dict'): + message = message.error_dict + elif not hasattr(message, 'message'): + message = message.error_list + else: + message, code, params = message.message, message.code, message.params + if isinstance(message, dict): + self.error_dict = {} + for field, messages in message.items(): + if not isinstance(messages, ValidationError): + messages = ValidationError(messages) + self.error_dict[field] = messages.error_list + elif isinstance(message, list): + self.error_list = [] + for message in message: + if not isinstance(message, ValidationError): + message = ValidationError(message) + if hasattr(message, 'error_dict'): + self.error_list.extend(sum(message.error_dict.values(), [])) + else: + self.error_list.extend(message.error_list) + else: + self.message = message + self.code = code + self.params = params + self.error_list = [self] + + @property + def message_dict(self): + getattr(self, 'error_dict') + return dict(self) + + @property + def messages(self): + if hasattr(self, 'error_dict'): + return sum(dict(self).values(), []) + return list(self) + + def update_error_dict(self, error_dict): + if hasattr(self, 'error_dict'): + for field, error_list in self.error_dict.items(): + error_dict.setdefault(field, []).extend(error_list) + else: + error_dict.setdefault('__all__', []).extend(self.error_list) + return error_dict + + def __iter__(self): + if hasattr(self, 'error_dict'): + for field, errors in self.error_dict.items(): + yield field, list(ValidationError(errors)) + else: + for error in self.error_list: + message = error.message + if error.params: + message %= error.params + yield str(message) + + def __str__(self): + if hasattr(self, 'error_dict'): + return repr(dict(self)) + return repr(list(self)) + + def __repr__(self): + return 'ValidationError(%s)' % self + + def __eq__(self, other): + if not isinstance(other, ValidationError): + return NotImplemented + return hash(self) == hash(other) diff --git a/sanic_rest_framework/fields.py b/sanic_rest_framework/fields.py index 2d65ebc..9a2c649 100644 --- a/sanic_rest_framework/fields.py +++ b/sanic_rest_framework/fields.py @@ -9,6 +9,9 @@ 文件说明 """ import copy +from typing import Any, AnyStr, Optional, List + +from sanic_rest_framework.exceptions import ValidationError class empty: @@ -27,7 +30,7 @@ class Field: _sort_counter = 0 default_error_messages = { - 'required': '{field_name}是必填项', + 'required': '{field_name}是必填项,为空', 'null': '{field_name}不能为空' } default_validators = None @@ -37,9 +40,9 @@ class Field: label=None, description=None, validators=None, ): """ 字段基类 - :param read_only: 是否只反序列化 - :param write_only: 是否只序列化 - :param required: 是否必填 与 read_only 冲突 + :param read_only: 是否只反序列化, 供 serializers 使用 + :param write_only: 是否只序列化, 供 serializers 使用 + :param required: 是否必填 与 read_only 冲突,供 serializers 使用 :param allow_null: 是否允许为空 :param default: 默认值 :param initial: 初始化值 @@ -74,9 +77,71 @@ class Field: # 为绑定做准备 .bind() self.field_name = None self.parent = None + # 存储错误 + self._errors: List[ValidationError] = [] - def bind(self, field_name, parent): + def bind(self, field_name: str, parent): self.field_name = field_name self.parent = parent if self.source is None: self.source = field_name + + # 数据处理 + def to_python(self, data: Any): + """序列化时用到的数据处理函数 + 会在验证器验证前执行,届时验证器接受到的data + 就是 to_python 执行后返回的数据 + """ + raise NotImplementedError( + '{cls}类在继承 Field 类后内部的 .to_python() 必须重写' + '请勿忘记处理 write_only 时的情况'.format(cls=self.__class__.__name__, ) + ) + + def to_string(self, data: Any): + """反序列化时用到的数据处理函数""" + raise NotImplementedError( + '{cls}类在继承 Field 类后内部的 .to_string() 必须重写' + '请勿忘记处理 read_only 时的情况'.format(cls=self.__class__.__name__, ) + ) + + # 验证处理 + def validate(self, data: Any): + """自带的验证方法""" + pass + + def run_validation(self, data: Any, raise_exception: bool = False): + """执行验证器""" + errors = [] + data = self.to_python(data) + try: + self.validate(data) + except ValidationError as e: + if hasattr(e, 'code') and e.code in self.error_messages: + e.message = self.error_messages[e.code] + errors.extend(e.error_list) + + for validator in self.validators: + if callable(validator): + try: + validator(data) + except ValidationError as e: + if hasattr(e, 'code') and e.code in self.error_messages: + e.message = self.error_messages[e.code] + errors.extend(e.error_list) + if errors: + raise ValidationError(errors) + self._errors.extend(errors) + + def validate_empty_values(self, data): + """ + 验证空值 + """ + if self.allow_null: + if data is empty: + return True, None + return False, data + + def add_error(self, error): + if self._errors is None: + self._errors = [] + self._errors.append(error) -- Gitee From ecc0d1823680203e56aafb533c0dea445763b9cd Mon Sep 17 00:00:00 2001 From: LaoSi Date: Fri, 22 Jan 2021 17:58:05 +0800 Subject: [PATCH 05/34] field --- Apps/__init__.py | 8 +- Cofing/__init__.py | 6 +- Cofing/develop.py | 6 +- Cofing/formal.py | 6 +- Cofing/local.py | 8 +- db.py | 6 +- sanic_rest_framework/__init__.py | 6 +- sanic_rest_framework/constant.py | 4 +- sanic_rest_framework/fields.py | 436 +++++++++++++++++++++++++++- sanic_rest_framework/routes.py | 4 +- sanic_rest_framework/serializers.py | 17 ++ sanic_rest_framework/validators.py | 19 ++ 12 files changed, 490 insertions(+), 36 deletions(-) create mode 100644 sanic_rest_framework/validators.py diff --git a/Apps/__init__.py b/Apps/__init__.py index e5646ee..a005523 100644 --- a/Apps/__init__.py +++ b/Apps/__init__.py @@ -2,11 +2,9 @@ @Author: WangYuXiang @E-mile: Hill@3io.cc @CreateTime: 2021/1/15 14:01 -@DependencyLibrary: - TODO: 需要填写依赖安装方法 例:python -m pip install requests -@MainFunction: - 无 -@FileDoc: +@DependencyLibrary: 无 +@MainFunction:无 +@FileDoc: __init__.py 工厂函数 """ diff --git a/Cofing/__init__.py b/Cofing/__init__.py index ad88d06..117b32e 100644 --- a/Cofing/__init__.py +++ b/Cofing/__init__.py @@ -2,10 +2,8 @@ @Author: WangYuXiang @E-mile: Hill@3io.cc @CreateTime: 2021/1/15 14:07 -@DependencyLibrary: - TODO: 需要填写依赖安装方法 例:python -m pip install requests -@MainFunction: - TODO: 需要填写入口函数 +@DependencyLibrary: 无 +@MainFunction:无 @FileDoc: __init__.py 配置初始化文件 diff --git a/Cofing/develop.py b/Cofing/develop.py index aed0f16..2c11a9f 100644 --- a/Cofing/develop.py +++ b/Cofing/develop.py @@ -2,10 +2,8 @@ @Author: WangYuXiang @E-mile: Hill@3io.cc @CreateTime: 2021/1/15 14:04 -@DependencyLibrary: - TODO: 需要填写依赖安装方法 例:python -m pip install requests -@MainFunction: - TODO: 需要填写入口函数 +@DependencyLibrary: 无 +@MainFunction:无 @FileDoc: develop.py 开发环境配置文件 diff --git a/Cofing/formal.py b/Cofing/formal.py index 033ff8d..d8f6e78 100644 --- a/Cofing/formal.py +++ b/Cofing/formal.py @@ -2,10 +2,8 @@ @Author: WangYuXiang @E-mile: Hill@3io.cc @CreateTime: 2021/1/15 14:04 -@DependencyLibrary: - TODO: 需要填写依赖安装方法 例:python -m pip install requests -@MainFunction: - TODO: 需要填写入口函数 +@DependencyLibrary: 无 +@MainFunction:无 @FileDoc: formal.py 正式运行环境配置文件 diff --git a/Cofing/local.py b/Cofing/local.py index 276864e..ac45825 100644 --- a/Cofing/local.py +++ b/Cofing/local.py @@ -2,11 +2,9 @@ @Author: WangYuXiang @E-mile: Hill@3io.cc @CreateTime: 2021/1/15 14:03 -@DependencyLibrary: - TODO: 需要填写依赖安装方法 例:python -m pip install requests -@MainFunction: - TODO: 需要填写入口函数 -@FileDoc: +@DependencyLibrary: 无 +@MainFunction:无 +@FileDoc: local.py 本地运行配置文件 """ diff --git a/db.py b/db.py index 3d62573..c170c4e 100644 --- a/db.py +++ b/db.py @@ -2,10 +2,8 @@ @Author: WangYuXiang @E-mile: Hill@3io.cc @CreateTime: 2021/1/15 14:00 -@DependencyLibrary: - TODO: 需要填写依赖安装方法 例:python -m pip install requests -@MainFunction: - TODO: 需要填写入口函数 +@DependencyLibrary: 无 +@MainFunction:无 @FileDoc: db.py 文件说明 diff --git a/sanic_rest_framework/__init__.py b/sanic_rest_framework/__init__.py index a85a032..bb2d574 100644 --- a/sanic_rest_framework/__init__.py +++ b/sanic_rest_framework/__init__.py @@ -2,10 +2,8 @@ @Author: WangYuXiang @E-mile: Hill@3io.cc @CreateTime: 2021/1/19 15:47 -@DependencyLibrary: - TODO: 需要填写依赖安装方法 例:python -m pip install requests -@MainFunction: - TODO: 需要填写入口函数 +@DependencyLibrary: 无 +@MainFunction:无 @FileDoc: __init__.py 文件说明 diff --git a/sanic_rest_framework/constant.py b/sanic_rest_framework/constant.py index 8ea0801..f5d90b6 100644 --- a/sanic_rest_framework/constant.py +++ b/sanic_rest_framework/constant.py @@ -2,8 +2,8 @@ @Author: WangYuXiang @E-mile: Hill@3io.cc @CreateTime: 2021/1/19 15:45 -@DependencyLibrary: -@MainFunction: +@DependencyLibrary: 无 +@MainFunction:无 @FileDoc: constant.py 全局常量 diff --git a/sanic_rest_framework/fields.py b/sanic_rest_framework/fields.py index 9a2c649..cd5ed39 100644 --- a/sanic_rest_framework/fields.py +++ b/sanic_rest_framework/fields.py @@ -9,6 +9,9 @@ 文件说明 """ import copy +import decimal +import re +from datetime import timezone, timedelta, datetime, date, time from typing import Any, AnyStr, Optional, List from sanic_rest_framework.exceptions import ValidationError @@ -37,7 +40,7 @@ class Field: def __init__(self, read_only=False, write_only=False, required=True, allow_null=False, default=None, initial=empty, source=None, error_messages=None, - label=None, description=None, validators=None, ): + label=None, description=None, validators=None): """ 字段基类 :param read_only: 是否只反序列化, 供 serializers 使用 @@ -109,7 +112,7 @@ class Field: """自带的验证方法""" pass - def run_validation(self, data: Any, raise_exception: bool = False): + def run_validation(self, data: Any): """执行验证器""" errors = [] data = self.to_python(data) @@ -145,3 +148,432 @@ class Field: if self._errors is None: self._errors = [] self._errors.append(error) + + def raise_error(self, key, **kwargs): + """直接返回错误""" + try: + msg = self.error_messages[key] + except KeyError: + class_name = self.__class__.__name__ + msg = "在 {class_name} 类的 error_messages " \ + "属性中未能找到 Key 为 {key} 的错误描述".format(class_name=class_name, key=key) + raise AssertionError(msg) + message_string = msg.format(**kwargs) + raise ValidationError(message_string, code=key) + + +class CharField(Field): + default_error_messages = { + 'required': '{field_name}是必填项', + 'null': '{field_name}不能为空', + 'invalid': '{field_name}出现错误的数据类型,仅支持整字符类型', + 'max_length': '{field_name}最长支持{max_length}个字符', + 'min_length': '{field_name}至少要有{min_length}个字符', + } + + def __init__(self, *args, **kwargs): + self.max_length = kwargs.pop('max_length', None) + self.min_length = kwargs.pop('min_length', None) + self.trim_whitespace = kwargs.pop('trim_whitespace', True) + super(CharField, self).__init__(*args, *kwargs) + # TODO 需要将 MaxLengthValidator 与 MinLengthValidator 添加入 self.validators + + def to_python(self, data: Any): + if isinstance(data, bool) or not isinstance(data, (str, int, float,)): + self.raise_error('invalid', field_name=self.field_name) + value = str(data) + return value.strip() if self.trim_whitespace else value + + def to_string(self, data: Any): + return str(data) + + +class IntegerField(Field): + """整数类型""" + default_error_messages = { + 'required': '{field_name}是必填项', + 'null': '{field_name}不能为空', + 'invalid': '{field_name}出现错误的数据类型,仅支持整数类型', + 'max_value': '{field_name}仅支持小于{max_value}的整数', + 'min_value': '{field_name}仅支持大于{min_value}的整数', + 'max_string_length': '{field_name}仅支持转换长度小于{max_string_length}的整数字符串', + } + re_decimal = re.compile(r'\.0*\s*$') + MAX_STRING_LENGTH = 1000 + + def __init__(self, max_value=None, min_value=None, *args, **kwargs): + self.max_value = max_value + self.min_value = min_value + super(IntegerField, self).__init__(*args, *kwargs) + # TODO 需要将 MaxValueValidator 与 MinValueValidator 添加入 self.validators + + def to_python(self, data: Any): + if isinstance(data, str) and len(data) > self.MAX_STRING_LENGTH: + self.raise_error('max_string_length', field_name=self.field_name, max_string_length=self.MAX_STRING_LENGTH) + try: + data = int(self.re_decimal.sub('', str(data))) + except (ValueError, TypeError): + self.raise_error('invalid', field_name=self.field_name) + return data + + def to_string(self, data: Any): + return int(data) + + +class FloatField(IntegerField): + """浮点类型""" + default_error_messages = { + 'required': '{field_name}是必填项', + 'null': '{field_name}不能为空', + 'invalid': '{field_name}出现错误的数据类型,仅支持浮点类型', + 'max_value': '{field_name}仅支持小于{max_value}浮点', + 'min_value': '{field_name}仅支持大于{min_value}浮点', + 'max_string_length': '{field_name}仅支持转换长度小于{max_string_length}的浮点字符串', + } + MAX_STRING_LENGTH = 1000 + + def to_python(self, data: Any): + if isinstance(data, str) and len(data) > self.MAX_STRING_LENGTH: + self.raise_error('max_string_length', field_name=self.field_name, max_string_length=self.MAX_STRING_LENGTH) + try: + return float(data) + except (TypeError, ValueError): + self.raise_error('invalid', field_name=self.field_name) + + def to_string(self, data: Any): + return float(data) + + +class DecimalField(Field): + """十进制类型""" + default_error_messages = { + 'required': '{field_name}是必填项', + 'null': '{field_name}不能为空', + 'invalid': '{field_name}出现错误的数据类型,仅支持Decimal十进制类型', + 'max_value': '{field_name}仅支持小于{max_value}Decimal十进制类型', + 'min_value': '{field_name}仅支持大于{min_value}Decimal十进制类型', + 'max_string_length': '{field_name}仅支持转换长度小于{max_string_length}的Decimal十进制字符串', + } + MAX_STRING_LENGTH = 1000 + + def __init__(self, max_digits, decimal_places, coerce_to_string=False, max_value=None, min_value=None, rounding=None, *args, **kwargs): + self.max_digits = max_digits + self.decimal_places = decimal_places + self.coerce_to_string = coerce_to_string + self.max_value = max_value + self.min_value = min_value + self.rounding = rounding + if self.max_digits is not None and self.decimal_places is not None: + self.max_whole_digits = self.max_digits - self.decimal_places + else: + self.max_whole_digits = None + super(DecimalField, self).__init__(*args, **kwargs) + # TODO 需要将 MaxLengthValidator 与 MinLengthValidator 添加入 self.validators + + def to_python(self, data: Any): + data = str(data).strip() + + if len(data) > self.MAX_STRING_LENGTH: + self.raise_error('max_string_length', field_name=self.field_name, max_string_length=self.MAX_STRING_LENGTH) + try: + data = decimal.Decimal(data) + except decimal.DecimalException: + self.raise_error('invalid', field_name=self.field_name) + + if data.is_nan(): + self.raise_error('invalid', field_name=self.field_name) + + # 检查无穷大和负无穷大。 + if data in (decimal.Decimal('Inf'), decimal.Decimal('-Inf')): + self.raise_error('invalid', field_name=self.field_name) + + return self.quantize(self.validate_precision(data)) + + def to_string(self, data: Any): + if not isinstance(data, decimal.Decimal): + data = decimal.Decimal(str(data).strip()) + + quantized = self.quantize(data) + + if not self.coerce_to_string: + return quantized + return '{:f}'.format(quantized) + + def validate_precision(self, value): + """ + 确保数字中的位数不超过max_digits,并且小数点后超过十进制的位数。 + 覆盖此方法以禁用输入的精度验证值或以您需要的任何方式增强它。 + """ + sign, digittuple, exponent = value.as_tuple() + + if exponent >= 0: + # 1234500.0 + total_digits = len(digittuple) + exponent + whole_digits = total_digits + decimal_places = 0 + elif len(digittuple) > abs(exponent): + # 123.45 + total_digits = len(digittuple) + whole_digits = total_digits - abs(exponent) + decimal_places = abs(exponent) + else: + # 0.001234 + total_digits = abs(exponent) + whole_digits = 0 + decimal_places = total_digits + + if self.max_digits is not None and total_digits > self.max_digits: + self.raise_error('max_digits', max_digits=self.max_digits) + if self.decimal_places is not None and decimal_places > self.decimal_places: + self.raise_error('max_decimal_places', max_decimal_places=self.decimal_places) + if self.max_whole_digits is not None and whole_digits > self.max_whole_digits: + self.raise_error('max_whole_digits', max_whole_digits=self.max_whole_digits) + return value + + def quantize(self, value): + """ + 将十进制值量化为配置的精度。 + """ + if self.decimal_places is None: + return value + + context = decimal.getcontext().copy() + if self.max_digits is not None: + context.prec = self.max_digits + return value.quantize( + decimal.Decimal('.1') ** self.decimal_places, + rounding=self.rounding, + context=context + ) + + +class BooleanField(Field): + default_error_messages = { + 'required': '{field_name}是必填项', + 'null': '{field_name}不能为空', + 'invalid': '{field_name}出现错误的数据类型,{value}不是有效的布尔值', + } + TRUE_VALUES = { + 't', 'T', + 'y', 'Y', 'yes', 'YES', + 'true', 'True', 'TRUE', + 'on', 'On', 'ON', + '1', 1, + True + } + FALSE_VALUES = { + 'f', 'F', + 'n', 'N', 'no', 'NO', + 'false', 'False', 'FALSE', + 'off', 'Off', 'OFF', + '0', 0, 0.0, + False + } + NULL_VALUES = {'null', 'Null', 'NULL', '', None} + + def to_python(self, data: Any): + try: + if data in self.TRUE_VALUES: + return True + elif data in self.FALSE_VALUES: + return False + elif data in self.NULL_VALUES and self.allow_null: + return None + except TypeError: + self.raise_error('invalid', field_name=self.field_name, value=data) + + def to_string(self, data: Any): + if data in self.TRUE_VALUES: + return True + elif data in self.FALSE_VALUES: + return False + if data in self.NULL_VALUES and self.allow_null: + return None + return bool(data) + + +class DateTimeField(Field): + default_error_messages = { + 'required': '{field_name}是必填项', + 'null': '{field_name}不能为空', + 'invalid': '{field_name}出现错误的数据类型,{value}不是有效的日期时间类型', + 'date': '{field_name}需要的是日期时间格式而不是日期格式', + 'overflow': '{field_name}时间超出范围' + } + + def __init__(self, output_format='%Y-%m-%d %H:%M:%S', input_formats='%Y-%m-%d %H:%M:%S', set_timezone: timezone = None, *args, **kwargs): + self.output_format = output_format + self.input_formats = input_formats + if set_timezone is not None: + self.set_timezone = set_timezone + else: + self.set_timezone = self.get_default_timezone() + super(DateTimeField, self).__init__(*args, **kwargs) + + def get_default_timezone(self): + """设置默认时区为北京时间""" + return timezone(timedelta(hours=8)) + + def enforce_timezone(self, value): + """强制设置一个时区""" + return value.astimezone(self.set_timezone) + + def to_python(self, data: Any): + if not isinstance(data, (str, data, datetime)): + self.raise_error('invalid', field_name=self.field_name) + + if isinstance(data, str): + try: + data = datetime.strptime(data, self.input_formats) + except (ValueError, TypeError): + self.raise_error('invalid', field_name=self.field_name) + if isinstance(data, date): + self.raise_error('date', field_name=self.field_name) + if isinstance(data, datetime): + data = self.enforce_timezone(data) + return data + + def to_string(self, data: Any): + if not data: + return None + if isinstance(data, str): + return data + if isinstance(data, datetime): + return data.strftime(self.output_format) + self.raise_error('invalid', field_name=self.field_name) + + +class DateField(Field): + """日期字段""" + default_error_messages = { + 'required': '{field_name}是必填项', + 'null': '{field_name}不能为空', + 'invalid': '{field_name}出现错误的数据类型,{value}不是有效的日期时间类型', + 'datetime': '{field_name}需要的是日期格式而不是日期时间格式', + } + + def __init__(self, output_format='%Y-%m-%d', input_formats='%Y-%m-%d', *args, **kwargs): + self.output_format = output_format + self.input_formats = input_formats + super(DateField, self).__init__(*args, **kwargs) + + def to_python(self, data: Any): + if isinstance(data, str): + try: + data = datetime.strptime(data, self.input_formats).date() + return data + except (ValueError, TypeError): + self.raise_error('invalid', field_name=self.field_name) + if isinstance(data, datetime): + self.raise_error('datetime', field_name=self.field_name) + if isinstance(data, date): + return data + self.raise_error('invalid', field_name=self.field_name) + + def to_string(self, data: Any): + if not data: + return data + if isinstance(data, str): + return data + if isinstance(data, date): + return data.strftime(self.output_format) + self.raise_error('invalid', field_name=self.field_name) + + +class TimeField(Field): + """时间字段""" + default_error_messages = { + 'required': '{field_name}是必填项', + 'null': '{field_name}不能为空', + 'invalid': '{field_name}出现错误的数据类型,{value}不是有效的日期时间类型', + 'date': '{field_name}需要的是时间格式而不是日期格式', + } + + def __init__(self, output_format='%H:%M:%S', input_formats='%H:%M:%S', *args, **kwargs): + self.output_format = output_format + self.input_formats = input_formats + super(TimeField, self).__init__(*args, **kwargs) + + def to_python(self, data: Any): + if isinstance(data, str): + try: + data = datetime.strptime(data, self.input_formats).time() + return data + except (ValueError, TypeError): + self.raise_error('invalid', field_name=self.field_name) + if isinstance(data, datetime): + return data.time() + if isinstance(data, date): + self.raise_error('date', field_name=self.field_name) + self.raise_error('invalid', field_name=self.field_name) + + def to_string(self, data: Any): + if not data: + return data + if isinstance(data, str): + return data + if isinstance(data, (time, datetime)): + return data.strftime(self.output_format) + self.raise_error('invalid', field_name=self.field_name) + + +class ChoiceField(Field): + """多选字段""" + default_error_messages = { + 'required': '{field_name}是必填项', + 'null': '{field_name}不能为空', + 'invalid': '{field_name}出现错误的数据类型,{value}不是有效的日期时间类型', + 'key': '{field_name}没有{key}选项', + } + + def __init__(self, choices, *args, **kwargs): + """ + + :param choices: (('key','value'),('key','value'),) + :param args: + :param kwargs: + """ + self.choices = choices + super(ChoiceField, self).__init__(*args, **kwargs) + + def to_python(self, data: Any): + data = self.get_choices().get(str(data)) + return data + + def to_string(self, data: Any): + return str(data) + + def choices_get_value_by_key(self, key): + """得到字符串""" + choices = self.get_choices() + if key not in choices: + self.raise_error('key', field_name=self.field_name, key=key) + return choices.get(key, None) + + def get_choices(self) -> dict: + choices = {str(key): value for key, value in self.choices} + return choices + + +class MultipleChoiceField(Field): + pass + + +class RegexField(Field): + pass + + +class EmailField(Field): + pass + + +class URLField(Field): + pass + + +class UUIDField(Field): + pass + + +class IPAddressField(Field): + pass diff --git a/sanic_rest_framework/routes.py b/sanic_rest_framework/routes.py index fd5e046..d09c107 100644 --- a/sanic_rest_framework/routes.py +++ b/sanic_rest_framework/routes.py @@ -2,8 +2,8 @@ @Author: WangYuXiang @E-mile: Hill@3io.cc @CreateTime: 2021/1/19 16:08 -@DependencyLibrary: -@MainFunction: +@DependencyLibrary: 无 +@MainFunction:无 @FileDoc: routes.py 便捷路由文件 diff --git a/sanic_rest_framework/serializers.py b/sanic_rest_framework/serializers.py index 2907fdd..c1f1f4b 100644 --- a/sanic_rest_framework/serializers.py +++ b/sanic_rest_framework/serializers.py @@ -8,3 +8,20 @@ serializers.py 序列化器文件 """ +from typing import Any + +from sanic_rest_framework.fields import Field + + +class Serializer(Field): + """序列化器""" + + def __init__(self, data=None, *args, **kwargs): + super(Serializer, self).__init__(*args, **kwargs) + self.data = data + + def to_python(self, data: Any): + return data + + def to_string(self, data: Any): + return data diff --git a/sanic_rest_framework/validators.py b/sanic_rest_framework/validators.py new file mode 100644 index 0000000..42703ed --- /dev/null +++ b/sanic_rest_framework/validators.py @@ -0,0 +1,19 @@ +""" +@Author:WangYuXiang +@E-mile:Hill@3io.cc +@CreateTime:2021/1/22 10:41 +@DependencyLibrary:无 +@MainFunction:无 +@FileDoc: + validators.py + 文件说明 +@ChangeHistory: + datetime action why + example: + 2021/1/22 10:41 change 'Fix bug' + +""" + + +class MaxLengthValidator(): + pass -- Gitee From 5aa121aafe4675d47e904b148646b396ef169c1f Mon Sep 17 00:00:00 2001 From: LaoSi Date: Sat, 23 Jan 2021 00:14:55 +0800 Subject: [PATCH 06/34] SerializerMetaclass --- sanic_rest_framework/fields.py | 33 ++--------- sanic_rest_framework/serializers.py | 92 +++++++++++++++++++++++++++-- 2 files changed, 91 insertions(+), 34 deletions(-) diff --git a/sanic_rest_framework/fields.py b/sanic_rest_framework/fields.py index cd5ed39..f81dbdd 100644 --- a/sanic_rest_framework/fields.py +++ b/sanic_rest_framework/fields.py @@ -256,7 +256,8 @@ class DecimalField(Field): } MAX_STRING_LENGTH = 1000 - def __init__(self, max_digits, decimal_places, coerce_to_string=False, max_value=None, min_value=None, rounding=None, *args, **kwargs): + def __init__(self, max_digits, decimal_places, coerce_to_string=False, max_value=None, min_value=None, + rounding=None, *args, **kwargs): self.max_digits = max_digits self.decimal_places = decimal_places self.coerce_to_string = coerce_to_string @@ -401,7 +402,8 @@ class DateTimeField(Field): 'overflow': '{field_name}时间超出范围' } - def __init__(self, output_format='%Y-%m-%d %H:%M:%S', input_formats='%Y-%m-%d %H:%M:%S', set_timezone: timezone = None, *args, **kwargs): + def __init__(self, output_format='%Y-%m-%d %H:%M:%S', input_formats='%Y-%m-%d %H:%M:%S', + set_timezone: timezone = None, *args, **kwargs): self.output_format = output_format self.input_formats = input_formats if set_timezone is not None: @@ -518,7 +520,7 @@ class TimeField(Field): class ChoiceField(Field): - """多选字段""" + """限定可选的字段""" default_error_messages = { 'required': '{field_name}是必填项', 'null': '{field_name}不能为空', @@ -528,7 +530,6 @@ class ChoiceField(Field): def __init__(self, choices, *args, **kwargs): """ - :param choices: (('key','value'),('key','value'),) :param args: :param kwargs: @@ -553,27 +554,3 @@ class ChoiceField(Field): def get_choices(self) -> dict: choices = {str(key): value for key, value in self.choices} return choices - - -class MultipleChoiceField(Field): - pass - - -class RegexField(Field): - pass - - -class EmailField(Field): - pass - - -class URLField(Field): - pass - - -class UUIDField(Field): - pass - - -class IPAddressField(Field): - pass diff --git a/sanic_rest_framework/serializers.py b/sanic_rest_framework/serializers.py index c1f1f4b..863d798 100644 --- a/sanic_rest_framework/serializers.py +++ b/sanic_rest_framework/serializers.py @@ -8,20 +8,100 @@ serializers.py 序列化器文件 """ +import copy +from collections import OrderedDict +from functools import cached_property from typing import Any -from sanic_rest_framework.fields import Field +from sanic_rest_framework.fields import Field, empty +from .exceptions import ValidationError +from .helpers import BindingDict -class Serializer(Field): +class BaseSerializer(Field): """序列化器""" - def __init__(self, data=None, *args, **kwargs): - super(Serializer, self).__init__(*args, **kwargs) + def __init__(self, data=empty, **kwargs): + super(BaseSerializer, self).__init__(**kwargs) self.data = data def to_python(self, data: Any): - return data + raise NotImplementedError('`to_python()` must be implemented.') def to_string(self, data: Any): - return data + raise NotImplementedError('`to_string()` must be implemented.') + + def update(self, instance, validated_data): + raise NotImplementedError('`update()` must be implemented.') + + def create(self, validated_data): + raise NotImplementedError('`create()` must be implemented.') + + def validate(self, raise_exception: bool = False): + """ + 验证函数 + :param raise_exception: 是否直接抛出错误 + :return: + """ + raise NotImplementedError('`validate()` must be implemented.') + + +class SerializerMetaclass(type): + + @classmethod + def _get_declared_fields(cls, bases, attrs): + fields = [(field_name, attrs.pop(field_name)) + for field_name, obj in list(attrs.items()) + if isinstance(obj, Field)] + fields.sort(key=lambda x: x[1]._sort_counter) + + known = set(attrs) + + def visit(name): + known.add(name) + return name + + base_fields = [] + for base in bases: + if hasattr(base, '_declared_fields'): + for name, field in base._declared_fields.items(): + if name not in known: + base_fields.append((visit(name), field)) + + return OrderedDict(base_fields + fields) + + def __new__(cls, name, bases, attrs): + attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs) + return super().__new__(cls, name, bases, attrs) + + +class Serializer(BaseSerializer, metaclass=SerializerMetaclass): + pass + + @cached_property + def fields(self): + """ + 单个格式为 {field_name: field_instance}. + fields 是动态加载的 避免在导入时出现意想不到的错误 + """ + # like drf + fields = BindingDict(self) + for key, value in self.get_fields().items(): + fields[key] = value + return fields + + def get_fields(self) -> dict: + """ + 得到所有fields + :return: + """ + return copy.deepcopy(self._declared_fields) + + def get_validators(self): + """ + 得到所有属于序列化器的验证器,存在于 Mete 类中 + """ + meta = getattr(self, 'Meta', None) + validators = getattr(meta, 'validators', None) + return list(validators) if validators else [] + -- Gitee From ff62b4bd9584f4a8aadf68cabbdcc889ec1e4d9e Mon Sep 17 00:00:00 2001 From: LaoSi Date: Sun, 24 Jan 2021 23:54:17 +0800 Subject: [PATCH 07/34] =?UTF-8?q?=E4=BF=AE=E6=94=B9field=E4=BA=8Eserialize?= =?UTF-8?q?rs=E7=9A=84=E5=85=B1=E7=94=A8=E5=9F=BA=E7=B1=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sanic_rest_framework/fields.py | 266 +++++++++++++++++++++++++++- sanic_rest_framework/helpers.py | 42 +++++ sanic_rest_framework/serializers.py | 31 ++++ 3 files changed, 337 insertions(+), 2 deletions(-) create mode 100644 sanic_rest_framework/helpers.py diff --git a/sanic_rest_framework/fields.py b/sanic_rest_framework/fields.py index f81dbdd..4c2202d 100644 --- a/sanic_rest_framework/fields.py +++ b/sanic_rest_framework/fields.py @@ -12,10 +12,16 @@ import copy import decimal import re from datetime import timezone, timedelta, datetime, date, time -from typing import Any, AnyStr, Optional, List +from typing import Any, AnyStr, Optional, List, Mapping +from functools import cached_property +from tortoise import Model +from tortoise.queryset import QuerySet +from tortoise.exceptions import DoesNotExist from sanic_rest_framework.exceptions import ValidationError +REGEX_TYPE = type(re.compile('')) + class empty: """ @@ -25,10 +31,231 @@ class empty: pass +class SkipField(Exception): + pass + + NOT_RAED_ONLY_AND_WRITE_ONLY = 'read_only 和 write_only 不能同时为True, 只能二选一' NOT_RAED_ONLY_REQUIRED_ONLY = 'read_only 为 True 时 required 不能为True , 只能二选一' +class BaseField: + _sort_counter = 0 + # 所有field都强制拥有的错误提示 + base_error_messages = { + 'required': '此字段为必填项,提交时必须携带', + 'null': '此字段不能为空', + } + default_error_messages = None + + default_validators = None + + def __init__(self, read_only=False, write_only=False, required=True, allow_null=False, + default=empty, source=None, validators=None, error_messages=None, + label=None, description=None + ): + """ + 字段及field的基类 + :param read_only: 只反序列化 + :param write_only: 只序列化 + :param required: 序列化时必须存在此值 + :param allow_null: 序列化可以为空 + :param default: 默认值 可用于序列化和反序列化 + :param source: 反序列化是值的来源 + :param validators: 序列化时需要通过的验证 + :param error_messages: 出现错误时的自定义描述 + :param label: 字段标题 + :param description: 字段描述 + """ + self._sort_counter = BaseField._sort_counter + BaseField._sort_counter += 1 + + self.read_only = read_only + self.write_only = write_only + self.required = required + self.allow_null = allow_null + self.default = default + self.source = source + self.label = label + self.description = description + + self.validators = self.collect_validators([validators, self.default_validators]) + self.error_messages = self.collect_error_message( + [self.base_error_messages, self.default_error_messages, error_messages]) + + self.field_name = None + self.parent = None + + def __new__(cls, *args, **kwargs): + """ + 当一个字段被实例化时,我们存储所使用的参数, + 这样我就可以在 __deepcopy__ 提供他们 + """ + instance = super().__new__(cls) + instance._args = args + instance._kwargs = kwargs + return instance + + def __deepcopy__(self, memo): + """ + 当克隆字段时,我们使用参数实例化它 + 最初创建时用的,而不是复制完整的状态。 + """ + args = [ + copy.deepcopy(item) if not isinstance(item, REGEX_TYPE) else item + for item in self._args + ] + kwargs = { + key: (copy.deepcopy(value, memo) if (key not in ('validators', 'regex')) else value) + for key, value in self._kwargs.items() + } + return self.__class__(*args, **kwargs) + + def bind(self, field_name, parent): + + self.field_name = field_name + self.parent = parent + + if self.source is None: + self.source = self.field_name + if '.' in self.source: + self.source_attr = self.source.split('.') + else: + self.source_attr = [self.source] + + def collect_error_message(self, error_messages_list: List[dict]) -> dict: + """ + 收集错误提示 + :param error_messages_list: 错误提示列表 + :return: + """ + error_messages = {} + for error_message in error_messages_list: + if error_message is not None: + error_messages.update(error_message) + return error_messages + + def collect_validators(self, validators_list: List[list]) -> list: + """ + 收集所有验证器 + :param validators_list: 验证器列表 + :return: + """ + validators = [] + for validator_list in validators_list: + if validator_list is not None: + for validator in validator_list: + if validator not in validators: + validators.append(validator) + return validators + + def is_partial(self, root=None): + """当请求为部分修改的时候,返回 True """ + if root is None: + root = self.root + return getattr(root, 'partial', False) + + def internal_convert(self, data: Any) -> Any: + """对数据进行序列化转换并返回""" + raise NotImplementedError( + '{cls}类的 .internal_convert 方法必须重写'.format(cls=self.__class__.__name__, ) + ) + + def external_convert(self, data: Any) -> Any: + """对数据进行反序列化转换并返回""" + raise NotImplementedError( + '{cls}类的 .internal_convert 方法必须重写'.format(cls=self.__class__.__name__, ) + ) + + def get_value_to_internal(self, data: Mapping) -> Any: + """ + 从传入的外部数据中得到值 + 值用于输入验证 + :param data: *外部* 数据 + :return: + """ + if not isinstance(data, Mapping): + raise ValidationError('{field_name}传入的数据为无效数据类型,仅支持字段类型'.format(field_name=self.field_name)) + if self.field_name not in data: + if self.is_partial(): + return empty + return self.default + value = data.get(self.field_name) + return value + + async def async_get_attribute(self, instance, attr): + return await getattr(instance, attr) + + def get_value_to_external(self, instance: Any) -> Any: + """ + 从传入的内部数据中得到值 + 值用于输出 + :param instance: *内部* 数据 + :return: + """ + + for attr in self.source_attr: + try: + if isinstance(instance, Mapping): + instance = instance[attr] + else: + instance = self.async_get_attribute(instance, attr) + except DoesNotExist: + return None + return instance + + def run_validators(self, data) -> None: + """ + 使用验证器验证传入的数据 + 直接抛出错误 + :param data: + :return: 无返回值 + """ + errors = [] + for validator in self.validators: + try: + validator(data, self) + except ValidationError as exc: + if hasattr(exc, 'code') and exc.code in self.error_messages: + exc.message = self.error_messages[exc.code] + errors.extend(exc.error_list) + if errors: + raise ValidationError(errors) + + def run_validation(self, data): + """执行验证""" + value = self.internal_convert(data) + self.run_validators(value) + return value + + @property + def root(self): + """ + 得到字段的最高级父级 + """ + root = self + while root.parent is not None: + root = root.parent + return root + + def raise_error(self, key, **kwargs): + """ + 返回在 error_messages 中注册了的错误 + :param key: 错误的键 + :param kwargs: + :return: + """ + try: + msg = self.error_messages[key] + except KeyError: + class_name = self.__class__.__name__ + msg = "在 {class_name} 类的 error_messages " \ + "属性中未能找到 Key 为 {key} 的错误描述".format(class_name=class_name, key=key) + raise AssertionError(msg) + message_string = msg.format(**kwargs) + raise ValidationError(message_string, code=key) + + class Field: _sort_counter = 0 @@ -107,7 +334,16 @@ class Field: '请勿忘记处理 read_only 时的情况'.format(cls=self.__class__.__name__, ) ) + def get_value(self, data: Any): + """得到 self.data 中的数据 数据可能来自 dict model request.query""" + if isinstance(data, Mapping): + if self.source in data: + return data[self.source] + elif isinstance(data, Model): + return await getattr(data, self.source) + # 验证处理 + def validate(self, data: Any): """自带的验证方法""" pass @@ -132,8 +368,9 @@ class Field: e.message = self.error_messages[e.code] errors.extend(e.error_list) if errors: + self._errors.extend(errors) raise ValidationError(errors) - self._errors.extend(errors) + return data def validate_empty_values(self, data): """ @@ -161,6 +398,31 @@ class Field: message_string = msg.format(**kwargs) raise ValidationError(message_string, code=key) + def __new__(cls, *args, **kwargs): + """ + 当一个字段被实例化时,我们存储所使用的参数, + 这样我就可以在 __deepcopy__ 提供他们 + """ + instance = super().__new__(cls) + instance._args = args + instance._kwargs = kwargs + return instance + + def __deepcopy__(self, memo): + """ + 当克隆字段时,我们使用参数实例化它 + 最初创建时用的,而不是复制完整的状态。 + """ + args = [ + copy.deepcopy(item) if not isinstance(item, REGEX_TYPE) else item + for item in self._args + ] + kwargs = { + key: (copy.deepcopy(value, memo) if (key not in ('validators', 'regex')) else value) + for key, value in self._kwargs.items() + } + return self.__class__(*args, **kwargs) + class CharField(Field): default_error_messages = { diff --git a/sanic_rest_framework/helpers.py b/sanic_rest_framework/helpers.py new file mode 100644 index 0000000..a9ddb8b --- /dev/null +++ b/sanic_rest_framework/helpers.py @@ -0,0 +1,42 @@ +""" +@Author: WangYuXiang +@E-mile: Hill@3io.cc +@CreateTime: 2021/1/22 22:48 +@DependencyLibrary: 无 +@MainFunction:无 +@FileDoc: + helpers.py + +""" +from collections import OrderedDict, MutableMapping + + +class BindingDict(MutableMapping): + """ + 这个类似于dict的对象用于在序列化器中存储字段。 + 这确保了无论何时将字段添加到我们调用的序列化器中 + field.bind() 使 field_name 和 parent 属性可以正确设置。 + """ + + def __init__(self, serializer): + self.serializer = serializer + self.fields = OrderedDict() + + def __setitem__(self, key, field): + self.fields[key] = field + field.bind(field_name=key, parent=self.serializer) + + def __getitem__(self, key): + return self.fields[key] + + def __delitem__(self, key): + del self.fields[key] + + def __iter__(self): + return iter(self.fields) + + def __len__(self): + return len(self.fields) + + def __repr__(self): + return dict.__repr__(self.fields) diff --git a/sanic_rest_framework/serializers.py b/sanic_rest_framework/serializers.py index 863d798..320d8f3 100644 --- a/sanic_rest_framework/serializers.py +++ b/sanic_rest_framework/serializers.py @@ -97,6 +97,18 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass): """ return copy.deepcopy(self._declared_fields) + @property + def _writable_fields(self): + for field in self.fields.values(): + if not field.read_only: + yield field + + @property + def _readable_fields(self): + for field in self.fields.values(): + if not field.write_only: + yield field + def get_validators(self): """ 得到所有属于序列化器的验证器,存在于 Mete 类中 @@ -105,3 +117,22 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass): validators = getattr(meta, 'validators', None) return list(validators) if validators else [] + def run_validation(self, data: Any): + pass + + @property + def data(self) -> OrderedDict: + if self.initial is empty: + return OrderedDict() + data = OrderedDict( + [(field.field_name, field.to_string(self.initial)) for field in self._readable_fields] + ) + return data + + @property + def error(self): + pass + + def __iter__(self): + for field in self.fields.values(): + yield self[field.field_name] -- Gitee From 9fa50ce2a2e342b27b079d6486aee600585ba4c6 Mon Sep 17 00:00:00 2001 From: LaoSi Date: Mon, 25 Jan 2021 17:56:53 +0800 Subject: [PATCH 08/34] =?UTF-8?q?=E4=BF=AE=E6=94=B9=20serializers=20?= =?UTF-8?q?=E6=B6=88=E9=99=A4=20drf=E6=80=9D=E6=83=B3=E5=AF=B9=E6=9C=AC?= =?UTF-8?q?=E9=A1=B9=E7=9B=AE=E7=9A=84=E5=BD=B1=E5=93=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sanic_rest_framework/fields.py | 543 +++++++++++++++------------- sanic_rest_framework/serializers.py | 194 +++++++--- 2 files changed, 430 insertions(+), 307 deletions(-) diff --git a/sanic_rest_framework/fields.py b/sanic_rest_framework/fields.py index 4c2202d..241e4cd 100644 --- a/sanic_rest_framework/fields.py +++ b/sanic_rest_framework/fields.py @@ -39,7 +39,8 @@ NOT_RAED_ONLY_AND_WRITE_ONLY = 'read_only 和 write_only 不能同时为True, NOT_RAED_ONLY_REQUIRED_ONLY = 'read_only 为 True 时 required 不能为True , 只能二选一' -class BaseField: +class Field: + """字段及序列化器基类""" _sort_counter = 0 # 所有field都强制拥有的错误提示 base_error_messages = { @@ -52,8 +53,7 @@ class BaseField: def __init__(self, read_only=False, write_only=False, required=True, allow_null=False, default=empty, source=None, validators=None, error_messages=None, - label=None, description=None - ): + label=None, description=None): """ 字段及field的基类 :param read_only: 只反序列化 @@ -67,8 +67,8 @@ class BaseField: :param label: 字段标题 :param description: 字段描述 """ - self._sort_counter = BaseField._sort_counter - BaseField._sort_counter += 1 + self._sort_counter = Field._sort_counter + Field._sort_counter += 1 self.read_only = read_only self.write_only = write_only @@ -112,16 +112,22 @@ class BaseField: return self.__class__(*args, **kwargs) def bind(self, field_name, parent): - + """ + 提供给父级的绑定 + :param field_name: + :param parent: + :return: + """ self.field_name = field_name self.parent = parent if self.source is None: self.source = self.field_name - if '.' in self.source: - self.source_attr = self.source.split('.') + + if self.source == '*': + self.source_attrs = [] else: - self.source_attr = [self.source] + self.source_attr = self.source.split('.') def collect_error_message(self, error_messages_list: List[dict]) -> dict: """ @@ -256,181 +262,179 @@ class BaseField: raise ValidationError(message_string, code=key) -class Field: - _sort_counter = 0 - - default_error_messages = { - 'required': '{field_name}是必填项,为空', - 'null': '{field_name}不能为空' - } - default_validators = None - - def __init__(self, read_only=False, write_only=False, required=True, allow_null=False, - default=None, initial=empty, source=None, error_messages=None, - label=None, description=None, validators=None): - """ - 字段基类 - :param read_only: 是否只反序列化, 供 serializers 使用 - :param write_only: 是否只序列化, 供 serializers 使用 - :param required: 是否必填 与 read_only 冲突,供 serializers 使用 - :param allow_null: 是否允许为空 - :param default: 默认值 - :param initial: 初始化值 - :param source: 来源 getattr(initial,source) - :param label: 易懂的语言描述 列如:机器编号, 呈现在ApiDoc内 - :param description: 详细描述 呈现在ApiDoc内 - :param validators: 自定义的验证器 - """ - # 优先检查冲突 - assert (read_only and write_only), NOT_RAED_ONLY_AND_WRITE_ONLY - assert (read_only and required), NOT_RAED_ONLY_REQUIRED_ONLY - - self._sort_counter = Field._sort_counter - Field._sort_counter += 1 - - self.read_only = read_only - self.write_only = write_only - self.required = required - self.allow_null = allow_null - self.default = default - self.source = source - self.initial = self.initial if (initial is empty) else initial - self.label = label - self.description = description - self.error_messages = copy.deepcopy(self.default_error_messages) - self.validators = [] if self.default_validators is None else copy.deepcopy(self.default_validators) - if error_messages: - self.error_messages.update(error_messages) - if validators: - self.validators.extend(validators) - - # 为绑定做准备 .bind() - self.field_name = None - self.parent = None - # 存储错误 - self._errors: List[ValidationError] = [] - - def bind(self, field_name: str, parent): - self.field_name = field_name - self.parent = parent - if self.source is None: - self.source = field_name - - # 数据处理 - def to_python(self, data: Any): - """序列化时用到的数据处理函数 - 会在验证器验证前执行,届时验证器接受到的data - 就是 to_python 执行后返回的数据 - """ - raise NotImplementedError( - '{cls}类在继承 Field 类后内部的 .to_python() 必须重写' - '请勿忘记处理 write_only 时的情况'.format(cls=self.__class__.__name__, ) - ) - - def to_string(self, data: Any): - """反序列化时用到的数据处理函数""" - raise NotImplementedError( - '{cls}类在继承 Field 类后内部的 .to_string() 必须重写' - '请勿忘记处理 read_only 时的情况'.format(cls=self.__class__.__name__, ) - ) - - def get_value(self, data: Any): - """得到 self.data 中的数据 数据可能来自 dict model request.query""" - if isinstance(data, Mapping): - if self.source in data: - return data[self.source] - elif isinstance(data, Model): - return await getattr(data, self.source) - - # 验证处理 - - def validate(self, data: Any): - """自带的验证方法""" - pass - - def run_validation(self, data: Any): - """执行验证器""" - errors = [] - data = self.to_python(data) - try: - self.validate(data) - except ValidationError as e: - if hasattr(e, 'code') and e.code in self.error_messages: - e.message = self.error_messages[e.code] - errors.extend(e.error_list) - - for validator in self.validators: - if callable(validator): - try: - validator(data) - except ValidationError as e: - if hasattr(e, 'code') and e.code in self.error_messages: - e.message = self.error_messages[e.code] - errors.extend(e.error_list) - if errors: - self._errors.extend(errors) - raise ValidationError(errors) - return data - - def validate_empty_values(self, data): - """ - 验证空值 - """ - if self.allow_null: - if data is empty: - return True, None - return False, data - - def add_error(self, error): - if self._errors is None: - self._errors = [] - self._errors.append(error) - - def raise_error(self, key, **kwargs): - """直接返回错误""" - try: - msg = self.error_messages[key] - except KeyError: - class_name = self.__class__.__name__ - msg = "在 {class_name} 类的 error_messages " \ - "属性中未能找到 Key 为 {key} 的错误描述".format(class_name=class_name, key=key) - raise AssertionError(msg) - message_string = msg.format(**kwargs) - raise ValidationError(message_string, code=key) - - def __new__(cls, *args, **kwargs): - """ - 当一个字段被实例化时,我们存储所使用的参数, - 这样我就可以在 __deepcopy__ 提供他们 - """ - instance = super().__new__(cls) - instance._args = args - instance._kwargs = kwargs - return instance - - def __deepcopy__(self, memo): - """ - 当克隆字段时,我们使用参数实例化它 - 最初创建时用的,而不是复制完整的状态。 - """ - args = [ - copy.deepcopy(item) if not isinstance(item, REGEX_TYPE) else item - for item in self._args - ] - kwargs = { - key: (copy.deepcopy(value, memo) if (key not in ('validators', 'regex')) else value) - for key, value in self._kwargs.items() - } - return self.__class__(*args, **kwargs) - +# class Field: +# _sort_counter = 0 +# +# default_error_messages = { +# 'required': '{field_name}是必填项,为空', +# 'null': '{field_name}不能为空' +# } +# default_validators = None +# +# def __init__(self, read_only=False, write_only=False, required=True, allow_null=False, +# default=None, initial=empty, source=None, error_messages=None, +# label=None, description=None, validators=None): +# """ +# 字段基类 +# :param read_only: 是否只反序列化, 供 serializers 使用 +# :param write_only: 是否只序列化, 供 serializers 使用 +# :param required: 是否必填 与 read_only 冲突,供 serializers 使用 +# :param allow_null: 是否允许为空 +# :param default: 默认值 +# :param initial: 初始化值 +# :param source: 来源 getattr(initial,source) +# :param label: 易懂的语言描述 列如:机器编号, 呈现在ApiDoc内 +# :param description: 详细描述 呈现在ApiDoc内 +# :param validators: 自定义的验证器 +# """ +# # 优先检查冲突 +# assert (read_only and write_only), NOT_RAED_ONLY_AND_WRITE_ONLY +# assert (read_only and required), NOT_RAED_ONLY_REQUIRED_ONLY +# +# self._sort_counter = Field._sort_counter +# Field._sort_counter += 1 +# +# self.read_only = read_only +# self.write_only = write_only +# self.required = required +# self.allow_null = allow_null +# self.default = default +# self.source = source +# self.initial = self.initial if (initial is empty) else initial +# self.label = label +# self.description = description +# self.error_messages = copy.deepcopy(self.default_error_messages) +# self.validators = [] if self.default_validators is None else copy.deepcopy(self.default_validators) +# if error_messages: +# self.error_messages.update(error_messages) +# if validators: +# self.validators.extend(validators) +# +# # 为绑定做准备 .bind() +# self.field_name = None +# self.parent = None +# # 存储错误 +# self._errors: List[ValidationError] = [] +# +# def bind(self, field_name: str, parent): +# self.field_name = field_name +# self.parent = parent +# if self.source is None: +# self.source = field_name +# +# # 数据处理 +# def to_python(self, data: Any): +# """序列化时用到的数据处理函数 +# 会在验证器验证前执行,届时验证器接受到的data +# 就是 to_python 执行后返回的数据 +# """ +# raise NotImplementedError( +# '{cls}类在继承 Field 类后内部的 .to_python() 必须重写' +# '请勿忘记处理 write_only 时的情况'.format(cls=self.__class__.__name__, ) +# ) +# +# def to_string(self, data: Any): +# """反序列化时用到的数据处理函数""" +# raise NotImplementedError( +# '{cls}类在继承 Field 类后内部的 .to_string() 必须重写' +# '请勿忘记处理 read_only 时的情况'.format(cls=self.__class__.__name__, ) +# ) +# +# def get_value(self, data: Any): +# """得到 self.data 中的数据 数据可能来自 dict model request.query""" +# if isinstance(data, Mapping): +# if self.source in data: +# return data[self.source] +# elif isinstance(data, Model): +# return await getattr(data, self.source) +# +# # 验证处理 +# +# def validate(self, data: Any): +# """自带的验证方法""" +# pass +# +# def run_validation(self, data: Any): +# """执行验证器""" +# errors = [] +# data = self.to_python(data) +# try: +# self.validate(data) +# except ValidationError as e: +# if hasattr(e, 'code') and e.code in self.error_messages: +# e.message = self.error_messages[e.code] +# errors.extend(e.error_list) +# +# for validator in self.validators: +# if callable(validator): +# try: +# validator(data) +# except ValidationError as e: +# if hasattr(e, 'code') and e.code in self.error_messages: +# e.message = self.error_messages[e.code] +# errors.extend(e.error_list) +# if errors: +# self._errors.extend(errors) +# raise ValidationError(errors) +# return data +# +# def validate_empty_values(self, data): +# """ +# 验证空值 +# """ +# if self.allow_null: +# if data is empty: +# return True, None +# return False, data +# +# def add_error(self, error): +# if self._errors is None: +# self._errors = [] +# self._errors.append(error) +# +# def raise_error(self, key, **kwargs): +# """直接返回错误""" +# try: +# msg = self.error_messages[key] +# except KeyError: +# class_name = self.__class__.__name__ +# msg = "在 {class_name} 类的 error_messages " \ +# "属性中未能找到 Key 为 {key} 的错误描述".format(class_name=class_name, key=key) +# raise AssertionError(msg) +# message_string = msg.format(**kwargs) +# raise ValidationError(message_string, code=key) +# +# def __new__(cls, *args, **kwargs): +# """ +# 当一个字段被实例化时,我们存储所使用的参数, +# 这样我就可以在 __deepcopy__ 提供他们 +# """ +# instance = super().__new__(cls) +# instance._args = args +# instance._kwargs = kwargs +# return instance +# +# def __deepcopy__(self, memo): +# """ +# 当克隆字段时,我们使用参数实例化它 +# 最初创建时用的,而不是复制完整的状态。 +# """ +# args = [ +# copy.deepcopy(item) if not isinstance(item, REGEX_TYPE) else item +# for item in self._args +# ] +# kwargs = { +# key: (copy.deepcopy(value, memo) if (key not in ('validators', 'regex')) else value) +# for key, value in self._kwargs.items() +# } +# return self.__class__(*args, **kwargs) +# class CharField(Field): default_error_messages = { - 'required': '{field_name}是必填项', - 'null': '{field_name}不能为空', - 'invalid': '{field_name}出现错误的数据类型,仅支持整字符类型', - 'max_length': '{field_name}最长支持{max_length}个字符', - 'min_length': '{field_name}至少要有{min_length}个字符', + 'invalid': '出现错误的数据类型,仅支持整字符类型', + 'max_length': '最长支持{max_length}个字符', + 'min_length': '至少要有{min_length}个字符', } def __init__(self, *args, **kwargs): @@ -440,25 +444,23 @@ class CharField(Field): super(CharField, self).__init__(*args, *kwargs) # TODO 需要将 MaxLengthValidator 与 MinLengthValidator 添加入 self.validators - def to_python(self, data: Any): + def internal_convert(self, data: Any) -> Any: if isinstance(data, bool) or not isinstance(data, (str, int, float,)): - self.raise_error('invalid', field_name=self.field_name) + self.raise_error('invalid') value = str(data) return value.strip() if self.trim_whitespace else value - def to_string(self, data: Any): + def external_convert(self, data: Any) -> Any: return str(data) class IntegerField(Field): """整数类型""" default_error_messages = { - 'required': '{field_name}是必填项', - 'null': '{field_name}不能为空', - 'invalid': '{field_name}出现错误的数据类型,仅支持整数类型', - 'max_value': '{field_name}仅支持小于{max_value}的整数', - 'min_value': '{field_name}仅支持大于{min_value}的整数', - 'max_string_length': '{field_name}仅支持转换长度小于{max_string_length}的整数字符串', + 'invalid': '出现错误的数据类型,仅支持整数类型', + 'max_value': '仅支持小于{max_value}的整数', + 'min_value': '仅支持大于{min_value}的整数', + 'max_string_length': '仅支持转换长度小于{max_string_length}的整数字符串', } re_decimal = re.compile(r'\.0*\s*$') MAX_STRING_LENGTH = 1000 @@ -469,52 +471,51 @@ class IntegerField(Field): super(IntegerField, self).__init__(*args, *kwargs) # TODO 需要将 MaxValueValidator 与 MinValueValidator 添加入 self.validators - def to_python(self, data: Any): + def internal_convert(self, data: Any): if isinstance(data, str) and len(data) > self.MAX_STRING_LENGTH: - self.raise_error('max_string_length', field_name=self.field_name, max_string_length=self.MAX_STRING_LENGTH) + self.raise_error('max_string_length', max_string_length=self.MAX_STRING_LENGTH) try: data = int(self.re_decimal.sub('', str(data))) except (ValueError, TypeError): - self.raise_error('invalid', field_name=self.field_name) + self.raise_error('invalid') return data - def to_string(self, data: Any): + def external_convert(self, data: Any): return int(data) class FloatField(IntegerField): """浮点类型""" default_error_messages = { - 'required': '{field_name}是必填项', - 'null': '{field_name}不能为空', - 'invalid': '{field_name}出现错误的数据类型,仅支持浮点类型', - 'max_value': '{field_name}仅支持小于{max_value}浮点', - 'min_value': '{field_name}仅支持大于{min_value}浮点', - 'max_string_length': '{field_name}仅支持转换长度小于{max_string_length}的浮点字符串', + 'invalid': '出现错误的数据类型,仅支持浮点类型', + 'max_value': '仅支持小于{max_value}浮点', + 'min_value': '仅支持大于{min_value}浮点', + 'max_string_length': '仅支持转换长度小于{max_string_length}的浮点字符串', } MAX_STRING_LENGTH = 1000 - def to_python(self, data: Any): + def internal_convert(self, data: Any): if isinstance(data, str) and len(data) > self.MAX_STRING_LENGTH: - self.raise_error('max_string_length', field_name=self.field_name, max_string_length=self.MAX_STRING_LENGTH) + self.raise_error('max_string_length', max_string_length=self.MAX_STRING_LENGTH) try: return float(data) except (TypeError, ValueError): - self.raise_error('invalid', field_name=self.field_name) + self.raise_error('invalid') - def to_string(self, data: Any): + def external_convert(self, data: Any): return float(data) class DecimalField(Field): """十进制类型""" default_error_messages = { - 'required': '{field_name}是必填项', - 'null': '{field_name}不能为空', - 'invalid': '{field_name}出现错误的数据类型,仅支持Decimal十进制类型', - 'max_value': '{field_name}仅支持小于{max_value}Decimal十进制类型', - 'min_value': '{field_name}仅支持大于{min_value}Decimal十进制类型', - 'max_string_length': '{field_name}仅支持转换长度小于{max_string_length}的Decimal十进制字符串', + 'invalid': '出现错误的数据类型,仅支持Decimal十进制类型', + 'max_value': '仅支持小于{max_value}Decimal十进制类型', + 'min_value': '仅支持大于{min_value}Decimal十进制类型', + 'max_string_length': '仅支持转换长度小于{max_string_length}的Decimal十进制字符串', + 'max_digits': '确保总数不超过{max_digits}个数字。', + 'max_decimal_places': '确保不超过{max_decimal_places}个小数位。', + 'max_whole_digits': '确保小数点前的位数不超过{max_whole_digits}个。', } MAX_STRING_LENGTH = 1000 @@ -533,26 +534,26 @@ class DecimalField(Field): super(DecimalField, self).__init__(*args, **kwargs) # TODO 需要将 MaxLengthValidator 与 MinLengthValidator 添加入 self.validators - def to_python(self, data: Any): + def internal_convert(self, data: Any): data = str(data).strip() if len(data) > self.MAX_STRING_LENGTH: - self.raise_error('max_string_length', field_name=self.field_name, max_string_length=self.MAX_STRING_LENGTH) + self.raise_error('max_string_length', max_string_length=self.MAX_STRING_LENGTH) try: data = decimal.Decimal(data) except decimal.DecimalException: - self.raise_error('invalid', field_name=self.field_name) + self.raise_error('invalid') if data.is_nan(): - self.raise_error('invalid', field_name=self.field_name) + self.raise_error('invalid') # 检查无穷大和负无穷大。 if data in (decimal.Decimal('Inf'), decimal.Decimal('-Inf')): - self.raise_error('invalid', field_name=self.field_name) + self.raise_error('invalid') return self.quantize(self.validate_precision(data)) - def to_string(self, data: Any): + def external_convert(self, data: Any): if not isinstance(data, decimal.Decimal): data = decimal.Decimal(str(data).strip()) @@ -612,9 +613,7 @@ class DecimalField(Field): class BooleanField(Field): default_error_messages = { - 'required': '{field_name}是必填项', - 'null': '{field_name}不能为空', - 'invalid': '{field_name}出现错误的数据类型,{value}不是有效的布尔值', + 'invalid': '出现错误的数据类型,{value}不是有效的布尔值', } TRUE_VALUES = { 't', 'T', @@ -634,7 +633,7 @@ class BooleanField(Field): } NULL_VALUES = {'null', 'Null', 'NULL', '', None} - def to_python(self, data: Any): + def internal_convert(self, data: Any) -> Any: try: if data in self.TRUE_VALUES: return True @@ -643,9 +642,9 @@ class BooleanField(Field): elif data in self.NULL_VALUES and self.allow_null: return None except TypeError: - self.raise_error('invalid', field_name=self.field_name, value=data) + self.raise_error('invalid', value=data) - def to_string(self, data: Any): + def external_convert(self, data: Any) -> Any: if data in self.TRUE_VALUES: return True elif data in self.FALSE_VALUES: @@ -657,11 +656,9 @@ class BooleanField(Field): class DateTimeField(Field): default_error_messages = { - 'required': '{field_name}是必填项', - 'null': '{field_name}不能为空', - 'invalid': '{field_name}出现错误的数据类型,{value}不是有效的日期时间类型', - 'date': '{field_name}需要的是日期时间格式而不是日期格式', - 'overflow': '{field_name}时间超出范围' + 'convert': '日期转换异常,请确认日期格式符合 %Y-%m-%d %H:%M:%S 规则', + 'date': '需要的是日期时间格式而不是日期格式', + 'overflow': '时间超出范围' } def __init__(self, output_format='%Y-%m-%d %H:%M:%S', input_formats='%Y-%m-%d %H:%M:%S', @@ -682,38 +679,36 @@ class DateTimeField(Field): """强制设置一个时区""" return value.astimezone(self.set_timezone) - def to_python(self, data: Any): + def internal_convert(self, data: Any) -> Any: if not isinstance(data, (str, data, datetime)): - self.raise_error('invalid', field_name=self.field_name) + self.raise_error('convert') if isinstance(data, str): try: data = datetime.strptime(data, self.input_formats) except (ValueError, TypeError): - self.raise_error('invalid', field_name=self.field_name) + self.raise_error('convert') if isinstance(data, date): - self.raise_error('date', field_name=self.field_name) + self.raise_error('date') if isinstance(data, datetime): data = self.enforce_timezone(data) return data - def to_string(self, data: Any): + def external_convert(self, data: Any) -> Any: if not data: return None if isinstance(data, str): return data if isinstance(data, datetime): return data.strftime(self.output_format) - self.raise_error('invalid', field_name=self.field_name) + self.raise_error('invalid') class DateField(Field): """日期字段""" default_error_messages = { - 'required': '{field_name}是必填项', - 'null': '{field_name}不能为空', - 'invalid': '{field_name}出现错误的数据类型,{value}不是有效的日期时间类型', - 'datetime': '{field_name}需要的是日期格式而不是日期时间格式', + 'invalid': '出现错误的数据类型,{value}不是有效的日期时间类型', + 'datetime': '需要的是日期格式而不是日期时间格式', } def __init__(self, output_format='%Y-%m-%d', input_formats='%Y-%m-%d', *args, **kwargs): @@ -721,36 +716,34 @@ class DateField(Field): self.input_formats = input_formats super(DateField, self).__init__(*args, **kwargs) - def to_python(self, data: Any): + def internal_convert(self, data: Any) -> Any: if isinstance(data, str): try: data = datetime.strptime(data, self.input_formats).date() return data except (ValueError, TypeError): - self.raise_error('invalid', field_name=self.field_name) + self.raise_error('invalid', data=data) if isinstance(data, datetime): - self.raise_error('datetime', field_name=self.field_name) + self.raise_error('datetime') if isinstance(data, date): return data - self.raise_error('invalid', field_name=self.field_name) + self.raise_error('invalid', data=data) - def to_string(self, data: Any): + def external_convert(self, data: Any) -> Any: if not data: return data if isinstance(data, str): return data if isinstance(data, date): return data.strftime(self.output_format) - self.raise_error('invalid', field_name=self.field_name) + self.raise_error('invalid') class TimeField(Field): """时间字段""" default_error_messages = { - 'required': '{field_name}是必填项', - 'null': '{field_name}不能为空', - 'invalid': '{field_name}出现错误的数据类型,{value}不是有效的日期时间类型', - 'date': '{field_name}需要的是时间格式而不是日期格式', + 'invalid': '出现错误的数据类型,{value}不是有效的日期时间类型', + 'date': '需要的是时间格式而不是日期格式', } def __init__(self, output_format='%H:%M:%S', input_formats='%H:%M:%S', *args, **kwargs): @@ -758,36 +751,34 @@ class TimeField(Field): self.input_formats = input_formats super(TimeField, self).__init__(*args, **kwargs) - def to_python(self, data: Any): + def internal_convert(self, data: Any) -> Any: if isinstance(data, str): try: data = datetime.strptime(data, self.input_formats).time() return data except (ValueError, TypeError): - self.raise_error('invalid', field_name=self.field_name) + self.raise_error('invalid') if isinstance(data, datetime): return data.time() if isinstance(data, date): - self.raise_error('date', field_name=self.field_name) - self.raise_error('invalid', field_name=self.field_name) + self.raise_error('date') + self.raise_error('invalid') - def to_string(self, data: Any): + def external_convert(self, data: Any) -> Any: if not data: return data if isinstance(data, str): return data if isinstance(data, (time, datetime)): return data.strftime(self.output_format) - self.raise_error('invalid', field_name=self.field_name) + self.raise_error('invalid') class ChoiceField(Field): """限定可选的字段""" default_error_messages = { - 'required': '{field_name}是必填项', - 'null': '{field_name}不能为空', - 'invalid': '{field_name}出现错误的数据类型,{value}不是有效的日期时间类型', - 'key': '{field_name}没有{key}选项', + 'invalid': '出现错误的数据类型,{value}不是有效的日期时间类型', + 'key': '错误选项{key}', } def __init__(self, choices, *args, **kwargs): @@ -799,20 +790,52 @@ class ChoiceField(Field): self.choices = choices super(ChoiceField, self).__init__(*args, **kwargs) - def to_python(self, data: Any): + def internal_convert(self, data: Any) -> Any: data = self.get_choices().get(str(data)) return data - def to_string(self, data: Any): + def external_convert(self, data: Any) -> Any: return str(data) def choices_get_value_by_key(self, key): """得到字符串""" choices = self.get_choices() if key not in choices: - self.raise_error('key', field_name=self.field_name, key=key) + self.raise_error('key', key=key) return choices.get(key, None) def get_choices(self) -> dict: choices = {str(key): value for key, value in self.choices} return choices + + +class SerializerMethodField(Field): + """ + 一个只读字段,可通过在父序列化器类。调用的方法将具有以下形式 + “ get_ {field_name}”,并且应采用单个参数,即 + 对象被序列化。 + For example: + + class ExampleSerializer(self): + extra_info = SerializerMethodField() + + def get_extra_info(self, obj): + return ... # Calculate some data to return. + """ + + def __init__(self, method_name=None, **kwargs): + self.method_name = method_name + kwargs['source'] = '*' + kwargs['read_only'] = True + super().__init__(**kwargs) + + def bind(self, field_name, parent): + # The method name defaults to `get_{field_name}`. + if self.method_name is None: + self.method_name = 'get_{field_name}'.format(field_name=field_name) + + super().bind(field_name, parent) + + def to_representation(self, value): + method = getattr(self.parent, self.method_name) + return method(value) diff --git a/sanic_rest_framework/serializers.py b/sanic_rest_framework/serializers.py index 320d8f3..278e7db 100644 --- a/sanic_rest_framework/serializers.py +++ b/sanic_rest_framework/serializers.py @@ -11,39 +11,119 @@ import copy from collections import OrderedDict from functools import cached_property -from typing import Any +from typing import Any, Mapping -from sanic_rest_framework.fields import Field, empty +from sanic_rest_framework.fields import Field, empty, SkipField from .exceptions import ValidationError from .helpers import BindingDict +LIST_SERIALIZER_KWARGS = ( + 'read_only', + 'write_only', + 'required', + 'allow_null', + 'default', + 'source', + 'validators', + 'error_messages', + 'label', + 'description', + 'instance', + 'data', + 'partial' +) +ALL_FIELDS = '__all__' + class BaseSerializer(Field): """序列化器""" - def __init__(self, data=empty, **kwargs): - super(BaseSerializer, self).__init__(**kwargs) - self.data = data + def __init__(self, instance=None, data=empty, **kwargs): + self.instance = instance + if data is not empty: + self.initial_data = data + self.partial = kwargs.pop('partial', False) + self._context = kwargs.pop('context', {}) + kwargs.pop('many', None) + super().__init__(**kwargs) - def to_python(self, data: Any): - raise NotImplementedError('`to_python()` must be implemented.') + def __new__(cls, *args, **kwargs): + if kwargs.pop('many', False): + return cls.many_init(*args, **kwargs) + return super().__new__(cls, *args, **kwargs) - def to_string(self, data: Any): - raise NotImplementedError('`to_string()` must be implemented.') + @classmethod + def many_init(cls, *args, **kwargs): + """""" + allow_empty = kwargs.pop('allow_empty', None) + child_serializer = cls(*args, **kwargs) + list_kwargs = { + 'child': child_serializer, + } + if allow_empty is not None: + list_kwargs['allow_empty'] = allow_empty + list_kwargs.update({ + key: value for key, value in kwargs.items() + if key in LIST_SERIALIZER_KWARGS + }) + meta = getattr(cls, 'Meta', None) + list_serializer_class = getattr(meta, 'list_serializer_class', ListSerializer) + return list_serializer_class(*args, **list_kwargs) + + def internal_convert(self, data: Any) -> Any: + """对数据进行序列化转换并返回""" + raise NotImplementedError( + '{cls}类的 .internal_convert 方法必须重写'.format(cls=self.__class__.__name__, ) + ) - def update(self, instance, validated_data): - raise NotImplementedError('`update()` must be implemented.') + def external_convert(self, data: Any) -> Any: + """对数据进行反序列化转换并返回""" + raise NotImplementedError( + '{cls}类的 .internal_convert 方法必须重写'.format(cls=self.__class__.__name__, ) + ) - def create(self, validated_data): - raise NotImplementedError('`create()` must be implemented.') + def is_valid(self, raise_exception=False): + assert hasattr(self, 'initial_data'), ( + 'Cannot call `.is_valid()` as no `data=` keyword argument was ' + 'passed when instantiating the serializer instance.' + ) - def validate(self, raise_exception: bool = False): - """ - 验证函数 - :param raise_exception: 是否直接抛出错误 - :return: - """ - raise NotImplementedError('`validate()` must be implemented.') + if not hasattr(self, '_validated_data'): + try: + self._validated_data = self.run_validation(self.initial_data) + except ValidationError as exc: + self._validated_data = {} + self._errors = exc.error_dict + else: + self._errors = {} + + if self._errors and raise_exception: + raise ValidationError(self.errors) + return not bool(self._errors) + + @property + def data(self): + """对外呈现的数据""" + assert self.instance is None, '调用 .data 必须先传入 instance= ' + if not hasattr(self, '_data'): + self._data = self.external_convert(self.instance) + return self._data + + @property + def errors(self): + """对外呈现验证过程中出现的错误""" + if not hasattr(self, '_errors'): + msg = '你必须先执行 .is_valid() 方法才能调用 .errors' + raise AssertionError(msg) + return self._errors + + @property + def validated_data(self): + """对内使用的验证后的数据""" + if not hasattr(self, '_validated_data'): + msg = '你必须先执行 .is_valid() 方法才能调用 .validated_data' + raise AssertionError(msg) + return self._validated_data class SerializerMetaclass(type): @@ -76,7 +156,9 @@ class SerializerMetaclass(type): class Serializer(BaseSerializer, metaclass=SerializerMetaclass): - pass + default_error_messages = { + 'invalid': '无效数据。应该是字典,但是得到了{datatype}' + } @cached_property def fields(self): @@ -90,13 +172,6 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass): fields[key] = value return fields - def get_fields(self) -> dict: - """ - 得到所有fields - :return: - """ - return copy.deepcopy(self._declared_fields) - @property def _writable_fields(self): for field in self.fields.values(): @@ -109,30 +184,55 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass): if not field.write_only: yield field - def get_validators(self): + def get_fields(self) -> dict: """ - 得到所有属于序列化器的验证器,存在于 Mete 类中 + 得到所有fields + :return: """ - meta = getattr(self, 'Meta', None) - validators = getattr(meta, 'validators', None) - return list(validators) if validators else [] - - def run_validation(self, data: Any): - pass + return copy.deepcopy(self._declared_fields) - @property - def data(self) -> OrderedDict: - if self.initial is empty: - return OrderedDict() - data = OrderedDict( - [(field.field_name, field.to_string(self.initial)) for field in self._readable_fields] - ) - return data + def get_value_to_internal(self, data: Mapping) -> Any: + return data.get(self.field_name, empty) + + def _read_only_defaults(self): + fields = [ + field for field in self.fields.values() + if field.read_only and (field.default != empty) and (field.source != '*') and ('.' not in field.source) + ] + + defaults = OrderedDict() + for field in fields: + try: + default = field.get_default() + except SkipField: + continue + defaults[field.source] = default + return defaults + + def run_validators(self, value): + """ + Add read_only fields with defaults to value before running validators. + """ + if isinstance(value, dict): + to_validate = self._read_only_defaults() + to_validate.update(value) + else: + to_validate = value + super().run_validators(to_validate) - @property - def error(self): - pass + def validate(self, attrs): + return attrs def __iter__(self): for field in self.fields.values(): yield self[field.field_name] + + def __getitem__(self, key): + field = self.fields[key] + value = self.data.get(key) + error = self.errors.get(key) if hasattr(self, '_errors') else None + return { + 'field': field, + 'value': value, + 'error': error, + } -- Gitee From 7f0f069b2c168eb5a40e281a0cc1526c1e8f3b90 Mon Sep 17 00:00:00 2001 From: LaoSi Date: Tue, 26 Jan 2021 00:00:56 +0800 Subject: [PATCH 09/34] =?UTF-8?q?external=5Fto=5Finternal=20=E9=80=BB?= =?UTF-8?q?=E8=BE=91=E9=9C=80=E8=A6=81=E9=87=8D=E6=96=B0=E7=BC=95=E4=B8=80?= =?UTF-8?q?=E7=BC=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sanic_rest_framework/fields.py | 229 +++++----------------------- sanic_rest_framework/serializers.py | 86 ++++++++--- 2 files changed, 104 insertions(+), 211 deletions(-) diff --git a/sanic_rest_framework/fields.py b/sanic_rest_framework/fields.py index 241e4cd..f80d93a 100644 --- a/sanic_rest_framework/fields.py +++ b/sanic_rest_framework/fields.py @@ -67,6 +67,9 @@ class Field: :param label: 字段标题 :param description: 字段描述 """ + assert (read_only and write_only), NOT_RAED_ONLY_AND_WRITE_ONLY + assert (read_only and required), NOT_RAED_ONLY_REQUIRED_ONLY + self._sort_counter = Field._sort_counter Field._sort_counter += 1 @@ -161,19 +164,19 @@ class Field: root = self.root return getattr(root, 'partial', False) - def internal_convert(self, data: Any) -> Any: - """对数据进行序列化转换并返回""" + def external_to_internal(self, data: Any) -> Any: + """对数据进行反序列化转换并返回""" raise NotImplementedError( - '{cls}类的 .internal_convert 方法必须重写'.format(cls=self.__class__.__name__, ) + '{cls}类的 .external_to_internal 方法必须重写'.format(cls=self.__class__.__name__, ) ) - def external_convert(self, data: Any) -> Any: - """对数据进行反序列化转换并返回""" + def internal_to_external(self, data: Any) -> Any: + """对数据进行序列化转换并返回""" raise NotImplementedError( - '{cls}类的 .internal_convert 方法必须重写'.format(cls=self.__class__.__name__, ) + '{cls}类的 .external_to_internal 方法必须重写'.format(cls=self.__class__.__name__, ) ) - def get_value_to_internal(self, data: Mapping) -> Any: + def get_external_value(self, data: Mapping) -> Any: """ 从传入的外部数据中得到值 值用于输入验证 @@ -192,7 +195,7 @@ class Field: async def async_get_attribute(self, instance, attr): return await getattr(instance, attr) - def get_value_to_external(self, instance: Any) -> Any: + def get_internal_value(self, instance: Any) -> Any: """ 从传入的内部数据中得到值 值用于输出 @@ -208,6 +211,14 @@ class Field: instance = self.async_get_attribute(instance, attr) except DoesNotExist: return None + except (KeyError, AttributeError) as exc: + if self.default is not empty: + return self.default + if self.allow_null: + return None + if not self.required: + raise SkipField() + raise type(exc)('在序列化过程中字段{field_name}未能成功序列化'.format(field_name=self.field_name)) return instance def run_validators(self, data) -> None: @@ -230,7 +241,7 @@ class Field: def run_validation(self, data): """执行验证""" - value = self.internal_convert(data) + value = self.external_to_internal(data) self.run_validators(value) return value @@ -262,174 +273,6 @@ class Field: raise ValidationError(message_string, code=key) -# class Field: -# _sort_counter = 0 -# -# default_error_messages = { -# 'required': '{field_name}是必填项,为空', -# 'null': '{field_name}不能为空' -# } -# default_validators = None -# -# def __init__(self, read_only=False, write_only=False, required=True, allow_null=False, -# default=None, initial=empty, source=None, error_messages=None, -# label=None, description=None, validators=None): -# """ -# 字段基类 -# :param read_only: 是否只反序列化, 供 serializers 使用 -# :param write_only: 是否只序列化, 供 serializers 使用 -# :param required: 是否必填 与 read_only 冲突,供 serializers 使用 -# :param allow_null: 是否允许为空 -# :param default: 默认值 -# :param initial: 初始化值 -# :param source: 来源 getattr(initial,source) -# :param label: 易懂的语言描述 列如:机器编号, 呈现在ApiDoc内 -# :param description: 详细描述 呈现在ApiDoc内 -# :param validators: 自定义的验证器 -# """ -# # 优先检查冲突 -# assert (read_only and write_only), NOT_RAED_ONLY_AND_WRITE_ONLY -# assert (read_only and required), NOT_RAED_ONLY_REQUIRED_ONLY -# -# self._sort_counter = Field._sort_counter -# Field._sort_counter += 1 -# -# self.read_only = read_only -# self.write_only = write_only -# self.required = required -# self.allow_null = allow_null -# self.default = default -# self.source = source -# self.initial = self.initial if (initial is empty) else initial -# self.label = label -# self.description = description -# self.error_messages = copy.deepcopy(self.default_error_messages) -# self.validators = [] if self.default_validators is None else copy.deepcopy(self.default_validators) -# if error_messages: -# self.error_messages.update(error_messages) -# if validators: -# self.validators.extend(validators) -# -# # 为绑定做准备 .bind() -# self.field_name = None -# self.parent = None -# # 存储错误 -# self._errors: List[ValidationError] = [] -# -# def bind(self, field_name: str, parent): -# self.field_name = field_name -# self.parent = parent -# if self.source is None: -# self.source = field_name -# -# # 数据处理 -# def to_python(self, data: Any): -# """序列化时用到的数据处理函数 -# 会在验证器验证前执行,届时验证器接受到的data -# 就是 to_python 执行后返回的数据 -# """ -# raise NotImplementedError( -# '{cls}类在继承 Field 类后内部的 .to_python() 必须重写' -# '请勿忘记处理 write_only 时的情况'.format(cls=self.__class__.__name__, ) -# ) -# -# def to_string(self, data: Any): -# """反序列化时用到的数据处理函数""" -# raise NotImplementedError( -# '{cls}类在继承 Field 类后内部的 .to_string() 必须重写' -# '请勿忘记处理 read_only 时的情况'.format(cls=self.__class__.__name__, ) -# ) -# -# def get_value(self, data: Any): -# """得到 self.data 中的数据 数据可能来自 dict model request.query""" -# if isinstance(data, Mapping): -# if self.source in data: -# return data[self.source] -# elif isinstance(data, Model): -# return await getattr(data, self.source) -# -# # 验证处理 -# -# def validate(self, data: Any): -# """自带的验证方法""" -# pass -# -# def run_validation(self, data: Any): -# """执行验证器""" -# errors = [] -# data = self.to_python(data) -# try: -# self.validate(data) -# except ValidationError as e: -# if hasattr(e, 'code') and e.code in self.error_messages: -# e.message = self.error_messages[e.code] -# errors.extend(e.error_list) -# -# for validator in self.validators: -# if callable(validator): -# try: -# validator(data) -# except ValidationError as e: -# if hasattr(e, 'code') and e.code in self.error_messages: -# e.message = self.error_messages[e.code] -# errors.extend(e.error_list) -# if errors: -# self._errors.extend(errors) -# raise ValidationError(errors) -# return data -# -# def validate_empty_values(self, data): -# """ -# 验证空值 -# """ -# if self.allow_null: -# if data is empty: -# return True, None -# return False, data -# -# def add_error(self, error): -# if self._errors is None: -# self._errors = [] -# self._errors.append(error) -# -# def raise_error(self, key, **kwargs): -# """直接返回错误""" -# try: -# msg = self.error_messages[key] -# except KeyError: -# class_name = self.__class__.__name__ -# msg = "在 {class_name} 类的 error_messages " \ -# "属性中未能找到 Key 为 {key} 的错误描述".format(class_name=class_name, key=key) -# raise AssertionError(msg) -# message_string = msg.format(**kwargs) -# raise ValidationError(message_string, code=key) -# -# def __new__(cls, *args, **kwargs): -# """ -# 当一个字段被实例化时,我们存储所使用的参数, -# 这样我就可以在 __deepcopy__ 提供他们 -# """ -# instance = super().__new__(cls) -# instance._args = args -# instance._kwargs = kwargs -# return instance -# -# def __deepcopy__(self, memo): -# """ -# 当克隆字段时,我们使用参数实例化它 -# 最初创建时用的,而不是复制完整的状态。 -# """ -# args = [ -# copy.deepcopy(item) if not isinstance(item, REGEX_TYPE) else item -# for item in self._args -# ] -# kwargs = { -# key: (copy.deepcopy(value, memo) if (key not in ('validators', 'regex')) else value) -# for key, value in self._kwargs.items() -# } -# return self.__class__(*args, **kwargs) -# - class CharField(Field): default_error_messages = { 'invalid': '出现错误的数据类型,仅支持整字符类型', @@ -444,13 +287,13 @@ class CharField(Field): super(CharField, self).__init__(*args, *kwargs) # TODO 需要将 MaxLengthValidator 与 MinLengthValidator 添加入 self.validators - def internal_convert(self, data: Any) -> Any: + def external_to_internal(self, data: Any) -> Any: if isinstance(data, bool) or not isinstance(data, (str, int, float,)): self.raise_error('invalid') value = str(data) return value.strip() if self.trim_whitespace else value - def external_convert(self, data: Any) -> Any: + def internal_to_external(self, data: Any) -> Any: return str(data) @@ -471,7 +314,7 @@ class IntegerField(Field): super(IntegerField, self).__init__(*args, *kwargs) # TODO 需要将 MaxValueValidator 与 MinValueValidator 添加入 self.validators - def internal_convert(self, data: Any): + def external_to_internal(self, data: Any): if isinstance(data, str) and len(data) > self.MAX_STRING_LENGTH: self.raise_error('max_string_length', max_string_length=self.MAX_STRING_LENGTH) try: @@ -480,7 +323,7 @@ class IntegerField(Field): self.raise_error('invalid') return data - def external_convert(self, data: Any): + def internal_to_external(self, data: Any): return int(data) @@ -494,7 +337,7 @@ class FloatField(IntegerField): } MAX_STRING_LENGTH = 1000 - def internal_convert(self, data: Any): + def external_to_internal(self, data: Any): if isinstance(data, str) and len(data) > self.MAX_STRING_LENGTH: self.raise_error('max_string_length', max_string_length=self.MAX_STRING_LENGTH) try: @@ -502,7 +345,7 @@ class FloatField(IntegerField): except (TypeError, ValueError): self.raise_error('invalid') - def external_convert(self, data: Any): + def internal_to_external(self, data: Any): return float(data) @@ -534,7 +377,7 @@ class DecimalField(Field): super(DecimalField, self).__init__(*args, **kwargs) # TODO 需要将 MaxLengthValidator 与 MinLengthValidator 添加入 self.validators - def internal_convert(self, data: Any): + def external_to_internal(self, data: Any): data = str(data).strip() if len(data) > self.MAX_STRING_LENGTH: @@ -553,7 +396,7 @@ class DecimalField(Field): return self.quantize(self.validate_precision(data)) - def external_convert(self, data: Any): + def internal_to_external(self, data: Any): if not isinstance(data, decimal.Decimal): data = decimal.Decimal(str(data).strip()) @@ -633,7 +476,7 @@ class BooleanField(Field): } NULL_VALUES = {'null', 'Null', 'NULL', '', None} - def internal_convert(self, data: Any) -> Any: + def external_to_internal(self, data: Any) -> Any: try: if data in self.TRUE_VALUES: return True @@ -644,7 +487,7 @@ class BooleanField(Field): except TypeError: self.raise_error('invalid', value=data) - def external_convert(self, data: Any) -> Any: + def internal_to_external(self, data: Any) -> Any: if data in self.TRUE_VALUES: return True elif data in self.FALSE_VALUES: @@ -679,7 +522,7 @@ class DateTimeField(Field): """强制设置一个时区""" return value.astimezone(self.set_timezone) - def internal_convert(self, data: Any) -> Any: + def external_to_internal(self, data: Any) -> Any: if not isinstance(data, (str, data, datetime)): self.raise_error('convert') @@ -694,7 +537,7 @@ class DateTimeField(Field): data = self.enforce_timezone(data) return data - def external_convert(self, data: Any) -> Any: + def internal_to_external(self, data: Any) -> Any: if not data: return None if isinstance(data, str): @@ -716,7 +559,7 @@ class DateField(Field): self.input_formats = input_formats super(DateField, self).__init__(*args, **kwargs) - def internal_convert(self, data: Any) -> Any: + def external_to_internal(self, data: Any) -> Any: if isinstance(data, str): try: data = datetime.strptime(data, self.input_formats).date() @@ -729,7 +572,7 @@ class DateField(Field): return data self.raise_error('invalid', data=data) - def external_convert(self, data: Any) -> Any: + def internal_to_external(self, data: Any) -> Any: if not data: return data if isinstance(data, str): @@ -751,7 +594,7 @@ class TimeField(Field): self.input_formats = input_formats super(TimeField, self).__init__(*args, **kwargs) - def internal_convert(self, data: Any) -> Any: + def external_to_internal(self, data: Any) -> Any: if isinstance(data, str): try: data = datetime.strptime(data, self.input_formats).time() @@ -790,7 +633,7 @@ class ChoiceField(Field): self.choices = choices super(ChoiceField, self).__init__(*args, **kwargs) - def internal_convert(self, data: Any) -> Any: + def external_to_internal(self, data: Any) -> Any: data = self.get_choices().get(str(data)) return data diff --git a/sanic_rest_framework/serializers.py b/sanic_rest_framework/serializers.py index 278e7db..774cd84 100644 --- a/sanic_rest_framework/serializers.py +++ b/sanic_rest_framework/serializers.py @@ -35,8 +35,34 @@ LIST_SERIALIZER_KWARGS = ( ALL_FIELDS = '__all__' +def set_value(dictionary, keys, value): + """ + set_value({'a': 1}, [], {'b': 2}) -> {'a': 1, 'b': 2} + set_value({'a': 1}, ['x'], 2) -> {'a': 1, 'x': 2} + set_value({'a': 1}, ['x', 'y'], 2) -> {'a': 1, 'x': {'y': 2}} + """ + if not keys: + dictionary.update(value) + return + + for key in keys[:-1]: + if key not in dictionary: + dictionary[key] = {} + dictionary = dictionary[key] + + dictionary[keys[-1]] = value + + class BaseSerializer(Field): - """序列化器""" + """序列化器 + .instance -> + .get_internal_value -> + .internal_to_external() -> + .data + + .install_data -> .get_external_value -> .external_to_internal() -> .validated_data + + """ def __init__(self, instance=None, data=empty, **kwargs): self.instance = instance @@ -70,16 +96,16 @@ class BaseSerializer(Field): list_serializer_class = getattr(meta, 'list_serializer_class', ListSerializer) return list_serializer_class(*args, **list_kwargs) - def internal_convert(self, data: Any) -> Any: + def external_to_internal(self, data: Any) -> Any: """对数据进行序列化转换并返回""" raise NotImplementedError( - '{cls}类的 .internal_convert 方法必须重写'.format(cls=self.__class__.__name__, ) + '{cls}类的 .external_to_internal 方法必须重写'.format(cls=self.__class__.__name__, ) ) - def external_convert(self, data: Any) -> Any: + def internal_to_external(self, data: Any) -> Any: """对数据进行反序列化转换并返回""" raise NotImplementedError( - '{cls}类的 .internal_convert 方法必须重写'.format(cls=self.__class__.__name__, ) + '{cls}类的 .external_to_internal 方法必须重写'.format(cls=self.__class__.__name__, ) ) def is_valid(self, raise_exception=False): @@ -106,7 +132,7 @@ class BaseSerializer(Field): """对外呈现的数据""" assert self.instance is None, '调用 .data 必须先传入 instance= ' if not hasattr(self, '_data'): - self._data = self.external_convert(self.instance) + self._data = self.internal_to_external(self.instance) return self._data @property @@ -191,23 +217,47 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass): """ return copy.deepcopy(self._declared_fields) - def get_value_to_internal(self, data: Mapping) -> Any: - return data.get(self.field_name, empty) + def internal_to_external(self, data: Any) -> Any: + """ + 内转外 + :param data: + :return: + """ + res = OrderedDict() + fields = self._readable_fields + for field in fields: + value = field.get_internal_value(data) + res[field.field_name] = field.internal_to_external(value) + return res - def _read_only_defaults(self): - fields = [ - field for field in self.fields.values() - if field.read_only and (field.default != empty) and (field.source != '*') and ('.' not in field.source) - ] + # 反序列化 - defaults = OrderedDict() + def external_to_internal(self, data: Any) -> Any: + """ + 外转内 + :param data: + :return: + """ + res = OrderedDict() + errors = OrderedDict() + fields = self._writable_fields for field in fields: + validate_method = getattr(self, 'validate_' + field.field_name, None) + primitive_value = field.get_external_value(data) try: - default = field.get_default() + validated_value = field.run_validation(primitive_value) + if validate_method is not None: + validated_value = validate_method(validated_value) + except ValidationError as exc: + errors[field.field_name] = exc.error_dict except SkipField: - continue - defaults[field.source] = default - return defaults + pass + else: + set_value(res, field.source_attrs, validated_value) + + if errors: + raise ValidationError(errors) + return res def run_validators(self, value): """ -- Gitee From adf38bbdd0b5a8c369b7d1bea906b62851c7295d Mon Sep 17 00:00:00 2001 From: LaoSi Date: Tue, 26 Jan 2021 17:04:48 +0800 Subject: [PATCH 10/34] =?UTF-8?q?=E6=8F=90=E4=BA=A4=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- run.py | 18 +++++++++++++ sanic_rest_framework/fields.py | 42 ++++++++++++++++++++++++----- sanic_rest_framework/serializers.py | 26 +++++++++--------- 3 files changed, 65 insertions(+), 21 deletions(-) diff --git a/run.py b/run.py index f203742..c70a303 100644 --- a/run.py +++ b/run.py @@ -4,15 +4,33 @@ from tortoise.contrib.sanic import register_tortoise from db import TestModel from sanic_rest_framework.routes import Route +from sanic_rest_framework.serializers import Serializer +from sanic_rest_framework.fields import CharField from sanic_rest_framework.views import BaseAPIView app = Sanic(__name__) admin = Blueprint('admin', '/admin') +class QianTaoserializer(Serializer): + name = CharField() + doc = CharField() + + +class TestSerializer(Serializer): + id = CharField(max_length=18) + qt = QianTaoserializer() + + class TestView(BaseAPIView): async def get(self, request): + test = TestSerializer(data={ + 'id': '2', + 'qt':{} + }) + print(test.is_valid()) + test.validated_data return self.success_json_response() async def post(self, request): diff --git a/sanic_rest_framework/fields.py b/sanic_rest_framework/fields.py index f80d93a..600fe27 100644 --- a/sanic_rest_framework/fields.py +++ b/sanic_rest_framework/fields.py @@ -13,7 +13,6 @@ import decimal import re from datetime import timezone, timedelta, datetime, date, time from typing import Any, AnyStr, Optional, List, Mapping -from functools import cached_property from tortoise import Model from tortoise.queryset import QuerySet from tortoise.exceptions import DoesNotExist @@ -67,8 +66,8 @@ class Field: :param label: 字段标题 :param description: 字段描述 """ - assert (read_only and write_only), NOT_RAED_ONLY_AND_WRITE_ONLY - assert (read_only and required), NOT_RAED_ONLY_REQUIRED_ONLY + assert not (read_only and write_only), NOT_RAED_ONLY_AND_WRITE_ONLY + assert not (read_only and required), NOT_RAED_ONLY_REQUIRED_ONLY self._sort_counter = Field._sort_counter Field._sort_counter += 1 @@ -130,7 +129,7 @@ class Field: if self.source == '*': self.source_attrs = [] else: - self.source_attr = self.source.split('.') + self.source_attrs = self.source.split('.') def collect_error_message(self, error_messages_list: List[dict]) -> dict: """ @@ -184,7 +183,7 @@ class Field: :return: """ if not isinstance(data, Mapping): - raise ValidationError('{field_name}传入的数据为无效数据类型,仅支持字段类型'.format(field_name=self.field_name)) + raise ValidationError('传入的数据为无效数据类型,仅支持字段类型'.format(self.field_name)) if self.field_name not in data: if self.is_partial(): return empty @@ -203,7 +202,7 @@ class Field: :return: """ - for attr in self.source_attr: + for attr in self.source_attrs: try: if isinstance(instance, Mapping): instance = instance[attr] @@ -239,8 +238,37 @@ class Field: if errors: raise ValidationError(errors) + def get_default(self): + if self.default is empty or getattr(self.root, 'partial', False): + raise SkipField() + if callable(self.default): + return self.default() + return self.default + + def validate_empty_values(self, data): + if self.read_only: + return True, self.get_default() + if data is empty: + if self.is_partial(): + raise SkipField() + if self.required: + self.raise_error('required') + return True, self.get_default() + + if data is None: + if not self.allow_null: + self.raise_error('null') + elif self.source == '*': + return False, None + return True, None + + return (False, data) + def run_validation(self, data): """执行验证""" + (is_empty_value, data) = self.validate_empty_values(data) + if is_empty_value: + return data value = self.external_to_internal(data) self.run_validators(value) return value @@ -270,7 +298,7 @@ class Field: "属性中未能找到 Key 为 {key} 的错误描述".format(class_name=class_name, key=key) raise AssertionError(msg) message_string = msg.format(**kwargs) - raise ValidationError(message_string, code=key) + raise ValidationError({self.field_name: message_string}, code=key) class CharField(Field): diff --git a/sanic_rest_framework/serializers.py b/sanic_rest_framework/serializers.py index 774cd84..2b318e2 100644 --- a/sanic_rest_framework/serializers.py +++ b/sanic_rest_framework/serializers.py @@ -10,7 +10,6 @@ """ import copy from collections import OrderedDict -from functools import cached_property from typing import Any, Mapping from sanic_rest_framework.fields import Field, empty, SkipField @@ -130,7 +129,7 @@ class BaseSerializer(Field): @property def data(self): """对外呈现的数据""" - assert self.instance is None, '调用 .data 必须先传入 instance= ' + assert not self.instance is None, '调用 .data 必须先传入 instance= ' if not hasattr(self, '_data'): self._data = self.internal_to_external(self.instance) return self._data @@ -186,7 +185,7 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass): 'invalid': '无效数据。应该是字典,但是得到了{datatype}' } - @cached_property + @property def fields(self): """ 单个格式为 {field_name: field_instance}. @@ -254,21 +253,20 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass): pass else: set_value(res, field.source_attrs, validated_value) - if errors: raise ValidationError(errors) return res - def run_validators(self, value): - """ - Add read_only fields with defaults to value before running validators. - """ - if isinstance(value, dict): - to_validate = self._read_only_defaults() - to_validate.update(value) - else: - to_validate = value - super().run_validators(to_validate) + # def run_validators(self, value): + # """ + # Add read_only fields with defaults to value before running validators. + # """ + # if isinstance(value, dict): + # to_validate = self._read_only_defaults() + # to_validate.update(value) + # else: + # to_validate = value + # super().run_validators(to_validate) def validate(self, attrs): return attrs -- Gitee From ad796ea9ff1c5f84fb4e8cbd192036c57f80b622 Mon Sep 17 00:00:00 2001 From: LaoSi Date: Wed, 27 Jan 2021 18:14:07 +0800 Subject: [PATCH 11/34] =?UTF-8?q?=E6=8F=90=E4=BA=A4=E4=BF=AE=E6=94=B9=20fi?= =?UTF-8?q?elds=20exceptions=20serializers?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增 validators --- run.py | 14 +-- sanic_rest_framework/exceptions.py | 16 ++- sanic_rest_framework/fields.py | 19 ++- sanic_rest_framework/serializers.py | 5 +- .../test/test_serializers/test_serializer.py | 40 ++++++ sanic_rest_framework/validators.py | 119 +++++++++++++++++- 6 files changed, 191 insertions(+), 22 deletions(-) create mode 100644 sanic_rest_framework/test/test_serializers/test_serializer.py diff --git a/run.py b/run.py index c70a303..0ab8ebe 100644 --- a/run.py +++ b/run.py @@ -12,22 +12,14 @@ app = Sanic(__name__) admin = Blueprint('admin', '/admin') -class QianTaoserializer(Serializer): - name = CharField() - doc = CharField() - - -class TestSerializer(Serializer): - id = CharField(max_length=18) - qt = QianTaoserializer() - - class TestView(BaseAPIView): async def get(self, request): test = TestSerializer(data={ 'id': '2', - 'qt':{} + 'qt': { + 'name': '刘文静' + } }) print(test.is_valid()) test.validated_data diff --git a/sanic_rest_framework/exceptions.py b/sanic_rest_framework/exceptions.py index 907725c..5d5ec6b 100644 --- a/sanic_rest_framework/exceptions.py +++ b/sanic_rest_framework/exceptions.py @@ -8,6 +8,7 @@ exceptions.py 序列化器文件 """ +from typing import Mapping class ValidationError(Exception): @@ -24,10 +25,13 @@ class ValidationError(Exception): message, code, params = message.message, message.code, message.params if isinstance(message, dict): self.error_dict = {} - for field, messages in message.items(): - if not isinstance(messages, ValidationError): - messages = ValidationError(messages) - self.error_dict[field] = messages.error_list + for field, msg in message.items(): + if not isinstance(msg, ValidationError): + msg = ValidationError(msg) + if hasattr(msg, 'error_dict'): + self.error_dict[field] = [msg.error_dict] + else: + self.error_dict[field] = msg.error_list elif isinstance(message, list): self.error_list = [] for message in message: @@ -85,3 +89,7 @@ class ValidationError(Exception): if not isinstance(other, ValidationError): return NotImplemented return hash(self) == hash(other) + + +class ValidatorAssertError(Exception): + pass diff --git a/sanic_rest_framework/fields.py b/sanic_rest_framework/fields.py index 600fe27..e6ba032 100644 --- a/sanic_rest_framework/fields.py +++ b/sanic_rest_framework/fields.py @@ -18,6 +18,9 @@ from tortoise.queryset import QuerySet from tortoise.exceptions import DoesNotExist from sanic_rest_framework.exceptions import ValidationError +from sanic_rest_framework.validators import ( + MaxLengthValidator, MinLengthValidator, MaxValueValidator, MinValueValidator +) REGEX_TYPE = type(re.compile('')) @@ -183,7 +186,7 @@ class Field: :return: """ if not isinstance(data, Mapping): - raise ValidationError('传入的数据为无效数据类型,仅支持字段类型'.format(self.field_name)) + raise ValidationError('传入的数据为无效数据类型,仅支持字典类型'.format(self.field_name)) if self.field_name not in data: if self.is_partial(): return empty @@ -298,7 +301,7 @@ class Field: "属性中未能找到 Key 为 {key} 的错误描述".format(class_name=class_name, key=key) raise AssertionError(msg) message_string = msg.format(**kwargs) - raise ValidationError({self.field_name: message_string}, code=key) + raise ValidationError(message_string, code=key) class CharField(Field): @@ -312,8 +315,12 @@ class CharField(Field): self.max_length = kwargs.pop('max_length', None) self.min_length = kwargs.pop('min_length', None) self.trim_whitespace = kwargs.pop('trim_whitespace', True) - super(CharField, self).__init__(*args, *kwargs) + super(CharField, self).__init__(*args, **kwargs) # TODO 需要将 MaxLengthValidator 与 MinLengthValidator 添加入 self.validators + if self.max_length is not None: + self.validators.append(MaxLengthValidator(max_length=self.max_length, error_messages={'max_length': self.error_messages['max_length']})) + if self.min_length is not None: + self.validators.append(MinLengthValidator(min_length=self.min_length, error_messages={'min_length': self.error_messages['min_length']})) def external_to_internal(self, data: Any) -> Any: if isinstance(data, bool) or not isinstance(data, (str, int, float,)): @@ -339,8 +346,12 @@ class IntegerField(Field): def __init__(self, max_value=None, min_value=None, *args, **kwargs): self.max_value = max_value self.min_value = min_value - super(IntegerField, self).__init__(*args, *kwargs) + super(IntegerField, self).__init__(*args, **kwargs) # TODO 需要将 MaxValueValidator 与 MinValueValidator 添加入 self.validators + if self.max_value is not None: + self.validators.append(MaxValueValidator(max_value=self.max_value, error_messages={'max_value': self.error_messages['max_value']})) + if self.min_value is not None: + self.validators.append(MinValueValidator(min_value=self.min_value, error_messages={'min_value': self.error_messages['min_value']})) def external_to_internal(self, data: Any): if isinstance(data, str) and len(data) > self.MAX_STRING_LENGTH: diff --git a/sanic_rest_framework/serializers.py b/sanic_rest_framework/serializers.py index 2b318e2..881f999 100644 --- a/sanic_rest_framework/serializers.py +++ b/sanic_rest_framework/serializers.py @@ -248,7 +248,10 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass): if validate_method is not None: validated_value = validate_method(validated_value) except ValidationError as exc: - errors[field.field_name] = exc.error_dict + if isinstance(field, BaseSerializer): + errors[field.field_name] = exc.error_dict + else: + errors[field.field_name] = exc.error_list except SkipField: pass else: diff --git a/sanic_rest_framework/test/test_serializers/test_serializer.py b/sanic_rest_framework/test/test_serializers/test_serializer.py new file mode 100644 index 0000000..0b3134a --- /dev/null +++ b/sanic_rest_framework/test/test_serializers/test_serializer.py @@ -0,0 +1,40 @@ +""" +@Author:WangYuXiang +@E-mile:Hill@3io.cc +@CreateTime:2021/1/27 9:55 +@DependencyLibrary:无 +@MainFunction:无 +@FileDoc: + test_serializer.py + 测试文件 +@ChangeHistory: + datetime action why + example: + 2021/1/27 9:55 change 'Fix bug' + +""" +from sanic_rest_framework.fields import CharField, IntegerField, FloatField +from sanic_rest_framework.serializers import Serializer + + +class QianTaoserializer(Serializer): + name = CharField() + doc = CharField() + + +class TestSerializer(Serializer): + id = CharField(max_length=18) + qt = QianTaoserializer(required=True, allow_null=True) + ages = IntegerField(max_value=6) + + +data = { + 'id': '20', + 'qt': {}, + 'ages': 2, +} + +test = TestSerializer(data=data) +test.is_valid(raise_exception=True) +print(test.errors) +print(test.validated_data) diff --git a/sanic_rest_framework/validators.py b/sanic_rest_framework/validators.py index 42703ed..f6c0f24 100644 --- a/sanic_rest_framework/validators.py +++ b/sanic_rest_framework/validators.py @@ -13,7 +13,122 @@ 2021/1/22 10:41 change 'Fix bug' """ +import copy +from typing import Dict +from .exceptions import ValidationError, ValidatorAssertError -class MaxLengthValidator(): - pass + +# class MaxLengthValidator(): +# pass +# +# +# class MinLengthValidator(): +# pass + +class BaseValidator: + """验证器基类 + 所有通用验证器都需要继承本类, + 在调用 __call__ 时抛出 ValidationError 错误 + 即代表验证失败 + """ + default_error_messages: Dict[str, str] = { + + } + + def __init__(self, error_messages: Dict[str, str] = None, code=None): + self.error_messages = copy.copy(self.default_error_messages) + if error_messages is not None: + self.error_messages.update(copy.copy(error_messages)) + self.code = code + + def __call__(self, value, serializer=None): + raise NotImplementedError('验证器必须重新定义 __call__()') + + def raise_error(self, key, **kws): + msg = self.default_error_messages[key].format(**kws) + return ValidationError(msg, code=key) + + +class MaxLengthValidator(BaseValidator): + default_error_messages: Dict[str, str] = { + 'max_length': '超出长度,最长支持{max_length}', + 'invalid': '无效的数据类型,数据类型只支持{datatypes}' + } + + def __init__(self, max_length, **kwargs): + if not isinstance(max_length, (int, float)): + raise ValidatorAssertError('max_length的值只支持数值类型') + self.max_length = max_length + + super(MaxLengthValidator, self).__init__(**kwargs) + + def __call__(self, value, serializer=None): + if not isinstance(value, (str, list, dict, type)): + self.raise_error('invalid', datatypes='str, list, dict, type') + + if len(value) > self.max_length: + self.raise_error('max_length', max_length=self.max_length) + + +class MinLengthValidator(BaseValidator): + default_error_messages: Dict[str, str] = { + 'max_length': '低于最低长度,最低为 {min_length}', + 'invalid': '无效的数据类型,数据类型只支持 {datatypes} ' + + } + + def __init__(self, min_length, **kwargs): + if not isinstance(min_length, (int, float)): + raise ValidatorAssertError('min_length的值只支持数值类型') + + self.min_length = min_length + super(MinLengthValidator, self).__init__(**kwargs) + + def __call__(self, value, serializer=None): + if not isinstance(value, (str, list, dict, type)): + self.raise_error('invalid', datatypes='str, list, dict, type') + + if len(value) < self.min_length: + self.raise_error('min_length', min_length=self.min_length) + + +class MaxValueValidator(BaseValidator): + default_error_messages: Dict[str, str] = { + 'max_value': '超出最大值,最大值支持到{max_value}', + 'invalid': '无效的数据类型,数据类型只支持{datatypes}' + + } + + def __init__(self, max_value, **kwargs): + if not isinstance(max_value, (int, float)): + raise ValidatorAssertError('max_value的值只支持数值类型') + self.max_value = max_value + super(MaxValueValidator, self).__init__(**kwargs) + + def __call__(self, value, serializer=None): + if not isinstance(value, (int, float)): + self.raise_error('invalid', datatypes='int, float') + + if value > self.max_value: + self.raise_error('max_value', min_length=self.max_value) + + +class MinValueValidator(BaseValidator): + default_error_messages: Dict[str, str] = { + 'min_value': '低于最小值,最小值至少要为{min_value}', + 'invalid': '无效的数据类型,数据类型只支持{datatypes}' + } + + def __init__(self, min_value, **kwargs): + if not isinstance(min_value, (int, float)): + raise ValidatorAssertError('min_value的值只支持数值类型') + self.min_value = min_value + super(MinValueValidator, self).__init__(**kwargs) + + def __call__(self, value, serializer=None): + if not isinstance(value, (int, float)): + self.raise_error('invalid', datatypes='int, float') + + if value < self.min_value: + self.raise_error('min_value', min_length=self.min_value) -- Gitee From 156755c05e2de672b135c0e25364b4e03a8d6a22 Mon Sep 17 00:00:00 2001 From: LaoSi Date: Thu, 28 Jan 2021 18:05:19 +0800 Subject: [PATCH 12/34] =?UTF-8?q?=E5=8D=95=E5=85=83=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sanic_rest_framework/fields.py | 41 ++++---- sanic_rest_framework/serializers.py | 11 --- .../test/test_fields/test_char_field.py | 96 +++++++++++++++++++ .../test/test_serializers/test_serializer.py | 65 ++++++++++--- sanic_rest_framework/validators.py | 11 +-- 5 files changed, 177 insertions(+), 47 deletions(-) create mode 100644 sanic_rest_framework/test/test_fields/test_char_field.py diff --git a/sanic_rest_framework/fields.py b/sanic_rest_framework/fields.py index e6ba032..a86f236 100644 --- a/sanic_rest_framework/fields.py +++ b/sanic_rest_framework/fields.py @@ -42,7 +42,12 @@ NOT_RAED_ONLY_REQUIRED_ONLY = 'read_only 为 True 时 required 不能为True , class Field: - """字段及序列化器基类""" + """字段及序列化器基类 + required: 反序列化时是否必须存在,值限制写入时 + allow_null: 是否可以为 None 即存在当没值 + allow_empty: 是否可以为空 value = '' 即为空 + + """ _sort_counter = 0 # 所有field都强制拥有的错误提示 base_error_messages = { @@ -53,18 +58,19 @@ class Field: default_validators = None - def __init__(self, read_only=False, write_only=False, required=True, allow_null=False, - default=empty, source=None, validators=None, error_messages=None, + def __init__(self, read_only=False, write_only=False, required=False, allow_null=False, + allow_empty=False, default=empty, source=None, validators=None, error_messages=None, label=None, description=None): """ 字段及field的基类 - :param read_only: 只反序列化 - :param write_only: 只序列化 - :param required: 序列化时必须存在此值 - :param allow_null: 序列化可以为空 + :param read_only: 只序列化 + :param write_only: 只反序列化 + :param required: 反序列化时必须存在此值 + :param allow_null: 反序列化时可以为 None + :param allow_empty: 反序列化可以为 '' :param default: 默认值 可用于序列化和反序列化 - :param source: 反序列化是值的来源 - :param validators: 序列化时需要通过的验证 + :param source: 序列化时值的来源 + :param validators: 反序列化时数据需要通过的验证 :param error_messages: 出现错误时的自定义描述 :param label: 字段标题 :param description: 字段描述 @@ -79,6 +85,7 @@ class Field: self.write_only = write_only self.required = required self.allow_null = allow_null + self.allow_null = allow_empty self.default = default self.source = source self.label = label @@ -195,6 +202,7 @@ class Field: return value async def async_get_attribute(self, instance, attr): + """适用于 Model 对象""" return await getattr(instance, attr) def get_internal_value(self, instance: Any) -> Any: @@ -264,7 +272,6 @@ class Field: elif self.source == '*': return False, None return True, None - return (False, data) def run_validation(self, data): @@ -316,7 +323,6 @@ class CharField(Field): self.min_length = kwargs.pop('min_length', None) self.trim_whitespace = kwargs.pop('trim_whitespace', True) super(CharField, self).__init__(*args, **kwargs) - # TODO 需要将 MaxLengthValidator 与 MinLengthValidator 添加入 self.validators if self.max_length is not None: self.validators.append(MaxLengthValidator(max_length=self.max_length, error_messages={'max_length': self.error_messages['max_length']})) if self.min_length is not None: @@ -347,7 +353,6 @@ class IntegerField(Field): self.max_value = max_value self.min_value = min_value super(IntegerField, self).__init__(*args, **kwargs) - # TODO 需要将 MaxValueValidator 与 MinValueValidator 添加入 self.validators if self.max_value is not None: self.validators.append(MaxValueValidator(max_value=self.max_value, error_messages={'max_value': self.error_messages['max_value']})) if self.min_value is not None: @@ -604,12 +609,12 @@ class DateField(Field): data = datetime.strptime(data, self.input_formats).date() return data except (ValueError, TypeError): - self.raise_error('invalid', data=data) + self.raise_error('invalid', value=data) if isinstance(data, datetime): self.raise_error('datetime') if isinstance(data, date): return data - self.raise_error('invalid', data=data) + self.raise_error('invalid', value=data) def internal_to_external(self, data: Any) -> Any: if not data: @@ -625,6 +630,7 @@ class TimeField(Field): """时间字段""" default_error_messages = { 'invalid': '出现错误的数据类型,{value}不是有效的日期时间类型', + 'format': '时间格式错误,需要格式为 %H:%M:%S ', 'date': '需要的是时间格式而不是日期格式', } @@ -639,12 +645,12 @@ class TimeField(Field): data = datetime.strptime(data, self.input_formats).time() return data except (ValueError, TypeError): - self.raise_error('invalid') + self.raise_error('format') if isinstance(data, datetime): return data.time() if isinstance(data, date): self.raise_error('date') - self.raise_error('invalid') + self.raise_error('invalid', value=type(data)) def external_convert(self, data: Any) -> Any: if not data: @@ -709,6 +715,9 @@ class SerializerMethodField(Field): self.method_name = method_name kwargs['source'] = '*' kwargs['read_only'] = True + if kwargs.get('required'): + assert 'SerializerMethodField 为只读字段,不能反序列化' + super().__init__(**kwargs) def bind(self, field_name, parent): diff --git a/sanic_rest_framework/serializers.py b/sanic_rest_framework/serializers.py index 881f999..97b3dd2 100644 --- a/sanic_rest_framework/serializers.py +++ b/sanic_rest_framework/serializers.py @@ -260,17 +260,6 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass): raise ValidationError(errors) return res - # def run_validators(self, value): - # """ - # Add read_only fields with defaults to value before running validators. - # """ - # if isinstance(value, dict): - # to_validate = self._read_only_defaults() - # to_validate.update(value) - # else: - # to_validate = value - # super().run_validators(to_validate) - def validate(self, attrs): return attrs diff --git a/sanic_rest_framework/test/test_fields/test_char_field.py b/sanic_rest_framework/test/test_fields/test_char_field.py new file mode 100644 index 0000000..80b2f86 --- /dev/null +++ b/sanic_rest_framework/test/test_fields/test_char_field.py @@ -0,0 +1,96 @@ +""" +@Author:WangYuXiang +@E-mile:Hill@3io.cc +@CreateTime:2021/1/28 16:16 +@DependencyLibrary:无 +@MainFunction:无 +@FileDoc: + test_char_field.py + 字符字段单元测试 +@ChangeHistory: + datetime action why + example: + 2021/1/28 16:16 change 'Fix bug' + +""" + +import unittest + +from sanic_rest_framework.exceptions import ValidationError +from sanic_rest_framework.fields import CharField + + +class TestCharField(unittest.TestCase): + def test_get_external_value(self): + data1 = {'char': 'Python'} + data2 = {'char': 66666} + char = CharField() + char.bind('char', char) + self.assertEqual(char.get_external_value(data1), 'Python') + self.assertEqual(char.get_external_value(data2), 66666) + + def test_external_to_internal(self): + data = ' Python' + char1 = CharField() + self.assertEqual(char1.external_to_internal(data), 'Python') + + def test_get_internal_value(self): + data1 = {'char': 'Python'} + data2 = {'char': 66666} + char = CharField() + char.bind('char', char) + + value1 = char.get_internal_value(data1) + self.assertEqual(value1, 'Python') + value2 = char.get_internal_value(data2) + self.assertEqual(value2, 66666) + + def test_internal_to_external(self): + data1 = {'char1': 'Python'} + data2 = {'char1': 66666} + char1 = CharField() + char1.bind('char1', char1) + + value = char1.get_internal_value(data1) + self.assertEqual(char1.internal_to_external(value), 'Python') + + value = char1.get_internal_value(data2) + self.assertEqual(char1.internal_to_external(value), '66666') + + def test_trim_whitespace(self): + data = ' Python' + char1 = CharField() + char2 = CharField(trim_whitespace=True) + char3 = CharField(trim_whitespace=False) + c1_data = char1.external_to_internal(data) + c2_data = char2.external_to_internal(data) + c3_data = char3.external_to_internal(data) + self.assertEqual(c1_data, 'Python') + self.assertEqual(c2_data, 'Python') + self.assertEqual(c3_data, ' Python') + + def test_max_length(self): + data = 'Python' + char1 = CharField() + char2 = CharField(max_length=10) + char3 = CharField(max_length=5) + self.assertEqual(char1.run_validation(data), 'Python') + self.assertEqual(char2.run_validation(data), 'Python') + + with self.assertRaises(ValidationError): + char3.run_validation(data) + + def test_min_length(self): + data = 'Python' + char1 = CharField() + char2 = CharField(min_length=5) + char3 = CharField(min_length=10) + self.assertEqual(char1.run_validation(data), 'Python') + self.assertEqual(char2.run_validation(data), 'Python') + + with self.assertRaises(ValidationError): + char3.run_validation(data) + + +if __name__ == '__main__': + unittest.main() diff --git a/sanic_rest_framework/test/test_serializers/test_serializer.py b/sanic_rest_framework/test/test_serializers/test_serializer.py index 0b3134a..68179dc 100644 --- a/sanic_rest_framework/test/test_serializers/test_serializer.py +++ b/sanic_rest_framework/test/test_serializers/test_serializer.py @@ -13,8 +13,12 @@ 2021/1/27 9:55 change 'Fix bug' """ -from sanic_rest_framework.fields import CharField, IntegerField, FloatField + +from sanic_rest_framework.fields import ( + CharField, IntegerField, FloatField, DateField, TimeField, DecimalField, + DateTimeField, BooleanField, ChoiceField, SerializerMethodField) from sanic_rest_framework.serializers import Serializer +import unittest class QianTaoserializer(Serializer): @@ -22,19 +26,58 @@ class QianTaoserializer(Serializer): doc = CharField() -class TestSerializer(Serializer): +class BaseSerializer(Serializer): id = CharField(max_length=18) qt = QianTaoserializer(required=True, allow_null=True) ages = IntegerField(max_value=6) -data = { - 'id': '20', - 'qt': {}, - 'ages': 2, -} +class FieldRequiredSerializer(Serializer): + char_field = CharField(required=True) + integer_field = IntegerField(required=True) + float_field = FloatField(required=True) + decimal_field = DecimalField(required=True, max_digits=9, decimal_places=3) + date_field = DateField(required=True) + time_field = TimeField(required=True) + datetime_field = DateTimeField(required=True) + boolean_field = BooleanField(required=True) + choice_field = ChoiceField(required=True, choices=[(1, '老四')]) + serializer_method_field = SerializerMethodField() + + +class TestFieldRequired(unittest.TestCase): + def test_has_data(self): + data = { + 'char_field': '', + 'integer_field': '', + 'float_field': '', + 'decimal_field': '', + 'date_field': '', + 'time_field': '', + 'datetime_field': '', + 'boolean_field': '', + 'choice_field': '', + 'serializer_method_field': '', + } + frs = FieldRequiredSerializer(data=data) + self.assertEqual(frs.is_valid(), False) + + def test_success_data(self): + data = { + 'char_field': 'NiHao', + 'integer_field': 1, + 'float_field': 80.06, + 'decimal_field': 99620, + 'date_field': '2017-16-28', + 'time_field': '16:18:11', + 'datetime_field': '2017-16-28 16:18:11', + 'boolean_field': '0', + 'choice_field': '80', + 'serializer_method_field': 'aaa', + } + frs = FieldRequiredSerializer(data=data) + self.assertEqual(frs.is_valid(), True) + -test = TestSerializer(data=data) -test.is_valid(raise_exception=True) -print(test.errors) -print(test.validated_data) +if __name__ == '__main__': + unittest.main() diff --git a/sanic_rest_framework/validators.py b/sanic_rest_framework/validators.py index f6c0f24..f213272 100644 --- a/sanic_rest_framework/validators.py +++ b/sanic_rest_framework/validators.py @@ -19,13 +19,6 @@ from typing import Dict from .exceptions import ValidationError, ValidatorAssertError -# class MaxLengthValidator(): -# pass -# -# -# class MinLengthValidator(): -# pass - class BaseValidator: """验证器基类 所有通用验证器都需要继承本类, @@ -47,7 +40,7 @@ class BaseValidator: def raise_error(self, key, **kws): msg = self.default_error_messages[key].format(**kws) - return ValidationError(msg, code=key) + raise ValidationError(msg, code=key) class MaxLengthValidator(BaseValidator): @@ -73,7 +66,7 @@ class MaxLengthValidator(BaseValidator): class MinLengthValidator(BaseValidator): default_error_messages: Dict[str, str] = { - 'max_length': '低于最低长度,最低为 {min_length}', + 'min_length': '低于最低长度,最低为 {min_length}', 'invalid': '无效的数据类型,数据类型只支持 {datatypes} ' } -- Gitee From ecc66c480c8b2954af379a0f5dfb50c6de88ed2d Mon Sep 17 00:00:00 2001 From: LaoSi Date: Fri, 29 Jan 2021 18:02:56 +0800 Subject: [PATCH 13/34] =?UTF-8?q?=E6=B5=8B=E8=AF=95=E5=B9=B6=E4=BF=AE?= =?UTF-8?q?=E6=94=B9BUG?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sanic_rest_framework/fields.py | 21 +- .../test/test_fields/test_decimal_field.py | 188 ++++++++++++++++++ .../test/test_fields/test_float_field.py | 166 ++++++++++++++++ .../test/test_fields/test_integer_field.py | 150 ++++++++++++++ sanic_rest_framework/test/utils.py | 35 ++++ sanic_rest_framework/validators.py | 4 +- 6 files changed, 558 insertions(+), 6 deletions(-) create mode 100644 sanic_rest_framework/test/test_fields/test_decimal_field.py create mode 100644 sanic_rest_framework/test/test_fields/test_float_field.py create mode 100644 sanic_rest_framework/test/test_fields/test_integer_field.py create mode 100644 sanic_rest_framework/test/utils.py diff --git a/sanic_rest_framework/fields.py b/sanic_rest_framework/fields.py index a86f236..742c854 100644 --- a/sanic_rest_framework/fields.py +++ b/sanic_rest_framework/fields.py @@ -12,7 +12,7 @@ import copy import decimal import re from datetime import timezone, timedelta, datetime, date, time -from typing import Any, AnyStr, Optional, List, Mapping +from typing import Any, List, Mapping from tortoise import Model from tortoise.queryset import QuerySet from tortoise.exceptions import DoesNotExist @@ -243,7 +243,7 @@ class Field: try: validator(data, self) except ValidationError as exc: - if hasattr(exc, 'code') and exc.code in self.error_messages: + if hasattr(exc, 'code') and exc.code in self.error_messages and isinstance(exc.message, ValidationError): exc.message = self.error_messages[exc.code] errors.extend(exc.error_list) if errors: @@ -374,7 +374,7 @@ class IntegerField(Field): class FloatField(IntegerField): """浮点类型""" default_error_messages = { - 'invalid': '出现错误的数据类型,仅支持浮点类型', + 'invalid': '出现错误的数据类型{data_type},仅支持浮点类型', 'max_value': '仅支持小于{max_value}浮点', 'min_value': '仅支持大于{min_value}浮点', 'max_string_length': '仅支持转换长度小于{max_string_length}的浮点字符串', @@ -382,12 +382,14 @@ class FloatField(IntegerField): MAX_STRING_LENGTH = 1000 def external_to_internal(self, data: Any): + if isinstance(data, bool): + self.raise_error('invalid', data_type=type(data).__name__) if isinstance(data, str) and len(data) > self.MAX_STRING_LENGTH: self.raise_error('max_string_length', max_string_length=self.MAX_STRING_LENGTH) try: return float(data) except (TypeError, ValueError): - self.raise_error('invalid') + self.raise_error('invalid', data_type=type(data).__name__) def internal_to_external(self, data: Any): return float(data) @@ -408,6 +410,17 @@ class DecimalField(Field): def __init__(self, max_digits, decimal_places, coerce_to_string=False, max_value=None, min_value=None, rounding=None, *args, **kwargs): + """ + 整数位数 = max_digits - decimal_places + :param max_digits: 数字允许的最大位数 + :param decimal_places: 小数的最大位数 + :param coerce_to_string: + :param max_value: + :param min_value: + :param rounding: + :param args: + :param kwargs: + """ self.max_digits = max_digits self.decimal_places = decimal_places self.coerce_to_string = coerce_to_string diff --git a/sanic_rest_framework/test/test_fields/test_decimal_field.py b/sanic_rest_framework/test/test_fields/test_decimal_field.py new file mode 100644 index 0000000..7754f02 --- /dev/null +++ b/sanic_rest_framework/test/test_fields/test_decimal_field.py @@ -0,0 +1,188 @@ +""" +@Author:WangYuXiang +@E-mile:Hill@3io.cc +@CreateTime:2021/1/29 10:23 +@DependencyLibrary:无 +@MainFunction:无 +@FileDoc: + test_decimal_field.py + 测试十进制类型字段 +@ChangeHistory: + datetime action why + example: + 2021/1/29 10:23 change 'Fix bug' + +""" + +import unittest +from decimal import Decimal, ROUND_DOWN, getcontext +from sanic_rest_framework.exceptions import ValidationError +from sanic_rest_framework.fields import DecimalField as TestField +from sanic_rest_framework.test.utils import TestDataMixin + + +class TestDecimalField(TestDataMixin, unittest.TestCase): + + def test_validate_precision(self): + """ + 测试 max_digits, decimal_places + :return: + """ + + tf = TestField(max_digits=6, decimal_places=2) + with self.assertRaises(ValidationError): + tf.validate_precision(Decimal('99.999')) + with self.assertRaises(ValidationError): + tf.validate_precision(Decimal('99999.99')) + self.assertEqual(tf.validate_precision(Decimal('99.99')), Decimal('99.99')) + self.assertEqual(tf.validate_precision(Decimal('9999.99')), Decimal('9999.99')) + with self.assertRaises(ValidationError): + tf.validate_precision(Decimal('999.999')) + + def test_quantize(self): + tf = TestField(max_digits=6, decimal_places=2) + context = getcontext() # 获取decimal现在的上下文 + context.rounding = ROUND_DOWN + self.assertEqual(tf.quantize(Decimal('99.99999999')), Decimal('99.99')) + + def test_get_external_value(self): + """得到外部传入的值""" + data1 = {'tf': self.str_max_int} + data2 = {'tf': self.max_int} + data3 = {'tf': self.str_chinese} + tf = TestField() + tf.bind('tf', tf) + self.assertEqual(tf.get_external_value(data1), self.str_max_int) + self.assertEqual(tf.get_external_value(data2), self.max_int) + self.assertEqual(tf.get_external_value(data3), self.str_chinese) + + def test_external_to_internal(self): + """ + 外转内 str -> dict 是严格的,不符合类型的都应该报错, + 一切都要经过验证 + float => [1,'1','1.0',1.0',1.6,'1.6'] + :return: + """ + tf1 = TestField() + with self.assertRaises(ValidationError): + tf1.external_to_internal(self.bool_True) + with self.assertRaises(ValidationError): + tf1.external_to_internal(self.str_chinese) + + self.assertEqual(tf1.external_to_internal(self.str_pi), self.pi) + self.assertEqual(tf1.external_to_internal(self.pi), self.pi) + self.assertEqual(tf1.external_to_internal(self.str_max_float), self.max_float) + self.assertEqual(tf1.external_to_internal(self.max_int), self.max_int) + self.assertEqual(tf1.external_to_internal(self.str_max_int), self.max_int) + + def test_get_internal_value(self): + """测试由 instance 得到内部的值 """ + data1 = {'tf': self.str_chinese} + data2 = {'tf': self.str_england} + data3 = {'tf': self.str_max_float} + data4 = {'tf': self.max_float} + data6 = {'tf': self.max_int} + data7 = {'tf': self.str_max_int} + tf = TestField() + tf.bind('tf', tf) + + value = tf.get_internal_value(data1) + self.assertEqual(value, self.str_chinese) + value = tf.get_internal_value(data2) + self.assertEqual(value, self.str_england) + value = tf.get_internal_value(data3) + self.assertEqual(value, self.str_max_float) + value = tf.get_internal_value(data4) + self.assertEqual(value, self.max_float) + value = tf.get_internal_value(data6) + self.assertEqual(value, self.max_int) + value = tf.get_internal_value(data7) + self.assertEqual(value, self.str_max_int) + + def test_internal_to_external(self): + """ + 内转外 str -> dict 是宽松的, + 只要是数值类型都不报错,float(xx) + :return: + """ + data1 = {'tf1': self.str_chinese} + data2 = {'tf1': self.max_int} + data3 = {'tf1': self.str_max_int} + data4 = {'tf1': self.str_max_float} + data5 = {'tf1': self.max_float} + + tf = TestField() + tf.bind('tf1', tf) + + value = tf.get_internal_value(data1) + with self.assertRaises(ValueError): + tf.internal_to_external(value) + + value = tf.get_internal_value(data2) + self.assertEqual(tf.internal_to_external(value), self.max_int) + + value = tf.get_internal_value(data3) + self.assertEqual(tf.internal_to_external(value), self.max_int) + value = tf.get_internal_value(data4) + self.assertEqual(tf.internal_to_external(value), self.max_float) + value = tf.get_internal_value(data5) + self.assertEqual(tf.internal_to_external(value), self.max_float) + + def test_max_value(self): + tf1 = TestField() + tf2 = TestField(max_value=10.69) + + # 未设置 不存在超出限制 + self.assertEqual(tf1.run_validators(data=self.str_max_int), None) + self.assertEqual(tf1.run_validators(data=self.min_float), None) + self.assertEqual(tf1.run_validators(data=self.str_max_float), None) + self.assertEqual(tf1.run_validators(data=self.max_float), None) + + # 超出限制 + with self.assertRaises(ValidationError): + self.assertEqual(tf2.run_validators(data=self.max_int), None) + with self.assertRaises(ValidationError): + self.assertEqual(tf2.run_validators(data=self.max_float), None) + with self.assertRaises(ValidationError): + self.assertEqual(tf2.run_validators(data=self.str_max_int), None) + with self.assertRaises(ValidationError): + self.assertEqual(tf2.run_validators(data=self.str_max_float), None) + + self.assertEqual(tf2.run_validators(data=self.min_int), None) + self.assertEqual(tf2.run_validators(data=self.min_float), None) + + def test_min_length(self): + tf1 = TestField() + tf2 = TestField(min_value=10.69) + + # 未设置 不存在超出限制 + self.assertEqual(tf1.run_validators(data=self.str_max_int), None) + self.assertEqual(tf1.run_validators(data=self.min_float), None) + self.assertEqual(tf1.run_validators(data=self.str_max_float), None) + self.assertEqual(tf1.run_validators(data=self.max_float), None) + + # 超出限制 + with self.assertRaises(ValidationError): + self.assertEqual(tf2.run_validators(data=self.min_float), None) + with self.assertRaises(ValidationError): + self.assertEqual(tf2.run_validators(data=self.min_int), None) + + # 不支持其除 int float 以外的格式 + with self.assertRaises(ValidationError): + self.assertEqual(tf2.run_validators(data=self.str_min_float), None) + with self.assertRaises(ValidationError): + self.assertEqual(tf2.run_validators(data=self.str_min_int), None) + + self.assertEqual(tf2.run_validators(data=self.max_int), None) + self.assertEqual(tf2.run_validators(data=self.max_float), None) + + def test_max_string_length(self): + tf1 = TestField() + with self.assertRaises(ValidationError): + # 超出约定长度 + tf1.external_to_internal(self.long_str) + self.assertEqual(tf1.external_to_internal(self.str_max_int), self.max_int) + + +if __name__ == '__main__': + unittest.main() diff --git a/sanic_rest_framework/test/test_fields/test_float_field.py b/sanic_rest_framework/test/test_fields/test_float_field.py new file mode 100644 index 0000000..996cf49 --- /dev/null +++ b/sanic_rest_framework/test/test_fields/test_float_field.py @@ -0,0 +1,166 @@ +""" +@Author:WangYuXiang +@E-mile:Hill@3io.cc +@CreateTime:2021/1/29 10:23 +@DependencyLibrary:无 +@MainFunction:无 +@FileDoc: + test_float_field.py + 测试浮点类型字段 +@ChangeHistory: + datetime action why + example: + 2021/1/29 10:23 change 'Fix bug' + +""" + +import unittest + +from sanic_rest_framework.exceptions import ValidationError +from sanic_rest_framework.fields import FloatField as TestField +from sanic_rest_framework.test.utils import TestDataMixin + + +class TestFloatField(TestDataMixin, unittest.TestCase): + + def test_get_external_value(self): + """得到外部传入的值""" + data1 = {'tf': self.str_max_int} + data2 = {'tf': self.max_int} + data3 = {'tf': self.str_chinese} + tf = TestField() + tf.bind('tf', tf) + self.assertEqual(tf.get_external_value(data1), self.str_max_int) + self.assertEqual(tf.get_external_value(data2), self.max_int) + self.assertEqual(tf.get_external_value(data3), self.str_chinese) + + def test_external_to_internal(self): + """ + 外转内 str -> dict 是严格的,不符合类型的都应该报错, + 一切都要经过验证 + float => [1,'1','1.0',1.0',1.6,'1.6'] + :return: + """ + tf1 = TestField() + with self.assertRaises(ValidationError): + tf1.external_to_internal(self.bool_True) + with self.assertRaises(ValidationError): + tf1.external_to_internal(self.str_chinese) + + self.assertEqual(tf1.external_to_internal(self.str_pi), self.pi) + self.assertEqual(tf1.external_to_internal(self.pi), self.pi) + self.assertEqual(tf1.external_to_internal(self.str_max_float), self.max_float) + self.assertEqual(tf1.external_to_internal(self.max_int), self.max_int) + self.assertEqual(tf1.external_to_internal(self.str_max_int), self.max_int) + + def test_get_internal_value(self): + """测试由 instance 得到内部的值 """ + data1 = {'tf': self.str_chinese} + data2 = {'tf': self.str_england} + data3 = {'tf': self.str_max_float} + data4 = {'tf': self.max_float} + data6 = {'tf': self.max_int} + data7 = {'tf': self.str_max_int} + tf = TestField() + tf.bind('tf', tf) + + value = tf.get_internal_value(data1) + self.assertEqual(value, self.str_chinese) + value = tf.get_internal_value(data2) + self.assertEqual(value, self.str_england) + value = tf.get_internal_value(data3) + self.assertEqual(value, self.str_max_float) + value = tf.get_internal_value(data4) + self.assertEqual(value, self.max_float) + value = tf.get_internal_value(data6) + self.assertEqual(value, self.max_int) + value = tf.get_internal_value(data7) + self.assertEqual(value, self.str_max_int) + + def test_internal_to_external(self): + """ + 内转外 str -> dict 是宽松的, + 只要是数值类型都不报错,float(xx) + :return: + """ + data1 = {'tf1': self.str_chinese} + data2 = {'tf1': self.max_int} + data3 = {'tf1': self.str_max_int} + data4 = {'tf1': self.str_max_float} + data5 = {'tf1': self.max_float} + + tf = TestField() + tf.bind('tf1', tf) + + value = tf.get_internal_value(data1) + with self.assertRaises(ValueError): + tf.internal_to_external(value) + + value = tf.get_internal_value(data2) + self.assertEqual(tf.internal_to_external(value), self.max_int) + + value = tf.get_internal_value(data3) + self.assertEqual(tf.internal_to_external(value), self.max_int) + value = tf.get_internal_value(data4) + self.assertEqual(tf.internal_to_external(value), self.max_float) + value = tf.get_internal_value(data5) + self.assertEqual(tf.internal_to_external(value), self.max_float) + + def test_max_value(self): + tf1 = TestField() + tf2 = TestField(max_value=10.69) + + # 未设置 不存在超出限制 + self.assertEqual(tf1.run_validators(data=self.str_max_int), None) + self.assertEqual(tf1.run_validators(data=self.min_float), None) + self.assertEqual(tf1.run_validators(data=self.str_max_float), None) + self.assertEqual(tf1.run_validators(data=self.max_float), None) + + # 超出限制 + with self.assertRaises(ValidationError): + self.assertEqual(tf2.run_validators(data=self.max_int), None) + with self.assertRaises(ValidationError): + self.assertEqual(tf2.run_validators(data=self.max_float), None) + with self.assertRaises(ValidationError): + self.assertEqual(tf2.run_validators(data=self.str_max_int), None) + with self.assertRaises(ValidationError): + self.assertEqual(tf2.run_validators(data=self.str_max_float), None) + + self.assertEqual(tf2.run_validators(data=self.min_int), None) + self.assertEqual(tf2.run_validators(data=self.min_float), None) + + def test_min_length(self): + tf1 = TestField() + tf2 = TestField(min_value=10.69) + + # 未设置 不存在超出限制 + self.assertEqual(tf1.run_validators(data=self.str_max_int), None) + self.assertEqual(tf1.run_validators(data=self.min_float), None) + self.assertEqual(tf1.run_validators(data=self.str_max_float), None) + self.assertEqual(tf1.run_validators(data=self.max_float), None) + + # 超出限制 + with self.assertRaises(ValidationError): + self.assertEqual(tf2.run_validators(data=self.min_float), None) + with self.assertRaises(ValidationError): + self.assertEqual(tf2.run_validators(data=self.min_int), None) + + # 不支持其除 int float 以外的格式 + with self.assertRaises(ValidationError): + self.assertEqual(tf2.run_validators(data=self.str_min_float), None) + with self.assertRaises(ValidationError): + self.assertEqual(tf2.run_validators(data=self.str_min_int), None) + + self.assertEqual(tf2.run_validators(data=self.max_int), None) + self.assertEqual(tf2.run_validators(data=self.max_float), None) + + def test_max_string_length(self): + tf1 = TestField() + with self.assertRaises(ValidationError): + # 超出约定长度 + tf1.external_to_internal(self.long_str) + self.assertEqual(tf1.external_to_internal(self.str_max_int), self.max_int) + + +if __name__ == '__main__': + unittest.main() diff --git a/sanic_rest_framework/test/test_fields/test_integer_field.py b/sanic_rest_framework/test/test_fields/test_integer_field.py new file mode 100644 index 0000000..8c7654e --- /dev/null +++ b/sanic_rest_framework/test/test_fields/test_integer_field.py @@ -0,0 +1,150 @@ +""" +@Author:WangYuXiang +@E-mile:Hill@3io.cc +@CreateTime:2021/1/29 10:23 +@DependencyLibrary:无 +@MainFunction:无 +@FileDoc: + test_integer_field.py + 测试整数类型字段 +@ChangeHistory: + datetime action why + example: + 2021/1/29 10:23 change 'Fix bug' + +""" + +import unittest + +from sanic_rest_framework.exceptions import ValidationError +from sanic_rest_framework.fields import IntegerField as TestField +from sanic_rest_framework.test.utils import TestDataMixin + + +class TestCharField(TestDataMixin, unittest.TestCase): + + def test_get_external_value(self): + data1 = {'tf': self.str_max_float} + data2 = {'tf': self.max_float} + data3 = {'tf': self.str_chinese} + tf = TestField() + tf.bind('tf', tf) + self.assertEqual(tf.get_external_value(data1), self.str_max_float) + self.assertEqual(tf.get_external_value(data2), self.max_float) + self.assertEqual(tf.get_external_value(data3), self.str_chinese) + + def test_external_to_internal(self): + """ + 外转内 str -> dict 是严格的,不符合类型的都应该报错, + 一切都要经过验证 + int => [1,'1','1.0',1.0',1.] + :return: + """ + + tf = TestField() + with self.assertRaises(ValidationError): + tf.external_to_internal(self.max_float) + with self.assertRaises(ValidationError): + tf.external_to_internal(self.str_max_float) + with self.assertRaises(ValidationError): + tf.external_to_internal(self.str_chinese) + with self.assertRaises(ValidationError): + tf.external_to_internal(self.bool_True) + + self.assertEqual(tf.external_to_internal(self.max_int), self.max_int) + self.assertEqual(tf.external_to_internal(self.str_max_int), self.max_int) + + def test_get_internal_value(self): + data1 = {'tf': 'Python'} + data2 = {'tf': 66666} + tf = TestField() + tf.bind('tf', tf) + + value1 = tf.get_internal_value(data1) + self.assertEqual(value1, 'Python') + value2 = tf.get_internal_value(data2) + self.assertEqual(value2, 66666) + + def test_internal_to_external(self): + """ + 内转外 str -> dict 是宽松的, + 只要是数值类型都不报错,int(xx) + :return: + """ + data1 = {'tf1': self.str_chinese} + data2 = {'tf1': self.max_int} + data3 = {'tf1': self.str_max_int} + data4 = {'tf1': self.str_max_float} + data5 = {'tf1': self.max_float} + + tf = TestField() + tf.bind('tf1', tf) + + value = tf.get_internal_value(data1) + with self.assertRaises(ValueError): + tf.internal_to_external(value) + + value = tf.get_internal_value(data2) + self.assertEqual(tf.internal_to_external(value), self.max_int) + + value = tf.get_internal_value(data3) + self.assertEqual(tf.internal_to_external(value), self.max_int) + value = tf.get_internal_value(data4) + + with self.assertRaises(ValueError): + # str_float 不可被转换 + tf.internal_to_external(value) + + value = tf.get_internal_value(data5) + self.assertEqual(tf.internal_to_external(value), self.max_int) + + def test_max_value(self): + tf1 = TestField() + tf2 = TestField(max_value=10) + + with self.assertRaises(ValidationError): + # 格式不支持str + self.assertEqual(tf1.run_validators(data=self.str_max_int), None) + self.assertEqual(tf2.run_validators(data=self.str_max_int), None) + # 超出限制 + + # 未设置 不存在超出限制 + self.assertEqual(tf1.run_validators(data=self.max_int), None) + self.assertEqual(tf1.run_validators(data=self.min_int), None) + + # 超出限制 + with self.assertRaises(ValidationError): + self.assertEqual(tf2.run_validators(data=self.max_int), None) + self.assertEqual(tf2.run_validators(data=self.min_int), None) + + def test_min_length(self): + tf1 = TestField() + tf2 = TestField(min_value=10) + + with self.assertRaises(ValidationError): + # 格式不支持str + self.assertEqual(tf1.run_validators(data=self.str_max_int), None) + + with self.assertRaises(ValidationError): + self.assertEqual(tf2.run_validators(data=self.str_max_int), None) + # 超出限制 + + # 未设置 不存在超出限制 + self.assertEqual(tf1.run_validators(data=self.max_int), None) + self.assertEqual(tf1.run_validators(data=self.min_int), None) + + # 低于最小值 + with self.assertRaises(ValidationError): + self.assertEqual(tf2.run_validators(data=self.min_int), None) + self.assertEqual(tf2.run_validators(data=self.max_int), None) + + def test_max_string_length(self): + tf1 = TestField() + with self.assertRaises(ValidationError): + # 超出约定长度 + tf1.external_to_internal(self.long_str) + self.assertEqual(tf1.external_to_internal(self.str_max_int), self.max_int) + + +if __name__ == '__main__': + unittest.main() diff --git a/sanic_rest_framework/test/utils.py b/sanic_rest_framework/test/utils.py new file mode 100644 index 0000000..b2b5b8f --- /dev/null +++ b/sanic_rest_framework/test/utils.py @@ -0,0 +1,35 @@ +""" +@Author:WangYuXiang +@E-mile:Hill@3io.cc +@CreateTime:2021/1/29 16:38 +@DependencyLibrary:无 +@MainFunction:无 +@FileDoc: + utils.py + 文件说明 +@ChangeHistory: + datetime action why + example: + 2021/1/29 16:38 change 'Fix bug' + +""" +class TestDataMixin: + min_int = 1 + max_int = 9999 + min_float = 0.01 + max_float = 9999.99 + pi = 3.1415926 + str_min_int = '1' + str_max_int = '9999' + str_min_float = '0.01' + str_max_float = '9999.99' + str_pi = '3.1415926' + str_chinese_w = ' 测试字符 ' + str_chinese = '测试字符' + str_england = 'Test String' + str_england_w = ' Test String ' + bool_True = True + bool_False = False + str_bool_True = 'True' + str_bool_False = 'False' + long_str = '996' * 600 \ No newline at end of file diff --git a/sanic_rest_framework/validators.py b/sanic_rest_framework/validators.py index f213272..4e7daa3 100644 --- a/sanic_rest_framework/validators.py +++ b/sanic_rest_framework/validators.py @@ -104,7 +104,7 @@ class MaxValueValidator(BaseValidator): self.raise_error('invalid', datatypes='int, float') if value > self.max_value: - self.raise_error('max_value', min_length=self.max_value) + self.raise_error('max_value', max_value=self.max_value) class MinValueValidator(BaseValidator): @@ -124,4 +124,4 @@ class MinValueValidator(BaseValidator): self.raise_error('invalid', datatypes='int, float') if value < self.min_value: - self.raise_error('min_value', min_length=self.min_value) + self.raise_error('min_value', min_value=self.min_value) -- Gitee From 5b52958c7172788e2d483bc9440a35d492372940 Mon Sep 17 00:00:00 2001 From: LaoSi Date: Tue, 2 Feb 2021 17:56:10 +0800 Subject: [PATCH 14/34] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E5=AD=97=E6=AE=B5?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sanic_rest_framework/fields.py | 10 +- .../test/test_fields/test_base_field.py | 135 ++++++++++++++++++ .../test/test_fields/test_decimal_field.py | 68 +++++---- .../test/test_fields/test_float_field.py | 2 +- .../test/test_fields/test_integer_field.py | 7 +- 5 files changed, 190 insertions(+), 32 deletions(-) create mode 100644 sanic_rest_framework/test/test_fields/test_base_field.py diff --git a/sanic_rest_framework/fields.py b/sanic_rest_framework/fields.py index 742c854..786a5db 100644 --- a/sanic_rest_framework/fields.py +++ b/sanic_rest_framework/fields.py @@ -257,6 +257,11 @@ class Field: return self.default def validate_empty_values(self, data): + """ + 判断是否为空值 + :param data: Any + :return: (bool,Any) => (是否为空 ,data) + """ if self.read_only: return True, self.get_default() if data is empty: @@ -432,7 +437,10 @@ class DecimalField(Field): else: self.max_whole_digits = None super(DecimalField, self).__init__(*args, **kwargs) - # TODO 需要将 MaxLengthValidator 与 MinLengthValidator 添加入 self.validators + if self.max_value is not None: + self.validators.append(MaxValueValidator(max_value=self.max_value, error_messages={'max_value': self.error_messages['max_value']})) + if self.min_value is not None: + self.validators.append(MinValueValidator(min_value=self.min_value, error_messages={'min_value': self.error_messages['min_value']})) def external_to_internal(self, data: Any): data = str(data).strip() diff --git a/sanic_rest_framework/test/test_fields/test_base_field.py b/sanic_rest_framework/test/test_fields/test_base_field.py new file mode 100644 index 0000000..e87c243 --- /dev/null +++ b/sanic_rest_framework/test/test_fields/test_base_field.py @@ -0,0 +1,135 @@ +""" +@Author:WangYuXiang +@E-mile:Hill@3io.cc +@CreateTime:2021/2/2 15:58 +@DependencyLibrary:无 +@MainFunction:无 +@FileDoc: + test_base_field.py + 字段 +@ChangeHistory: + datetime action why + example: + 2021/2/2 15:58 change 'Fix bug' + +""" +import unittest +from sanic_rest_framework.exceptions import ValidationError +from sanic_rest_framework.fields import Field as TestField, empty, SkipField +from sanic_rest_framework.test.utils import TestDataMixin +from sanic_rest_framework.validators import MaxValueValidator + + +class TestBaseField(TestDataMixin, unittest.TestCase): + """测试基类的基本功能""" + + def test_bing(self): + tf1 = TestField() + tf2 = TestField() + tf1.bind('test1', tf2) + self.assertEqual(tf1.field_name, 'test1') + self.assertEqual(tf1.source, 'test1') + self.assertEqual(tf1.source_attrs, ['test1']) + self.assertEqual(tf1.parent, tf2) + + def test_get_external_value(self): + data1 = {'test': self.str_max_int} + data2 = {'test': self.max_int} + data3 = {'test': self.str_chinese} + data4 = {'test1': self.str_chinese} + base_tf = TestField() + tf1 = TestField() + tf2 = TestField(default=1) + tf1.bind('test', base_tf) + tf2.bind('test1', base_tf) + + self.assertEqual(tf1.get_external_value(data1), self.str_max_int) + self.assertEqual(tf1.get_external_value(data2), self.max_int) + self.assertEqual(tf1.get_external_value(data3), self.str_chinese) + self.assertEqual(tf1.get_external_value(data4), empty) + self.assertEqual(tf1.get_external_value({}), empty) + + self.assertEqual(tf2.get_external_value(data1), 1) + self.assertEqual(tf2.get_external_value(data2), 1) + self.assertEqual(tf2.get_external_value(data3), 1) + self.assertEqual(tf2.get_external_value(data4), self.str_chinese) + + def test_get_internal_value(self): + # 未进行 Model 类型测试 + test_data = [ + [self.str_chinese, {'tf': self.str_chinese}], + [self.str_england, {'tf': self.str_england}], + [self.max_float, {'tf': self.max_float}], + [self.str_max_float, {'tf': self.str_max_float}], + [self.max_int, {'tf': self.max_int}], + [self.str_max_int, {'tf': self.str_max_int}], + [self.pi, {'tf': self.pi}], + [self.str_pi, {'tf': self.str_pi}], + [self.bool_True, {'tf': self.bool_True}], + [self.bool_False, {'tf': self.bool_False}], + [self.str_bool_True, {'tf': self.str_bool_True}], + [self.str_bool_False, {'tf': self.str_bool_False}], + ] + + tf = TestField() + tf.bind('tf', tf) + for value, data in test_data: + self.assertEqual(tf.get_internal_value(data), value) + + def test_run_validators(self): + tf = TestField(validators=[MaxValueValidator(1000)]) + tf.bind('tf', tf) + tf.run_validators(10) + tf.run_validators(1000) + with self.assertRaises(ValidationError): + tf.run_validators(1001) + + def test_get_default(self): + def test(): + return '5' + + tf1 = TestField() + tf2 = TestField(default=0) + tf3 = TestField(default=test) + with self.assertRaises(SkipField): + tf1.get_default() + self.assertEqual(tf2.get_default(), 0) + self.assertEqual(tf3.get_default(), '5') + + def test_validate_empty_values(self): + base_tf = TestField() + + tf1 = TestField() + tf1.bind('t1', base_tf) + self.assertEqual(tf1.validate_empty_values(10), (False, 10)) + with self.assertRaises(SkipField): + tf1.validate_empty_values(empty) + with self.assertRaises(ValidationError): + tf1.validate_empty_values(None) + + tf2 = TestField(default=0) + tf2.bind('t2', base_tf) + self.assertEqual(tf2.validate_empty_values(10), (False, 10)) + self.assertEqual(tf2.validate_empty_values(empty), (True, 0)) + with self.assertRaises(ValidationError): + tf2.validate_empty_values(None) + + tf3 = TestField(default=0, read_only=True) + tf3.bind('t3', base_tf) + self.assertEqual(tf3.validate_empty_values(10), (True, 0)) + self.assertEqual(tf3.validate_empty_values(empty), (True, 0)) + self.assertEqual(tf3.validate_empty_values(None), (True, 0)) + + tf4 = TestField(default=0, required=True) + tf4.bind('t4', base_tf) + self.assertEqual(tf4.validate_empty_values(10), (False, 10)) + with self.assertRaises(ValidationError): + tf4.validate_empty_values(empty) + with self.assertRaises(ValidationError): + tf4.validate_empty_values(None) + + tf5 = TestField(default=0, allow_null=True) + tf5.bind('t5', base_tf) + self.assertEqual(tf5.validate_empty_values(10), (False, 10)) + self.assertEqual(tf5.validate_empty_values(empty), (True, 0)) + self.assertEqual(tf5.validate_empty_values(empty), (True, 0)) diff --git a/sanic_rest_framework/test/test_fields/test_decimal_field.py b/sanic_rest_framework/test/test_fields/test_decimal_field.py index 7754f02..2a29c07 100644 --- a/sanic_rest_framework/test/test_fields/test_decimal_field.py +++ b/sanic_rest_framework/test/test_fields/test_decimal_field.py @@ -15,7 +15,7 @@ """ import unittest -from decimal import Decimal, ROUND_DOWN, getcontext +from decimal import Decimal, ROUND_DOWN, getcontext, InvalidOperation from sanic_rest_framework.exceptions import ValidationError from sanic_rest_framework.fields import DecimalField as TestField from sanic_rest_framework.test.utils import TestDataMixin @@ -42,7 +42,7 @@ class TestDecimalField(TestDataMixin, unittest.TestCase): def test_quantize(self): tf = TestField(max_digits=6, decimal_places=2) context = getcontext() # 获取decimal现在的上下文 - context.rounding = ROUND_DOWN + context.rounding = ROUND_DOWN self.assertEqual(tf.quantize(Decimal('99.99999999')), Decimal('99.99')) def test_get_external_value(self): @@ -50,7 +50,7 @@ class TestDecimalField(TestDataMixin, unittest.TestCase): data1 = {'tf': self.str_max_int} data2 = {'tf': self.max_int} data3 = {'tf': self.str_chinese} - tf = TestField() + tf = TestField(max_digits=6, decimal_places=2) tf.bind('tf', tf) self.assertEqual(tf.get_external_value(data1), self.str_max_int) self.assertEqual(tf.get_external_value(data2), self.max_int) @@ -63,17 +63,26 @@ class TestDecimalField(TestDataMixin, unittest.TestCase): float => [1,'1','1.0',1.0',1.6,'1.6'] :return: """ - tf1 = TestField() + tf1 = TestField(max_digits=6, decimal_places=2) + + # 布尔类型 不行 with self.assertRaises(ValidationError): tf1.external_to_internal(self.bool_True) + # 字符类型 中文 不行 with self.assertRaises(ValidationError): tf1.external_to_internal(self.str_chinese) + # 字符类型 小数π 超长不行 + with self.assertRaises(ValidationError): + tf1.external_to_internal(self.str_pi) + # 浮点类型 小数π 超长不行 + with self.assertRaises(ValidationError): + self.assertEqual(tf1.external_to_internal(self.pi), self.pi) - self.assertEqual(tf1.external_to_internal(self.str_pi), self.pi) - self.assertEqual(tf1.external_to_internal(self.pi), self.pi) - self.assertEqual(tf1.external_to_internal(self.str_max_float), self.max_float) - self.assertEqual(tf1.external_to_internal(self.max_int), self.max_int) - self.assertEqual(tf1.external_to_internal(self.str_max_int), self.max_int) + self.assertEqual(tf1.external_to_internal(self.str_max_float), Decimal(self.str_max_float)) + + self.assertEqual(tf1.external_to_internal(self.max_int), Decimal(self.max_int)) + + self.assertEqual(tf1.external_to_internal(self.str_max_int), Decimal(self.max_int)) def test_get_internal_value(self): """测试由 instance 得到内部的值 """ @@ -83,7 +92,7 @@ class TestDecimalField(TestDataMixin, unittest.TestCase): data4 = {'tf': self.max_float} data6 = {'tf': self.max_int} data7 = {'tf': self.str_max_int} - tf = TestField() + tf = TestField(max_digits=6, decimal_places=2) tf.bind('tf', tf) value = tf.get_internal_value(data1) @@ -105,32 +114,39 @@ class TestDecimalField(TestDataMixin, unittest.TestCase): 只要是数值类型都不报错,float(xx) :return: """ - data1 = {'tf1': self.str_chinese} - data2 = {'tf1': self.max_int} - data3 = {'tf1': self.str_max_int} - data4 = {'tf1': self.str_max_float} - data5 = {'tf1': self.max_float} + data1 = {'tf': self.str_chinese} + data2 = {'tf': self.max_int} + data3 = {'tf': self.str_max_int} + data4 = {'tf': self.str_max_float} + data5 = {'tf': self.max_float} - tf = TestField() - tf.bind('tf1', tf) + tf = TestField(max_digits=6, decimal_places=2) + tf.bind('tf', tf) value = tf.get_internal_value(data1) - with self.assertRaises(ValueError): + with self.assertRaises(InvalidOperation): tf.internal_to_external(value) value = tf.get_internal_value(data2) self.assertEqual(tf.internal_to_external(value), self.max_int) + # Decimal(int) == int + # Decimal(str_int) == int value = tf.get_internal_value(data3) self.assertEqual(tf.internal_to_external(value), self.max_int) + + # Decimal(float) == float + # Decimal(str_float) != float value = tf.get_internal_value(data4) - self.assertEqual(tf.internal_to_external(value), self.max_float) + self.assertNotEqual(tf.internal_to_external(value), self.max_float) + self.assertEqual(tf.internal_to_external(value), Decimal(self.str_max_float)) + value = tf.get_internal_value(data5) - self.assertEqual(tf.internal_to_external(value), self.max_float) + self.assertEqual(tf.internal_to_external(value), Decimal(self.str_max_float)) def test_max_value(self): - tf1 = TestField() - tf2 = TestField(max_value=10.69) + tf1 = TestField(max_digits=6, decimal_places=2) + tf2 = TestField(max_digits=6, decimal_places=2, max_value=10.69) # 未设置 不存在超出限制 self.assertEqual(tf1.run_validators(data=self.str_max_int), None) @@ -151,9 +167,9 @@ class TestDecimalField(TestDataMixin, unittest.TestCase): self.assertEqual(tf2.run_validators(data=self.min_int), None) self.assertEqual(tf2.run_validators(data=self.min_float), None) - def test_min_length(self): - tf1 = TestField() - tf2 = TestField(min_value=10.69) + def test_min_value(self): + tf1 = TestField(max_digits=6, decimal_places=2) + tf2 = TestField(max_digits=6, decimal_places=2, min_value=10.69) # 未设置 不存在超出限制 self.assertEqual(tf1.run_validators(data=self.str_max_int), None) @@ -177,7 +193,7 @@ class TestDecimalField(TestDataMixin, unittest.TestCase): self.assertEqual(tf2.run_validators(data=self.max_float), None) def test_max_string_length(self): - tf1 = TestField() + tf1 = TestField(max_digits=6, decimal_places=2) with self.assertRaises(ValidationError): # 超出约定长度 tf1.external_to_internal(self.long_str) diff --git a/sanic_rest_framework/test/test_fields/test_float_field.py b/sanic_rest_framework/test/test_fields/test_float_field.py index 996cf49..5d62388 100644 --- a/sanic_rest_framework/test/test_fields/test_float_field.py +++ b/sanic_rest_framework/test/test_fields/test_float_field.py @@ -129,7 +129,7 @@ class TestFloatField(TestDataMixin, unittest.TestCase): self.assertEqual(tf2.run_validators(data=self.min_int), None) self.assertEqual(tf2.run_validators(data=self.min_float), None) - def test_min_length(self): + def test_min_value(self): tf1 = TestField() tf2 = TestField(min_value=10.69) diff --git a/sanic_rest_framework/test/test_fields/test_integer_field.py b/sanic_rest_framework/test/test_fields/test_integer_field.py index 8c7654e..1f0d026 100644 --- a/sanic_rest_framework/test/test_fields/test_integer_field.py +++ b/sanic_rest_framework/test/test_fields/test_integer_field.py @@ -117,13 +117,12 @@ class TestCharField(TestDataMixin, unittest.TestCase): self.assertEqual(tf2.run_validators(data=self.max_int), None) self.assertEqual(tf2.run_validators(data=self.min_int), None) - def test_min_length(self): + def test_min_value(self): tf1 = TestField() tf2 = TestField(min_value=10) - with self.assertRaises(ValidationError): - # 格式不支持str - self.assertEqual(tf1.run_validators(data=self.str_max_int), None) + # 无限制,则无验证器所以不报错 + self.assertEqual(tf1.run_validators(data=self.str_max_int), None) with self.assertRaises(ValidationError): self.assertEqual(tf2.run_validators(data=self.str_max_int), None) -- Gitee From c905ed4295a36aadffc5b6c8b9d03c1374eb04e8 Mon Sep 17 00:00:00 2001 From: LaoSi Date: Wed, 3 Feb 2021 18:19:13 +0800 Subject: [PATCH 15/34] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=8F=8A=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sanic_rest_framework/fields.py | 8 +- .../test/test_fields/test_bool_field.py | 93 +++++++++++++++++++ .../test/test_fields/test_char_field.py | 22 +---- .../test/test_fields/test_date_field.py | 74 +++++++++++++++ .../test/test_fields/test_datetime_field.py | 87 +++++++++++++++++ .../test/test_fields/test_decimal_field.py | 42 +-------- .../test/test_fields/test_float_field.py | 37 +------- .../test/test_fields/test_integer_field.py | 31 +------ sanic_rest_framework/test/utils.py | 23 ++++- 9 files changed, 292 insertions(+), 125 deletions(-) create mode 100644 sanic_rest_framework/test/test_fields/test_bool_field.py create mode 100644 sanic_rest_framework/test/test_fields/test_date_field.py create mode 100644 sanic_rest_framework/test/test_fields/test_datetime_field.py diff --git a/sanic_rest_framework/fields.py b/sanic_rest_framework/fields.py index 786a5db..9600854 100644 --- a/sanic_rest_framework/fields.py +++ b/sanic_rest_framework/fields.py @@ -564,6 +564,7 @@ class BooleanField(Field): class DateTimeField(Field): default_error_messages = { + 'invalid': '错误的数据类型{data_type}, 不能转换为字符格式', 'convert': '日期转换异常,请确认日期格式符合 %Y-%m-%d %H:%M:%S 规则', 'date': '需要的是日期时间格式而不是日期格式', 'overflow': '时间超出范围' @@ -590,14 +591,13 @@ class DateTimeField(Field): def external_to_internal(self, data: Any) -> Any: if not isinstance(data, (str, data, datetime)): self.raise_error('convert') - + if isinstance(data, date): + self.raise_error('date') if isinstance(data, str): try: data = datetime.strptime(data, self.input_formats) except (ValueError, TypeError): self.raise_error('convert') - if isinstance(data, date): - self.raise_error('date') if isinstance(data, datetime): data = self.enforce_timezone(data) return data @@ -609,7 +609,7 @@ class DateTimeField(Field): return data if isinstance(data, datetime): return data.strftime(self.output_format) - self.raise_error('invalid') + self.raise_error('invalid', data_type=type(data).__name__) class DateField(Field): diff --git a/sanic_rest_framework/test/test_fields/test_bool_field.py b/sanic_rest_framework/test/test_fields/test_bool_field.py new file mode 100644 index 0000000..878a8a2 --- /dev/null +++ b/sanic_rest_framework/test/test_fields/test_bool_field.py @@ -0,0 +1,93 @@ +""" +@Author:WangYuXiang +@E-mile:Hill@3io.cc +@CreateTime:2021/2/3 10:07 +@DependencyLibrary:无 +@MainFunction:无 +@FileDoc: + test_bool_field.py + 布尔类型字段测试 +@ChangeHistory: + datetime action why + example: + 2021/2/3 10:07 change 'Fix bug' + +""" +import unittest +from decimal import Decimal, ROUND_DOWN, getcontext, InvalidOperation +from sanic_rest_framework.exceptions import ValidationError +from sanic_rest_framework.fields import BooleanField as TestField +from sanic_rest_framework.test.test_fields.test_base_field import TestBaseField + + +class TestDecimalField(TestBaseField): + + def test_external_to_internal(self): + """ + 外转内 str -> dict 是严格的,不符合类型的都应该报错, + 一切都要经过验证 + :return: + """ + TRUE_VALUES = { + 't', 'T', + 'y', 'Y', 'yes', 'YES', + 'true', 'True', 'TRUE', + 'on', 'On', 'ON', + '1', 1, + True + } + FALSE_VALUES = { + 'f', 'F', + 'n', 'N', 'no', 'NO', + 'false', 'False', 'FALSE', + 'off', 'Off', 'OFF', + '0', 0, 0.0, + False + } + NULL_VALUES = {'null', 'Null', 'NULL', '', None} + tf1 = TestField() + for i in TRUE_VALUES: + self.assertEqual(tf1.external_to_internal(i), True) + for i in FALSE_VALUES: + self.assertEqual(tf1.external_to_internal(i), False) + for i in NULL_VALUES: + self.assertEqual(tf1.external_to_internal(i), None) + + def test_internal_to_external(self): + """ + 内转外 str -> dict 是宽松的, + 只要类型正确都不报错 + :return: + """ + TRUE_VALUES = { + 't', 'T', + 'y', 'Y', 'yes', 'YES', + 'true', 'True', 'TRUE', + 'on', 'On', 'ON', + '1', 1, + True + } + FALSE_VALUES = { + 'f', 'F', + 'n', 'N', 'no', 'NO', + 'false', 'False', 'FALSE', + 'off', 'Off', 'OFF', + '0', 0, 0.0, + False + } + tf1 = TestField() + for i in TRUE_VALUES: + self.assertEqual(tf1.internal_to_external(i), True) + for i in FALSE_VALUES: + self.assertEqual(tf1.internal_to_external(i), False) + self.assertEqual(tf1.internal_to_external(None), False) + self.assertEqual(tf1.internal_to_external(''), False) + self.assertEqual(tf1.internal_to_external('null'), True) + self.assertEqual(tf1.internal_to_external('any'), True) + self.assertEqual(tf1.internal_to_external('NULL'), True) + self.assertEqual(tf1.internal_to_external('yyyy'), True) + self.assertEqual(tf1.internal_to_external('曹凯'), True) + + +if __name__ == '__main__': + unittest.main() diff --git a/sanic_rest_framework/test/test_fields/test_char_field.py b/sanic_rest_framework/test/test_fields/test_char_field.py index 80b2f86..7e8a3eb 100644 --- a/sanic_rest_framework/test/test_fields/test_char_field.py +++ b/sanic_rest_framework/test/test_fields/test_char_field.py @@ -18,33 +18,15 @@ import unittest from sanic_rest_framework.exceptions import ValidationError from sanic_rest_framework.fields import CharField +from sanic_rest_framework.test.test_fields.test_base_field import TestBaseField -class TestCharField(unittest.TestCase): - def test_get_external_value(self): - data1 = {'char': 'Python'} - data2 = {'char': 66666} - char = CharField() - char.bind('char', char) - self.assertEqual(char.get_external_value(data1), 'Python') - self.assertEqual(char.get_external_value(data2), 66666) - +class TestCharField(TestBaseField): def test_external_to_internal(self): data = ' Python' char1 = CharField() self.assertEqual(char1.external_to_internal(data), 'Python') - def test_get_internal_value(self): - data1 = {'char': 'Python'} - data2 = {'char': 66666} - char = CharField() - char.bind('char', char) - - value1 = char.get_internal_value(data1) - self.assertEqual(value1, 'Python') - value2 = char.get_internal_value(data2) - self.assertEqual(value2, 66666) - def test_internal_to_external(self): data1 = {'char1': 'Python'} data2 = {'char1': 66666} diff --git a/sanic_rest_framework/test/test_fields/test_date_field.py b/sanic_rest_framework/test/test_fields/test_date_field.py new file mode 100644 index 0000000..90540d5 --- /dev/null +++ b/sanic_rest_framework/test/test_fields/test_date_field.py @@ -0,0 +1,74 @@ +""" +@Author:WangYuXiang +@E-mile:Hill@3io.cc +@CreateTime:2021/2/3 11:56 +@DependencyLibrary:无 +@MainFunction:无 +@FileDoc: + test_date_field.py + 测试日期格式字段 +@ChangeHistory: + datetime action why + example: + 2021/2/3 11:56 change 'Fix bug' + +""" + +import unittest +from datetime import date + +from sanic_rest_framework.exceptions import ValidationError +from sanic_rest_framework.fields import DateField as TestField +from sanic_rest_framework.test.test_fields.test_base_field import TestBaseField + + +class TestDateTimeField(TestBaseField): + + def test_external_to_internal(self): + """ + 外转内 str -> dict 是严格的,不符合类型的都应该报错, + 一切都要经过验证 + :return: + """ + + # 正常测试 + tf1 = TestField() + self.assertEqual(tf1.external_to_internal(self.str_date), self.obj_date) + self.assertEqual(tf1.external_to_internal(self.obj_date), self.obj_date) + with self.assertRaises(ValidationError): + tf1.external_to_internal(self.str_datetime_bad_y) + with self.assertRaises(ValidationError): + tf1.external_to_internal(self.str_datetime_bad_ym) + with self.assertRaises(ValidationError): + tf1.external_to_internal(self.str_datetime_bad_d) + with self.assertRaises(ValidationError): + tf1.external_to_internal(self.str_datetime_bad_h) + with self.assertRaises(ValidationError): + tf1.external_to_internal(self.str_datetime_bad_hm) + with self.assertRaises(ValidationError): + tf1.external_to_internal(self.str_datetime_bad_s) + with self.assertRaises(ValidationError): + tf1.external_to_internal(self.str_time) + with self.assertRaises(ValidationError): + tf1.external_to_internal(self.str_datetime) + + + def test_internal_to_external(self): + """ + 内转外 str -> dict 是宽松的, + 只要类型正确都不报错 + :return: + """ + tf1 = TestField() + self.assertEqual(tf1.internal_to_external(None), None) + self.assertEqual(tf1.internal_to_external(self.obj_datetime), self.str_datetime) + with self.assertRaises(ValidationError): + tf1.internal_to_external(self.obj_date) + with self.assertRaises(ValidationError): + tf1.internal_to_external(self.obj_time) + with self.assertRaises(ValidationError): + tf1.internal_to_external(1) + + +if __name__ == '__main__': + unittest.main() diff --git a/sanic_rest_framework/test/test_fields/test_datetime_field.py b/sanic_rest_framework/test/test_fields/test_datetime_field.py new file mode 100644 index 0000000..e39971d --- /dev/null +++ b/sanic_rest_framework/test/test_fields/test_datetime_field.py @@ -0,0 +1,87 @@ +""" +@Author:WangYuXiang +@E-mile:Hill@3io.cc +@CreateTime:2021/2/3 10:34 +@DependencyLibrary:无 +@MainFunction:无 +@FileDoc: + test_datetime_field.py + 测试日期时间类型字段 +@ChangeHistory: + datetime action why + example: + 2021/2/3 10:34 change 'Fix bug' + +""" + +import unittest +from datetime import timezone, timedelta, datetime +from sanic_rest_framework.exceptions import ValidationError +from sanic_rest_framework.fields import DateTimeField as TestField +from sanic_rest_framework.test.test_fields.test_base_field import TestBaseField + + +class TestDateTimeField(TestBaseField): + + def test_get_default_timezone(self): + tf = TestField() + tf.get_default_timezone() + self.assertEqual(tf.get_default_timezone(), timezone(timedelta(hours=8))) + + def test_enforce_timezone(self): + time1 = datetime.now() + tf1 = TestField() + tf2 = TestField(set_timezone=timezone(timedelta(hours=6))) + self.assertEqual(tf1.enforce_timezone(time1).tzinfo, tf1.get_default_timezone()) + self.assertEqual(tf2.enforce_timezone(time1).tzinfo, timezone(timedelta(hours=6))) + self.assertNotEqual(tf2.enforce_timezone(time1).tzinfo, tf1.get_default_timezone()) + + def test_external_to_internal(self): + """ + 外转内 str -> dict 是严格的,不符合类型的都应该报错, + 一切都要经过验证 + :return: + """ + + # 正常测试 + tf1 = TestField() + tf2 = TestField(set_timezone=timezone(timedelta(hours=5))) + self.assertEqual(tf1.external_to_internal(self.str_datetime), datetime(2019, 12, 18, 8, 21, 25, 0, tf1.get_default_timezone())) + self.assertNotEqual(tf2.external_to_internal(self.str_datetime).tzinfo, tf2.get_default_timezone()) + self.assertEqual(tf2.external_to_internal(self.str_datetime).tzinfo, timezone(timedelta(hours=5))) + self.assertEqual(tf2.external_to_internal(self.str_datetime), datetime(2019, 12, 18, 5, 21, 25, 0, timezone(timedelta(hours=5)))) + with self.assertRaises(ValidationError): + tf1.external_to_internal(self.str_datetime_bad_y) + with self.assertRaises(ValidationError): + tf1.external_to_internal(self.str_datetime_bad_ym) + with self.assertRaises(ValidationError): + tf1.external_to_internal(self.str_datetime_bad_d) + with self.assertRaises(ValidationError): + tf1.external_to_internal(self.str_datetime_bad_h) + with self.assertRaises(ValidationError): + tf1.external_to_internal(self.str_datetime_bad_hm) + with self.assertRaises(ValidationError): + tf1.external_to_internal(self.str_datetime_bad_s) + with self.assertRaises(ValidationError): + tf1.external_to_internal(self.str_time) + with self.assertRaises(ValidationError): + tf1.external_to_internal(self.str_date) + + def test_internal_to_external(self): + """ + 内转外 str -> dict 是宽松的, + 只要类型正确都不报错 + :return: + """ + tf1 = TestField() + self.assertEqual(tf1.internal_to_external(None), None) + self.assertEqual(tf1.internal_to_external(self.obj_datetime), self.str_datetime) + with self.assertRaises(ValidationError): + tf1.internal_to_external(self.obj_date) + with self.assertRaises(ValidationError): + tf1.internal_to_external(self.obj_time) + with self.assertRaises(ValidationError): + tf1.internal_to_external(1) + +if __name__ == '__main__': + unittest.main() diff --git a/sanic_rest_framework/test/test_fields/test_decimal_field.py b/sanic_rest_framework/test/test_fields/test_decimal_field.py index 2a29c07..65cbd97 100644 --- a/sanic_rest_framework/test/test_fields/test_decimal_field.py +++ b/sanic_rest_framework/test/test_fields/test_decimal_field.py @@ -18,10 +18,10 @@ import unittest from decimal import Decimal, ROUND_DOWN, getcontext, InvalidOperation from sanic_rest_framework.exceptions import ValidationError from sanic_rest_framework.fields import DecimalField as TestField -from sanic_rest_framework.test.utils import TestDataMixin +from sanic_rest_framework.test.test_fields.test_base_field import TestBaseField -class TestDecimalField(TestDataMixin, unittest.TestCase): +class TestDecimalField(TestBaseField): def test_validate_precision(self): """ @@ -45,22 +45,10 @@ class TestDecimalField(TestDataMixin, unittest.TestCase): context.rounding = ROUND_DOWN self.assertEqual(tf.quantize(Decimal('99.99999999')), Decimal('99.99')) - def test_get_external_value(self): - """得到外部传入的值""" - data1 = {'tf': self.str_max_int} - data2 = {'tf': self.max_int} - data3 = {'tf': self.str_chinese} - tf = TestField(max_digits=6, decimal_places=2) - tf.bind('tf', tf) - self.assertEqual(tf.get_external_value(data1), self.str_max_int) - self.assertEqual(tf.get_external_value(data2), self.max_int) - self.assertEqual(tf.get_external_value(data3), self.str_chinese) - def test_external_to_internal(self): """ 外转内 str -> dict 是严格的,不符合类型的都应该报错, 一切都要经过验证 - float => [1,'1','1.0',1.0',1.6,'1.6'] :return: """ tf1 = TestField(max_digits=6, decimal_places=2) @@ -84,34 +72,10 @@ class TestDecimalField(TestDataMixin, unittest.TestCase): self.assertEqual(tf1.external_to_internal(self.str_max_int), Decimal(self.max_int)) - def test_get_internal_value(self): - """测试由 instance 得到内部的值 """ - data1 = {'tf': self.str_chinese} - data2 = {'tf': self.str_england} - data3 = {'tf': self.str_max_float} - data4 = {'tf': self.max_float} - data6 = {'tf': self.max_int} - data7 = {'tf': self.str_max_int} - tf = TestField(max_digits=6, decimal_places=2) - tf.bind('tf', tf) - - value = tf.get_internal_value(data1) - self.assertEqual(value, self.str_chinese) - value = tf.get_internal_value(data2) - self.assertEqual(value, self.str_england) - value = tf.get_internal_value(data3) - self.assertEqual(value, self.str_max_float) - value = tf.get_internal_value(data4) - self.assertEqual(value, self.max_float) - value = tf.get_internal_value(data6) - self.assertEqual(value, self.max_int) - value = tf.get_internal_value(data7) - self.assertEqual(value, self.str_max_int) - def test_internal_to_external(self): """ 内转外 str -> dict 是宽松的, - 只要是数值类型都不报错,float(xx) + 只要类型正确都不报错 :return: """ data1 = {'tf': self.str_chinese} diff --git a/sanic_rest_framework/test/test_fields/test_float_field.py b/sanic_rest_framework/test/test_fields/test_float_field.py index 5d62388..1157363 100644 --- a/sanic_rest_framework/test/test_fields/test_float_field.py +++ b/sanic_rest_framework/test/test_fields/test_float_field.py @@ -18,21 +18,12 @@ import unittest from sanic_rest_framework.exceptions import ValidationError from sanic_rest_framework.fields import FloatField as TestField +from sanic_rest_framework.test.test_fields.test_base_field import TestBaseField from sanic_rest_framework.test.utils import TestDataMixin -class TestFloatField(TestDataMixin, unittest.TestCase): +class TestFloatField(TestBaseField): - def test_get_external_value(self): - """得到外部传入的值""" - data1 = {'tf': self.str_max_int} - data2 = {'tf': self.max_int} - data3 = {'tf': self.str_chinese} - tf = TestField() - tf.bind('tf', tf) - self.assertEqual(tf.get_external_value(data1), self.str_max_int) - self.assertEqual(tf.get_external_value(data2), self.max_int) - self.assertEqual(tf.get_external_value(data3), self.str_chinese) def test_external_to_internal(self): """ @@ -53,30 +44,6 @@ class TestFloatField(TestDataMixin, unittest.TestCase): self.assertEqual(tf1.external_to_internal(self.max_int), self.max_int) self.assertEqual(tf1.external_to_internal(self.str_max_int), self.max_int) - def test_get_internal_value(self): - """测试由 instance 得到内部的值 """ - data1 = {'tf': self.str_chinese} - data2 = {'tf': self.str_england} - data3 = {'tf': self.str_max_float} - data4 = {'tf': self.max_float} - data6 = {'tf': self.max_int} - data7 = {'tf': self.str_max_int} - tf = TestField() - tf.bind('tf', tf) - - value = tf.get_internal_value(data1) - self.assertEqual(value, self.str_chinese) - value = tf.get_internal_value(data2) - self.assertEqual(value, self.str_england) - value = tf.get_internal_value(data3) - self.assertEqual(value, self.str_max_float) - value = tf.get_internal_value(data4) - self.assertEqual(value, self.max_float) - value = tf.get_internal_value(data6) - self.assertEqual(value, self.max_int) - value = tf.get_internal_value(data7) - self.assertEqual(value, self.str_max_int) - def test_internal_to_external(self): """ 内转外 str -> dict 是宽松的, diff --git a/sanic_rest_framework/test/test_fields/test_integer_field.py b/sanic_rest_framework/test/test_fields/test_integer_field.py index 1f0d026..38df7be 100644 --- a/sanic_rest_framework/test/test_fields/test_integer_field.py +++ b/sanic_rest_framework/test/test_fields/test_integer_field.py @@ -18,20 +18,10 @@ import unittest from sanic_rest_framework.exceptions import ValidationError from sanic_rest_framework.fields import IntegerField as TestField -from sanic_rest_framework.test.utils import TestDataMixin +from sanic_rest_framework.test.test_fields.test_base_field import TestBaseField -class TestCharField(TestDataMixin, unittest.TestCase): - - def test_get_external_value(self): - data1 = {'tf': self.str_max_float} - data2 = {'tf': self.max_float} - data3 = {'tf': self.str_chinese} - tf = TestField() - tf.bind('tf', tf) - self.assertEqual(tf.get_external_value(data1), self.str_max_float) - self.assertEqual(tf.get_external_value(data2), self.max_float) - self.assertEqual(tf.get_external_value(data3), self.str_chinese) +class TestCharField(TestBaseField): def test_external_to_internal(self): """ @@ -54,17 +44,6 @@ class TestCharField(TestDataMixin, unittest.TestCase): self.assertEqual(tf.external_to_internal(self.max_int), self.max_int) self.assertEqual(tf.external_to_internal(self.str_max_int), self.max_int) - def test_get_internal_value(self): - data1 = {'tf': 'Python'} - data2 = {'tf': 66666} - tf = TestField() - tf.bind('tf', tf) - - value1 = tf.get_internal_value(data1) - self.assertEqual(value1, 'Python') - value2 = tf.get_internal_value(data2) - self.assertEqual(value2, 66666) - def test_internal_to_external(self): """ 内转外 str -> dict 是宽松的, @@ -102,11 +81,11 @@ class TestCharField(TestDataMixin, unittest.TestCase): tf1 = TestField() tf2 = TestField(max_value=10) + self.assertEqual(tf1.run_validators(data=self.str_max_int), None) + + # 超出限制 with self.assertRaises(ValidationError): - # 格式不支持str - self.assertEqual(tf1.run_validators(data=self.str_max_int), None) self.assertEqual(tf2.run_validators(data=self.str_max_int), None) - # 超出限制 # 未设置 不存在超出限制 self.assertEqual(tf1.run_validators(data=self.max_int), None) diff --git a/sanic_rest_framework/test/utils.py b/sanic_rest_framework/test/utils.py index b2b5b8f..27b8eae 100644 --- a/sanic_rest_framework/test/utils.py +++ b/sanic_rest_framework/test/utils.py @@ -13,6 +13,9 @@ 2021/1/29 16:38 change 'Fix bug' """ +import datetime + + class TestDataMixin: min_int = 1 max_int = 9999 @@ -32,4 +35,22 @@ class TestDataMixin: bool_False = False str_bool_True = 'True' str_bool_False = 'False' - long_str = '996' * 600 \ No newline at end of file + long_str = '996' * 600 + str_datetime = '2019-12-18 08:21:25' + str_datetime_bad_y = '99999-12-18 08:21:25' + str_datetime_bad_ym = '2019-13-18 08:21:25' + str_datetime_bad_d = '2019-12-32 08:21:25' + str_datetime_bad_h = '2019-12-18 25:21:25' + str_datetime_bad_hm = '2019-12-18 08:61:25' + str_datetime_bad_s = '2019-12-18 08:21:61' + str_date = '2019-12-18' + str_date_bad_y = '99999-12-18' + str_date_bad_m = '2019-13-18' + str_date_bad_d = '2019-12-32' + str_time = '08:21:25' + str_time_bad_h = '25:59:25' + str_time_bad_m = '08:61:25' + str_time_bad_s = '08:59:61' + obj_datetime = datetime.datetime(2019, 12, 18, 8, 21, 25) + obj_date = datetime.date(2019, 12, 18) + obj_time = datetime.time(8, 21, 25) -- Gitee From 054d10f4ec779e50588a0f72858930b3c4e9f082 Mon Sep 17 00:00:00 2001 From: LaoSi Date: Thu, 4 Feb 2021 17:24:01 +0800 Subject: [PATCH 16/34] =?UTF-8?q?=E6=B5=8B=E8=AF=95=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sanic_rest_framework/fields.py | 5 +- .../test/test_fields/test_date_field.py | 1 - .../test/test_fields/test_time_field.py | 74 +++++++++++++++++++ 3 files changed, 78 insertions(+), 2 deletions(-) create mode 100644 sanic_rest_framework/test/test_fields/test_time_field.py diff --git a/sanic_rest_framework/fields.py b/sanic_rest_framework/fields.py index 9600854..ceb3e72 100644 --- a/sanic_rest_framework/fields.py +++ b/sanic_rest_framework/fields.py @@ -653,6 +653,7 @@ class TimeField(Field): 'invalid': '出现错误的数据类型,{value}不是有效的日期时间类型', 'format': '时间格式错误,需要格式为 %H:%M:%S ', 'date': '需要的是时间格式而不是日期格式', + 'datetime': '需要的是时间格式而不是日期时间格式', } def __init__(self, output_format='%H:%M:%S', input_formats='%H:%M:%S', *args, **kwargs): @@ -668,7 +669,9 @@ class TimeField(Field): except (ValueError, TypeError): self.raise_error('format') if isinstance(data, datetime): - return data.time() + self.raise_error('datetime') + if isinstance(data, time): + return data if isinstance(data, date): self.raise_error('date') self.raise_error('invalid', value=type(data)) diff --git a/sanic_rest_framework/test/test_fields/test_date_field.py b/sanic_rest_framework/test/test_fields/test_date_field.py index 90540d5..f9d3936 100644 --- a/sanic_rest_framework/test/test_fields/test_date_field.py +++ b/sanic_rest_framework/test/test_fields/test_date_field.py @@ -15,7 +15,6 @@ """ import unittest -from datetime import date from sanic_rest_framework.exceptions import ValidationError from sanic_rest_framework.fields import DateField as TestField diff --git a/sanic_rest_framework/test/test_fields/test_time_field.py b/sanic_rest_framework/test/test_fields/test_time_field.py new file mode 100644 index 0000000..f42f726 --- /dev/null +++ b/sanic_rest_framework/test/test_fields/test_time_field.py @@ -0,0 +1,74 @@ +""" +@Author:WangYuXiang +@E-mile:Hill@3io.cc +@CreateTime:2021/2/4 14:35 +@DependencyLibrary:无 +@MainFunction:无 +@FileDoc: + test_time_field.py + 测试时间格式字段 +@ChangeHistory: + datetime action why + example: + 2021/2/4 14:35 change 'Fix bug' + +""" + +import unittest +from datetime import date + +from sanic_rest_framework.exceptions import ValidationError +from sanic_rest_framework.fields import TimeField as TestField +from sanic_rest_framework.test.test_fields.test_base_field import TestBaseField + + +class TestDateTimeField(TestBaseField): + + def test_external_to_internal(self): + """ + 外转内 str -> dict 是严格的,不符合类型的都应该报错, + 一切都要经过验证 + :return: + """ + + # 正常测试 + tf1 = TestField() + + self.assertEqual(tf1.external_to_internal(self.str_time), self.obj_time) + self.assertEqual(tf1.external_to_internal(self.obj_time), self.obj_time) + with self.assertRaises(ValidationError): + tf1.external_to_internal(self.str_datetime_bad_y) + with self.assertRaises(ValidationError): + tf1.external_to_internal(self.str_datetime_bad_ym) + with self.assertRaises(ValidationError): + tf1.external_to_internal(self.str_datetime_bad_d) + with self.assertRaises(ValidationError): + tf1.external_to_internal(self.str_datetime_bad_h) + with self.assertRaises(ValidationError): + tf1.external_to_internal(self.str_datetime_bad_hm) + with self.assertRaises(ValidationError): + tf1.external_to_internal(self.str_datetime_bad_s) + with self.assertRaises(ValidationError): + tf1.external_to_internal(self.str_time) + with self.assertRaises(ValidationError): + tf1.external_to_internal(self.str_datetime) + + def test_internal_to_external(self): + """ + 内转外 str -> dict 是宽松的, + 只要类型正确都不报错 + :return: + """ + tf1 = TestField() + self.assertEqual(tf1.internal_to_external(None), None) + self.assertEqual(tf1.internal_to_external(self.obj_datetime), self.str_datetime) + with self.assertRaises(ValidationError): + tf1.internal_to_external(self.obj_date) + with self.assertRaises(ValidationError): + tf1.internal_to_external(self.obj_time) + with self.assertRaises(ValidationError): + tf1.internal_to_external(1) + + +if __name__ == '__main__': + unittest.main() -- Gitee From a3797dc61062b3cddacb6d33b074c30605c98144 Mon Sep 17 00:00:00 2001 From: LaoSi Date: Thu, 25 Feb 2021 16:58:26 +0800 Subject: [PATCH 17/34] ListSerializer --- sanic_rest_framework/serializers.py | 90 ++++++++++++++++++- .../test/test_fields/test_base_field.py | 4 +- .../test/test_serializers/test_serializer.py | 37 ++++++-- 3 files changed, 119 insertions(+), 12 deletions(-) diff --git a/sanic_rest_framework/serializers.py b/sanic_rest_framework/serializers.py index 97b3dd2..f783627 100644 --- a/sanic_rest_framework/serializers.py +++ b/sanic_rest_framework/serializers.py @@ -9,9 +9,12 @@ 序列化器文件 """ import copy +import inspect from collections import OrderedDict from typing import Any, Mapping +from tortoise import models + from sanic_rest_framework.fields import Field, empty, SkipField from .exceptions import ValidationError from .helpers import BindingDict @@ -107,6 +110,14 @@ class BaseSerializer(Field): '{cls}类的 .external_to_internal 方法必须重写'.format(cls=self.__class__.__name__, ) ) + def validate(self, data): + return data + + def run_validation(self, data): + value = super(BaseSerializer, self).run_validation(data) + value = self.validate(value) + return value + def is_valid(self, raise_exception=False): assert hasattr(self, 'initial_data'), ( 'Cannot call `.is_valid()` as no `data=` keyword argument was ' @@ -248,10 +259,11 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass): if validate_method is not None: validated_value = validate_method(validated_value) except ValidationError as exc: - if isinstance(field, BaseSerializer): - errors[field.field_name] = exc.error_dict - else: - errors[field.field_name] = exc.error_list + errors[field.field_name] = exc.error_dict if hasattr(exc, 'error_dict') else exc.error_list + # if isinstance(field, BaseSerializer): + # errors[field.field_name] = exc.error_dict + # else: + # errors[field.field_name] = exc.error_list except SkipField: pass else: @@ -276,3 +288,73 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass): 'value': value, 'error': error, } + + +class ListSerializer(BaseSerializer): + child = None + many = True + + default_error_messages = { + 'not_a_list': '预期项目列表,但类型为“ {input_type}”。', + 'empty': '此列表不能为空。' + } + + def __init__(self, *args, **kwargs): + self.child = kwargs.pop('child', copy.deepcopy(self.child)) + self.allow_empty = kwargs.pop('allow_empty', True) + assert self.child is not None, '`child` 是必填参数。' + assert not inspect.isclass(self.child), '`child` 尚未实例化。' + super().__init__(*args, **kwargs) + self.child.bind(field_name='', parent=self) + + async def internal_to_external(self, data: Any) -> Any: + """ + 内转外 + :param data: + :return: + """ + iterable = await data.all() if isinstance(data, models.Model) else data + + return [ + self.child.internal_to_external(item) for item in iterable + ] + + def external_to_internal(self, data: Any) -> Any: + """ + 外转内 + :param data: + :return: + """ + if not isinstance(data, list): + raise self.raise_error('not_a_list', input_type=type(data).__name__) + + if not self.allow_empty and len(data) == 0: + raise self.raise_error('empty') + + ret = [] + errors = [] + + for item in data: + try: + validated = self.child.run_validation(item) + except ValidationError as exc: + errors.append(exc) + else: + ret.append(validated) + errors.append({}) + if any(errors): + raise ValidationError(errors) + return ret + + def run_validation(self, data=empty): + """ + 我们覆盖默认的`run_validation`,因为验证由验证者执行, + 而.validate()方法应使用“non_fields_error”键被强制为错误字典。 + """ + (is_empty_value, data) = self.validate_empty_values(data) + if is_empty_value: + return data + value = self.external_to_internal(data) + self.run_validators(value) + value = self.validate(value) + return value diff --git a/sanic_rest_framework/test/test_fields/test_base_field.py b/sanic_rest_framework/test/test_fields/test_base_field.py index e87c243..ece1e54 100644 --- a/sanic_rest_framework/test/test_fields/test_base_field.py +++ b/sanic_rest_framework/test/test_fields/test_base_field.py @@ -13,14 +13,14 @@ 2021/2/2 15:58 change 'Fix bug' """ -import unittest +from tortoise.contrib import test from sanic_rest_framework.exceptions import ValidationError from sanic_rest_framework.fields import Field as TestField, empty, SkipField from sanic_rest_framework.test.utils import TestDataMixin from sanic_rest_framework.validators import MaxValueValidator -class TestBaseField(TestDataMixin, unittest.TestCase): +class TestBaseField(TestDataMixin, test.TestCase): """测试基类的基本功能""" def test_bing(self): diff --git a/sanic_rest_framework/test/test_serializers/test_serializer.py b/sanic_rest_framework/test/test_serializers/test_serializer.py index 68179dc..3d71638 100644 --- a/sanic_rest_framework/test/test_serializers/test_serializer.py +++ b/sanic_rest_framework/test/test_serializers/test_serializer.py @@ -13,12 +13,22 @@ 2021/1/27 9:55 change 'Fix bug' """ +import asyncio +from tortoise.contrib.test import initializer, finalizer, TestCase +from tortoise import fields + +from db import TestModel from sanic_rest_framework.fields import ( CharField, IntegerField, FloatField, DateField, TimeField, DecimalField, DateTimeField, BooleanField, ChoiceField, SerializerMethodField) from sanic_rest_framework.serializers import Serializer -import unittest + + +class IToESerializer(Serializer): + birthday = fields.DatetimeField() + name = fields.CharField(80) + ages = fields.IntField() class QianTaoserializer(Serializer): @@ -45,11 +55,15 @@ class FieldRequiredSerializer(Serializer): serializer_method_field = SerializerMethodField() -class TestFieldRequired(unittest.TestCase): +class FieldManySerializer(Serializer): + qt = QianTaoserializer(many=True) + + +class TestFieldRequired(TestCase): def test_has_data(self): data = { - 'char_field': '', - 'integer_field': '', + # 'char_field': '', + # 'integer_field': '', 'float_field': '', 'decimal_field': '', 'date_field': '', @@ -78,6 +92,17 @@ class TestFieldRequired(unittest.TestCase): frs = FieldRequiredSerializer(data=data) self.assertEqual(frs.is_valid(), True) + def test_many_data(self): + data = { + 'qt': [{'name': '刘文静', 'doc': 'py开发'}, {'name': '陈掏灰', 'doc': 'html开发'}] + } + fms = FieldManySerializer(data=data, instance=TestModel) + fms.is_valid(raise_exception=True) + print(fms.validated_data) + + async def test_o(self): + initializer(['db', ], db_url="sqlite://G:/Codes/Python/test/sanic_rest_framework/db.sqlite", loop=asyncio.get_event_loop()) + print(await TestModel.get(pk=1)) + IToESerializer(instance=await TestModel.get(pk=1)) -if __name__ == '__main__': - unittest.main() + finalizer() -- Gitee From 5fb7436add16e7d02ecb2555bdff2d695a31e044 Mon Sep 17 00:00:00 2001 From: LaoSi Date: Mon, 1 Mar 2021 17:14:34 +0800 Subject: [PATCH 18/34] =?UTF-8?q?=E5=8D=95=E5=85=83=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sanic_rest_framework/fields.py | 23 +++++++------ sanic_rest_framework/serializers.py | 14 ++++---- sanic_rest_framework/test/__init__.py | 0 sanic_rest_framework/test/models.py | 9 +++++ .../test/test_serializers/test_serializer.py | 33 ++++++++++--------- 5 files changed, 47 insertions(+), 32 deletions(-) create mode 100644 sanic_rest_framework/test/__init__.py create mode 100644 sanic_rest_framework/test/models.py diff --git a/sanic_rest_framework/fields.py b/sanic_rest_framework/fields.py index ceb3e72..9e5e0db 100644 --- a/sanic_rest_framework/fields.py +++ b/sanic_rest_framework/fields.py @@ -179,7 +179,7 @@ class Field: '{cls}类的 .external_to_internal 方法必须重写'.format(cls=self.__class__.__name__, ) ) - def internal_to_external(self, data: Any) -> Any: + async def internal_to_external(self, data: Any) -> Any: """对数据进行序列化转换并返回""" raise NotImplementedError( '{cls}类的 .external_to_internal 方法必须重写'.format(cls=self.__class__.__name__, ) @@ -205,7 +205,7 @@ class Field: """适用于 Model 对象""" return await getattr(instance, attr) - def get_internal_value(self, instance: Any) -> Any: + async def get_internal_value(self, instance: Any) -> Any: """ 从传入的内部数据中得到值 值用于输出 @@ -218,7 +218,10 @@ class Field: if isinstance(instance, Mapping): instance = instance[attr] else: - instance = self.async_get_attribute(instance, attr) + if '.' in attr: + instance = await self.async_get_attribute(instance, attr) + else: + instance = getattr(instance, attr) except DoesNotExist: return None except (KeyError, AttributeError) as exc: @@ -339,7 +342,7 @@ class CharField(Field): value = str(data) return value.strip() if self.trim_whitespace else value - def internal_to_external(self, data: Any) -> Any: + async def internal_to_external(self, data: Any) -> Any: return str(data) @@ -372,7 +375,7 @@ class IntegerField(Field): self.raise_error('invalid') return data - def internal_to_external(self, data: Any): + async def internal_to_external(self, data: Any): return int(data) @@ -396,7 +399,7 @@ class FloatField(IntegerField): except (TypeError, ValueError): self.raise_error('invalid', data_type=type(data).__name__) - def internal_to_external(self, data: Any): + async def internal_to_external(self, data: Any): return float(data) @@ -461,7 +464,7 @@ class DecimalField(Field): return self.quantize(self.validate_precision(data)) - def internal_to_external(self, data: Any): + async def internal_to_external(self, data: Any): if not isinstance(data, decimal.Decimal): data = decimal.Decimal(str(data).strip()) @@ -552,7 +555,7 @@ class BooleanField(Field): except TypeError: self.raise_error('invalid', value=data) - def internal_to_external(self, data: Any) -> Any: + async def internal_to_external(self, data: Any) -> Any: if data in self.TRUE_VALUES: return True elif data in self.FALSE_VALUES: @@ -602,7 +605,7 @@ class DateTimeField(Field): data = self.enforce_timezone(data) return data - def internal_to_external(self, data: Any) -> Any: + async def internal_to_external(self, data: Any) -> Any: if not data: return None if isinstance(data, str): @@ -637,7 +640,7 @@ class DateField(Field): return data self.raise_error('invalid', value=data) - def internal_to_external(self, data: Any) -> Any: + async def internal_to_external(self, data: Any) -> Any: if not data: return data if isinstance(data, str): diff --git a/sanic_rest_framework/serializers.py b/sanic_rest_framework/serializers.py index f783627..953bb5b 100644 --- a/sanic_rest_framework/serializers.py +++ b/sanic_rest_framework/serializers.py @@ -104,7 +104,7 @@ class BaseSerializer(Field): '{cls}类的 .external_to_internal 方法必须重写'.format(cls=self.__class__.__name__, ) ) - def internal_to_external(self, data: Any) -> Any: + async def internal_to_external(self, data: Any) -> Any: """对数据进行反序列化转换并返回""" raise NotImplementedError( '{cls}类的 .external_to_internal 方法必须重写'.format(cls=self.__class__.__name__, ) @@ -138,11 +138,11 @@ class BaseSerializer(Field): return not bool(self._errors) @property - def data(self): + async def data(self): """对外呈现的数据""" assert not self.instance is None, '调用 .data 必须先传入 instance= ' if not hasattr(self, '_data'): - self._data = self.internal_to_external(self.instance) + self._data = await self.internal_to_external(self.instance) return self._data @property @@ -227,7 +227,7 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass): """ return copy.deepcopy(self._declared_fields) - def internal_to_external(self, data: Any) -> Any: + async def internal_to_external(self, data: Any) -> Any: """ 内转外 :param data: @@ -236,8 +236,8 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass): res = OrderedDict() fields = self._readable_fields for field in fields: - value = field.get_internal_value(data) - res[field.field_name] = field.internal_to_external(value) + value = await field.get_internal_value(data) + res[field.field_name] = await field.internal_to_external(value) return res # 反序列化 @@ -316,7 +316,7 @@ class ListSerializer(BaseSerializer): iterable = await data.all() if isinstance(data, models.Model) else data return [ - self.child.internal_to_external(item) for item in iterable + await self.child.internal_to_external(item) for item in iterable ] def external_to_internal(self, data: Any) -> Any: diff --git a/sanic_rest_framework/test/__init__.py b/sanic_rest_framework/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sanic_rest_framework/test/models.py b/sanic_rest_framework/test/models.py new file mode 100644 index 0000000..e064c1d --- /dev/null +++ b/sanic_rest_framework/test/models.py @@ -0,0 +1,9 @@ +from datetime import date +from tortoise import fields +from tortoise import Model + + +class TestModel(Model): + name = fields.CharField(max_length=30) + ages = fields.IntField() + birthday = fields.DateField(default=date.today) diff --git a/sanic_rest_framework/test/test_serializers/test_serializer.py b/sanic_rest_framework/test/test_serializers/test_serializer.py index 3d71638..df890a3 100644 --- a/sanic_rest_framework/test/test_serializers/test_serializer.py +++ b/sanic_rest_framework/test/test_serializers/test_serializer.py @@ -4,31 +4,29 @@ @CreateTime:2021/1/27 9:55 @DependencyLibrary:无 @MainFunction:无 -@FileDoc: +@FileDoc: test_serializer.py 测试文件 @ChangeHistory: datetime action why example: 2021/1/27 9:55 change 'Fix bug' - + """ import asyncio - -from tortoise.contrib.test import initializer, finalizer, TestCase -from tortoise import fields - -from db import TestModel from sanic_rest_framework.fields import ( CharField, IntegerField, FloatField, DateField, TimeField, DecimalField, DateTimeField, BooleanField, ChoiceField, SerializerMethodField) from sanic_rest_framework.serializers import Serializer +from tortoise.contrib.test import initializer, finalizer, TestCase +from sanic_rest_framework.test.models import TestModel + class IToESerializer(Serializer): - birthday = fields.DatetimeField() - name = fields.CharField(80) - ages = fields.IntField() + birthday = DateField() + name = CharField(80) + ages = IntegerField() class QianTaoserializer(Serializer): @@ -59,6 +57,11 @@ class FieldManySerializer(Serializer): qt = QianTaoserializer(many=True) +initializer(['sanic_rest_framework.test.models', ], + db_url="sqlite://./db.sqlite", + loop=asyncio.get_event_loop()) + + class TestFieldRequired(TestCase): def test_has_data(self): data = { @@ -94,15 +97,15 @@ class TestFieldRequired(TestCase): def test_many_data(self): data = { - 'qt': [{'name': '刘文静', 'doc': 'py开发'}, {'name': '陈掏灰', 'doc': 'html开发'}] + 'qt': [{'name': '刘文静', 'doc': 'py开发'}, + {'name': '陈掏灰', 'doc': 'html开发'}] } fms = FieldManySerializer(data=data, instance=TestModel) fms.is_valid(raise_exception=True) print(fms.validated_data) async def test_o(self): - initializer(['db', ], db_url="sqlite://G:/Codes/Python/test/sanic_rest_framework/db.sqlite", loop=asyncio.get_event_loop()) - print(await TestModel.get(pk=1)) - IToESerializer(instance=await TestModel.get(pk=1)) - + await TestModel(name='刘文静', ages=22, birthday='2016-12-12').save() + ts = await TestModel.get(pk=1) + print(await IToESerializer(instance=ts).data) finalizer() -- Gitee From 702b1a20f0c193206cf8b4d2b0e75e0b7a6b89ee Mon Sep 17 00:00:00 2001 From: LaoSi Date: Tue, 2 Mar 2021 02:21:39 +0800 Subject: [PATCH 19/34] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E6=A1=88=E4=BE=8B=20=E5=9F=BA=E4=BA=8E=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E6=A1=88=E4=BE=8B=E4=BF=AE=E5=A4=8D=E5=BA=8F=E5=88=97=E5=8C=96?= =?UTF-8?q?=E5=99=A8=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements.txt | 12 +- sanic_rest_framework/serializers.py | 16 +- sanic_rest_framework/test/models.py | 18 +- .../test/test_fields/__init__.py | 0 .../test/test_serializers/__init__.py | 0 .../test/test_serializers/test_serializer.py | 169 ++++++++++-------- 6 files changed, 130 insertions(+), 85 deletions(-) create mode 100644 sanic_rest_framework/test/test_fields/__init__.py create mode 100644 sanic_rest_framework/test/test_serializers/__init__.py diff --git a/requirements.txt b/requirements.txt index 01050ec..625ce45 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,18 +1,22 @@ aiofiles==0.6.0 -aiosqlite==0.16.0 +aiosqlite==0.16.1 +asynctest==0.13.0 +autopep8==1.5.5 certifi==2020.12.5 h11==0.9.0 httpcore==0.11.1 httptools==0.1.1 httpx==0.15.4 idna==3.1 -iso8601==0.1.13 +iso8601==0.1.14 multidict==5.1.0 +pycodestyle==2.6.0 PyPika==0.44.1 pytz==2020.5 rfc3986==1.4.0 -sanic==20.12.1 +sanic==20.12.2 sniffio==1.2.0 -tortoise-orm==0.16.19 +toml==0.10.2 +tortoise-orm==0.16.21 typing-extensions==3.7.4.3 websockets==8.1 diff --git a/sanic_rest_framework/serializers.py b/sanic_rest_framework/serializers.py index 953bb5b..4943296 100644 --- a/sanic_rest_framework/serializers.py +++ b/sanic_rest_framework/serializers.py @@ -13,7 +13,9 @@ import inspect from collections import OrderedDict from typing import Any, Mapping -from tortoise import models +from tortoise import models, Model +from tortoise.fields import relational +from tortoise.queryset import QuerySet from sanic_rest_framework.fields import Field, empty, SkipField from .exceptions import ValidationError @@ -253,8 +255,8 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass): fields = self._writable_fields for field in fields: validate_method = getattr(self, 'validate_' + field.field_name, None) - primitive_value = field.get_external_value(data) try: + primitive_value = field.get_external_value(data) validated_value = field.run_validation(primitive_value) if validate_method is not None: validated_value = validate_method(validated_value) @@ -307,13 +309,21 @@ class ListSerializer(BaseSerializer): super().__init__(*args, **kwargs) self.child.bind(field_name='', parent=self) + async def get_internal_value(self, instance: Any) -> Any: + if isinstance(instance, Mapping): + data = instance.get(self.field_name) + if not isinstance(data, list): + data = [data] + return data + return await instance.all() + async def internal_to_external(self, data: Any) -> Any: """ 内转外 :param data: :return: """ - iterable = await data.all() if isinstance(data, models.Model) else data + iterable = await data.all() if isinstance(data, (QuerySet, Model, relational.RelationalField)) else data return [ await self.child.internal_to_external(item) for item in iterable diff --git a/sanic_rest_framework/test/models.py b/sanic_rest_framework/test/models.py index e064c1d..ad6a4c4 100644 --- a/sanic_rest_framework/test/models.py +++ b/sanic_rest_framework/test/models.py @@ -3,7 +3,17 @@ from tortoise import fields from tortoise import Model -class TestModel(Model): - name = fields.CharField(max_length=30) - ages = fields.IntField() - birthday = fields.DateField(default=date.today) +class UserModel(Model): + name = fields.CharField(max_length=8, null=False) + birthday = fields.DateField() + phone = fields.CharField(max_length=11) + balance = fields.DecimalField(13, 2) + address: fields.ManyToManyRelation["AddressModel"] = fields.ManyToManyField( + 'models.AddressModel', through='user2address', related_name='user') + + +class AddressModel(Model): + phone = fields.CharField(12, null=False) + address = fields.CharField(100) + house_number = fields.CharField(100) + # user: fields.ManyToManyRelation[UserModel] diff --git a/sanic_rest_framework/test/test_fields/__init__.py b/sanic_rest_framework/test/test_fields/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sanic_rest_framework/test/test_serializers/__init__.py b/sanic_rest_framework/test/test_serializers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sanic_rest_framework/test/test_serializers/test_serializer.py b/sanic_rest_framework/test/test_serializers/test_serializer.py index df890a3..b9ba763 100644 --- a/sanic_rest_framework/test/test_serializers/test_serializer.py +++ b/sanic_rest_framework/test/test_serializers/test_serializer.py @@ -14,47 +14,32 @@ """ import asyncio +import unittest +from collections import OrderedDict +from copy import deepcopy + from sanic_rest_framework.fields import ( CharField, IntegerField, FloatField, DateField, TimeField, DecimalField, DateTimeField, BooleanField, ChoiceField, SerializerMethodField) from sanic_rest_framework.serializers import Serializer from tortoise.contrib.test import initializer, finalizer, TestCase -from sanic_rest_framework.test.models import TestModel - - -class IToESerializer(Serializer): - birthday = DateField() - name = CharField(80) - ages = IntegerField() - - -class QianTaoserializer(Serializer): - name = CharField() - doc = CharField() - - -class BaseSerializer(Serializer): - id = CharField(max_length=18) - qt = QianTaoserializer(required=True, allow_null=True) - ages = IntegerField(max_value=6) +from sanic_rest_framework.test.models import UserModel, AddressModel +from sanic_rest_framework.exceptions import ValidationError -class FieldRequiredSerializer(Serializer): - char_field = CharField(required=True) - integer_field = IntegerField(required=True) - float_field = FloatField(required=True) - decimal_field = DecimalField(required=True, max_digits=9, decimal_places=3) - date_field = DateField(required=True) - time_field = TimeField(required=True) - datetime_field = DateTimeField(required=True) - boolean_field = BooleanField(required=True) - choice_field = ChoiceField(required=True, choices=[(1, '老四')]) - serializer_method_field = SerializerMethodField() +class AddressSerializer(Serializer): + phone = CharField(max_length=11, required=True) + address = CharField(max_length=100) + house_number = CharField(max_length=100) -class FieldManySerializer(Serializer): - qt = QianTaoserializer(many=True) +class UserSerializer(Serializer): + name = CharField(max_length=8, required=True) + birthday = DateField() + phone = CharField(max_length=11) + balance = DecimalField(13, 2) + address = AddressSerializer(many=True) initializer(['sanic_rest_framework.test.models', ], @@ -62,50 +47,86 @@ initializer(['sanic_rest_framework.test.models', ], loop=asyncio.get_event_loop()) -class TestFieldRequired(TestCase): - def test_has_data(self): - data = { - # 'char_field': '', - # 'integer_field': '', - 'float_field': '', - 'decimal_field': '', - 'date_field': '', - 'time_field': '', - 'datetime_field': '', - 'boolean_field': '', - 'choice_field': '', - 'serializer_method_field': '', - } - frs = FieldRequiredSerializer(data=data) - self.assertEqual(frs.is_valid(), False) - - def test_success_data(self): - data = { - 'char_field': 'NiHao', - 'integer_field': 1, - 'float_field': 80.06, - 'decimal_field': 99620, - 'date_field': '2017-16-28', - 'time_field': '16:18:11', - 'datetime_field': '2017-16-28 16:18:11', - 'boolean_field': '0', - 'choice_field': '80', - 'serializer_method_field': 'aaa', +class TestOrdinarySerializer(TestCase): + def setUp(self) -> None: + self.mapping_data = { + 'phone': '17674707036', + 'address': '长沙市IFS', + 'house_number': '67L' } - frs = FieldRequiredSerializer(data=data) - self.assertEqual(frs.is_valid(), True) + self.res = OrderedDict(self.mapping_data) + + async def test_serializer(self): + """测试序列化""" + addr_model = AddressModel(**self.mapping_data) + await addr_model.save() + as_map = AddressSerializer(instance=self.mapping_data) + as_model = AddressSerializer(instance=await AddressModel.get(pk=1)) + self.assertEqual(await as_map.data, self.res) + self.assertEqual(await as_model.data, self.res) + + async def test_deserializer(self): + """测试反序列化""" + + addr_model = AddressModel(**self.mapping_data) + await addr_model.save() + as_map = AddressSerializer(data=self.mapping_data) + as_model = AddressSerializer(data=await AddressModel.get(pk=1)) + self.assertIs(as_map.is_valid(), True) + self.assertEqual(as_map.validated_data, self.res) + + self.assertIs(as_model.is_valid(), False) + print(as_model.errors) + # print(as_model.validated_data) + + @classmethod + async def tearDownClass(cls) -> None: + finalizer() + - def test_many_data(self): - data = { - 'qt': [{'name': '刘文静', 'doc': 'py开发'}, - {'name': '陈掏灰', 'doc': 'html开发'}] +class TestM2MNestedSerializer(TestCase): + """测试单个嵌套""" + + def setUp(self) -> None: + self.mapping_data = { + 'name': '老四', + 'birthday': '2010-11-16', + 'phone': '17674707037', + 'balance': 99.009, + 'address': [{ + 'phone': '17674707036', + 'address': '长沙市IFS', + 'house_number': '67L' + }, { + 'phone': '17674707037', + 'address': '长沙市HPT', + 'house_number': '4L' + }] } - fms = FieldManySerializer(data=data, instance=TestModel) - fms.is_valid(raise_exception=True) - print(fms.validated_data) - - async def test_o(self): - await TestModel(name='刘文静', ages=22, birthday='2016-12-12').save() - ts = await TestModel.get(pk=1) - print(await IToESerializer(instance=ts).data) + + async def test_serializer(self): + """测试序列化""" + user_info = deepcopy(self.mapping_data) + address = user_info.pop('address') + user = UserModel(**user_info) + await user.save() + for i in address: + addr = AddressModel(**i) + await addr.save() + await user.address.add(addr) + + us_map = UserSerializer(instance=self.mapping_data) + us_model = UserSerializer(instance=await UserModel.get(pk=1)) + print(await us_map.data) + print(await us_model.data) + + async def test_deserializer(self): + """测试反序列化""" + + @classmethod + async def tearDownClass(cls) -> None: finalizer() + + +if __name__ == '__main__': + unittest.main() -- Gitee From 996b9c0751cac5a7a946f54b726160ef9905e598 Mon Sep 17 00:00:00 2001 From: LaoSi Date: Tue, 2 Mar 2021 17:59:25 +0800 Subject: [PATCH 20/34] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E7=94=A8=E4=BE=8B=EF=BC=8C=E4=BF=AE=E6=94=B9=E5=B5=8C=E5=A5=97?= =?UTF-8?q?=E5=BA=8F=E5=88=97=E5=8C=96=E6=97=B6=E5=87=BA=E7=8E=B0=E7=9A=84?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sanic_rest_framework/fields.py | 10 +- sanic_rest_framework/serializers.py | 22 +- sanic_rest_framework/test/models.py | 19 +- .../test/test_serializers/test_serializer.py | 253 +++++++++++++++++- 4 files changed, 282 insertions(+), 22 deletions(-) diff --git a/sanic_rest_framework/fields.py b/sanic_rest_framework/fields.py index 9e5e0db..482fd67 100644 --- a/sanic_rest_framework/fields.py +++ b/sanic_rest_framework/fields.py @@ -15,6 +15,7 @@ from datetime import timezone, timedelta, datetime, date, time from typing import Any, List, Mapping from tortoise import Model from tortoise.queryset import QuerySet +from tortoise.fields.relational import RelationalField, ManyToManyField from tortoise.exceptions import DoesNotExist from sanic_rest_framework.exceptions import ValidationError @@ -218,10 +219,9 @@ class Field: if isinstance(instance, Mapping): instance = instance[attr] else: - if '.' in attr: - instance = await self.async_get_attribute(instance, attr) - else: - instance = getattr(instance, attr) + instance = getattr(instance, attr) + if isinstance(instance, QuerySet): + instance = await instance except DoesNotExist: return None except (KeyError, AttributeError) as exc: @@ -416,7 +416,7 @@ class DecimalField(Field): } MAX_STRING_LENGTH = 1000 - def __init__(self, max_digits, decimal_places, coerce_to_string=False, max_value=None, min_value=None, + def __init__(self, max_digits, decimal_places, coerce_to_string=True, max_value=None, min_value=None, rounding=None, *args, **kwargs): """ 整数位数 = max_digits - decimal_places diff --git a/sanic_rest_framework/serializers.py b/sanic_rest_framework/serializers.py index 4943296..29c38b8 100644 --- a/sanic_rest_framework/serializers.py +++ b/sanic_rest_framework/serializers.py @@ -234,6 +234,7 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass): 内转外 :param data: :return: + """ res = OrderedDict() fields = self._readable_fields @@ -262,10 +263,6 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass): validated_value = validate_method(validated_value) except ValidationError as exc: errors[field.field_name] = exc.error_dict if hasattr(exc, 'error_dict') else exc.error_list - # if isinstance(field, BaseSerializer): - # errors[field.field_name] = exc.error_dict - # else: - # errors[field.field_name] = exc.error_list except SkipField: pass else: @@ -310,12 +307,14 @@ class ListSerializer(BaseSerializer): self.child.bind(field_name='', parent=self) async def get_internal_value(self, instance: Any) -> Any: - if isinstance(instance, Mapping): - data = instance.get(self.field_name) - if not isinstance(data, list): - data = [data] - return data - return await instance.all() + for attr in self.source_attrs: + if isinstance(instance, Mapping): + instance = instance.get(attr, []) + if not isinstance(instance, list): + instance = [instance] + else: + instance = await getattr(instance, attr) + return instance async def internal_to_external(self, data: Any) -> Any: """ @@ -323,7 +322,7 @@ class ListSerializer(BaseSerializer): :param data: :return: """ - iterable = await data.all() if isinstance(data, (QuerySet, Model, relational.RelationalField)) else data + iterable = await data.all() if isinstance(data, (QuerySet, relational.RelationalField)) else data return [ await self.child.internal_to_external(item) for item in iterable @@ -368,3 +367,4 @@ class ListSerializer(BaseSerializer): self.run_validators(value) value = self.validate(value) return value + diff --git a/sanic_rest_framework/test/models.py b/sanic_rest_framework/test/models.py index ad6a4c4..760ba4a 100644 --- a/sanic_rest_framework/test/models.py +++ b/sanic_rest_framework/test/models.py @@ -1,13 +1,14 @@ from datetime import date from tortoise import fields from tortoise import Model +from tortoise.fields import ForeignKeyRelation, ReverseRelation class UserModel(Model): name = fields.CharField(max_length=8, null=False) birthday = fields.DateField() phone = fields.CharField(max_length=11) - balance = fields.DecimalField(13, 2) + balance = fields.DecimalField(13, 3) address: fields.ManyToManyRelation["AddressModel"] = fields.ManyToManyField( 'models.AddressModel', through='user2address', related_name='user') @@ -17,3 +18,19 @@ class AddressModel(Model): address = fields.CharField(100) house_number = fields.CharField(100) # user: fields.ManyToManyRelation[UserModel] + + +class SchoolModel(Model): + name = fields.CharField(12) + address: fields.OneToOneRelation["AddressModel"] = fields.OneToOneField("models.AddressModel", 'school') + + +class ClassRoomModel(Model): + room_number = fields.CharField(18) + student_count = fields.IntField() + students: ReverseRelation['StudentModel'] + + +class StudentModel(Model): + name = fields.CharField(max_length=12, null=False) + class_room: ForeignKeyRelation["ClassRoomModel"] = fields.ForeignKeyField('models.ClassRoomModel', 'students') diff --git a/sanic_rest_framework/test/test_serializers/test_serializer.py b/sanic_rest_framework/test/test_serializers/test_serializer.py index b9ba763..ea4d0be 100644 --- a/sanic_rest_framework/test/test_serializers/test_serializer.py +++ b/sanic_rest_framework/test/test_serializers/test_serializer.py @@ -14,9 +14,11 @@ """ import asyncio +import datetime import unittest from collections import OrderedDict from copy import deepcopy +from decimal import Decimal from sanic_rest_framework.fields import ( CharField, IntegerField, FloatField, DateField, TimeField, DecimalField, @@ -24,7 +26,7 @@ from sanic_rest_framework.fields import ( from sanic_rest_framework.serializers import Serializer from tortoise.contrib.test import initializer, finalizer, TestCase -from sanic_rest_framework.test.models import UserModel, AddressModel +from sanic_rest_framework.test.models import UserModel, AddressModel, SchoolModel, StudentModel, ClassRoomModel from sanic_rest_framework.exceptions import ValidationError @@ -38,10 +40,36 @@ class UserSerializer(Serializer): name = CharField(max_length=8, required=True) birthday = DateField() phone = CharField(max_length=11) - balance = DecimalField(13, 2) + balance = DecimalField(13, 3) address = AddressSerializer(many=True) +class SchoolSerializer(Serializer): + name = CharField(max_length=12) + address = AddressSerializer() + + +class StudentSerializer(Serializer): + name = CharField(max_length=12) + # class_room = ClassRoomSerializer() + + +class ClassRoomSerializer(Serializer): + room_number = CharField(max_length=12) + student_count = IntegerField() + students = StudentSerializer(many=True) + + +class M2OClassRoomSerializer(Serializer): + room_number = CharField(max_length=12) + student_count = IntegerField() + + +class M2OStudentSerializer(Serializer): + name = CharField(max_length=12) + class_room = M2OClassRoomSerializer() + + initializer(['sanic_rest_framework.test.models', ], db_url="sqlite://./db.sqlite", loop=asyncio.get_event_loop()) @@ -85,14 +113,14 @@ class TestOrdinarySerializer(TestCase): class TestM2MNestedSerializer(TestCase): - """测试单个嵌套""" + """测试多对多嵌套""" def setUp(self) -> None: self.mapping_data = { 'name': '老四', 'birthday': '2010-11-16', 'phone': '17674707037', - 'balance': 99.009, + 'balance': '99.009', 'address': [{ 'phone': '17674707036', 'address': '长沙市IFS', @@ -103,6 +131,12 @@ class TestM2MNestedSerializer(TestCase): 'house_number': '4L' }] } + self.des_res = OrderedDict({'name': '老四', 'birthday': datetime.date(2010, 11, 16), + 'phone': '17674707037', 'balance': Decimal('99.009'), + 'address': [ + OrderedDict({'phone': '17674707036', 'address': '长沙市IFS', 'house_number': '67L'}), + OrderedDict({'phone': '17674707037', 'address': '长沙市HPT', 'house_number': '4L'})]}) + self.res = OrderedDict(self.mapping_data) async def test_serializer(self): """测试序列化""" @@ -117,11 +151,220 @@ class TestM2MNestedSerializer(TestCase): us_map = UserSerializer(instance=self.mapping_data) us_model = UserSerializer(instance=await UserModel.get(pk=1)) - print(await us_map.data) print(await us_model.data) + self.assertEqual(await us_map.data, self.res) + self.assertEqual(await us_model.data, self.res) async def test_deserializer(self): """测试反序列化""" + user_info = deepcopy(self.mapping_data) + address = user_info.pop('address') + user = UserModel(**user_info) + await user.save() + for i in address: + addr = AddressModel(**i) + await addr.save() + await user.address.add(addr) + us_map = UserSerializer(data=self.mapping_data) + us_model = UserSerializer(data=await UserModel.get(pk=1)) + self.assertIs(us_map.is_valid(), True) + self.assertEqual(us_map.validated_data, self.des_res) + + self.assertIs(us_model.is_valid(), False) + print(us_model.errors) + + @classmethod + async def tearDownClass(cls) -> None: + finalizer() + + +class TestO2OSerializer(TestCase): + """测试一对一序列化""" + + def setUp(self) -> None: + self.mapping_data = { + 'name': '老四', + 'address': { + 'phone': '17674707036', + 'address': '长沙市IFS', + 'house_number': '67L' + } + } + self.res = OrderedDict(self.mapping_data) + + async def test_serializer(self): + """测试序列化""" + school_info = deepcopy(self.mapping_data) + address = school_info.pop('address') + address_model = await AddressModel.create(**address) + school_model = await SchoolModel.create(address=address_model, **school_info) + + ss_map = SchoolSerializer(instance=self.mapping_data) + us_model = SchoolSerializer(instance=await SchoolModel.get(pk=1)) + self.assertEqual(await ss_map.data, self.res) + self.assertEqual(await us_model.data, self.res) + + async def test_deserializer(self): + """测试反序列化""" + school_info = deepcopy(self.mapping_data) + address = school_info.pop('address') + address_model = await AddressModel.create(**address) + school_model = await SchoolModel.create(address=address_model, **school_info) + + ss_map = SchoolSerializer(data=self.mapping_data) + ss_model = SchoolSerializer(data=await SchoolModel.get(pk=1)) + self.assertIs(ss_map.is_valid(), True) + self.assertEqual(ss_map.validated_data, self.res) + + self.assertIs(ss_model.is_valid(), False) + print(ss_model.errors) + + @classmethod + async def tearDownClass(cls) -> None: + finalizer() + + +class TestO2MSerializer(TestCase): + """测试一对多序列化器""" + + def setUp(self) -> None: + self.mapping_data = { + 'room_number': '1024', + 'student_count': 80, + 'students': [{'name': '刘文静'}, {'name': '马冬梅'}, {'name': '光明'}] + } + self.res = OrderedDict(self.mapping_data) + + async def test_serializer(self): + """测试序列化""" + + class_room_info = deepcopy(self.mapping_data) + students = class_room_info.pop('students') + class_model = await ClassRoomModel.create(**class_room_info) + await StudentModel.bulk_create([StudentModel(class_room=class_model, **student) for student in students]) + + cs_map = ClassRoomSerializer(instance=self.mapping_data) + cs_model = ClassRoomSerializer(instance=class_model) + print(await cs_map.data) + print(await cs_model.data) + self.assertEqual(await cs_map.data, self.res) + self.assertEqual(await cs_model.data, self.res) + + async def test_deserializer(self): + """测试反序列化""" + class_room_info = deepcopy(self.mapping_data) + students = class_room_info.pop('students') + class_model = await ClassRoomModel.create(**class_room_info) + await StudentModel.bulk_create([StudentModel(class_room=class_model, **student) for student in students]) + + cs_map = ClassRoomSerializer(data=self.mapping_data) + cs_model = ClassRoomSerializer(data=class_model) + self.assertIs(cs_map.is_valid(), True) + self.assertEqual(cs_map.validated_data, self.res) + + self.assertIs(cs_model.is_valid(), False) + print(cs_model.errors) + + @classmethod + async def tearDownClass(cls) -> None: + finalizer() + + +class TestM2OSerializer(TestCase): + """测试多对一序列化器""" + + def setUp(self) -> None: + self.mapping_data = [ + { + 'name': '马冬梅', + 'class_room': { + 'room_number': '1024', + 'student_count': 40 + } + }, { + 'name': '光明', + 'class_room': { + 'room_number': '1024', + 'student_count': 40 + } + }, { + 'name': '李焕英', + 'class_room': { + 'room_number': '1024', + 'student_count': 40 + } + }, + + ] + self.res_dict = { + '马冬梅': { + 'name': '马冬梅', + 'class_room': { + 'room_number': '1024', + 'student_count': 40 + } + }, '光明': { + 'name': '光明', + 'class_room': { + 'room_number': '1024', + 'student_count': 40 + } + }, '李焕英': { + 'name': '李焕英', + 'class_room': { + 'room_number': '1024', + 'student_count': 40 + } + }, + } + self.res = OrderedDict(self.mapping_data) + + async def test_serializer(self): + """测试序列化""" + + student_model_list = [] + students = deepcopy(self.mapping_data) + # 初始化数据库数据 + for student_info in students: + class_room_info = student_info.pop('class_room') + class_room_model, status = await ClassRoomModel.get_or_create(**class_room_info) + student_model = await StudentModel.create(class_room=class_room_model, **student_info) + student_model_list.append(student_model) + + # 验证模型类型的数据是否正常 + for student_model in student_model_list: + m2oss_model = M2OStudentSerializer(instance=student_model) + self.assertEqual(await m2oss_model.data, OrderedDict(self.res_dict[student_model.name])) + + # 验证字段类型的数据是否正常 + students = deepcopy(self.mapping_data) + for student_info in students: + m2oss_map = M2OStudentSerializer(instance=student_info) + self.assertEqual(await m2oss_map.data, OrderedDict(self.res_dict[student_info['name']])) + + async def test_deserializer(self): + """测试反序列化""" + student_model_list = [] + students = deepcopy(self.mapping_data) + # 初始化数据库数据 + for student_info in students: + class_room_info = student_info.pop('class_room') + class_room_model, status = await ClassRoomModel.get_or_create(**class_room_info) + student_model = await StudentModel.create(class_room=class_room_model, **student_info) + student_model_list.append(student_model) + + # 验证模型类型的数据是否正常 + for student_model in student_model_list: + m2oss_model = M2OStudentSerializer(data=student_model) + self.assertIs(m2oss_model.is_valid(), False) + print(m2oss_model.errors) + + # 验证字段类型的数据是否正常 + students = deepcopy(self.mapping_data) + for student_info in students: + m2oss_map = M2OStudentSerializer(data=student_info) + self.assertIs(m2oss_map.is_valid(), True) + self.assertEqual(m2oss_map.validated_data, OrderedDict(self.res_dict[student_info['name']])) @classmethod async def tearDownClass(cls) -> None: -- Gitee From 6f5b639c77b4cc35b73cf4dbdd57b65a2355559b Mon Sep 17 00:00:00 2001 From: LaoSi Date: Thu, 4 Mar 2021 17:23:38 +0800 Subject: [PATCH 21/34] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E7=94=A8=E4=BE=8B=EF=BC=8C=E5=85=A8=E9=83=A8=E8=B7=91=E9=80=9A?= =?UTF-8?q?=EF=BC=8C=E4=BF=AE=E6=94=B9BUG=20=E6=96=B0=E5=A2=9E=20ModelSeri?= =?UTF-8?q?alizers=20=E5=B0=9A=E6=9C=AA=E5=AE=8C=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sanic_rest_framework/fields.py | 79 +++++----- sanic_rest_framework/serializers.py | 140 +++++++++++++++++- sanic_rest_framework/test/models.py | 3 + .../test/test_fields/test_base_field.py | 16 +- .../test/test_fields/test_bool_field.py | 28 ++-- .../test/test_fields/test_char_field.py | 18 ++- .../test/test_fields/test_choice_field.py | 69 +++++++++ .../test/test_fields/test_date_field.py | 27 ++-- .../test/test_fields/test_datetime_field.py | 13 +- .../test/test_fields/test_decimal_field.py | 26 ++-- .../test/test_fields/test_float_field.py | 22 +-- .../test/test_fields/test_integer_field.py | 22 +-- .../test/test_fields/test_time_field.py | 22 ++- .../test/test_serializers/test_serializer.py | 11 +- 14 files changed, 379 insertions(+), 117 deletions(-) create mode 100644 sanic_rest_framework/test/test_fields/test_choice_field.py diff --git a/sanic_rest_framework/fields.py b/sanic_rest_framework/fields.py index 482fd67..f01cef4 100644 --- a/sanic_rest_framework/fields.py +++ b/sanic_rest_framework/fields.py @@ -301,22 +301,22 @@ class Field: root = root.parent return root - def raise_error(self, key, **kwargs): + def raise_error(self, _key, **kwargs): """ 返回在 error_messages 中注册了的错误 - :param key: 错误的键 + :param _key: 错误的键 :param kwargs: :return: """ try: - msg = self.error_messages[key] + msg = self.error_messages[_key] except KeyError: class_name = self.__class__.__name__ msg = "在 {class_name} 类的 error_messages " \ - "属性中未能找到 Key 为 {key} 的错误描述".format(class_name=class_name, key=key) + "属性中未能找到 Key 为 {key} 的错误描述".format(class_name=class_name, key=_key) raise AssertionError(msg) message_string = msg.format(**kwargs) - raise ValidationError(message_string, code=key) + raise ValidationError(message_string, code=_key) class CharField(Field): @@ -416,7 +416,7 @@ class DecimalField(Field): } MAX_STRING_LENGTH = 1000 - def __init__(self, max_digits, decimal_places, coerce_to_string=True, max_value=None, min_value=None, + def __init__(self, max_digits, decimal_places, coerce_to_string=False, max_value=None, min_value=None, rounding=None, *args, **kwargs): """ 整数位数 = max_digits - decimal_places @@ -472,7 +472,7 @@ class DecimalField(Field): if not self.coerce_to_string: return quantized - return '{:f}'.format(quantized) + return ('{:%sf}' % self.decimal_places).format(quantized) def validate_precision(self, value): """ @@ -592,24 +592,25 @@ class DateTimeField(Field): return value.astimezone(self.set_timezone) def external_to_internal(self, data: Any) -> Any: - if not isinstance(data, (str, data, datetime)): + if not isinstance(data, (str, date, datetime)): self.raise_error('convert') - if isinstance(data, date): + if type(data) == date: self.raise_error('date') if isinstance(data, str): try: data = datetime.strptime(data, self.input_formats) except (ValueError, TypeError): self.raise_error('convert') - if isinstance(data, datetime): + if type(data) == datetime: data = self.enforce_timezone(data) return data async def internal_to_external(self, data: Any) -> Any: - if not data: - return None if isinstance(data, str): - return data + try: + data = datetime.strptime(data, self.input_formats) + except (ValueError, TypeError): + self.raise_error('convert') if isinstance(data, datetime): return data.strftime(self.output_format) self.raise_error('invalid', data_type=type(data).__name__) @@ -631,7 +632,6 @@ class DateField(Field): if isinstance(data, str): try: data = datetime.strptime(data, self.input_formats).date() - return data except (ValueError, TypeError): self.raise_error('invalid', value=data) if isinstance(data, datetime): @@ -641,13 +641,14 @@ class DateField(Field): self.raise_error('invalid', value=data) async def internal_to_external(self, data: Any) -> Any: - if not data: - return data if isinstance(data, str): - return data + try: + data = datetime.strptime(data, self.input_formats).date() + except (ValueError, TypeError): + self.raise_error('invalid', value=data) if isinstance(data, date): return data.strftime(self.output_format) - self.raise_error('invalid') + self.raise_error('invalid', value=data) class TimeField(Field): @@ -679,14 +680,17 @@ class TimeField(Field): self.raise_error('date') self.raise_error('invalid', value=type(data)) - def external_convert(self, data: Any) -> Any: - if not data: - return data + async def internal_to_external(self, data: Any) -> Any: if isinstance(data, str): - return data - if isinstance(data, (time, datetime)): + try: + data = datetime.strptime(data, self.input_formats).time() + except (ValueError, TypeError): + self.raise_error('format') + if isinstance(data, datetime): + data = data.time() + if isinstance(data, time): return data.strftime(self.output_format) - self.raise_error('invalid') + self.raise_error('invalid', value=data) class ChoiceField(Field): @@ -706,18 +710,22 @@ class ChoiceField(Field): super(ChoiceField, self).__init__(*args, **kwargs) def external_to_internal(self, data: Any) -> Any: - data = self.get_choices().get(str(data)) - return data + return self.choices_get_value_by_key(data) - def external_convert(self, data: Any) -> Any: - return str(data) + async def internal_to_external(self, data: Any) -> Any: + return self.choices_get_value_by_key(data) def choices_get_value_by_key(self, key): """得到字符串""" - choices = self.get_choices() - if key not in choices: - self.raise_error('key', key=key) - return choices.get(key, None) + if self.check_key_choices(key): + choices_dict = self.get_choices() + value = choices_dict[str(key)] + return value + self.raise_error('key', key=key) + + def check_key_choices(self, key): + choices_dict = self.get_choices() + return key in choices_dict def get_choices(self) -> dict: choices = {str(key): value for key, value in self.choices} @@ -754,6 +762,9 @@ class SerializerMethodField(Field): super().bind(field_name, parent) - def to_representation(self, value): + def internal_to_external(self, data: Any) -> Any: method = getattr(self.parent, self.method_name) - return method(value) + return method(data) + + def external_to_internal(self, data: Any) -> Any: + raise ValidationError('SerializerMethodField 不支持反序列化') diff --git a/sanic_rest_framework/serializers.py b/sanic_rest_framework/serializers.py index 29c38b8..1af96b9 100644 --- a/sanic_rest_framework/serializers.py +++ b/sanic_rest_framework/serializers.py @@ -14,10 +14,13 @@ from collections import OrderedDict from typing import Any, Mapping from tortoise import models, Model -from tortoise.fields import relational from tortoise.queryset import QuerySet +from tortoise import fields -from sanic_rest_framework.fields import Field, empty, SkipField +from sanic_rest_framework.fields import ( + empty, SkipField, + Field, CharField, IntegerField, FloatField, DecimalField, BooleanField, DateTimeField, DateField, TimeField, ChoiceField, SerializerMethodField +) from .exceptions import ValidationError from .helpers import BindingDict @@ -322,7 +325,7 @@ class ListSerializer(BaseSerializer): :param data: :return: """ - iterable = await data.all() if isinstance(data, (QuerySet, relational.RelationalField)) else data + iterable = await data.all() if isinstance(data, (QuerySet, fields.relational.RelationalField)) else data return [ await self.child.internal_to_external(item) for item in iterable @@ -368,3 +371,134 @@ class ListSerializer(BaseSerializer): value = self.validate(value) return value + + +class ModelSerializer(Serializer): + """ + class Meta: + model = None + depth = 0 + extra_kwargs = {} + fields = () # 冲突,不能与exclude共存 + exclude = () # 冲突,不能与fields共存 + read_only_fields = () # 字段与write_only_fields冲突 + write_only_fields = () # 字段与read_only_fields冲突 + + """ + convert_mapping = { + fields.BigIntField: IntegerField, + fields.BinaryField: None, # TODO : 暂时无法解决 + fields.BooleanField: BooleanField, + fields.CharEnumField: ChoiceField, + fields.CharField: CharField, + fields.DateField: DateField, + fields.DatetimeField: DateTimeField, + fields.DecimalField: DecimalField, + fields.FloatField: FloatField, + fields.IntEnumField: ChoiceField, + fields.IntField: IntegerField, + fields.JSONField: CharField, + fields.SmallIntField: IntegerField, + fields.TextField: CharField, + fields.TimeDeltaField: None, # TODO : 需要为其单独创建一个对应的字段 + fields.UUIDField: CharField, + } + + @property + def fields(self): + """ + 单个格式为 {field_name: field_instance}. + fields 是动态加载的 避免在导入时出现意想不到的错误 + """ + assert hasattr(self, 'Meta'), ( + '{serializer_class} 类没有 "Meta" 属性'.format( + serializer_class=self.__class__.__name__ + ) + ) + assert hasattr(self.Meta, 'model'), ( + '{serializer_class} 类没有 "Meta.model" 属性'.format( + serializer_class=self.__class__.__name__ + ) + ) + if self.Meta.model.Meta.abstract: + raise ValueError('不能将ModelSerializer与抽象模型一起使用。') + + declared_fields = copy.deepcopy(self._declared_fields) + model = getattr(self.Meta, 'model') + depth = getattr(self.Meta, 'depth', 0) + + model_fields = self._clean_model_field(model) + model_basis_fields = self._get_model_basis_fields(model_fields) + serializer_fields = {} + for basis_field_name, basis_field_class in model_basis_fields.items(): + if basis_field_name in declared_fields: + current_field_class = declared_fields[basis_field_name] + else: + current_field_class = self.convert_mapping[basis_field_class.__class__] + serializer_fields[basis_field_name] = current_field_class + + # + # + # # like drf + # fields = BindingDict(self) + # for key, value in self.get_fields().items(): + # fields[key] = value + # return fields + + def _clean_model_field(self, model): + """ + 清除不需要的字段如 fk_id + :param model: + :return: + """ + clean_field_names = [] + field_dict = {} + fields_map = copy.deepcopy(model._meta.fields_map()) + for field_name, field_class in fields_map.items(): + if isinstance(field_class, (fields.relational.ForeignKeyFieldInstance, fields.relational.OneToOneFieldInstance)): + clean_field_names.append(field_class.source_field) + field_dict[field_name] = field_class + + for clean_field_name in clean_field_names: + if clean_field_name in field_dict: + field_dict.pop(clean_field_name) + return field_dict + + def _get_model_basis_fields(self, model_fields): + """ + 得到基础字段,非关系字段 + :param model: + :return: + """ + return {field_name: field_class for field_name, field_class in model_fields.items() if not isinstance(field_class, fields.relational.RelationalField)} + + def _get_model_M2M_fields(self, model_fields): + """ + 得到多对多字段 + :param model: + :return: + """ + return {field_name: field_class for field_name, field_class in model_fields.items() if isinstance(field_class, fields.relational.ManyToManyFieldInstance)} + + def _get_model_O2O_fields(self, model_fields): + """得到一对一字段""" + return {field_name: field_class for field_name, field_class in model_fields.items() if + isinstance(field_class, (fields.relational.BackwardOneToOneRelation, fields.relational.OneToOneFieldInstance))} + + def _get_model_M2O_fields(self, model_fields): + """ + 得到多对一字段 + :param model: + :return: + """ + return {field_name: field_class for field_name, field_class in model_fields.items() if + isinstance(field_class, fields.relational.ForeignKeyFieldInstance)} + + def _get_model_O2M_fields(self, model_fields): + """ + 得到一对多字段 + :param model: + :return: + """ + return {field_name: field_class for field_name, field_class in model_fields.items() if + isinstance(field_class, fields.relational.BackwardFKRelation)} diff --git a/sanic_rest_framework/test/models.py b/sanic_rest_framework/test/models.py index 760ba4a..4708db3 100644 --- a/sanic_rest_framework/test/models.py +++ b/sanic_rest_framework/test/models.py @@ -34,3 +34,6 @@ class ClassRoomModel(Model): class StudentModel(Model): name = fields.CharField(max_length=12, null=False) class_room: ForeignKeyRelation["ClassRoomModel"] = fields.ForeignKeyField('models.ClassRoomModel', 'students') + +class DateSeriesModel(Model): + name = fields.TimeDeltaField() \ No newline at end of file diff --git a/sanic_rest_framework/test/test_fields/test_base_field.py b/sanic_rest_framework/test/test_fields/test_base_field.py index ece1e54..5afa85a 100644 --- a/sanic_rest_framework/test/test_fields/test_base_field.py +++ b/sanic_rest_framework/test/test_fields/test_base_field.py @@ -13,16 +13,28 @@ 2021/2/2 15:58 change 'Fix bug' """ +import asyncio + from tortoise.contrib import test +from tortoise.contrib.test import initializer, finalizer + from sanic_rest_framework.exceptions import ValidationError from sanic_rest_framework.fields import Field as TestField, empty, SkipField from sanic_rest_framework.test.utils import TestDataMixin from sanic_rest_framework.validators import MaxValueValidator +initializer(['sanic_rest_framework.test.models', ], + # db_url="sqlite://./db.sqlite", + loop=asyncio.get_event_loop()) + class TestBaseField(TestDataMixin, test.TestCase): """测试基类的基本功能""" + @classmethod + async def tearDownClass(cls) -> None: + finalizer() + def test_bing(self): tf1 = TestField() tf2 = TestField() @@ -54,7 +66,7 @@ class TestBaseField(TestDataMixin, test.TestCase): self.assertEqual(tf2.get_external_value(data3), 1) self.assertEqual(tf2.get_external_value(data4), self.str_chinese) - def test_get_internal_value(self): + async def test_get_internal_value(self): # 未进行 Model 类型测试 test_data = [ [self.str_chinese, {'tf': self.str_chinese}], @@ -74,7 +86,7 @@ class TestBaseField(TestDataMixin, test.TestCase): tf = TestField() tf.bind('tf', tf) for value, data in test_data: - self.assertEqual(tf.get_internal_value(data), value) + self.assertEqual(await tf.get_internal_value(data), value) def test_run_validators(self): tf = TestField(validators=[MaxValueValidator(1000)]) diff --git a/sanic_rest_framework/test/test_fields/test_bool_field.py b/sanic_rest_framework/test/test_fields/test_bool_field.py index 878a8a2..02d9b1f 100644 --- a/sanic_rest_framework/test/test_fields/test_bool_field.py +++ b/sanic_rest_framework/test/test_fields/test_bool_field.py @@ -13,12 +13,20 @@ 2021/2/3 10:07 change 'Fix bug' """ +import asyncio import unittest from decimal import Decimal, ROUND_DOWN, getcontext, InvalidOperation + +from tortoise.contrib.test import initializer + from sanic_rest_framework.exceptions import ValidationError from sanic_rest_framework.fields import BooleanField as TestField from sanic_rest_framework.test.test_fields.test_base_field import TestBaseField +initializer(['sanic_rest_framework.test.models', ], + # db_url="sqlite://./db.sqlite", + loop=asyncio.get_event_loop()) + class TestDecimalField(TestBaseField): @@ -53,7 +61,7 @@ class TestDecimalField(TestBaseField): for i in NULL_VALUES: self.assertEqual(tf1.external_to_internal(i), None) - def test_internal_to_external(self): + async def test_internal_to_external(self): """ 内转外 str -> dict 是宽松的, 只要类型正确都不报错 @@ -77,16 +85,16 @@ class TestDecimalField(TestBaseField): } tf1 = TestField() for i in TRUE_VALUES: - self.assertEqual(tf1.internal_to_external(i), True) + self.assertEqual(await tf1.internal_to_external(i), True) for i in FALSE_VALUES: - self.assertEqual(tf1.internal_to_external(i), False) - self.assertEqual(tf1.internal_to_external(None), False) - self.assertEqual(tf1.internal_to_external(''), False) - self.assertEqual(tf1.internal_to_external('null'), True) - self.assertEqual(tf1.internal_to_external('any'), True) - self.assertEqual(tf1.internal_to_external('NULL'), True) - self.assertEqual(tf1.internal_to_external('yyyy'), True) - self.assertEqual(tf1.internal_to_external('曹凯'), True) + self.assertEqual(await tf1.internal_to_external(i), False) + self.assertEqual(await tf1.internal_to_external(None), False) + self.assertEqual(await tf1.internal_to_external(''), False) + self.assertEqual(await tf1.internal_to_external('null'), True) + self.assertEqual(await tf1.internal_to_external('any'), True) + self.assertEqual(await tf1.internal_to_external('NULL'), True) + self.assertEqual(await tf1.internal_to_external('yyyy'), True) + self.assertEqual(await tf1.internal_to_external('曹凯'), True) if __name__ == '__main__': diff --git a/sanic_rest_framework/test/test_fields/test_char_field.py b/sanic_rest_framework/test/test_fields/test_char_field.py index 7e8a3eb..449d5a4 100644 --- a/sanic_rest_framework/test/test_fields/test_char_field.py +++ b/sanic_rest_framework/test/test_fields/test_char_field.py @@ -13,13 +13,19 @@ 2021/1/28 16:16 change 'Fix bug' """ - +import asyncio import unittest +from tortoise.contrib.test import initializer + from sanic_rest_framework.exceptions import ValidationError from sanic_rest_framework.fields import CharField from sanic_rest_framework.test.test_fields.test_base_field import TestBaseField +initializer(['sanic_rest_framework.test.models', ], + # db_url="sqlite://./db.sqlite", + loop=asyncio.get_event_loop()) + class TestCharField(TestBaseField): def test_external_to_internal(self): @@ -27,17 +33,17 @@ class TestCharField(TestBaseField): char1 = CharField() self.assertEqual(char1.external_to_internal(data), 'Python') - def test_internal_to_external(self): + async def test_internal_to_external(self): data1 = {'char1': 'Python'} data2 = {'char1': 66666} char1 = CharField() char1.bind('char1', char1) - value = char1.get_internal_value(data1) - self.assertEqual(char1.internal_to_external(value), 'Python') + value = await char1.get_internal_value(data1) + self.assertEqual(await char1.internal_to_external(value), 'Python') - value = char1.get_internal_value(data2) - self.assertEqual(char1.internal_to_external(value), '66666') + value = await char1.get_internal_value(data2) + self.assertEqual(await char1.internal_to_external(value), '66666') def test_trim_whitespace(self): data = ' Python' diff --git a/sanic_rest_framework/test/test_fields/test_choice_field.py b/sanic_rest_framework/test/test_fields/test_choice_field.py new file mode 100644 index 0000000..8fc9277 --- /dev/null +++ b/sanic_rest_framework/test/test_fields/test_choice_field.py @@ -0,0 +1,69 @@ +""" +@Author:WangYuXiang +@E-mile:Hill@3io.cc +@CreateTime:2021/3/4 11:20 +@DependencyLibrary:无 +@MainFunction:无 +@FileDoc: + test_choice_field.py + 文件说明 +@ChangeHistory: + datetime action why + example: + 2021/3/4 11:20 change 'Fix bug' + +""" + +import unittest + +from sanic_rest_framework.exceptions import ValidationError +from sanic_rest_framework.fields import ChoiceField as TestField +from sanic_rest_framework.test.test_fields.test_base_field import TestBaseField + + +class TestDateTimeField(TestBaseField): + def setUp(self) -> None: + self.choice = ( + ('刘文静', '开发人员'), + ('光明', '播音员'), + ('李焕英', '纺织厂女工') + ) + + def test_external_to_internal(self): + """ + 外转内 str -> dict 是严格的,不符合类型的都应该报错, + 一切都要经过验证 + :return: + """ + + # 正常测试 + tf1 = TestField(choices=self.choice) + self.assertEqual(tf1.external_to_internal('刘文静'), '开发人员') + self.assertEqual(tf1.external_to_internal('光明'), '播音员') + self.assertEqual(tf1.external_to_internal('李焕英'), '纺织厂女工') + + with self.assertRaises(ValidationError): + tf1.external_to_internal(None) + with self.assertRaises(ValidationError): + tf1.external_to_internal('曹凯') + + async def test_internal_to_external(self): + """ + 内转外 str -> dict 是宽松的, + 只要类型正确都不报错 + :return: + """ + # 正常测试 + tf1 = TestField(choices=self.choice) + self.assertEqual(await tf1.internal_to_external('刘文静'), '开发人员') + self.assertEqual(await tf1.internal_to_external('光明'), '播音员') + self.assertEqual(await tf1.internal_to_external('李焕英'), '纺织厂女工') + + with self.assertRaises(ValidationError): + await tf1.internal_to_external(None) + with self.assertRaises(ValidationError): + await tf1.internal_to_external('曹凯') + + +if __name__ == '__main__': + unittest.main() diff --git a/sanic_rest_framework/test/test_fields/test_date_field.py b/sanic_rest_framework/test/test_fields/test_date_field.py index f9d3936..fe2d7fd 100644 --- a/sanic_rest_framework/test/test_fields/test_date_field.py +++ b/sanic_rest_framework/test/test_fields/test_date_field.py @@ -17,7 +17,7 @@ import unittest from sanic_rest_framework.exceptions import ValidationError -from sanic_rest_framework.fields import DateField as TestField +from sanic_rest_framework.fields import DateTimeField as TestField from sanic_rest_framework.test.test_fields.test_base_field import TestBaseField @@ -32,8 +32,9 @@ class TestDateTimeField(TestBaseField): # 正常测试 tf1 = TestField() - self.assertEqual(tf1.external_to_internal(self.str_date), self.obj_date) - self.assertEqual(tf1.external_to_internal(self.obj_date), self.obj_date) + + self.assertEqual(tf1.external_to_internal(self.str_datetime), tf1.enforce_timezone(self.obj_datetime)) + self.assertEqual(tf1.external_to_internal(self.obj_datetime), tf1.enforce_timezone(self.obj_datetime)) with self.assertRaises(ValidationError): tf1.external_to_internal(self.str_datetime_bad_y) with self.assertRaises(ValidationError): @@ -49,24 +50,28 @@ class TestDateTimeField(TestBaseField): with self.assertRaises(ValidationError): tf1.external_to_internal(self.str_time) with self.assertRaises(ValidationError): - tf1.external_to_internal(self.str_datetime) - + tf1.external_to_internal(self.obj_time) + with self.assertRaises(ValidationError): + tf1.external_to_internal(self.str_date) + with self.assertRaises(ValidationError): + tf1.external_to_internal(self.obj_date) - def test_internal_to_external(self): + async def test_internal_to_external(self): """ 内转外 str -> dict 是宽松的, 只要类型正确都不报错 :return: """ tf1 = TestField() - self.assertEqual(tf1.internal_to_external(None), None) - self.assertEqual(tf1.internal_to_external(self.obj_datetime), self.str_datetime) with self.assertRaises(ValidationError): - tf1.internal_to_external(self.obj_date) + await tf1.internal_to_external(None) + self.assertEqual(await tf1.internal_to_external(self.obj_datetime), self.str_datetime) + with self.assertRaises(ValidationError): + await tf1.internal_to_external(self.obj_date) with self.assertRaises(ValidationError): - tf1.internal_to_external(self.obj_time) + await tf1.internal_to_external(self.obj_time) with self.assertRaises(ValidationError): - tf1.internal_to_external(1) + await tf1.internal_to_external(1) if __name__ == '__main__': diff --git a/sanic_rest_framework/test/test_fields/test_datetime_field.py b/sanic_rest_framework/test/test_fields/test_datetime_field.py index e39971d..fd53cb8 100644 --- a/sanic_rest_framework/test/test_fields/test_datetime_field.py +++ b/sanic_rest_framework/test/test_fields/test_datetime_field.py @@ -67,21 +67,22 @@ class TestDateTimeField(TestBaseField): with self.assertRaises(ValidationError): tf1.external_to_internal(self.str_date) - def test_internal_to_external(self): + async def test_internal_to_external(self): """ 内转外 str -> dict 是宽松的, 只要类型正确都不报错 :return: """ tf1 = TestField() - self.assertEqual(tf1.internal_to_external(None), None) - self.assertEqual(tf1.internal_to_external(self.obj_datetime), self.str_datetime) with self.assertRaises(ValidationError): - tf1.internal_to_external(self.obj_date) + await tf1.internal_to_external(None) + self.assertEqual(await tf1.internal_to_external(self.obj_datetime), self.str_datetime) with self.assertRaises(ValidationError): - tf1.internal_to_external(self.obj_time) + await tf1.internal_to_external(self.obj_date) with self.assertRaises(ValidationError): - tf1.internal_to_external(1) + await tf1.internal_to_external(self.obj_time) + with self.assertRaises(ValidationError): + await tf1.internal_to_external(1) if __name__ == '__main__': unittest.main() diff --git a/sanic_rest_framework/test/test_fields/test_decimal_field.py b/sanic_rest_framework/test/test_fields/test_decimal_field.py index 65cbd97..38fae19 100644 --- a/sanic_rest_framework/test/test_fields/test_decimal_field.py +++ b/sanic_rest_framework/test/test_fields/test_decimal_field.py @@ -72,7 +72,7 @@ class TestDecimalField(TestBaseField): self.assertEqual(tf1.external_to_internal(self.str_max_int), Decimal(self.max_int)) - def test_internal_to_external(self): + async def test_internal_to_external(self): """ 内转外 str -> dict 是宽松的, 只要类型正确都不报错 @@ -84,29 +84,29 @@ class TestDecimalField(TestBaseField): data4 = {'tf': self.str_max_float} data5 = {'tf': self.max_float} - tf = TestField(max_digits=6, decimal_places=2) + tf = TestField(max_digits=6, decimal_places=2, coerce_to_string=False) tf.bind('tf', tf) - value = tf.get_internal_value(data1) + value = await tf.get_internal_value(data1) with self.assertRaises(InvalidOperation): - tf.internal_to_external(value) + await tf.internal_to_external(value) - value = tf.get_internal_value(data2) - self.assertEqual(tf.internal_to_external(value), self.max_int) + value = await tf.get_internal_value(data2) + self.assertEqual(await tf.internal_to_external(value), 9999.00) # Decimal(int) == int # Decimal(str_int) == int - value = tf.get_internal_value(data3) - self.assertEqual(tf.internal_to_external(value), self.max_int) + value = await tf.get_internal_value(data3) + self.assertEqual(await tf.internal_to_external(value), self.max_int) # Decimal(float) == float # Decimal(str_float) != float - value = tf.get_internal_value(data4) - self.assertNotEqual(tf.internal_to_external(value), self.max_float) - self.assertEqual(tf.internal_to_external(value), Decimal(self.str_max_float)) + value = await tf.get_internal_value(data4) + self.assertNotEqual(await tf.internal_to_external(value), self.max_float) + self.assertEqual(await tf.internal_to_external(value), Decimal(self.str_max_float)) - value = tf.get_internal_value(data5) - self.assertEqual(tf.internal_to_external(value), Decimal(self.str_max_float)) + value = await tf.get_internal_value(data5) + self.assertEqual(await tf.internal_to_external(value), Decimal(self.str_max_float)) def test_max_value(self): tf1 = TestField(max_digits=6, decimal_places=2) diff --git a/sanic_rest_framework/test/test_fields/test_float_field.py b/sanic_rest_framework/test/test_fields/test_float_field.py index 1157363..4664363 100644 --- a/sanic_rest_framework/test/test_fields/test_float_field.py +++ b/sanic_rest_framework/test/test_fields/test_float_field.py @@ -44,7 +44,7 @@ class TestFloatField(TestBaseField): self.assertEqual(tf1.external_to_internal(self.max_int), self.max_int) self.assertEqual(tf1.external_to_internal(self.str_max_int), self.max_int) - def test_internal_to_external(self): + async def test_internal_to_external(self): """ 内转外 str -> dict 是宽松的, 只要是数值类型都不报错,float(xx) @@ -59,19 +59,19 @@ class TestFloatField(TestBaseField): tf = TestField() tf.bind('tf1', tf) - value = tf.get_internal_value(data1) + value = await tf.get_internal_value(data1) with self.assertRaises(ValueError): - tf.internal_to_external(value) + await tf.internal_to_external(value) - value = tf.get_internal_value(data2) - self.assertEqual(tf.internal_to_external(value), self.max_int) + value = await tf.get_internal_value(data2) + self.assertEqual(await tf.internal_to_external(value), self.max_int) - value = tf.get_internal_value(data3) - self.assertEqual(tf.internal_to_external(value), self.max_int) - value = tf.get_internal_value(data4) - self.assertEqual(tf.internal_to_external(value), self.max_float) - value = tf.get_internal_value(data5) - self.assertEqual(tf.internal_to_external(value), self.max_float) + value = await tf.get_internal_value(data3) + self.assertEqual(await tf.internal_to_external(value), self.max_int) + value = await tf.get_internal_value(data4) + self.assertEqual(await tf.internal_to_external(value), self.max_float) + value = await tf.get_internal_value(data5) + self.assertEqual(await tf.internal_to_external(value), self.max_float) def test_max_value(self): tf1 = TestField() diff --git a/sanic_rest_framework/test/test_fields/test_integer_field.py b/sanic_rest_framework/test/test_fields/test_integer_field.py index 38df7be..18e6c8e 100644 --- a/sanic_rest_framework/test/test_fields/test_integer_field.py +++ b/sanic_rest_framework/test/test_fields/test_integer_field.py @@ -44,7 +44,7 @@ class TestCharField(TestBaseField): self.assertEqual(tf.external_to_internal(self.max_int), self.max_int) self.assertEqual(tf.external_to_internal(self.str_max_int), self.max_int) - def test_internal_to_external(self): + async def test_internal_to_external(self): """ 内转外 str -> dict 是宽松的, 只要是数值类型都不报错,int(xx) @@ -59,23 +59,23 @@ class TestCharField(TestBaseField): tf = TestField() tf.bind('tf1', tf) - value = tf.get_internal_value(data1) + value = await tf.get_internal_value(data1) with self.assertRaises(ValueError): - tf.internal_to_external(value) + await tf.internal_to_external(value) - value = tf.get_internal_value(data2) - self.assertEqual(tf.internal_to_external(value), self.max_int) + value = await tf.get_internal_value(data2) + self.assertEqual(await tf.internal_to_external(value), self.max_int) - value = tf.get_internal_value(data3) - self.assertEqual(tf.internal_to_external(value), self.max_int) - value = tf.get_internal_value(data4) + value = await tf.get_internal_value(data3) + self.assertEqual(await tf.internal_to_external(value), self.max_int) + value = await tf.get_internal_value(data4) with self.assertRaises(ValueError): # str_float 不可被转换 - tf.internal_to_external(value) + await tf.internal_to_external(value) - value = tf.get_internal_value(data5) - self.assertEqual(tf.internal_to_external(value), self.max_int) + value = await tf.get_internal_value(data5) + self.assertEqual(await tf.internal_to_external(value), self.max_int) def test_max_value(self): tf1 = TestField() diff --git a/sanic_rest_framework/test/test_fields/test_time_field.py b/sanic_rest_framework/test/test_fields/test_time_field.py index f42f726..fcc7821 100644 --- a/sanic_rest_framework/test/test_fields/test_time_field.py +++ b/sanic_rest_framework/test/test_fields/test_time_field.py @@ -49,25 +49,33 @@ class TestDateTimeField(TestBaseField): with self.assertRaises(ValidationError): tf1.external_to_internal(self.str_datetime_bad_s) with self.assertRaises(ValidationError): - tf1.external_to_internal(self.str_time) + tf1.external_to_internal(self.str_date) with self.assertRaises(ValidationError): tf1.external_to_internal(self.str_datetime) - def test_internal_to_external(self): + async def test_internal_to_external(self): """ 内转外 str -> dict 是宽松的, 只要类型正确都不报错 :return: """ tf1 = TestField() - self.assertEqual(tf1.internal_to_external(None), None) - self.assertEqual(tf1.internal_to_external(self.obj_datetime), self.str_datetime) with self.assertRaises(ValidationError): - tf1.internal_to_external(self.obj_date) + await tf1.internal_to_external(None) + # 可以转换对象类型的 datetime + self.assertEqual(await tf1.internal_to_external(self.obj_datetime), self.str_time) + self.assertEqual(await tf1.internal_to_external(self.obj_time), self.str_time) + self.assertEqual(await tf1.internal_to_external(self.str_time), self.str_time) + + # 不可转换字符类型的 datetime + with self.assertRaises(ValidationError): + await tf1.internal_to_external(self.str_datetime) + with self.assertRaises(ValidationError): + await tf1.internal_to_external(self.obj_date) with self.assertRaises(ValidationError): - tf1.internal_to_external(self.obj_time) + await tf1.internal_to_external(self.str_date) with self.assertRaises(ValidationError): - tf1.internal_to_external(1) + await tf1.internal_to_external(1) if __name__ == '__main__': diff --git a/sanic_rest_framework/test/test_serializers/test_serializer.py b/sanic_rest_framework/test/test_serializers/test_serializer.py index ea4d0be..8898532 100644 --- a/sanic_rest_framework/test/test_serializers/test_serializer.py +++ b/sanic_rest_framework/test/test_serializers/test_serializer.py @@ -71,7 +71,7 @@ class M2OStudentSerializer(Serializer): initializer(['sanic_rest_framework.test.models', ], - db_url="sqlite://./db.sqlite", + # db_url="sqlite://./db.sqlite", loop=asyncio.get_event_loop()) @@ -136,6 +136,11 @@ class TestM2MNestedSerializer(TestCase): 'address': [ OrderedDict({'phone': '17674707036', 'address': '长沙市IFS', 'house_number': '67L'}), OrderedDict({'phone': '17674707037', 'address': '长沙市HPT', 'house_number': '4L'})]}) + self.s_res = OrderedDict({'name': '老四', 'birthday': '2010-11-16', + 'phone': '17674707037', 'balance': Decimal('99.009'), + 'address': [ + OrderedDict({'phone': '17674707036', 'address': '长沙市IFS', 'house_number': '67L'}), + OrderedDict({'phone': '17674707037', 'address': '长沙市HPT', 'house_number': '4L'})]}) self.res = OrderedDict(self.mapping_data) async def test_serializer(self): @@ -152,8 +157,8 @@ class TestM2MNestedSerializer(TestCase): us_map = UserSerializer(instance=self.mapping_data) us_model = UserSerializer(instance=await UserModel.get(pk=1)) print(await us_model.data) - self.assertEqual(await us_map.data, self.res) - self.assertEqual(await us_model.data, self.res) + self.assertEqual(await us_map.data, self.s_res) + self.assertEqual(await us_model.data, self.s_res) async def test_deserializer(self): """测试反序列化""" -- Gitee From 6f31b8b6ceee8ea8e95c4c869f4dba7162f53dcd Mon Sep 17 00:00:00 2001 From: LaoSi Date: Fri, 5 Mar 2021 17:03:57 +0800 Subject: [PATCH 22/34] =?UTF-8?q?=E8=BD=AC=E6=8D=A2=E5=99=A8=E7=BC=96?= =?UTF-8?q?=E5=86=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sanic_rest_framework/converter.py | 133 ++++++++++++++++++++++++++++ sanic_rest_framework/serializers.py | 6 +- 2 files changed, 137 insertions(+), 2 deletions(-) create mode 100644 sanic_rest_framework/converter.py diff --git a/sanic_rest_framework/converter.py b/sanic_rest_framework/converter.py new file mode 100644 index 0000000..ff7a5a0 --- /dev/null +++ b/sanic_rest_framework/converter.py @@ -0,0 +1,133 @@ +""" +@Author:WangYuXiang +@E-mile:Hill@3io.cc +@CreateTime:2021/3/5 9:44 +@DependencyLibrary:无 +@MainFunction:无 +@FileDoc: + converter.py + 转换器 +@ChangeHistory: + datetime action why + example: + 2021/3/5 9:44 change 'Fix bug' + +""" +from tortoise import fields +from sanic_rest_framework.fields import ( + CharField, ChoiceField, IntegerField, BooleanField, + DecimalField, DateTimeField, DateField, TimeField, + FloatField +) + + +def converts(*args): + def _inner(func): + func._converter_for = frozenset(args) + return func + + return _inner + + +class ConverterException(Exception): + pass + + +class ModelConverterBase(object): + def __init__(self, converters): + + if not converters: + converters = {} + + for name in dir(self): + obj = getattr(self, name) + if hasattr(obj, '_converter_for'): + for classname in obj._converter_for: + converters[classname] = obj + + self.converters = converters + + +class ModelConverter(ModelConverterBase): + """模型转换器""" + + def convert(self, serializer, model_field, **field_kwargs): + model = serializer.Meta.model + read_only_fields = serializer.Meta.read_only_fields if hasattr(serializer.Meta,'read_only_fields') else () + write_only_fields = serializer.Meta.write_only_fields if hasattr(serializer.Meta,'write_only_fields') else () + + + + kwargs = { + 'read_only': False, + 'write_only': False, + 'required': model_field.required, + 'allow_null': model_field.null, + # 'allow_empty': False, # M2M O2M + 'source': None, + 'description': model_field.description + } + if not isinstance(model_field, fields.relational.RelationalField): + type_name = model_field.__class__.__name__ + if model_field.default is not None: + kwargs['default'] = model_field.default, + if hasattr(model.Meta, ) + converter = self.converters[type_name] + else: + type_name = model_field.__class__.__name__ + converter = self.converters[type_name] + kwargs.update(field_kwargs) + + return converter(model, model_field, field_kwargs, **kwargs) + + @converts(fields.CharField) + def _(self, model, model_field, *field_args, **field_kws): + max_length = model_field.max_length + if max_length is not None: + field_kws['max_length'] = max_length + return CharField(*field_args, **field_kws) + + @converts(fields.UUIDField, fields.JSONField, fields.TextField) + def _(self, model, model_field, *field_args, **field_kws): + return CharField(*field_args, **field_kws) + + @converts(fields.IntField, fields.BigIntField, fields.SmallIntField) + def _(self, model, model_field, *field_args, **field_kws): + return IntegerField(*field_args, **field_kws) + + @converts(fields.BooleanField) + def _(self, model, model_field, *field_args, **field_kws): + return BooleanField(*field_args, **field_kws) + + @converts(fields.CharEnumField) + def _(self, model, model_field, *field_args, **field_kws): + max_length = model_field.max_length + if max_length is not None: + field_kws['max_length'] = max_length + choices = ((i.name, i.value) for i in model_field.enum_type) + field_kws['choices'] = choices + return ChoiceField(*field_args, **field_kws) + + @converts(fields.IntEnumField) + def _(self, model, model_field, *field_args, **field_kws): + choices = ((i.name, i.value) for i in model_field.enum_type) + field_kws['choices'] = choices + return ChoiceField(*field_args, **field_kws) + + @converts(fields.DecimalField) + def _(self, model, model_field, *field_args, **field_kws): + field_kws['max_digits'] = model_field.max_digits + field_kws['decimal_places'] = model_field.decimal_places + return DecimalField(*field_args, **field_kws) + + @converts(fields.DatetimeField) + def _(self, model, model_field, *field_args, **field_kws): + return DateTimeField(*field_args, **field_kws) + + @converts(fields.DateField) + def _(self, model, model_field, *field_args, **field_kws): + return DateField(*field_args, **field_kws) + + @converts(fields.FloatField) + def _(self, model, model_field, *field_args, **field_kws): + return FloatField(*field_args, **field_kws) diff --git a/sanic_rest_framework/serializers.py b/sanic_rest_framework/serializers.py index 1af96b9..d339bd3 100644 --- a/sanic_rest_framework/serializers.py +++ b/sanic_rest_framework/serializers.py @@ -372,7 +372,6 @@ class ListSerializer(BaseSerializer): return value - class ModelSerializer(Serializer): """ class Meta: @@ -436,7 +435,7 @@ class ModelSerializer(Serializer): else: current_field_class = self.convert_mapping[basis_field_class.__class__] serializer_fields[basis_field_name] = current_field_class - + # serializer_fields = # # # # like drf @@ -445,6 +444,9 @@ class ModelSerializer(Serializer): # fields[key] = value # return fields + def _get_model_field_extra_kwargs(self, model_field) -> dict: + return {} + def _clean_model_field(self, model): """ 清除不需要的字段如 fk_id -- Gitee From 2c9eef3def1b7bea6800709eb1432323aed99462 Mon Sep 17 00:00:00 2001 From: LaoSi Date: Mon, 8 Mar 2021 18:03:31 +0800 Subject: [PATCH 23/34] =?UTF-8?q?=E6=8F=90=E4=BA=A4=E5=B5=8C=E5=A5=97?= =?UTF-8?q?=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sanic_rest_framework/converter.py | 92 ++++++++++++------- sanic_rest_framework/serializers.py | 12 ++- sanic_rest_framework/test/models.py | 51 +++++++++- .../test_model_serializers.py | 47 ++++++++++ 4 files changed, 164 insertions(+), 38 deletions(-) create mode 100644 sanic_rest_framework/test/test_serializers/test_model_serializers.py diff --git a/sanic_rest_framework/converter.py b/sanic_rest_framework/converter.py index ff7a5a0..6fed424 100644 --- a/sanic_rest_framework/converter.py +++ b/sanic_rest_framework/converter.py @@ -34,8 +34,8 @@ class ConverterException(Exception): class ModelConverterBase(object): - def __init__(self, converters): - + def __init__(self, nested_field_class, converters=None, ): + self.nested_field_class = nested_field_class if not converters: converters = {} @@ -53,10 +53,8 @@ class ModelConverter(ModelConverterBase): def convert(self, serializer, model_field, **field_kwargs): model = serializer.Meta.model - read_only_fields = serializer.Meta.read_only_fields if hasattr(serializer.Meta,'read_only_fields') else () - write_only_fields = serializer.Meta.write_only_fields if hasattr(serializer.Meta,'write_only_fields') else () - - + read_only_fields = serializer.Meta.read_only_fields if hasattr(serializer.Meta, 'read_only_fields') else () + write_only_fields = serializer.Meta.write_only_fields if hasattr(serializer.Meta, 'write_only_fields') else () kwargs = { 'read_only': False, @@ -71,63 +69,91 @@ class ModelConverter(ModelConverterBase): type_name = model_field.__class__.__name__ if model_field.default is not None: kwargs['default'] = model_field.default, - if hasattr(model.Meta, ) - converter = self.converters[type_name] + converter = self.converters[type_name] else: type_name = model_field.__class__.__name__ + if hasattr(serializer.Meta, 'nested_depth'): + nested_depth = serializer.Meta.nested_depth + else: + nested_depth = 10 + kwargs['nested_depth'] = nested_depth converter = self.converters[type_name] kwargs.update(field_kwargs) - return converter(model, model_field, field_kwargs, **kwargs) + return converter(model, model_field, **kwargs) - @converts(fields.CharField) - def _(self, model, model_field, *field_args, **field_kws): + @converts('CharField') + def convert_charfield(self, model, model_field, *field_args, **field_kws): max_length = model_field.max_length if max_length is not None: field_kws['max_length'] = max_length return CharField(*field_args, **field_kws) - @converts(fields.UUIDField, fields.JSONField, fields.TextField) - def _(self, model, model_field, *field_args, **field_kws): + @converts('UUIDField', 'JSONField', 'TextField') + def convert_textfield(self, model, model_field, *field_args, **field_kws): return CharField(*field_args, **field_kws) - @converts(fields.IntField, fields.BigIntField, fields.SmallIntField) - def _(self, model, model_field, *field_args, **field_kws): + @converts('IntField', 'BigIntField', 'SmallIntField') + def convert_integerfield(self, model, model_field, *field_args, **field_kws): return IntegerField(*field_args, **field_kws) - @converts(fields.BooleanField) - def _(self, model, model_field, *field_args, **field_kws): + @converts('BooleanField') + def convert_booleanfield(self, model, model_field, *field_args, **field_kws): return BooleanField(*field_args, **field_kws) - @converts(fields.CharEnumField) - def _(self, model, model_field, *field_args, **field_kws): - max_length = model_field.max_length - if max_length is not None: - field_kws['max_length'] = max_length - choices = ((i.name, i.value) for i in model_field.enum_type) + @converts('CharEnumFieldInstance') + def convert_charenumfield(self, model, model_field, *field_args, **field_kws): + # max_length = model_field.max_length + # if max_length is not None: + # field_kws['max_length'] = max_length + choices = [(i.name, i.value) for i in model_field.enum_type] field_kws['choices'] = choices return ChoiceField(*field_args, **field_kws) - @converts(fields.IntEnumField) - def _(self, model, model_field, *field_args, **field_kws): + @converts('IntEnumFieldInstance') + def convert_intenumfield(self, model, model_field, *field_args, **field_kws): choices = ((i.name, i.value) for i in model_field.enum_type) field_kws['choices'] = choices return ChoiceField(*field_args, **field_kws) - @converts(fields.DecimalField) - def _(self, model, model_field, *field_args, **field_kws): + @converts('DecimalField') + def convert_decimalfield(self, model, model_field, *field_args, **field_kws): field_kws['max_digits'] = model_field.max_digits field_kws['decimal_places'] = model_field.decimal_places return DecimalField(*field_args, **field_kws) - @converts(fields.DatetimeField) - def _(self, model, model_field, *field_args, **field_kws): + @converts('DatetimeField') + def convert_datetimefield(self, model, model_field, *field_args, **field_kws): return DateTimeField(*field_args, **field_kws) - @converts(fields.DateField) - def _(self, model, model_field, *field_args, **field_kws): + @converts('DateField') + def convert_datefield(self, model, model_field, *field_args, **field_kws): return DateField(*field_args, **field_kws) - @converts(fields.FloatField) - def _(self, model, model_field, *field_args, **field_kws): + @converts('FloatField') + def convert_floatfield(self, model, model_field, *field_args, **field_kws): return FloatField(*field_args, **field_kws) + + @converts('ManyToManyFieldInstance', 'BackwardFKRelation') + def convert_MTM(self, model, model_field, *field_args, **field_kws): + nested_depth = field_kws.get('nested_depth', 10) + + class NestedSerializer(self.nested_field_class): + class Meta: + model = model_field.related_model + depth = nested_depth - 1 + fields = '__all__' + + return NestedSerializer(many=True) + + @converts('ForeignKeyFieldInstance', 'OneToOneFieldInstance', 'BackwardOneToOneRelation') + def convert_ManyToOne(self, model, model_field, *field_args, **field_kws): + nested_depth = field_kws.get('nested_depth', 10) + + class NestedSerializer(self.nested_field_class): + class Meta: + model = model_field.related_model + depth = nested_depth - 1 + fields = '__all__' + + return NestedSerializer() diff --git a/sanic_rest_framework/serializers.py b/sanic_rest_framework/serializers.py index d339bd3..66d13c3 100644 --- a/sanic_rest_framework/serializers.py +++ b/sanic_rest_framework/serializers.py @@ -21,6 +21,7 @@ from sanic_rest_framework.fields import ( empty, SkipField, Field, CharField, IntegerField, FloatField, DecimalField, BooleanField, DateTimeField, DateField, TimeField, ChoiceField, SerializerMethodField ) +from sanic_rest_framework.converter import ModelConverter from .exceptions import ValidationError from .helpers import BindingDict @@ -419,22 +420,25 @@ class ModelSerializer(Serializer): serializer_class=self.__class__.__name__ ) ) - if self.Meta.model.Meta.abstract: + if self.Meta.model._meta.abstract: raise ValueError('不能将ModelSerializer与抽象模型一起使用。') declared_fields = copy.deepcopy(self._declared_fields) model = getattr(self.Meta, 'model') depth = getattr(self.Meta, 'depth', 0) + converter = ModelConverter(ModelSerializer) model_fields = self._clean_model_field(model) model_basis_fields = self._get_model_basis_fields(model_fields) serializer_fields = {} - for basis_field_name, basis_field_class in model_basis_fields.items(): + for basis_field_name, basis_field_class in model_fields.items(): if basis_field_name in declared_fields: current_field_class = declared_fields[basis_field_name] else: - current_field_class = self.convert_mapping[basis_field_class.__class__] + current_field_class = converter.convert(self, basis_field_class) serializer_fields[basis_field_name] = current_field_class + return serializer_fields + # # serializer_fields = # # @@ -455,7 +459,7 @@ class ModelSerializer(Serializer): """ clean_field_names = [] field_dict = {} - fields_map = copy.deepcopy(model._meta.fields_map()) + fields_map = copy.deepcopy(model._meta.fields_map) for field_name, field_class in fields_map.items(): if isinstance(field_class, (fields.relational.ForeignKeyFieldInstance, fields.relational.OneToOneFieldInstance)): clean_field_names.append(field_class.source_field) diff --git a/sanic_rest_framework/test/models.py b/sanic_rest_framework/test/models.py index 4708db3..093a465 100644 --- a/sanic_rest_framework/test/models.py +++ b/sanic_rest_framework/test/models.py @@ -1,9 +1,57 @@ from datetime import date +from enum import Enum + from tortoise import fields from tortoise import Model from tortoise.fields import ForeignKeyRelation, ReverseRelation +class IntEnum(Enum): + OK = 1 + BAD = 2 + + +class CharEnum(Enum): + OK = '好' + BAD = '坏' + + +class TestModel(Model): + char_field = fields.CharField(max_length=8, null=False) + float_field = fields.FloatField() + date_field = fields.DateField() + int_field = fields.IntField() + decimal_field = fields.DecimalField(max_digits=13, decimal_places=3) + datetime_field = fields.DatetimeField() + int_enum_field = fields.IntEnumField(enum_type=IntEnum) + char_enum_field = fields.CharEnumField(enum_type=CharEnum) + boolean_field = fields.BooleanField() + small_int_field = fields.SmallIntField() + big_int_field = fields.BigIntField() + text_field = fields.TextField() + json_field = fields.JSONField() + uuid_field = fields.UUIDField() + one_to_many: ForeignKeyRelation["TestManyToOneModel"] = fields.ForeignKeyField('models.TestManyToOneModel', related_name='many_to_one') + one_to_one: fields.OneToOneRelation["TestOneToOneModel"] = fields.OneToOneField('models.TestOneToOneModel', related_name='one_to_one') + many_to_many: fields.ManyToManyRelation["TestModel"] = fields.ManyToManyField("models.TestModel", related_name="many_2_many", through="many_many") + + +class TestOneTOManyModel(Model): + one_to_many: ForeignKeyRelation["TestModel"] = fields.ForeignKeyField('models.TestModel', related_name='many_to_one') + + +class TestManyToOneModel(Model): + many_to_one: fields.ReverseRelation["TestModel"] + + +class TestOneToOneModel(Model): + name = fields.CharField(max_length=8, null=False) + + +class ManyToManyModel(Model): + many_2_many: fields.ManyToManyRelation["TestModel"] + + class UserModel(Model): name = fields.CharField(max_length=8, null=False) birthday = fields.DateField() @@ -35,5 +83,6 @@ class StudentModel(Model): name = fields.CharField(max_length=12, null=False) class_room: ForeignKeyRelation["ClassRoomModel"] = fields.ForeignKeyField('models.ClassRoomModel', 'students') + class DateSeriesModel(Model): - name = fields.TimeDeltaField() \ No newline at end of file + name = fields.TimeDeltaField() diff --git a/sanic_rest_framework/test/test_serializers/test_model_serializers.py b/sanic_rest_framework/test/test_serializers/test_model_serializers.py new file mode 100644 index 0000000..ee8982c --- /dev/null +++ b/sanic_rest_framework/test/test_serializers/test_model_serializers.py @@ -0,0 +1,47 @@ +""" +@Author:WangYuXiang +@E-mile:Hill@3io.cc +@CreateTime:2021/3/8 11:59 +@DependencyLibrary:无 +@MainFunction:无 +@FileDoc: + test_model_serializers.py + 文件说明 +@ChangeHistory: + datetime action why + example: + 2021/3/8 11:59 change 'Fix bug' + +""" +import asyncio + +from tortoise.contrib.test import initializer, TestCase, finalizer + +from sanic_rest_framework.converter import ModelConverter +from sanic_rest_framework.test.models import TestModel +from sanic_rest_framework.serializers import ModelSerializer + + +class TestModelSerializer(ModelSerializer): + class Meta: + model = TestModel + + +initializer(['sanic_rest_framework.test.models', ], + # db_url="sqlite://./db.sqlite", + loop=asyncio.get_event_loop()) + + +class TestOrdinarySerializer(TestCase): + def setUp(self) -> None: + pass + + async def test_serializer(self): + # for name, field in UserModel._meta.fields_map.items(): + # ModelConverter().convert(TestModelSerializer, field, name) + tms = TestModelSerializer() + print(tms.fields) + + @classmethod + async def tearDownClass(cls) -> None: + finalizer() -- Gitee From cb8ac66a94a5b34e16a10535c53a6e35e8189029 Mon Sep 17 00:00:00 2001 From: LaoSi Date: Tue, 9 Mar 2021 17:43:11 +0800 Subject: [PATCH 24/34] =?UTF-8?q?=E5=BA=8F=E5=88=97=E5=8C=96=E5=99=A8?= =?UTF-8?q?=E5=A4=A7=E8=87=B4=E5=AE=8C=E6=88=90=EF=BC=8C=E5=BE=85=E5=AE=8C?= =?UTF-8?q?=E5=96=84=E6=B5=8B=E8=AF=95=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sanic_rest_framework/converter.py | 43 +++-- sanic_rest_framework/fields.py | 29 ++- sanic_rest_framework/serializers.py | 168 +++++++++--------- sanic_rest_framework/test/models.py | 16 +- .../test_model_serializers.py | 52 +++++- 5 files changed, 192 insertions(+), 116 deletions(-) diff --git a/sanic_rest_framework/converter.py b/sanic_rest_framework/converter.py index 6fed424..f8936c0 100644 --- a/sanic_rest_framework/converter.py +++ b/sanic_rest_framework/converter.py @@ -17,7 +17,7 @@ from tortoise import fields from sanic_rest_framework.fields import ( CharField, ChoiceField, IntegerField, BooleanField, DecimalField, DateTimeField, DateField, TimeField, - FloatField + FloatField, EnumChoiceField ) @@ -53,16 +53,18 @@ class ModelConverter(ModelConverterBase): def convert(self, serializer, model_field, **field_kwargs): model = serializer.Meta.model - read_only_fields = serializer.Meta.read_only_fields if hasattr(serializer.Meta, 'read_only_fields') else () - write_only_fields = serializer.Meta.write_only_fields if hasattr(serializer.Meta, 'write_only_fields') else () - + read_only = field_kwargs.get('read_only', False) + required = not model_field.null + if read_only and required: + raise ValueError('{}序列化器内字段{}为必填项不能使用read_only属性'.format( + type(serializer).__name__, model_field.model_field_name)) kwargs = { - 'read_only': False, + 'read_only': read_only, 'write_only': False, - 'required': model_field.required, + 'required': required, 'allow_null': model_field.null, # 'allow_empty': False, # M2M O2M - 'source': None, + # 'source': None, 'description': model_field.description } if not isinstance(model_field, fields.relational.RelationalField): @@ -72,11 +74,11 @@ class ModelConverter(ModelConverterBase): converter = self.converters[type_name] else: type_name = model_field.__class__.__name__ - if hasattr(serializer.Meta, 'nested_depth'): - nested_depth = serializer.Meta.nested_depth + if hasattr(serializer.Meta, 'depth'): + nested_depth = serializer.Meta.depth else: nested_depth = 10 - kwargs['nested_depth'] = nested_depth + kwargs['nested_depth'] = nested_depth converter = self.converters[type_name] kwargs.update(field_kwargs) @@ -103,18 +105,15 @@ class ModelConverter(ModelConverterBase): @converts('CharEnumFieldInstance') def convert_charenumfield(self, model, model_field, *field_args, **field_kws): - # max_length = model_field.max_length - # if max_length is not None: - # field_kws['max_length'] = max_length - choices = [(i.name, i.value) for i in model_field.enum_type] - field_kws['choices'] = choices - return ChoiceField(*field_args, **field_kws) + field_kws['enum_type'] = model_field.enum_type + field_kws['value_type'] = str + return EnumChoiceField(*field_args, **field_kws) @converts('IntEnumFieldInstance') def convert_intenumfield(self, model, model_field, *field_args, **field_kws): - choices = ((i.name, i.value) for i in model_field.enum_type) - field_kws['choices'] = choices - return ChoiceField(*field_args, **field_kws) + field_kws['enum_type'] = model_field.enum_type + field_kws['value_type'] = int + return EnumChoiceField(*field_args, **field_kws) @converts('DecimalField') def convert_decimalfield(self, model, model_field, *field_args, **field_kws): @@ -134,8 +133,8 @@ class ModelConverter(ModelConverterBase): def convert_floatfield(self, model, model_field, *field_args, **field_kws): return FloatField(*field_args, **field_kws) - @converts('ManyToManyFieldInstance', 'BackwardFKRelation') - def convert_MTM(self, model, model_field, *field_args, **field_kws): + @converts('ManyToManyFieldInstance', 'BackwardFKRelation', 'ManyToManyRelation') + def convert_manytomany(self, model, model_field, *field_args, **field_kws): nested_depth = field_kws.get('nested_depth', 10) class NestedSerializer(self.nested_field_class): @@ -147,7 +146,7 @@ class ModelConverter(ModelConverterBase): return NestedSerializer(many=True) @converts('ForeignKeyFieldInstance', 'OneToOneFieldInstance', 'BackwardOneToOneRelation') - def convert_ManyToOne(self, model, model_field, *field_args, **field_kws): + def convert_manytoone(self, model, model_field, *field_args, **field_kws): nested_depth = field_kws.get('nested_depth', 10) class NestedSerializer(self.nested_field_class): diff --git a/sanic_rest_framework/fields.py b/sanic_rest_framework/fields.py index f01cef4..5c22ec2 100644 --- a/sanic_rest_framework/fields.py +++ b/sanic_rest_framework/fields.py @@ -12,6 +12,7 @@ import copy import decimal import re from datetime import timezone, timedelta, datetime, date, time +from enum import Enum from typing import Any, List, Mapping from tortoise import Model from tortoise.queryset import QuerySet @@ -719,7 +720,7 @@ class ChoiceField(Field): """得到字符串""" if self.check_key_choices(key): choices_dict = self.get_choices() - value = choices_dict[str(key)] + value = choices_dict[key] return value self.raise_error('key', key=key) @@ -728,10 +729,34 @@ class ChoiceField(Field): return key in choices_dict def get_choices(self) -> dict: - choices = {str(key): value for key, value in self.choices} + choices = {key: value for key, value in self.choices} return choices +class EnumChoiceField(Field): + """枚举类型字段""" + + def __init__(self, enum_type, value_type, *args, **kwargs): + """ + :param enum_type: 枚举类 + :param args: + :param kwargs: + """ + self.enum_type = enum_type + self.value_type = value_type + super(EnumChoiceField, self).__init__(*args, **kwargs) + + def external_to_internal(self, data: Any) -> Any: + return self.enum_type(data) if data is not None else None + + async def internal_to_external(self, data: Any) -> Any: + if isinstance(data, Enum): + return self.value_type(data.value) + if isinstance(data, self.value_type): + return self.value_type(self.enum_type(data).value) + return self.value_type(data) + + class SerializerMethodField(Field): """ 一个只读字段,可通过在父序列化器类。调用的方法将具有以下形式 diff --git a/sanic_rest_framework/serializers.py b/sanic_rest_framework/serializers.py index 66d13c3..45e01e1 100644 --- a/sanic_rest_framework/serializers.py +++ b/sanic_rest_framework/serializers.py @@ -15,7 +15,7 @@ from typing import Any, Mapping from tortoise import models, Model from tortoise.queryset import QuerySet -from tortoise import fields +from tortoise import fields as tortoise_fields from sanic_rest_framework.fields import ( empty, SkipField, @@ -135,7 +135,10 @@ class BaseSerializer(Field): self._validated_data = self.run_validation(self.initial_data) except ValidationError as exc: self._validated_data = {} - self._errors = exc.error_dict + if hasattr(exc, 'error_dict'): + self._errors = exc.error_dict + else: + self._errors = {self.field_name, exc.error_list} else: self._errors = {} @@ -326,7 +329,7 @@ class ListSerializer(BaseSerializer): :param data: :return: """ - iterable = await data.all() if isinstance(data, (QuerySet, fields.relational.RelationalField)) else data + iterable = await data.all() if isinstance(data, (QuerySet, tortoise_fields.relational.RelationalField)) else data return [ await self.child.internal_to_external(item) for item in iterable @@ -385,24 +388,6 @@ class ModelSerializer(Serializer): write_only_fields = () # 字段与read_only_fields冲突 """ - convert_mapping = { - fields.BigIntField: IntegerField, - fields.BinaryField: None, # TODO : 暂时无法解决 - fields.BooleanField: BooleanField, - fields.CharEnumField: ChoiceField, - fields.CharField: CharField, - fields.DateField: DateField, - fields.DatetimeField: DateTimeField, - fields.DecimalField: DecimalField, - fields.FloatField: FloatField, - fields.IntEnumField: ChoiceField, - fields.IntField: IntegerField, - fields.JSONField: CharField, - fields.SmallIntField: IntegerField, - fields.TextField: CharField, - fields.TimeDeltaField: None, # TODO : 需要为其单独创建一个对应的字段 - fields.UUIDField: CharField, - } @property def fields(self): @@ -425,31 +410,51 @@ class ModelSerializer(Serializer): declared_fields = copy.deepcopy(self._declared_fields) model = getattr(self.Meta, 'model') - depth = getattr(self.Meta, 'depth', 0) - converter = ModelConverter(ModelSerializer) + depth = getattr(self.Meta, 'depth', 10) + converter = ModelConverter(ModelSerializer) model_fields = self._clean_model_field(model) - model_basis_fields = self._get_model_basis_fields(model_fields) - serializer_fields = {} - for basis_field_name, basis_field_class in model_fields.items(): - if basis_field_name in declared_fields: - current_field_class = declared_fields[basis_field_name] + effective_field = self.get_effective_field(model_fields) + serializer_fields = BindingDict(self) + + for field_name, field_class in effective_field.items(): + if field_name in declared_fields: + current_field_class = declared_fields[field_name] else: - current_field_class = converter.convert(self, basis_field_class) - serializer_fields[basis_field_name] = current_field_class + current_field_class = converter.convert(self, field_class, **self.get_field_kws_by_meta(field_name)) + serializer_fields[field_name] = current_field_class return serializer_fields - # - # serializer_fields = - # - # - # # like drf - # fields = BindingDict(self) - # for key, value in self.get_fields().items(): - # fields[key] = value - # return fields - - def _get_model_field_extra_kwargs(self, model_field) -> dict: - return {} + + def get_effective_field(self, model_fields) -> dict: + """ + 得到有效的字段 + :param model_fields: 模型字段 + :return: + """ + meta_fields = getattr(self.Meta, 'fields', None) + meta_exclude = getattr(self.Meta, 'exclude', None) + if meta_exclude and meta_fields: + raise ValueError('class ”{}“ ’Meta.fields‘ 和 ’Meta.exclude‘ 不可以共存 '.format(self.__class__.__name__)) + + if meta_exclude is not None: + return {k: v for k, v in model_fields.items() if k not in meta_exclude} + elif meta_exclude is None and meta_fields is None: + return model_fields + else: + return {k: v for k, v in model_fields.items() if k in meta_fields} + + def get_field_kws_by_meta(self, field_name): + read_only_fields = getattr(self.Meta, 'read_only_fields', []) + write_only_fields = getattr(self.Meta, 'write_only_fields', []) + if field_name in read_only_fields and field_name in write_only_fields: + raise ValueError('字段 {} 不可用同时存在于类 ”{}“ 的 ’Meta.read_only_fields‘ ' + '和 ’Meta.write_only_fields‘ 属性中'.format(field_name, self.__class__.__name__)) + if field_name in read_only_fields: + return {'read_only': True, 'write_only': False} + elif field_name in write_only_fields: + return {'read_only': False, 'write_only': True} + else: + return {'read_only': False, 'write_only': False} def _clean_model_field(self, model): """ @@ -461,7 +466,7 @@ class ModelSerializer(Serializer): field_dict = {} fields_map = copy.deepcopy(model._meta.fields_map) for field_name, field_class in fields_map.items(): - if isinstance(field_class, (fields.relational.ForeignKeyFieldInstance, fields.relational.OneToOneFieldInstance)): + if isinstance(field_class, (tortoise_fields.relational.ForeignKeyFieldInstance, tortoise_fields.relational.OneToOneFieldInstance)): clean_field_names.append(field_class.source_field) field_dict[field_name] = field_class @@ -470,41 +475,42 @@ class ModelSerializer(Serializer): field_dict.pop(clean_field_name) return field_dict - def _get_model_basis_fields(self, model_fields): - """ - 得到基础字段,非关系字段 - :param model: - :return: - """ - return {field_name: field_class for field_name, field_class in model_fields.items() if not isinstance(field_class, fields.relational.RelationalField)} - - def _get_model_M2M_fields(self, model_fields): - """ - 得到多对多字段 - :param model: - :return: - """ - return {field_name: field_class for field_name, field_class in model_fields.items() if isinstance(field_class, fields.relational.ManyToManyFieldInstance)} - - def _get_model_O2O_fields(self, model_fields): - """得到一对一字段""" - return {field_name: field_class for field_name, field_class in model_fields.items() if - isinstance(field_class, (fields.relational.BackwardOneToOneRelation, fields.relational.OneToOneFieldInstance))} - - def _get_model_M2O_fields(self, model_fields): - """ - 得到多对一字段 - :param model: - :return: - """ - return {field_name: field_class for field_name, field_class in model_fields.items() if - isinstance(field_class, fields.relational.ForeignKeyFieldInstance)} - - def _get_model_O2M_fields(self, model_fields): - """ - 得到一对多字段 - :param model: - :return: - """ - return {field_name: field_class for field_name, field_class in model_fields.items() if - isinstance(field_class, fields.relational.BackwardFKRelation)} + # + # def _get_model_basis_fields(self, model_fields): + # """ + # 得到基础字段,非关系字段 + # :param model: + # :return: + # """ + # return {field_name: field_class for field_name, field_class in model_fields.items() if not isinstance(field_class, tortoise_fields.relational.RelationalField)} + # + # def _get_model_M2M_fields(self, model_fields): + # """ + # 得到多对多字段 + # :param model: + # :return: + # """ + # return {field_name: field_class for field_name, field_class in model_fields.items() if isinstance(field_class, tortoise_fields.relational.ManyToManyFieldInstance)} + # + # def _get_model_O2O_fields(self, model_fields): + # """得到一对一字段""" + # return {field_name: field_class for field_name, field_class in model_fields.items() if + # isinstance(field_class, (tortoise_fields.relational.BackwardOneToOneRelation, tortoise_fields.relational.OneToOneFieldInstance))} + # + # def _get_model_M2O_fields(self, model_fields): + # """ + # 得到多对一字段 + # :param model: + # :return: + # """ + # return {field_name: field_class for field_name, field_class in model_fields.items() if + # isinstance(field_class, tortoise_fields.relational.ForeignKeyFieldInstance)} + # + # def _get_model_O2M_fields(self, model_fields): + # """ + # 得到一对多字段 + # :param model: + # :return: + # """ + # return {field_name: field_class for field_name, field_class in model_fields.items() if + # isinstance(field_class, tortoise_fields.relational.BackwardFKRelation)} diff --git a/sanic_rest_framework/test/models.py b/sanic_rest_framework/test/models.py index 093a465..175ed62 100644 --- a/sanic_rest_framework/test/models.py +++ b/sanic_rest_framework/test/models.py @@ -1,12 +1,12 @@ from datetime import date -from enum import Enum +from enum import Enum,IntEnum from tortoise import fields from tortoise import Model from tortoise.fields import ForeignKeyRelation, ReverseRelation -class IntEnum(Enum): +class Enum1(IntEnum): OK = 1 BAD = 2 @@ -17,13 +17,13 @@ class CharEnum(Enum): class TestModel(Model): - char_field = fields.CharField(max_length=8, null=False) + char_field = fields.CharField(max_length=8, null=True) float_field = fields.FloatField() date_field = fields.DateField() int_field = fields.IntField() decimal_field = fields.DecimalField(max_digits=13, decimal_places=3) datetime_field = fields.DatetimeField() - int_enum_field = fields.IntEnumField(enum_type=IntEnum) + int_enum_field = fields.IntEnumField(enum_type=Enum1) char_enum_field = fields.CharEnumField(enum_type=CharEnum) boolean_field = fields.BooleanField() small_int_field = fields.SmallIntField() @@ -31,9 +31,9 @@ class TestModel(Model): text_field = fields.TextField() json_field = fields.JSONField() uuid_field = fields.UUIDField() - one_to_many: ForeignKeyRelation["TestManyToOneModel"] = fields.ForeignKeyField('models.TestManyToOneModel', related_name='many_to_one') - one_to_one: fields.OneToOneRelation["TestOneToOneModel"] = fields.OneToOneField('models.TestOneToOneModel', related_name='one_to_one') - many_to_many: fields.ManyToManyRelation["TestModel"] = fields.ManyToManyField("models.TestModel", related_name="many_2_many", through="many_many") + one_to_many: ForeignKeyRelation["TestManyToOneModel"] = fields.ForeignKeyField('models.TestManyToOneModel', related_name='many_to_one', null=True) + one_to_one: fields.OneToOneRelation["TestOneToOneModel"] = fields.OneToOneField('models.TestOneToOneModel', related_name='one_to_one', null=True) + many_to_many: fields.ManyToManyRelation["TestModel"] = fields.ManyToManyField("models.TestModel", related_name="many_2_many", through="many_many", null=True) class TestOneTOManyModel(Model): @@ -45,7 +45,7 @@ class TestManyToOneModel(Model): class TestOneToOneModel(Model): - name = fields.CharField(max_length=8, null=False) + name = fields.CharField(max_length=8, null=True) class ManyToManyModel(Model): diff --git a/sanic_rest_framework/test/test_serializers/test_model_serializers.py b/sanic_rest_framework/test/test_serializers/test_model_serializers.py index ee8982c..d9de6fe 100644 --- a/sanic_rest_framework/test/test_serializers/test_model_serializers.py +++ b/sanic_rest_framework/test/test_serializers/test_model_serializers.py @@ -14,17 +14,23 @@ """ import asyncio +import uuid +from datetime import datetime from tortoise.contrib.test import initializer, TestCase, finalizer from sanic_rest_framework.converter import ModelConverter -from sanic_rest_framework.test.models import TestModel +from sanic_rest_framework.test.models import TestModel, TestOneTOManyModel, TestOneToOneModel, CharEnum, Enum1 from sanic_rest_framework.serializers import ModelSerializer class TestModelSerializer(ModelSerializer): class Meta: model = TestModel + # fields = ('char_field',) + exclude = ('id',) + read_only_fields = ('char_field',) + write_only_fields = ('float_field',) initializer(['sanic_rest_framework.test.models', ], @@ -34,12 +40,52 @@ initializer(['sanic_rest_framework.test.models', ], class TestOrdinarySerializer(TestCase): def setUp(self) -> None: - pass + self.data = { + 'char_field': '老四', + 'float_field': 1.36, + 'date_field': '2016-11-1', + 'int_field': 999, + 'decimal_field': 999.9, + 'datetime_field': '2016-12-1 11:1:1', + 'int_enum_field': 1, + 'char_enum_field': '好', + 'boolean_field': False, + 'small_int_field': 1, + 'big_int_field': 99999, + 'text_field': 'PE', + 'json_field': '{"a":1}', + 'uuid_field': '91a4c540-80b5-11eb-b03f-e0d55e47dfb2', + 'one_to_many': [], + 'one_to_one': {}, + 'many_to_many': [], + } async def test_serializer(self): # for name, field in UserModel._meta.fields_map.items(): # ModelConverter().convert(TestModelSerializer, field, name) - tms = TestModelSerializer() + + tm = TestModel() + tm.char_field = 1 + tm.float_field = 1 + tm.date_field = '2016-11-11' + tm.int_field = 1 + tm.decimal_field = 1 + tm.datetime_field = datetime(2016, 11, 11, 0, 0, 0) + tm.int_enum_field = Enum1.OK + tm.char_enum_field = CharEnum.OK + tm.boolean_field = True + tm.small_int_field = 1 + tm.big_int_field = 1 + tm.text_field = '1' + tm.json_field = {} + tm.uuid_field = str(uuid.uuid1()) + tm.one_to_one = await TestOneToOneModel(name='1').save() + await tm.save() + tms = TestModelSerializer(instance=await TestModel.all(), many=True) + print(await tms.data) + tms = TestModelSerializer(data=self.data, partial=True) + tms.is_valid() + print(tms.validated_data) print(tms.fields) @classmethod -- Gitee From 2466eb0e8de739fae1bf2bdbcfcc28c54bb4afa7 Mon Sep 17 00:00:00 2001 From: LaoSi Date: Wed, 10 Mar 2021 17:49:54 +0800 Subject: [PATCH 25/34] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=B7=BB=E5=8A=A0Views?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sanic_rest_framework/exceptions.py | 19 +++ sanic_rest_framework/fields.py | 7 ++ sanic_rest_framework/filter.py | 20 ++++ sanic_rest_framework/status.py | 10 ++ sanic_rest_framework/views.py | 184 +++++++++++++++++++++++++++-- 5 files changed, 232 insertions(+), 8 deletions(-) create mode 100644 sanic_rest_framework/filter.py diff --git a/sanic_rest_framework/exceptions.py b/sanic_rest_framework/exceptions.py index 5d5ec6b..cb1270f 100644 --- a/sanic_rest_framework/exceptions.py +++ b/sanic_rest_framework/exceptions.py @@ -10,6 +10,10 @@ """ from typing import Mapping +from sanic.response import json + +from sanic_rest_framework.status import HttpStatus, RuleStatus + class ValidationError(Exception): """验证器通用错误类 发生错误即抛出此类""" @@ -93,3 +97,18 @@ class ValidationError(Exception): class ValidatorAssertError(Exception): pass + + +class APIException(Exception): + def __init__(self, message, status=RuleStatus.STATUS_0_FAIL, http_status=HttpStatus.HTTP_500_INTERNAL_SERVER_ERROR, *args, **kwargs): + self.message = message + self.status = status + self.http_status = http_status + + def response_data(self): + return { + 'message': self.message, + 'status': self.status, + 'http_status': self.http_status + } + diff --git a/sanic_rest_framework/fields.py b/sanic_rest_framework/fields.py index 5c22ec2..6c6f318 100644 --- a/sanic_rest_framework/fields.py +++ b/sanic_rest_framework/fields.py @@ -302,6 +302,13 @@ class Field: root = root.parent return root + @property + def context(self): + """ + 返回初始化时传递给根序列化程序的上下文。 + """ + return getattr(self.root, '_context', {}) + def raise_error(self, _key, **kwargs): """ 返回在 error_messages 中注册了的错误 diff --git a/sanic_rest_framework/filter.py b/sanic_rest_framework/filter.py new file mode 100644 index 0000000..9a264fe --- /dev/null +++ b/sanic_rest_framework/filter.py @@ -0,0 +1,20 @@ +""" +@Author:WangYuXiang +@E-mile:Hill@3io.cc +@CreateTime:2021/3/10 17:25 +@DependencyLibrary:无 +@MainFunction:无 +@FileDoc: + filter.py + 文件说明 +@ChangeHistory: + datetime action why + example: + 2021/3/10 17:25 change 'Fix bug' + +""" + + +class SimpleFilter(): + def filter_queryset(self, request, queryset, view): + pass diff --git a/sanic_rest_framework/status.py b/sanic_rest_framework/status.py index e482042..d40f94e 100644 --- a/sanic_rest_framework/status.py +++ b/sanic_rest_framework/status.py @@ -99,3 +99,13 @@ class HttpStatus: HTTP_509_BANDWIDTH_LIMIT_EXCEEDED = 509 HTTP_510_NOT_EXTENDED = 510 HTTP_511_NETWORK_AUTHENTICATION_REQUIRED = 511 + +class c(Exception): + pass + +def aa(): + try: + raise c('发财务') + except Exception as ex: + print(ex) +aa() \ No newline at end of file diff --git a/sanic_rest_framework/views.py b/sanic_rest_framework/views.py index a502360..7e6837b 100644 --- a/sanic_rest_framework/views.py +++ b/sanic_rest_framework/views.py @@ -9,22 +9,35 @@ 基础视图文件 """ from sanic.response import json +from tortoise.queryset import QuerySet + from sanic_rest_framework.constant import ALL_METHOD +from sanic_rest_framework.exceptions import APIException from sanic_rest_framework.status import RuleStatus, HttpStatus class BaseAPIView: """基础API视图""" - detail = False + authentication_classes = () + permission_classes = () def dispatch(self, request, *args, **kwargs): """分发路由""" + request.user = None method = request.method if method not in self.licensed_methods: return self.json_response(msg='发生错误:未找到%s方法' % method, status=RuleStatus.STATUS_0_FAIL, - response_status=HttpStatus.HTTP_405_METHOD_NOT_ALLOWED) + http_status=HttpStatus.HTTP_405_METHOD_NOT_ALLOWED) handler = getattr(self, method.lower(), None) - return handler(request, *args, **kwargs) + + try: + self.initial(request, *args, **kwargs) + response = handler(request, *args, **kwargs) + except APIException as exc: + response = self.handle_exception(exc) + except Exception as exc: + response = self.handle_uncaught_exception(exc) + return response @classmethod def as_view(cls, methods=None, *class_args, **class_kwargs): @@ -43,20 +56,27 @@ class BaseAPIView: return self.dispatch(request, *args, **kwargs) view.base_class = cls - view.API_DOC_CONFIG = class_kwargs.get('API_DOC_CONFIG') # 未来的API文档配置属性 + view.API_DOC_CONFIG = class_kwargs.get('API_DOC_CONFIG') # 未来的API文档配置属性+ view.__doc__ = cls.__doc__ view.__module__ = cls.__module__ view.__name__ = cls.__name__ return view + def handle_exception(self, exc: APIException): + return self.json_response(**exc.response_data()) + + def handle_uncaught_exception(self, exc): + """处理未知错误""" + return self.json_response(msg=exc.args[0], status=RuleStatus.STATUS_0_FAIL, http_status=HttpStatus.HTTP_500_INTERNAL_SERVER_ERROR) + def json_response(self, data=None, msg="OK", status=RuleStatus.STATUS_1_SUCCESS, - response_status=HttpStatus.HTTP_200_OK): + http_status=HttpStatus.HTTP_200_OK): """ Json 相应体 :param data: 返回的数据主题 :param msg: 前台提示字符串 :param status: 前台约定状态,供前台判断是否成功 - :param response_status: Http响应数据 + :param http_status: Http响应数据 :return: """ if data is None: @@ -66,7 +86,7 @@ class BaseAPIView: 'message': msg, 'status': status } - return json(body=response_body, status=response_status) + return json(body=response_body, status=http_status) def success_json_response(self, data=None, msg="Success"): """ @@ -85,10 +105,158 @@ class BaseAPIView: :return: json """ return self.json_response(data=data, msg=msg, status=RuleStatus.STATUS_0_FAIL, - response_status=HttpStatus.HTTP_400_BAD_REQUEST) + http_status=HttpStatus.HTTP_400_BAD_REQUEST) + + def get_authenticators(self): + """ + 实例化并返回此视图可以使用的身份验证器列表 + """ + return [auth() for auth in self.authentication_classes] + + def check_authentication(self, request): + """ + 检查权限 查看是否拥有权限,并在此处为Request.User 赋值 + :param request: 请求 + :return: + """ + for authenticators in self.get_authenticators(): + request.user = authenticators.authenticate(request, request.user) + + def get_permissions(self): + """ + 实例化并返回此视图所需的权限列表 + """ + return [permission() for permission in self.permission_classes] + + def check_permissions(self, request): + """ + 检查是否应允许该请求,如果不允许该请求, + 则在 has_permission 中引发一个适当的异常。 + :param request: 当前请求 + :return: + """ + for permission in self.get_permissions(): + permission.has_permission(request, self) + + def check_object_permissions(self, request, obj): + """ + 检查是否应允许给定对象的请求, 如果不允许该请求, + 则在 has_object_permission 中引发一个适当的异常。 + 常用于 get_object() 方法 + :param request: 当前请求 + :param obj: 需要鉴权的模型对象 + :return: + """ + for permission in self.get_permissions(): + permission.has_object_permission(request, self, obj) + + def check_throttles(self, request): + """ + 检查范围频率。 + 则引发一个 APIException 异常。 + :param request: + :return: + """ + pass + + def initial(self, request, *args, **kwargs): + """ + 在请求分发之前执行初始化操作,用于检查权限及检查基础内容 + """ + self.check_authentication(request) + self.check_permissions(request) + self.check_throttles(request) # class ViewJsonHelperMixin: # class APIView(ViewJsonHelperMixin, BaseAPIView): + + +class APIView(BaseAPIView): + detail = False + queryset = None + lookup_field = 'pk' + serializer_class = None + pagination_class = None + lookup_url_kwarg = None + filter_class = None + search_fields = None + + async def get_queryset(self): + assert self.queryset is not None, ( + "'%s'应该包含一个'queryset'属性," + "或重写`get_queryset()`方法。" + % self.__class__.__name__ + ) + + queryset = self.queryset + if isinstance(queryset, QuerySet): + queryset = await queryset.all() + return queryset + + def filter_queryset(self, queryset): + self.filter_class().filter_queryset(self.request, queryset) + + def get_serializer(self, *args, **kwargs): + """ + 返回应该用于验证和验证的序列化程序实例 + 对输入进行反序列化,并对输出进行序列化。 + """ + serializer_class = self.get_serializer_class() + kwargs.setdefault('context', self.get_serializer_context()) + return serializer_class(*args, **kwargs) + + def get_serializer_class(self): + """ + 返回用于序列化器的类。 + 默认使用`self.serializer_class`。 + + 如果您需要提供其他信息,则可能要覆盖此设置 + 序列化取决于传入的请求。 + + (例如,管理员获得完整的序列化,其他获得基本的序列化) + """ + assert self.serializer_class is not None, ( + "'%s' should either include a `serializer_class` attribute, " + "or override the `get_serializer_class()` method." + % self.__class__.__name__ + ) + return self.serializer_class + + def get_serializer_context(self): + """ + 提供给序列化程序类的额外上下文。 + """ + return { + 'request': self.request, + 'view': self + } + + @property + def paginator(self): + """ + 与视图关联的分页器实例,或“None”。 + """ + if not hasattr(self, '_paginator'): + if self.pagination_class is None: + self._paginator = None + else: + self._paginator = self.pagination_class() + return self._paginator + + # def paginate_queryset(self, queryset): + # """ + # Return a single page of results, or `None` if pagination is disabled. + # """ + # if self.paginator is None: + # return None + # return self.paginator.paginate_queryset(queryset, self.request, view=self) + # + # def get_paginated_response(self, data): + # """ + # Return a paginated style `Response` object for the given output data. + # """ + # assert self.paginator is not None + # return self.paginator.get_paginated_response(data) -- Gitee From ce9e34fff0676eb52679bb4aacc19577fe77fe04 Mon Sep 17 00:00:00 2001 From: LaoSi Date: Thu, 11 Mar 2021 17:55:09 +0800 Subject: [PATCH 26/34] =?UTF-8?q?=E9=85=8D=E7=BD=AE=E9=9C=80=E8=A6=81?= =?UTF-8?q?=E4=BF=AE=E6=94=B9=EF=BC=8C=E5=AE=9E=E7=8E=B0=E4=BA=86=E9=80=9A?= =?UTF-8?q?=E7=94=A8=E8=BF=87=E6=BB=A4=E7=B1=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- run.py | 12 +--- sanic_rest_framework/filters.py | 72 +++++++++++++++++++ .../{filter.py => paginations.py} | 11 ++- sanic_rest_framework/request.py | 23 ++++++ sanic_rest_framework/routes.py | 2 +- sanic_rest_framework/setting.py | 28 ++++++++ sanic_rest_framework/status.py | 10 --- sanic_rest_framework/utils.py | 25 +++++++ sanic_rest_framework/views.py | 7 +- 9 files changed, 160 insertions(+), 30 deletions(-) create mode 100644 sanic_rest_framework/filters.py rename sanic_rest_framework/{filter.py => paginations.py} (52%) create mode 100644 sanic_rest_framework/request.py create mode 100644 sanic_rest_framework/setting.py create mode 100644 sanic_rest_framework/utils.py diff --git a/run.py b/run.py index 0ab8ebe..45927a5 100644 --- a/run.py +++ b/run.py @@ -3,26 +3,18 @@ from sanic.blueprints import Blueprint from tortoise.contrib.sanic import register_tortoise from db import TestModel +from sanic_rest_framework.request import SRFRequest from sanic_rest_framework.routes import Route from sanic_rest_framework.serializers import Serializer from sanic_rest_framework.fields import CharField from sanic_rest_framework.views import BaseAPIView - -app = Sanic(__name__) +app = Sanic(__name__, request_class=SRFRequest) admin = Blueprint('admin', '/admin') class TestView(BaseAPIView): async def get(self, request): - test = TestSerializer(data={ - 'id': '2', - 'qt': { - 'name': '刘文静' - } - }) - print(test.is_valid()) - test.validated_data return self.success_json_response() async def post(self, request): diff --git a/sanic_rest_framework/filters.py b/sanic_rest_framework/filters.py new file mode 100644 index 0000000..ca64fc6 --- /dev/null +++ b/sanic_rest_framework/filters.py @@ -0,0 +1,72 @@ +""" +@Author:WangYuXiang +@E-mile:Hill@3io.cc +@CreateTime:2021/3/10 17:25 +@DependencyLibrary:无 +@MainFunction:无 +@FileDoc: + filters.py + 文件说明 +@ChangeHistory: + datetime action why + example: + 2021/3/10 17:25 change 'Fix bug' + +""" +LOOKUP_SEP = '__' + + +class SimpleFilter: + """简单过滤器""" + + def filter_queryset(self, request, queryset, view): + raise NotImplementedError(".filter_queryset() must be overridden.") + + def construct_orm_filter(self, field_name, view, request): + return '' + + +class SearchFilter(SimpleFilter): + lookup_prefixes = { + '^': 'istartswith', + '$': 'iendswith', + '>': 'gt', + '<': 'lt', + '>=': 'gte', + '<=': 'lte' + } + + def get_search_fields(self, view, request): + """ + 搜索字段是从视图获取的,但请求始终是 + 传递给此方法。子类可以重写此方法以 + 根据请求内容动态更改搜索字段。 + """ + return getattr(view, 'search_fields', None) + + def filter_queryset(self, request, queryset, view): + search_fields = self.get_search_fields(view, request) + if not search_fields: + return queryset + orm_filters = {} + for search_field in search_fields: + orm_filters.update(self.construct_orm_filter(search_field, view, queryset)) + return queryset.filters(**orm_filters) + + def construct_orm_filter(self, field_name, view, request): + lookup_suffix_keys = list(self.lookup_prefixes.keys()) + lookup_suffix = None + for key in lookup_suffix_keys: + if key in field_name: + field_name = field_name[len(key):] + lookup_suffix_key = field_name[:len(key)] + lookup_suffix = self.lookup_prefixes.get(lookup_suffix_key) + break + if not lookup_suffix: + return {} + orm_lookup = LOOKUP_SEP.join([field_name, lookup_suffix]) + return {orm_lookup: self.get_filter_value(request, field_name)} + + def get_filter_value(self, request, field_name): + values = request.args.get(field_name) + return ''.join(values) diff --git a/sanic_rest_framework/filter.py b/sanic_rest_framework/paginations.py similarity index 52% rename from sanic_rest_framework/filter.py rename to sanic_rest_framework/paginations.py index 9a264fe..cc89ff4 100644 --- a/sanic_rest_framework/filter.py +++ b/sanic_rest_framework/paginations.py @@ -1,20 +1,19 @@ """ @Author:WangYuXiang @E-mile:Hill@3io.cc -@CreateTime:2021/3/10 17:25 +@CreateTime:2021/3/11 17:37 @DependencyLibrary:无 @MainFunction:无 @FileDoc: - filter.py + paginations.py 文件说明 @ChangeHistory: datetime action why example: - 2021/3/10 17:25 change 'Fix bug' + 2021/3/11 17:37 change 'Fix bug' """ -class SimpleFilter(): - def filter_queryset(self, request, queryset, view): - pass +class BasePagination: + pass diff --git a/sanic_rest_framework/request.py b/sanic_rest_framework/request.py new file mode 100644 index 0000000..cd9f180 --- /dev/null +++ b/sanic_rest_framework/request.py @@ -0,0 +1,23 @@ +""" +@Author:WangYuXiang +@E-mile:Hill@3io.cc +@CreateTime:2021/3/11 15:46 +@DependencyLibrary:无 +@MainFunction:无 +@FileDoc: + request.py + 文件说明 +@ChangeHistory: + datetime action why + example: + 2021/3/11 15:46 change 'Fix bug' + +""" + +from sanic.request import Request as SanicRequest + + +class SRFRequest(SanicRequest): + def __init__(self, *args, **kwargs): + super(SRFRequest, self).__init__(*args, **kwargs) + self.user = None diff --git a/sanic_rest_framework/routes.py b/sanic_rest_framework/routes.py index d09c107..cfb15a9 100644 --- a/sanic_rest_framework/routes.py +++ b/sanic_rest_framework/routes.py @@ -36,7 +36,7 @@ class Route: dynamic_uri = '/{prefix}/' static_uri = '/{prefix}' base_method_group = LIST_METHOD_GROUP - if viewset.detail: + if hasattr(viewset, 'detail') and viewset.detail: base_method_group = DETAIL_METHOD_GROUP viewset_methods = self.get_viewset_methods(viewset) diff --git a/sanic_rest_framework/setting.py b/sanic_rest_framework/setting.py new file mode 100644 index 0000000..262afe3 --- /dev/null +++ b/sanic_rest_framework/setting.py @@ -0,0 +1,28 @@ +""" +@Author:WangYuXiang +@E-mile:Hill@3io.cc +@CreateTime:2021/3/11 16:40 +@DependencyLibrary:无 +@MainFunction:无 +@FileDoc: + setting.py + 文件说明 +@ChangeHistory: + datetime action why + example: + 2021/3/11 16:40 change 'Fix bug' + +""" +from sanic_rest_framework.filters import SearchFilter +from sanic_rest_framework.utils import ObjectDict +from sanic.app import Sanic + +srf_config = ObjectDict({ + 'VIEW_DEFAULT_FILTER': SearchFilter +}) +app.config.get('SANIC_REST_FRAMEWORK_CONFIG', {}) +srf_config = srf_config.update(srf_config) + +__all__ = ( + 'srf_config', +) diff --git a/sanic_rest_framework/status.py b/sanic_rest_framework/status.py index d40f94e..e482042 100644 --- a/sanic_rest_framework/status.py +++ b/sanic_rest_framework/status.py @@ -99,13 +99,3 @@ class HttpStatus: HTTP_509_BANDWIDTH_LIMIT_EXCEEDED = 509 HTTP_510_NOT_EXTENDED = 510 HTTP_511_NETWORK_AUTHENTICATION_REQUIRED = 511 - -class c(Exception): - pass - -def aa(): - try: - raise c('发财务') - except Exception as ex: - print(ex) -aa() \ No newline at end of file diff --git a/sanic_rest_framework/utils.py b/sanic_rest_framework/utils.py new file mode 100644 index 0000000..5918fb6 --- /dev/null +++ b/sanic_rest_framework/utils.py @@ -0,0 +1,25 @@ +""" +@Author:WangYuXiang +@E-mile:Hill@3io.cc +@CreateTime:2021/3/11 16:59 +@DependencyLibrary:无 +@MainFunction:无 +@FileDoc: + utils.py + 文件说明 +@ChangeHistory: + datetime action why + example: + 2021/3/11 16:59 change 'Fix bug' + +""" + + +class ObjectDict(dict): + def __getattr__(self, key): + return self[key] + + def __setattr__(self, key, value): + if isinstance(value, dict): + self[key] = ObjectDict(value) + self[key] = value diff --git a/sanic_rest_framework/views.py b/sanic_rest_framework/views.py index 7e6837b..9dd4704 100644 --- a/sanic_rest_framework/views.py +++ b/sanic_rest_framework/views.py @@ -13,6 +13,7 @@ from tortoise.queryset import QuerySet from sanic_rest_framework.constant import ALL_METHOD from sanic_rest_framework.exceptions import APIException +from sanic_rest_framework.setting import srf_config from sanic_rest_framework.status import RuleStatus, HttpStatus @@ -179,12 +180,12 @@ class APIView(BaseAPIView): queryset = None lookup_field = 'pk' serializer_class = None - pagination_class = None + pagination_class = srf_config.VIEW_DEFAULT_FILTER lookup_url_kwarg = None filter_class = None search_fields = None - async def get_queryset(self): + def get_queryset(self): assert self.queryset is not None, ( "'%s'应该包含一个'queryset'属性," "或重写`get_queryset()`方法。" @@ -193,7 +194,7 @@ class APIView(BaseAPIView): queryset = self.queryset if isinstance(queryset, QuerySet): - queryset = await queryset.all() + queryset = queryset.all() return queryset def filter_queryset(self, queryset): -- Gitee From 4b5b780a162a2b81bb491038ac576ec8972ae639 Mon Sep 17 00:00:00 2001 From: LaoSi Date: Fri, 12 Mar 2021 17:02:20 +0800 Subject: [PATCH 27/34] =?UTF-8?q?=E6=8F=90=E4=BA=A4=E4=BF=AE=E6=94=B9=20?= =?UTF-8?q?=E8=B7=AF=E7=94=B1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- db.py | 86 ++++++++++++- run.py | 17 +-- sanic_rest_framework/paginations.py | 67 ++++++++++- sanic_rest_framework/routes.py | 4 +- sanic_rest_framework/setting.py | 28 ----- sanic_rest_framework/utils.py | 49 ++++++++ sanic_rest_framework/views.py | 180 ++++++++++++++++++++++++---- 7 files changed, 358 insertions(+), 73 deletions(-) delete mode 100644 sanic_rest_framework/setting.py diff --git a/db.py b/db.py index c170c4e..2ff7457 100644 --- a/db.py +++ b/db.py @@ -9,12 +9,90 @@ 文件说明 """ from datetime import date +from enum import Enum, IntEnum -from tortoise.fields import CharField, IntField, DateField +from tortoise import fields from tortoise import Model +from tortoise.fields import ForeignKeyRelation, ReverseRelation + + +class Enum1(IntEnum): + OK = 1 + BAD = 2 + + +class CharEnum(Enum): + OK = '好' + BAD = '坏' class TestModel(Model): - name = CharField(max_length=30) - ages = IntField() - birthday = DateField(default=date.today) + char_field = fields.CharField(max_length=8, null=True) + float_field = fields.FloatField() + date_field = fields.DateField() + int_field = fields.IntField() + decimal_field = fields.DecimalField(max_digits=13, decimal_places=3) + datetime_field = fields.DatetimeField() + int_enum_field = fields.IntEnumField(enum_type=Enum1) + char_enum_field = fields.CharEnumField(enum_type=CharEnum) + boolean_field = fields.BooleanField() + small_int_field = fields.SmallIntField() + big_int_field = fields.BigIntField() + text_field = fields.TextField() + json_field = fields.JSONField() + uuid_field = fields.UUIDField() + one_to_many: ForeignKeyRelation["TestManyToOneModel"] = fields.ForeignKeyField('models.TestManyToOneModel', related_name='many_to_one', null=True) + one_to_one: fields.OneToOneRelation["TestOneToOneModel"] = fields.OneToOneField('models.TestOneToOneModel', related_name='one_to_one', null=True) + many_to_many: fields.ManyToManyRelation["TestModel"] = fields.ManyToManyField("models.TestModel", related_name="many_2_many", through="many_many", null=True) + + +class TestOneTOManyModel(Model): + one_to_many: ForeignKeyRelation["TestModel"] = fields.ForeignKeyField('models.TestModel', related_name='many_to_one') + + +class TestManyToOneModel(Model): + many_to_one: fields.ReverseRelation["TestModel"] + + +class TestOneToOneModel(Model): + name = fields.CharField(max_length=8, null=True) + + +class ManyToManyModel(Model): + many_2_many: fields.ManyToManyRelation["TestModel"] + + +class UserModel(Model): + name = fields.CharField(max_length=8, null=False) + birthday = fields.DateField() + phone = fields.CharField(max_length=11) + balance = fields.DecimalField(13, 3) + address: fields.ManyToManyRelation["AddressModel"] = fields.ManyToManyField( + 'models.AddressModel', through='user2address', related_name='user') + + +class AddressModel(Model): + phone = fields.CharField(12, null=False) + address = fields.CharField(100) + house_number = fields.CharField(100) + # user: fields.ManyToManyRelation[UserModel] + + +class SchoolModel(Model): + name = fields.CharField(12) + address: fields.OneToOneRelation["AddressModel"] = fields.OneToOneField("models.AddressModel", 'school') + + +class ClassRoomModel(Model): + room_number = fields.CharField(18) + student_count = fields.IntField() + students: ReverseRelation['StudentModel'] + + +class StudentModel(Model): + name = fields.CharField(max_length=12, null=False) + class_room: ForeignKeyRelation["ClassRoomModel"] = fields.ForeignKeyField('models.ClassRoomModel', 'students') + + +class DateSeriesModel(Model): + name = fields.TimeDeltaField() diff --git a/run.py b/run.py index 45927a5..e56aaf7 100644 --- a/run.py +++ b/run.py @@ -7,21 +7,16 @@ from sanic_rest_framework.request import SRFRequest from sanic_rest_framework.routes import Route from sanic_rest_framework.serializers import Serializer from sanic_rest_framework.fields import CharField -from sanic_rest_framework.views import BaseAPIView +from sanic_rest_framework.views import ( + APIView, RetrieveModelMixin, ListModelMixin, CreateModelMixin, DestroyModelMixin, UpdateModelMixin +) + app = Sanic(__name__, request_class=SRFRequest) admin = Blueprint('admin', '/admin') -class TestView(BaseAPIView): - - async def get(self, request): - return self.success_json_response() - - async def post(self, request): - return self.success_json_response() - - async def put(self, request, pk): - return self.success_json_response() +class TestView(APIView, RetrieveModelMixin): + pass route = Route() diff --git a/sanic_rest_framework/paginations.py b/sanic_rest_framework/paginations.py index cc89ff4..af6be82 100644 --- a/sanic_rest_framework/paginations.py +++ b/sanic_rest_framework/paginations.py @@ -13,7 +13,70 @@ 2021/3/11 17:37 change 'Fix bug' """ +from sanic_rest_framework.exceptions import APIException +from sanic_rest_framework.status import HttpStatus +from sanic.request import Request +from sanic_rest_framework.utils import replace_query_param -class BasePagination: - pass + +class LimitOffsetPagination: + page_size = 60 + page_query_param = 'page' + page_size_query_param = 'page_size' + max_page_size = 10000 + + @property + def count(self): + assert not hasattr(self, '_count'), '必须先执行 `.paginate_queryset()` 函数才能使用.count' + return self._count + + def get_next_link(self, request: Request): + assert not hasattr(self, '_count'), '必须先执行 `.paginate_queryset()` 函数才能使用.next()' + if self.page * self.page_size + self.page_size >= self.count: + return None + uri = request.server_path + query_string = request.query_string + query_string = replace_query_param(query_string, self.page_query_param, self.page) + query_string = replace_query_param(query_string, self.page_size_query_param, self.page_size) + return '%s?%s' % (uri, query_string) + + def get_previous_link(self, request: Request): + assert not hasattr(self, '_count'), '必须先执行 `.paginate_queryset()` 函数才能使用.next()' + if self.page * self.page_size <= 0: + return None + uri = request.server_path + query_string = request.query_string + query_string = replace_query_param(query_string, self.page_query_param, self.page) + query_string = replace_query_param(query_string, self.page_size_query_param, self.page_size) + return '%s?%s' % (uri, query_string) + + def paginate_queryset(self, queryset, request, view): + self.page = self.get_query_page(request) + self.page_size = self.get_query_page_size(request) + self._count = queryset.count() + return queryset.limit(self.page_size).offset(self.page * self.page_size) + + def get_query_page(self, request): + try: + page = int(request.get(self.page_query_param, 0)) + except ValueError as exc: + raise APIException('发生错误的分页数据', http_status=HttpStatus.HTTP_400_BAD_REQUEST) + return page + + def get_query_page_size(self, request): + try: + page = int(request.get(self.page_size_query_param, self.page_size)) + if page > self.max_page_size: + raise APIException('分页内容大小超出最大限制', http_status=HttpStatus.HTTP_400_BAD_REQUEST) + except ValueError as exc: + raise APIException('发生错误的分页数据', http_status=HttpStatus.HTTP_400_BAD_REQUEST) + return page + + def response(self, request): + return { + 'count': self.count, + 'next': None, + 'previous': None, + 'results': None, + } diff --git a/sanic_rest_framework/routes.py b/sanic_rest_framework/routes.py index cfb15a9..94a901c 100644 --- a/sanic_rest_framework/routes.py +++ b/sanic_rest_framework/routes.py @@ -33,7 +33,7 @@ class Route: if name is None: name = prefix - dynamic_uri = '/{prefix}/' + dynamic_uri = '/{prefix}/<{lookup_field}:string>' static_uri = '/{prefix}' base_method_group = LIST_METHOD_GROUP if hasattr(viewset, 'detail') and viewset.detail: @@ -46,7 +46,7 @@ class Route: if viewset_dynamic_method: self.routes.append({ 'handler': viewset.as_view(viewset_dynamic_method), - 'uri': dynamic_uri.format(prefix=prefix), + 'uri': dynamic_uri.format(prefix=prefix, lookup_field=viewset.lookup_field), 'name': '{name}_detail'.format(name=name) }) if viewset_static_method: diff --git a/sanic_rest_framework/setting.py b/sanic_rest_framework/setting.py deleted file mode 100644 index 262afe3..0000000 --- a/sanic_rest_framework/setting.py +++ /dev/null @@ -1,28 +0,0 @@ -""" -@Author:WangYuXiang -@E-mile:Hill@3io.cc -@CreateTime:2021/3/11 16:40 -@DependencyLibrary:无 -@MainFunction:无 -@FileDoc: - setting.py - 文件说明 -@ChangeHistory: - datetime action why - example: - 2021/3/11 16:40 change 'Fix bug' - -""" -from sanic_rest_framework.filters import SearchFilter -from sanic_rest_framework.utils import ObjectDict -from sanic.app import Sanic - -srf_config = ObjectDict({ - 'VIEW_DEFAULT_FILTER': SearchFilter -}) -app.config.get('SANIC_REST_FRAMEWORK_CONFIG', {}) -srf_config = srf_config.update(srf_config) - -__all__ = ( - 'srf_config', -) diff --git a/sanic_rest_framework/utils.py b/sanic_rest_framework/utils.py index 5918fb6..52177da 100644 --- a/sanic_rest_framework/utils.py +++ b/sanic_rest_framework/utils.py @@ -13,6 +13,15 @@ 2021/3/11 16:59 change 'Fix bug' """ +import datetime +from decimal import Decimal +from urllib import parse + +from sanic_rest_framework.exceptions import APIException + +_PROTECTED_TYPES = ( + type(None), int, float, Decimal, datetime.datetime, datetime.date, datetime.time, +) class ObjectDict(dict): @@ -23,3 +32,43 @@ class ObjectDict(dict): if isinstance(value, dict): self[key] = ObjectDict(value) self[key] = value + + +def is_protected_type(obj): + """确定对象实例是否为受保护的类型。 + 受保护类型的对象在传递给时会原样保留 + force_str(strings_only = True)。 + """ + return isinstance(obj, _PROTECTED_TYPES) + + +def force_str(s, encoding='utf-8', strings_only=False, errors='strict'): + """ + 与smart_str()类似,除了将懒实例解析为 + 字符串,而不是保留为惰性对象。 + 如果strings_only为True,请不要转换(某些)非字符串类对象。 + """ + # 出于性能原因,请先处理常见情况。 + if issubclass(type(s), str): + return s + if strings_only and is_protected_type(s): + return s + try: + if isinstance(s, bytes): + s = str(s, encoding, errors) + else: + s = str(s) + except UnicodeDecodeError as e: + raise APIException('{value}出现解码错误'.format(value=s)) + return s + + +def replace_query_param(url, key, val): + """ + 给定一个URL和一个键/值对,在URL的查询参数中设置或替换一个项目,然后返回新的URL。 + """ + (scheme, netloc, path, query, fragment) = parse.urlsplit(force_str(url)) + query_dict = parse.parse_qs(query, keep_blank_values=True) + query_dict[force_str(key)] = [force_str(val)] + query = parse.urlencode(sorted(list(query_dict.items())), doseq=True) + return parse.urlunsplit((scheme, netloc, path, query, fragment)) diff --git a/sanic_rest_framework/views.py b/sanic_rest_framework/views.py index 9dd4704..d3d9a1b 100644 --- a/sanic_rest_framework/views.py +++ b/sanic_rest_framework/views.py @@ -10,10 +10,12 @@ """ from sanic.response import json from tortoise.queryset import QuerySet +from sanic.log import logger from sanic_rest_framework.constant import ALL_METHOD from sanic_rest_framework.exceptions import APIException -from sanic_rest_framework.setting import srf_config +from sanic_rest_framework.filters import SearchFilter +from sanic_rest_framework.paginations import LimitOffsetPagination from sanic_rest_framework.status import RuleStatus, HttpStatus @@ -36,7 +38,10 @@ class BaseAPIView: response = handler(request, *args, **kwargs) except APIException as exc: response = self.handle_exception(exc) + except AssertionError as exc: + raise exc except Exception as exc: + logger.error('--捕获未知错误--', exc) response = self.handle_uncaught_exception(exc) return response @@ -108,6 +113,31 @@ class BaseAPIView: return self.json_response(data=data, msg=msg, status=RuleStatus.STATUS_0_FAIL, http_status=HttpStatus.HTTP_400_BAD_REQUEST) + def get_object(self): + """ + 返回视图显示的对象。 + 如果您需要提供非标准的内容,则可能要覆盖此设置 + queryset查找。 + """ + queryset = self.filter_queryset(self.get_queryset()) + + lookup_field = self.lookup_field + + assert lookup_field in self.kwargs, ( + '%s 不存在于 %s 的 Url配置中的关键词内 ' % + (lookup_field, self.__class__.__name__,) + ) + + filter_kwargs = {lookup_field: self.kwargs[lookup_field]} + obj = queryset.get_or_none(**filter_kwargs) + if obj is None: + raise APIException('不存在%s为%s的数据'.format(lookup_field, self.kwargs[lookup_field])) + + # May raise a permission denied + self.check_object_permissions(self.request, obj) + + return obj + def get_authenticators(self): """ 实例化并返回此视图可以使用的身份验证器列表 @@ -169,25 +199,18 @@ class BaseAPIView: self.check_throttles(request) -# class ViewJsonHelperMixin: - - -# class APIView(ViewJsonHelperMixin, BaseAPIView): - - class APIView(BaseAPIView): detail = False queryset = None lookup_field = 'pk' serializer_class = None - pagination_class = srf_config.VIEW_DEFAULT_FILTER - lookup_url_kwarg = None - filter_class = None + pagination_class = None + filter_class = SearchFilter search_fields = None def get_queryset(self): assert self.queryset is not None, ( - "'%s'应该包含一个'queryset'属性," + "'%s'应该包含一个'queryset'属性," "或重写`get_queryset()`方法。" % self.__class__.__name__ ) @@ -198,7 +221,7 @@ class APIView(BaseAPIView): return queryset def filter_queryset(self, queryset): - self.filter_class().filter_queryset(self.request, queryset) + return self.filter_class().filter_queryset(self.request, queryset, self) def get_serializer(self, *args, **kwargs): """ @@ -247,17 +270,122 @@ class APIView(BaseAPIView): self._paginator = self.pagination_class() return self._paginator - # def paginate_queryset(self, queryset): - # """ - # Return a single page of results, or `None` if pagination is disabled. - # """ - # if self.paginator is None: - # return None - # return self.paginator.paginate_queryset(queryset, self.request, view=self) - # - # def get_paginated_response(self, data): - # """ - # Return a paginated style `Response` object for the given output data. - # """ - # assert self.paginator is not None - # return self.paginator.get_paginated_response(data) + def paginate_queryset(self, queryset): + """ + 返回单页结果,如果禁用了分页,则返回“无”。 + """ + if self.paginator is None: + return None + return self.paginator.paginate_queryset(queryset, self.request, view=self) + + def get_paginated_response(self, data): + """ + 返回给定输出数据的分页样式`Response`对象。 + """ + assert self.paginator is not None + return self.paginator.get_paginated_response(data) + + +class ListModelMixin: + """ + 适用于输出列表类型数据 + """ + pagination_class = LimitOffsetPagination + detail = False + + def get(self, request, *args, **kwargs): + return self.list(request, *args, **kwargs) + + def list(self, request, *args, **kwargs): + queryset = self.filter_queryset(self.get_queryset()) + + page = self.paginate_queryset(queryset) + if page is not None: + serializer = self.get_serializer(page, many=True) + data = self.get_paginated_response(serializer.data) + return self.success_json_response(data=data) + + serializer = self.get_serializer(queryset, many=True) + return self.success_json_response(data=serializer.data) + + +class CreateModelMixin: + """ + 适用于快速创建内容 + 占用 post 方法 + """ + + def post(self, request, *args, **kwargs): + return self.create(request, *args, **kwargs) + + def create(self, request, *args, **kwargs): + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + self.perform_create(serializer) + return self.success_json_response(data=serializer.data, http_status=HttpStatus.HTTP_201_CREATED) + + def perform_create(self, serializer): + serializer.save() + + +class RetrieveModelMixin: + """ + 适用于查询指定PK的内容 + """ + detail = True + + def get(self, request, *args, **kwargs): + return self.retrieve(request, *args, **kwargs) + + def retrieve(self, request, *args, **kwargs): + instance = self.get_object() + serializer = self.get_serializer(instance) + return self.success_json_response(data=serializer.data) + + +class UpdateModelMixin: + """ + 适用于快速创建更新操作 + """ + + def put(self, request, *args, **kwargs): + return self.update(request, *args, **kwargs) + + def patch(self, request, *args, **kwargs): + return self.partial_update(request, *args, **kwargs) + + def update(self, request, *args, **kwargs): + partial = kwargs.pop('partial', False) + instance = self.get_object() + serializer = self.get_serializer(instance, data=request.data, partial=partial) + serializer.is_valid(raise_exception=True) + self.perform_update(serializer) + + # if getattr(instance, '_prefetched_objects_cache', None): + # instance._prefetched_objects_cache = {} + + return self.success_json_response(serializer.data) + + def perform_update(self, serializer): + serializer.save() + + def partial_update(self, request, *args, **kwargs): + kwargs['partial'] = True + return self.update(request, *args, **kwargs) + + +class DestroyModelMixin: + """ + 用于快速删除 + """ + + def delete(self, request, *args, **kwargs): + return self.destroy(request, *args, **kwargs) + + def destroy(self, request, *args, **kwargs): + instance = self.get_object() + self.perform_destroy(instance) + return self.success_json_response(status=HttpStatus.HTTP_204_NO_CONTENT) + + def perform_destroy(self, instance): + instance.delete() -- Gitee From 06a90b619698aa683bb1b428e6d57c1631ed943a Mon Sep 17 00:00:00 2001 From: LaoSi Date: Sat, 13 Mar 2021 01:12:07 +0800 Subject: [PATCH 28/34] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E8=BF=87=E6=BB=A4?= =?UTF-8?q?=E5=99=A8,=E5=88=86=E9=A1=B5=E5=99=A8=E5=9C=A8=E5=BC=82?= =?UTF-8?q?=E6=AD=A5=E6=97=B6=E5=87=BA=E7=8E=B0=E7=9A=84=E9=94=99=E8=AF=AF?= =?UTF-8?q?=E5=B9=B6=E4=BF=AE=E6=94=B9=20class=20APIView=20=E4=B8=BA?= =?UTF-8?q?=E5=8F=AA=E6=94=AF=E6=8C=81=E5=BC=82=E6=AD=A5=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 5 +- README.md | 104 +++++++++++++++++++++------- run.py | 45 +++++++++--- sanic_rest_framework/exceptions.py | 6 +- sanic_rest_framework/filters.py | 64 +++++++++++++---- sanic_rest_framework/paginations.py | 58 ++++++++++------ sanic_rest_framework/serializers.py | 1 + sanic_rest_framework/views.py | 60 ++++++++-------- 8 files changed, 242 insertions(+), 101 deletions(-) diff --git a/.gitignore b/.gitignore index 51cb344..fa98b5a 100644 --- a/.gitignore +++ b/.gitignore @@ -114,4 +114,7 @@ dmypy.json # Pyre type checker .pyre/ .idea/ -.vscode/ \ No newline at end of file +.vscode/ +/*.sqlite +/*.sqlite-shm +/*.sqlite-wal diff --git a/README.md b/README.md index 24112b6..b3b2f7b 100644 --- a/README.md +++ b/README.md @@ -1,38 +1,94 @@ # sanic_rest_framework #### 介绍 -Source of inspiration DjangoRestFramework -基于Sanic的rest api插件,集大成之作,长期维护。 -#### 软件架构 -软件架构说明 +Source of inspiration DjangoRestFramework 基于Sanic的rest api插件,集大成之作,长期维护。 +#### 简单案例 -#### 安装教程 +> models.py -1. xxxx -2. xxxx -3. xxxx +```python +from tortoise.models import Model +from tortoise import fields -#### 使用说明 -1. xxxx -2. xxxx -3. xxxx +class AddressModel(Model): + phone = fields.CharField(12, null=False) + address = fields.CharField(100) + house_number = fields.CharField(100) +``` -#### 参与贡献 +> serializers.py -1. Fork 本仓库 -2. 新建 Feat_xxx 分支 -3. 提交代码 -4. 新建 Pull Request +```python +from sanic_rest_framework.serializers import ModelSerializer +from models import AddressModel -#### 特技 +class TestSerializer(ModelSerializer): + class Meta: + model = AddressModel +``` + +> run.py + +```python +from sanic.app import Sanic +from tortoise.contrib.sanic import register_tortoise +from sanic_rest_framework.request import SRFRequest +from sanic_rest_framework.routes import Route +from sanic_rest_framework.views import RetrieveModelMixin, ListModelMixin, APIView +from models import AddressModel +from serializers import TestSerializer + +app = Sanic(__name__, request_class=SRFRequest) + + +class Test1View(RetrieveModelMixin, APIView): + queryset = AddressModel + serializer_class = TestSerializer + search_fields = ('phone',) + + async def retrieve(self, request, *args, **kwargs): + for i in range(100): + adr = AddressModel() + adr.phone = 1767470900 + i + adr.house_number = '房间%s' % i + adr.address = '地址%s' % i + await adr.save() + return self.success_json_response(msg='添加数据成功,请访问 http://127.0.0.1:8000/test2') + + +class Test2View(ListModelMixin, APIView): + queryset = AddressModel + serializer_class = TestSerializer + search_fields = ('@phone',) + + +route = Route() +route.register_route('test1', Test1View) +route.register_route('test2', Test2View) +route.initialize(app) + +register_tortoise( + app, + db_url="sqlite://./db.sqlite", + # db_url="sqlite://:memory:", + modules={"models": ["models"]}, + generate_schemas=True +) + +if __name__ == '__main__': + app.run(host="127.0.0.1", port=8000, debug=True, auto_reload=True) + +# 访问即可看效果 +# http://127.0.0.1:8000/test1 (先访问) +# http://127.0.0.1:8000/test2 +# http://127.0.0.1:8000/test2?phone='0' +# http://127.0.0.1:8000/test2?phone='li' +# http://127.0.0.1:8000/test2?phone='1767470901' +# http://127.0.0.1:8000/test2?page=1 +# http://127.0.0.1:8000/test2?page_size=10 +``` -1. 使用 Readme\_XXX.md 来支持不同的语言,例如 Readme\_en.md, Readme\_zh.md -2. Gitee 官方博客 [blog.gitee.com](https://blog.gitee.com) -3. 你可以 [https://gitee.com/explore](https://gitee.com/explore) 这个地址来了解 Gitee 上的优秀开源项目 -4. [GVP](https://gitee.com/gvp) 全称是 Gitee 最有价值开源项目,是综合评定出的优秀开源项目 -5. Gitee 官方提供的使用手册 [https://gitee.com/help](https://gitee.com/help) -6. Gitee 封面人物是一档用来展示 Gitee 会员风采的栏目 [https://gitee.com/gitee-stars/](https://gitee.com/gitee-stars/) diff --git a/run.py b/run.py index e56aaf7..06686b1 100644 --- a/run.py +++ b/run.py @@ -2,10 +2,10 @@ from sanic import Sanic from sanic.blueprints import Blueprint from tortoise.contrib.sanic import register_tortoise -from db import TestModel +from db import AddressModel from sanic_rest_framework.request import SRFRequest from sanic_rest_framework.routes import Route -from sanic_rest_framework.serializers import Serializer +from sanic_rest_framework.serializers import Serializer, ModelSerializer from sanic_rest_framework.fields import CharField from sanic_rest_framework.views import ( APIView, RetrieveModelMixin, ListModelMixin, CreateModelMixin, DestroyModelMixin, UpdateModelMixin @@ -15,17 +15,44 @@ app = Sanic(__name__, request_class=SRFRequest) admin = Blueprint('admin', '/admin') -class TestView(APIView, RetrieveModelMixin): - pass +class TestSerializer(ModelSerializer): + class Meta: + model = AddressModel + + +class Test1View(RetrieveModelMixin, APIView): + queryset = AddressModel + serializer_class = TestSerializer + search_fields = ('phone',) + + async def retrieve(self, request, *args, **kwargs): + for i in range(100): + adr = AddressModel() + adr.phone = 1767470900 + i + adr.house_number = '房间%s' % i + adr.address = '地址%s' % i + await adr.save() + + +class Test2View(ListModelMixin, APIView): + queryset = AddressModel + serializer_class = TestSerializer + search_fields = ('@phone',) route = Route() -route.register_route('test', TestView) -route.initialize(admin) +route.register_route('test1', Test1View) +route.register_route('test2', Test2View) +# route.initialize(admin) +route.initialize(app) app.blueprint(admin) register_tortoise( - app, db_url="sqlite://./db.sqlite", modules={"models": ["db"]}, generate_schemas=True + app, + db_url="sqlite://./db.sqlite", + # db_url="sqlite://:memory:", + modules={"models": ["db"]}, + generate_schemas=True ) - -app.run(host="127.0.0.1", port=8000, debug=True, auto_reload=True) +if __name__ == '__main__': + app.run(host="127.0.0.1", port=8000, debug=True, auto_reload=True) diff --git a/sanic_rest_framework/exceptions.py b/sanic_rest_framework/exceptions.py index cb1270f..8e799b9 100644 --- a/sanic_rest_framework/exceptions.py +++ b/sanic_rest_framework/exceptions.py @@ -100,15 +100,15 @@ class ValidatorAssertError(Exception): class APIException(Exception): - def __init__(self, message, status=RuleStatus.STATUS_0_FAIL, http_status=HttpStatus.HTTP_500_INTERNAL_SERVER_ERROR, *args, **kwargs): + def __init__(self, message, status=RuleStatus.STATUS_0_FAIL, http_status=HttpStatus.HTTP_500_INTERNAL_SERVER_ERROR, + *args, **kwargs): self.message = message self.status = status self.http_status = http_status def response_data(self): return { - 'message': self.message, + 'msg': self.message, 'status': self.status, 'http_status': self.http_status } - diff --git a/sanic_rest_framework/filters.py b/sanic_rest_framework/filters.py index ca64fc6..a680f60 100644 --- a/sanic_rest_framework/filters.py +++ b/sanic_rest_framework/filters.py @@ -33,10 +33,12 @@ class SearchFilter(SimpleFilter): '>': 'gt', '<': 'lt', '>=': 'gte', - '<=': 'lte' + '<=': 'lte', + '=': 'contains', + '@': 'icontains' } - def get_search_fields(self, view, request): + def get_search_fields(self, request, view): """ 搜索字段是从视图获取的,但请求始终是 传递给此方法。子类可以重写此方法以 @@ -45,28 +47,62 @@ class SearchFilter(SimpleFilter): return getattr(view, 'search_fields', None) def filter_queryset(self, request, queryset, view): - search_fields = self.get_search_fields(view, request) + """ + 根据定义的搜索字段过滤传入的queryset + :param request: 当前请求 + :param queryset: 查询对象 + :param view: 当前视图 + :return: + """ + search_fields = self.get_search_fields(request, view) if not search_fields: return queryset orm_filters = {} for search_field in search_fields: - orm_filters.update(self.construct_orm_filter(search_field, view, queryset)) - return queryset.filters(**orm_filters) + orm_filters.update(self.construct_orm_filter(search_field, request, view)) + return queryset.filter(**orm_filters) - def construct_orm_filter(self, field_name, view, request): + def dismantle_search_field(self, search_field): + """ + 拆解带有特殊字符的搜索字段 + :param search_field: 搜索字段 + :return: (field_name, lookup_suffix) + """ lookup_suffix_keys = list(self.lookup_prefixes.keys()) lookup_suffix = None - for key in lookup_suffix_keys: - if key in field_name: - field_name = field_name[len(key):] - lookup_suffix_key = field_name[:len(key)] - lookup_suffix = self.lookup_prefixes.get(lookup_suffix_key) - break - if not lookup_suffix: + field_name = None + for lookup_suffix_key in lookup_suffix_keys: + if lookup_suffix_key in search_field: + lookup_suffix = self.lookup_prefixes[lookup_suffix_key] + field_name = search_field[len(lookup_suffix_key):] + return field_name, lookup_suffix + return field_name, lookup_suffix + + def construct_orm_filter(self, search_field, request, view): + """ + 构造适用于orm的过滤参数 + :param search_field: 搜索字段 + :param request: 当前请求 + :param view: 视图 + :return: + """ + field_name, lookup_suffix = self.dismantle_search_field(search_field) + args = request.args + + if field_name not in args: return {} - orm_lookup = LOOKUP_SEP.join([field_name, lookup_suffix]) + if lookup_suffix: + orm_lookup = LOOKUP_SEP.join([field_name, lookup_suffix]) + else: + orm_lookup = field_name return {orm_lookup: self.get_filter_value(request, field_name)} def get_filter_value(self, request, field_name): + """ + 根据字段名从请求中得到值 + :param request: 当前请求 + :param field_name: 字段名 + :return: + """ values = request.args.get(field_name) return ''.join(values) diff --git a/sanic_rest_framework/paginations.py b/sanic_rest_framework/paginations.py index af6be82..b0f3106 100644 --- a/sanic_rest_framework/paginations.py +++ b/sanic_rest_framework/paginations.py @@ -6,21 +6,35 @@ @MainFunction:无 @FileDoc: paginations.py - 文件说明 + 分页器 @ChangeHistory: datetime action why example: 2021/3/11 17:37 change 'Fix bug' """ -from sanic_rest_framework.exceptions import APIException -from sanic_rest_framework.status import HttpStatus from sanic.request import Request +from tortoise import Model +from sanic_rest_framework.exceptions import APIException +from sanic_rest_framework.status import HttpStatus from sanic_rest_framework.utils import replace_query_param -class LimitOffsetPagination: +class BasePagination: + """抽象基类""" + + async def paginate_queryset(self, queryset, request, view): + raise NotImplementedError( + '必须在 `%s` 中实现异步的 `.paginate_queryset()` 方法' % self.__class__.__name__) + + def response(self, request, data): + raise NotImplementedError( + '必须在 `%s` 中实现 `.response()` 方法' % self.__class__.__name__) + + +class GeneralPagination(BasePagination): + """通用分页器""" page_size = 60 page_query_param = 'page' page_size_query_param = 'page_size' @@ -28,55 +42,57 @@ class LimitOffsetPagination: @property def count(self): - assert not hasattr(self, '_count'), '必须先执行 `.paginate_queryset()` 函数才能使用.count' + assert hasattr(self, '_count'), '必须先执行 `.paginate_queryset()` 函数才能使用.count' return self._count def get_next_link(self, request: Request): - assert not hasattr(self, '_count'), '必须先执行 `.paginate_queryset()` 函数才能使用.next()' + assert hasattr(self, '_count'), '必须先执行 `.paginate_queryset()` 函数才能使用.get_next_link()' if self.page * self.page_size + self.page_size >= self.count: return None uri = request.server_path - query_string = request.query_string - query_string = replace_query_param(query_string, self.page_query_param, self.page) + query_string = '?' + request.query_string + query_string = replace_query_param(query_string, self.page_query_param, self.page + 1) query_string = replace_query_param(query_string, self.page_size_query_param, self.page_size) - return '%s?%s' % (uri, query_string) + return uri + query_string def get_previous_link(self, request: Request): - assert not hasattr(self, '_count'), '必须先执行 `.paginate_queryset()` 函数才能使用.next()' + assert hasattr(self, '_count'), '必须先执行 `.paginate_queryset()` 函数才能使用.get_previous_link()' if self.page * self.page_size <= 0: return None uri = request.server_path - query_string = request.query_string - query_string = replace_query_param(query_string, self.page_query_param, self.page) + query_string = '?' + request.query_string + query_string = replace_query_param(query_string, self.page_query_param, self.page - 1) query_string = replace_query_param(query_string, self.page_size_query_param, self.page_size) - return '%s?%s' % (uri, query_string) + return uri + query_string - def paginate_queryset(self, queryset, request, view): + async def paginate_queryset(self, queryset, request, view): self.page = self.get_query_page(request) self.page_size = self.get_query_page_size(request) - self._count = queryset.count() + if not isinstance(queryset, Model): + queryset = queryset.filter() + self._count = await queryset.count() return queryset.limit(self.page_size).offset(self.page * self.page_size) def get_query_page(self, request): try: - page = int(request.get(self.page_query_param, 0)) + page = int(request.args.get(self.page_query_param, 0)) except ValueError as exc: raise APIException('发生错误的分页数据', http_status=HttpStatus.HTTP_400_BAD_REQUEST) return page def get_query_page_size(self, request): try: - page = int(request.get(self.page_size_query_param, self.page_size)) + page = int(request.args.get(self.page_size_query_param, self.page_size)) if page > self.max_page_size: raise APIException('分页内容大小超出最大限制', http_status=HttpStatus.HTTP_400_BAD_REQUEST) except ValueError as exc: raise APIException('发生错误的分页数据', http_status=HttpStatus.HTTP_400_BAD_REQUEST) return page - def response(self, request): + def response(self, request, data): return { 'count': self.count, - 'next': None, - 'previous': None, - 'results': None, + 'next': self.get_next_link(request), + 'previous': self.get_previous_link(request), + 'results': data, } diff --git a/sanic_rest_framework/serializers.py b/sanic_rest_framework/serializers.py index 45e01e1..228194c 100644 --- a/sanic_rest_framework/serializers.py +++ b/sanic_rest_framework/serializers.py @@ -314,6 +314,7 @@ class ListSerializer(BaseSerializer): self.child.bind(field_name='', parent=self) async def get_internal_value(self, instance: Any) -> Any: + for attr in self.source_attrs: if isinstance(instance, Mapping): instance = instance.get(attr, []) diff --git a/sanic_rest_framework/views.py b/sanic_rest_framework/views.py index d3d9a1b..5c6aead 100644 --- a/sanic_rest_framework/views.py +++ b/sanic_rest_framework/views.py @@ -15,7 +15,7 @@ from sanic.log import logger from sanic_rest_framework.constant import ALL_METHOD from sanic_rest_framework.exceptions import APIException from sanic_rest_framework.filters import SearchFilter -from sanic_rest_framework.paginations import LimitOffsetPagination +from sanic_rest_framework.paginations import GeneralPagination from sanic_rest_framework.status import RuleStatus, HttpStatus @@ -24,7 +24,7 @@ class BaseAPIView: authentication_classes = () permission_classes = () - def dispatch(self, request, *args, **kwargs): + async def dispatch(self, request, *args, **kwargs): """分发路由""" request.user = None method = request.method @@ -35,7 +35,7 @@ class BaseAPIView: try: self.initial(request, *args, **kwargs) - response = handler(request, *args, **kwargs) + response = await handler(request, *args, **kwargs) except APIException as exc: response = self.handle_exception(exc) except AssertionError as exc: @@ -73,7 +73,9 @@ class BaseAPIView: def handle_uncaught_exception(self, exc): """处理未知错误""" - return self.json_response(msg=exc.args[0], status=RuleStatus.STATUS_0_FAIL, http_status=HttpStatus.HTTP_500_INTERNAL_SERVER_ERROR) + message = '{}:{}'.format(exc.__class__.__name__, '|'.join(exc.args)) + return self.json_response(msg=message, status=RuleStatus.STATUS_0_FAIL, + http_status=HttpStatus.HTTP_500_INTERNAL_SERVER_ERROR) def json_response(self, data=None, msg="OK", status=RuleStatus.STATUS_1_SUCCESS, http_status=HttpStatus.HTTP_200_OK): @@ -113,7 +115,7 @@ class BaseAPIView: return self.json_response(data=data, msg=msg, status=RuleStatus.STATUS_0_FAIL, http_status=HttpStatus.HTTP_400_BAD_REQUEST) - def get_object(self): + async def get_object(self): """ 返回视图显示的对象。 如果您需要提供非标准的内容,则可能要覆盖此设置 @@ -129,9 +131,9 @@ class BaseAPIView: ) filter_kwargs = {lookup_field: self.kwargs[lookup_field]} - obj = queryset.get_or_none(**filter_kwargs) + obj = await queryset.get_or_none(**filter_kwargs) if obj is None: - raise APIException('不存在%s为%s的数据'.format(lookup_field, self.kwargs[lookup_field])) + raise APIException('不存在%s为%s的数据' % (lookup_field, self.kwargs[lookup_field])) # May raise a permission denied self.check_object_permissions(self.request, obj) @@ -270,43 +272,43 @@ class APIView(BaseAPIView): self._paginator = self.pagination_class() return self._paginator - def paginate_queryset(self, queryset): + async def paginate_queryset(self, queryset): """ 返回单页结果,如果禁用了分页,则返回“无”。 """ if self.paginator is None: return None - return self.paginator.paginate_queryset(queryset, self.request, view=self) + return await self.paginator.paginate_queryset(queryset, self.request, view=self) def get_paginated_response(self, data): """ 返回给定输出数据的分页样式`Response`对象。 """ assert self.paginator is not None - return self.paginator.get_paginated_response(data) + return self.paginator.response(self.request, data) class ListModelMixin: """ 适用于输出列表类型数据 """ - pagination_class = LimitOffsetPagination + pagination_class = GeneralPagination detail = False - def get(self, request, *args, **kwargs): - return self.list(request, *args, **kwargs) + async def get(self, request, *args, **kwargs): + return await self.list(request, *args, **kwargs) - def list(self, request, *args, **kwargs): + async def list(self, request, *args, **kwargs): queryset = self.filter_queryset(self.get_queryset()) - page = self.paginate_queryset(queryset) + page = await self.paginate_queryset(queryset) if page is not None: serializer = self.get_serializer(page, many=True) - data = self.get_paginated_response(serializer.data) + data = self.get_paginated_response(await serializer.data) return self.success_json_response(data=data) serializer = self.get_serializer(queryset, many=True) - return self.success_json_response(data=serializer.data) + return self.success_json_response(data=await serializer.data) class CreateModelMixin: @@ -315,17 +317,17 @@ class CreateModelMixin: 占用 post 方法 """ - def post(self, request, *args, **kwargs): - return self.create(request, *args, **kwargs) + async def post(self, request, *args, **kwargs): + return await self.create(request, *args, **kwargs) - def create(self, request, *args, **kwargs): + async def create(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) - self.perform_create(serializer) - return self.success_json_response(data=serializer.data, http_status=HttpStatus.HTTP_201_CREATED) + await self.perform_create(serializer) + return self.success_json_response(data=await serializer.data, http_status=HttpStatus.HTTP_201_CREATED) - def perform_create(self, serializer): - serializer.save() + async def perform_create(self, serializer): + await serializer.save() class RetrieveModelMixin: @@ -334,13 +336,13 @@ class RetrieveModelMixin: """ detail = True - def get(self, request, *args, **kwargs): - return self.retrieve(request, *args, **kwargs) + async def get(self, request, *args, **kwargs): + return await self.retrieve(request, *args, **kwargs) - def retrieve(self, request, *args, **kwargs): - instance = self.get_object() + async def retrieve(self, request, *args, **kwargs): + instance = await self.get_object() serializer = self.get_serializer(instance) - return self.success_json_response(data=serializer.data) + return self.success_json_response(data=await serializer.data) class UpdateModelMixin: -- Gitee From 0ce8141c214c8c3c5e8a576869e1e7b66e2bd668 Mon Sep 17 00:00:00 2001 From: LaoSi Date: Mon, 15 Mar 2021 17:01:56 +0800 Subject: [PATCH 29/34] =?UTF-8?q?=E6=8F=90=E4=BA=A4=E9=92=88=E5=AF=B9?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E5=BA=8F=E5=88=97=E5=8C=96=E5=99=A8=E4=BF=9D?= =?UTF-8?q?=E5=AD=98=E7=9A=84=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- run.py | 7 ++ sanic_rest_framework/converter.py | 8 +- sanic_rest_framework/fields.py | 1 + sanic_rest_framework/request.py | 10 ++ sanic_rest_framework/serializers.py | 151 +++++++++++++++++++++++++++- sanic_rest_framework/test/models.py | 2 +- sanic_rest_framework/views.py | 17 ++-- 7 files changed, 184 insertions(+), 12 deletions(-) diff --git a/run.py b/run.py index 06686b1..5926f3f 100644 --- a/run.py +++ b/run.py @@ -40,9 +40,16 @@ class Test2View(ListModelMixin, APIView): search_fields = ('@phone',) +class Test3View(CreateModelMixin, APIView): + queryset = AddressModel + serializer_class = TestSerializer + search_fields = ('@phone',) + + route = Route() route.register_route('test1', Test1View) route.register_route('test2', Test2View) +route.register_route('test3', Test3View) # route.initialize(admin) route.initialize(app) app.blueprint(admin) diff --git a/sanic_rest_framework/converter.py b/sanic_rest_framework/converter.py index f8936c0..faab20b 100644 --- a/sanic_rest_framework/converter.py +++ b/sanic_rest_framework/converter.py @@ -54,16 +54,19 @@ class ModelConverter(ModelConverterBase): def convert(self, serializer, model_field, **field_kwargs): model = serializer.Meta.model read_only = field_kwargs.get('read_only', False) + write_only = field_kwargs.get('write_only', False) required = not model_field.null if read_only and required: raise ValueError('{}序列化器内字段{}为必填项不能使用read_only属性'.format( type(serializer).__name__, model_field.model_field_name)) + if model_field.pk: + field_kwargs['read_only'] = True + field_kwargs['required'] = False kwargs = { 'read_only': read_only, - 'write_only': False, + 'write_only': write_only, 'required': required, 'allow_null': model_field.null, - # 'allow_empty': False, # M2M O2M # 'source': None, 'description': model_field.description } @@ -78,6 +81,7 @@ class ModelConverter(ModelConverterBase): nested_depth = serializer.Meta.depth else: nested_depth = 10 + kwargs['allow_empty'] = model_field.null kwargs['nested_depth'] = nested_depth converter = self.converters[type_name] kwargs.update(field_kwargs) diff --git a/sanic_rest_framework/fields.py b/sanic_rest_framework/fields.py index 6c6f318..c77a82c 100644 --- a/sanic_rest_framework/fields.py +++ b/sanic_rest_framework/fields.py @@ -269,6 +269,7 @@ class Field: if self.read_only: return True, self.get_default() if data is empty: + # if self.is_partial(): raise SkipField() if self.required: diff --git a/sanic_rest_framework/request.py b/sanic_rest_framework/request.py index cd9f180..0261545 100644 --- a/sanic_rest_framework/request.py +++ b/sanic_rest_framework/request.py @@ -13,7 +13,9 @@ 2021/3/11 15:46 change 'Fix bug' """ +from collections import OrderedDict +from sanic.exceptions import InvalidUsage from sanic.request import Request as SanicRequest @@ -21,3 +23,11 @@ class SRFRequest(SanicRequest): def __init__(self, *args, **kwargs): super(SRFRequest, self).__init__(*args, **kwargs) self.user = None + + @property + def data(self): + try: + data = self.json + except InvalidUsage as exc: + data = self.form + return data diff --git a/sanic_rest_framework/serializers.py b/sanic_rest_framework/serializers.py index 228194c..7db3c8e 100644 --- a/sanic_rest_framework/serializers.py +++ b/sanic_rest_framework/serializers.py @@ -10,8 +10,9 @@ """ import copy import inspect +from asyncio import coroutine from collections import OrderedDict -from typing import Any, Mapping +from typing import Any, Mapping, Coroutine from tortoise import models, Model from tortoise.queryset import QuerySet @@ -124,10 +125,56 @@ class BaseSerializer(Field): value = self.validate(value) return value + async def update(self, instance, validated_data): + raise NotImplementedError('`update()` 必须实现.') + + async def create(self, validated_data): + raise NotImplementedError('`create()` 必须实现.') + + async def save(self, **kwargs): + assert hasattr(self, '_errors'), ( + '您必须先调用`.is_valid()`,然后再调用`.save()`。' + ) + + assert not self.errors, ( + '您不能在未通过效验的序列化器上调用`.save()`。' + ) + + # Guard against incorrect use of `serializer.save(commit=False)` + assert 'commit' not in kwargs, ( + "`commit`不是`save()`方法的有效关键字参数。" + "如果需要在提交数据库之前访问数据,则" + "而是检查`serializer.validated_data`." + "如果您还可以将其他关键字参数传递给`save()`," + "需要在保存的模型实例上设置额外的属性。" + "例如:'serializer.save(owner = request.user)'。" + ) + + assert not hasattr(self, '_data'), ( + '访问`serializer.data`后,您不能调用`.save()`。' + '如果需要在提交数据库之前访问数据,请访问' + '`serializer.validated_data`' + ) + + validated_data = dict(list(self.validated_data.items()) + list(kwargs.items())) + + if self.instance is not None: + self.instance = await self.update(self.instance, validated_data) + assert self.instance is not None, ( + '`update()` 没有返回对象实例。' + ) + else: + self.instance = await self.create(validated_data) + assert self.instance is not None, ( + '`create()` 没有返回对象实例。' + ) + + return self.instance + def is_valid(self, raise_exception=False): assert hasattr(self, 'initial_data'), ( - 'Cannot call `.is_valid()` as no `data=` keyword argument was ' - 'passed when instantiating the serializer instance.' + '无法调用`.is_valid()`,因为' + '类实例化时没有传入`data =`关键字参数' ) if not hasattr(self, '_validated_data'): @@ -376,6 +423,49 @@ class ListSerializer(BaseSerializer): value = self.validate(value) return value + async def update(self, instance, validated_data): + raise NotImplementedError( + '当 many=True 时,有些序列化器不支持更新操作,' + '所以在必须使用时请继承ListSerializer并覆盖' + '`.update()`,不提供默认实现方式') + + async def create(self, validated_data): + return [ + await self.child.create(attrs) for attrs in validated_data + ] + + async def save(self, **kwargs): + """ + 保存实例 + """ + # 防止错误使用`serializer.save(commit = False)` + assert 'commit' not in kwargs, ( + "`commit`不是`save()`方法的有效关键字参数。" + "如果需要在提交数据库之前访问数据,则" + "而是检查`serializer.validated_data`." + "如果您还可以将其他关键字参数传递给`save()`," + "需要在保存的模型实例上设置额外的属性。" + "例如:'serializer.save(owner = request.user)'。" + ) + + validated_data = [ + dict(list(attrs.items()) + list(kwargs.items())) + for attrs in self.validated_data + ] + + if self.instance is not None: + self.instance = await self.update(self.instance, validated_data) + assert self.instance is not None, ( + '`update()` 没有返回对象实例。' + ) + else: + self.instance = await self.create(validated_data) + assert self.instance is not None, ( + '`create()` 没有返回对象实例。' + ) + + return self.instance + class ModelSerializer(Serializer): """ @@ -476,6 +566,61 @@ class ModelSerializer(Serializer): field_dict.pop(clean_field_name) return field_dict + async def create(self, validated_data): + """ + 根据验证后的数据进行创建, + """ + # raise_errors_on_nested_writes('create', self, validated_data) + + ModelClass = self.Meta.model + ModelClassMeta = ModelClass._meta + many_to_many = {} + one_to_one = {} + one_to_many = {} + + for m2m_field in ModelClass._meta.m2m_fields: + if m2m_field in self.fields: + if m2m_field in validated_data: + many_to_many[m2m_field] = validated_data.pop(m2m_field) + o2o_field_names = ModelClassMeta.backward_o2o_fields + o2o_field_names.update() + + + # for m2m_field in ModelClass._meta.m2m_fields: + # if m2m_field in self.fields: + # if m2m_field in validated_data: + # many_to_many[m2m_field] = validated_data.pop(m2m_field) + + try: + instance = await ModelClass.create(**validated_data) + except TypeError as exc: + print(exc) + raise exc + if many_to_many: + for field_name, value in many_to_many.items(): + field = getattr(instance, field_name) + field.set(value) + + return instance + + def update(self, instance, validated_data): + ModelClass = self.Meta.model + + m2m_fields = [] + for attr, value in validated_data.items(): + if attr in ModelClass._meta.m2m_fields: + m2m_fields.append((attr, value)) + else: + setattr(instance, attr, value) + + instance.save() + + for attr, value in m2m_fields: + field = getattr(instance, attr) + field.set(value) + + return instance + # # def _get_model_basis_fields(self, model_fields): # """ diff --git a/sanic_rest_framework/test/models.py b/sanic_rest_framework/test/models.py index 175ed62..6194b4e 100644 --- a/sanic_rest_framework/test/models.py +++ b/sanic_rest_framework/test/models.py @@ -33,7 +33,7 @@ class TestModel(Model): uuid_field = fields.UUIDField() one_to_many: ForeignKeyRelation["TestManyToOneModel"] = fields.ForeignKeyField('models.TestManyToOneModel', related_name='many_to_one', null=True) one_to_one: fields.OneToOneRelation["TestOneToOneModel"] = fields.OneToOneField('models.TestOneToOneModel', related_name='one_to_one', null=True) - many_to_many: fields.ManyToManyRelation["TestModel"] = fields.ManyToManyField("models.TestModel", related_name="many_2_many", through="many_many", null=True) + many_to_many: fields.ManyToManyRelation["TestModel"] = fields.ManyToManyField("models.ManyToManyModel", related_name="many_2_many", through="many_many", null=True) class TestOneTOManyModel(Model): diff --git a/sanic_rest_framework/views.py b/sanic_rest_framework/views.py index 5c6aead..9de0256 100644 --- a/sanic_rest_framework/views.py +++ b/sanic_rest_framework/views.py @@ -13,7 +13,7 @@ from tortoise.queryset import QuerySet from sanic.log import logger from sanic_rest_framework.constant import ALL_METHOD -from sanic_rest_framework.exceptions import APIException +from sanic_rest_framework.exceptions import APIException, ValidationError from sanic_rest_framework.filters import SearchFilter from sanic_rest_framework.paginations import GeneralPagination from sanic_rest_framework.status import RuleStatus, HttpStatus @@ -38,6 +38,8 @@ class BaseAPIView: response = await handler(request, *args, **kwargs) except APIException as exc: response = self.handle_exception(exc) + except ValidationError as exc: + response = self.error_json_response(exc.message_dict, '数据验证失败') except AssertionError as exc: raise exc except Exception as exc: @@ -96,24 +98,27 @@ class BaseAPIView: } return json(body=response_body, status=http_status) - def success_json_response(self, data=None, msg="Success"): + def success_json_response(self, data=None, msg="Success", **kwargs): """ 快捷的成功的json响应体 :param data: 返回的数据主题 :param msg: 前台提示字符串 :return: json """ - return self.json_response(data=data, msg=msg, status=RuleStatus.STATUS_1_SUCCESS) + status = kwargs.pop('status', RuleStatus.STATUS_1_SUCCESS) + http_status = kwargs.pop('http_status', HttpStatus.HTTP_200_OK) + return self.json_response(data=data, msg=msg, status=status, http_status=http_status) - def error_json_response(self, data=None, msg="Fail"): + def error_json_response(self, data=None, msg="Fail", **kwargs): """ 快捷的失败的json响应体 :param data: 返回的数据主题 :param msg: 前台提示字符串 :return: json """ - return self.json_response(data=data, msg=msg, status=RuleStatus.STATUS_0_FAIL, - http_status=HttpStatus.HTTP_400_BAD_REQUEST) + status = kwargs.pop('status', RuleStatus.STATUS_0_FAIL) + http_status = kwargs.pop('http_status', HttpStatus.HTTP_400_BAD_REQUEST) + return self.json_response(data=data, msg=msg, status=status, http_status=http_status) async def get_object(self): """ -- Gitee From 93b0a99f78aaad06db77bbf7de012e04b59bbb91 Mon Sep 17 00:00:00 2001 From: LaoSi Date: Mon, 15 Mar 2021 23:00:48 +0800 Subject: [PATCH 30/34] =?UTF-8?q?=E6=8F=90=E4=BA=A4=E8=BD=AC=E6=8D=A2?= =?UTF-8?q?=E5=99=A8=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sanic_rest_framework/converter.py | 28 ++++++++- .../test/test_models/__init__.py | 15 +++++ .../test/test_models/test_o2o.py | 55 +++++++++++++++++ .../test/test_models/test_otm_mto.py | 61 +++++++++++++++++++ 4 files changed, 158 insertions(+), 1 deletion(-) create mode 100644 sanic_rest_framework/test/test_models/__init__.py create mode 100644 sanic_rest_framework/test/test_models/test_o2o.py create mode 100644 sanic_rest_framework/test/test_models/test_otm_mto.py diff --git a/sanic_rest_framework/converter.py b/sanic_rest_framework/converter.py index faab20b..e39ab74 100644 --- a/sanic_rest_framework/converter.py +++ b/sanic_rest_framework/converter.py @@ -149,6 +149,19 @@ class ModelConverter(ModelConverterBase): return NestedSerializer(many=True) + @converts('BackwardFKRelation') + def convert_manytomany(self, model, model_field, *field_args, **field_kws): + nested_depth = field_kws.get('nested_depth', 10) + + class NestedSerializer(self.nested_field_class): + class Meta: + model = model_field.related_model + depth = nested_depth - 1 + fields = '__all__' + + field_kws['read_only'] = True + return NestedSerializer(many=True, **field_kws) + @converts('ForeignKeyFieldInstance', 'OneToOneFieldInstance', 'BackwardOneToOneRelation') def convert_manytoone(self, model, model_field, *field_args, **field_kws): nested_depth = field_kws.get('nested_depth', 10) @@ -159,4 +172,17 @@ class ModelConverter(ModelConverterBase): depth = nested_depth - 1 fields = '__all__' - return NestedSerializer() + return NestedSerializer(**field_kws) + + @converts('BackwardOneToOneRelation') + def convert_manytoone(self, model, model_field, *field_args, **field_kws): + nested_depth = field_kws.get('nested_depth', 10) + + class NestedSerializer(self.nested_field_class): + class Meta: + model = model_field.related_model + depth = nested_depth - 1 + fields = '__all__' + + field_kws['read_only'] = True + return NestedSerializer(**field_kws) diff --git a/sanic_rest_framework/test/test_models/__init__.py b/sanic_rest_framework/test/test_models/__init__.py new file mode 100644 index 0000000..e1bc933 --- /dev/null +++ b/sanic_rest_framework/test/test_models/__init__.py @@ -0,0 +1,15 @@ +""" +@Author: WangYuXiang +@E-mile: Hill@3io.cc +@CreateTime: 2021/3/15 19:18 +@DependencyLibrary: 无 +@MainFunction:无 +@FileDoc: + __init__.py.py + 文件描述 +@ChangeHistory: + datetime action why + example: + 2021/2/4 14:35 change 'Fix bug' + +""" diff --git a/sanic_rest_framework/test/test_models/test_o2o.py b/sanic_rest_framework/test/test_models/test_o2o.py new file mode 100644 index 0000000..6be73f8 --- /dev/null +++ b/sanic_rest_framework/test/test_models/test_o2o.py @@ -0,0 +1,55 @@ +""" +@Author: WangYuXiang +@E-mile: Hill@3io.cc +@CreateTime: 2021/3/15 19:05 +@DependencyLibrary: 无 +@MainFunction:无 +@FileDoc: + test_o2o.py + 一对一模型 +@ChangeHistory: + datetime action why + example: + 2021/2/4 14:35 change 'Fix bug' + +""" + +from tortoise import Model, Tortoise, run_async +from tortoise import fields + + +class User(Model): + account = fields.CharField(12) + password = fields.CharField(128) + + +class Students(Model): + name = fields.CharField(30) + user: fields.OneToOneRelation[User] = fields.OneToOneField('models.User', 'student', null=True) + + +async def run(): + await Tortoise.init(db_url="sqlite://:memory:", modules={"models": ["__main__"]}) + await Tortoise.generate_schemas() + + # 测试正向创建 + user = User(account='17674707037', password='123456') + await user.save() + std = Students(name='刘桂香', user=user) + await std.save() + + # 测试逆向创建 + std = Students(name='刘桂香') + user = User(account='17674707037', password='123456') + std.user = user + await std.save() + await user.save() + # 结论 一对一只能由单边发起 + # 列如在 Students 类中设置了一对一关系,那么必须先创建对应的 User 然后才能创建 Students + # 在 Students.user 允许为空的情况下,可以先创建 Students 再创建 User + # 先创建 Students 再创建 User ,如果需要关联一对一那么只能 Students.user=User, + # 不可以 User.student = Students ,因为不支持反向设置一对一关系 + + +if __name__ == "__main__": + run_async(run()) diff --git a/sanic_rest_framework/test/test_models/test_otm_mto.py b/sanic_rest_framework/test/test_models/test_otm_mto.py new file mode 100644 index 0000000..0e3cfb8 --- /dev/null +++ b/sanic_rest_framework/test/test_models/test_otm_mto.py @@ -0,0 +1,61 @@ +""" +@Author: WangYuXiang +@E-mile: Hill@3io.cc +@CreateTime: 2021/3/15 19:45 +@DependencyLibrary: 无 +@MainFunction:无 +@FileDoc: + test_otm_mto.py + 文件描述 +@ChangeHistory: + datetime action why + example: + 2021/2/4 14:35 change 'Fix bug' + +""" + +from tortoise import Model, Tortoise, run_async +from tortoise import fields + + +class ClassRoom(Model): + class_num = fields.IntField() + class_name = fields.CharField(128) + students = fields.ReverseRelation['Students'] + + +class Students(Model): + name = fields.CharField(30) + class_room: fields.ForeignKeyRelation[ClassRoom] = fields.ForeignKeyField('models.ClassRoom', 'students', + null=True) + + +async def run(): + await Tortoise.init(db_url="sqlite::memory:", modules={"models": ["__main__"]}) + await Tortoise.generate_schemas() + + # 测试正向创建 + cr = await ClassRoom.create(**{ + 'class_num': 1, + 'class_name': '1000' + }) + std = Students(name='刘桂香', class_room=cr) + await std.save() + cr = await ClassRoom.create(**{ + 'class_num': 1, + 'class_name': '1000' + }) + print(await cr.students) + std = Students(name='刘桂香') + std.class_room = cr + await std.save() + # 结论多对一关系和一对多关系都只能由单边发起 + # 列如在 Students 类中设置了多对一关系,那么必须先创建对应的 ClassRoom 然后才能创建 Students + # 在 Students.class_room 允许为空的情况下,可以先创建 Students 再创建 ClassRoom + # 先创建 Students 再创建 ClassRoom ,如果需要关联多对一那么只能 Students.class_room=ClassRoom, + # 不可以 ClassRoom.students = Students ,因为不支持反向设置多对一关系 + # 另外 ClassRoom.students 是一个方向查询对象 可以使用 .all() .filter() 等方法生成 Queryset + + +if __name__ == "__main__": + run_async(run()) -- Gitee From 4327e15f83f23c4bae746792d3346f93054c3af5 Mon Sep 17 00:00:00 2001 From: LaoSi Date: Wed, 17 Mar 2021 18:00:21 +0800 Subject: [PATCH 31/34] =?UTF-8?q?=E5=BE=85=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- run.py | 11 +++ sanic_rest_framework/converter.py | 41 ++++++---- sanic_rest_framework/fields.py | 4 +- sanic_rest_framework/serializers.py | 74 ++++++++++--------- sanic_rest_framework/test/models.py | 4 +- .../test/test_models/test_o2o.py | 6 +- sanic_rest_framework/views.py | 24 +++--- 7 files changed, 100 insertions(+), 64 deletions(-) diff --git a/run.py b/run.py index 5926f3f..5ddde79 100644 --- a/run.py +++ b/run.py @@ -6,6 +6,7 @@ from db import AddressModel from sanic_rest_framework.request import SRFRequest from sanic_rest_framework.routes import Route from sanic_rest_framework.serializers import Serializer, ModelSerializer +from sanic_rest_framework.fields import IntegerField from sanic_rest_framework.fields import CharField from sanic_rest_framework.views import ( APIView, RetrieveModelMixin, ListModelMixin, CreateModelMixin, DestroyModelMixin, UpdateModelMixin @@ -18,6 +19,9 @@ admin = Blueprint('admin', '/admin') class TestSerializer(ModelSerializer): class Meta: model = AddressModel + exclude = ('school',) + + school_id = IntegerField(label='ap') class Test1View(RetrieveModelMixin, APIView): @@ -46,10 +50,17 @@ class Test3View(CreateModelMixin, APIView): search_fields = ('@phone',) +class Test4View(UpdateModelMixin, APIView): + queryset = AddressModel + serializer_class = TestSerializer + search_fields = ('@phone',) + + route = Route() route.register_route('test1', Test1View) route.register_route('test2', Test2View) route.register_route('test3', Test3View) +route.register_route('test4', Test4View) # route.initialize(admin) route.initialize(app) app.blueprint(admin) diff --git a/sanic_rest_framework/converter.py b/sanic_rest_framework/converter.py index e39ab74..5eeaf20 100644 --- a/sanic_rest_framework/converter.py +++ b/sanic_rest_framework/converter.py @@ -137,9 +137,10 @@ class ModelConverter(ModelConverterBase): def convert_floatfield(self, model, model_field, *field_args, **field_kws): return FloatField(*field_args, **field_kws) - @converts('ManyToManyFieldInstance', 'BackwardFKRelation', 'ManyToManyRelation') + @converts('ManyToManyFieldInstance', 'ManyToManyRelation') def convert_manytomany(self, model, model_field, *field_args, **field_kws): - nested_depth = field_kws.get('nested_depth', 10) + """多对多""" + nested_depth = field_kws.pop('nested_depth', 10) class NestedSerializer(self.nested_field_class): class Meta: @@ -149,9 +150,10 @@ class ModelConverter(ModelConverterBase): return NestedSerializer(many=True) - @converts('BackwardFKRelation') - def convert_manytomany(self, model, model_field, *field_args, **field_kws): - nested_depth = field_kws.get('nested_depth', 10) + @converts('BackwardFKRelation', ) + def convert_backwardfkrelation(self, model, model_field, *field_args, **field_kws): + """反向一对多""" + nested_depth = field_kws.pop('nested_depth', 10) class NestedSerializer(self.nested_field_class): class Meta: @@ -159,12 +161,25 @@ class ModelConverter(ModelConverterBase): depth = nested_depth - 1 fields = '__all__' - field_kws['read_only'] = True - return NestedSerializer(many=True, **field_kws) + return NestedSerializer(many=True) + + @converts('ForeignKeyFieldInstance') + def convert_foreignkeyfieldinstance(self, model, model_field, *field_args, **field_kws): + """正向多对一""" + nested_depth = field_kws.pop('nested_depth', 10) + + class NestedSerializer(self.nested_field_class): + class Meta: + model = model_field.related_model + depth = nested_depth - 1 + fields = '__all__' + + return NestedSerializer(**field_kws) - @converts('ForeignKeyFieldInstance', 'OneToOneFieldInstance', 'BackwardOneToOneRelation') - def convert_manytoone(self, model, model_field, *field_args, **field_kws): - nested_depth = field_kws.get('nested_depth', 10) + @converts('OneToOneFieldInstance') + def convert_onetoonefieldinstance(self, model, model_field, *field_args, **field_kws): + """正向 一对一""" + nested_depth = field_kws.pop('nested_depth', 10) class NestedSerializer(self.nested_field_class): class Meta: @@ -175,8 +190,9 @@ class ModelConverter(ModelConverterBase): return NestedSerializer(**field_kws) @converts('BackwardOneToOneRelation') - def convert_manytoone(self, model, model_field, *field_args, **field_kws): - nested_depth = field_kws.get('nested_depth', 10) + def convert_backwardonetoonerelation(self, model, model_field, *field_args, **field_kws): + """反向一对一""" + nested_depth = field_kws.pop('nested_depth', 10) class NestedSerializer(self.nested_field_class): class Meta: @@ -184,5 +200,4 @@ class ModelConverter(ModelConverterBase): depth = nested_depth - 1 fields = '__all__' - field_kws['read_only'] = True return NestedSerializer(**field_kws) diff --git a/sanic_rest_framework/fields.py b/sanic_rest_framework/fields.py index c77a82c..aceb5d6 100644 --- a/sanic_rest_framework/fields.py +++ b/sanic_rest_framework/fields.py @@ -13,6 +13,7 @@ import decimal import re from datetime import timezone, timedelta, datetime, date, time from enum import Enum +from inspect import isawaitable, iscoroutine from typing import Any, List, Mapping from tortoise import Model from tortoise.queryset import QuerySet @@ -214,14 +215,13 @@ class Field: :param instance: *内部* 数据 :return: """ - for attr in self.source_attrs: try: if isinstance(instance, Mapping): instance = instance[attr] else: instance = getattr(instance, attr) - if isinstance(instance, QuerySet): + if isawaitable(instance): instance = await instance except DoesNotExist: return None diff --git a/sanic_rest_framework/serializers.py b/sanic_rest_framework/serializers.py index 7db3c8e..162b01b 100644 --- a/sanic_rest_framework/serializers.py +++ b/sanic_rest_framework/serializers.py @@ -293,8 +293,14 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass): res = OrderedDict() fields = self._readable_fields for field in fields: - value = await field.get_internal_value(data) - res[field.field_name] = await field.internal_to_external(value) + try: + value = await field.get_internal_value(data) + except SkipField: + continue + if value is None: + res[field.field_name] = None + else: + res[field.field_name] = await field.internal_to_external(value) return res # 反序列化 @@ -369,6 +375,9 @@ class ListSerializer(BaseSerializer): instance = [instance] else: instance = await getattr(instance, attr) + # instance = getattr(instance, attr) + # if inspect.iscoroutine(instance): + # instance = await instance return instance async def internal_to_external(self, data: Any) -> Any: @@ -377,7 +386,7 @@ class ListSerializer(BaseSerializer): :param data: :return: """ - iterable = await data.all() if isinstance(data, (QuerySet, tortoise_fields.relational.RelationalField)) else data + iterable = await data.all() if inspect.isawaitable(data) else data return [ await self.child.internal_to_external(item) for item in iterable @@ -400,11 +409,11 @@ class ListSerializer(BaseSerializer): for item in data: try: - validated = self.child.run_validation(item) + value = self.child.run_validation(item) except ValidationError as exc: errors.append(exc) else: - ret.append(validated) + ret.append(value) errors.append({}) if any(errors): raise ValidationError(errors) @@ -502,10 +511,14 @@ class ModelSerializer(Serializer): declared_fields = copy.deepcopy(self._declared_fields) model = getattr(self.Meta, 'model') depth = getattr(self.Meta, 'depth', 10) + if depth is not None: + assert depth >= 0, "'depth' may not be negative." + assert depth <= 10, "'depth' may not be greater than 10." converter = ModelConverter(ModelSerializer) - model_fields = self._clean_model_field(model) - effective_field = self.get_effective_field(model_fields) + model_original_fields = model._meta.fields_map + model_clean_fields = self._clean_model_field(model_original_fields) + effective_field = self.get_effective_field(model_clean_fields) serializer_fields = BindingDict(self) for field_name, field_class in effective_field.items(): @@ -529,7 +542,7 @@ class ModelSerializer(Serializer): if meta_exclude is not None: return {k: v for k, v in model_fields.items() if k not in meta_exclude} - elif meta_exclude is None and meta_fields is None: + elif meta_exclude is None and (meta_fields is None or meta_fields is ALL_FIELDS): return model_fields else: return {k: v for k, v in model_fields.items() if k in meta_fields} @@ -547,7 +560,7 @@ class ModelSerializer(Serializer): else: return {'read_only': False, 'write_only': False} - def _clean_model_field(self, model): + def _clean_model_field(self, model_original_fields): """ 清除不需要的字段如 fk_id :param model: @@ -555,7 +568,7 @@ class ModelSerializer(Serializer): """ clean_field_names = [] field_dict = {} - fields_map = copy.deepcopy(model._meta.fields_map) + fields_map = copy.deepcopy(model_original_fields) for field_name, field_class in fields_map.items(): if isinstance(field_class, (tortoise_fields.relational.ForeignKeyFieldInstance, tortoise_fields.relational.OneToOneFieldInstance)): clean_field_names.append(field_class.source_field) @@ -575,52 +588,45 @@ class ModelSerializer(Serializer): ModelClass = self.Meta.model ModelClassMeta = ModelClass._meta many_to_many = {} - one_to_one = {} - one_to_many = {} - for m2m_field in ModelClass._meta.m2m_fields: + for m2m_field in ModelClassMeta.m2m_fields: if m2m_field in self.fields: if m2m_field in validated_data: many_to_many[m2m_field] = validated_data.pop(m2m_field) - o2o_field_names = ModelClassMeta.backward_o2o_fields - o2o_field_names.update() - - - # for m2m_field in ModelClass._meta.m2m_fields: - # if m2m_field in self.fields: - # if m2m_field in validated_data: - # many_to_many[m2m_field] = validated_data.pop(m2m_field) - try: instance = await ModelClass.create(**validated_data) except TypeError as exc: - print(exc) + raise exc if many_to_many: - for field_name, value in many_to_many.items(): + for field_name, values in many_to_many.items(): field = getattr(instance, field_name) - field.set(value) - + for value in values: + await field.add(value) return instance - def update(self, instance, validated_data): + async def update(self, instance, validated_data): + """更新""" ModelClass = self.Meta.model + ModelClassMeta = ModelClass._meta m2m_fields = [] for attr, value in validated_data.items(): - if attr in ModelClass._meta.m2m_fields: + if attr in ModelClassMeta.m2m_fields: m2m_fields.append((attr, value)) else: setattr(instance, attr, value) - - instance.save() - - for attr, value in m2m_fields: + await instance.save() + for attr, values in m2m_fields: field = getattr(instance, attr) - field.set(value) - + for value in values: + value, _ = await field.remote_model.get_or_create(**value) + await field.add(value) return instance + def check_relationship(self): + """检查关系字段 目前不能为关系字段提供自动转换功能""" + # # def _get_model_basis_fields(self, model_fields): # """ diff --git a/sanic_rest_framework/test/models.py b/sanic_rest_framework/test/models.py index 6194b4e..7b43ab0 100644 --- a/sanic_rest_framework/test/models.py +++ b/sanic_rest_framework/test/models.py @@ -1,5 +1,5 @@ from datetime import date -from enum import Enum,IntEnum +from enum import Enum, IntEnum from tortoise import fields from tortoise import Model @@ -58,7 +58,7 @@ class UserModel(Model): phone = fields.CharField(max_length=11) balance = fields.DecimalField(13, 3) address: fields.ManyToManyRelation["AddressModel"] = fields.ManyToManyField( - 'models.AddressModel', through='user2address', related_name='user') + 'models.AddressModel', through='user2address', related_name='user', null=True) class AddressModel(Model): diff --git a/sanic_rest_framework/test/test_models/test_o2o.py b/sanic_rest_framework/test/test_models/test_o2o.py index 6be73f8..9beab9e 100644 --- a/sanic_rest_framework/test/test_models/test_o2o.py +++ b/sanic_rest_framework/test/test_models/test_o2o.py @@ -29,7 +29,7 @@ class Students(Model): async def run(): - await Tortoise.init(db_url="sqlite://:memory:", modules={"models": ["__main__"]}) + await Tortoise.init(db_url="sqlite://./db.sqlite", modules={"models": ["__main__"]}) await Tortoise.generate_schemas() # 测试正向创建 @@ -37,6 +37,10 @@ async def run(): await user.save() std = Students(name='刘桂香', user=user) await std.save() + print(type(user.student)) + c = await user.student + print(await user.student) + print(await user.student.model) # 测试逆向创建 std = Students(name='刘桂香') diff --git a/sanic_rest_framework/views.py b/sanic_rest_framework/views.py index 9de0256..ed333f2 100644 --- a/sanic_rest_framework/views.py +++ b/sanic_rest_framework/views.py @@ -355,30 +355,30 @@ class UpdateModelMixin: 适用于快速创建更新操作 """ - def put(self, request, *args, **kwargs): - return self.update(request, *args, **kwargs) + async def put(self, request, *args, **kwargs): + return await self.update(request, *args, **kwargs) - def patch(self, request, *args, **kwargs): - return self.partial_update(request, *args, **kwargs) + async def patch(self, request, *args, **kwargs): + return await self.partial_update(request, *args, **kwargs) - def update(self, request, *args, **kwargs): + async def update(self, request, *args, **kwargs): partial = kwargs.pop('partial', False) - instance = self.get_object() + instance = await self.get_object() serializer = self.get_serializer(instance, data=request.data, partial=partial) serializer.is_valid(raise_exception=True) - self.perform_update(serializer) + await self.perform_update(serializer) # if getattr(instance, '_prefetched_objects_cache', None): # instance._prefetched_objects_cache = {} - return self.success_json_response(serializer.data) + return self.success_json_response(data=await serializer.data) - def perform_update(self, serializer): - serializer.save() + async def perform_update(self, serializer): + await serializer.save() - def partial_update(self, request, *args, **kwargs): + async def partial_update(self, request, *args, **kwargs): kwargs['partial'] = True - return self.update(request, *args, **kwargs) + return await self.update(request, *args, **kwargs) class DestroyModelMixin: -- Gitee From f5db9f61cd3d027c61b8094271ddcffbae1ec062 Mon Sep 17 00:00:00 2001 From: LaoSi Date: Thu, 18 Mar 2021 18:06:52 +0800 Subject: [PATCH 32/34] =?UTF-8?q?=E4=B8=8D=E8=87=AA=E5=8A=A8=E7=94=9F?= =?UTF-8?q?=E6=88=90=E5=A4=96=E9=94=AE=EF=BC=8C=E4=B9=9F=E4=B8=8D=E4=B8=BB?= =?UTF-8?q?=E5=8A=A8=E5=A4=84=E7=90=86=E5=A4=96=E9=94=AE=E4=B8=8E=E5=A4=9A?= =?UTF-8?q?=E5=AF=B9=E5=A4=9A=E7=9A=84=E5=80=BC=20=E6=8F=90=E4=BE=9B=20bef?= =?UTF-8?q?ore=5Fcreate=20=E5=92=8C=20before=5Fupdate=20=E6=8E=A5=E5=8F=A3?= =?UTF-8?q?=E5=A4=84=E7=90=86=E4=BB=A3=E7=A0=81=20=E9=9C=80=E8=A6=81?= =?UTF-8?q?=E4=B8=BB=E5=8A=A8=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- run.py | 57 +++++++++-------------- sanic_rest_framework/fields.py | 25 +++++----- sanic_rest_framework/filters.py | 2 +- sanic_rest_framework/serializers.py | 72 +++++++++++++++++------------ sanic_rest_framework/views.py | 30 ++++++------ 5 files changed, 95 insertions(+), 91 deletions(-) diff --git a/run.py b/run.py index 5ddde79..ff25ba6 100644 --- a/run.py +++ b/run.py @@ -2,66 +2,53 @@ from sanic import Sanic from sanic.blueprints import Blueprint from tortoise.contrib.sanic import register_tortoise -from db import AddressModel +from db import AddressModel, UserModel from sanic_rest_framework.request import SRFRequest from sanic_rest_framework.routes import Route from sanic_rest_framework.serializers import Serializer, ModelSerializer from sanic_rest_framework.fields import IntegerField from sanic_rest_framework.fields import CharField from sanic_rest_framework.views import ( - APIView, RetrieveModelMixin, ListModelMixin, CreateModelMixin, DestroyModelMixin, UpdateModelMixin + APIView, DetailModelMixin, ListModelMixin, CreateModelMixin, DestroyModelMixin, UpdateModelMixin ) app = Sanic(__name__, request_class=SRFRequest) admin = Blueprint('admin', '/admin') -class TestSerializer(ModelSerializer): +class UserSerializer(ModelSerializer): class Meta: - model = AddressModel + model = UserModel exclude = ('school',) - school_id = IntegerField(label='ap') - - -class Test1View(RetrieveModelMixin, APIView): - queryset = AddressModel - serializer_class = TestSerializer - search_fields = ('phone',) - - async def retrieve(self, request, *args, **kwargs): - for i in range(100): - adr = AddressModel() - adr.phone = 1767470900 + i - adr.house_number = '房间%s' % i - adr.address = '地址%s' % i - await adr.save() - -class Test2View(ListModelMixin, APIView): - queryset = AddressModel - serializer_class = TestSerializer - search_fields = ('@phone',) +class AddressSerializer(ModelSerializer): + class Meta: + model = AddressModel + exclude = ('school',) + user = UserSerializer(many=True) + school_id = IntegerField(label='ap') -class Test3View(CreateModelMixin, APIView): - queryset = AddressModel - serializer_class = TestSerializer - search_fields = ('@phone',) + async def before_create(self, validated_data, instance): + users = validated_data.pop('user') + user_obj = [] + for user in users: + us = UserSerializer(data=user) + await us.is_valid(True) + user_obj.append(await us.save()) + validated_data['user'] = user_obj + return validated_data, instance -class Test4View(UpdateModelMixin, APIView): +class Test1View(ListModelMixin, CreateModelMixin, UpdateModelMixin, DestroyModelMixin, APIView): queryset = AddressModel - serializer_class = TestSerializer - search_fields = ('@phone',) + serializer_class = AddressSerializer + search_fields = ('phone', 'user__name', 'address') route = Route() route.register_route('test1', Test1View) -route.register_route('test2', Test2View) -route.register_route('test3', Test3View) -route.register_route('test4', Test4View) -# route.initialize(admin) route.initialize(app) app.blueprint(admin) diff --git a/sanic_rest_framework/fields.py b/sanic_rest_framework/fields.py index aceb5d6..fd9758a 100644 --- a/sanic_rest_framework/fields.py +++ b/sanic_rest_framework/fields.py @@ -176,7 +176,7 @@ class Field: root = self.root return getattr(root, 'partial', False) - def external_to_internal(self, data: Any) -> Any: + async def external_to_internal(self, data: Any) -> Any: """对数据进行反序列化转换并返回""" raise NotImplementedError( '{cls}类的 .external_to_internal 方法必须重写'.format(cls=self.__class__.__name__, ) @@ -284,12 +284,12 @@ class Field: return True, None return (False, data) - def run_validation(self, data): + async def run_validation(self, data): """执行验证""" (is_empty_value, data) = self.validate_empty_values(data) if is_empty_value: return data - value = self.external_to_internal(data) + value = await self.external_to_internal(data) self.run_validators(value) return value @@ -327,6 +327,9 @@ class Field: message_string = msg.format(**kwargs) raise ValidationError(message_string, code=_key) + def __str__(self): + return super(Field, self).__str__() + self.field_name + class CharField(Field): default_error_messages = { @@ -345,7 +348,7 @@ class CharField(Field): if self.min_length is not None: self.validators.append(MinLengthValidator(min_length=self.min_length, error_messages={'min_length': self.error_messages['min_length']})) - def external_to_internal(self, data: Any) -> Any: + async def external_to_internal(self, data: Any) -> Any: if isinstance(data, bool) or not isinstance(data, (str, int, float,)): self.raise_error('invalid') value = str(data) @@ -375,7 +378,7 @@ class IntegerField(Field): if self.min_value is not None: self.validators.append(MinValueValidator(min_value=self.min_value, error_messages={'min_value': self.error_messages['min_value']})) - def external_to_internal(self, data: Any): + async def external_to_internal(self, data: Any): if isinstance(data, str) and len(data) > self.MAX_STRING_LENGTH: self.raise_error('max_string_length', max_string_length=self.MAX_STRING_LENGTH) try: @@ -398,7 +401,7 @@ class FloatField(IntegerField): } MAX_STRING_LENGTH = 1000 - def external_to_internal(self, data: Any): + async def external_to_internal(self, data: Any): if isinstance(data, bool): self.raise_error('invalid', data_type=type(data).__name__) if isinstance(data, str) and len(data) > self.MAX_STRING_LENGTH: @@ -454,7 +457,7 @@ class DecimalField(Field): if self.min_value is not None: self.validators.append(MinValueValidator(min_value=self.min_value, error_messages={'min_value': self.error_messages['min_value']})) - def external_to_internal(self, data: Any): + async def external_to_internal(self, data: Any): data = str(data).strip() if len(data) > self.MAX_STRING_LENGTH: @@ -553,7 +556,7 @@ class BooleanField(Field): } NULL_VALUES = {'null', 'Null', 'NULL', '', None} - def external_to_internal(self, data: Any) -> Any: + async def external_to_internal(self, data: Any) -> Any: try: if data in self.TRUE_VALUES: return True @@ -600,7 +603,7 @@ class DateTimeField(Field): """强制设置一个时区""" return value.astimezone(self.set_timezone) - def external_to_internal(self, data: Any) -> Any: + async def external_to_internal(self, data: Any) -> Any: if not isinstance(data, (str, date, datetime)): self.raise_error('convert') if type(data) == date: @@ -637,7 +640,7 @@ class DateField(Field): self.input_formats = input_formats super(DateField, self).__init__(*args, **kwargs) - def external_to_internal(self, data: Any) -> Any: + async def external_to_internal(self, data: Any) -> Any: if isinstance(data, str): try: data = datetime.strptime(data, self.input_formats).date() @@ -674,7 +677,7 @@ class TimeField(Field): self.input_formats = input_formats super(TimeField, self).__init__(*args, **kwargs) - def external_to_internal(self, data: Any) -> Any: + async def external_to_internal(self, data: Any) -> Any: if isinstance(data, str): try: data = datetime.strptime(data, self.input_formats).time() diff --git a/sanic_rest_framework/filters.py b/sanic_rest_framework/filters.py index a680f60..76a313d 100644 --- a/sanic_rest_framework/filters.py +++ b/sanic_rest_framework/filters.py @@ -70,7 +70,7 @@ class SearchFilter(SimpleFilter): """ lookup_suffix_keys = list(self.lookup_prefixes.keys()) lookup_suffix = None - field_name = None + field_name = search_field for lookup_suffix_key in lookup_suffix_keys: if lookup_suffix_key in search_field: lookup_suffix = self.lookup_prefixes[lookup_suffix_key] diff --git a/sanic_rest_framework/serializers.py b/sanic_rest_framework/serializers.py index 162b01b..a4197d3 100644 --- a/sanic_rest_framework/serializers.py +++ b/sanic_rest_framework/serializers.py @@ -105,7 +105,7 @@ class BaseSerializer(Field): list_serializer_class = getattr(meta, 'list_serializer_class', ListSerializer) return list_serializer_class(*args, **list_kwargs) - def external_to_internal(self, data: Any) -> Any: + async def external_to_internal(self, data: Any) -> Any: """对数据进行序列化转换并返回""" raise NotImplementedError( '{cls}类的 .external_to_internal 方法必须重写'.format(cls=self.__class__.__name__, ) @@ -117,12 +117,14 @@ class BaseSerializer(Field): '{cls}类的 .external_to_internal 方法必须重写'.format(cls=self.__class__.__name__, ) ) - def validate(self, data): - return data + async def validate(self, attr): + return attr - def run_validation(self, data): - value = super(BaseSerializer, self).run_validation(data) + async def run_validation(self, data): + value = await super(BaseSerializer, self).run_validation(data) value = self.validate(value) + if inspect.isawaitable(value): + value = await value return value async def update(self, instance, validated_data): @@ -157,13 +159,14 @@ class BaseSerializer(Field): ) validated_data = dict(list(self.validated_data.items()) + list(kwargs.items())) - if self.instance is not None: + validated_data, self.instance = await self.before_update(validated_data, self.instance) self.instance = await self.update(self.instance, validated_data) assert self.instance is not None, ( '`update()` 没有返回对象实例。' ) else: + validated_data, self.instance = await self.before_create(validated_data, self.instance) self.instance = await self.create(validated_data) assert self.instance is not None, ( '`create()` 没有返回对象实例。' @@ -171,7 +174,13 @@ class BaseSerializer(Field): return self.instance - def is_valid(self, raise_exception=False): + async def before_update(self, validated_data, instance): + return validated_data, instance + + async def before_create(self, validated_data, instance): + return validated_data, instance + + async def is_valid(self, raise_exception=False): assert hasattr(self, 'initial_data'), ( '无法调用`.is_valid()`,因为' '类实例化时没有传入`data =`关键字参数' @@ -179,7 +188,7 @@ class BaseSerializer(Field): if not hasattr(self, '_validated_data'): try: - self._validated_data = self.run_validation(self.initial_data) + self._validated_data = await self.run_validation(self.initial_data) except ValidationError as exc: self._validated_data = {} if hasattr(exc, 'error_dict'): @@ -305,7 +314,7 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass): # 反序列化 - def external_to_internal(self, data: Any) -> Any: + async def external_to_internal(self, data: Any) -> Any: """ 外转内 :param data: @@ -318,7 +327,7 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass): validate_method = getattr(self, 'validate_' + field.field_name, None) try: primitive_value = field.get_external_value(data) - validated_value = field.run_validation(primitive_value) + validated_value = await field.run_validation(primitive_value) if validate_method is not None: validated_value = validate_method(validated_value) except ValidationError as exc: @@ -331,7 +340,7 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass): raise ValidationError(errors) return res - def validate(self, attrs): + async def validate(self, attrs): return attrs def __iter__(self): @@ -392,7 +401,7 @@ class ListSerializer(BaseSerializer): await self.child.internal_to_external(item) for item in iterable ] - def external_to_internal(self, data: Any) -> Any: + async def external_to_internal(self, data: Any) -> Any: """ 外转内 :param data: @@ -409,7 +418,7 @@ class ListSerializer(BaseSerializer): for item in data: try: - value = self.child.run_validation(item) + value = await self.child.run_validation(item) except ValidationError as exc: errors.append(exc) else: @@ -419,7 +428,7 @@ class ListSerializer(BaseSerializer): raise ValidationError(errors) return ret - def run_validation(self, data=empty): + async def run_validation(self, data=empty): """ 我们覆盖默认的`run_validation`,因为验证由验证者执行, 而.validate()方法应使用“non_fields_error”键被强制为错误字典。 @@ -427,9 +436,9 @@ class ListSerializer(BaseSerializer): (is_empty_value, data) = self.validate_empty_values(data) if is_empty_value: return data - value = self.external_to_internal(data) + value = await self.external_to_internal(data) self.run_validators(value) - value = self.validate(value) + value = await self.validate(value) return value async def update(self, instance, validated_data): @@ -522,11 +531,9 @@ class ModelSerializer(Serializer): serializer_fields = BindingDict(self) for field_name, field_class in effective_field.items(): - if field_name in declared_fields: - current_field_class = declared_fields[field_name] - else: - current_field_class = converter.convert(self, field_class, **self.get_field_kws_by_meta(field_name)) + current_field_class = converter.convert(self, field_class, **self.get_field_kws_by_meta(field_name)) serializer_fields[field_name] = current_field_class + serializer_fields.update(declared_fields) return serializer_fields def get_effective_field(self, model_fields) -> dict: @@ -548,6 +555,11 @@ class ModelSerializer(Serializer): return {k: v for k, v in model_fields.items() if k in meta_fields} def get_field_kws_by_meta(self, field_name): + """ + 判断当前字段是否在Meat内设置了只读或只写 + :param field_name: + :return: + """ read_only_fields = getattr(self.Meta, 'read_only_fields', []) write_only_fields = getattr(self.Meta, 'write_only_fields', []) if field_name in read_only_fields and field_name in write_only_fields: @@ -563,13 +575,15 @@ class ModelSerializer(Serializer): def _clean_model_field(self, model_original_fields): """ 清除不需要的字段如 fk_id - :param model: + :param model_original_fields: :return: """ clean_field_names = [] field_dict = {} fields_map = copy.deepcopy(model_original_fields) - for field_name, field_class in fields_map.items(): + basis_fields = self._get_model_basis_fields(fields_map) + + for field_name, field_class in basis_fields.items(): if isinstance(field_class, (tortoise_fields.relational.ForeignKeyFieldInstance, tortoise_fields.relational.OneToOneFieldInstance)): clean_field_names.append(field_class.source_field) field_dict[field_name] = field_class @@ -628,13 +642,13 @@ class ModelSerializer(Serializer): """检查关系字段 目前不能为关系字段提供自动转换功能""" # - # def _get_model_basis_fields(self, model_fields): - # """ - # 得到基础字段,非关系字段 - # :param model: - # :return: - # """ - # return {field_name: field_class for field_name, field_class in model_fields.items() if not isinstance(field_class, tortoise_fields.relational.RelationalField)} + def _get_model_basis_fields(self, model_fields): + """ + 得到基础字段,非关系字段 + :param model: + :return: + """ + return {field_name: field_class for field_name, field_class in model_fields.items() if not isinstance(field_class, tortoise_fields.relational.RelationalField)} # # def _get_model_M2M_fields(self, model_fields): # """ diff --git a/sanic_rest_framework/views.py b/sanic_rest_framework/views.py index ed333f2..d18829b 100644 --- a/sanic_rest_framework/views.py +++ b/sanic_rest_framework/views.py @@ -17,6 +17,7 @@ from sanic_rest_framework.exceptions import APIException, ValidationError from sanic_rest_framework.filters import SearchFilter from sanic_rest_framework.paginations import GeneralPagination from sanic_rest_framework.status import RuleStatus, HttpStatus +from simplejson import dumps class BaseAPIView: @@ -96,7 +97,7 @@ class BaseAPIView: 'message': msg, 'status': status } - return json(body=response_body, status=http_status) + return json(body=response_body, status=http_status, dumps=dumps) def success_json_response(self, data=None, msg="Success", **kwargs): """ @@ -138,7 +139,7 @@ class BaseAPIView: filter_kwargs = {lookup_field: self.kwargs[lookup_field]} obj = await queryset.get_or_none(**filter_kwargs) if obj is None: - raise APIException('不存在%s为%s的数据' % (lookup_field, self.kwargs[lookup_field])) + raise APIException('不存在%s为%s的数据' % (lookup_field, self.kwargs[lookup_field]), http_status=HttpStatus.HTTP_200_OK) # May raise a permission denied self.check_object_permissions(self.request, obj) @@ -207,7 +208,6 @@ class BaseAPIView: class APIView(BaseAPIView): - detail = False queryset = None lookup_field = 'pk' serializer_class = None @@ -327,7 +327,7 @@ class CreateModelMixin: async def create(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) - serializer.is_valid(raise_exception=True) + await serializer.is_valid(raise_exception=True) await self.perform_create(serializer) return self.success_json_response(data=await serializer.data, http_status=HttpStatus.HTTP_201_CREATED) @@ -335,16 +335,16 @@ class CreateModelMixin: await serializer.save() -class RetrieveModelMixin: +class DetailModelMixin: """ 适用于查询指定PK的内容 """ detail = True async def get(self, request, *args, **kwargs): - return await self.retrieve(request, *args, **kwargs) + return await self.detail(request, *args, **kwargs) - async def retrieve(self, request, *args, **kwargs): + async def detail(self, request, *args, **kwargs): instance = await self.get_object() serializer = self.get_serializer(instance) return self.success_json_response(data=await serializer.data) @@ -365,7 +365,7 @@ class UpdateModelMixin: partial = kwargs.pop('partial', False) instance = await self.get_object() serializer = self.get_serializer(instance, data=request.data, partial=partial) - serializer.is_valid(raise_exception=True) + await serializer.is_valid(raise_exception=True) await self.perform_update(serializer) # if getattr(instance, '_prefetched_objects_cache', None): @@ -386,13 +386,13 @@ class DestroyModelMixin: 用于快速删除 """ - def delete(self, request, *args, **kwargs): - return self.destroy(request, *args, **kwargs) + async def delete(self, request, *args, **kwargs): + return await self.destroy(request, *args, **kwargs) - def destroy(self, request, *args, **kwargs): - instance = self.get_object() - self.perform_destroy(instance) + async def destroy(self, request, *args, **kwargs): + instance = await self.get_object() + await self.perform_destroy(instance) return self.success_json_response(status=HttpStatus.HTTP_204_NO_CONTENT) - def perform_destroy(self, instance): - instance.delete() + async def perform_destroy(self, instance): + await instance.delete() -- Gitee From 742b1d46424f7dcf49131e684d6c39232099443c Mon Sep 17 00:00:00 2001 From: LaoSi Date: Thu, 8 Apr 2021 17:59:00 +0800 Subject: [PATCH 33/34] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E5=A4=A7=E9=87=8FBUG?= =?UTF-8?q?=20ViewSet=E5=8F=8ARoute=E6=8F=90=E4=BE=9B=E6=9B=B4=E5=A5=BD?= =?UTF-8?q?=E7=9A=84=E5=AE=9E=E7=8E=B0=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sanic_rest_framework/authentication.py | 52 +++ sanic_rest_framework/constant.py | 1 + sanic_rest_framework/converter.py | 13 +- sanic_rest_framework/document/route.md | 62 +++ sanic_rest_framework/exceptions.py | 20 + sanic_rest_framework/fields.py | 47 ++- sanic_rest_framework/mixins.py | 126 ++++++ sanic_rest_framework/paginations.py | 32 +- sanic_rest_framework/request.py | 2 +- sanic_rest_framework/routes.py | 178 +++++++-- sanic_rest_framework/serializers.py | 64 +-- sanic_rest_framework/status.py | 1 - .../test_model_serializers.py | 2 +- sanic_rest_framework/validators.py | 6 +- sanic_rest_framework/views.py | 367 +++++++++--------- 15 files changed, 709 insertions(+), 264 deletions(-) create mode 100644 sanic_rest_framework/authentication.py create mode 100644 sanic_rest_framework/document/route.md create mode 100644 sanic_rest_framework/mixins.py diff --git a/sanic_rest_framework/authentication.py b/sanic_rest_framework/authentication.py new file mode 100644 index 0000000..330a459 --- /dev/null +++ b/sanic_rest_framework/authentication.py @@ -0,0 +1,52 @@ +""" +@Author:WangYuXiang +@E-mile:Hill@3io.cc +@CreateTime:2021/3/31 16:21 +@DependencyLibrary:无 +@MainFunction:无 +@FileDoc: + authentication.py + 文件说明 +@ChangeHistory: + datetime action why + example: + 2021/3/31 16:21 change 'Fix bug' + +""" +import jwt +from jwt import ExpiredSignatureError + +from apps.auth.models import UserModel +from sanic_rest_framework.exceptions import APIException +from sanic_rest_framework.request import SRFRequest +from sanic_rest_framework.status import HttpStatus + + +class BaseAuthenticate: + def authenticate(self, request: SRFRequest, **kwargs): + """验证权限并返回User对象""" + # request.headers[''] + + +class TokenAuthenticate(BaseAuthenticate): + token_key = 'X-Token' + + async def authenticate(self, request: SRFRequest, **kwargs): + """验证逻辑""" + token = request.headers.get(self.token_key) + if token is None: + raise APIException(message='授权错误:请求头{}不存在'.format(self.token_key), http_status=HttpStatus.HTTP_401_UNAUTHORIZED) + token_secret = request.app.config.TOKEN_SECRET + try: + token_info = self.authentication_token(token, token_secret) + except ExpiredSignatureError: + raise APIException(message='授权已过期,请重新登录', http_status=HttpStatus.HTTP_401_UNAUTHORIZED) + await self._authenticate(request, token_info, **kwargs) + + async def _authenticate(self, request: SRFRequest, token_info: dict, **kwargs): + """主要处理逻辑""" + pass + + def authentication_token(self, token, token_secret): + token_info = jwt.decode(token, token_secret, algorithms=['HS256']) + return token_info diff --git a/sanic_rest_framework/constant.py b/sanic_rest_framework/constant.py index f5d90b6..7b6da38 100644 --- a/sanic_rest_framework/constant.py +++ b/sanic_rest_framework/constant.py @@ -17,3 +17,4 @@ LIST_METHOD_GROUP = { 'dynamic_method': ['PUT', 'DELETE', 'PATCH'], 'static_method': ['GET', 'POST', 'OPTION'] } +DEFAULT_METHOD_MAP = {'get': 'get', 'post': 'post', 'put': 'put', 'patch': 'patch', 'delete': 'delete', 'head': 'head', 'options': 'options'} diff --git a/sanic_rest_framework/converter.py b/sanic_rest_framework/converter.py index 5eeaf20..987d04c 100644 --- a/sanic_rest_framework/converter.py +++ b/sanic_rest_framework/converter.py @@ -56,20 +56,21 @@ class ModelConverter(ModelConverterBase): read_only = field_kwargs.get('read_only', False) write_only = field_kwargs.get('write_only', False) required = not model_field.null - if read_only and required: - raise ValueError('{}序列化器内字段{}为必填项不能使用read_only属性'.format( - type(serializer).__name__, model_field.model_field_name)) + # if read_only and required: + # raise ValueError('{}序列化器内字段{}为必填项不能使用read_only属性'.format( + # type(serializer).__name__, model_field.model_field_name)) if model_field.pk: field_kwargs['read_only'] = True field_kwargs['required'] = False kwargs = { 'read_only': read_only, 'write_only': write_only, - 'required': required, 'allow_null': model_field.null, - # 'source': None, 'description': model_field.description } + if not read_only: + kwargs['required'] = required + if not isinstance(model_field, fields.relational.RelationalField): type_name = model_field.__class__.__name__ if model_field.default is not None: @@ -81,7 +82,7 @@ class ModelConverter(ModelConverterBase): nested_depth = serializer.Meta.depth else: nested_depth = 10 - kwargs['allow_empty'] = model_field.null + # kwargs['allow_empty'] = model_field.null kwargs['nested_depth'] = nested_depth converter = self.converters[type_name] kwargs.update(field_kwargs) diff --git a/sanic_rest_framework/document/route.md b/sanic_rest_framework/document/route.md new file mode 100644 index 0000000..42b3235 --- /dev/null +++ b/sanic_rest_framework/document/route.md @@ -0,0 +1,62 @@ +# `Route`类介绍 +#### 注意:`Route`类仅适用于 `ViewSetView`类视图 +已实现的基础类有 +* BaseRoute [抽象] +* DefaultRoute [常用] + +上述类最终都会生成可适用于 `spot.add_route()` 使用的数据,格式如下 +> 由于Route类是为 urls.py 文件服务的所以先看一下urls.py -> urls 需要的数据格式 +```python +# urls.py -> urls 如下格式的数据 List[Dict] +[{ + 'handler': TestView, + 'uri': /test, + 'name': 'test', + 'is_base':False, # 缺省为 False + 'methods':['GET','POST'] # 缺省为 ALL_METHOD +}] + +# DefaultRoute.urls 生成的数据格式 +[{ + 'handler': TestView, + 'uri': /test, + 'name': 'test', + 'is_base':False, +}] + +# 所以在具体使用时只需要如下即可 +route = DefaultRoute() +route.register_route(TestView, '/test', test) +'....' +urls = route.urls + +``` +> 源代码 + +```python + +class DefaultRoute(BaseRoute): + def register_route(self, viewset: object, prefix: str, name: str = None, is_base: bool = False): + """ + 注册路由到路由管理类 + :param viewset: 视图,无需 + :param prefix: + :param name: + :param is_base: + :return: + """ + + def get_viewset_method_list(self, viewset): + """ + 得到viewSet所有请求方法 + :param viewset: 类视图 + :return: + """ + + @property + def urls(self): + """得到 urls.py 需要的数据""" + + def initialize(self, destination: Union[Sanic, Blueprint]): + """注册路由到destination中""" +``` diff --git a/sanic_rest_framework/exceptions.py b/sanic_rest_framework/exceptions.py index 8e799b9..f6a110c 100644 --- a/sanic_rest_framework/exceptions.py +++ b/sanic_rest_framework/exceptions.py @@ -112,3 +112,23 @@ class APIException(Exception): 'status': self.status, 'http_status': self.http_status } + + +class ValidationException(Exception): + """验证错误类""" + default_detail = '无效的输入' + default_code = 'invalid' + + def __init__(self, error_detail=None, code=None): + if error_detail is None: + error_detail = self.default_detail + if code is None: + code = self.default_code + + # For validation failures, we may collect many errors together, + # so the details should always be coerced to a list if not already. + if not isinstance(error_detail, dict) and not isinstance(error_detail, list): + error_detail = [error_detail] + + self.code = code + self.error_detail = error_detail diff --git a/sanic_rest_framework/fields.py b/sanic_rest_framework/fields.py index fd9758a..d8be403 100644 --- a/sanic_rest_framework/fields.py +++ b/sanic_rest_framework/fields.py @@ -10,6 +10,7 @@ """ import copy import decimal +import inspect import re from datetime import timezone, timedelta, datetime, date, time from enum import Enum @@ -20,7 +21,7 @@ from tortoise.queryset import QuerySet from tortoise.fields.relational import RelationalField, ManyToManyField from tortoise.exceptions import DoesNotExist -from sanic_rest_framework.exceptions import ValidationError +from sanic_rest_framework.exceptions import ValidationException from sanic_rest_framework.validators import ( MaxLengthValidator, MinLengthValidator, MaxValueValidator, MinValueValidator ) @@ -48,7 +49,6 @@ class Field: """字段及序列化器基类 required: 反序列化时是否必须存在,值限制写入时 allow_null: 是否可以为 None 即存在当没值 - allow_empty: 是否可以为空 value = '' 即为空 """ _sort_counter = 0 @@ -62,15 +62,14 @@ class Field: default_validators = None def __init__(self, read_only=False, write_only=False, required=False, allow_null=False, - allow_empty=False, default=empty, source=None, validators=None, error_messages=None, + default=empty, source=None, validators=None, error_messages=None, label=None, description=None): """ 字段及field的基类 :param read_only: 只序列化 :param write_only: 只反序列化 :param required: 反序列化时必须存在此值 - :param allow_null: 反序列化时可以为 None - :param allow_empty: 反序列化可以为 '' + :param allow_null: 反序列化时可以为 None, '' :param default: 默认值 可用于序列化和反序列化 :param source: 序列化时值的来源 :param validators: 反序列化时数据需要通过的验证 @@ -88,7 +87,7 @@ class Field: self.write_only = write_only self.required = required self.allow_null = allow_null - self.allow_null = allow_empty + # self.allow_null = allow_empty self.default = default self.source = source self.label = label @@ -196,7 +195,7 @@ class Field: :return: """ if not isinstance(data, Mapping): - raise ValidationError('传入的数据为无效数据类型,仅支持字典类型'.format(self.field_name)) + raise ValidationException('传入的数据为无效数据类型,仅支持字典类型'.format(self.field_name)) if self.field_name not in data: if self.is_partial(): return empty @@ -246,12 +245,10 @@ class Field: for validator in self.validators: try: validator(data, self) - except ValidationError as exc: - if hasattr(exc, 'code') and exc.code in self.error_messages and isinstance(exc.message, ValidationError): - exc.message = self.error_messages[exc.code] - errors.extend(exc.error_list) + except ValidationException as exc: + errors.extend(exc.error_detail) if errors: - raise ValidationError(errors) + raise ValidationException(errors) def get_default(self): if self.default is empty or getattr(self.root, 'partial', False): @@ -325,7 +322,7 @@ class Field: "属性中未能找到 Key 为 {key} 的错误描述".format(class_name=class_name, key=_key) raise AssertionError(msg) message_string = msg.format(**kwargs) - raise ValidationError(message_string, code=_key) + raise ValidationException(message_string, code=_key) def __str__(self): return super(Field, self).__str__() + self.field_name @@ -344,9 +341,11 @@ class CharField(Field): self.trim_whitespace = kwargs.pop('trim_whitespace', True) super(CharField, self).__init__(*args, **kwargs) if self.max_length is not None: - self.validators.append(MaxLengthValidator(max_length=self.max_length, error_messages={'max_length': self.error_messages['max_length']})) + self.validators.append(MaxLengthValidator(max_length=self.max_length, + error_messages={'max_length': self.error_messages['max_length']})) if self.min_length is not None: - self.validators.append(MinLengthValidator(min_length=self.min_length, error_messages={'min_length': self.error_messages['min_length']})) + self.validators.append(MinLengthValidator(min_length=self.min_length, + error_messages={'min_length': self.error_messages['min_length']})) async def external_to_internal(self, data: Any) -> Any: if isinstance(data, bool) or not isinstance(data, (str, int, float,)): @@ -374,9 +373,11 @@ class IntegerField(Field): self.min_value = min_value super(IntegerField, self).__init__(*args, **kwargs) if self.max_value is not None: - self.validators.append(MaxValueValidator(max_value=self.max_value, error_messages={'max_value': self.error_messages['max_value']})) + self.validators.append(MaxValueValidator(max_value=self.max_value, + error_messages={'max_value': self.error_messages['max_value']})) if self.min_value is not None: - self.validators.append(MinValueValidator(min_value=self.min_value, error_messages={'min_value': self.error_messages['min_value']})) + self.validators.append(MinValueValidator(min_value=self.min_value, + error_messages={'min_value': self.error_messages['min_value']})) async def external_to_internal(self, data: Any): if isinstance(data, str) and len(data) > self.MAX_STRING_LENGTH: @@ -453,9 +454,11 @@ class DecimalField(Field): self.max_whole_digits = None super(DecimalField, self).__init__(*args, **kwargs) if self.max_value is not None: - self.validators.append(MaxValueValidator(max_value=self.max_value, error_messages={'max_value': self.error_messages['max_value']})) + self.validators.append(MaxValueValidator(max_value=self.max_value, + error_messages={'max_value': self.error_messages['max_value']})) if self.min_value is not None: - self.validators.append(MinValueValidator(min_value=self.min_value, error_messages={'min_value': self.error_messages['min_value']})) + self.validators.append(MinValueValidator(min_value=self.min_value, + error_messages={'min_value': self.error_messages['min_value']})) async def external_to_internal(self, data: Any): data = str(data).strip() @@ -798,9 +801,11 @@ class SerializerMethodField(Field): super().bind(field_name, parent) - def internal_to_external(self, data: Any) -> Any: + async def internal_to_external(self, data: Any) -> Any: method = getattr(self.parent, self.method_name) + if inspect.iscoroutinefunction(method): + return await method(data) return method(data) def external_to_internal(self, data: Any) -> Any: - raise ValidationError('SerializerMethodField 不支持反序列化') + raise ValidationException('SerializerMethodField 不支持反序列化') diff --git a/sanic_rest_framework/mixins.py b/sanic_rest_framework/mixins.py new file mode 100644 index 0000000..04deaaf --- /dev/null +++ b/sanic_rest_framework/mixins.py @@ -0,0 +1,126 @@ +""" +@Author:WangYuXiang +@E-mile:Hill@3io.cc +@CreateTime:2021/3/26 14:43 +@DependencyLibrary:无 +@MainFunction:无 +@FileDoc: + mixins.py + 文件说明 +@ChangeHistory: + datetime action why + example: + 2021/3/26 14:43 change 'Fix bug' + +""" +from sanic_rest_framework.paginations import GeneralPagination +from sanic_rest_framework.status import HttpStatus + +__all__ = [ + 'ListModelMixin', 'CreateModelMixin', 'RetrieveModelMixin', 'UpdateModelMixin', 'DestroyModelMixin' +] + + +class ListModelMixin: + """ + 适用于输出列表类型数据 + """ + pagination_class = GeneralPagination + detail = False + + async def get(self, request, *args, **kwargs): + return await self.list(request, *args, **kwargs) + + async def list(self, request, *args, **kwargs): + queryset = self.filter_queryset(self.get_queryset()) + + page = await self.paginate_queryset(queryset) + if page is not None: + serializer = self.get_serializer(page, many=True) + data = self.get_paginated_response(await serializer.data) + return self.success_json_response(data=data) + + serializer = self.get_serializer(queryset, many=True) + return self.success_json_response(data=await serializer.data) + + +class CreateModelMixin: + """ + 适用于快速创建内容 + 占用 post 方法 + """ + + async def post(self, request, *args, **kwargs): + return await self.create(request, *args, **kwargs) + + async def create(self, request, *args, **kwargs): + serializer = self.get_serializer(data=request.data) + await serializer.is_valid(raise_exception=True) + await self.perform_create(serializer) + return self.success_json_response(data=await serializer.data, http_status=HttpStatus.HTTP_201_CREATED) + + async def perform_create(self, serializer): + await serializer.save() + + +class RetrieveModelMixin: + """ + 适用于查询指定PK的内容 + """ + detail = True + + async def get(self, request, *args, **kwargs): + return await self.retrieve(request, *args, **kwargs) + + async def retrieve(self, request, *args, **kwargs): + instance = await self.get_object() + serializer = self.get_serializer(instance) + return self.success_json_response(data=await serializer.data) + + +class UpdateModelMixin: + """ + 适用于快速创建更新操作 + """ + + async def put(self, request, *args, **kwargs): + return await self.update(request, *args, **kwargs) + + async def patch(self, request, *args, **kwargs): + return await self.partial_update(request, *args, **kwargs) + + async def update(self, request, *args, **kwargs): + partial = kwargs.pop('partial', False) + instance = await self.get_object() + serializer = self.get_serializer(instance, data=request.data, partial=partial) + await serializer.is_valid(raise_exception=True) + await self.perform_update(serializer) + + # if getattr(instance, '_prefetched_objects_cache', None): + # instance._prefetched_objects_cache = {} + + return self.success_json_response(data=await serializer.data) + + async def perform_update(self, serializer): + await serializer.save() + + async def partial_update(self, request, *args, **kwargs): + kwargs['partial'] = True + return await self.update(request, *args, **kwargs) + + +class DestroyModelMixin: + """ + 用于快速删除 + """ + + async def delete(self, request, *args, **kwargs): + return await self.destroy(request, *args, **kwargs) + + async def destroy(self, request, *args, **kwargs): + instance = await self.get_object() + await self.perform_destroy(instance) + return self.success_json_response(status=HttpStatus.HTTP_204_NO_CONTENT) + + async def perform_destroy(self, instance): + await instance.delete() diff --git a/sanic_rest_framework/paginations.py b/sanic_rest_framework/paginations.py index b0f3106..8c11983 100644 --- a/sanic_rest_framework/paginations.py +++ b/sanic_rest_framework/paginations.py @@ -46,25 +46,49 @@ class GeneralPagination(BasePagination): return self._count def get_next_link(self, request: Request): + """ + 得到下一页的请求地址 + :param request: + :return: None or String + """ assert hasattr(self, '_count'), '必须先执行 `.paginate_queryset()` 函数才能使用.get_next_link()' - if self.page * self.page_size + self.page_size >= self.count: + page = self.get_next_page() + if not page: return None uri = request.server_path query_string = '?' + request.query_string - query_string = replace_query_param(query_string, self.page_query_param, self.page + 1) + query_string = replace_query_param(query_string, self.page_query_param, page) query_string = replace_query_param(query_string, self.page_size_query_param, self.page_size) return uri + query_string + def get_next_page(self): + """得到下一页的页码,不存在则返回None""" + if self.page * self.page_size + self.page_size >= self.count: + return None + return self.page + 1 + def get_previous_link(self, request: Request): + """ + 得到上一页的请求地址 + :param request: + :return: None or String + """ assert hasattr(self, '_count'), '必须先执行 `.paginate_queryset()` 函数才能使用.get_previous_link()' - if self.page * self.page_size <= 0: + page = self.get_previous_page() + if not page: return None uri = request.server_path query_string = '?' + request.query_string - query_string = replace_query_param(query_string, self.page_query_param, self.page - 1) + query_string = replace_query_param(query_string, self.page_query_param, page) query_string = replace_query_param(query_string, self.page_size_query_param, self.page_size) return uri + query_string + def get_previous_page(self): + """得到上一页页码,不存在则返回None""" + if self.page * self.page_size <= 0: + return None + return self.page - 1 + async def paginate_queryset(self, queryset, request, view): self.page = self.get_query_page(request) self.page_size = self.get_query_page_size(request) diff --git a/sanic_rest_framework/request.py b/sanic_rest_framework/request.py index 0261545..148aee8 100644 --- a/sanic_rest_framework/request.py +++ b/sanic_rest_framework/request.py @@ -30,4 +30,4 @@ class SRFRequest(SanicRequest): data = self.json except InvalidUsage as exc: data = self.form - return data + return {} if data is None else data diff --git a/sanic_rest_framework/routes.py b/sanic_rest_framework/routes.py index 94a901c..32b5913 100644 --- a/sanic_rest_framework/routes.py +++ b/sanic_rest_framework/routes.py @@ -8,64 +8,196 @@ routes.py 便捷路由文件 """ +from collections import namedtuple from typing import List, Type, Union from sanic import Sanic, Blueprint from .constant import ALL_METHOD, DETAIL_METHOD_GROUP, LIST_METHOD_GROUP - # 默认分组 +from .views import BaseView, ViewSetView + +Route = namedtuple('Route', ['url', 'mapping', 'name', 'detail', 'initkwargs', 'is_base']) + + +class BaseRoute: + + def __init__(self, trailing_slash=False): + """ + 基础类 + :param trailing_slash: + """ + self.trailing_slash = '/' if trailing_slash else '' + self.registry = [] + + def register(self, viewset: object, prefix: str, basename: str = None, is_base: bool = False): + if basename is None: + basename = prefix.replace('/', '_') + self.registry.append((prefix, viewset, basename, is_base)) + + # invalidate the urls cache + if hasattr(self, '_urls'): + del self._urls + def get_urls(self): + pass -class Route: - def __init__(self): - self.routes = [] + @property + def urls(self): + if not hasattr(self, '_urls'): + self._urls = self.get_urls() + return self._urls - def register_route(self, prefix, viewset, name=None): + +class ViewSetRouter(BaseRoute): + routes = [ + # List route. + Route( + url=r'{prefix}{trailing_slash}', + mapping={ + 'get': 'list', + 'post': 'create' + }, + name='{basename}-list', + detail=False, + initkwargs={'suffix': 'List'}, + is_base=False, + ), + Route( + url=r'{prefix}/<{lookup}:string>{trailing_slash}', + mapping={ + 'get': 'retrieve', + 'put': 'update', + 'patch': 'partial_update', + 'delete': 'destroy' + }, + name='{basename}-detail', + detail=True, + initkwargs={'suffix': 'Instance'}, + is_base=False, + ), + ] + + def get_lookup(self, viewset): """ - 注册路由 - :param prefix: url 前缀 - :param viewset: 视图类 - :param name: 供 url_for 使用的名称 + 得到主键字段名 + :param viewset: 集合视图 :return: """ - if name is None: - name = prefix + return getattr(viewset, 'lookup_field', 'pk') + + def get_urls(self): + """ + Use the registered viewsets to generate a list of URL patterns. + """ + ret = [] + + for prefix, viewset, basename, is_base in self.registry: + lookup = self.get_lookup(viewset) + + for route in self.routes: + + # Only actions which actually exist on the viewset will be bound + mapping = self.get_method_map(viewset, route.mapping) + if not mapping: + continue - dynamic_uri = '/{prefix}/<{lookup_field}:string>' - static_uri = '/{prefix}' + # Build the url pattern + uri = route.url.format( + prefix=prefix, + lookup=lookup, + trailing_slash=self.trailing_slash + ) + + # If there is no prefix, the first part of the url is probably + # controlled by project's urls.py and the router is in an app, + # so a slash in the beginning will (A) cause Django to give + # warnings and (B) generate URLS that will require using '//'. + if not prefix and uri[:2] == '^/': + uri = '^' + uri[2:] + + initkwargs = route.initkwargs.copy() + initkwargs.update({ + 'basename': basename, + 'detail': route.detail, + }) + + view = viewset.as_view(mapping, **initkwargs) + name = route.name.format(basename=basename) + ret.append({ + 'handler': view, + 'uri': uri, + 'name': name, + 'is_base': is_base, + + }) + + return ret + + def get_method_map(self, viewset, method_map): + """得到可用的模型""" + bound_methods = {} + for method, action in method_map.items(): + if hasattr(viewset, action): + bound_methods[method] = action + return bound_methods + + +class DefaultRoute(BaseRoute): + def register_route(self, viewset: object, prefix: str, name: str = None, is_base: bool = False): + """ + 注册路由到路由管理类 + :param viewset: 视图,无需 + :param prefix: + :param name: + :param is_base: + :return: + """ + if name is None: + name = prefix.replace('/', '_') base_method_group = LIST_METHOD_GROUP if hasattr(viewset, 'detail') and viewset.detail: base_method_group = DETAIL_METHOD_GROUP - viewset_methods = self.get_viewset_methods(viewset) - viewset_dynamic_method = [i for i in viewset_methods if i in base_method_group['dynamic_method']] - viewset_static_method = [i for i in viewset_methods if i in base_method_group['static_method']] + viewset_method_list = self.get_viewset_method_list(viewset) + viewset_dynamic_method = [i for i in viewset_method_list if i in base_method_group['dynamic_method']] + viewset_static_method = [i for i in viewset_method_list if i in base_method_group['static_method']] if viewset_dynamic_method: self.routes.append({ 'handler': viewset.as_view(viewset_dynamic_method), - 'uri': dynamic_uri.format(prefix=prefix, lookup_field=viewset.lookup_field), - 'name': '{name}_detail'.format(name=name) + 'uri': self.dynamic_uri.format(prefix=prefix, lookup_field=viewset.lookup_field), + 'name': '{name}_detail'.format(name=name), + 'is_base': is_base + }) if viewset_static_method: self.routes.append({ 'handler': viewset.as_view(viewset_static_method), - 'uri': static_uri.format(prefix=prefix), - 'name': '{name}_list'.format(name=name) + 'uri': self.static_uri.format(prefix=prefix), + 'name': '{name}_list'.format(name=name), + 'is_base': is_base }) - def get_viewset_methods(self, viewset): - """得到viewSet所有请求方法""" + def get_viewset_method_list(self, viewset): + """ + 得到viewSet所有请求方法 + :param viewset: 类视图 + :return: + """ methods = [] for method in ALL_METHOD: if hasattr(viewset, method.lower()): methods.append(method) return methods + @property + def urls(self): + return self.routes + def initialize(self, destination: Union[Sanic, Blueprint]): """注册路由""" for route in self.routes: - route['methods'] = ALL_METHOD + route.pop('is_base') destination.add_route(**route) diff --git a/sanic_rest_framework/serializers.py b/sanic_rest_framework/serializers.py index a4197d3..6dc2d7c 100644 --- a/sanic_rest_framework/serializers.py +++ b/sanic_rest_framework/serializers.py @@ -17,13 +17,15 @@ from typing import Any, Mapping, Coroutine from tortoise import models, Model from tortoise.queryset import QuerySet from tortoise import fields as tortoise_fields +from tortoise.fields.relational import ForeignKeyFieldInstance, OneToOneFieldInstance, ManyToManyFieldInstance, ManyToManyRelation from sanic_rest_framework.fields import ( empty, SkipField, - Field, CharField, IntegerField, FloatField, DecimalField, BooleanField, DateTimeField, DateField, TimeField, ChoiceField, SerializerMethodField + Field, CharField, IntegerField, FloatField, DecimalField, BooleanField, DateTimeField, DateField, TimeField, + ChoiceField, SerializerMethodField ) from sanic_rest_framework.converter import ModelConverter -from .exceptions import ValidationError +from .exceptions import ValidationException from .helpers import BindingDict LIST_SERIALIZER_KWARGS = ( @@ -70,7 +72,6 @@ class BaseSerializer(Field): .data .install_data -> .get_external_value -> .external_to_internal() -> .validated_data - """ def __init__(self, instance=None, data=empty, **kwargs): @@ -89,14 +90,14 @@ class BaseSerializer(Field): @classmethod def many_init(cls, *args, **kwargs): - """""" - allow_empty = kwargs.pop('allow_empty', None) + """初始化多值组件""" + # allow_empty = kwargs.pop('allow_empty', None) child_serializer = cls(*args, **kwargs) list_kwargs = { 'child': child_serializer, } - if allow_empty is not None: - list_kwargs['allow_empty'] = allow_empty + # if allow_empty is not None: + # list_kwargs['allow_empty'] = allow_empty list_kwargs.update({ key: value for key, value in kwargs.items() if key in LIST_SERIALIZER_KWARGS @@ -189,17 +190,14 @@ class BaseSerializer(Field): if not hasattr(self, '_validated_data'): try: self._validated_data = await self.run_validation(self.initial_data) - except ValidationError as exc: + except ValidationException as exc: self._validated_data = {} - if hasattr(exc, 'error_dict'): - self._errors = exc.error_dict - else: - self._errors = {self.field_name, exc.error_list} + self._errors = exc.error_detail else: self._errors = {} if self._errors and raise_exception: - raise ValidationError(self.errors) + raise ValidationException(self.errors) return not bool(self._errors) @property @@ -329,15 +327,19 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass): primitive_value = field.get_external_value(data) validated_value = await field.run_validation(primitive_value) if validate_method is not None: - validated_value = validate_method(validated_value) - except ValidationError as exc: - errors[field.field_name] = exc.error_dict if hasattr(exc, 'error_dict') else exc.error_list + if inspect.iscoroutinefunction(validate_method): + validated_value = await validate_method(validated_value) + else: + validated_value = validate_method(validated_value) + + except ValidationException as exc: + errors[field.field_name] = exc.error_detail except SkipField: pass else: set_value(res, field.source_attrs, validated_value) if errors: - raise ValidationError(errors) + raise ValidationException(errors) return res async def validate(self, attrs): @@ -364,29 +366,26 @@ class ListSerializer(BaseSerializer): default_error_messages = { 'not_a_list': '预期项目列表,但类型为“ {input_type}”。', - 'empty': '此列表不能为空。' + 'null': '此列表不能为空。' } def __init__(self, *args, **kwargs): self.child = kwargs.pop('child', copy.deepcopy(self.child)) - self.allow_empty = kwargs.pop('allow_empty', True) + # self.allow_empty = kwargs.pop('allow_empty', True) assert self.child is not None, '`child` 是必填参数。' assert not inspect.isclass(self.child), '`child` 尚未实例化。' super().__init__(*args, **kwargs) self.child.bind(field_name='', parent=self) async def get_internal_value(self, instance: Any) -> Any: - + """目的得到值""" for attr in self.source_attrs: if isinstance(instance, Mapping): instance = instance.get(attr, []) if not isinstance(instance, list): instance = [instance] else: - instance = await getattr(instance, attr) - # instance = getattr(instance, attr) - # if inspect.iscoroutine(instance): - # instance = await instance + instance = getattr(instance, attr) return instance async def internal_to_external(self, data: Any) -> Any: @@ -410,8 +409,8 @@ class ListSerializer(BaseSerializer): if not isinstance(data, list): raise self.raise_error('not_a_list', input_type=type(data).__name__) - if not self.allow_empty and len(data) == 0: - raise self.raise_error('empty') + if not self.allow_null and len(data) == 0: + raise self.raise_error('null') ret = [] errors = [] @@ -419,13 +418,13 @@ class ListSerializer(BaseSerializer): for item in data: try: value = await self.child.run_validation(item) - except ValidationError as exc: - errors.append(exc) + except ValidationException as exc: + errors.append(exc.error_detail) else: ret.append(value) errors.append({}) if any(errors): - raise ValidationError(errors) + raise ValidationException(errors) return ret async def run_validation(self, data=empty): @@ -584,7 +583,9 @@ class ModelSerializer(Serializer): basis_fields = self._get_model_basis_fields(fields_map) for field_name, field_class in basis_fields.items(): - if isinstance(field_class, (tortoise_fields.relational.ForeignKeyFieldInstance, tortoise_fields.relational.OneToOneFieldInstance)): + if isinstance(field_class, ( + tortoise_fields.relational.ForeignKeyFieldInstance, + tortoise_fields.relational.OneToOneFieldInstance)): clean_field_names.append(field_class.source_field) field_dict[field_name] = field_class @@ -648,7 +649,8 @@ class ModelSerializer(Serializer): :param model: :return: """ - return {field_name: field_class for field_name, field_class in model_fields.items() if not isinstance(field_class, tortoise_fields.relational.RelationalField)} + return {field_name: field_class for field_name, field_class in model_fields.items() if + not isinstance(field_class, tortoise_fields.relational.RelationalField)} # # def _get_model_M2M_fields(self, model_fields): # """ diff --git a/sanic_rest_framework/status.py b/sanic_rest_framework/status.py index e482042..189a6fc 100644 --- a/sanic_rest_framework/status.py +++ b/sanic_rest_framework/status.py @@ -8,7 +8,6 @@ status.py Http status describe file """ -from enum import Enum def is_informational(code): diff --git a/sanic_rest_framework/test/test_serializers/test_model_serializers.py b/sanic_rest_framework/test/test_serializers/test_model_serializers.py index d9de6fe..9b71234 100644 --- a/sanic_rest_framework/test/test_serializers/test_model_serializers.py +++ b/sanic_rest_framework/test/test_serializers/test_model_serializers.py @@ -84,7 +84,7 @@ class TestOrdinarySerializer(TestCase): tms = TestModelSerializer(instance=await TestModel.all(), many=True) print(await tms.data) tms = TestModelSerializer(data=self.data, partial=True) - tms.is_valid() + await tms.is_valid() print(tms.validated_data) print(tms.fields) diff --git a/sanic_rest_framework/validators.py b/sanic_rest_framework/validators.py index 4e7daa3..89b37db 100644 --- a/sanic_rest_framework/validators.py +++ b/sanic_rest_framework/validators.py @@ -16,7 +16,9 @@ import copy from typing import Dict -from .exceptions import ValidationError, ValidatorAssertError +from .exceptions import ValidationException, ValidatorAssertError + +__all__ = ['BaseValidator', 'MaxLengthValidator', 'MinLengthValidator', 'MaxValueValidator', 'MinValueValidator'] class BaseValidator: @@ -40,7 +42,7 @@ class BaseValidator: def raise_error(self, key, **kws): msg = self.default_error_messages[key].format(**kws) - raise ValidationError(msg, code=key) + raise ValidationException(msg, code=key) class MaxLengthValidator(BaseValidator): diff --git a/sanic_rest_framework/views.py b/sanic_rest_framework/views.py index d18829b..07ac140 100644 --- a/sanic_rest_framework/views.py +++ b/sanic_rest_framework/views.py @@ -7,79 +7,186 @@ @FileDoc: views.py 基础视图文件 + BaseView 只实现路由分发的基础视图 + GeneralView 通用视图,可以基于其实现增删改查,提供权限套件 + ViewSetView 视图集视图,可以配合Mixin实现复杂的视图集, + 数据来源基于模型查询集,可以配合Route组件实现便捷的路由管理 + + + """ -from sanic.response import json -from tortoise.queryset import QuerySet +import inspect +from datetime import datetime + from sanic.log import logger +from sanic.response import json, HTTPResponse -from sanic_rest_framework.constant import ALL_METHOD -from sanic_rest_framework.exceptions import APIException, ValidationError +from sanic_rest_framework import mixins +from sanic_rest_framework.constant import ALL_METHOD, DEFAULT_METHOD_MAP +from sanic_rest_framework.exceptions import APIException, ValidationException from sanic_rest_framework.filters import SearchFilter -from sanic_rest_framework.paginations import GeneralPagination +from sanic_rest_framework.mixins import CreateModelMixin, ListModelMixin, DestroyModelMixin, UpdateModelMixin, \ + RetrieveModelMixin from sanic_rest_framework.status import RuleStatus, HttpStatus from simplejson import dumps +from tortoise.queryset import QuerySet +__all__ = ['BaseView', 'GeneralViewView', 'ViewSetView', 'CRUDView', 'CLUDView'] -class BaseAPIView: - """基础API视图""" - authentication_classes = () - permission_classes = () + +class BaseView: + """只实现路由分发的基础视图 + 在使用时应当开放全部路由 ALL_METHOD + app.add_route('/test', BaseView.as_view(), 'test', ALL_METHOD) + 如需限制路由则在其他地方注明 + app.add_route('/test', BaseView.as_view(), 'test', ALL_METHOD) + 注意以上方法的报错是不可控的 + """ + + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) async def dispatch(self, request, *args, **kwargs): """分发路由""" request.user = None method = request.method - if method not in self.licensed_methods: - return self.json_response(msg='发生错误:未找到%s方法' % method, status=RuleStatus.STATUS_0_FAIL, - http_status=HttpStatus.HTTP_405_METHOD_NOT_ALLOWED) + if method.lower() not in self.method_map: + return HTTPResponse('405请求方法错误', status=405) handler = getattr(self, method.lower(), None) - - try: - self.initial(request, *args, **kwargs) - response = await handler(request, *args, **kwargs) - except APIException as exc: - response = self.handle_exception(exc) - except ValidationError as exc: - response = self.error_json_response(exc.message_dict, '数据验证失败') - except AssertionError as exc: - raise exc - except Exception as exc: - logger.error('--捕获未知错误--', exc) - response = self.handle_uncaught_exception(exc) + response = handler(request, *args, **kwargs) + if inspect.isawaitable(response): + response = await response return response @classmethod - def as_view(cls, methods=None, *class_args, **class_kwargs): + def get_method_map(cls): + methods = {} + for method in ALL_METHOD: + method = method.lower() + if hasattr(cls, method): + methods[method] = method + return methods - # 许可的方法 - if methods is None: - methods = ALL_METHOD + @classmethod + def as_view(cls, method_map=DEFAULT_METHOD_MAP, *class_args, **class_kwargs): # 返回的响应方法闭包 def view(request, *args, **kwargs): self = view.base_class(*class_args, **class_kwargs) - self.licensed_methods = methods + self.method_map = method_map + + for method, action in method_map.items(): + handler = getattr(self, action) + setattr(self, method, handler) + self.request = request self.args = args self.kwargs = kwargs + self.app = request.app return self.dispatch(request, *args, **kwargs) view.base_class = cls + view.methods = list(method_map.keys()) view.API_DOC_CONFIG = class_kwargs.get('API_DOC_CONFIG') # 未来的API文档配置属性+ view.__doc__ = cls.__doc__ view.__module__ = cls.__module__ view.__name__ = cls.__name__ return view + +# +# class BaseView: +# """只实现路由分发的基础视图 +# 在使用时应当开放全部路由 ALL_METHOD +# app.add_route('/test', BaseView.as_view(), 'test', ALL_METHOD) +# 如需限制路由则 +# app.add_route('/test', BaseView.as_view(['GET','POST']), 'test', ALL_METHOD) +# OR +# app.add_route('/test', BaseView.as_view(), 'test', ['GET','POST']) +# 注意以上方法的报错是不可控的 +# """ +# +# async def dispatch(self, request, *args, **kwargs): +# """分发路由""" +# request.user = None +# method = request.method +# if method not in self.licensed_methods: +# return HTTPResponse('405请求方法错误', status=405) +# handler = getattr(self, method.lower(), None) +# response = handler(request, *args, **kwargs) +# if inspect.isawaitable(response): +# response = await response +# return response +# +# @classmethod +# def get_methods(cls): +# methods = [] +# for method in ALL_METHOD: +# if hasattr(cls, method.lower()): +# methods.append(method) +# return methods +# +# @classmethod +# def as_view(cls, methods=None, *class_args, **class_kwargs): +# +# # 许可的方法 +# if methods is None: +# methods = cls.get_methods() +# +# # 返回的响应方法闭包 +# def view(request, *args, **kwargs): +# self = view.base_class(*class_args, **class_kwargs) +# self.licensed_methods = methods +# self.request = request +# self.args = args +# self.kwargs = kwargs +# self.app = request.app +# return self.dispatch(request, *args, **kwargs) +# +# view.base_class = cls +# view.API_DOC_CONFIG = class_kwargs.get('API_DOC_CONFIG') # 未来的API文档配置属性+ +# view.__doc__ = cls.__doc__ +# view.__module__ = cls.__module__ +# view.__name__ = cls.__name__ +# return view + + +class GeneralViewView(BaseView): + """通用视图,可以基于其实现增删改查,提供权限套件""" + authentication_classes = () + permission_classes = () + + async def dispatch(self, request, *args, **kwargs): + """分发路由""" + request.user = None + method = request.method + if method.lower() not in self.method_map: + return self.json_response(msg='发生错误:未找到%s方法' % method, status=RuleStatus.STATUS_0_FAIL, + http_status=HttpStatus.HTTP_405_METHOD_NOT_ALLOWED) + handler = getattr(self, method.lower(), None) + + try: + await self.initial(request, *args, **kwargs) + response = handler(request, *args, **kwargs) + if inspect.isawaitable(response): + response = await response + except APIException as exc: + response = self.handle_exception(exc) + except ValidationException as exc: + response = self.error_json_response(exc.error_detail, '数据验证失败') + except AssertionError as exc: + raise exc + except Exception as exc: + logger.error('--捕获未知错误--', exc) + msg = '发生致命的未知错误,请在服务器查看时间为{}的日志'.format(datetime.now().strftime('%F %T')) + response = self.json_response(msg=msg, status=RuleStatus.STATUS_0_FAIL, + http_status=HttpStatus.HTTP_500_INTERNAL_SERVER_ERROR) + return response + def handle_exception(self, exc: APIException): return self.json_response(**exc.response_data()) - def handle_uncaught_exception(self, exc): - """处理未知错误""" - message = '{}:{}'.format(exc.__class__.__name__, '|'.join(exc.args)) - return self.json_response(msg=message, status=RuleStatus.STATUS_0_FAIL, - http_status=HttpStatus.HTTP_500_INTERNAL_SERVER_ERROR) - def json_response(self, data=None, msg="OK", status=RuleStatus.STATUS_1_SUCCESS, http_status=HttpStatus.HTTP_200_OK): """ @@ -121,45 +228,20 @@ class BaseAPIView: http_status = kwargs.pop('http_status', HttpStatus.HTTP_400_BAD_REQUEST) return self.json_response(data=data, msg=msg, status=status, http_status=http_status) - async def get_object(self): - """ - 返回视图显示的对象。 - 如果您需要提供非标准的内容,则可能要覆盖此设置 - queryset查找。 - """ - queryset = self.filter_queryset(self.get_queryset()) - - lookup_field = self.lookup_field - - assert lookup_field in self.kwargs, ( - '%s 不存在于 %s 的 Url配置中的关键词内 ' % - (lookup_field, self.__class__.__name__,) - ) - - filter_kwargs = {lookup_field: self.kwargs[lookup_field]} - obj = await queryset.get_or_none(**filter_kwargs) - if obj is None: - raise APIException('不存在%s为%s的数据' % (lookup_field, self.kwargs[lookup_field]), http_status=HttpStatus.HTTP_200_OK) - - # May raise a permission denied - self.check_object_permissions(self.request, obj) - - return obj - def get_authenticators(self): """ 实例化并返回此视图可以使用的身份验证器列表 """ return [auth() for auth in self.authentication_classes] - def check_authentication(self, request): + async def check_authentication(self, request): """ 检查权限 查看是否拥有权限,并在此处为Request.User 赋值 :param request: 请求 :return: """ for authenticators in self.get_authenticators(): - request.user = authenticators.authenticate(request, request.user) + await authenticators.authenticate(request) def get_permissions(self): """ @@ -167,7 +249,7 @@ class BaseAPIView: """ return [permission() for permission in self.permission_classes] - def check_permissions(self, request): + async def check_permissions(self, request): """ 检查是否应允许该请求,如果不允许该请求, 则在 has_permission 中引发一个适当的异常。 @@ -175,9 +257,9 @@ class BaseAPIView: :return: """ for permission in self.get_permissions(): - permission.has_permission(request, self) + await permission.has_permission(request, self) - def check_object_permissions(self, request, obj): + async def check_object_permissions(self, request, obj): """ 检查是否应允许给定对象的请求, 如果不允许该请求, 则在 has_object_permission 中引发一个适当的异常。 @@ -187,9 +269,9 @@ class BaseAPIView: :return: """ for permission in self.get_permissions(): - permission.has_object_permission(request, self, obj) + await permission.has_object_permission(request, self, obj) - def check_throttles(self, request): + async def check_throttles(self, request): """ 检查范围频率。 则引发一个 APIException 异常。 @@ -198,16 +280,20 @@ class BaseAPIView: """ pass - def initial(self, request, *args, **kwargs): + async def initial(self, request, *args, **kwargs): """ 在请求分发之前执行初始化操作,用于检查权限及检查基础内容 """ - self.check_authentication(request) - self.check_permissions(request) - self.check_throttles(request) + await self.check_authentication(request) + await self.check_permissions(request) + await self.check_throttles(request) -class APIView(BaseAPIView): +class ViewSetView(GeneralViewView): + """ + 视图集视图,可以配合Mixin实现复杂的视图集, + 数据来源基于模型查询集,可以配合Route组件实现便捷的路由管理 + """ queryset = None lookup_field = 'pk' serializer_class = None @@ -215,6 +301,32 @@ class APIView(BaseAPIView): filter_class = SearchFilter search_fields = None + async def get_object(self): + """ + 返回视图显示的对象。 + 如果您需要提供非标准的内容,则可能要覆盖此设置 + queryset查找。 + """ + queryset = self.filter_queryset(self.get_queryset()) + + lookup_field = self.lookup_field + + assert lookup_field in self.kwargs, ( + '%s 不存在于 %s 的 Url配置中的关键词内 ' % + (lookup_field, self.__class__.__name__,) + ) + + filter_kwargs = {lookup_field: self.kwargs[lookup_field]} + obj = await queryset.get_or_none(**filter_kwargs) + if obj is None: + raise APIException('不存在%s为%s的数据' % (lookup_field, self.kwargs[lookup_field]), + http_status=HttpStatus.HTTP_200_OK) + + # May raise a permission denied + await self.check_object_permissions(self.request, obj) + + return obj + def get_queryset(self): assert self.queryset is not None, ( "'%s'应该包含一个'queryset'属性," @@ -293,106 +405,13 @@ class APIView(BaseAPIView): return self.paginator.response(self.request, data) -class ListModelMixin: +class ModelViewSet(mixins.CreateModelMixin, + mixins.RetrieveModelMixin, + mixins.UpdateModelMixin, + mixins.DestroyModelMixin, + mixins.ListModelMixin, + ViewSetView): """ - 适用于输出列表类型数据 + `create()`, `retrieve()`, `update()`, `partial_update()`, `destroy()`, `list()` actions. """ - pagination_class = GeneralPagination - detail = False - - async def get(self, request, *args, **kwargs): - return await self.list(request, *args, **kwargs) - - async def list(self, request, *args, **kwargs): - queryset = self.filter_queryset(self.get_queryset()) - - page = await self.paginate_queryset(queryset) - if page is not None: - serializer = self.get_serializer(page, many=True) - data = self.get_paginated_response(await serializer.data) - return self.success_json_response(data=data) - - serializer = self.get_serializer(queryset, many=True) - return self.success_json_response(data=await serializer.data) - - -class CreateModelMixin: - """ - 适用于快速创建内容 - 占用 post 方法 - """ - - async def post(self, request, *args, **kwargs): - return await self.create(request, *args, **kwargs) - - async def create(self, request, *args, **kwargs): - serializer = self.get_serializer(data=request.data) - await serializer.is_valid(raise_exception=True) - await self.perform_create(serializer) - return self.success_json_response(data=await serializer.data, http_status=HttpStatus.HTTP_201_CREATED) - - async def perform_create(self, serializer): - await serializer.save() - - -class DetailModelMixin: - """ - 适用于查询指定PK的内容 - """ - detail = True - - async def get(self, request, *args, **kwargs): - return await self.detail(request, *args, **kwargs) - - async def detail(self, request, *args, **kwargs): - instance = await self.get_object() - serializer = self.get_serializer(instance) - return self.success_json_response(data=await serializer.data) - - -class UpdateModelMixin: - """ - 适用于快速创建更新操作 - """ - - async def put(self, request, *args, **kwargs): - return await self.update(request, *args, **kwargs) - - async def patch(self, request, *args, **kwargs): - return await self.partial_update(request, *args, **kwargs) - - async def update(self, request, *args, **kwargs): - partial = kwargs.pop('partial', False) - instance = await self.get_object() - serializer = self.get_serializer(instance, data=request.data, partial=partial) - await serializer.is_valid(raise_exception=True) - await self.perform_update(serializer) - - # if getattr(instance, '_prefetched_objects_cache', None): - # instance._prefetched_objects_cache = {} - - return self.success_json_response(data=await serializer.data) - - async def perform_update(self, serializer): - await serializer.save() - - async def partial_update(self, request, *args, **kwargs): - kwargs['partial'] = True - return await self.update(request, *args, **kwargs) - - -class DestroyModelMixin: - """ - 用于快速删除 - """ - - async def delete(self, request, *args, **kwargs): - return await self.destroy(request, *args, **kwargs) - - async def destroy(self, request, *args, **kwargs): - instance = await self.get_object() - await self.perform_destroy(instance) - return self.success_json_response(status=HttpStatus.HTTP_204_NO_CONTENT) - - async def perform_destroy(self, instance): - await instance.delete() + pass -- Gitee From 275ada4eddefc897679ff068a8f6d868a00daabb Mon Sep 17 00:00:00 2001 From: LaoSi Date: Thu, 22 Apr 2021 17:52:17 +0800 Subject: [PATCH 34/34] =?UTF-8?q?=E4=BC=98=E5=8C=96=20ViewSet=20=E4=B8=8E?= =?UTF-8?q?=20Route=20=E5=AE=9E=E7=8E=B0=E6=96=B9=E5=BC=8F=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E4=BA=86=E5=87=A0=E4=B8=AA=E5=B0=8Fbug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 尚未完成 测试用例 --- Apps/__init__.py | 24 ----- Cofing/__init__.py | 20 ---- Cofing/develop.py | 10 -- Cofing/formal.py | 10 -- Cofing/local.py | 10 -- db.py | 98 ----------------- sanic_rest_framework/converter.py | 13 +-- sanic_rest_framework/exceptions.py | 158 ++++++++++++++-------------- sanic_rest_framework/filters.py | 70 ++++++++++++ sanic_rest_framework/mixins.py | 3 - sanic_rest_framework/paginations.py | 11 +- sanic_rest_framework/request.py | 3 +- sanic_rest_framework/routes.py | 59 ----------- sanic_rest_framework/serializers.py | 30 +++--- sanic_rest_framework/views.py | 74 ++----------- 15 files changed, 189 insertions(+), 404 deletions(-) delete mode 100644 Apps/__init__.py delete mode 100644 Cofing/__init__.py delete mode 100644 Cofing/develop.py delete mode 100644 Cofing/formal.py delete mode 100644 Cofing/local.py delete mode 100644 db.py diff --git a/Apps/__init__.py b/Apps/__init__.py deleted file mode 100644 index a005523..0000000 --- a/Apps/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -""" -@Author: WangYuXiang -@E-mile: Hill@3io.cc -@CreateTime: 2021/1/15 14:01 -@DependencyLibrary: 无 -@MainFunction:无 -@FileDoc: - __init__.py - 工厂函数 -""" -import os - -from sanic import Sanic - -from Cofing import get_config -from tortoise.contrib.sanic import register_tortoise - -def create_app(): - app = Sanic(__name__) - app.config.from_object(get_config(config_name='develop')) - register_tortoise( - app, config=db_config, modules={"models": ["models"]}, generate_schemas=False - ) - return app diff --git a/Cofing/__init__.py b/Cofing/__init__.py deleted file mode 100644 index 117b32e..0000000 --- a/Cofing/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -""" -@Author: WangYuXiang -@E-mile: Hill@3io.cc -@CreateTime: 2021/1/15 14:07 -@DependencyLibrary: 无 -@MainFunction:无 -@FileDoc: - __init__.py - 配置初始化文件 -""" -import importlib - - -def get_config(config_name): - """ - 得到配置文件 - :param config_name: - :return: - """ - return importlib.import_module('Config.{}.Config'.format(config_name)) diff --git a/Cofing/develop.py b/Cofing/develop.py deleted file mode 100644 index 2c11a9f..0000000 --- a/Cofing/develop.py +++ /dev/null @@ -1,10 +0,0 @@ -""" -@Author: WangYuXiang -@E-mile: Hill@3io.cc -@CreateTime: 2021/1/15 14:04 -@DependencyLibrary: 无 -@MainFunction:无 -@FileDoc: - develop.py - 开发环境配置文件 -""" diff --git a/Cofing/formal.py b/Cofing/formal.py deleted file mode 100644 index d8f6e78..0000000 --- a/Cofing/formal.py +++ /dev/null @@ -1,10 +0,0 @@ -""" -@Author: WangYuXiang -@E-mile: Hill@3io.cc -@CreateTime: 2021/1/15 14:04 -@DependencyLibrary: 无 -@MainFunction:无 -@FileDoc: - formal.py - 正式运行环境配置文件 -""" diff --git a/Cofing/local.py b/Cofing/local.py deleted file mode 100644 index ac45825..0000000 --- a/Cofing/local.py +++ /dev/null @@ -1,10 +0,0 @@ -""" -@Author: WangYuXiang -@E-mile: Hill@3io.cc -@CreateTime: 2021/1/15 14:03 -@DependencyLibrary: 无 -@MainFunction:无 -@FileDoc: - local.py - 本地运行配置文件 -""" diff --git a/db.py b/db.py deleted file mode 100644 index 2ff7457..0000000 --- a/db.py +++ /dev/null @@ -1,98 +0,0 @@ -""" -@Author: WangYuXiang -@E-mile: Hill@3io.cc -@CreateTime: 2021/1/15 14:00 -@DependencyLibrary: 无 -@MainFunction:无 -@FileDoc: - db.py - 文件说明 -""" -from datetime import date -from enum import Enum, IntEnum - -from tortoise import fields -from tortoise import Model -from tortoise.fields import ForeignKeyRelation, ReverseRelation - - -class Enum1(IntEnum): - OK = 1 - BAD = 2 - - -class CharEnum(Enum): - OK = '好' - BAD = '坏' - - -class TestModel(Model): - char_field = fields.CharField(max_length=8, null=True) - float_field = fields.FloatField() - date_field = fields.DateField() - int_field = fields.IntField() - decimal_field = fields.DecimalField(max_digits=13, decimal_places=3) - datetime_field = fields.DatetimeField() - int_enum_field = fields.IntEnumField(enum_type=Enum1) - char_enum_field = fields.CharEnumField(enum_type=CharEnum) - boolean_field = fields.BooleanField() - small_int_field = fields.SmallIntField() - big_int_field = fields.BigIntField() - text_field = fields.TextField() - json_field = fields.JSONField() - uuid_field = fields.UUIDField() - one_to_many: ForeignKeyRelation["TestManyToOneModel"] = fields.ForeignKeyField('models.TestManyToOneModel', related_name='many_to_one', null=True) - one_to_one: fields.OneToOneRelation["TestOneToOneModel"] = fields.OneToOneField('models.TestOneToOneModel', related_name='one_to_one', null=True) - many_to_many: fields.ManyToManyRelation["TestModel"] = fields.ManyToManyField("models.TestModel", related_name="many_2_many", through="many_many", null=True) - - -class TestOneTOManyModel(Model): - one_to_many: ForeignKeyRelation["TestModel"] = fields.ForeignKeyField('models.TestModel', related_name='many_to_one') - - -class TestManyToOneModel(Model): - many_to_one: fields.ReverseRelation["TestModel"] - - -class TestOneToOneModel(Model): - name = fields.CharField(max_length=8, null=True) - - -class ManyToManyModel(Model): - many_2_many: fields.ManyToManyRelation["TestModel"] - - -class UserModel(Model): - name = fields.CharField(max_length=8, null=False) - birthday = fields.DateField() - phone = fields.CharField(max_length=11) - balance = fields.DecimalField(13, 3) - address: fields.ManyToManyRelation["AddressModel"] = fields.ManyToManyField( - 'models.AddressModel', through='user2address', related_name='user') - - -class AddressModel(Model): - phone = fields.CharField(12, null=False) - address = fields.CharField(100) - house_number = fields.CharField(100) - # user: fields.ManyToManyRelation[UserModel] - - -class SchoolModel(Model): - name = fields.CharField(12) - address: fields.OneToOneRelation["AddressModel"] = fields.OneToOneField("models.AddressModel", 'school') - - -class ClassRoomModel(Model): - room_number = fields.CharField(18) - student_count = fields.IntField() - students: ReverseRelation['StudentModel'] - - -class StudentModel(Model): - name = fields.CharField(max_length=12, null=False) - class_room: ForeignKeyRelation["ClassRoomModel"] = fields.ForeignKeyField('models.ClassRoomModel', 'students') - - -class DateSeriesModel(Model): - name = fields.TimeDeltaField() diff --git a/sanic_rest_framework/converter.py b/sanic_rest_framework/converter.py index 987d04c..53ffa92 100644 --- a/sanic_rest_framework/converter.py +++ b/sanic_rest_framework/converter.py @@ -14,10 +14,10 @@ """ from tortoise import fields + from sanic_rest_framework.fields import ( - CharField, ChoiceField, IntegerField, BooleanField, - DecimalField, DateTimeField, DateField, TimeField, - FloatField, EnumChoiceField + CharField, IntegerField, BooleanField, + DecimalField, DateTimeField, DateField, FloatField, EnumChoiceField ) @@ -56,12 +56,7 @@ class ModelConverter(ModelConverterBase): read_only = field_kwargs.get('read_only', False) write_only = field_kwargs.get('write_only', False) required = not model_field.null - # if read_only and required: - # raise ValueError('{}序列化器内字段{}为必填项不能使用read_only属性'.format( - # type(serializer).__name__, model_field.model_field_name)) - if model_field.pk: - field_kwargs['read_only'] = True - field_kwargs['required'] = False + kwargs = { 'read_only': read_only, 'write_only': write_only, diff --git a/sanic_rest_framework/exceptions.py b/sanic_rest_framework/exceptions.py index f6a110c..3f20bcd 100644 --- a/sanic_rest_framework/exceptions.py +++ b/sanic_rest_framework/exceptions.py @@ -14,85 +14,85 @@ from sanic.response import json from sanic_rest_framework.status import HttpStatus, RuleStatus - -class ValidationError(Exception): - """验证器通用错误类 发生错误即抛出此类""" - - def __init__(self, message, code=None, params=None): - super().__init__(message, code, params) - if isinstance(message, ValidationError): - if hasattr(message, 'error_dict'): - message = message.error_dict - elif not hasattr(message, 'message'): - message = message.error_list - else: - message, code, params = message.message, message.code, message.params - if isinstance(message, dict): - self.error_dict = {} - for field, msg in message.items(): - if not isinstance(msg, ValidationError): - msg = ValidationError(msg) - if hasattr(msg, 'error_dict'): - self.error_dict[field] = [msg.error_dict] - else: - self.error_dict[field] = msg.error_list - elif isinstance(message, list): - self.error_list = [] - for message in message: - if not isinstance(message, ValidationError): - message = ValidationError(message) - if hasattr(message, 'error_dict'): - self.error_list.extend(sum(message.error_dict.values(), [])) - else: - self.error_list.extend(message.error_list) - else: - self.message = message - self.code = code - self.params = params - self.error_list = [self] - - @property - def message_dict(self): - getattr(self, 'error_dict') - return dict(self) - - @property - def messages(self): - if hasattr(self, 'error_dict'): - return sum(dict(self).values(), []) - return list(self) - - def update_error_dict(self, error_dict): - if hasattr(self, 'error_dict'): - for field, error_list in self.error_dict.items(): - error_dict.setdefault(field, []).extend(error_list) - else: - error_dict.setdefault('__all__', []).extend(self.error_list) - return error_dict - - def __iter__(self): - if hasattr(self, 'error_dict'): - for field, errors in self.error_dict.items(): - yield field, list(ValidationError(errors)) - else: - for error in self.error_list: - message = error.message - if error.params: - message %= error.params - yield str(message) - - def __str__(self): - if hasattr(self, 'error_dict'): - return repr(dict(self)) - return repr(list(self)) - - def __repr__(self): - return 'ValidationError(%s)' % self - - def __eq__(self, other): - if not isinstance(other, ValidationError): - return NotImplemented - return hash(self) == hash(other) +# +# class ValidationError(Exception): +# """验证器通用错误类 发生错误即抛出此类""" +# +# def __init__(self, message, code=None, params=None): +# super().__init__(message, code, params) +# if isinstance(message, ValidationError): +# if hasattr(message, 'error_dict'): +# message = message.error_dict +# elif not hasattr(message, 'message'): +# message = message.error_list +# else: +# message, code, params = message.message, message.code, message.params +# if isinstance(message, dict): +# self.error_dict = {} +# for field, msg in message.items(): +# if not isinstance(msg, ValidationError): +# msg = ValidationError(msg) +# if hasattr(msg, 'error_dict'): +# self.error_dict[field] = [msg.error_dict] +# else: +# self.error_dict[field] = msg.error_list +# elif isinstance(message, list): +# self.error_list = [] +# for message in message: +# if not isinstance(message, ValidationError): +# message = ValidationError(message) +# if hasattr(message, 'error_dict'): +# self.error_list.extend(sum(message.error_dict.values(), [])) +# else: +# self.error_list.extend(message.error_list) +# else: +# self.message = message +# self.code = code +# self.params = params +# self.error_list = [self] +# +# @property +# def message_dict(self): +# getattr(self, 'error_dict') +# return dict(self) +# +# @property +# def messages(self): +# if hasattr(self, 'error_dict'): +# return sum(dict(self).values(), []) +# return list(self) +# +# def update_error_dict(self, error_dict): +# if hasattr(self, 'error_dict'): +# for field, error_list in self.error_dict.items(): +# error_dict.setdefault(field, []).extend(error_list) +# else: +# error_dict.setdefault('__all__', []).extend(self.error_list) +# return error_dict +# +# def __iter__(self): +# if hasattr(self, 'error_dict'): +# for field, errors in self.error_dict.items(): +# yield field, list(ValidationError(errors)) +# else: +# for error in self.error_list: +# message = error.message +# if error.params: +# message %= error.params +# yield str(message) +# +# def __str__(self): +# if hasattr(self, 'error_dict'): +# return repr(dict(self)) +# return repr(list(self)) +# +# def __repr__(self): +# return 'ValidationError(%s)' % self +# +# def __eq__(self, other): +# if not isinstance(other, ValidationError): +# return NotImplemented +# return hash(self) == hash(other) class ValidatorAssertError(Exception): diff --git a/sanic_rest_framework/filters.py b/sanic_rest_framework/filters.py index 76a313d..d61e24a 100644 --- a/sanic_rest_framework/filters.py +++ b/sanic_rest_framework/filters.py @@ -14,6 +14,7 @@ """ LOOKUP_SEP = '__' +from tortoise.models import Q class SimpleFilter: @@ -27,6 +28,7 @@ class SimpleFilter: class SearchFilter(SimpleFilter): + """以And进行查询""" lookup_prefixes = { '^': 'istartswith', '$': 'iendswith', @@ -106,3 +108,71 @@ class SearchFilter(SimpleFilter): """ values = request.args.get(field_name) return ''.join(values) + + +class OrSearchFilter(SearchFilter): + """以OR进行查询""" + + def filter_queryset(self, request, queryset, view): + """ + 根据定义的搜索字段过滤传入的queryset + :param request: 当前请求 + :param queryset: 查询对象 + :param view: 当前视图 + :return: + """ + search_fields = self.get_search_fields(request, view) + if not search_fields: + return queryset + orm_filters = [] + + for search_field in search_fields: + orm_filters.append(Q(**self.construct_orm_filter(search_field, request, view))) + if not orm_filters: + return queryset.filter() + return queryset.filter(Q(*orm_filters, join_type=Q.OR)) + + def dismantle_search_field(self, search_field): + """ + 拆解带有特殊字符的搜索字段 + :param search_field: 搜索字段 + :return: (field_name, lookup_suffix) + """ + lookup_suffix_keys = list(self.lookup_prefixes.keys()) + lookup_suffix = None + field_name = search_field + for lookup_suffix_key in lookup_suffix_keys: + if lookup_suffix_key in search_field: + lookup_suffix = self.lookup_prefixes[lookup_suffix_key] + field_name = search_field[len(lookup_suffix_key):] + return field_name, lookup_suffix + return field_name, lookup_suffix + + def construct_orm_filter(self, search_field, request, view): + """ + 构造适用于orm的过滤参数 + :param search_field: 搜索字段 + :param request: 当前请求 + :param view: 视图 + :return: + """ + field_name, lookup_suffix = self.dismantle_search_field(search_field) + args = request.args + + if field_name not in args: + return {} + if lookup_suffix: + orm_lookup = LOOKUP_SEP.join([field_name, lookup_suffix]) + else: + orm_lookup = field_name + return {orm_lookup: self.get_filter_value(request, field_name)} + + def get_filter_value(self, request, field_name): + """ + 根据字段名从请求中得到值 + :param request: 当前请求 + :param field_name: 字段名 + :return: + """ + values = request.args.get(field_name) + return ''.join(values) diff --git a/sanic_rest_framework/mixins.py b/sanic_rest_framework/mixins.py index 04deaaf..9083ea8 100644 --- a/sanic_rest_framework/mixins.py +++ b/sanic_rest_framework/mixins.py @@ -96,9 +96,6 @@ class UpdateModelMixin: await serializer.is_valid(raise_exception=True) await self.perform_update(serializer) - # if getattr(instance, '_prefetched_objects_cache', None): - # instance._prefetched_objects_cache = {} - return self.success_json_response(data=await serializer.data) async def perform_update(self, serializer): diff --git a/sanic_rest_framework/paginations.py b/sanic_rest_framework/paginations.py index 8c11983..dedbed6 100644 --- a/sanic_rest_framework/paginations.py +++ b/sanic_rest_framework/paginations.py @@ -42,6 +42,7 @@ class GeneralPagination(BasePagination): @property def count(self): + """总记录数""" assert hasattr(self, '_count'), '必须先执行 `.paginate_queryset()` 函数才能使用.count' return self._count @@ -90,21 +91,26 @@ class GeneralPagination(BasePagination): return self.page - 1 async def paginate_queryset(self, queryset, request, view): + """为queryset添加分页查询条件""" self.page = self.get_query_page(request) self.page_size = self.get_query_page_size(request) if not isinstance(queryset, Model): queryset = queryset.filter() self._count = await queryset.count() - return queryset.limit(self.page_size).offset(self.page * self.page_size) + return queryset.limit(self.page_size).offset((self.page - 1) * self.page_size) def get_query_page(self, request): + """得到页数""" try: - page = int(request.args.get(self.page_query_param, 0)) + page = int(request.args.get(self.page_query_param, 1)) except ValueError as exc: raise APIException('发生错误的分页数据', http_status=HttpStatus.HTTP_400_BAD_REQUEST) + if page < 1: + page = 1 return page def get_query_page_size(self, request): + """得到页记录数""" try: page = int(request.args.get(self.page_size_query_param, self.page_size)) if page > self.max_page_size: @@ -114,6 +120,7 @@ class GeneralPagination(BasePagination): return page def response(self, request, data): + """便捷的response""" return { 'count': self.count, 'next': self.get_next_link(request), diff --git a/sanic_rest_framework/request.py b/sanic_rest_framework/request.py index 148aee8..bdfe292 100644 --- a/sanic_rest_framework/request.py +++ b/sanic_rest_framework/request.py @@ -30,4 +30,5 @@ class SRFRequest(SanicRequest): data = self.json except InvalidUsage as exc: data = self.form - return {} if data is None else data + data = {} if data is None else data + return data diff --git a/sanic_rest_framework/routes.py b/sanic_rest_framework/routes.py index 32b5913..f82f8b3 100644 --- a/sanic_rest_framework/routes.py +++ b/sanic_rest_framework/routes.py @@ -142,62 +142,3 @@ class ViewSetRouter(BaseRoute): if hasattr(viewset, action): bound_methods[method] = action return bound_methods - - -class DefaultRoute(BaseRoute): - def register_route(self, viewset: object, prefix: str, name: str = None, is_base: bool = False): - """ - 注册路由到路由管理类 - :param viewset: 视图,无需 - :param prefix: - :param name: - :param is_base: - :return: - """ - if name is None: - name = prefix.replace('/', '_') - base_method_group = LIST_METHOD_GROUP - if hasattr(viewset, 'detail') and viewset.detail: - base_method_group = DETAIL_METHOD_GROUP - - viewset_method_list = self.get_viewset_method_list(viewset) - viewset_dynamic_method = [i for i in viewset_method_list if i in base_method_group['dynamic_method']] - viewset_static_method = [i for i in viewset_method_list if i in base_method_group['static_method']] - - if viewset_dynamic_method: - self.routes.append({ - 'handler': viewset.as_view(viewset_dynamic_method), - 'uri': self.dynamic_uri.format(prefix=prefix, lookup_field=viewset.lookup_field), - 'name': '{name}_detail'.format(name=name), - 'is_base': is_base - - }) - if viewset_static_method: - self.routes.append({ - 'handler': viewset.as_view(viewset_static_method), - 'uri': self.static_uri.format(prefix=prefix), - 'name': '{name}_list'.format(name=name), - 'is_base': is_base - }) - - def get_viewset_method_list(self, viewset): - """ - 得到viewSet所有请求方法 - :param viewset: 类视图 - :return: - """ - methods = [] - for method in ALL_METHOD: - if hasattr(viewset, method.lower()): - methods.append(method) - return methods - - @property - def urls(self): - return self.routes - - def initialize(self, destination: Union[Sanic, Blueprint]): - """注册路由""" - for route in self.routes: - route.pop('is_base') - destination.add_route(**route) diff --git a/sanic_rest_framework/serializers.py b/sanic_rest_framework/serializers.py index 6dc2d7c..7561785 100644 --- a/sanic_rest_framework/serializers.py +++ b/sanic_rest_framework/serializers.py @@ -10,21 +10,16 @@ """ import copy import inspect -from asyncio import coroutine from collections import OrderedDict -from typing import Any, Mapping, Coroutine +from typing import Any, Mapping -from tortoise import models, Model -from tortoise.queryset import QuerySet from tortoise import fields as tortoise_fields -from tortoise.fields.relational import ForeignKeyFieldInstance, OneToOneFieldInstance, ManyToManyFieldInstance, ManyToManyRelation +from sanic_rest_framework.converter import ModelConverter from sanic_rest_framework.fields import ( empty, SkipField, - Field, CharField, IntegerField, FloatField, DecimalField, BooleanField, DateTimeField, DateField, TimeField, - ChoiceField, SerializerMethodField + Field ) -from sanic_rest_framework.converter import ModelConverter from .exceptions import ValidationException from .helpers import BindingDict @@ -300,14 +295,19 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass): res = OrderedDict() fields = self._readable_fields for field in fields: + method = getattr(self, 'read_{}'.format(field.field_name), None) try: value = await field.get_internal_value(data) except SkipField: continue - if value is None: - res[field.field_name] = None - else: - res[field.field_name] = await field.internal_to_external(value) + if value is not None: + value = await field.internal_to_external(value) + if method: + value = method(value) + if inspect.isawaitable(value): + value = await value + res[field.field_name] = value + return res # 反序列化 @@ -609,7 +609,11 @@ class ModelSerializer(Serializer): if m2m_field in validated_data: many_to_many[m2m_field] = validated_data.pop(m2m_field) try: - instance = await ModelClass.create(**validated_data) + instance = ModelClass() + for attr, value in validated_data.items(): + if attr not in many_to_many: + setattr(instance, attr, value) + await instance.save() except TypeError as exc: raise exc diff --git a/sanic_rest_framework/views.py b/sanic_rest_framework/views.py index 07ac140..16adc73 100644 --- a/sanic_rest_framework/views.py +++ b/sanic_rest_framework/views.py @@ -5,7 +5,7 @@ @DependencyLibrary: @MainFunction: @FileDoc: - views.py + login.py 基础视图文件 BaseView 只实现路由分发的基础视图 GeneralView 通用视图,可以基于其实现增删改查,提供权限套件 @@ -25,13 +25,11 @@ from sanic_rest_framework import mixins from sanic_rest_framework.constant import ALL_METHOD, DEFAULT_METHOD_MAP from sanic_rest_framework.exceptions import APIException, ValidationException from sanic_rest_framework.filters import SearchFilter -from sanic_rest_framework.mixins import CreateModelMixin, ListModelMixin, DestroyModelMixin, UpdateModelMixin, \ - RetrieveModelMixin from sanic_rest_framework.status import RuleStatus, HttpStatus -from simplejson import dumps +from simplejson import dumps, JSONEncoder from tortoise.queryset import QuerySet -__all__ = ['BaseView', 'GeneralViewView', 'ViewSetView', 'CRUDView', 'CLUDView'] +__all__ = ['BaseView', 'GeneralView', 'ViewSetView', 'ModelViewSet'] class BaseView: @@ -77,8 +75,9 @@ class BaseView: self.method_map = method_map for method, action in method_map.items(): - handler = getattr(self, action) - setattr(self, method, handler) + handler = getattr(self, action, None) + if handler: + setattr(self, method, handler) self.request = request self.args = args @@ -95,64 +94,7 @@ class BaseView: return view -# -# class BaseView: -# """只实现路由分发的基础视图 -# 在使用时应当开放全部路由 ALL_METHOD -# app.add_route('/test', BaseView.as_view(), 'test', ALL_METHOD) -# 如需限制路由则 -# app.add_route('/test', BaseView.as_view(['GET','POST']), 'test', ALL_METHOD) -# OR -# app.add_route('/test', BaseView.as_view(), 'test', ['GET','POST']) -# 注意以上方法的报错是不可控的 -# """ -# -# async def dispatch(self, request, *args, **kwargs): -# """分发路由""" -# request.user = None -# method = request.method -# if method not in self.licensed_methods: -# return HTTPResponse('405请求方法错误', status=405) -# handler = getattr(self, method.lower(), None) -# response = handler(request, *args, **kwargs) -# if inspect.isawaitable(response): -# response = await response -# return response -# -# @classmethod -# def get_methods(cls): -# methods = [] -# for method in ALL_METHOD: -# if hasattr(cls, method.lower()): -# methods.append(method) -# return methods -# -# @classmethod -# def as_view(cls, methods=None, *class_args, **class_kwargs): -# -# # 许可的方法 -# if methods is None: -# methods = cls.get_methods() -# -# # 返回的响应方法闭包 -# def view(request, *args, **kwargs): -# self = view.base_class(*class_args, **class_kwargs) -# self.licensed_methods = methods -# self.request = request -# self.args = args -# self.kwargs = kwargs -# self.app = request.app -# return self.dispatch(request, *args, **kwargs) -# -# view.base_class = cls -# view.API_DOC_CONFIG = class_kwargs.get('API_DOC_CONFIG') # 未来的API文档配置属性+ -# view.__doc__ = cls.__doc__ -# view.__module__ = cls.__module__ -# view.__name__ = cls.__name__ -# return view - - -class GeneralViewView(BaseView): +class GeneralView(BaseView): """通用视图,可以基于其实现增删改查,提供权限套件""" authentication_classes = () permission_classes = () @@ -289,7 +231,7 @@ class GeneralViewView(BaseView): await self.check_throttles(request) -class ViewSetView(GeneralViewView): +class ViewSetView(GeneralView): """ 视图集视图,可以配合Mixin实现复杂的视图集, 数据来源基于模型查询集,可以配合Route组件实现便捷的路由管理 -- Gitee