diff --git a/CLI-COMMANDS.md b/CLI-COMMANDS.md index 34539368..9929da6c 100644 --- a/CLI-COMMANDS.md +++ b/CLI-COMMANDS.md @@ -122,6 +122,22 @@ roboflow workflow version list my-workflow roboflow workflow fork other-ws/their-workflow ``` +### Fork a Universe project (async) + +```bash +# Fork a public Universe project into the default (or --workspace) workspace. +# By default this blocks until the async task completes (up to --timeout seconds). +roboflow project fork https://universe.roboflow.com/leo-ueno-uduc7/license-plate-recognition +roboflow project fork leo-ueno-uduc7/license-plate-recognition --workspace my-ws + +# Return immediately with a {taskId, url} payload instead of waiting. +roboflow project fork leo-ueno-uduc7/license-plate-recognition --no-wait + +# Poll the resulting task later (works for any async task that returns a taskId). +roboflow asynctasks get +roboflow asynctasks wait --timeout 600 +``` + ### Create a dataset version ```bash @@ -259,6 +275,7 @@ Version numbers are always numeric — that's how `x/y` is disambiguated between | `workflow` | Manage workflows | | `folder` | Manage workspace folders | | `annotation` | Annotation batches and jobs | +| `asynctasks` | Inspect async background tasks (e.g. project forks) | | `trash` | List items in Trash | | `universe` | Search Roboflow Universe | | `video` | Video inference | diff --git a/roboflow/adapters/rfapi.py b/roboflow/adapters/rfapi.py index e087d603..b0df1851 100644 --- a/roboflow/adapters/rfapi.py +++ b/roboflow/adapters/rfapi.py @@ -3,6 +3,7 @@ import os import urllib from typing import Dict, List, Optional, Union +from urllib.parse import quote import requests from requests.exceptions import RequestException @@ -813,6 +814,71 @@ def list_workflow_versions(api_key, workspace_url, workflow_url): return response.json() +def fork_project( + api_key, + dest_workspace, + *, + url=None, + source_project_slug=None, +): + """POST /{ws}/projects/fork — enqueue an async fork of a public Universe project. + + Pass ``url`` (a Universe URL) or an explicit ``source_project_slug``. The + API owns parsing/validation. Returns the server's response, e.g. + ``{"taskId": "...", "url": ""}``. + """ + payload: Dict[str, str] = {} + if url: + payload["url"] = url + if source_project_slug: + payload["source_project"] = source_project_slug + response = requests.post( + f"{API_URL}/{dest_workspace}/projects/fork", + params={"api_key": api_key}, + json=payload, + ) + if not response.ok: + raise RoboflowError(response.text) + return response.json() + + +def get_async_task(api_key, workspace_url, task_id): + """GET /{ws}/asynctasks/{id} — fetch the current status of an async task. + + Returns the server's status payload, e.g. + ``{"taskId": "...", "status": "running", "progress": {...}}`` or + ``{"taskId": "...", "status": "completed", "result": {...}}`` once + terminal. Raises ``RoboflowError`` for any non-2xx response (including + 404 for unknown ids or cross-workspace probes). + """ + # ``task_id`` comes from arbitrary external input; encode so a stray + # ``/``, ``?`` or ``#`` cannot mutate the request path (and still send + # the api_key with it). + encoded_task_id = quote(task_id, safe="") + response = requests.get( + f"{API_URL}/{workspace_url}/asynctasks/{encoded_task_id}", + params={"api_key": api_key}, + ) + if response.status_code != 200: + raise RoboflowError(response.text) + return response.json() + + +def get_async_task_at(api_key, polling_url): + """GET an async-task polling URL returned verbatim by the server. + + Enqueue endpoints (e.g. ``/{ws}/projects/fork``) return a fully-qualified + ``url`` alongside ``taskId``. The host may differ from ``API_URL`` (e.g. + local dev against ``localapi.roboflow.one``), so hit it directly and + only attach the api_key. Falls back to ``get_async_task`` callers when + no server-supplied URL is available. + """ + response = requests.get(polling_url, params={"api_key": api_key}) + if response.status_code != 200: + raise RoboflowError(response.text) + return response.json() + + def fork_workflow(api_key, workspace_url, *, source_workspace, source_workflow, name=None, url=None): """POST /{ws}/forkWorkflow — fork a workflow into this workspace. diff --git a/roboflow/cli/__init__.py b/roboflow/cli/__init__.py index e998c65b..1fd4a038 100644 --- a/roboflow/cli/__init__.py +++ b/roboflow/cli/__init__.py @@ -185,6 +185,7 @@ def _walk(group: Any, prefix: str = "") -> None: # --------------------------------------------------------------------------- from roboflow.cli.handlers.annotation import annotation_app # noqa: E402 +from roboflow.cli.handlers.asynctasks import asynctasks_app # noqa: E402 from roboflow.cli.handlers.auth import auth_app # noqa: E402 from roboflow.cli.handlers.batch import batch_app # noqa: E402 from roboflow.cli.handlers.completion import completion_app # noqa: E402 @@ -207,6 +208,7 @@ def _walk(group: Any, prefix: str = "") -> None: # Register ALL commands in alphabetical order for clean --help output app.add_typer(annotation_app, name="annotation") +app.add_typer(asynctasks_app, name="asynctasks") app.add_typer(auth_app, name="auth") app.add_typer(batch_app, name="batch", hidden=True) # All stubs — hidden until implemented app.add_typer(completion_app, name="completion") diff --git a/roboflow/cli/handlers/asynctasks.py b/roboflow/cli/handlers/asynctasks.py new file mode 100644 index 00000000..51dbe26f --- /dev/null +++ b/roboflow/cli/handlers/asynctasks.py @@ -0,0 +1,136 @@ +"""Async task polling commands. + +These mirror the generic ``GET /:workspace/asynctasks/:id`` endpoint so any +backend operation that returns ``{taskId, url}`` can be inspected with the +same CLI tools. +""" + +from __future__ import annotations + +from typing import Annotated + +import typer + +from roboflow.cli._compat import SortedGroup, ctx_to_args + +asynctasks_app = typer.Typer( + cls=SortedGroup, + help="Inspect async background tasks (e.g. project forks)", + no_args_is_help=True, +) + + +@asynctasks_app.command("get") +def get_async_task( + ctx: typer.Context, + task_id: Annotated[str, typer.Argument(help="Async task id (returned by /projects/fork etc.)")], +) -> None: + """Show the current status of an async task.""" + args = ctx_to_args(ctx, task_id=task_id) + _get_async_task(args) + + +@asynctasks_app.command("wait") +def wait_async_task( + ctx: typer.Context, + task_id: Annotated[str, typer.Argument(help="Async task id")], + timeout: Annotated[ + int, + typer.Option("--timeout", help="Seconds to wait for completion (0 = no timeout)."), + ] = 1800, +) -> None: + """Block until an async task is completed or failed.""" + args = ctx_to_args(ctx, task_id=task_id, timeout=timeout) + _wait_async_task(args) + + +# --------------------------------------------------------------------------- +# Business logic +# --------------------------------------------------------------------------- + + +def _resolve_ws_and_key(args): # noqa: ANN001 + from roboflow.cli._output import output_error + from roboflow.cli._resolver import resolve_default_workspace + from roboflow.config import load_roboflow_api_key + + workspace_url = args.workspace or resolve_default_workspace(api_key=args.api_key) + if not workspace_url: + output_error( + args, + "No workspace specified.", + hint="Use --workspace or run 'roboflow auth login'.", + exit_code=2, + ) + return None, None + api_key = args.api_key or load_roboflow_api_key(workspace_url) + if not api_key: + output_error( + args, + "No API key found.", + hint="Set ROBOFLOW_API_KEY or run 'roboflow auth login'.", + exit_code=2, + ) + return None, None + return workspace_url, api_key + + +def _get_async_task(args): # noqa: ANN001 + from roboflow.adapters import rfapi + from roboflow.cli._output import output, output_error + + workspace_url, api_key = _resolve_ws_and_key(args) + if not api_key: + return + + try: + status = rfapi.get_async_task(api_key, workspace_url, args.task_id) + except rfapi.RoboflowError as exc: + # Server returns 404 for unknown ids OR cross-workspace probes. + output_error(args, str(exc), exit_code=3) + return + + output(args, status, text=f"taskId={status.get('taskId')} status={status.get('status')}") + + +def _wait_async_task(args): # noqa: ANN001 + from roboflow.adapters import rfapi + from roboflow.cli._output import output, output_error + from roboflow.core.async_tasks import poll_until_terminal + + workspace_url, api_key = _resolve_ws_and_key(args) + if not api_key: + return + + def _print_progress(status): # noqa: ANN001 + if args.json: + return + progress = status.get("progress") + if not isinstance(progress, dict): + return + # Don't use `or` here: `current == 0` is a legitimate value. + current = progress["current"] if "current" in progress else progress.get("completed") + total = progress.get("total") + if current is not None and total is not None: + print(f"Task progress: {current}/{total}", flush=True) + + try: + final = poll_until_terminal( + api_key, + workspace_url, + args.task_id, + timeout=args.timeout, + on_update=_print_progress, + ) + except rfapi.RoboflowError as exc: + output_error(args, str(exc), exit_code=3) + return + except TimeoutError as exc: + output_error(args, str(exc)) + return + + if final.get("status") == "failed": + output_error(args, final.get("error") or "Task failed.") + return + + output(args, final, text=f"taskId={final.get('taskId')} status={final.get('status')}") diff --git a/roboflow/cli/handlers/project.py b/roboflow/cli/handlers/project.py index d2e6c97e..e59bd56f 100644 --- a/roboflow/cli/handlers/project.py +++ b/roboflow/cli/handlers/project.py @@ -78,6 +78,27 @@ def restore_project( _restore_project(args) +@project_app.command("fork") +def fork_project( + ctx: typer.Context, + source: Annotated[ + str, + typer.Argument(help="Source project: Universe URL or '/' shorthand."), + ], + no_wait: Annotated[ + bool, + typer.Option("--no-wait", help="Return immediately with the taskId instead of waiting."), + ] = False, + timeout: Annotated[ + int, + typer.Option("--timeout", help="Seconds to wait for completion (0 = no timeout)."), + ] = 1800, +) -> None: + """Fork a public Universe project into a workspace.""" + args = ctx_to_args(ctx, source=source, no_wait=no_wait, timeout=timeout) + _fork_project(args) + + @project_app.command("health") def health_project( ctx: typer.Context, @@ -352,6 +373,98 @@ def _restore_project(args): # noqa: ANN001 output(args, data, text=f"Restored {workspace_url}/{project_slug} from Trash.") +def _fork_project(args): # noqa: ANN001 + from roboflow.adapters import rfapi + from roboflow.cli._output import output, output_error + from roboflow.cli._resolver import resolve_default_workspace + from roboflow.config import load_roboflow_api_key + from roboflow.core.async_tasks import poll_until_terminal + + # The server accepts the full URL (or `/` shorthand) as `url` + # and parses it itself — forward verbatim so the CLI doesn't duplicate + # that logic. + source = (args.source or "").strip() + if not source: + output_error( + args, + "Source is required.", + hint="Use '/' or a Universe URL.", + ) + return + + dest_workspace = args.workspace or resolve_default_workspace(api_key=args.api_key) + if not dest_workspace: + output_error( + args, + "No workspace specified.", + hint="Use --workspace or run 'roboflow auth login'.", + exit_code=2, + ) + return + + api_key = args.api_key or load_roboflow_api_key(dest_workspace) + if not api_key: + output_error( + args, + "No API key found.", + hint="Set ROBOFLOW_API_KEY or run 'roboflow auth login'.", + exit_code=2, + ) + return + + try: + enqueued = rfapi.fork_project(api_key, dest_workspace, url=source) + except rfapi.RoboflowError as exc: + output_error(args, str(exc)) + return + + task_id = enqueued["taskId"] + + if args.no_wait: + polling_url = enqueued.get("url") + text = f"Fork enqueued: taskId={task_id}" + if polling_url: + text += f"\nPoll: {polling_url}" + output(args, enqueued, text=text) + return + + def _print_progress(status): # noqa: ANN001 + if args.json: + return + progress = status.get("progress") + if not isinstance(progress, dict): + return + # Don't use `or` here: `current == 0` is a legitimate value. + current = progress["current"] if "current" in progress else progress.get("completed") + total = progress.get("total") + if current is not None and total is not None: + print(f"Task progress: {current}/{total}", flush=True) + + try: + final = poll_until_terminal( + api_key, + dest_workspace, + task_id, + timeout=args.timeout, + on_update=_print_progress, + polling_url=enqueued.get("url"), + ) + except rfapi.RoboflowError as exc: + output_error(args, str(exc)) + return + except TimeoutError as exc: + output_error(args, str(exc)) + return + + if final.get("status") == "failed": + output_error(args, final.get("error") or "Fork task failed.") + return + + project_url = (final.get("result") or {}).get("url", "") + text = f"Forked.\nDestination URL: {project_url}" if project_url else "Forked." + output(args, final, text=text) + + def _health_project(args): # noqa: ANN001 import json diff --git a/roboflow/core/async_tasks.py b/roboflow/core/async_tasks.py new file mode 100644 index 00000000..fd385a83 --- /dev/null +++ b/roboflow/core/async_tasks.py @@ -0,0 +1,51 @@ +"""Helpers for polling Roboflow async tasks.""" + +from __future__ import annotations + +import time +from typing import Any, Callable, Dict, Optional + +from roboflow.adapters import rfapi + +NON_TERMINAL_STATUSES = frozenset({"created", "pending", "queued", "running", "in_progress"}) + + +def poll_until_terminal( + api_key: str, + workspace_url: str, + task_id: str, + *, + interval: float = 4.0, + timeout: float = 1800.0, + on_update: Optional[Callable[[Dict[str, Any]], None]] = None, + polling_url: Optional[str] = None, +) -> Dict[str, Any]: + """Poll an async task until status is terminal or timeout elapses. + + If ``polling_url`` is provided, hit it verbatim (the server returns one + alongside ``taskId`` from enqueue endpoints; it may point at a different + host than ``API_URL``). Otherwise build the URL from ``API_URL`` / + ``workspace_url`` / ``task_id`` via :func:`rfapi.get_async_task`. + + A non-positive ``timeout`` disables the timeout. Returns the final + status dict on terminal status. ``RoboflowError`` from the underlying + API call is propagated; ``TimeoutError`` is raised if the deadline + passes before a terminal status is observed. + """ + deadline = None if timeout <= 0 else time.monotonic() + timeout + while True: + if polling_url: + status = rfapi.get_async_task_at(api_key, polling_url) + else: + status = rfapi.get_async_task(api_key, workspace_url, task_id) + # Invoke the callback before the terminal check so the final tick + # (typically `current == total`) is delivered to the caller. + if on_update: + on_update(status) + if status.get("status") not in NON_TERMINAL_STATUSES: + return status + if deadline is not None and time.monotonic() >= deadline: + raise TimeoutError( + f"Timed out after {timeout:.0f}s waiting for task {task_id} (last status: {status.get('status')})." + ) + time.sleep(interval) diff --git a/roboflow/core/workspace.py b/roboflow/core/workspace.py index 6c8e87ab..89519409 100644 --- a/roboflow/core/workspace.py +++ b/roboflow/core/workspace.py @@ -131,6 +131,32 @@ def create_project(self, project_name, project_type, project_license, annotation return Project(self.__api_key, r.json(), self.model_format) + def fork_project( + self, + *, + url: Optional[str] = None, + source_project_slug: Optional[str] = None, + ) -> Dict[str, Any]: + """Fork a public Universe project into this workspace. + + Args: + url: Universe project URL. + source_project_slug: Source project slug when not using ``url``. + + Returns: + The API response, typically ``{"taskId": "...", "url": "..."}``. + """ + return rfapi.fork_project( + self.__api_key, + self.url, + url=url, + source_project_slug=source_project_slug, + ) + + def get_async_task(self, task_id: str) -> Dict[str, Any]: + """Return the current status of an async task owned by this workspace.""" + return rfapi.get_async_task(self.__api_key, self.url, task_id) + def devices(self) -> List["Device"]: """List v2 devices registered in this workspace. diff --git a/tests/adapters/test_rfapi_phase2.py b/tests/adapters/test_rfapi_phase2.py index cc205df5..dee26e9f 100644 --- a/tests/adapters/test_rfapi_phase2.py +++ b/tests/adapters/test_rfapi_phase2.py @@ -537,6 +537,149 @@ def test_error(self, mock_get): list_workflow_versions("key", "ws", "wf1") +class TestForkProject(unittest.TestCase): + @patch("roboflow.adapters.rfapi.requests.post") + def test_success_with_url(self, mock_post): + from roboflow.adapters.rfapi import fork_project + + mock_post.return_value = MagicMock(status_code=202, json=lambda: {"taskId": "task-1", "url": "poll"}) + + result = fork_project("key", "target-ws", url="source-ws/source-project") + + self.assertEqual(result["taskId"], "task-1") + self.assertIn("/target-ws/projects/fork", mock_post.call_args[0][0]) + payload = mock_post.call_args[1]["json"] + self.assertEqual(payload, {"url": "source-ws/source-project"}) + + @patch("roboflow.adapters.rfapi.requests.post") + def test_success_with_explicit_source_slug(self, mock_post): + from roboflow.adapters.rfapi import fork_project + + mock_post.return_value = MagicMock(status_code=202, json=lambda: {"taskId": "task-1", "url": "poll"}) + + fork_project( + "key", + "target-ws", + source_project_slug="source-project", + ) + + payload = mock_post.call_args[1]["json"] + self.assertEqual(payload, {"source_project": "source-project"}) + + @patch("roboflow.adapters.rfapi.requests.post") + def test_error(self, mock_post): + from roboflow.adapters.rfapi import RoboflowError, fork_project + + mock_post.return_value = MagicMock(status_code=403, ok=False, text="Forbidden") + with self.assertRaises(RoboflowError): + fork_project("key", "ws", url="source-ws/source-project") + + @patch("roboflow.adapters.rfapi.requests.post") + def test_any_2xx_accepted(self, mock_post): + """#8 — accept any 2xx so the SDK doesn't break if the backend ever + returns 200 (sync result) or 201 (created) instead of 202. + """ + from roboflow.adapters.rfapi import fork_project + + for code in (200, 201, 202, 204): + mock_post.return_value = MagicMock( + status_code=code, + ok=200 <= code < 300, + json=lambda: {"taskId": "t", "url": "u"}, + ) + result = fork_project("key", "ws", url="source-ws/source-project") + self.assertEqual(result["taskId"], "t") + + +class TestGetAsyncTask(unittest.TestCase): + @patch("roboflow.adapters.rfapi.requests.get") + def test_success(self, mock_get): + from roboflow.adapters.rfapi import get_async_task + + mock_get.return_value = MagicMock(status_code=200, json=lambda: {"taskId": "task-1", "status": "running"}) + + result = get_async_task("key", "ws", "task-1") + + self.assertEqual(result["status"], "running") + self.assertIn("/ws/asynctasks/task-1", mock_get.call_args[0][0]) + + @patch("roboflow.adapters.rfapi.requests.get") + def test_malformed_task_id_is_url_encoded(self, mock_get): + """A task_id containing path/query/fragment characters must not + silently mutate the request path. Each unsafe char is percent-encoded + by ``urllib.parse.quote(..., safe="")``.""" + from roboflow.adapters.rfapi import get_async_task + + mock_get.return_value = MagicMock(status_code=200, json=lambda: {"taskId": "x", "status": "running"}) + + # Task ids containing `/`, `?`, `#`, `..`, spaces, and a literal `&` + # used to land in the path verbatim — leaking the api_key with a + # forged path and confusing the router. With the fix in place each + # unsafe character is percent-encoded. + get_async_task("key", "ws", "../task?secret=1#frag&x") + + called_url = mock_get.call_args[0][0] + # Slash, dot, question mark, hash, ampersand, and space are all encoded. + self.assertIn("/ws/asynctasks/", called_url) + self.assertIn("%2F", called_url) # `/` + self.assertIn("%3F", called_url) # `?` + self.assertIn("%23", called_url) # `#` + self.assertIn("%26", called_url) # `&` + # Path doesn't end with the bare task id segments. + self.assertNotIn("/asynctasks/../task", called_url) + self.assertNotIn("?secret=1", called_url.split("/asynctasks/", 1)[1]) + + @patch("roboflow.adapters.rfapi.requests.get") + def test_get_async_task_at_uses_supplied_url(self, mock_get): + """``get_async_task_at`` hits the server-supplied polling URL + verbatim (modulo the api_key query param), so polling stays on the + host the task lives on even if it differs from ``API_URL``.""" + from roboflow.adapters.rfapi import get_async_task_at + + mock_get.return_value = MagicMock(status_code=200, json=lambda: {"taskId": "task-1", "status": "completed"}) + result = get_async_task_at("key", "https://other.host/ws/asynctasks/task-1") + self.assertEqual(result["status"], "completed") + self.assertEqual(mock_get.call_args[0][0], "https://other.host/ws/asynctasks/task-1") + self.assertEqual(mock_get.call_args[1]["params"], {"api_key": "key"}) + + @patch("roboflow.adapters.rfapi.requests.get") + def test_error(self, mock_get): + from roboflow.adapters.rfapi import RoboflowError, get_async_task + + mock_get.return_value = MagicMock(status_code=404, text="Not found") + with self.assertRaises(RoboflowError): + get_async_task("key", "ws", "missing") + + +class TestGetAsyncTaskAt(unittest.TestCase): + @patch("roboflow.adapters.rfapi.requests.get") + def test_polling_url_used_verbatim(self, mock_get): + """When the server returns a fully-qualified polling URL, the SDK must + hit it as-is (potentially on a different host than ``API_URL``) and + only attach the ``api_key`` query param. + """ + from roboflow.adapters.rfapi import get_async_task_at + + mock_get.return_value = MagicMock(status_code=200, json=lambda: {"taskId": "task-1", "status": "running"}) + + polling_url = "https://localapi.roboflow.one/ws/asynctasks/task-1" + result = get_async_task_at("api-key", polling_url) + + self.assertEqual(result["status"], "running") + # URL passed through unchanged. + self.assertEqual(mock_get.call_args[0][0], polling_url) + # api_key tacked on as a param. + self.assertEqual(mock_get.call_args[1]["params"], {"api_key": "api-key"}) + + @patch("roboflow.adapters.rfapi.requests.get") + def test_error_on_non_200(self, mock_get): + from roboflow.adapters.rfapi import RoboflowError, get_async_task_at + + mock_get.return_value = MagicMock(status_code=404, text="Not found") + with self.assertRaises(RoboflowError): + get_async_task_at("key", "https://api.roboflow.com/ws/asynctasks/missing") + + class TestForkWorkflow(unittest.TestCase): @patch("roboflow.adapters.rfapi.requests.post") def test_success(self, mock_post): diff --git a/tests/cli/test_asynctasks_handler.py b/tests/cli/test_asynctasks_handler.py new file mode 100644 index 00000000..688f27fd --- /dev/null +++ b/tests/cli/test_asynctasks_handler.py @@ -0,0 +1,197 @@ +"""Tests for the `roboflow asynctasks` CLI handler.""" + +import json +import unittest +from argparse import Namespace +from unittest.mock import patch + +from typer.testing import CliRunner + +from roboflow.cli import app + +runner = CliRunner() + + +def _make_args(**kwargs): + defaults = { + "json": False, + "workspace": "test-ws", + "api_key": "test-key", + "quiet": False, + "task_id": "task-123", + "timeout": 1800, + } + defaults.update(kwargs) + return Namespace(**defaults) + + +class TestAsyncTasksRegistration(unittest.TestCase): + def test_app_exists(self) -> None: + from roboflow.cli.handlers.asynctasks import asynctasks_app + + self.assertIsNotNone(asynctasks_app) + + def test_get_help(self) -> None: + result = runner.invoke(app, ["asynctasks", "get", "--help"]) + self.assertEqual(result.exit_code, 0) + + def test_wait_help(self) -> None: + result = runner.invoke(app, ["asynctasks", "wait", "--help"]) + self.assertEqual(result.exit_code, 0) + self.assertIn("timeout", result.output.lower()) + + +class TestAsyncTaskGet(unittest.TestCase): + @patch("roboflow.adapters.rfapi.get_async_task") + @patch("roboflow.config.load_roboflow_api_key", return_value="test-key") + def test_get_text(self, _mock_key, mock_get): + from roboflow.cli.handlers.asynctasks import _get_async_task + + mock_get.return_value = { + "taskId": "task-123", + "status": "running", + "progress": {"percent": 42}, + } + args = _make_args() + with patch("builtins.print") as mock_print: + _get_async_task(args) + + mock_get.assert_called_once_with("test-key", "test-ws", "task-123") + printed = mock_print.call_args[0][0] + self.assertIn("task-123", printed) + self.assertIn("running", printed) + + @patch("roboflow.adapters.rfapi.get_async_task") + @patch("roboflow.config.load_roboflow_api_key", return_value="test-key") + def test_get_json(self, _mock_key, mock_get): + from roboflow.cli.handlers.asynctasks import _get_async_task + + payload = { + "taskId": "task-123", + "status": "completed", + "result": {"forked": True, "url": "https://app.roboflow.com/x/y"}, + } + mock_get.return_value = payload + args = _make_args(json=True) + with patch("builtins.print") as mock_print: + _get_async_task(args) + + # Server payload pass-through. + out = json.loads(mock_print.call_args[0][0]) + self.assertEqual(out, payload) + + @patch("roboflow.adapters.rfapi.get_async_task") + @patch("roboflow.config.load_roboflow_api_key", return_value="test-key") + def test_get_404_exits_three(self, _mock_key, mock_get): + from roboflow.adapters.rfapi import RoboflowError + from roboflow.cli.handlers.asynctasks import _get_async_task + + mock_get.side_effect = RoboflowError('{"error":"Async task not found"}') + args = _make_args() + with self.assertRaises(SystemExit) as ctx: + _get_async_task(args) + self.assertEqual(ctx.exception.code, 3) + + +class TestAsyncTaskWait(unittest.TestCase): + @patch("roboflow.core.async_tasks.time.sleep", lambda *_a, **_k: None) + @patch("roboflow.adapters.rfapi.get_async_task") + @patch("roboflow.config.load_roboflow_api_key", return_value="test-key") + def test_wait_until_completed(self, _mock_key, mock_get): + from roboflow.cli.handlers.asynctasks import _wait_async_task + + mock_get.side_effect = [ + {"taskId": "task-1", "status": "pending", "progress": None}, + {"taskId": "task-1", "status": "running", "progress": {"current": 1, "total": 3}}, + {"taskId": "task-1", "status": "completed", "result": {"ok": True}}, + ] + args = _make_args(task_id="task-1") + with patch("builtins.print") as mock_print: + _wait_async_task(args) + + self.assertEqual(mock_get.call_count, 3) + printed = mock_print.call_args[0][0] + self.assertIn("completed", printed) + mock_print.assert_any_call("Task progress: 1/3", flush=True) + + @patch("roboflow.core.async_tasks.time.sleep", lambda *_a, **_k: None) + @patch("roboflow.adapters.rfapi.get_async_task") + @patch("roboflow.config.load_roboflow_api_key", return_value="test-key") + def test_wait_until_failed_exits_one(self, _mock_key, mock_get): + from roboflow.cli.handlers.asynctasks import _wait_async_task + + mock_get.return_value = { + "taskId": "task-1", + "status": "failed", + "error": "Source dataset not public", + } + args = _make_args(task_id="task-1") + with self.assertRaises(SystemExit) as ctx: + _wait_async_task(args) + self.assertEqual(ctx.exception.code, 1) + + @patch("roboflow.core.async_tasks.time.sleep", lambda *_a, **_k: None) + @patch("roboflow.core.async_tasks.time.monotonic") + @patch("roboflow.adapters.rfapi.get_async_task") + @patch("roboflow.config.load_roboflow_api_key", return_value="test-key") + def test_wait_timeout_exits_one(self, _mock_key, mock_get, mock_monotonic): + from roboflow.cli.handlers.asynctasks import _wait_async_task + + # Two get_async_task calls, then deadline check trips. + mock_get.return_value = {"taskId": "task-1", "status": "running"} + # monotonic sequence: start, deadline-check-1 (still under), deadline-check-2 (over) + mock_monotonic.side_effect = [0.0, 0.5, 99999.0] + args = _make_args(task_id="task-1", timeout=10) + with self.assertRaises(SystemExit) as ctx: + _wait_async_task(args) + self.assertEqual(ctx.exception.code, 1) + + @patch("roboflow.core.async_tasks.time.sleep", lambda *_a, **_k: None) + @patch("roboflow.adapters.rfapi.get_async_task") + @patch("roboflow.config.load_roboflow_api_key", return_value="test-key") + def test_wait_server_error_exits_three(self, _mock_key, mock_get): + from roboflow.adapters.rfapi import RoboflowError + from roboflow.cli.handlers.asynctasks import _wait_async_task + + mock_get.side_effect = RoboflowError('{"error":"Async task not found"}') + args = _make_args(task_id="task-1") + with self.assertRaises(SystemExit) as ctx: + _wait_async_task(args) + self.assertEqual(ctx.exception.code, 3) + + +class TestPollUntilTerminalCallback(unittest.TestCase): + """#6 — `on_update` must fire on the terminal tick too, so callers driving + a progress bar see the final ``current == total`` event before the loop + returns.""" + + @patch("roboflow.core.async_tasks.time.sleep", lambda *_a, **_k: None) + @patch("roboflow.adapters.rfapi.get_async_task") + def test_on_update_called_for_terminal_status(self, mock_get): + from roboflow.core.async_tasks import poll_until_terminal + + mock_get.side_effect = [ + {"taskId": "t", "status": "running", "progress": {"current": 1, "total": 2}}, + {"taskId": "t", "status": "completed", "progress": {"current": 2, "total": 2}}, + ] + seen = [] + result = poll_until_terminal("k", "ws", "t", on_update=seen.append) + + # Both ticks delivered, including the completed one. + self.assertEqual([s["status"] for s in seen], ["running", "completed"]) + self.assertEqual(result["status"], "completed") + + +class TestAsyncTaskNoWorkspace(unittest.TestCase): + @patch("roboflow.cli._resolver.resolve_default_workspace", return_value=None) + def test_get_no_workspace_exits_two(self, _mock_resolve): + from roboflow.cli.handlers.asynctasks import _get_async_task + + args = _make_args(workspace=None, api_key=None) + with self.assertRaises(SystemExit) as ctx: + _get_async_task(args) + self.assertEqual(ctx.exception.code, 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/cli/test_project_fork_handler.py b/tests/cli/test_project_fork_handler.py new file mode 100644 index 00000000..d96f7806 --- /dev/null +++ b/tests/cli/test_project_fork_handler.py @@ -0,0 +1,195 @@ +"""Tests for the `roboflow project fork` CLI handler.""" + +import json +import unittest +from argparse import Namespace +from unittest.mock import patch + +from typer.testing import CliRunner + +from roboflow.cli import app + +runner = CliRunner() + + +def _make_args(**kwargs): + """Create a Namespace with CLI defaults and fork-command defaults.""" + defaults = { + "json": False, + "workspace": "test-ws", + "api_key": "test-key", + "quiet": False, + "no_wait": False, + "timeout": 1800, + } + defaults.update(kwargs) + return Namespace(**defaults) + + +class TestProjectForkRegistration(unittest.TestCase): + def test_fork_help_exists(self) -> None: + result = runner.invoke(app, ["project", "fork", "--help"]) + self.assertEqual(result.exit_code, 0) + self.assertIn("Universe", result.output) + self.assertIn("no", result.output.lower()) + self.assertIn("wait", result.output.lower()) + self.assertIn("timeout", result.output.lower()) + + +class TestForkProjectNoWait(unittest.TestCase): + @patch("roboflow.adapters.rfapi.fork_project") + @patch("roboflow.config.load_roboflow_api_key", return_value="test-key") + def test_url_form_no_wait_text(self, _mock_key, mock_fork): + from roboflow.cli.handlers.project import _fork_project + + mock_fork.return_value = { + "taskId": "task-123", + "url": "https://api.roboflow.com/test-ws/asynctasks/task-123", + } + args = _make_args( + source="https://universe.roboflow.com/ws/proj", + no_wait=True, + ) + with patch("builtins.print") as mock_print: + _fork_project(args) + + mock_fork.assert_called_once_with( + "test-key", + "test-ws", + url="https://universe.roboflow.com/ws/proj", + ) + printed = mock_print.call_args[0][0] + self.assertIn("task-123", printed) + # #10 — server-supplied polling URL surfaces so the user can poll later. + self.assertIn("https://api.roboflow.com/test-ws/asynctasks/task-123", printed) + + @patch("roboflow.adapters.rfapi.fork_project") + @patch("roboflow.config.load_roboflow_api_key", return_value="test-key") + def test_shorthand_no_wait_json(self, _mock_key, mock_fork): + from roboflow.cli.handlers.project import _fork_project + + mock_fork.return_value = {"taskId": "task-456", "url": "poll-url"} + args = _make_args(json=True, source="ws/proj", no_wait=True) + with patch("builtins.print") as mock_print: + _fork_project(args) + + mock_fork.assert_called_once_with( + "test-key", + "test-ws", + url="ws/proj", + ) + out = json.loads(mock_print.call_args[0][0]) + # Server response is passed through verbatim. + self.assertEqual(out, {"taskId": "task-456", "url": "poll-url"}) + + +class TestForkProjectWait(unittest.TestCase): + @patch("roboflow.core.async_tasks.time.sleep", lambda *_a, **_k: None) + @patch("roboflow.adapters.rfapi.get_async_task") + @patch("roboflow.adapters.rfapi.fork_project") + @patch("roboflow.config.load_roboflow_api_key", return_value="test-key") + def test_wait_until_completed_text(self, _mock_key, mock_fork, mock_get): + from roboflow.cli.handlers.project import _fork_project + + # No `url` in the fork response → poll_until_terminal falls back to + # rfapi.get_async_task (which `mock_get` patches). + mock_fork.return_value = {"taskId": "task-1"} + mock_get.side_effect = [ + {"taskId": "task-1", "status": "running", "progress": {"current": 1, "total": 2}}, + { + "taskId": "task-1", + "status": "completed", + "result": { + "forked": True, + "datasetUrl": "license-plates", + "id": "test-ws/license-plates", + "name": "License Plates", + "url": "https://app.roboflow.com/test-ws/license-plates", + }, + }, + ] + args = _make_args(source="ws/proj") + with patch("builtins.print") as mock_print: + _fork_project(args) + + printed = mock_print.call_args[0][0] + self.assertIn("Forked", printed) + self.assertIn("Destination URL", printed) + self.assertIn("https://app.roboflow.com/test-ws/license-plates", printed) + mock_print.assert_any_call("Task progress: 1/2", flush=True) + self.assertEqual(mock_get.call_count, 2) + + @patch("roboflow.core.async_tasks.time.sleep", lambda *_a, **_k: None) + @patch("roboflow.adapters.rfapi.get_async_task") + @patch("roboflow.adapters.rfapi.fork_project") + @patch("roboflow.config.load_roboflow_api_key", return_value="test-key") + def test_wait_until_completed_json(self, _mock_key, mock_fork, mock_get): + from roboflow.cli.handlers.project import _fork_project + + mock_fork.return_value = {"taskId": "task-1"} + terminal_payload = { + "taskId": "task-1", + "status": "completed", + "result": {"forked": True, "url": "https://app.roboflow.com/x/y"}, + } + mock_get.return_value = terminal_payload + args = _make_args(json=True, source="ws/proj") + with patch("builtins.print") as mock_print: + _fork_project(args) + + # Server payload is passed through unchanged in --json mode. + out = json.loads(mock_print.call_args[0][0]) + self.assertEqual(out, terminal_payload) + + @patch("roboflow.core.async_tasks.time.sleep", lambda *_a, **_k: None) + @patch("roboflow.adapters.rfapi.get_async_task") + @patch("roboflow.adapters.rfapi.fork_project") + @patch("roboflow.config.load_roboflow_api_key", return_value="test-key") + def test_wait_until_failed_exits_one(self, _mock_key, mock_fork, mock_get): + from roboflow.cli.handlers.project import _fork_project + + mock_fork.return_value = {"taskId": "task-1"} + mock_get.return_value = { + "taskId": "task-1", + "status": "failed", + "error": "Source dataset is not public", + } + args = _make_args(source="ws/proj") + with self.assertRaises(SystemExit) as ctx: + _fork_project(args) + self.assertEqual(ctx.exception.code, 1) + + +class TestForkProjectErrors(unittest.TestCase): + def test_empty_source_exits(self): + from roboflow.cli.handlers.project import _fork_project + + args = _make_args(source="") + with self.assertRaises(SystemExit) as ctx: + _fork_project(args) + self.assertEqual(ctx.exception.code, 1) + + @patch("roboflow.adapters.rfapi.fork_project") + @patch("roboflow.config.load_roboflow_api_key", return_value="test-key") + def test_server_error_passes_through(self, _mock_key, mock_fork): + from roboflow.adapters.rfapi import RoboflowError + from roboflow.cli.handlers.project import _fork_project + + mock_fork.side_effect = RoboflowError('{"error":"You already own that dataset."}') + args = _make_args(source="ws/proj") + with self.assertRaises(SystemExit) as ctx: + _fork_project(args) + self.assertEqual(ctx.exception.code, 1) + + @patch("roboflow.cli._resolver.resolve_default_workspace", return_value=None) + def test_no_workspace_exits_two(self, _mock_resolve): + from roboflow.cli.handlers.project import _fork_project + + args = _make_args(workspace=None, api_key=None, source="ws/proj") + with self.assertRaises(SystemExit) as ctx: + _fork_project(args) + self.assertEqual(ctx.exception.code, 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/manual/uselocal b/tests/manual/uselocal index b91bcbc3..f9c0b713 100644 --- a/tests/manual/uselocal +++ b/tests/manual/uselocal @@ -1,8 +1,15 @@ #!/bin/env bash -SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" -cp $SCRIPT_DIR/data/.config-staging $SCRIPT_DIR/data/.config +SCRIPT_PATH="${BASH_SOURCE[0]}" +[ -n "$ZSH_VERSION" ] && SCRIPT_PATH="${(%):-%x}" +SCRIPT_DIR="$( cd "$( dirname "$SCRIPT_PATH" )" && pwd )" +cp "$SCRIPT_DIR/data/.config-staging" "$SCRIPT_DIR/data/.config" export API_URL=https://localapi.roboflow.one export APP_URL=https://localapp.roboflow.one export DEDICATED_DEPLOYMENT_URL=https://staging.roboflow.cloud export ROBOFLOW_CONFIG_DIR=$SCRIPT_DIR/data/.config +MKCERT_ROOT_CA=$HOME/Library/Application\ Support/mkcert/rootCA.pem +if [ -f "$MKCERT_ROOT_CA" ]; then + export REQUESTS_CA_BUNDLE=$MKCERT_ROOT_CA + export CURL_CA_BUNDLE=$MKCERT_ROOT_CA +fi # need to set it in /etc/hosts to the IP of host.docker.internal! diff --git a/tests/test_workspace.py b/tests/test_workspace.py index 19d3c78b..5b7f4aa0 100644 --- a/tests/test_workspace.py +++ b/tests/test_workspace.py @@ -4,8 +4,18 @@ import tempfile import unittest import zipfile +from unittest.mock import patch -from roboflow.core.workspace import _zip_directory +from roboflow.core.workspace import Workspace, _zip_directory + + +def _make_workspace(): + return Workspace( + {"workspace": {"name": "Test", "projects": [], "url": "test-ws"}}, + api_key="test-key", + default_workspace="test-ws", + model_format="yolov8", + ) class TestZipDirectory(unittest.TestCase): @@ -39,5 +49,46 @@ def test_filters_hidden_and_junk_entries(self): os.unlink(zip_path) +class TestWorkspaceAsyncTasks(unittest.TestCase): + @patch("roboflow.adapters.rfapi.fork_project") + def test_fork_project_uses_workspace_destination(self, mock_fork): + workspace = _make_workspace() + mock_fork.return_value = {"taskId": "task-1", "url": "poll-url"} + + result = workspace.fork_project(url="source-ws/source-project") + + self.assertEqual(result, {"taskId": "task-1", "url": "poll-url"}) + mock_fork.assert_called_once_with( + "test-key", + "test-ws", + url="source-ws/source-project", + source_project_slug=None, + ) + + @patch("roboflow.adapters.rfapi.fork_project") + def test_fork_project_accepts_explicit_source_slug(self, mock_fork): + workspace = _make_workspace() + mock_fork.return_value = {"taskId": "task-1", "url": "poll-url"} + + workspace.fork_project(source_project_slug="source-project") + + mock_fork.assert_called_once_with( + "test-key", + "test-ws", + url=None, + source_project_slug="source-project", + ) + + @patch("roboflow.adapters.rfapi.get_async_task") + def test_get_async_task_uses_workspace_destination(self, mock_get): + workspace = _make_workspace() + mock_get.return_value = {"taskId": "task-1", "status": "running"} + + result = workspace.get_async_task("task-1") + + self.assertEqual(result, {"taskId": "task-1", "status": "running"}) + mock_get.assert_called_once_with("test-key", "test-ws", "task-1") + + if __name__ == "__main__": unittest.main()