-
Notifications
You must be signed in to change notification settings - Fork 22
Expand file tree
/
Copy pathclient.py
More file actions
128 lines (106 loc) · 4.31 KB
/
client.py
File metadata and controls
128 lines (106 loc) · 4.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""OpenAI text client."""
import json
from typing import Any, Unpack
from celeste.parameters import ParameterMapper
from celeste.providers.openai.responses.client import (
OpenAIResponsesClient as OpenAIResponsesMixin,
)
from celeste.providers.openai.responses.streaming import (
OpenAIResponsesStream as _OpenAIResponsesStream,
)
from celeste.tools import ToolCall, ToolResult
from celeste.types import ImageContent, Message, TextContent, VideoContent
from celeste.utils import build_image_data_url
from ...client import TextClient
from ...io import (
TextInput,
TextOutput,
)
from ...parameters import TextParameters
from ...streaming import TextStream
from ..openresponses.client import OpenResponsesTextStream
from .parameters import OPENAI_PARAMETER_MAPPERS
class OpenAITextStream(_OpenAIResponsesStream, OpenResponsesTextStream):
"""OpenAI streaming for text modality."""
class OpenAITextClient(OpenAIResponsesMixin, TextClient):
"""OpenAI text client using Responses API."""
@classmethod
def parameter_mappers(cls) -> list[ParameterMapper[TextContent]]:
return OPENAI_PARAMETER_MAPPERS
async def generate(
self,
prompt: str | None = None,
*,
messages: list[Message] | None = None,
**parameters: Unpack[TextParameters],
) -> TextOutput:
"""Generate text from prompt."""
inputs = TextInput(prompt=prompt, messages=messages)
return await self._predict(inputs, **parameters)
async def analyze(
self,
prompt: str | None = None,
*,
messages: list[Message] | None = None,
image: ImageContent | None = None,
video: VideoContent | None = None,
**parameters: Unpack[TextParameters],
) -> TextOutput:
"""Analyze image(s) or video(s) with prompt or messages."""
inputs = TextInput(prompt=prompt, messages=messages, image=image, video=video)
return await self._predict(inputs, **parameters)
def _init_request(self, inputs: TextInput) -> dict[str, Any]:
"""Initialize request with input content."""
if inputs.messages is not None:
items: list[dict[str, Any]] = []
for msg in inputs.messages:
if isinstance(msg, ToolResult):
items.append(
{
"type": "function_call_output",
"call_id": msg.tool_call_id,
"output": str(msg.content),
}
)
else:
items.append(msg.model_dump())
return {"input": items}
content: list[dict[str, Any]] = []
if inputs.image is not None:
images = inputs.image if isinstance(inputs.image, list) else [inputs.image]
for img in images:
content.append(
{"type": "input_image", "image_url": build_image_data_url(img)}
)
content.append({"type": "input_text", "text": inputs.prompt or ""})
return {"input": [{"role": "user", "content": content}]}
def _parse_content(
self,
response_data: dict[str, Any],
) -> TextContent:
"""Parse text content from response."""
output = super()._parse_content(response_data)
for item in output:
if item.get("type") == "message":
for part in item.get("content", []):
if part.get("type") == "output_text":
text = part.get("text") or ""
return text
return ""
def _parse_tool_calls(self, response_data: dict[str, Any]) -> list[ToolCall]:
"""Parse tool calls from OpenAI response."""
return [
ToolCall(
id=item.get("call_id", item.get("id", "")),
name=item["name"],
arguments=json.loads(item["arguments"])
if isinstance(item.get("arguments"), str)
else item.get("arguments", {}),
)
for item in response_data.get("output", [])
if item.get("type") == "function_call"
]
def _stream_class(self) -> type[TextStream]:
"""Return the Stream class for this provider."""
return OpenAITextStream
__all__ = ["OpenAITextClient", "OpenAITextStream"]