-
Notifications
You must be signed in to change notification settings - Fork 20
Expand file tree
/
Copy pathfake_vllm_endpoint.py
More file actions
275 lines (233 loc) · 10.2 KB
/
fake_vllm_endpoint.py
File metadata and controls
275 lines (233 loc) · 10.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
# -*- coding: utf-8 -*-
"""
Fake vLLM endpoint for OpenClaw agent training.
Based on ajet/tuner_lib/experimental/oai_model_one2many.py
"""
import os
import uuid
import asyncio
import httpx
import json
import threading
from contextlib import asynccontextmanager
from typing import Dict, List, Optional
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import StreamingResponse
from loguru import logger
from pydantic import BaseModel
from ajet.schema.task import Task, WorkflowOutput
from ajet.copilot.job import AgentJetJob
from ajet.tuner_lib.experimental.swarm_client import SwarmClient
import sys
sys.path.insert(0, os.path.dirname(__file__))
from on_user_submit_new_requests import on_user_submit_new_requests, get_query_history
from on_compute_relative_reward import on_compute_relative_reward
# Configuration
SWARM_URL = os.getenv("AJET_SWARM_URL", "http://localhost:10086")
NUM_REPEAT = int(os.getenv("NUM_REPEAT", "4"))
TRAINING_OBJECTIVE = "Train model to be more extraverted"
# Global State
USER_REQUEST_RECORD: List[Dict] = []
REQUEST_COUNTER = 0
swarm_client: Optional[SwarmClient] = None
ajet_job = AgentJetJob(
algorithm="grpo",
project_name="openclaw-extraversion",
experiment_name="extraversion_training",
n_gpu=8,
model='/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-7B-Instruct',
batch_size=32,
logging="swanlab",
num_repeat=NUM_REPEAT,
max_prompt_length=16000, # at least 16000
max_response_length=8000,
max_model_len=24000, # bigger than / equal to `max_prompt_length + max_response_length`
max_response_length_in_one_turn=4000,
)
class EpisodeResult(BaseModel):
"""Result from a single episode execution."""
episode_uuid: str
response: Dict | List[bytes]
def extract_assistant_message(resp: Dict | List[bytes]) -> Dict:
"""Extract assistant message from response."""
if isinstance(resp, list):
content_parts: List[str] = []
for raw in resp:
line = raw.decode() if isinstance(raw, bytes) else raw
if not line.startswith("data:"):
continue
payload = line[len("data:"):].strip()
if payload == "[DONE]":
break
try:
chunk = json.loads(payload)
delta = chunk.get("choices", [{}])[0].get("delta", {})
if delta.get("content"):
content_parts.append(delta["content"])
except Exception:
pass
return {"role": "assistant", "content": "".join(content_parts)}
else:
return resp.get("choices", [{}])[0].get("message", {})
async def proxy_chat_completion(base_url: str, api_key: str, request: Request, is_stream: bool = False) -> Dict | List[bytes]:
"""Proxy a chat completion request."""
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
"Connection": "close",
}
json_data = await request.json()
json_data["stream"] = is_stream
# Remove fields not supported by vLLM to avoid warnings
UNSUPPORTED_FIELDS = {"strict", "store"}
for field in UNSUPPORTED_FIELDS:
json_data.pop(field, None)
# Also remove 'strict' from response_format if present
if "response_format" in json_data and isinstance(json_data["response_format"], dict):
json_data["response_format"].pop("strict", None)
async with httpx.AsyncClient(timeout=300.0) as client:
resp = await client.post(f"{base_url}/chat/completions", json=json_data, headers=headers)
resp.raise_for_status()
if is_stream:
chunks = []
async for line in resp.aiter_lines():
if line.strip():
chunks.append(line.encode() if isinstance(line, str) else line)
return chunks
else:
return resp.json()
def _check_finish_reason_length(response_data: Dict | List[bytes]) -> bool:
"""Return True if any choice has finish_reason='length'."""
if isinstance(response_data, list):
for raw in response_data:
line = raw.decode() if isinstance(raw, bytes) else raw
if not line.startswith("data:"):
continue
payload = line[len("data:"):].strip()
if payload == "[DONE]":
break
try:
chunk = json.loads(payload)
finish_reason = chunk.get("choices", [{}])[0].get("finish_reason")
if finish_reason == "length":
return True
except Exception:
pass
return False
else:
choices = response_data.get("choices", [])
return any(c.get("finish_reason") == "length" for c in choices)
async def run_single_episode(episode_index: int, request: Request, is_stream: bool) -> EpisodeResult:
"""Run a single episode."""
assert swarm_client is not None
episode_uuid, api_baseurl_key = await asyncio.to_thread(swarm_client.begin_episode)
try:
response_data = await proxy_chat_completion(
base_url=api_baseurl_key.base_url,
api_key=api_baseurl_key.api_key,
request=request,
is_stream=is_stream,
)
if _check_finish_reason_length(response_data):
raise HTTPException(
status_code=400,
detail={
"error": {
"message": "This model's maximum context length is exceeded. Please reduce the length of the messages.",
"type": "invalid_request_error",
"param": "messages",
"code": "context_length_exceeded",
}
},
)
return EpisodeResult(episode_uuid=episode_uuid, response=response_data)
except Exception as e:
logger.error(f"Error in episode {episode_index}: {e}")
swarm_client.abort_episode(episode_uuid)
raise
async def run_all_episodes(request: Request, is_stream: bool) -> List[EpisodeResult]:
"""Run all episodes in parallel."""
episode_tasks = [run_single_episode(i, request, is_stream) for i in range(NUM_REPEAT)]
results = await asyncio.gather(*episode_tasks, return_exceptions=True)
valid_results: List[EpisodeResult] = []
for result in results:
if isinstance(result, HTTPException) and result.status_code == 400:
# Propagate context_length_exceeded directly to client
raise result
elif isinstance(result, Exception):
logger.warning(f"Episode failed: {result}")
elif isinstance(result, EpisodeResult):
valid_results.append(result)
if not valid_results:
raise HTTPException(status_code=500, detail="All episodes failed")
return valid_results
async def finalize_episodes(task: Task, valid_results: List[EpisodeResult], rewards: List[float]) -> None:
"""Finalize all episodes by sending rewards."""
assert swarm_client is not None
loop = asyncio.get_event_loop()
for episode_result, reward in zip(valid_results, rewards):
workflow_output = WorkflowOutput(reward=reward, metadata={})
await loop.run_in_executor(
None,
lambda ep=episode_result, wo=workflow_output: swarm_client.end_episode(task, ep.episode_uuid, wo),
)
async def handle_one2many_request(request: Request, request_id: str) -> Dict | List[bytes]:
"""Handle a one-to-many request."""
json_data = await request.json()
is_stream = json_data.get('stream', False)
messages = json_data.get('messages', [])
message_latest = messages[-1]
user_query = str(message_latest.get("content", "") if isinstance(message_latest, dict) else "")
task = Task(task_id=str(uuid.uuid4()), main_query=user_query, metadata={"TRAINING_OBJECTIVE": TRAINING_OBJECTIVE})
await on_user_submit_new_requests(request_id, task)
valid_results = await run_all_episodes(request, is_stream)
all_answers = [extract_assistant_message(r.response) for r in valid_results]
rewards = await on_compute_relative_reward(valid_results, all_answers, question=user_query)
await finalize_episodes(task, valid_results, rewards)
best_idx = rewards.index(max(rewards))
return valid_results[best_idx].response
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan manager."""
global swarm_client
logger.info(f"Initializing swarm client with URL: {SWARM_URL}")
swarm_client = SwarmClient(SWARM_URL)
logger.info(f"Syncing train config and starting engine with num_repeat={NUM_REPEAT}")
def start_engine_background():
try:
swarm_client.auto_sync_train_config_and_start_engine(ajet_job, force_restart=False)
logger.info("Swarm engine is ready!")
except Exception as e:
logger.warning(f"Engine auto-sync skipped or failed: {e}")
engine_thread = threading.Thread(target=start_engine_background, daemon=True)
engine_thread.start()
yield
app = FastAPI(title="OpenClaw Extraversion Training", lifespan=lifespan)
@app.api_route("/v1/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"])
async def one2many_proxy(request: Request, path: str):
"""Main proxy endpoint."""
global REQUEST_COUNTER
if request.method == "POST" and path == "chat/completions":
REQUEST_COUNTER += 1
request_id = f"req_{REQUEST_COUNTER}_{uuid.uuid4().hex[:8]}"
logger.info(f"Received chat completion request {request_id}")
response_data = await handle_one2many_request(request, request_id)
if isinstance(response_data, list):
async def stream_chunks(chunks: List[bytes]):
for chunk in chunks:
yield chunk + b"\n\n"
return StreamingResponse(stream_chunks(response_data), media_type="text/event-stream")
return response_data
else:
raise HTTPException(status_code=404, detail="Not Found")
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {"status": "healthy"}
@app.get("/requests")
async def get_requests():
"""Get all recorded user requests."""
return {"requests": get_query_history()}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8090)