Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/mistralai/client/_hooks/registration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .custom_user_agent import CustomUserAgentHook
from .deprecation_warning import DeprecationWarningHook
from .traceparent import TraceparentInjectionHook
from .tracing import TracingHook
from .types import Hooks

Expand All @@ -16,6 +17,7 @@ def init_hooks(hooks: Hooks):
"""
tracing_hook = TracingHook()
hooks.register_before_request_hook(CustomUserAgentHook())
hooks.register_before_request_hook(TraceparentInjectionHook())
hooks.register_after_success_hook(DeprecationWarningHook())
hooks.register_after_success_hook(tracing_hook)
hooks.register_before_request_hook(tracing_hook)
Expand Down
38 changes: 38 additions & 0 deletions src/mistralai/client/_hooks/traceparent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import random
from typing import Dict, Union

import httpx
from opentelemetry.propagate import inject

from .types import BeforeRequestContext, BeforeRequestHook


_EXECUTE_OPERATION_IDS = {
"execute_workflow_v1_workflows__workflow_identifier__execute_post",
"execute_workflow_registration_v1_workflows_registrations__workflow_registration_id__execute_post",
}


class TraceparentInjectionHook(BeforeRequestHook):
"""Inject a sampled traceparent on /execute requests so worker traces are always recorded."""

def before_request(
self, hook_ctx: BeforeRequestContext, request: httpx.Request
) -> Union[httpx.Request, Exception]:
if hook_ctx.operation_id not in _EXECUTE_OPERATION_IDS:
return request

# Don't overwrite an explicitly provided traceparent.
if "traceparent" in request.headers:
return request

carrier: Dict[str, str] = {}
inject(carrier)
traceparent = carrier.get("traceparent", "")
if not traceparent.endswith("-01"):
trace_id = random.getrandbits(128)
span_id = random.getrandbits(64)
traceparent = f"00-{trace_id:032x}-{span_id:016x}-01"

request.headers["traceparent"] = traceparent
return request
112 changes: 112 additions & 0 deletions src/mistralai/extra/tests/test_traceparent_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import re
import unittest
from unittest.mock import MagicMock

import httpx
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
from opentelemetry.sdk.trace.sampling import ALWAYS_OFF

from mistralai.client._hooks.traceparent import TraceparentInjectionHook
from mistralai.client._hooks.types import BeforeRequestContext, HookContext


TRACEPARENT_RE = re.compile(r"^00-[0-9a-f]{32}-[0-9a-f]{16}-01$")

_EXECUTE_OP_ID = "execute_workflow_v1_workflows__workflow_identifier__execute_post"
_EXECUTE_REG_OP_ID = "execute_workflow_registration_v1_workflows_registrations__workflow_registration_id__execute_post"
_OTHER_OP_ID = "list_executions_v1_workflows__workflow_identifier__executions_get"


def _make_hook_ctx(operation_id: str = _EXECUTE_OP_ID) -> BeforeRequestContext:
ctx = HookContext(
config=MagicMock(),
base_url="https://api.mistral.ai",
operation_id=operation_id,
oauth2_scopes=None,
security_source=None,
)
return BeforeRequestContext(ctx)


def _make_request(path: str, traceparent: str | None = None) -> httpx.Request:
headers = {}
if traceparent is not None:
headers["traceparent"] = traceparent
return httpx.Request("POST", f"https://api.mistral.ai{path}", headers=headers)


class TestTraceparentInjectionHook(unittest.TestCase):
def setUp(self):
self.hook = TraceparentInjectionHook()

# --- non-execute operations must NOT be touched ---

def test_other_operation_is_unchanged(self):
req = _make_request("/v1/workflows/my-wf/executions")
result = self.hook.before_request(_make_hook_ctx(_OTHER_OP_ID), req)
assert isinstance(result, httpx.Request)
self.assertNotIn("traceparent", result.headers)

# --- execute operations: header injected ---

def test_execute_gets_sampled_traceparent(self):
req = _make_request("/v1/workflows/my-wf/execute")
result = self.hook.before_request(_make_hook_ctx(_EXECUTE_OP_ID), req)
assert isinstance(result, httpx.Request)
self.assertIn("traceparent", result.headers)
self.assertRegex(result.headers["traceparent"], TRACEPARENT_RE)

def test_execute_registration_gets_sampled_traceparent(self):
req = _make_request("/v1/workflows/registrations/reg-id/execute")
result = self.hook.before_request(_make_hook_ctx(_EXECUTE_REG_OP_ID), req)
assert isinstance(result, httpx.Request)
self.assertIn("traceparent", result.headers)
self.assertRegex(result.headers["traceparent"], TRACEPARENT_RE)

def test_explicit_traceparent_is_not_overwritten(self):
explicit = "00-aabbccddeeff00112233445566778899-0102030405060708-01"
req = _make_request("/v1/workflows/my-wf/execute", traceparent=explicit)
result = self.hook.before_request(_make_hook_ctx(_EXECUTE_OP_ID), req)
assert isinstance(result, httpx.Request)
self.assertEqual(result.headers["traceparent"], explicit)

# --- OTEL context propagation ---

def test_propagates_sampled_active_span(self):
exporter = InMemorySpanExporter()
provider = TracerProvider()
provider.add_span_processor(SimpleSpanProcessor(exporter))
tracer = provider.get_tracer("test")

with tracer.start_as_current_span("parent") as span:
req = _make_request("/v1/workflows/my-wf/execute")
result = self.hook.before_request(_make_hook_ctx(_EXECUTE_OP_ID), req)

assert isinstance(result, httpx.Request)
injected = result.headers["traceparent"]
self.assertTrue(injected.endswith("-01"))
trace_id_hex = f"{span.get_span_context().trace_id:032x}"
self.assertIn(trace_id_hex, injected)

def test_generates_fresh_traceparent_when_no_active_span(self):
req = _make_request("/v1/workflows/my-wf/execute")
result = self.hook.before_request(_make_hook_ctx(_EXECUTE_OP_ID), req)
assert isinstance(result, httpx.Request)
self.assertRegex(result.headers["traceparent"], TRACEPARENT_RE)

def test_generates_fresh_traceparent_when_span_is_unsampled(self):
provider = TracerProvider(sampler=ALWAYS_OFF)
tracer = provider.get_tracer("test")

with tracer.start_as_current_span("unsampled-parent"):
req = _make_request("/v1/workflows/my-wf/execute")
result = self.hook.before_request(_make_hook_ctx(_EXECUTE_OP_ID), req)

assert isinstance(result, httpx.Request)
injected = result.headers["traceparent"]
self.assertTrue(injected.endswith("-01"), f"Expected sampled, got: {injected}")

if __name__ == "__main__":
unittest.main()
Loading