Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[engine] Fix some type hints when request is passed #9077

Merged
merged 7 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cvat/apps/engine/background.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from django.utils import timezone
from django_rq.queues import DjangoRQ, DjangoScheduler
from rest_framework import serializers, status
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.reverse import reverse
from rq.job import Job as RQJob
Expand All @@ -40,6 +39,7 @@
from cvat.apps.engine.permissions import get_cloud_storage_for_import_or_export
from cvat.apps.engine.rq_job_handler import RQId, RQJobMetaField
from cvat.apps.engine.serializers import RqIdSerializer
from cvat.apps.engine.types import PatchedRequest
from cvat.apps.engine.utils import (
build_annotations_file_name,
build_backup_file_name,
Expand Down Expand Up @@ -180,7 +180,7 @@ def location(self) -> Location:
def __init__(
self,
db_instance: Union[models.Project, models.Task, models.Job],
request: Request,
request: PatchedRequest,
export_callback: Callable,
save_images: Optional[bool] = None,
*,
Expand Down Expand Up @@ -525,7 +525,7 @@ def location(self) -> Location:
def __init__(
self,
db_instance: Union[models.Project, models.Task],
request: Request,
request: PatchedRequest,
*,
version: int = 2,
) -> None:
Expand Down
16 changes: 13 additions & 3 deletions cvat/apps/engine/backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
ValidationParamsSerializer,
)
from cvat.apps.engine.task import JobFileMapping, _create_thread
from cvat.apps.engine.types import PatchedRequest
from cvat.apps.engine.utils import (
av_scan_paths,
define_dependent_job,
Expand Down Expand Up @@ -1147,7 +1148,16 @@ def create_backup(
log_exception(logger)
raise

def _import(importer, request, queue, rq_id, Serializer, file_field_name, location_conf, filename=None):
def _import(
importer: TaskImporter | ProjectImporter,
request: PatchedRequest,
queue: django_rq.queues.Queue,
rq_id: str,
Serializer: TaskFileSerializer | ProjectFileSerializer,
file_field_name: str,
location_conf: dict,
filename: str | None = None,
):
rq_job = queue.fetch_job(rq_id)

if (user_id_from_meta := getattr(rq_job, 'meta', {}).get(RQJobMetaField.USER, {}).get('id')) and user_id_from_meta != request.user.id:
Expand Down Expand Up @@ -1235,7 +1245,7 @@ def _import(importer, request, queue, rq_id, Serializer, file_field_name, locati
def get_backup_dirname():
return TmpDirManager.TMP_ROOT

def import_project(request, queue_name, filename=None):
def import_project(request: PatchedRequest, queue_name: str, filename: str | None = None):
if 'rq_id' in request.data:
rq_id = request.data['rq_id']
else:
Expand Down Expand Up @@ -1264,7 +1274,7 @@ def import_project(request, queue_name, filename=None):
filename=filename
)

def import_task(request, queue_name, filename=None):
def import_task(request: PatchedRequest, queue_name: str, filename: str | None = None):
rq_id = request.data.get('rq_id', RQId(
RequestAction.IMPORT, RequestTarget.TASK, uuid.uuid4(),
subresource=RequestSubresource.BACKUP,
Expand Down
5 changes: 5 additions & 0 deletions cvat/apps/engine/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@
#
# SPDX-License-Identifier: MIT

from typing import Protocol
from uuid import uuid4


class WithUUID(Protocol):
uuid: str


class RequestTrackingMiddleware:
def __init__(self, get_response):
self.get_response = get_response
Expand Down
45 changes: 30 additions & 15 deletions cvat/apps/engine/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import django_rq
from attr.converters import to_bool
from django.conf import settings
from django.http import HttpRequest
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
from rest_framework import mixins, status
Expand All @@ -33,9 +32,18 @@
from cvat.apps.engine.handlers import clear_import_cache
from cvat.apps.engine.location import StorageType, get_location_configuration
from cvat.apps.engine.log import ServerLogManager
from cvat.apps.engine.models import Location, RequestAction, RequestSubresource, RequestTarget
from cvat.apps.engine.models import (
Job,
Location,
Project,
RequestAction,
RequestSubresource,
RequestTarget,
Task,
)
from cvat.apps.engine.rq_job_handler import RQId
from cvat.apps.engine.serializers import DataSerializer, RqIdSerializer
from cvat.apps.engine.types import PatchedRequest
from cvat.apps.engine.utils import is_dataset_export

slogger = ServerLogManager(__name__)
Expand Down Expand Up @@ -159,7 +167,7 @@ def create_file(metadata, file_size, upload_dir):
return tus_file

class TusChunk:
def __init__(self, request):
def __init__(self, request: PatchedRequest):
self.META = request.META
self.offset = int(request.META.get("HTTP_UPLOAD_OFFSET", 0))
self.size = int(request.META.get("CONTENT_LENGTH", settings.TUS_DEFAULT_CHUNK_SIZE))
Expand Down Expand Up @@ -215,7 +223,7 @@ def _tus_response(self, status, data=None, extra_headers=None):
response.__setitem__(key, value)
return response

def _get_metadata(self, request):
def _get_metadata(self, request: PatchedRequest):
metadata = {}
if request.META.get("HTTP_UPLOAD_METADATA"):
for kv in request.META.get("HTTP_UPLOAD_METADATA").split(","):
Expand All @@ -230,7 +238,7 @@ def _get_metadata(self, request):
metadata[splited_metadata[0]] = ""
return metadata

def upload_data(self, request):
def upload_data(self, request: PatchedRequest):
tus_request = request.headers.get('Upload-Length', None) is not None or request.method == 'OPTIONS'
bulk_file_upload = request.headers.get('Upload-Multiple', None) is not None
start_upload = request.headers.get('Upload-Start', None) is not None
Expand All @@ -247,9 +255,9 @@ def upload_data(self, request):
else: # backward compatibility case - no upload headers were found
return self.upload_finished(request)

def init_tus_upload(self, request):
def init_tus_upload(self, request: PatchedRequest):
if request.method == 'OPTIONS':
return self._tus_response(status=status.HTTP_204)
return self._tus_response(status=status.HTTP_204_NO_CONTENT)
else:
metadata = self._get_metadata(request)
filename = metadata.get('filename', '')
Expand Down Expand Up @@ -321,7 +329,7 @@ def init_tus_upload(self, request):
extra_headers={'Location': urljoin(location, tus_file.file_id),
'Upload-Filename': tus_file.filename})

def append_tus_chunk(self, request, file_id):
def append_tus_chunk(self, request: PatchedRequest, file_id: str):
tus_file = TusFile(str(file_id), self.get_upload_dir())
if request.method == 'HEAD':
if tus_file.exists():
Expand Down Expand Up @@ -405,7 +413,7 @@ class PartialUpdateModelMixin:
Almost the same as UpdateModelMixin, but has no public PUT / update() method.
"""

def _update(self, request, *args, **kwargs):
def _update(self, request: PatchedRequest, *args, **kwargs):
# This method must not be named "update" not to be matched with the PUT method
return mixins.UpdateModelMixin.update(self, request, *args, **kwargs)

Expand All @@ -420,7 +428,7 @@ def partial_update(self, request, *args, **kwargs):
class DatasetMixin:
def export_dataset_v1(
self,
request,
request: PatchedRequest,
save_images: bool,
*,
get_data: Optional[Callable[[int], dict[str, Any]]] = None,
Expand Down Expand Up @@ -466,7 +474,7 @@ def export_dataset_v1(
},
)
@action(detail=True, methods=['POST'], serializer_class=None, url_path='dataset/export')
def export_dataset_v2(self, request: HttpRequest, pk: int):
def export_dataset_v2(self, request: PatchedRequest, pk: int):
self._object = self.get_object() # force call of check_object_permissions()

save_images = is_dataset_export(request)
Expand All @@ -476,7 +484,14 @@ def export_dataset_v2(self, request: HttpRequest, pk: int):
return dataset_export_manager.export()

# FUTURE-TODO: migrate to new API
def import_annotations(self, request, db_obj, import_func, rq_func, rq_id_factory):
def import_annotations(
self,
request: PatchedRequest,
db_obj: Project | Task | Job,
import_func: Callable[..., None],
rq_func: Callable[..., None],
rq_id_factory: RQId,
):
is_tus_request = request.headers.get('Upload-Length', None) is not None or \
request.method == 'OPTIONS'
if is_tus_request:
Expand Down Expand Up @@ -508,7 +523,7 @@ def import_annotations(self, request, db_obj, import_func, rq_func, rq_id_factor


class BackupMixin:
def export_backup_v1(self, request: HttpRequest) -> Response:
def export_backup_v1(self, request: PatchedRequest) -> Response:
db_object = self.get_object() # force to call check_object_permissions

export_backup_manager = BackupExportManager(db_object, request, version=1)
Expand All @@ -520,7 +535,7 @@ def export_backup_v1(self, request: HttpRequest) -> Response:
return response

# FUTURE-TODO: migrate to new API
def import_backup_v1(self, request: HttpRequest, import_func: Callable) -> Response:
def import_backup_v1(self, request: PatchedRequest, import_func: Callable) -> Response:
location = request.query_params.get("location", Location.LOCAL)
if location == Location.CLOUD_STORAGE:
file_name = request.query_params.get("filename", "")
Expand Down Expand Up @@ -554,7 +569,7 @@ def import_backup_v1(self, request: HttpRequest, import_func: Callable) -> Respo
},
)
@action(detail=True, methods=['POST'], serializer_class=None, url_path='backup/export')
def export_backup_v2(self, request: HttpRequest, pk: int):
def export_backup_v2(self, request: PatchedRequest, pk: int):
db_object = self.get_object() # force to call check_object_permissions

export_backup_manager = BackupExportManager(db_object, request, version=2)
Expand Down
4 changes: 3 additions & 1 deletion cvat/apps/engine/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@

from rest_framework.pagination import PageNumberPagination

from cvat.apps.engine.types import PatchedRequest


class CustomPagination(PageNumberPagination):
page_size_query_param = "page_size"

def get_page_size(self, request):
def get_page_size(self, request: PatchedRequest):
page_size = 0
try:
value = request.query_params[self.page_size_query_param]
Expand Down
Loading
Loading