From 2b743b22af489e1687bc11c58d7cbb6e01b165e5 Mon Sep 17 00:00:00 2001 From: Denys Fedoryshchenko Date: Thu, 9 Apr 2026 00:27:00 +0300 Subject: [PATCH] linting fixes Signed-off-by: Denys Fedoryshchenko --- api/admin.py | 20 +- api/db.py | 16 +- api/email_sender.py | 16 +- api/main.py | 300 +++++++++++++++----- api/models.py | 44 ++- api/pubsub.py | 12 +- api/pubsub_mongo.py | 131 +++++++-- api/user_manager.py | 47 ++- migrations/20231102101356_user.py | 78 ++--- migrations/20231215122000_node_models.py | 119 ++++---- scripts/usermanager.py | 98 +++++-- tests/e2e_tests/test_count_handler.py | 10 +- tests/e2e_tests/test_pipeline.py | 22 +- tests/e2e_tests/test_pubsub_handler.py | 12 +- tests/e2e_tests/test_regression_handler.py | 13 +- tests/e2e_tests/test_unsubscribe_handler.py | 8 +- tests/e2e_tests/test_user_creation.py | 16 +- tests/unit_tests/conftest.py | 23 +- tests/unit_tests/test_authz_handler.py | 17 +- tests/unit_tests/test_events_handler.py | 12 +- tests/unit_tests/test_node_handler.py | 24 +- tests/unit_tests/test_token_handler.py | 8 +- tests/unit_tests/test_user_group_handler.py | 12 +- tests/unit_tests/test_user_handler.py | 48 +++- 24 files changed, 784 insertions(+), 322 deletions(-) diff --git a/api/admin.py b/api/admin.py index d673cba2..3ade6afa 100644 --- a/api/admin.py +++ b/api/admin.py @@ -29,10 +29,14 @@ async def setup_admin_user(db, username, email, password=None): if not password: password = os.getenv("KCI_INITIAL_PASSWORD") if not password: - print("Password is empty and KCI_INITIAL_PASSWORD is not set, aborting.") + print( + "Password is empty and KCI_INITIAL_PASSWORD is not set, aborting." + ) return None else: - retyped = getpass.getpass(f"Retype password for user '{username}': ") + retyped = getpass.getpass( + f"Retype password for user '{username}': " + ) if password != retyped: print("Sorry, passwords do not match, aborting.") return None @@ -67,7 +71,9 @@ async def main(args): db = Database(args.mongo, args.database) await db.initialize_beanie() await db.create_indexes() - created = await setup_admin_user(db, args.username, args.email, password=args.password) + created = await setup_admin_user( + db, args.username, args.email, password=args.password + ) return created is not None @@ -79,8 +85,12 @@ async def main(args): help="Mongo server connection string", ) parser.add_argument("--username", default="admin", help="Admin username") - parser.add_argument("--database", default="kernelci", help="KernelCI database name") - parser.add_argument("--email", required=True, help="Admin user email address") + parser.add_argument( + "--database", default="kernelci", help="KernelCI database name" + ) + parser.add_argument( + "--email", required=True, help="Admin user email address" + ) parser.add_argument( "--password", default="", diff --git a/api/db.py b/api/db.py index ed2301b3..d04cc070 100644 --- a/api/db.py +++ b/api/db.py @@ -151,7 +151,9 @@ def _translate_operators(self, attributes): if isinstance(op_value, str) and op_value.isdecimal(): op_value = int(op_value) if translated_attributes.get(key): - translated_attributes[key].update({op_key: op_value}) + translated_attributes[key].update( + {op_key: op_value} + ) else: translated_attributes[key] = {op_key: op_value} return translated_attributes @@ -251,7 +253,9 @@ async def insert_many(self, model, documents): result = await col.insert_many(documents) return result.inserted_ids - async def _create_recursively(self, hierarchy: Hierarchy, parent: Node, cls, col): + async def _create_recursively( + self, hierarchy: Hierarchy, parent: Node, cls, col + ): obj = parse_node_obj(hierarchy.node) if parent: obj.parent = parent.id @@ -259,7 +263,9 @@ async def _create_recursively(self, hierarchy: Hierarchy, parent: Node, cls, col obj.update() if obj.parent == obj.id: raise ValueError("Parent cannot be the same as the object") - res = await col.replace_one({"_id": ObjectId(obj.id)}, obj.dict(by_alias=True)) + res = await col.replace_one( + {"_id": ObjectId(obj.id)}, obj.dict(by_alias=True) + ) if res.matched_count == 0: raise ValueError(f"No object found with id: {obj.id}") else: @@ -293,7 +299,9 @@ async def update(self, obj): obj.update() if obj.parent == obj.id: raise ValueError("Parent cannot be the same as the object") - res = await col.replace_one({"_id": ObjectId(obj.id)}, obj.dict(by_alias=True)) + res = await col.replace_one( + {"_id": ObjectId(obj.id)}, obj.dict(by_alias=True) + ) if res.matched_count == 0: raise ValueError(f"No object found with id: {obj.id}") return obj.__class__(**await col.find_one(ObjectId(obj.id))) diff --git a/api/email_sender.py b/api/email_sender.py index 489f11ba..1afb851d 100644 --- a/api/email_sender.py +++ b/api/email_sender.py @@ -26,9 +26,13 @@ def __init__(self): def _smtp_connect(self): """Method to create a connection with SMTP server""" if self._settings.smtp_port == 465: - smtp = smtplib.SMTP_SSL(self._settings.smtp_host, self._settings.smtp_port) + smtp = smtplib.SMTP_SSL( + self._settings.smtp_host, self._settings.smtp_port + ) else: - smtp = smtplib.SMTP(self._settings.smtp_host, self._settings.smtp_port) + smtp = smtplib.SMTP( + self._settings.smtp_host, self._settings.smtp_port + ) smtp.starttls() smtp.login(self._settings.email_sender, self._settings.email_password) return smtp @@ -60,7 +64,11 @@ def _send_email(self, email_msg): detail="Failed to send email", ) from exc - def create_and_send_email(self, email_subject, email_content, email_recipient): + def create_and_send_email( + self, email_subject, email_content, email_recipient + ): """Method to create and send email""" - email_msg = self._create_email(email_subject, email_content, email_recipient) + email_msg = self._create_email( + email_subject, email_content, email_recipient + ) self._send_email(email_msg) diff --git a/api/main.py b/api/main.py index 4cace41c..a2538cb6 100644 --- a/api/main.py +++ b/api/main.py @@ -111,7 +111,8 @@ def _validate_startup_environment(): details.append("empty: " + ", ".join(sorted(empty))) raise RuntimeError( "Startup environment validation failed. " - "Set required environment variables before starting the API. " + "; ".join(details) + "Set required environment variables before starting the API. " + + "; ".join(details) ) @@ -159,7 +160,9 @@ async def subscription_cleanup_task(): while True: try: await asyncio.sleep(SUBSCRIPTION_CLEANUP_INTERVAL_MINUTES * 60) - cleaned = await pubsub.cleanup_stale_subscriptions(SUBSCRIPTION_MAX_AGE_MINUTES) + cleaned = await pubsub.cleanup_stale_subscriptions( + SUBSCRIPTION_MAX_AGE_MINUTES + ) if cleaned > 0: metrics.add("subscriptions_cleaned", 1) print(f"Cleaned up {cleaned} stale subscriptions") @@ -219,7 +222,9 @@ async def ensure_initial_admin_user(): await db.create( User( username=username, - hashed_password=Authentication.get_password_hash(initial_password), + hashed_password=Authentication.get_password_hash( + initial_password + ), email=email, is_superuser=1, is_verified=1, @@ -280,7 +285,9 @@ def get_current_user( def get_current_superuser( - user: User = Depends(fastapi_users_instance.current_user(active=True, superuser=True)), + user: User = Depends( + fastapi_users_instance.current_user(active=True, superuser=True) + ), ): """Get current active superuser""" return user @@ -381,11 +388,17 @@ def _resolve_public_base_url(request: Request) -> str: forwarded_header = request.headers.get("forwarded") if forwarded_header and is_proxy_request: - forwarded_host, forwarded_proto = _parse_forwarded_header(forwarded_header) + forwarded_host, forwarded_proto = _parse_forwarded_header( + forwarded_header + ) if is_proxy_request: - forwarded_host = forwarded_host or request.headers.get("x-forwarded-host") - forwarded_proto = forwarded_proto or request.headers.get("x-forwarded-proto") + forwarded_host = forwarded_host or request.headers.get( + "x-forwarded-host" + ) + forwarded_proto = forwarded_proto or request.headers.get( + "x-forwarded-proto" + ) if forwarded_host: scheme = forwarded_proto or request.url.scheme @@ -447,7 +460,9 @@ async def _create_user_for_invite( ) user_create.groups = groups - created_user = await register_router.routes[0].endpoint(request, user_create, user_manager) + created_user = await register_router.routes[0].endpoint( + request, user_create, user_manager + ) if invite.is_superuser: user_from_id = await db.find_by_id(User, created_user.id) @@ -458,12 +473,16 @@ async def _create_user_for_invite( app.include_router( - fastapi_users_instance.get_auth_router(auth_backend, requires_verification=True), + fastapi_users_instance.get_auth_router( + auth_backend, requires_verification=True + ), prefix="/user", tags=["user"], ) -register_router = fastapi_users_instance.get_register_router(UserRead, UserCreate) +register_router = fastapi_users_instance.get_register_router( + UserRead, UserCreate +) @app.post( @@ -566,7 +585,9 @@ async def accept_invite_page(): @app.get("/user/invite/url", response_model=InviteUrlResponse, tags=["user"]) -async def invite_url_preview(request: Request, current_user: User = Depends(get_current_superuser)): +async def invite_url_preview( + request: Request, current_user: User = Depends(get_current_superuser) +): """Preview the resolved public URL used in invite links (admin-only)""" metrics.add("http_requests_total", 1) public_base_url = _resolve_public_base_url(request) @@ -611,7 +632,9 @@ async def accept_invite(accept: InviteAcceptRequest): detail="Invite already accepted", ) - user_from_id.hashed_password = user_manager.password_helper.hash(accept.password) + user_from_id.hashed_password = user_manager.password_helper.hash( + accept.password + ) user_from_id.is_verified = True updated_user = await db.update(user_from_id) @@ -628,7 +651,9 @@ async def accept_invite(accept: InviteAcceptRequest): tags=["user"], ) -users_router = fastapi_users_instance.get_users_router(UserRead, UserUpdate, requires_verification=True) +users_router = fastapi_users_instance.get_users_router( + UserRead, UserUpdate, requires_verification=True +) app.add_api_route( path="/whoami", @@ -655,7 +680,12 @@ async def accept_invite(accept: InviteAcceptRequest): ) -@app.patch("/user/me", response_model=UserRead, tags=["user"], response_model_by_alias=False) +@app.patch( + "/user/me", + response_model=UserRead, + tags=["user"], + response_model_by_alias=False, +) async def update_me( request: Request, user: UserUpdateRequest, @@ -686,13 +716,23 @@ async def update_me( if not group: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=(f"User group does not exist with name: {group_name}"), + detail=( + f"User group does not exist with name: {group_name}" + ), ) groups.append(group) - user_update = UserUpdate(**(user.model_dump(exclude={"groups", "is_superuser"}, exclude_none=True))) + user_update = UserUpdate( + **( + user.model_dump( + exclude={"groups", "is_superuser"}, exclude_none=True + ) + ) + ) if groups: user_update.groups = groups - return await users_router.routes[1].endpoint(request, user_update, current_user, user_manager) + return await users_router.routes[1].endpoint( + request, user_update, current_user, user_manager + ) @app.patch( @@ -731,15 +771,21 @@ async def update_user( if not group: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=(f"User group does not exist with name: {group_name}"), + detail=( + f"User group does not exist with name: {group_name}" + ), ) groups.append(group) - user_update = UserUpdate(**(user.model_dump(exclude={"groups"}, exclude_none=True))) + user_update = UserUpdate( + **(user.model_dump(exclude={"groups"}, exclude_none=True)) + ) if groups: user_update.groups = groups - updated_user = await users_router.routes[3].endpoint(user_update, request, user_from_id, user_manager) + updated_user = await users_router.routes[3].endpoint( + user_update, request, user_from_id, user_manager + ) # Update superuser explicitly since fastapi-users update route ignores it. if user.is_superuser is not None: user_from_id = await db.find_by_id(User, updated_user.id) @@ -749,14 +795,18 @@ async def update_user( @app.get("/user-groups", response_model=PageModel, tags=["user"]) -async def get_user_groups(request: Request, current_user: User = Depends(get_current_superuser)): +async def get_user_groups( + request: Request, current_user: User = Depends(get_current_superuser) +): """List user groups (admin-only).""" metrics.add("http_requests_total", 1) query_params = dict(request.query_params) for pg_key in ["limit", "offset"]: query_params.pop(pg_key, None) paginated_resp = await db.find_by_attributes(UserGroup, query_params) - paginated_resp.items = serialize_paginated_data(UserGroup, paginated_resp.items) + paginated_resp.items = serialize_paginated_data( + UserGroup, paginated_resp.items + ) return paginated_resp @@ -766,7 +816,9 @@ async def get_user_groups(request: Request, current_user: User = Depends(get_cur tags=["user"], response_model_by_alias=False, ) -async def get_user_group(group_id: str, current_user: User = Depends(get_current_superuser)): +async def get_user_group( + group_id: str, current_user: User = Depends(get_current_superuser) +): """Get a user group by id (admin-only).""" metrics.add("http_requests_total", 1) group = await db.find_by_id(UserGroup, group_id) @@ -785,7 +837,8 @@ async def get_user_group(group_id: str, current_user: User = Depends(get_current response_model_by_alias=False, ) async def create_user_group( - group: UserGroupCreateRequest, current_user: User = Depends(get_current_superuser) + group: UserGroupCreateRequest, + current_user: User = Depends(get_current_superuser), ): """Create a user group (admin-only).""" metrics.add("http_requests_total", 1) @@ -798,8 +851,14 @@ async def create_user_group( return await db.create(UserGroup(name=group.name)) -@app.delete("/user-groups/{group_id}", status_code=status.HTTP_204_NO_CONTENT, tags=["user"]) -async def delete_user_group(group_id: str, current_user: User = Depends(get_current_superuser)): +@app.delete( + "/user-groups/{group_id}", + status_code=status.HTTP_204_NO_CONTENT, + tags=["user"], +) +async def delete_user_group( + group_id: str, current_user: User = Depends(get_current_superuser) +): """Delete a user group (admin-only).""" metrics.add("http_requests_total", 1) group = await db.find_by_id(UserGroup, group_id) @@ -812,7 +871,9 @@ async def delete_user_group(group_id: str, current_user: User = Depends(get_curr if assigned_count: raise HTTPException( status_code=status.HTTP_409_CONFLICT, - detail=("User group is assigned to users and cannot be deleted. Remove it from users first."), + detail=( + "User group is assigned to users and cannot be deleted. Remove it from users first." + ), ) await db.delete_by_id(UserGroup, group_id) return Response(status_code=status.HTTP_204_NO_CONTENT) @@ -835,13 +896,19 @@ def _user_can_edit_node(user: User, node: Node) -> bool: user_group_names = {group.name for group in user.groups} if "node:edit:any" in user_group_names: return True - if any(group_name in user_group_names for group_name in getattr(node, "user_groups", [])): + if any( + group_name in user_group_names + for group_name in getattr(node, "user_groups", []) + ): return True runtime = _get_node_runtime(node) if runtime: runtime_editor = ":".join(["runtime", runtime, "node-editor"]) runtime_admin = ":".join(["runtime", runtime, "node-admin"]) - if runtime_editor in user_group_names or runtime_admin in user_group_names: + if ( + runtime_editor in user_group_names + or runtime_admin in user_group_names + ): return True return False @@ -871,7 +938,9 @@ async def authorize_user(node_id: str, user: User = Depends(get_current_user)): tags=["user"], response_model_exclude={"items": {"__all__": {"hashed_password"}}}, ) -async def get_users(request: Request, current_user: User = Depends(get_current_user)): +async def get_users( + request: Request, current_user: User = Depends(get_current_user) +): """Get all the users if no request parameters have passed. Get the matching users otherwise.""" metrics.add("http_requests_total", 1) @@ -900,14 +969,18 @@ async def update_password( ) user_update = UserUpdate(password=new_password) user_from_username = await db.find_one(User, username=credentials.username) - await users_router.routes[3].endpoint(user_update, request, user_from_username, user_manager) + await users_router.routes[3].endpoint( + user_update, request, user_from_username, user_manager + ) # EventHistory is now stored by pubsub.publish_cloudevent() # No need for separate _get_eventhistory function -def _parse_event_id_filter(query_params: dict, event_id: str, event_ids: str) -> None: +def _parse_event_id_filter( + query_params: dict, event_id: str, event_ids: str +) -> None: """Parse and validate event id/ids filter parameters. Modifies query_params in place to add _id filter. @@ -922,12 +995,20 @@ def _parse_event_id_filter(query_params: dict, event_id: str, event_ids: str) -> try: query_params["_id"] = ObjectId(event_id) except (errors.InvalidId, TypeError) as exc: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid id format") from exc + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid id format", + ) from exc elif event_ids: try: - ids_list = [ObjectId(x.strip()) for x in event_ids.split(",") if x.strip()] + ids_list = [ + ObjectId(x.strip()) for x in event_ids.split(",") if x.strip() + ] except (errors.InvalidId, TypeError) as exc: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid ids format") from exc + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid ids format", + ) from exc if not ids_list: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -1140,7 +1221,9 @@ async def get_telemetry(request: Request): query_params["is_infra_error"] = False paginated_resp = await db.find_by_attributes(TelemetryEvent, query_params) - paginated_resp.items = serialize_paginated_data(TelemetryEvent, paginated_resp.items) + paginated_resp.items = serialize_paginated_data( + TelemetryEvent, paginated_resp.items + ) return paginated_resp @@ -1223,10 +1306,20 @@ async def get_telemetry_stats(request: Request): "$group": { "_id": {f: f"${f}" for f in group_by}, "total": {"$sum": 1}, - "pass": {"$sum": {"$cond": [{"$eq": ["$result", "pass"]}, 1, 0]}}, - "fail": {"$sum": {"$cond": [{"$eq": ["$result", "fail"]}, 1, 0]}}, - "incomplete": {"$sum": {"$cond": [{"$eq": ["$result", "incomplete"]}, 1, 0]}}, - "skip": {"$sum": {"$cond": [{"$eq": ["$result", "skip"]}, 1, 0]}}, + "pass": { + "$sum": {"$cond": [{"$eq": ["$result", "pass"]}, 1, 0]} + }, + "fail": { + "$sum": {"$cond": [{"$eq": ["$result", "fail"]}, 1, 0]} + }, + "incomplete": { + "$sum": { + "$cond": [{"$eq": ["$result", "incomplete"]}, 1, 0] + } + }, + "skip": { + "$sum": {"$cond": [{"$eq": ["$result", "skip"]}, 1, 0]} + }, "infra_error": {"$sum": {"$cond": ["$is_infra_error", 1, 0]}}, } } @@ -1265,14 +1358,18 @@ async def get_telemetry_stats(request: Request): @app.get("/telemetry/anomalies", tags=["telemetry"]) async def get_telemetry_anomalies( - window: str = Query("6h", description="Time window: 1h, 3h, 6h, 12h, 24h, 48h"), + window: str = Query( + "6h", description="Time window: 1h, 3h, 6h, 12h, 24h, 48h" + ), threshold: float = Query( 0.5, ge=0.0, le=1.0, description="Min failure/infra error rate to flag (0.0-1.0)", ), - min_total: int = Query(3, ge=1, description="Min events in window to consider (avoids noise)"), + min_total: int = Query( + 3, ge=1, description="Min events in window to consider (avoids noise)" + ), ): """Detect anomalies in telemetry data. @@ -1307,8 +1404,14 @@ async def get_telemetry_anomalies( "device_type": "$device_type", }, "total": {"$sum": 1}, - "fail": {"$sum": {"$cond": [{"$eq": ["$result", "fail"]}, 1, 0]}}, - "incomplete": {"$sum": {"$cond": [{"$eq": ["$result", "incomplete"]}, 1, 0]}}, + "fail": { + "$sum": {"$cond": [{"$eq": ["$result", "fail"]}, 1, 0]} + }, + "incomplete": { + "$sum": { + "$cond": [{"$eq": ["$result", "incomplete"]}, 1, 0] + } + }, "infra_error": {"$sum": {"$cond": ["$is_infra_error", 1, 0]}}, } }, @@ -1316,7 +1419,9 @@ async def get_telemetry_anomalies( { "$addFields": { "infra_rate": {"$divide": ["$infra_error", "$total"]}, - "fail_rate": {"$divide": [{"$add": ["$fail", "$incomplete"]}, "$total"]}, + "fail_rate": { + "$divide": [{"$add": ["$fail", "$incomplete"]}, "$total"] + }, } }, { @@ -1408,7 +1513,11 @@ async def translate_null_query_params(query_params: dict): return translated -@app.get("/node/{node_id}", response_model=Union[Node, None], response_model_by_alias=False) +@app.get( + "/node/{node_id}", + response_model=Union[Node, None], + response_model_by_alias=False, +) async def get_node(node_id: str): """Get node information from the provided node id""" metrics.add("http_requests_total", 1) @@ -1455,7 +1564,9 @@ async def get_nodes(request: Request): model = Node translated_params = model.translate_fields(query_params) paginated_resp = await db.find_by_attributes(model, translated_params) - paginated_resp.items = serialize_paginated_data(model, paginated_resp.items) + paginated_resp.items = serialize_paginated_data( + model, paginated_resp.items + ) return paginated_resp except KeyError as error: raise HTTPException( @@ -1488,7 +1599,9 @@ async def get_nodes_fast(request: Request): try: # Query using the base Node model, regardless of the specific # node type, use asyncio.wait_for with timeout 30 seconds - resp = await asyncio.wait_for(db_find_node_nonpaginated(query_params), timeout=15) + resp = await asyncio.wait_for( + db_find_node_nonpaginated(query_params), timeout=15 + ) return resp except asyncio.TimeoutError as error: raise HTTPException( @@ -1618,15 +1731,21 @@ async def put_node( # Sanity checks # Note: do not update node ownership fields, don't update 'state' # until we've checked the state transition is valid. - update_data = node.model_dump(exclude={"owner", "submitter", "user_groups", "state"}) + update_data = node.model_dump( + exclude={"owner", "submitter", "user_groups", "state"} + ) new_node_def = node_from_id.model_copy(update=update_data) # 1- Parse and validate node to specific subtype specialized_node = parse_node_obj(new_node_def) # 2 - State transition checks - is_valid, message = specialized_node.validate_node_state_transition(node.state) + is_valid, message = specialized_node.validate_node_state_transition( + node.state + ) if not is_valid: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=message) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=message + ) # if state changes, reset processed_by_kcidb_bridge flag if node.state != new_node_def.state: new_node_def.processed_by_kcidb_bridge = False @@ -1666,7 +1785,9 @@ class NodePatchRequest(BaseModel): processed_by_kcidb_bridge: Optional[bool] = None -@app.patch("/node/{node_id}", response_model=Node, response_model_by_alias=False) +@app.patch( + "/node/{node_id}", response_model=Node, response_model_by_alias=False +) async def patch_node( node_id: str, patch: NodePatchRequest, @@ -1703,11 +1824,6 @@ async def patch_node( # State transition checks if new_state is not None: -<<<<<<< HEAD - is_valid, message = specialized_node.validate_node_state_transition(new_state) - if not is_valid: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=message) -======= is_valid, message = specialized_node.validate_node_state_transition( new_state ) @@ -1715,7 +1831,6 @@ async def patch_node( raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=message ) ->>>>>>> 7df261e (api: Implement PATCH endpoint for node) if new_state != new_node_def.state: new_node_def.processed_by_kcidb_bridge = False new_node_def.state = new_state @@ -1744,7 +1859,9 @@ class NodeUpdateRequest(BaseModel): @app.put("/batch/nodeset", response_model=int) -async def put_batch_nodeset(data: NodeUpdateRequest, user: str = Depends(get_current_user)): +async def put_batch_nodeset( + data: NodeUpdateRequest, user: str = Depends(get_current_user) +): """ Set a field to a value for multiple nodes TBD: Make db.bulkupdate to update multiple nodes in one go @@ -1775,11 +1892,16 @@ async def put_batch_nodeset(data: NodeUpdateRequest, user: str = Depends(get_cur await db.update(node_from_id) updated += 1 else: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Field not supported") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Field not supported", + ) return updated -async def _set_node_ownership_recursively(user: User, hierarchy: Hierarchy, submitter: str, treeid: str): +async def _set_node_ownership_recursively( + user: User, hierarchy: Hierarchy, submitter: str, treeid: str +): """Set node ownership information for a hierarchy of nodes""" if not hierarchy.node.owner: hierarchy.node.owner = user.username @@ -1789,7 +1911,9 @@ async def _set_node_ownership_recursively(user: User, hierarchy: Hierarchy, subm await _set_node_ownership_recursively(user, node, submitter, treeid) -@app.put("/nodes/{node_id}", response_model=List[Node], response_model_by_alias=False) +@app.put( + "/nodes/{node_id}", response_model=List[Node], response_model_by_alias=False +) async def put_nodes( node_id: str, nodes: Hierarchy, @@ -1823,7 +1947,9 @@ async def put_nodes( # ----------------------------------------------------------------------------- # Key/Value namespace enabled store @app.get("/kv/{namespace}/{key}", response_model=Union[str, None]) -async def get_kv(namespace: str, key: str, user: User = Depends(get_current_user)): +async def get_kv( + namespace: str, key: str, user: User = Depends(get_current_user) +): """Get a key value pair from the store""" metrics.add("http_requests_total", 1) return await db.get_kv(namespace, key) @@ -1853,7 +1979,9 @@ async def post_kv( # Delete a key-value pair from the store @app.delete("/kv/{namespace}/{key}", response_model=Optional[str]) -async def delete_kv(namespace: str, key: str, user: User = Depends(get_current_user)): +async def delete_kv( + namespace: str, key: str, user: User = Depends(get_current_user) +): """Delete a key-value pair from the store""" metrics.add("http_requests_total", 1) await db.del_kv(namespace, key) @@ -1911,7 +2039,9 @@ async def unsubscribe(sub_id: int, user: User = Depends(get_current_user)): detail=f"Subscription id not found: {str(error)}", ) from error except RuntimeError as error: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(error)) from error + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(error) + ) from error @app.get("/listen/{sub_id}") @@ -1933,7 +2063,9 @@ async def listen(sub_id: int, user: User = Depends(get_current_user)): @app.post("/publish/{channel}") -async def publish(event: PublishEvent, channel: str, user: User = Depends(get_current_user)): +async def publish( + event: PublishEvent, channel: str, user: User = Depends(get_current_user) +): """Publish an event on the provided Pub/Sub channel""" metrics.add("http_requests_total", 1) event_dict = PublishEvent.dict(event) @@ -1953,7 +2085,9 @@ async def publish(event: PublishEvent, channel: str, user: User = Depends(get_cu @app.post("/push/{list_name}") -async def push(raw: dict, list_name: str, user: User = Depends(get_current_user)): +async def push( + raw: dict, list_name: str, user: User = Depends(get_current_user) +): """Push a message on the provided list""" metrics.add("http_requests_total", 1) attributes = dict(raw) @@ -2067,7 +2201,9 @@ async def icons(icon_name: str): metrics.add("http_requests_total", 1) root_dir = os.path.dirname(os.path.abspath(__file__)) if not re.match(r"^[A-Za-z0-9_.-]+\.png$", icon_name): - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid icon name") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid icon name" + ) icon_path = os.path.join(root_dir, "templates", icon_name) return FileResponse(icon_path) @@ -2082,13 +2218,17 @@ async def serve_css(filename: str): # Security: only allow safe filenames if not re.match(r"^[A-Za-z0-9_.-]+\.css$", filename): print(f"[CSS] Invalid filename pattern: {filename}") - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid filename") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid filename" + ) file_path = os.path.join(root_dir, "static", "css", filename) print(f"[CSS] Looking for file at: {file_path}") print(f"[CSS] File exists: {os.path.isfile(file_path)}") if not os.path.isfile(file_path): print(f"[CSS] File not found: {file_path}") - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="File not found" + ) print(f"[CSS] Serving file: {file_path}") return FileResponse( file_path, @@ -2109,13 +2249,17 @@ async def serve_js(filename: str): # Security: only allow safe filenames if not re.match(r"^[A-Za-z0-9_.-]+\.js$", filename): print(f"[JS] Invalid filename pattern: {filename}") - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid filename") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid filename" + ) file_path = os.path.join(root_dir, "static", "js", filename) print(f"[JS] Looking for file at: {file_path}") print(f"[JS] File exists: {os.path.isfile(file_path)}") if not os.path.isfile(file_path): print(f"[JS] File not found: {file_path}") - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="File not found" + ) print(f"[JS] Serving file: {file_path}") return FileResponse( file_path, @@ -2188,10 +2332,16 @@ def traceback_exception_handler(request: Request, exc: Exception): # https://github.com/DeanWay/fastapi-versioning/issues/30 for sub_app in versioned_app.routes: if hasattr(sub_app.app, "add_exception_handler"): - sub_app.app.add_exception_handler(ValueError, value_error_exception_handler) - sub_app.app.add_exception_handler(errors.InvalidId, invalid_id_exception_handler) + sub_app.app.add_exception_handler( + ValueError, value_error_exception_handler + ) + sub_app.app.add_exception_handler( + errors.InvalidId, invalid_id_exception_handler + ) # print traceback for all other exceptions - sub_app.app.add_exception_handler(Exception, traceback_exception_handler) + sub_app.app.add_exception_handler( + Exception, traceback_exception_handler + ) @versioned_app.middleware("http") diff --git a/api/models.py b/api/models.py index 179a1b9b..37390837 100644 --- a/api/models.py +++ b/api/models.py @@ -40,8 +40,14 @@ class Subscription(BaseModel): id: int = Field(description="Subscription ID") channel: str = Field(description="Subscription channel name") - user: str = Field(description=("Username of the user that created the subscription (owner)")) - promiscuous: bool = Field(description="Listen to all users messages", default=False) + user: str = Field( + description=( + "Username of the user that created the subscription (owner)" + ) + ) + promiscuous: bool = Field( + description="Listen to all users messages", default=False + ) class SubscriptionStats(Subscription): @@ -49,7 +55,8 @@ class SubscriptionStats(Subscription): created: datetime = Field(description="Timestamp of connection creation") last_poll: Optional[datetime] = Field( - default=None, description="Timestamp when connection last polled for data" + default=None, + description="Timestamp when connection last polled for data", ) @@ -65,15 +72,24 @@ class SubscriberState(BaseModel): Enables catch-up on missed events after reconnection. """ - subscriber_id: str = Field(description="Unique subscriber identifier (client-provided)") + subscriber_id: str = Field( + description="Unique subscriber identifier (client-provided)" + ) channel: str = Field(description="Subscribed channel name") - user: str = Field(description="Username of subscriber (for ownership validation)") - promiscuous: bool = Field(default=False, description="If true, receive all messages regardless of owner") + user: str = Field( + description="Username of subscriber (for ownership validation)" + ) + promiscuous: bool = Field( + default=False, + description="If true, receive all messages regardless of owner", + ) last_event_id: int = Field( - default=0, description="Last acknowledged event ID (implicit ACK on next poll)" + default=0, + description="Last acknowledged event ID (implicit ACK on next poll)", ) created_at: datetime = Field( - default_factory=datetime.utcnow, description="Subscription creation timestamp" + default_factory=datetime.utcnow, + description="Subscription creation timestamp", ) last_poll: Optional[datetime] = Field( default=None, description="Last poll timestamp (used for stale cleanup)" @@ -110,7 +126,9 @@ class User( """API User model""" username: Annotated[str, Indexed(unique=True)] - groups: List[UserGroup] = Field(default=[], description="A list of groups that the user belongs to") + groups: List[UserGroup] = Field( + default=[], description="A list of groups that the user belongs to" + ) @field_validator("groups") def validate_groups(cls, groups): # pylint: disable=no-self-argument @@ -182,7 +200,9 @@ def validate_groups(cls, groups): # pylint: disable=no-self-argument class UserUpdateRequest(schemas.BaseUserUpdate): """Update user request schema for API router""" - username: Annotated[Optional[str], Indexed(unique=True), Field(default=None)] + username: Annotated[ + Optional[str], Indexed(unique=True), Field(default=None) + ] groups: List[str] = Field(default=[]) @field_validator("groups") @@ -197,7 +217,9 @@ def validate_groups(cls, groups): # pylint: disable=no-self-argument class UserUpdate(schemas.BaseUserUpdate): """Schema used for sending update user request to 'fastapi-users' router""" - username: Annotated[Optional[str], Indexed(unique=True), Field(default=None)] + username: Annotated[ + Optional[str], Indexed(unique=True), Field(default=None) + ] groups: List[UserGroup] = Field(default=[]) @field_validator("groups") diff --git a/api/pubsub.py b/api/pubsub.py index 23cca2a4..3535450b 100644 --- a/api/pubsub.py +++ b/api/pubsub.py @@ -43,7 +43,9 @@ def __init__(self, host=None, db_number=None): host = self._settings.redis_host if db_number is None: db_number = self._settings.redis_db_number - self._redis = aioredis.from_url("redis://" + host + "/" + str(db_number), health_check_interval=30) + self._redis = aioredis.from_url( + "redis://" + host + "/" + str(db_number), health_check_interval=30 + ) # self._subscriptions is a dict that matches a subscription id # (key) with a Subscription object ('sub') and a redis # PubSub object ('redis_sub'). For instance: @@ -63,7 +65,9 @@ def _start_keep_alive_timer(self): return if not self._keep_alive_timer or self._keep_alive_timer.done(): loop = asyncio.get_running_loop() - self._keep_alive_timer = asyncio.run_coroutine_threadsafe(self._keep_alive(), loop) + self._keep_alive_timer = asyncio.run_coroutine_threadsafe( + self._keep_alive(), loop + ) async def _keep_alive(self): while True: @@ -140,7 +144,9 @@ async def listen(self, sub_id, user=None): self._subscriptions[sub_id]["last_poll"] = datetime.utcnow() msg = None try: - msg = await sub["redis_sub"].get_message(ignore_subscribe_messages=True, timeout=1.0) + msg = await sub["redis_sub"].get_message( + ignore_subscribe_messages=True, timeout=1.0 + ) except aioredis.ConnectionError: async with self._lock: channel = self._subscriptions[sub_id]["sub"].channel diff --git a/api/pubsub_mongo.py b/api/pubsub_mongo.py index d842203b..308a086b 100644 --- a/api/pubsub_mongo.py +++ b/api/pubsub_mongo.py @@ -68,14 +68,22 @@ async def create(cls, *args, mongo_client=None, **kwargs): await pubsub._init() return pubsub - def __init__(self, mongo_client=None, host=None, db_number=None, mongo_db_name="kernelci"): + def __init__( + self, + mongo_client=None, + host=None, + db_number=None, + mongo_db_name="kernelci", + ): self._settings = PubSubSettings() if host is None: host = self._settings.redis_host if db_number is None: db_number = self._settings.redis_db_number - self._redis = aioredis.from_url("redis://" + host + "/" + str(db_number), health_check_interval=30) + self._redis = aioredis.from_url( + "redis://" + host + "/" + str(db_number), health_check_interval=30 + ) # MongoDB setup if mongo_client is None: @@ -115,7 +123,9 @@ async def _migrate_eventhistory_if_needed(self): # Check if collection exists collections = await self._mongo_db.list_collection_names() if self.EVENT_HISTORY_COLLECTION not in collections: - logger.info("eventhistory collection does not exist, will be created") + logger.info( + "eventhistory collection does not exist, will be created" + ) return # Check existing indexes @@ -131,7 +141,9 @@ async def _migrate_eventhistory_if_needed(self): ttl = index_info["expireAfterSeconds"] if ttl == 86400: old_format_detected = True - logger.warning("Detected old eventhistory format (24h TTL). Migration required.") + logger.warning( + "Detected old eventhistory format (24h TTL). Migration required." + ) # Check for new sequence_id index if "key" in index_info: @@ -160,7 +172,9 @@ async def _migrate_eventhistory(self, col): # Drop all documents (they lack required fields) result = await col.delete_many({}) - logger.info("Deleted %d old eventhistory documents", result.deleted_count) + logger.info( + "Deleted %d old eventhistory documents", result.deleted_count + ) logger.info("eventhistory migration complete") @@ -183,12 +197,15 @@ async def _ensure_indexes(self): # Compound index for filtered event queries (kind + timestamp) await event_col.create_index( - [("data.kind", ASCENDING), ("timestamp", ASCENDING)], name="kind_timestamp" + [("data.kind", ASCENDING), ("timestamp", ASCENDING)], + name="kind_timestamp", ) # Subscriber state indexes # Unique index on subscriber_id - await sub_col.create_index("subscriber_id", unique=True, name="unique_subscriber_id") + await sub_col.create_index( + "subscriber_id", unique=True, name="unique_subscriber_id" + ) # Index for stale cleanup await sub_col.create_index("last_poll", name="last_poll") @@ -198,7 +215,9 @@ def _start_keep_alive_timer(self): return if not self._keep_alive_timer or self._keep_alive_timer.done(): loop = asyncio.get_running_loop() - self._keep_alive_timer = asyncio.run_coroutine_threadsafe(self._keep_alive(), loop) + self._keep_alive_timer = asyncio.run_coroutine_threadsafe( + self._keep_alive(), loop + ) async def _keep_alive(self): """Send periodic BEEP to keep connections alive""" @@ -233,7 +252,9 @@ async def _get_next_event_id(self) -> int: """Get next sequential event ID from Redis""" return await self._redis.incr(self.EVENT_SEQ_KEY) - async def _store_event(self, channel: str, data: Dict[str, Any], owner: Optional[str] = None) -> int: + async def _store_event( + self, channel: str, data: Dict[str, Any], owner: Optional[str] = None + ) -> int: """Store event in eventhistory collection and return sequence ID Uses the same collection as /events API endpoint (EventHistory model). @@ -249,7 +270,9 @@ async def _store_event(self, channel: str, data: Dict[str, Any], owner: Optional } col = self._mongo_db[self.EVENT_HISTORY_COLLECTION] # Use w=1 for acknowledged writes (durability) - await col.with_options(write_concern=WriteConcern(w=1)).insert_one(event_doc) + await col.with_options(write_concern=WriteConcern(w=1)).insert_one( + event_doc + ) return sequence_id async def _get_subscriber_state(self, subscriber_id: str) -> Optional[Dict]: @@ -272,11 +295,21 @@ def _decode_redis_message(msg: Dict) -> Dict: """Decode Redis message bytes to strings for JSON serialization""" return { "type": msg.get("type"), - "pattern": (msg.get("pattern").decode("utf-8") if msg.get("pattern") else None), + "pattern": ( + msg.get("pattern").decode("utf-8") + if msg.get("pattern") + else None + ), "channel": ( - msg["channel"].decode("utf-8") if isinstance(msg["channel"], bytes) else msg["channel"] + msg["channel"].decode("utf-8") + if isinstance(msg["channel"], bytes) + else msg["channel"] + ), + "data": ( + msg["data"].decode("utf-8") + if isinstance(msg["data"], bytes) + else msg["data"] ), - "data": (msg["data"].decode("utf-8") if isinstance(msg["data"], bytes) else msg["data"]), } def _eventhistory_to_cloudevent(self, event: Dict) -> str: @@ -327,7 +360,9 @@ async def _get_missed_events( cursor = col.find(query).sort("sequence_id", ASCENDING).limit(limit) return await cursor.to_list(length=limit) - async def subscribe(self, channel: str, user: str, options: Optional[Dict] = None) -> Subscription: + async def subscribe( + self, channel: str, user: str, options: Optional[Dict] = None + ) -> Subscription: """Subscribe to a Pub/Sub channel Args: @@ -346,7 +381,9 @@ async def subscribe(self, channel: str, user: str, options: Optional[Dict] = Non async with self._lock: redis_sub = self._redis.pubsub() - sub = Subscription(id=sub_id, channel=channel, user=user, promiscuous=promiscuous) + sub = Subscription( + id=sub_id, channel=channel, user=user, promiscuous=promiscuous + ) await redis_sub.subscribe(channel) self._subscriptions[sub_id] = { @@ -390,7 +427,9 @@ async def _setup_durable_subscription( if existing: # Existing subscriber - verify ownership if existing["user"] != user: - raise RuntimeError(f"Subscriber {subscriber_id} owned by different user") + raise RuntimeError( + f"Subscriber {subscriber_id} owned by different user" + ) # Load pending catch-up events missed = await self._get_missed_events( channel=existing["channel"], @@ -447,7 +486,9 @@ async def unsubscribe(self, sub_id: int, user: Optional[str] = None): await sub["redis_sub"].unsubscribe() await sub["redis_sub"].close() - async def _get_listen_subscription(self, sub_id: int, user: Optional[str] = None): + async def _get_listen_subscription( + self, sub_id: int, user: Optional[str] = None + ): async with self._lock: sub_data = self._subscriptions.get(sub_id) if not sub_data: @@ -460,7 +501,9 @@ async def _get_listen_subscription(self, sub_id: int, user: Optional[str] = None return sub, sub_data - async def _update_listen_subscription_state(self, sub: Subscription, sub_data: dict): + async def _update_listen_subscription_state( + self, sub: Subscription, sub_data: dict + ): subscriber_id = sub_data.get("subscriber_id") if subscriber_id and sub_data.get("last_delivered_id"): await self._update_subscriber_state( @@ -470,7 +513,9 @@ async def _update_listen_subscription_state(self, sub: Subscription, sub_data: d return subscriber_id - def _consume_pending_catchup(self, sub_id: int, sub: Subscription, sub_data: dict) -> Optional[Dict]: + def _consume_pending_catchup( + self, sub_id: int, sub: Subscription, sub_data: dict + ) -> Optional[Dict]: if not sub_data.get("pending_catchup"): return None @@ -486,7 +531,9 @@ def _consume_pending_catchup(self, sub_id: int, sub: Subscription, sub_data: dic "type": "message", } - async def _rebuild_redis_subscription(self, sub_id: int, sub: Subscription, sub_data: dict): + async def _rebuild_redis_subscription( + self, sub_id: int, sub: Subscription, sub_data: dict + ): async with self._lock: channel = sub.channel new_redis_sub = self._redis.pubsub() @@ -494,7 +541,9 @@ async def _rebuild_redis_subscription(self, sub_id: int, sub: Subscription, sub_ self._subscriptions[sub_id]["redis_sub"] = new_redis_sub sub_data["redis_sub"] = new_redis_sub - def _maybe_update_delivery_offset(self, subscriber_id: Optional[str], sub_data: dict, msg_data: Any): + def _maybe_update_delivery_offset( + self, subscriber_id: Optional[str], sub_data: dict, msg_data: Any + ): if not subscriber_id or not isinstance(msg_data, dict): return @@ -502,7 +551,9 @@ def _maybe_update_delivery_offset(self, subscriber_id: Optional[str], sub_data: if sequence_id: sub_data["last_delivered_id"] = sequence_id - def _should_deliver_to_user(self, sub: Subscription, msg_data: dict) -> bool: + def _should_deliver_to_user( + self, sub: Subscription, msg_data: dict + ) -> bool: if sub.promiscuous: return True @@ -518,7 +569,9 @@ async def _listen_for_message( while True: self._subscriptions[sub_id]["last_poll"] = datetime.utcnow() try: - msg = await sub_data["redis_sub"].get_message(ignore_subscribe_messages=True, timeout=1.0) + msg = await sub_data["redis_sub"].get_message( + ignore_subscribe_messages=True, timeout=1.0 + ) except aioredis.ConnectionError: await self._rebuild_redis_subscription(sub_id, sub, sub_data) continue @@ -530,13 +583,17 @@ async def _listen_for_message( continue msg_data = json.loads(msg["data"]) - self._maybe_update_delivery_offset(subscriber_id, sub_data, msg_data) + self._maybe_update_delivery_offset( + subscriber_id, sub_data, msg_data + ) if not self._should_deliver_to_user(sub, msg_data): continue return self._decode_redis_message(msg) - async def listen(self, sub_id: int, user: Optional[str] = None) -> Optional[Dict]: + async def listen( + self, sub_id: int, user: Optional[str] = None + ) -> Optional[Dict]: """Listen for Pub/Sub messages For durable subscriptions (with subscriber_id): @@ -547,7 +604,9 @@ async def listen(self, sub_id: int, user: Optional[str] = None) -> Optional[Dict Returns message dict or None on error. """ sub, sub_data = await self._get_listen_subscription(sub_id, user) - subscriber_id = await self._update_listen_subscription_state(sub, sub_data) + subscriber_id = await self._update_listen_subscription_state( + sub, sub_data + ) pending_msg = self._consume_pending_catchup(sub_id, sub, sub_data) if pending_msg: return pending_msg @@ -555,13 +614,17 @@ async def listen(self, sub_id: int, user: Optional[str] = None) -> Optional[Dict if not sub_data.get("catchup_done"): sub_data["catchup_done"] = True - return await self._listen_for_message(sub_id, sub, subscriber_id, sub_data) + return await self._listen_for_message( + sub_id, sub, subscriber_id, sub_data + ) async def publish(self, channel: str, message: str): """Publish a message on a channel (Redis only, no durability)""" await self._redis.publish(channel, message) - async def publish_cloudevent(self, channel: str, data: Any, attributes: Optional[Dict] = None): + async def publish_cloudevent( + self, channel: str, data: Any, attributes: Optional[Dict] = None + ): """Publish a CloudEvent on a Pub/Sub channel Events are: @@ -606,7 +669,9 @@ async def pop(self, list_name: str) -> Optional[Dict]: if data is not None: return data - async def push_cloudevent(self, list_name: str, data: Any, attributes: Optional[Dict] = None): + async def push_cloudevent( + self, list_name: str, data: Any, attributes: Optional[Dict] = None + ): """Push a CloudEvent on a list""" if not attributes: attributes = { @@ -631,7 +696,9 @@ async def subscription_stats(self) -> List[SubscriptionStats]: subscriptions.append(stats) return subscriptions - async def cleanup_stale_subscriptions(self, max_age_minutes: int = 30) -> int: + async def cleanup_stale_subscriptions( + self, max_age_minutes: int = 30 + ) -> int: """Remove subscriptions not polled recently For durable subscriptions, only the in-memory state is cleaned up. @@ -654,7 +721,9 @@ async def cleanup_stale_subscriptions(self, max_age_minutes: int = 30) -> int: return len(stale_ids) - async def cleanup_stale_subscriber_states(self, max_age_days: int = 30) -> int: + async def cleanup_stale_subscriber_states( + self, max_age_days: int = 30 + ) -> int: """Remove subscriber states not used for a long time This is separate from subscription cleanup - it removes the diff --git a/api/user_manager.py b/api/user_manager.py index 06237009..448cfd5e 100644 --- a/api/user_manager.py +++ b/api/user_manager.py @@ -32,7 +32,9 @@ class UserManager(ObjectIDIDMixin, BaseUserManager[User, PydanticObjectId]): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._email_sender = None - self._template_env = jinja2.Environment(loader=jinja2.PackageLoader("api", "templates")) + self._template_env = jinja2.Environment( + loader=jinja2.PackageLoader("api", "templates") + ) @property def email_sender(self): @@ -41,7 +43,9 @@ def email_sender(self): self._email_sender = EmailSender() return self._email_sender - async def on_after_register(self, user: User, request: Optional[Request] = None): + async def on_after_register( + self, user: User, request: Optional[Request] = None + ): """Handler to execute after successful user registration""" print(f"User {user.id} {user.username} has registered.") @@ -65,17 +69,23 @@ async def on_after_login( """Handler to execute after successful user login""" print(f"User {user.id} {user.username} logged in.") - async def on_after_forgot_password(self, user: User, token: str, request: Optional[Request] = None): + async def on_after_forgot_password( + self, user: User, token: str, request: Optional[Request] = None + ): """Handler to execute after successful forgot password request""" template = self._template_env.get_template("reset-password.jinja2") subject = "Reset Password Token for KernelCI API account" content = template.render(username=user.username, token=token) self.email_sender.create_and_send_email(subject, content, user.email) - async def on_after_reset_password(self, user: User, request: Optional[Request] = None): + async def on_after_reset_password( + self, user: User, request: Optional[Request] = None + ): """Handler to execute after successful password reset""" print(f"User {user.id} {user.username} has reset their password.") - template = self._template_env.get_template("reset-password-successful.jinja2") + template = self._template_env.get_template( + "reset-password-successful.jinja2" + ) subject = "Password reset successful for KernelCI API account" content = template.render( username=user.username, @@ -90,20 +100,29 @@ async def send_invite_accepted_email(self, user: User): self.email_sender.create_and_send_email(subject, content, user.email) async def on_after_update( - self, user: User, update_dict: Dict[str, Any], request: Optional[Request] = None + self, + user: User, + update_dict: Dict[str, Any], + request: Optional[Request] = None, ): """Handler to execute after successful user update""" print(f"User {user.id} {user.username} has been updated.") - async def on_before_delete(self, user: User, request: Optional[Request] = None): + async def on_before_delete( + self, user: User, request: Optional[Request] = None + ): """Handler to execute before user delete.""" print(f"User {user.id} {user.username} is going to be deleted.") - async def on_after_delete(self, user: User, request: Optional[Request] = None): + async def on_after_delete( + self, user: User, request: Optional[Request] = None + ): """Handler to execute after user delete.""" print(f"User {user.id} {user.username} was successfully deleted.") - async def authenticate(self, credentials: OAuth2PasswordRequestForm) -> User | None: + async def authenticate( + self, credentials: OAuth2PasswordRequestForm + ) -> User | None: """ Overload user authentication method `BaseUserManager.authenticate`. This is to fix login endpoint to receive `username` instead of `email`. @@ -113,14 +132,18 @@ async def authenticate(self, credentials: OAuth2PasswordRequestForm) -> User | N self.password_helper.hash(credentials.password) return None - verified, updated_password_hash = self.password_helper.verify_and_update( - credentials.password, user.hashed_password + verified, updated_password_hash = ( + self.password_helper.verify_and_update( + credentials.password, user.hashed_password + ) ) if not verified: return None # Update password hash to a more robust one if needed if updated_password_hash is not None: - await self.user_db.update(user, {"hashed_password": updated_password_hash}) + await self.user_db.update( + user, {"hashed_password": updated_password_hash} + ) return user diff --git a/migrations/20231102101356_user.py b/migrations/20231102101356_user.py index 61b2341f..c462d4fc 100644 --- a/migrations/20231102101356_user.py +++ b/migrations/20231102101356_user.py @@ -11,9 +11,15 @@ """ -name = '20231102101356_user' +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import pymongo.database + +name = "20231102101356_user" dependencies = [] + def user_upgrade_needed(user): """Checks if a DB user passed as a parameter needs to be migrated with this script. @@ -27,7 +33,7 @@ def user_upgrade_needed(user): """ # The existence of a 'profile' key seems to be enough to detect a # pre-migration user - if 'profile' in user: + if "profile" in user: return True else: return False @@ -43,35 +49,30 @@ def upgrade(db: "pymongo.database.Database"): # Check if the user is an admin (superuser), remove it from the # "admin" user group if it is is_superuser = False - new_groups_list = [g for g in user['profile']['groups'] - if g['name'] != 'admin'] - if len(new_groups_list) != len(user['profile']['groups']): + new_groups_list = [ + g for g in user["profile"]["groups"] if g["name"] != "admin" + ] + if len(new_groups_list) != len(user["profile"]["groups"]): is_superuser = True - user['profile']['groups'] = new_groups_list + user["profile"]["groups"] = new_groups_list # User update db.user.replace_one( + {"_id": user["_id"]}, { - "_id": user['_id'] - }, - { - "_id": user['_id'], - "email": user['profile']['email'], - "hashed_password": user['profile']['hashed_password'], - "is_active": user['active'], + "_id": user["_id"], + "email": user["profile"]["email"], + "hashed_password": user["profile"]["hashed_password"], + "is_active": user["active"], "is_superuser": is_superuser, "is_verified": False, - "username": user['profile']['username'], - "groups": user['profile']['groups'] + "username": user["profile"]["username"], + "groups": user["profile"]["groups"], }, ) # Sanity check: check if there are any old-format users in the # "admin" group. Remove the group if there aren't any remaining_admins = db.user.count( - { - "groups": { - "$elemMatch": {"name": "admin"} - } - } + {"groups": {"$elemMatch": {"name": "admin"}}} ) if remaining_admins == 0: db.usergroup.delete_one({"name": "admin"}) @@ -80,15 +81,13 @@ def upgrade(db: "pymongo.database.Database"): def downgrade(db: "pymongo.database.Database"): - superusers = db.user.find({'is_superuser': True}) + superusers = db.user.find({"is_superuser": True}) if superusers: # Create the 'admin' group if it doesn't exist db.usergroup.update_one( - {'name': 'admin'}, - {'$setOnInsert': {'name': 'admin'}}, - upsert=True + {"name": "admin"}, {"$setOnInsert": {"name": "admin"}}, upsert=True ) - admin_group = db.usergroup.find_one({'name': 'admin'}) + admin_group = db.usergroup.find_one({"name": "admin"}) users = db.user.find() db.user.drop_indexes() @@ -96,25 +95,26 @@ def downgrade(db: "pymongo.database.Database"): # Skip users that weren't migrated (unlikely corner case) if user_upgrade_needed(user): continue - if user.get('is_superuser') == True: + if user.get("is_superuser"): # Add user to admin group - new_groups_list = [g for g in user['groups'] - if g['name'] != 'admin'] + new_groups_list = [ + g for g in user["groups"] if g["name"] != "admin" + ] new_groups_list.append(admin_group) - user['groups'] = new_groups_list + user["groups"] = new_groups_list db.user.replace_one( { - '_id': user['_id'], + "_id": user["_id"], }, { - '_id': user['_id'], - 'active': user['is_active'], - 'profile': { - 'email': user['email'], - 'hashed_password': user['hashed_password'], - 'username': user['username'], - 'groups': user['groups'], - } - } + "_id": user["_id"], + "active": user["is_active"], + "profile": { + "email": user["email"], + "hashed_password": user["hashed_password"], + "username": user["username"], + "groups": user["groups"], + }, + }, ) diff --git a/migrations/20231215122000_node_models.py b/migrations/20231215122000_node_models.py index ca1bc752..eff694d0 100644 --- a/migrations/20231215122000_node_models.py +++ b/migrations/20231215122000_node_models.py @@ -5,16 +5,21 @@ """Migration for Node objects to comply with the models after commits: - api.models: basic definitions of Node submodels - api.main: use node endpoints for all type of Node subtypes - api.db: remove regression collection +api.models: basic definitions of Node submodels +api.main: use node endpoints for all type of Node subtypes +api.db: remove regression collection """ +from typing import TYPE_CHECKING + from bson.objectid import ObjectId -name = '20231215122000_node_models' -dependencies = ['20231102101356_user'] +if TYPE_CHECKING: + import pymongo.database + +name = "20231215122000_node_models" +dependencies = ["20231102101356_user"] def node_upgrade_needed(node): @@ -30,7 +35,7 @@ def node_upgrade_needed(node): """ # The existence of a 'revision' key seems to be enough to detect a # pre-migration Node - if 'revision' in node: + if "revision" in node: return True else: return False @@ -43,22 +48,17 @@ def upgrade(db: "pymongo.database.Database"): # Skip any node that's not in the old format if not node_upgrade_needed(node): continue - if not node.get('data'): + if not node.get("data"): # Initialize 'data' field if it's empty: a generic Node # with no specific type may have an emtpy 'data' field - db.node.update_one( - {'_id': node['_id']}, - {'$set': {'data': {}}} - ) + db.node.update_one({"_id": node["_id"]}, {"$set": {"data": {}}}) # move 'revision' to 'data.kernel_revision' db.node.update_one( - {'_id': node['_id']}, + {"_id": node["_id"]}, { - '$set': { - 'data.kernel_revision': node['revision'] - }, - '$unset': {'revision': ''} - } + "$set": {"data.kernel_revision": node["revision"]}, + "$unset": {"revision": ""}, + }, ) # Re-format regressions: move them from "regression" to "node" @@ -66,51 +66,55 @@ def upgrade(db: "pymongo.database.Database"): for regression in regressions: db.node.insert_one( { - 'name': regression.get('name'), - 'group': regression.get('group'), - 'path': regression.get('path'), - 'kind': 'regression', - 'data': { - 'pass_node': ObjectId(regression['regression_data'][0]['_id']), - 'fail_node': ObjectId(regression['regression_data'][1]['_id']) + "name": regression.get("name"), + "group": regression.get("group"), + "path": regression.get("path"), + "kind": "regression", + "data": { + "pass_node": ObjectId( + regression["regression_data"][0]["_id"] + ), + "fail_node": ObjectId( + regression["regression_data"][1]["_id"] + ), }, - 'artifacts': regression.get('artifacts'), - 'created': regression.get('created'), - 'updated': regression.get('updated'), - 'timeout': regression.get('timeout'), - 'owner': regression.get('owner'), + "artifacts": regression.get("artifacts"), + "created": regression.get("created"), + "updated": regression.get("updated"), + "timeout": regression.get("timeout"), + "owner": regression.get("owner"), } ) - db.regression.delete_one({'_id': regression['_id']}) + db.regression.delete_one({"_id": regression["_id"]}) -def downgrade(db: 'pymongo.database.Database'): +def downgrade(db: "pymongo.database.Database"): # Move regressions back to "regression" - regressions = db.node.find({'kind': 'regression'}) + regressions = db.node.find({"kind": "regression"}) for regression in regressions: fail_node = db.node.find_one( - {'_id': ObjectId(regression['data']['fail_node'])} + {"_id": ObjectId(regression["data"]["fail_node"])} ) db.regression.insert_one( { - 'name': regression.get('name'), - 'group': regression.get('group'), - 'path': regression.get('path'), - 'kind': 'regression', - 'parent': regression['data']['fail_node'], - 'regression_data': [ - regression['data']['pass_node'], - regression['data']['fail_node'] + "name": regression.get("name"), + "group": regression.get("group"), + "path": regression.get("path"), + "kind": "regression", + "parent": regression["data"]["fail_node"], + "regression_data": [ + regression["data"]["pass_node"], + regression["data"]["fail_node"], ], - 'revision': fail_node['data']['kernel_revision'], - 'artifacts': regression.get('artifacts'), - 'created': regression.get('created'), - 'updated': regression.get('updated'), - 'timeout': regression.get('timeout'), - 'owner': regression.get('owner'), + "revision": fail_node["data"]["kernel_revision"], + "artifacts": regression.get("artifacts"), + "created": regression.get("created"), + "updated": regression.get("updated"), + "timeout": regression.get("timeout"), + "owner": regression.get("owner"), } ) - db.node.delete_one({'_id': regression['_id']}) + db.node.delete_one({"_id": regression["_id"]}) # Downgrade node format nodes = db.node.find() @@ -120,18 +124,13 @@ def downgrade(db: 'pymongo.database.Database'): continue # move 'data.kernel_revision' to 'revision' db.node.update_one( - {'_id': node['_id']}, + {"_id": node["_id"]}, { - '$set': { - 'revision': node['data']['kernel_revision'] - }, - '$unset': {'data.kernel_revision': ''} - } + "$set": {"revision": node["data"]["kernel_revision"]}, + "$unset": {"data.kernel_revision": ""}, + }, ) # unset 'data' if it's empty - node['data'].pop('kernel_revision', None) - if len(node['data']) == 0: - db.node.update_one( - {'_id': node['_id']}, - {'$unset': {'data': ''}} - ) + node["data"].pop("kernel_revision", None) + if len(node["data"]) == 0: + db.node.update_one({"_id": node["_id"]}, {"$unset": {"data": ""}}) diff --git a/scripts/usermanager.py b/scripts/usermanager.py index 784f861d..c02f4292 100755 --- a/scripts/usermanager.py +++ b/scripts/usermanager.py @@ -4,7 +4,6 @@ import json import os import re -import sys import urllib.error import urllib.parse import urllib.request @@ -17,7 +16,9 @@ DEFAULT_CONFIG_PATHS = [ os.path.join(os.getcwd(), "usermanager.toml"), - os.path.join(os.path.expanduser("~"), ".config", "kernelci", "usermanager.toml"), + os.path.join( + os.path.expanduser("~"), ".config", "kernelci", "usermanager.toml" + ), ] @@ -168,7 +169,9 @@ def _resolve_group_id(group_id, api_url, token): if _looks_like_object_id(group_id): return group_id query = urllib.parse.urlencode({"name": group_id}) - status, body = _request_json("GET", f"{api_url}/user-groups?{query}", token=token) + status, body = _request_json( + "GET", f"{api_url}/user-groups?{query}", token=token + ) if status >= 400: _print_response(status, body) raise SystemExit(1) @@ -212,11 +215,15 @@ def _resolve_group_name(group_name, api_url, token): def _resolve_group_names(group_names, api_url, token): - return _dedupe([_resolve_group_name(name, api_url, token) for name in group_names]) + return _dedupe( + [_resolve_group_name(name, api_url, token) for name in group_names] + ) def _update_user_groups(resolved_id, add_groups, remove_groups, api_url, token): - status, body = _request_json("GET", f"{api_url}/user/{resolved_id}", token=token) + status, body = _request_json( + "GET", f"{api_url}/user/{resolved_id}", token=token + ) if status >= 400: _print_response(status, body) raise SystemExit(1) @@ -226,9 +233,13 @@ def _update_user_groups(resolved_id, add_groups, remove_groups, api_url, token): raise SystemExit("Failed to parse user response") from exc current_groups = _extract_group_names(payload) data = { - "groups": _apply_group_changes(current_groups, add_groups, remove_groups), + "groups": _apply_group_changes( + current_groups, add_groups, remove_groups + ), } - return _request_json("PATCH", f"{api_url}/user/{resolved_id}", data, token=token) + return _request_json( + "PATCH", f"{api_url}/user/{resolved_id}", data, token=token + ) def _request_json(method, url, data=None, token=None, form=False): @@ -274,7 +285,7 @@ def _require_token(token, args): ) -def main(): +def main(): # noqa: C901 command_help = [ ("accept-invite", "Accept an invite"), ("assign-group", "Assign group(s) to a user"), @@ -295,9 +306,11 @@ def main(): ("whoami", "Show current user"), ] command_list = "\n".join( - " {:<18} {}".format(name, desc) for name, desc in command_help) + " {:<18} {}".format(name, desc) for name, desc in command_help + ) default_paths = "\n".join( - " - {}".format(path) for path in DEFAULT_CONFIG_PATHS) + " - {}".format(path) for path in DEFAULT_CONFIG_PATHS + ) parser = argparse.ArgumentParser( description="KernelCI API user management helper", epilog=( @@ -324,12 +337,14 @@ def main(): help="Path to usermanager.toml (defaults to first match in the lookup list below)", ) parser.add_argument( - "--api-url", help="API base URL, e.g. " "http://localhost:8001/latest" + "--api-url", help="API base URL, e.g. http://localhost:8001/latest" ) parser.add_argument("--token", help="Bearer token for admin/user actions") parser.add_argument("--instance", help="Instance name from config") parser.add_argument( - "--token-label", default="Auth", help="Label used when prompting for a token" + "--token-label", + default="Auth", + help="Label used when prompting for a token", ) subparsers = parser.add_subparsers(dest="command", required=True) @@ -349,9 +364,13 @@ def main(): help="Group name or id; can be used multiple times or with commas", ) - subparsers.add_parser("config-example", help="Print a sample usermanager.toml") + subparsers.add_parser( + "config-example", help="Print a sample usermanager.toml" + ) - create_group = subparsers.add_parser("create-group", help="Create user group") + create_group = subparsers.add_parser( + "create-group", help="Create user group" + ) create_group.add_argument("name") deassign_group = subparsers.add_parser( @@ -365,7 +384,9 @@ def main(): help="Group name or id; can be used multiple times or with commas", ) - delete_group = subparsers.add_parser("delete-group", help="Delete user group") + delete_group = subparsers.add_parser( + "delete-group", help="Delete user group" + ) delete_group.add_argument("group_id") delete_user = subparsers.add_parser("delete-user", help="Delete user by id") @@ -377,7 +398,9 @@ def main(): generate_token.add_argument("--username", required=True) generate_token.add_argument("--password") - get_group = subparsers.add_parser("get-group", help="Get user group by id or name") + get_group = subparsers.add_parser( + "get-group", help="Get user group by id or name" + ) get_group.add_argument("group_id") get_user = subparsers.add_parser("get-user", help="Get user by id") @@ -393,11 +416,11 @@ def main(): invite.add_argument("--return-token", action="store_true") invite.add_argument("--resend-if-exists", action="store_true") - invite_url = subparsers.add_parser("invite-url", help="Preview invite URL base") + subparsers.add_parser("invite-url", help="Preview invite URL base") - list_groups = subparsers.add_parser("list-groups", help="List user groups") + subparsers.add_parser("list-groups", help="List user groups") - list_users = subparsers.add_parser("list-users", help="List users") + subparsers.add_parser("list-users", help="List users") login = subparsers.add_parser("login", help="Get an auth token") login.add_argument("--username", required=True) @@ -410,7 +433,10 @@ def main(): update_user.add_argument("--email", help="Set email") update_user.add_argument("--password", help="Set password") update_user.add_argument( - "--superuser", dest="is_superuser", action="store_true", help="Grant superuser" + "--superuser", + dest="is_superuser", + action="store_true", + help="Grant superuser", ) update_user.add_argument( "--no-superuser", @@ -419,10 +445,16 @@ def main(): help="Revoke superuser", ) update_user.add_argument( - "--active", dest="is_active", action="store_true", help="Set is_active true" + "--active", + dest="is_active", + action="store_true", + help="Set is_active true", ) update_user.add_argument( - "--inactive", dest="is_active", action="store_false", help="Set is_active false" + "--inactive", + dest="is_active", + action="store_false", + help="Set is_active false", ) update_user.add_argument( "--verified", @@ -436,7 +468,9 @@ def main(): action="store_false", help="Set is_verified false", ) - update_user.set_defaults(is_active=None, is_verified=None, is_superuser=None) + update_user.set_defaults( + is_active=None, is_verified=None, is_superuser=None + ) update_user.add_argument( "--set-groups", help="Replace all groups with a comma-separated list", @@ -454,7 +488,7 @@ def main(): help="Remove group(s); can be used multiple times or with commas", ) - whoami = subparsers.add_parser("whoami", help="Show current user") + subparsers.add_parser("whoami", help="Show current user") args = parser.parse_args() @@ -529,7 +563,9 @@ def main(): "POST", f"{api_url}/user/invite", payload, token=token ) elif args.command == "invite-url": - status, body = _request_json("GET", f"{api_url}/user/invite/url", token=token) + status, body = _request_json( + "GET", f"{api_url}/user/invite/url", token=token + ) elif args.command == "accept-invite": invite_token = _prompt_if_missing( args.token, @@ -542,7 +578,9 @@ def main(): secret=True, ) payload = {"token": invite_token, "password": password} - status, body = _request_json("POST", f"{api_url}/user/accept-invite", payload) + status, body = _request_json( + "POST", f"{api_url}/user/accept-invite", payload + ) elif args.command == "login": password = _prompt_if_missing( args.password, @@ -655,7 +693,9 @@ def main(): if not add_groups: raise SystemExit("No groups specified. Use --group.") add_groups = _resolve_group_names(add_groups, api_url, token) - status, body = _update_user_groups(resolved_id, add_groups, [], api_url, token) + status, body = _update_user_groups( + resolved_id, add_groups, [], api_url, token + ) elif args.command == "deassign-group": resolved_id = _resolve_user_id(args.user_id, api_url, token) remove_groups = _parse_group_list(args.group) @@ -666,7 +706,9 @@ def main(): resolved_id, [], remove_groups, api_url, token ) elif args.command == "list-groups": - status, body = _request_json("GET", f"{api_url}/user-groups", token=token) + status, body = _request_json( + "GET", f"{api_url}/user-groups", token=token + ) elif args.command == "get-group": resolved_id = _resolve_group_id(args.group_id, api_url, token) status, body = _request_json( diff --git a/tests/e2e_tests/test_count_handler.py b/tests/e2e_tests/test_count_handler.py index f86ec914..148e9ddb 100644 --- a/tests/e2e_tests/test_count_handler.py +++ b/tests/e2e_tests/test_count_handler.py @@ -10,7 +10,10 @@ @pytest.mark.asyncio -@pytest.mark.dependency(depends=["tests/e2e_tests/test_pipeline.py::test_node_pipeline"], scope="session") +@pytest.mark.dependency( + depends=["tests/e2e_tests/test_pipeline.py::test_node_pipeline"], + scope="session", +) async def test_count_nodes(test_async_client): """ Test Case : Test KernelCI API GET /count endpoint @@ -24,7 +27,10 @@ async def test_count_nodes(test_async_client): @pytest.mark.asyncio -@pytest.mark.dependency(depends=["tests/e2e_tests/test_pipeline.py::test_node_pipeline"], scope="session") +@pytest.mark.dependency( + depends=["tests/e2e_tests/test_pipeline.py::test_node_pipeline"], + scope="session", +) async def test_count_nodes_matching_attributes(test_async_client): """ Test Case : Test KernelCI API GET /count endpoint with attributes diff --git a/tests/e2e_tests/test_pipeline.py b/tests/e2e_tests/test_pipeline.py index f87099cc..567d5098 100644 --- a/tests/e2e_tests/test_pipeline.py +++ b/tests/e2e_tests/test_pipeline.py @@ -10,11 +10,17 @@ from cloudevents.http import from_json from .listen_handler import create_listen_task -from .test_node_handler import create_node, get_node_by_id, patch_node, update_node +from .test_node_handler import ( + create_node, + get_node_by_id, + update_node, +) @pytest.mark.dependency( - depends=["tests/e2e_tests/test_subscribe_handler.py::test_subscribe_node_channel"], + depends=[ + "tests/e2e_tests/test_subscribe_handler.py::test_subscribe_node_channel" + ], scope="session", ) @pytest.mark.order(4) @@ -37,7 +43,9 @@ async def test_node_pipeline(test_async_client): """ # Create Task to listen pubsub event on 'node' channel - task_listen = create_listen_task(test_async_client, pytest.node_channel_subscription_id) # pylint: disable=no-member + task_listen = create_listen_task( + test_async_client, pytest.node_channel_subscription_id + ) # pylint: disable=no-member # Create a node node = { @@ -47,7 +55,9 @@ async def test_node_pipeline(test_async_client): "data": { "kernel_revision": { "tree": "mainline", - "url": ("https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git"), + "url": ( + "https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git" + ), "branch": "master", "commit": "2a987e65025e2b79c6d453b78cb5985ac6e5eb28", "describe": "v5.16-rc4-31-g2a987e65025e", @@ -82,7 +92,9 @@ async def test_node_pipeline(test_async_client): node = response.json() # Create Task to listen 'updated' event on 'node' channel - task_listen = create_listen_task(test_async_client, pytest.node_channel_subscription_id) # pylint: disable=no-member + task_listen = create_listen_task( + test_async_client, pytest.node_channel_subscription_id + ) # pylint: disable=no-member # Update node.state node.update({"state": "done"}) diff --git a/tests/e2e_tests/test_pubsub_handler.py b/tests/e2e_tests/test_pubsub_handler.py index abc93d8e..36ed772c 100644 --- a/tests/e2e_tests/test_pubsub_handler.py +++ b/tests/e2e_tests/test_pubsub_handler.py @@ -13,7 +13,9 @@ @pytest.mark.dependency( - depends=["tests/e2e_tests/test_subscribe_handler.py::test_subscribe_test_channel"], + depends=[ + "tests/e2e_tests/test_subscribe_handler.py::test_subscribe_test_channel" + ], scope="session", ) @pytest.mark.asyncio @@ -24,7 +26,9 @@ async def test_pubsub_handler(test_async_client): Use pubsub listener task to verify published event message. """ # Create Task to listen pubsub event on 'test_channel' channel - task_listen = create_listen_task(test_async_client, pytest.test_channel_subscription_id) # pylint: disable=no-member + task_listen = create_listen_task( + test_async_client, pytest.test_channel_subscription_id + ) # pylint: disable=no-member # Created and publish CloudEvent attributes = { @@ -35,7 +39,9 @@ async def test_pubsub_handler(test_async_client): event = CloudEvent(attributes, data) headers, body = to_structured(event) headers["Authorization"] = f"Bearer {pytest.BEARER_TOKEN}" # pylint: disable=no-member - response = await test_async_client.post("publish/test_channel", headers=headers, data=body) + response = await test_async_client.post( + "publish/test_channel", headers=headers, data=body + ) assert response.status_code == 200 # Get result of pubsub event listener diff --git a/tests/e2e_tests/test_regression_handler.py b/tests/e2e_tests/test_regression_handler.py index 13d91594..08e207e4 100644 --- a/tests/e2e_tests/test_regression_handler.py +++ b/tests/e2e_tests/test_regression_handler.py @@ -11,7 +11,10 @@ from .test_node_handler import create_node, get_node_by_attribute -@pytest.mark.dependency(depends=["tests/e2e_tests/test_pipeline.py::test_node_pipeline"], scope="session") +@pytest.mark.dependency( + depends=["tests/e2e_tests/test_pipeline.py::test_node_pipeline"], + scope="session", +) @pytest.mark.asyncio async def test_regression_handler(test_async_client): """ @@ -24,7 +27,9 @@ async def test_regression_handler(test_async_client): method. """ # Get "checkout" node - response = await get_node_by_attribute(test_async_client, {"name": "checkout"}) + response = await get_node_by_attribute( + test_async_client, {"name": "checkout"} + ) checkout_node = response.json()["items"][0] # Create a 'kver' passed node @@ -57,7 +62,9 @@ async def test_regression_handler(test_async_client): # Create a "kver" regression node regression_fields = ["group", "name", "path", "state"] - regression_node = {field: failed_node_obj[field] for field in regression_fields} + regression_node = { + field: failed_node_obj[field] for field in regression_fields + } regression_node["kind"] = "regression" regression_node["data"] = { diff --git a/tests/e2e_tests/test_unsubscribe_handler.py b/tests/e2e_tests/test_unsubscribe_handler.py index 9fc8b2ad..dff85f65 100644 --- a/tests/e2e_tests/test_unsubscribe_handler.py +++ b/tests/e2e_tests/test_unsubscribe_handler.py @@ -11,7 +11,9 @@ @pytest.mark.asyncio @pytest.mark.dependency( - depends=["tests/e2e_tests/test_subscribe_handler.py::test_subscribe_node_channel"], + depends=[ + "tests/e2e_tests/test_subscribe_handler.py::test_subscribe_node_channel" + ], scope="session", ) @pytest.mark.order("last") @@ -32,7 +34,9 @@ async def test_unsubscribe_node_channel(test_async_client): @pytest.mark.asyncio @pytest.mark.dependency( - depends=["tests/e2e_tests/test_subscribe_handler.py::test_subscribe_test_channel"], + depends=[ + "tests/e2e_tests/test_subscribe_handler.py::test_subscribe_test_channel" + ], scope="session", ) @pytest.mark.order("last") diff --git a/tests/e2e_tests/test_user_creation.py b/tests/e2e_tests/test_user_creation.py index 864b0ea3..8286ce89 100644 --- a/tests/e2e_tests/test_user_creation.py +++ b/tests/e2e_tests/test_user_creation.py @@ -17,7 +17,9 @@ @pytest.mark.dependency( - depends=["tests/e2e_tests/test_user_group_handler.py::test_create_user_groups"], + depends=[ + "tests/e2e_tests/test_user_group_handler.py::test_create_user_groups" + ], scope="session", ) @pytest.mark.dependency() @@ -82,7 +84,9 @@ async def test_create_regular_user(test_async_client): "Accept": "application/json", "Authorization": f"Bearer {pytest.ADMIN_BEARER_TOKEN}", }, - data=json.dumps({"username": username, "password": password, "email": email}), + data=json.dumps( + {"username": username, "password": password, "email": email} + ), ) assert response.status_code == 200 assert ( @@ -171,7 +175,13 @@ async def test_create_user_negative(test_async_client): "Accept": "application/json", "Authorization": f"Bearer {pytest.BEARER_TOKEN}", }, - data=json.dumps({"username": "test", "password": "test", "email": "test@kernelci.org"}), + data=json.dumps( + { + "username": "test", + "password": "test", + "email": "test@kernelci.org", + } + ), ) assert response.status_code == 403 assert response.json() == {"detail": "Forbidden"} diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py index a48eb3aa..ead60a5c 100644 --- a/tests/unit_tests/conftest.py +++ b/tests/unit_tests/conftest.py @@ -208,7 +208,9 @@ def mock_pubsub_subscriptions(mocker): redis_mock = fakeredis.aioredis.FakeRedis() sub = Subscription(id=1, channel="test", user="test") mocker.patch.object(pubsub, "_redis", redis_mock) - subscriptions_mock = dict({1: {"sub": sub, "redis_sub": pubsub._redis.pubsub()}}) + subscriptions_mock = dict( + {1: {"sub": sub, "redis_sub": pubsub._redis.pubsub()}} + ) mocker.patch.object(pubsub, "_subscriptions", subscriptions_mock) return pubsub @@ -247,8 +249,14 @@ async def mock_init_beanie(mocker): """Mocks async call to Database method to initialize Beanie""" async_mock = AsyncMock() client = AsyncMongoMockClient() - init = await init_beanie(document_models=[User], database=client.get_database(name="db")) - mocker.patch("api.db.Database.initialize_beanie", side_effect=async_mock, return_value=init) + init = await init_beanie( + document_models=[User], database=client.get_database(name="db") + ) + mocker.patch( + "api.db.Database.initialize_beanie", + side_effect=async_mock, + return_value=init, + ) return async_mock @@ -266,7 +274,9 @@ def mock_db_update(mocker): async def mock_beanie_get_user_by_id(mocker): """Mocks async call to external method to get model by id""" async_mock = AsyncMock() - mocker.patch("fastapi_users_db_beanie.BeanieUserDatabase.get", side_effect=async_mock) + mocker.patch( + "fastapi_users_db_beanie.BeanieUserDatabase.get", side_effect=async_mock + ) return async_mock @@ -274,7 +284,10 @@ async def mock_beanie_get_user_by_id(mocker): async def mock_beanie_user_update(mocker): """Mocks async call to external method to update user""" async_mock = AsyncMock() - mocker.patch("fastapi_users_db_beanie.BeanieUserDatabase.update", side_effect=async_mock) + mocker.patch( + "fastapi_users_db_beanie.BeanieUserDatabase.update", + side_effect=async_mock, + ) return async_mock diff --git a/tests/unit_tests/test_authz_handler.py b/tests/unit_tests/test_authz_handler.py index e6df2d73..0475f1e4 100644 --- a/tests/unit_tests/test_authz_handler.py +++ b/tests/unit_tests/test_authz_handler.py @@ -73,14 +73,18 @@ def test_user_can_edit_node_with_matching_group(): def test_user_can_edit_node_with_runtime_editor_group(): """Runtime editor group grants edit access.""" - user = _make_user(groups=[UserGroup(name="runtime:lava-collabora:node-editor")]) + user = _make_user( + groups=[UserGroup(name="runtime:lava-collabora:node-editor")] + ) node = _make_node(runtime="lava-collabora") assert _user_can_edit_node(user, node) def test_user_can_edit_node_with_runtime_admin_group(): """Runtime admin group grants edit access.""" - user = _make_user(groups=[UserGroup(name="runtime:lava-collabora:node-admin")]) + user = _make_user( + groups=[UserGroup(name="runtime:lava-collabora:node-admin")] + ) node = _make_node(runtime="lava-collabora") assert _user_can_edit_node(user, node) @@ -102,7 +106,9 @@ def test_user_can_edit_node_as_superuser(): def test_user_cannot_edit_node_without_access(): """Unrelated user cannot edit when no access applies.""" user = _make_user(username="alice") - node = _make_node(owner="bob", user_groups=["team-a"], runtime="lava-collabora") + node = _make_node( + owner="bob", user_groups=["team-a"], runtime="lava-collabora" + ) assert not _user_can_edit_node(user, node) @@ -118,4 +124,7 @@ def test_user_me_rejects_groups_update(test_client): data=json.dumps(payload), ) assert response.status_code == 400 - assert response.json()["detail"] == "User groups can only be updated by an admin user" + assert ( + response.json()["detail"] + == "User groups can only be updated by an admin user" + ) diff --git a/tests/unit_tests/test_events_handler.py b/tests/unit_tests/test_events_handler.py index c48d43ce..8b21176d 100644 --- a/tests/unit_tests/test_events_handler.py +++ b/tests/unit_tests/test_events_handler.py @@ -44,7 +44,9 @@ def test_get_events_filter_by_ids(mock_db_find_by_attributes, test_client): def test_get_events_rejects_both_id_and_ids(test_client): """GET /events rejects requests with both id and ids parameters.""" - resp = test_client.get("events?id=deadbeefdeadbeefdeadbeef&ids=deadbeefdeadbeefdeadbeef") + resp = test_client.get( + "events?id=deadbeefdeadbeefdeadbeef&ids=deadbeefdeadbeefdeadbeef" + ) assert resp.status_code == 400 @@ -54,7 +56,9 @@ def test_get_events_rejects_invalid_id(test_client): assert resp.status_code == 400 -def test_get_events_filter_by_node_id_alias(mock_db_find_by_attributes, test_client): +def test_get_events_filter_by_node_id_alias( + mock_db_find_by_attributes, test_client +): """GET /events?node_id= aliases to data.id filter.""" node_id = "693af4f5fee8383e92b6b0eb" mock_db_find_by_attributes.return_value = [] @@ -69,5 +73,7 @@ def test_get_events_filter_by_node_id_alias(mock_db_find_by_attributes, test_cli def test_get_events_rejects_node_id_and_data_id(test_client): """GET /events rejects requests with both node_id and data.id parameters.""" - resp = test_client.get("events?node_id=693af4f5fee8383e92b6b0eb&data.id=693af4f5fee8383e92b6b0eb") + resp = test_client.get( + "events?node_id=693af4f5fee8383e92b6b0eb&data.id=693af4f5fee8383e92b6b0eb" + ) assert resp.status_code == 400 diff --git a/tests/unit_tests/test_node_handler.py b/tests/unit_tests/test_node_handler.py index 1cc685f9..fcd37a51 100644 --- a/tests/unit_tests/test_node_handler.py +++ b/tests/unit_tests/test_node_handler.py @@ -18,7 +18,9 @@ from tests.unit_tests.conftest import BEARER_TOKEN -def test_create_node_endpoint(mock_db_create, mock_publish_cloudevent, test_client): +def test_create_node_endpoint( + mock_db_create, mock_publish_cloudevent, test_client +): """ Test Case : Test KernelCI API /node endpoint Expected Result : @@ -89,7 +91,9 @@ def test_create_node_endpoint(mock_db_create, mock_publish_cloudevent, test_clie } -def test_get_nodes_by_attributes_endpoint(mock_db_find_by_attributes, test_client): +def test_get_nodes_by_attributes_endpoint( + mock_db_find_by_attributes, test_client +): """ Test Case : Test KernelCI API GET /nodes?attribute_name=attribute_value endpoint for the positive path @@ -156,7 +160,9 @@ def test_get_nodes_by_attributes_endpoint(mock_db_find_by_attributes, test_clien assert len(response.json()["items"]) > 0 -def test_get_nodes_by_attributes_endpoint_node_not_found(mock_db_find_by_attributes, test_client): +def test_get_nodes_by_attributes_endpoint_node_not_found( + mock_db_find_by_attributes, test_client +): """ Test Case : Test KernelCI API GET /nodes?attribute_name=attribute_value endpoint for the node not found @@ -165,7 +171,9 @@ def test_get_nodes_by_attributes_endpoint_node_not_found(mock_db_find_by_attribu Empty list """ - mock_db_find_by_attributes.return_value = PageModel(items=[], total=0, limit=50, offset=0) + mock_db_find_by_attributes.return_value = PageModel( + items=[], total=0, limit=50, offset=0 + ) params = {"name": "checkout", "revision.tree": "baseline"} response = test_client.get("nodes", params=params) @@ -233,7 +241,9 @@ def test_get_node_by_id_endpoint(mock_db_find_by_id, test_client): } -def test_get_node_by_id_endpoint_empty_response(mock_db_find_by_id, test_client): +def test_get_node_by_id_endpoint_empty_response( + mock_db_find_by_id, test_client +): """ Test Case : Test KernelCI API GET /node/{node_id} endpoint for negative path @@ -340,7 +350,9 @@ def test_get_all_nodes_empty_response(mock_db_find_by_attributes, test_client): HTTP Response Code 200 OK Empty list as no Node object is added. """ - mock_db_find_by_attributes.return_value = PageModel(items=[], total=0, limit=50, offset=0) + mock_db_find_by_attributes.return_value = PageModel( + items=[], total=0, limit=50, offset=0 + ) response = test_client.get("nodes") print("response.json()", response.json()) diff --git a/tests/unit_tests/test_token_handler.py b/tests/unit_tests/test_token_handler.py index db14b8b6..34405820 100644 --- a/tests/unit_tests/test_token_handler.py +++ b/tests/unit_tests/test_token_handler.py @@ -16,7 +16,9 @@ @pytest.mark.asyncio -async def test_token_endpoint(test_async_client, mock_user_find, mock_beanie_user_update): +async def test_token_endpoint( + test_async_client, mock_user_find, mock_beanie_user_update +): """ Test Case : Test KernelCI API /user/login endpoint Expected Result : @@ -49,7 +51,9 @@ async def test_token_endpoint(test_async_client, mock_user_find, mock_beanie_use @pytest.mark.asyncio -async def test_token_endpoint_incorrect_password(test_async_client, mock_user_find): +async def test_token_endpoint_incorrect_password( + test_async_client, mock_user_find +): """ Test Case : Test KernelCI API /user/login endpoint for negative path Incorrect password should be passed to the endpoint diff --git a/tests/unit_tests/test_user_group_handler.py b/tests/unit_tests/test_user_group_handler.py index d701eeaa..24d94cc4 100644 --- a/tests/unit_tests/test_user_group_handler.py +++ b/tests/unit_tests/test_user_group_handler.py @@ -35,7 +35,9 @@ def test_list_user_groups(mock_db_find_by_attributes, test_client): def test_create_user_group(mock_db_find_one, mock_db_create, test_client): """POST /user-groups creates a new user group.""" mock_db_find_one.return_value = None - mock_db_create.return_value = UserGroup(name="runtime:pull-labs-demo:node-editor") + mock_db_create.return_value = UserGroup( + name="runtime:pull-labs-demo:node-editor" + ) response = test_client.post( "user-groups", @@ -49,7 +51,9 @@ def test_create_user_group(mock_db_find_one, mock_db_create, test_client): assert response.json()["name"] == "runtime:pull-labs-demo:node-editor" -def test_delete_user_group(mock_db_find_by_id, mock_db_count, mock_db_delete_by_id, test_client): +def test_delete_user_group( + mock_db_find_by_id, mock_db_count, mock_db_delete_by_id, test_client +): """DELETE /user-groups/{id} removes an unused user group.""" mock_db_find_by_id.return_value = UserGroup(name="team-a") mock_db_count.return_value = 0 @@ -68,7 +72,9 @@ def test_delete_user_group(mock_db_find_by_id, mock_db_count, mock_db_delete_by_ ) -def test_delete_user_group_when_assigned(mock_db_find_by_id, mock_db_count, test_client): +def test_delete_user_group_when_assigned( + mock_db_find_by_id, mock_db_count, test_client +): """DELETE /user-groups/{id} rejects when group is assigned to users.""" mock_db_find_by_id.return_value = UserGroup(name="team-a") mock_db_count.return_value = 2 diff --git a/tests/unit_tests/test_user_handler.py b/tests/unit_tests/test_user_handler.py index 823e1444..c85a899a 100644 --- a/tests/unit_tests/test_user_handler.py +++ b/tests/unit_tests/test_user_handler.py @@ -19,7 +19,9 @@ @pytest.mark.asyncio -async def test_create_regular_user(mock_db_find_one, mock_db_create, test_async_client): +async def test_create_regular_user( + mock_db_find_one, mock_db_create, test_async_client +): """ Test Case : Test KernelCI API /user/register endpoint to create regular user when requested with admin user's bearer token @@ -42,8 +44,17 @@ async def test_create_regular_user(mock_db_find_one, mock_db_create, test_async_ response = await test_async_client.post( "user/register", - headers={"Accept": "application/json", "Authorization": ADMIN_BEARER_TOKEN}, - data=json.dumps({"username": "test", "password": "test", "email": "test@kernelci.org"}), + headers={ + "Accept": "application/json", + "Authorization": ADMIN_BEARER_TOKEN, + }, + data=json.dumps( + { + "username": "test", + "password": "test", + "email": "test@kernelci.org", + } + ), ) print(response.json()) assert response.status_code == 200 @@ -59,7 +70,9 @@ async def test_create_regular_user(mock_db_find_one, mock_db_create, test_async_ @pytest.mark.asyncio -async def test_create_admin_user(test_async_client, mock_db_find_one, mock_db_find_by_id, mock_db_update): +async def test_create_admin_user( + test_async_client, mock_db_find_one, mock_db_find_by_id, mock_db_update +): """ Test Case : Test KernelCI API /user/register endpoint to create admin user when requested with admin user's bearer token @@ -83,7 +96,10 @@ async def test_create_admin_user(test_async_client, mock_db_find_one, mock_db_fi response = await test_async_client.post( "user/register", - headers={"Accept": "application/json", "Authorization": ADMIN_BEARER_TOKEN}, + headers={ + "Accept": "application/json", + "Authorization": ADMIN_BEARER_TOKEN, + }, data=json.dumps( { "username": "test_admin", @@ -118,7 +134,13 @@ async def test_create_user_endpoint_negative(test_async_client): response = await test_async_client.post( "user/register", headers={"Accept": "application/json", "Authorization": BEARER_TOKEN}, - data=json.dumps({"username": "test", "password": "test", "email": "test@kernelci.org"}), + data=json.dumps( + { + "username": "test", + "password": "test", + "email": "test@kernelci.org", + } + ), ) print(response.json()) assert response.status_code == 403 @@ -152,7 +174,10 @@ async def test_create_user_with_group( response = await test_async_client.post( "user/register", - headers={"Accept": "application/json", "Authorization": ADMIN_BEARER_TOKEN}, + headers={ + "Accept": "application/json", + "Authorization": ADMIN_BEARER_TOKEN, + }, data=json.dumps( { "username": "test", @@ -176,7 +201,9 @@ async def test_create_user_with_group( @pytest.mark.asyncio -async def test_get_user_by_id_endpoint(test_async_client, mock_beanie_get_user_by_id): +async def test_get_user_by_id_endpoint( + test_async_client, mock_beanie_get_user_by_id +): """ Test Case : Test KernelCI API GET /user/{user_id} endpoint with admin token @@ -197,7 +224,10 @@ async def test_get_user_by_id_endpoint(test_async_client, mock_beanie_get_user_b response = await test_async_client.get( "user/61bda8f2eb1a63d2b7152418", - headers={"Accept": "application/json", "Authorization": ADMIN_BEARER_TOKEN}, + headers={ + "Accept": "application/json", + "Authorization": ADMIN_BEARER_TOKEN, + }, ) print("response.json()", response.json()) assert response.status_code == 200