from rest_framework.exceptions import APIException from rest_framework.fields import empty from rest_framework import serializers, status from rest_framework.request import Request from apps.qc.models import QcWechatbizuserinfo, QcCorpinfo from apps.qc.utils import get_query_by_corpkey from apps.user.models import User from libs.wechat import WechatWorkerUtil from utils.CustomField import CustomDateTimeField from utils.exceptions import CustomProjectException class ValidationError(APIException): status_code = status.HTTP_400_BAD_REQUEST default_detail = 'Invalid input.' default_code = 'invalid' def __init__(self, detail=None): self.detail = detail class CurrentUserIdDefault(serializers.CurrentUserDefault): def __call__(self, serializer_field): user = super().__call__(serializer_field) return user.pk class CurrentIpDefault: requires_context = True def __call__(self, serializer_field): request = serializer_field.context['request'] return request.META.get('REMOTE_ADDR') def __repr__(self): return '%s()' % self.__class__.__name__ class BaseSerializer(serializers.ModelSerializer): def __init__(self, instance=None, data=empty, request=None, **kwargs): super().__init__(instance, data, **kwargs) self.request: Request = request or self.context.get("request", None) self.user = self.request.user if self.request else "AnonymousUser" class QcUserInfoSerializer(BaseSerializer): userinfo = serializers.SerializerMethodField() user_mapping = None def get_userinfo(self, obj): view = self.context.get("view") queryset = self.context.get("queryset") if not hasattr(view, 'userinfo_fields'): return if 'pk' not in view.userinfo_fields: view.userinfo_fields.append('pk') if self.user_mapping is None: self.user_mapping = {} uids = [q.uid for q in queryset] users = User.objects.filter( pk__in=uids).values(*view.userinfo_fields) for user in users: self.user_mapping[user.get('pk')] = user return self.user_mapping.get(obj.uid, {}) class QcWechatUserInfoSerializer(BaseSerializer): wechat_userinfo = serializers.SerializerMethodField() wechat_user_mapping = None def get_wechat_userinfo(self, obj): view = self.context.get("view") queryset = self.context.get("queryset") if not hasattr(view, 'wechat_userinfo_fields'): wechat_userinfo_fields = ('username', 'alias', 'userid') else: wechat_userinfo_fields = view.wechat_userinfo_fields if self.wechat_user_mapping is None: self.wechat_user_mapping = {} userids = [q.userid for q in queryset] users = QcWechatbizuserinfo.objects.filter( userid__in=userids).values(*wechat_userinfo_fields) for user in users: self.wechat_user_mapping[user.get('userid')] = user return self.wechat_user_mapping.get(obj.userid, {}) class CurrentUserSerializer(BaseSerializer): user = serializers.HiddenField( default=serializers.CurrentUserDefault() ) class CurrentUserIdSerializer(BaseSerializer): uid = serializers.HiddenField( default=CurrentUserIdDefault() ) class UserNameSerializer(serializers.ModelSerializer): user = serializers.StringRelatedField(source='user.username') class DateTimeListSerializer(BaseSerializer): create_time = CustomDateTimeField(read_only=True) update_time = CustomDateTimeField(read_only=True) class CorpkeySerializer(BaseSerializer): corpkey = serializers.SerializerMethodField() def get_corpkey(self, obj): return f'{obj.corpid}-{obj.agentid}' class CorpSerializerMixin: def get_corp(self, corpkey=None): if not hasattr(self, 'corp'): if corpkey is None: corpkey = self.request.query_params.get('corpkey') or self.request.data.get('corpkey') if corpkey is None: raise ValidationError('corpkey is required') uid = self.request.user.pk expression = get_query_by_corpkey(corpkey) corp = QcCorpinfo.objects.filter(**expression, uid=uid).first() if corp is None: raise CustomProjectException(detail='企业不存在') setattr(self, 'corp', corp) return getattr(self, 'corp') class WechatWorkerMixin(CorpSerializerMixin): def get_wechat_worker(self): if not hasattr(self, 'wechat_worker'): corp = self.get_corp() wechat_worker = WechatWorkerUtil(corp.corpid, corp.appsecret) setattr(self, 'wechat_worker', wechat_worker) return getattr(self, 'wechat_worker') class SerializerFactory: @staticmethod def build(model, fields, **kwargs): meta = type(str('Meta'), (object,), {'model': model, 'fields': fields}) kwargs.update({'Meta': meta}) serializer_class = type(str('%sModelSerializerByType' % model._meta.object_name), (BaseSerializer,), kwargs) return serializer_class