200 lines
7.2 KiB
Python
200 lines
7.2 KiB
Python
import datetime
|
|
|
|
from django.db.models import Q
|
|
from django.db.models.fields import IntegerField
|
|
from rest_framework import status
|
|
from rest_framework.decorators import action
|
|
from rest_framework.permissions import IsAuthenticated as DRFIsAuthenticated
|
|
from rest_framework.viewsets import ModelViewSet, GenericViewSet
|
|
|
|
from utils.jwt_authenticate import MyJSONWebTokenAuthentication
|
|
from utils.pagination import CustomPagination
|
|
from utils.response import ApiResponse
|
|
from utils.tools import generate_options_by_choices
|
|
|
|
|
|
class IsAuthenticated(DRFIsAuthenticated):
|
|
def has_permission(self, request, view):
|
|
# 不等于匿名用户
|
|
return bool(request.user and str(request.user) != 'AnonymousUser')
|
|
|
|
|
|
class BaseModelViewSet(ModelViewSet):
|
|
permission_classes = [IsAuthenticated]
|
|
authentication_classes = [MyJSONWebTokenAuthentication]
|
|
white_list = []
|
|
filter_fields = []
|
|
fuzzy_filter_fields = []
|
|
exclude_param_blank = False
|
|
multi_filter_fields = []
|
|
pagination_class = CustomPagination
|
|
|
|
def list(self, request, *args, **kwargs):
|
|
queryset = self.filter_queryset(self.get_queryset())
|
|
_page = request.query_params.get('page')
|
|
page_size = request.query_params.get('page_size')
|
|
page = self.paginate_queryset(queryset)
|
|
context = self.get_serializer_context()
|
|
context['queryset'] = page
|
|
if page is not None and any([_page, page_size]):
|
|
serializer = self.get_serializer(page, many=True, context=context)
|
|
return self.get_paginated_response(serializer.data)
|
|
serializer = self.get_serializer(queryset, many=True, context=context)
|
|
return ApiResponse(data=serializer.data)
|
|
|
|
def create(self, request, *args, **kwargs):
|
|
serializer = self.get_serializer(data=request.data)
|
|
serializer.is_valid(raise_exception=True)
|
|
self.perform_create(serializer)
|
|
headers = self.get_success_headers(serializer.data)
|
|
return ApiResponse(data=serializer.data, status=status.HTTP_201_CREATED, headers=headers)
|
|
|
|
def retrieve(self, request, *args, **kwargs):
|
|
instance = self.get_object()
|
|
serializer = self.get_serializer(instance)
|
|
return ApiResponse(data=serializer.data)
|
|
|
|
def update(self, request, *args, **kwargs):
|
|
partial = kwargs.pop('partial', False)
|
|
instance = self.get_object()
|
|
serializer = self.get_serializer(
|
|
instance, data=request.data, partial=partial)
|
|
serializer.is_valid(raise_exception=True)
|
|
self.perform_update(serializer)
|
|
|
|
if getattr(instance, '_prefetched_objects_cache', None):
|
|
# If 'prefetch_related' has been applied to a queryset, we need to
|
|
# forcibly invalidate the prefetch cache on the instance.
|
|
instance._prefetched_objects_cache = {}
|
|
|
|
return ApiResponse(data=serializer.data)
|
|
|
|
def destroy(self, request, *args, **kwargs):
|
|
instance = self.get_object()
|
|
if hasattr(instance, 'is_delete'):
|
|
instance.is_delete = True
|
|
instance.save()
|
|
else:
|
|
instance.delete()
|
|
return ApiResponse(message='删除成功', data=None)
|
|
|
|
def get_serializer_class(self):
|
|
if isinstance(self.serializer_class, dict):
|
|
return self.serializer_class.get(self.action) or self.serializer_class.get('default')
|
|
|
|
return self.serializer_class
|
|
|
|
def check_permissions(self, request):
|
|
if request.method.lower() in ['options']:
|
|
return True
|
|
if self.action in self.white_list:
|
|
return True
|
|
return super().check_permissions(request)
|
|
|
|
def get_queryset(self):
|
|
action_queryset_func = getattr(
|
|
self, f'filter_{self.action}_queryset', None)
|
|
if action_queryset_func is not None and callable(action_queryset_func):
|
|
queryset = action_queryset_func()
|
|
else:
|
|
queryset = super().get_queryset()
|
|
q = Q()
|
|
if not self.detail:
|
|
for key in self.filter_fields:
|
|
if key not in self.request.query_params.keys():
|
|
continue
|
|
value = self.request.query_params.get(key)
|
|
if not value and not self.exclude_param_blank:
|
|
continue
|
|
if key in self.fuzzy_filter_fields:
|
|
q.add(Q(**{f'{key}__icontains': value}), Q.AND)
|
|
else:
|
|
q.add(Q(**{key: value}), Q.AND)
|
|
|
|
for key in self.multi_filter_fields:
|
|
list_key = f'{key}[]'
|
|
if list_key in self.request.query_params.keys():
|
|
value = self.request.query_params.getlist(list_key)
|
|
if value is not None:
|
|
q.add(Q(**{f'{key}__in': value}), Q.AND)
|
|
if hasattr(self.queryset.model, 'user'):
|
|
q.add(Q(user=self.request.user), Q.AND)
|
|
if hasattr(self.queryset.model, 'uid'):
|
|
uid = self.queryset.model._meta.get_field('uid')
|
|
if isinstance(uid, IntegerField):
|
|
q.add(Q(Q(uid=self.request.user.id) | Q(uid=0)), Q.AND)
|
|
if hasattr(self.queryset.model, 'is_delete'):
|
|
q.add(Q(is_delete=False), Q.AND)
|
|
if hasattr(self.queryset.model, 'user') and self.action == 'list':
|
|
q.add(Q(user_id=self.request.user.id), Q.AND)
|
|
# queryset = queryset.select_related('user')
|
|
queryset = queryset.filter(q)
|
|
return queryset
|
|
|
|
|
|
class OrderDataMixin:
|
|
|
|
def order_data(self, data):
|
|
params = self.request.query_params
|
|
field = params.get('field')
|
|
if field is None:
|
|
return data
|
|
order = params.get('order', 'descend')
|
|
|
|
def sort_fn(obj):
|
|
value = obj.get(field)
|
|
try:
|
|
value = int(value)
|
|
return value
|
|
except Exception:
|
|
try:
|
|
return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S')
|
|
except Exception:
|
|
return datetime.datetime(year=1998, month=2, day=1)
|
|
|
|
return sorted(data, key=sort_fn, reverse=order == 'descend')
|
|
|
|
|
|
class ChoicesMixin:
|
|
choices_map = {}
|
|
|
|
def get_choices(self):
|
|
data = {}
|
|
for key, choices in self.choices_map.items():
|
|
data[key] = generate_options_by_choices(choices.choices)
|
|
return data
|
|
|
|
@action(methods=['get'], detail=False)
|
|
def choices(self, request, *args, **kwargs):
|
|
data = self.get_choices()
|
|
return ApiResponse(data=data)
|
|
|
|
|
|
class ReadOnlyModelViewSet(BaseModelViewSet):
|
|
def destroy(self, request, *args, **kwargs):
|
|
return ApiResponse()
|
|
|
|
def update(self, request, *args, **kwargs):
|
|
return ApiResponse()
|
|
|
|
def create(self, request, *args, **kwargs):
|
|
return ApiResponse()
|
|
|
|
|
|
class BaseChoicesModelViewSet(BaseModelViewSet, ChoicesMixin):
|
|
pass
|
|
|
|
|
|
class ReadOnlyChoicesModelViewSet(ReadOnlyModelViewSet, ChoicesMixin):
|
|
pass
|
|
|
|
|
|
class BaseViewSet(GenericViewSet):
|
|
permission_classes = [IsAuthenticated]
|
|
authentication_classes = [MyJSONWebTokenAuthentication]
|
|
|
|
def get_serializer_class(self):
|
|
if isinstance(self.serializer_class, dict):
|
|
return self.serializer_class.get(self.action) or self.serializer_class.get('default')
|
|
return self.serializer_class
|