diff --git a/chats/admin.py b/chats/admin.py index 3609acda..2bb099a8 100644 --- a/chats/admin.py +++ b/chats/admin.py @@ -10,12 +10,12 @@ @admin.display(description="Пользователи чата") -def chat_users(obj): +def chat_users(obj) -> str: return f"{obj.get_users_str()}" @admin.display(description="Количество сообщений") -def chat_message_count(obj): +def chat_message_count(obj) -> int: return obj.messages.count() diff --git a/chats/middleware.py b/chats/middleware.py index 593c182e..72eda39a 100644 --- a/chats/middleware.py +++ b/chats/middleware.py @@ -6,6 +6,10 @@ from django.contrib.auth import get_user_model from django.utils.translation import gettext_lazy as _ from rest_framework.exceptions import AuthenticationFailed +from rest_framework.authtoken.models import Token +from users.models import CustomUser +from django.contrib.auth.models import AnonymousUser + User = get_user_model() @@ -22,10 +26,9 @@ class TokenAuthentication: model = None - def get_model(self): + def get_model(self) -> Token: if self.model is not None: return self.model - from rest_framework.authtoken.models import Token return Token @@ -36,7 +39,7 @@ def get_model(self): * user -- The user to which the token belongs """ - def authenticate_credentials(self, key): + def authenticate_credentials(self, key: str) -> CustomUser: model = self.get_model() try: token = model.objects.select_related("user").get(key=key) @@ -48,7 +51,7 @@ def authenticate_credentials(self, key): return token.user - def authenticate(self, token): + def authenticate(self, token: Token) -> CustomUser: """ Returns a `User` if a correct username and password have been supplied Args: @@ -71,14 +74,13 @@ def authenticate(self, token): @database_sync_to_async -def get_user(scope): +def get_user(scope: dict) -> CustomUser | AnonymousUser: """ Return the user model instance associated with the given scope. If no user is retrieved, return an instance of `AnonymousUser`. """ # postpone model import to avoid ImproperlyConfigured error before Django # setup is complete. - from django.contrib.auth.models import AnonymousUser if "token" not in scope: raise ValueError( diff --git a/chats/models.py b/chats/models.py index 3a84ddbd..41c2a9a0 100644 --- a/chats/models.py +++ b/chats/models.py @@ -1,3 +1,4 @@ +from __future__ import annotations from abc import abstractmethod from typing import List @@ -7,6 +8,7 @@ from files.models import UserFile from projects.models import Project +from users.models import CustomUser User = get_user_model() @@ -21,10 +23,10 @@ class BaseChat(models.Model): created_at = models.DateTimeField(auto_now_add=True) - def get_last_message(self): + def get_last_message(self) -> BaseMessage: return self.messages.last() - def get_users_str(self): + def get_users_str(self) -> str: """Returns string of users separated by a comma, who are in chat Returns: @@ -34,7 +36,7 @@ def get_users_str(self): return ", ".join([user.get_full_name() for user in users]) @abstractmethod - def get_users(self): + def get_users(self) -> List[CustomUser]: """ Returns all collaborators and leader of the project. @@ -44,7 +46,7 @@ def get_users(self): pass @abstractmethod - def get_avatar(self, user): + def get_avatar(self, user: CustomUser) -> str: """ Returns avatar of the chat for given user @@ -57,7 +59,7 @@ def get_avatar(self, user): pass @abstractmethod - def get_last_messages(self, message_count): + def get_last_messages(self, message_count: int) -> List[BaseMessage]: """ Returns last messages of the chat @@ -90,18 +92,18 @@ class ProjectChat(BaseChat): Project, on_delete=models.CASCADE, related_name="project_chats" ) - def get_users(self): + def get_users(self) -> List[CustomUser]: collaborators = self.project.collaborator_set.all() users = [collaborator.user for collaborator in collaborators] return users + [self.project.leader] - def get_avatar(self, user): + def get_avatar(self, user: CustomUser) -> str: return self.project.image_address - def get_last_messages(self, message_count) -> List["BaseMessage"]: + def get_last_messages(self, message_count: int) -> List[BaseMessage]: return self.messages.order_by("-created_at")[:message_count] - def __str__(self): + def __str__(self) -> str: return f"ProjectChat<{self.project.id}> - {self.project.name}" def save( @@ -129,15 +131,15 @@ class DirectChat(BaseChat): id = models.CharField(primary_key=True, max_length=64) users = models.ManyToManyField(User, related_name="direct_chats") - def get_users(self): + def get_users(self) -> List[CustomUser]: return self.users.all() - def get_avatar(self, user): + def get_avatar(self, user) -> str: other_user = self.get_users().exclude(pk=user.pk).first() return other_user.avatar @classmethod - def get_chat(cls, user1, user2) -> "DirectChat": + def get_chat(cls, user1: CustomUser, user2: CustomUser) -> "DirectChat": """ Returns chat between two users. @@ -157,25 +159,25 @@ def get_chat(cls, user1, user2) -> "DirectChat": chat.users.set([user1, user2]) return chat - def get_last_messages(self, message_count): + def get_last_messages(self, message_count: int) -> BaseMessage: return self.messages.order_by("-created_at")[:message_count] - def get_other_user(self, user) -> User: + def get_other_user(self, user: CustomUser) -> User: return self.users.exclude(pk=user.pk).first() @classmethod - def create_from_two_users(cls, user1, user2): + def create_from_two_users(cls, user1: CustomUser, user2: CustomUser) -> DirectChat: chat = cls.objects.create(pk=cls.get_chat_id_from_users(user1, user2)) chat.users.set([user1, user2]) return chat @classmethod - def get_chat_id_from_users(cls, user1, user2): + def get_chat_id_from_users(cls, user1: CustomUser, user2: CustomUser) -> str: first_user = user1 if user1.pk < user2.pk else user2 second_user = user2 if user1.pk < user2.pk else user1 return f"{first_user.pk}_{second_user.pk}" - def __str__(self): + def __str__(self) -> str: return f"DirectChat with {self.get_users_str()}" class Meta: @@ -200,7 +202,7 @@ class BaseMessage(models.Model): is_edited = models.BooleanField(default=False) created_at = models.DateTimeField(auto_now_add=True) - def __str__(self): + def __str__(self) -> str: return f"Message<{self.pk}>" class Meta: @@ -240,7 +242,7 @@ def clean(self): if self.reply_to and self.reply_to.chat != self.chat: raise ValidationError("Reply to message from another chat") - def __str__(self): + def __str__(self) -> str: return f"ProjectChatMessage<{self.pk}>" class Meta: @@ -278,7 +280,7 @@ def clean(self): if self.reply_to and self.reply_to.chat != self.chat: raise ValidationError("Reply to message from another chat") - def __str__(self): + def __str__(self) -> str: return f"DirectChatMessage<{self.pk}>" class Meta: @@ -307,7 +309,7 @@ class FileToMessage(models.Model): null=True, ) - def __str__(self): + def __str__(self) -> str: return f"FileToMessage<{self.file}>" class Meta: diff --git a/chats/serializers.py b/chats/serializers.py index 9c6200b0..264f2cba 100644 --- a/chats/serializers.py +++ b/chats/serializers.py @@ -16,22 +16,22 @@ class DirectChatListSerializer(serializers.ModelSerializer): name = serializers.SerializerMethodField(read_only=True) image_address = serializers.SerializerMethodField(read_only=True) - def get_opponent(self, chat: DirectChat): + def get_opponent(self, chat: DirectChat) -> dict: user = self.context.get("opponent") return UserDetailSerializer( user, context={"request": self.context.get("request")} ).data - def get_name(self, chat: DirectChat): + def get_name(self, chat: DirectChat) -> str: user = self.context.get("opponent") return user.get_full_name() - def get_image_address(self, chat: DirectChat): + def get_image_address(self, chat: DirectChat) -> str: user = self.context.get("opponent") return user.avatar @classmethod - def get_last_message(cls, chat: DirectChat): + def get_last_message(cls, chat: DirectChat) -> dict: return DirectChatMessageListSerializer(chat.get_last_message()).data class Meta: @@ -42,7 +42,7 @@ class Meta: class DirectChatDetailSerializer(serializers.ModelSerializer): opponent = serializers.SerializerMethodField() - def get_opponent(self, chat: DirectChat): + def get_opponent(self, chat: DirectChat) -> dict: user = self.context.get("opponent") return UserDetailSerializer( user, context={"request": self.context.get("request")} @@ -62,15 +62,15 @@ class ProjectChatListSerializer(serializers.ModelSerializer): image_address = serializers.SerializerMethodField(read_only=True) @classmethod - def get_image_address(cls, chat: ProjectChat): + def get_image_address(cls, chat: ProjectChat) -> str: return chat.project.image_address @classmethod - def get_name(cls, chat: ProjectChat): + def get_name(cls, chat: ProjectChat) -> str: return chat.project.name @classmethod - def get_last_message(cls, chat: ProjectChat): + def get_last_message(cls, chat: ProjectChat) -> dict: return ProjectChatMessageListSerializer(chat.get_last_message()).data class Meta: @@ -84,14 +84,14 @@ class ProjectChatDetailSerializer(serializers.ModelSerializer): image_address = serializers.SerializerMethodField(read_only=True) @classmethod - def get_image_address(cls, chat: ProjectChat): + def get_image_address(cls, chat: ProjectChat) -> str: return chat.project.image_address @classmethod - def get_name(cls, chat: ProjectChat): + def get_name(cls, chat: ProjectChat) -> str: return chat.project.name - def get_users(self, chat: ProjectChat): + def get_users(self, chat: ProjectChat) -> dict: return UserListSerializer( chat.get_users(), context={"request": self.context.get("request")}, many=True ).data @@ -133,7 +133,7 @@ class DirectChatMessageListSerializer(serializers.ModelSerializer): files = serializers.SerializerMethodField() @classmethod - def get_files(cls, message: DirectChatMessage): + def get_files(cls, message: DirectChatMessage) -> list[dict]: data = [] for file_to_message in message.file_to_message.all(): file_data = UserFileSerializer(file_to_message.file).data @@ -182,7 +182,7 @@ class ProjectChatMessageListSerializer(serializers.ModelSerializer): files = serializers.SerializerMethodField() @classmethod - def get_files(cls, message: DirectChatMessage): + def get_files(cls, message: DirectChatMessage) -> dict: data = [] for file_to_message in message.file_to_message.all(): file_data = UserFileSerializer(file_to_message.file).data diff --git a/chats/utils.py b/chats/utils.py index 02f36b06..87f55315 100644 --- a/chats/utils.py +++ b/chats/utils.py @@ -9,7 +9,7 @@ WrongChatIdException, NonMatchingDirectChatIdException, ) -from chats.models import DirectChatMessage, ProjectChatMessage, FileToMessage +from chats.models import DirectChatMessage, ProjectChatMessage, FileToMessage, BaseMessage from files.models import UserFile User = get_user_model() @@ -106,7 +106,9 @@ async def create_file_to_message( ) -async def match_files_and_messages(file_urls, messages): +async def match_files_and_messages( + file_urls: list[str], messages: dict[str, Union[str, None, ProjectChatMessage]] +): for url in file_urls: file = await sync_to_async(UserFile.objects.get)(pk=url) # implicitly matches a file and a message @@ -117,7 +119,7 @@ async def match_files_and_messages(file_urls, messages): ) -def get_all_files(messages): +def get_all_files(messages: list[BaseMessage]) -> list[str]: # looks like something bad - files = [] for message in messages: diff --git a/chats/views.py b/chats/views.py index a99b5190..cf3a9dce 100644 --- a/chats/views.py +++ b/chats/views.py @@ -1,5 +1,6 @@ from django.contrib.auth import get_user_model +from django_stubs_ext import QuerySet from rest_framework import status from rest_framework.generics import ( GenericAPIView, @@ -34,11 +35,11 @@ class DirectChatList(ListAPIView): serializer_class = DirectChatListSerializer permission_classes = [IsAuthenticated] - def get_queryset(self): + def get_queryset(self) -> QuerySet[DirectChat]: user = self.request.user return user.direct_chats.all() - def get(self, request, *args, **kwargs): + def get(self, request, *args, **kwargs) -> Response: chats = self.get_queryset() serialized_chats = [] for chat in chats: @@ -68,7 +69,7 @@ class ProjectChatList(ListAPIView): serializer_class = ProjectChatListSerializer permission_classes = [IsAuthenticated] - def get_queryset(self): + def get_queryset(self) -> QuerySet[ProjectChat]: user = self.request.user return user.get_project_chats() @@ -117,7 +118,9 @@ def get(self, request, *args, **kwargs) -> Response: except ValueError: return Response( status=status.HTTP_400_BAD_REQUEST, - data={"detail": "processed id must contain two integers separated by underscore"}, + data={ + "detail": "processed id must contain two integers separated by underscore" + }, ) except AssertionError as e: return Response(status=status.HTTP_400_BAD_REQUEST, data={"detail": str(e)}) @@ -152,7 +155,7 @@ class ProjectChatMessageList(ListCreateAPIView): permission_classes = [IsProjectChatMember] pagination_class = MessageListPagination - def get_queryset(self): + def get_queryset(self) -> QuerySet[ProjectChat]: try: return ( ProjectChat.objects.get(id=self.kwargs["id"]) @@ -163,7 +166,7 @@ def get_queryset(self): except ProjectChat.DoesNotExist: return ProjectChat.objects.none() - def post(self, request, *args, **kwargs): + def post(self, request, *args, **kwargs) -> Response: # TODO: try to create a message in a chat. If chat doesn't exist, create it and then create a message. serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) @@ -176,7 +179,7 @@ class DirectChatFileList(ListCreateAPIView): serializer_class = UserFileSerializer permission_classes = [IsAuthenticated] - def get_queryset(self): + def get_queryset(self) -> list[str]: messages = self.request.user.direct_chats.get(id=self.kwargs["id"]).messages.all() @@ -187,7 +190,7 @@ class ProjectChatFileList(ListCreateAPIView): serializer_class = UserFileSerializer permission_classes = [IsProjectChatMember] - def get_queryset(self): + def get_queryset(self) -> QuerySet[UserFile]: try: messages = ProjectChat.objects.get(id=self.kwargs["id"]).messages.all() return get_all_files(messages) @@ -216,7 +219,7 @@ class HasChatUnreadsView(GenericAPIView): ) } ) - def get(self, request, *args, **kwargs): + def get(self, request, *args, **kwargs) -> Response: user = request.user # get all user chats direct_messages = user.direct_chats.all().prefetch_related("messages")