diff --git a/lilac/router_tasks.py b/lilac/router_tasks.py index be0413490..9e2e6549f 100644 --- a/lilac/router_tasks.py +++ b/lilac/router_tasks.py @@ -12,3 +12,9 @@ def get_task_manifest() -> TaskManifest: """Get the tasks, both completed and pending.""" return get_task_manager().manifest() + + +@router.post('/{task_id}/cancel') +def cancel_task(task_id: str) -> None: + """Cancel a task.""" + get_task_manager().cancel_task(task_id) diff --git a/lilac/tasks.py b/lilac/tasks.py index 4e24e2931..43a453dcb 100644 --- a/lilac/tasks.py +++ b/lilac/tasks.py @@ -2,6 +2,7 @@ import functools import uuid +from collections import defaultdict from dataclasses import dataclass from datetime import datetime, timedelta from enum import Enum @@ -74,25 +75,29 @@ class TaskManifest(BaseModel): class TaskManager: """Manage FastAPI background tasks.""" - _tasks: dict[TaskId, TaskInfo] + _task_info: dict[TaskId, TaskInfo] + _task_threads: dict[TaskId, Thread] + _task_stopped: dict[TaskId, bool] # If true, task cancellation was requested. def __init__(self) -> None: # Maps a task id to the current progress of that task. Shared across all processes. - self._tasks = {} + self._task_info = {} + self._task_threads = {} + self._task_stopped = defaultdict(lambda: False) def get_task_info(self, task_id: TaskId) -> TaskInfo: """Get the task info for a task.""" - return self._tasks[task_id] + return self._task_info[task_id] def manifest(self) -> TaskManifest: """Get all tasks.""" tasks_with_progress = [ (task.total_progress / task.total_len) - for task in self._tasks.values() + for task in self._task_info.values() if task.total_progress and task.total_len and task.status != TaskStatus.COMPLETED ] return TaskManifest( - tasks=self._tasks, + tasks=self._task_info, progress=sum(tasks_with_progress) / len(tasks_with_progress) if tasks_with_progress else None, ) @@ -113,12 +118,38 @@ def task_id( start_timestamp=datetime.now().isoformat(), total_len=total_len, ) - self._tasks[task_id] = new_task + self._task_info[task_id] = new_task return task_id + def launch_task(self, task_id: TaskId, run_fn: Callable[..., Any]) -> None: + """Start a task in a background thread.""" + + def _wrapper() -> None: + try: + run_fn() + except Exception as e: + log(e) + self.set_error(task_id, str(e)) + else: + self.set_completed(task_id) + + thread = Thread(target=_wrapper, daemon=True) + thread.start() + self._task_threads[task_id] = thread + + def cancel_task(self, task_id: TaskId) -> None: + """Mark a thread for cancellation. + + The thread is not guaranteed to stop unless you also use get_progress_bar. If you implement + your own task execution logic, you can check tm._task_stopped[task_id] to see if the task + has been cancelled. + """ + self._task_stopped[task_id] = True + self._task_info[task_id].message = 'Task cancellation requested.' + def report_progress(self, task_id: TaskId, progress: int) -> None: """Report the progress of a task.""" - task = self._tasks[task_id] + task = self._task_info[task_id] task.total_progress = progress elapsed_sec = (datetime.now() - datetime.fromisoformat(task.start_timestamp)).total_seconds() ex_per_sec = progress / elapsed_sec if elapsed_sec else 0 @@ -128,7 +159,7 @@ def report_progress(self, task_id: TaskId, progress: int) -> None: def set_error(self, task_id: TaskId, error: str) -> None: """Mark a task as errored.""" - task = self._tasks[task_id] + task = self._task_info[task_id] task.status = TaskStatus.ERROR task.error = error task.end_timestamp = datetime.now().isoformat() @@ -136,7 +167,7 @@ def set_error(self, task_id: TaskId, error: str) -> None: def set_completed(self, task_id: TaskId) -> None: """Mark a task completed.""" end_timestamp = datetime.now().isoformat() - task = self._tasks[task_id] + task = self._task_info[task_id] task.end_timestamp = end_timestamp elapsed = datetime.fromisoformat(end_timestamp) - datetime.fromisoformat(task.start_timestamp) @@ -180,6 +211,8 @@ def progress_reporter(it: Iterator[TProgress]) -> Iterator[TProgress]: progress = offset try: for item in tqdm(it, initial=progress, total=task_info.total_len, desc=task_info.description): + if task_manager._task_stopped[task_id]: + raise AssertionError('Task cancelled successfully!') progress += 1 if progress % 100 == 0: task_manager.report_progress(task_id, progress) @@ -195,15 +228,4 @@ def progress_reporter(it: Iterator[TProgress]) -> Iterator[TProgress]: def launch_task(task_id: TaskId, run_fn: Callable) -> None: """Launch a task in a thread, handling exit conditions, etc..""" tm = get_task_manager() - - def _wrapper() -> None: - try: - run_fn() - except Exception as e: - log(e) - tm.set_error(task_id, str(e)) - else: - tm.set_completed(task_id) - - thread = Thread(target=_wrapper, daemon=True) - thread.start() + tm.launch_task(task_id, run_fn) diff --git a/web/blueprint/src/lib/components/TaskStatus.svelte b/web/blueprint/src/lib/components/TaskStatus.svelte index 6d7e73037..d6462824b 100644 --- a/web/blueprint/src/lib/components/TaskStatus.svelte +++ b/web/blueprint/src/lib/components/TaskStatus.svelte @@ -83,6 +83,7 @@
{task.name}
+
{ + return __request(OpenAPI, { + method: 'POST', + url: '/api/v1/tasks/{task_id}/cancel', + path: { + 'task_id': taskId, + }, + errors: { + 422: `Validation Error`, + }, + }); + } + }