import datetime from django.db.models import Q from django.db.models.fields import IntegerField from rest_framework import status from rest_framework.decorators import action from rest_framework.permissions import IsAuthenticated as DRFIsAuthenticated from rest_framework.viewsets import ModelViewSet, GenericViewSet from utils.jwt_authenticate import MyJSONWebTokenAuthentication from utils.pagination import CustomPagination from utils.response import ApiResponse from utils.tools import generate_options_by_choices class IsAuthenticated(DRFIsAuthenticated): def has_permission(self, request, view): # 不等于匿名用户 return bool(request.user and str(request.user) != 'AnonymousUser') class BaseModelViewSet(ModelViewSet): permission_classes = [IsAuthenticated] authentication_classes = [MyJSONWebTokenAuthentication] white_list = [] filter_fields = [] fuzzy_filter_fields = [] exclude_param_blank = False multi_filter_fields = [] pagination_class = CustomPagination def list(self, request, *args, **kwargs): queryset = self.filter_queryset(self.get_queryset()) _page = request.query_params.get('page') page_size = request.query_params.get('page_size') page = self.paginate_queryset(queryset) context = self.get_serializer_context() context['queryset'] = page if page is not None and any([_page, page_size]): serializer = self.get_serializer(page, many=True, context=context) return self.get_paginated_response(serializer.data) serializer = self.get_serializer(queryset, many=True, context=context) return ApiResponse(data=serializer.data) def create(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) self.perform_create(serializer) headers = self.get_success_headers(serializer.data) return ApiResponse(data=serializer.data, status=status.HTTP_201_CREATED, headers=headers) def retrieve(self, request, *args, **kwargs): instance = self.get_object() serializer = self.get_serializer(instance) return ApiResponse(data=serializer.data) def update(self, request, *args, **kwargs): partial = kwargs.pop('partial', False) instance = self.get_object() serializer = self.get_serializer( instance, data=request.data, partial=partial) serializer.is_valid(raise_exception=True) self.perform_update(serializer) if getattr(instance, '_prefetched_objects_cache', None): # If 'prefetch_related' has been applied to a queryset, we need to # forcibly invalidate the prefetch cache on the instance. instance._prefetched_objects_cache = {} return ApiResponse(data=serializer.data) def destroy(self, request, *args, **kwargs): instance = self.get_object() if hasattr(instance, 'is_delete'): instance.is_delete = True instance.save() else: instance.delete() return ApiResponse(message='删除成功', data=None) def get_serializer_class(self): if isinstance(self.serializer_class, dict): return self.serializer_class.get(self.action) or self.serializer_class.get('default') return self.serializer_class def check_permissions(self, request): if request.method.lower() in ['options']: return True if self.action in self.white_list: return True return super().check_permissions(request) def get_queryset(self): action_queryset_func = getattr( self, f'filter_{self.action}_queryset', None) if action_queryset_func is not None and callable(action_queryset_func): queryset = action_queryset_func() else: queryset = super().get_queryset() q = Q() if not self.detail: for key in self.filter_fields: if key not in self.request.query_params.keys(): continue value = self.request.query_params.get(key) if not value and not self.exclude_param_blank: continue if key in self.fuzzy_filter_fields: q.add(Q(**{f'{key}__icontains': value}), Q.AND) else: q.add(Q(**{key: value}), Q.AND) for key in self.multi_filter_fields: list_key = f'{key}[]' if list_key in self.request.query_params.keys(): value = self.request.query_params.getlist(list_key) if value is not None: q.add(Q(**{f'{key}__in': value}), Q.AND) if hasattr(self.queryset.model, 'user'): q.add(Q(user=self.request.user), Q.AND) if hasattr(self.queryset.model, 'uid'): uid = self.queryset.model._meta.get_field('uid') if isinstance(uid, IntegerField): q.add(Q(Q(uid=self.request.user.id) | Q(uid=0)), Q.AND) if hasattr(self.queryset.model, 'is_delete'): q.add(Q(is_delete=False), Q.AND) if hasattr(self.queryset.model, 'user') and self.action == 'list': q.add(Q(user_id=self.request.user.id), Q.AND) # queryset = queryset.select_related('user') queryset = queryset.filter(q) return queryset class OrderDataMixin: def order_data(self, data): params = self.request.query_params field = params.get('field') if field is None: return data order = params.get('order', 'descend') def sort_fn(obj): value = obj.get(field) try: value = int(value) return value except Exception: try: return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S') except Exception: return datetime.datetime(year=1998, month=2, day=1) return sorted(data, key=sort_fn, reverse=order == 'descend') class ChoicesMixin: choices_map = {} def get_choices(self): data = {} for key, choices in self.choices_map.items(): data[key] = generate_options_by_choices(choices.choices) return data @action(methods=['get'], detail=False) def choices(self, request, *args, **kwargs): data = self.get_choices() return ApiResponse(data=data) class ReadOnlyModelViewSet(BaseModelViewSet): def destroy(self, request, *args, **kwargs): return ApiResponse() def update(self, request, *args, **kwargs): return ApiResponse() def create(self, request, *args, **kwargs): return ApiResponse() class BaseChoicesModelViewSet(BaseModelViewSet, ChoicesMixin): pass class ReadOnlyChoicesModelViewSet(ReadOnlyModelViewSet, ChoicesMixin): pass class BaseViewSet(GenericViewSet): permission_classes = [IsAuthenticated] authentication_classes = [MyJSONWebTokenAuthentication] def get_serializer_class(self): if isinstance(self.serializer_class, dict): return self.serializer_class.get(self.action) or self.serializer_class.get('default') return self.serializer_class