Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
285 changes: 190 additions & 95 deletions src/http_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1134,7 +1134,8 @@ HTTPAPIServer::HTTPAPIServer(
const std::shared_ptr<SharedMemoryManager>& shm_manager, const int32_t port,
const bool reuse_port, const std::string& address,
const std::string& header_forward_pattern, const int thread_cnt,
const size_t max_input_size, const RestrictedFeatures& restricted_apis)
const int control_request_concurrency, const size_t max_input_size,
const RestrictedFeatures& restricted_apis)
: HTTPServer(port, reuse_port, address, header_forward_pattern, thread_cnt),
server_(server), trace_manager_(trace_manager), shm_manager_(shm_manager),
allocator_(nullptr), server_regex_(R"(/v2(?:/health/(live|ready))?)"),
Expand Down Expand Up @@ -1181,6 +1182,12 @@ HTTPAPIServer::HTTPAPIServer(
"setting allocator's buffer attributes function");

ConfigureGenerateMappingSchema();

max_control_requests_ = control_request_concurrency;
if (max_control_requests_ > 0) {
LOG_INFO << "Model-control async enabled, max concurrent requests: "
<< max_control_requests_;
}
}

HTTPAPIServer::~HTTPAPIServer()
Expand Down Expand Up @@ -1455,118 +1462,200 @@ HTTPAPIServer::HandleRepositoryControl(
req, EVHTP_RES_METHNALLOWED, "Method Not Allowed");
}

TRITONSERVER_Error* err = nullptr;
// Validate repository_name synchronously (lightweight)
if (!repository_name.empty()) {
err = TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_UNSUPPORTED,
"'repository_name' specification is not supported");
} else {
if (action == "load") {
static auto param_deleter =
[](std::vector<TRITONSERVER_Parameter*>* params) {
if (params != nullptr) {
for (auto& param : *params) {
TRITONSERVER_ParameterDelete(param);
}
delete params;
RETURN_AND_RESPOND_IF_ERR(
req, TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_UNSUPPORTED,
"'repository_name' specification is not supported"));
}

Comment thread
itsnothuy marked this conversation as resolved.
// --- Load path: parse parameters synchronously, then offload blocking call
if (action == "load") {
static auto param_deleter =
[](std::vector<TRITONSERVER_Parameter*>* params) {
if (params != nullptr) {
for (auto& param : *params) {
TRITONSERVER_ParameterDelete(param);
}
};
std::unique_ptr<
std::vector<TRITONSERVER_Parameter*>, decltype(param_deleter)>
params(new std::vector<TRITONSERVER_Parameter*>(), param_deleter);
// local variables to store the decoded file content, the data must
// be valid until TRITONSERVER_ServerLoadModelWithParameters returns.
std::list<std::vector<char>> binary_files;
// WAR for the const-ness check
std::vector<const TRITONSERVER_Parameter*> const_params;
triton::common::TritonJson::Value load_request;
size_t buffer_len = 0;
RETURN_AND_RESPOND_IF_ERR(
req, EVRequestToJson(req, "load model", &load_request, &buffer_len));

if (buffer_len > 0) {
// Parse request body for parameters
triton::common::TritonJson::Value param_json;
if (load_request.Find("parameters", &param_json)) {
// Iterate over each member in 'param_json'
std::vector<std::string> members;
RETURN_AND_RESPOND_IF_ERR(req, param_json.Members(&members));
for (const auto& m : members) {
const char* param_str = nullptr;
size_t param_len = 0;
delete params;
}
};
// Use shared_ptr so the lambda closure can be copyable (required by
// std::function) while still using a custom deleter.
std::shared_ptr<std::vector<TRITONSERVER_Parameter*>> params(
new std::vector<TRITONSERVER_Parameter*>(), param_deleter);
// local variables to store the decoded file content, the data must
// be valid until TRITONSERVER_ServerLoadModelWithParameters returns.
auto binary_files = std::make_shared<std::list<std::vector<char>>>();
// WAR for the const-ness check
auto const_params =
std::make_shared<std::vector<const TRITONSERVER_Parameter*>>();
triton::common::TritonJson::Value load_request;
size_t buffer_len = 0;
RETURN_AND_RESPOND_IF_ERR(
req, EVRequestToJson(req, "load model", &load_request, &buffer_len));

if (buffer_len > 0) {
// Parse request body for parameters
triton::common::TritonJson::Value param_json;
if (load_request.Find("parameters", &param_json)) {
// Iterate over each member in 'param_json'
std::vector<std::string> members;
RETURN_AND_RESPOND_IF_ERR(req, param_json.Members(&members));
for (const auto& m : members) {
const char* param_str = nullptr;
size_t param_len = 0;
RETURN_AND_RESPOND_IF_ERR(
req,
param_json.MemberAsString(m.c_str(), &param_str, &param_len));

TRITONSERVER_Parameter* param = nullptr;
if (m == "config") {
param = TRITONSERVER_ParameterNew(
m.c_str(), TRITONSERVER_PARAMETER_STRING, param_str);
} else if (m.rfind("file:", 0) == 0) {
size_t decoded_size;
binary_files->emplace_back(std::vector<char>());
RETURN_AND_RESPOND_IF_ERR(
req,
param_json.MemberAsString(m.c_str(), &param_str, &param_len));

TRITONSERVER_Parameter* param = nullptr;
if (m == "config") {
param = TRITONSERVER_ParameterNew(
m.c_str(), TRITONSERVER_PARAMETER_STRING, param_str);
} else if (m.rfind("file:", 0) == 0) {
size_t decoded_size;
binary_files.emplace_back(std::vector<char>());
RETURN_AND_RESPOND_IF_ERR(
req, DecodeBase64(
param_str, param_len, binary_files.back(),
decoded_size, m));
param = TRITONSERVER_ParameterBytesNew(
m.c_str(), binary_files.back().data(), decoded_size);
}
req, DecodeBase64(
param_str, param_len, binary_files->back(),
decoded_size, m));
param = TRITONSERVER_ParameterBytesNew(
m.c_str(), binary_files->back().data(), decoded_size);
}

if (param != nullptr) {
params->emplace_back(param);
const_params.emplace_back(param);
} else {
RETURN_AND_RESPOND_IF_ERR(
req, TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INTERNAL,
"unexpected error on creating Triton parameter"));
}
if (param != nullptr) {
params->emplace_back(param);
const_params->emplace_back(param);
} else {
RETURN_AND_RESPOND_IF_ERR(
req, TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INTERNAL,
"unexpected error on creating Triton parameter"));
}
}
}
}

// If async disabled, fall back to synchronous (original behavior)
if (max_control_requests_ <= 0) {
RETURN_AND_RESPOND_IF_ERR(
req, TRITONSERVER_ServerLoadModelWithParameters(
server_.get(), model_name.c_str(), const_params.data(),
const_params.size()));
} else if (action == "unload") {
// Check if the dependent models should be removed
bool unload_dependents = false;
{
triton::common::TritonJson::Value control_request;
size_t buffer_len = 0;
RETURN_AND_RESPOND_IF_ERR(
req, EVRequestToJson(
req, "unload model", &control_request, &buffer_len));

if (buffer_len > 0) {
triton::common::TritonJson::Value params_json;
if (control_request.Find("parameters", &params_json)) {
triton::common::TritonJson::Value ud_json;
if (params_json.Find("unload_dependents", &ud_json)) {
auto parse_err = ud_json.AsBool(&unload_dependents);
if (parse_err != nullptr) {
err = TRITONSERVER_ErrorNew(
TRITONSERVER_ErrorCode(parse_err),
(std::string("Unable to parse 'unload_dependents': ") +
TRITONSERVER_ErrorMessage(parse_err))
.c_str());
TRITONSERVER_ErrorDelete(parse_err);
}
server_.get(), model_name.c_str(), const_params->data(),
const_params->size()));
evhtp_send_reply(req, EVHTP_RES_OK);
return;
}

// Check concurrency limit before pausing the request.
int current = control_request_cnt_.fetch_add(1, std::memory_order_acq_rel);
if (current >= max_control_requests_) {
control_request_cnt_.fetch_sub(1, std::memory_order_acq_rel);
RETURN_AND_RESPOND_WITH_ERR(
req, EVHTP_RES_SERVUNAVAIL,
"Model control request rejected: too many concurrent "
"load/unload requests");
return;
}

// Create the async request object. This captures the evhtp thread
// and pauses the request so the evhtp worker is freed immediately.
auto* ctrl_req = new ControlRequestClass(req);
Comment thread
itsnothuy marked this conversation as resolved.
TRITONSERVER_Server* raw_server = server_.get();
auto* cnt_ptr = &control_request_cnt_;

// Spawn a detached thread for the blocking call. Thread creation
// overhead is acceptable for model load/unload operations.
std::thread([ctrl_req, raw_server, model_name, params, binary_files,
const_params, cnt_ptr]() {
ctrl_req->err_ = TRITONSERVER_ServerLoadModelWithParameters(
raw_server, model_name.c_str(), const_params->data(),
const_params->size());
cnt_ptr->fetch_sub(1, std::memory_order_acq_rel);
evthr_defer(
ctrl_req->thread_, ControlRequestClass::ReplyCallback, ctrl_req);
}).detach();
return;
}

// --- Unload path: parse parameters synchronously, then offload blocking call
if (action == "unload") {
bool unload_dependents = false;
{
triton::common::TritonJson::Value control_request;
size_t buffer_len = 0;
RETURN_AND_RESPOND_IF_ERR(
req,
EVRequestToJson(req, "unload model", &control_request, &buffer_len));

if (buffer_len > 0) {
triton::common::TritonJson::Value params_json;
if (control_request.Find("parameters", &params_json)) {
triton::common::TritonJson::Value ud_json;
if (params_json.Find("unload_dependents", &ud_json)) {
TRITONSERVER_Error* parse_err = ud_json.AsBool(&unload_dependents);
if (parse_err != nullptr) {
TRITONSERVER_Error* err = TRITONSERVER_ErrorNew(
TRITONSERVER_ErrorCode(parse_err),
(std::string("Unable to parse 'unload_dependents': ") +
TRITONSERVER_ErrorMessage(parse_err))
.c_str());
TRITONSERVER_ErrorDelete(parse_err);
RETURN_AND_RESPOND_IF_ERR(req, err);
}
}
}
}
}

// If async disabled, fall back to synchronous (original behavior)
if (max_control_requests_ <= 0) {
TRITONSERVER_Error* err = nullptr;
if (unload_dependents) {
err = TRITONSERVER_ServerUnloadModelAndDependents(
server_.get(), model_name.c_str());
} else {
err = TRITONSERVER_ServerUnloadModel(server_.get(), model_name.c_str());
}
RETURN_AND_RESPOND_IF_ERR(req, err);
evhtp_send_reply(req, EVHTP_RES_OK);
return;
}

// Check concurrency limit.
int current = control_request_cnt_.fetch_add(1, std::memory_order_acq_rel);
if (current >= max_control_requests_) {
control_request_cnt_.fetch_sub(1, std::memory_order_acq_rel);
RETURN_AND_RESPOND_WITH_ERR(
req, EVHTP_RES_SERVUNAVAIL,
"Model control request rejected: too many concurrent "
"load/unload requests");
return;
}

// Create the async request object.
auto* ctrl_req = new ControlRequestClass(req);
TRITONSERVER_Server* raw_server = server_.get();
auto* cnt_ptr = &control_request_cnt_;
Comment thread
itsnothuy marked this conversation as resolved.

std::thread([ctrl_req, raw_server, model_name, unload_dependents,
cnt_ptr]() {
if (unload_dependents) {
ctrl_req->err_ = TRITONSERVER_ServerUnloadModelAndDependents(
raw_server, model_name.c_str());
} else {
ctrl_req->err_ =
TRITONSERVER_ServerUnloadModel(raw_server, model_name.c_str());
}
cnt_ptr->fetch_sub(1, std::memory_order_acq_rel);
evthr_defer(
ctrl_req->thread_, ControlRequestClass::ReplyCallback, ctrl_req);
}).detach();
return;
}

RETURN_AND_RESPOND_IF_ERR(req, err);
// Unknown action — should not normally reach here
evhtp_send_reply(req, EVHTP_RES_OK);
}

Expand Down Expand Up @@ -4880,12 +4969,14 @@ HTTPAPIServer::Create(
const std::shared_ptr<SharedMemoryManager>& shm_manager, const int32_t port,
const bool reuse_port, const std::string& address,
const std::string& header_forward_pattern, const int thread_cnt,
const size_t max_input_size, const RestrictedFeatures& restricted_features,
const int control_request_concurrency, const size_t max_input_size,
const RestrictedFeatures& restricted_features,
std::unique_ptr<HTTPServer>* http_server)
{
http_server->reset(new HTTPAPIServer(
server, trace_manager, shm_manager, port, reuse_port, address,
header_forward_pattern, thread_cnt, max_input_size, restricted_features));
header_forward_pattern, thread_cnt, control_request_concurrency,
max_input_size, restricted_features));

const std::string addr = address + ":" + std::to_string(port);
LOG_INFO << "Started HTTPService at " << addr;
Expand Down Expand Up @@ -4916,10 +5007,14 @@ HTTPAPIServer::Create(
GetValue(options, "header_forward_pattern", &header_forward_pattern));
RETURN_IF_ERR(GetValue(options, "thread_count", &thread_count));

// Use model_load_thread_count as concurrency limit if available
int control_concurrency = 4; // default
GetValue(options, "control_request_concurrency", &control_concurrency);

return Create(
server, trace_manager, shm_manager, port, reuse_port, address,
header_forward_pattern, thread_count, HTTP_DEFAULT_MAX_INPUT_SIZE,
restricted_features, service);
header_forward_pattern, thread_count, control_concurrency,
HTTP_DEFAULT_MAX_INPUT_SIZE, restricted_features, service);
}


Expand Down
Loading