diff --git a/src/http_server.cc b/src/http_server.cc index 765cfde64b..d003f87191 100644 --- a/src/http_server.cc +++ b/src/http_server.cc @@ -1134,7 +1134,8 @@ HTTPAPIServer::HTTPAPIServer( const std::shared_ptr& shm_manager, const int32_t port, const bool reuse_port, const std::string& address, const std::string& header_forward_pattern, const int thread_cnt, - const size_t max_input_size, const RestrictedFeatures& restricted_apis) + const int control_request_concurrency, const size_t max_input_size, + const RestrictedFeatures& restricted_apis) : HTTPServer(port, reuse_port, address, header_forward_pattern, thread_cnt), server_(server), trace_manager_(trace_manager), shm_manager_(shm_manager), allocator_(nullptr), server_regex_(R"(/v2(?:/health/(live|ready))?)"), @@ -1181,6 +1182,12 @@ HTTPAPIServer::HTTPAPIServer( "setting allocator's buffer attributes function"); ConfigureGenerateMappingSchema(); + + max_control_requests_ = control_request_concurrency; + if (max_control_requests_ > 0) { + LOG_INFO << "Model-control async enabled, max concurrent requests: " + << max_control_requests_; + } } HTTPAPIServer::~HTTPAPIServer() @@ -1455,118 +1462,228 @@ HTTPAPIServer::HandleRepositoryControl( req, EVHTP_RES_METHNALLOWED, "Method Not Allowed"); } - TRITONSERVER_Error* err = nullptr; + // Validate repository_name synchronously (lightweight) if (!repository_name.empty()) { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_UNSUPPORTED, - "'repository_name' specification is not supported"); - } else { - if (action == "load") { - static auto param_deleter = - [](std::vector* params) { - if (params != nullptr) { - for (auto& param : *params) { - TRITONSERVER_ParameterDelete(param); - } - delete params; + RETURN_AND_RESPOND_IF_ERR( + req, TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "'repository_name' specification is not supported")); + } + + // --- Load path: parse parameters synchronously, then offload blocking call + if (action == "load") { + static auto param_deleter = + [](std::vector* params) { + if (params != nullptr) { + for (auto& param : *params) { + TRITONSERVER_ParameterDelete(param); } - }; - std::unique_ptr< - std::vector, decltype(param_deleter)> - params(new std::vector(), param_deleter); - // local variables to store the decoded file content, the data must - // be valid until TRITONSERVER_ServerLoadModelWithParameters returns. - std::list> binary_files; - // WAR for the const-ness check - std::vector const_params; - triton::common::TritonJson::Value load_request; - size_t buffer_len = 0; - RETURN_AND_RESPOND_IF_ERR( - req, EVRequestToJson(req, "load model", &load_request, &buffer_len)); - - if (buffer_len > 0) { - // Parse request body for parameters - triton::common::TritonJson::Value param_json; - if (load_request.Find("parameters", ¶m_json)) { - // Iterate over each member in 'param_json' - std::vector members; - RETURN_AND_RESPOND_IF_ERR(req, param_json.Members(&members)); - for (const auto& m : members) { - const char* param_str = nullptr; - size_t param_len = 0; + delete params; + } + }; + // Use shared_ptr for these containers because they are captured by value + // in the lambda passed to std::thread, ensuring the data outlives the + // thread and is properly cleaned up when all references are released. + std::shared_ptr> params( + new std::vector(), param_deleter); + // local variables to store the decoded file content, the data must + // be valid until TRITONSERVER_ServerLoadModelWithParameters returns. + auto binary_files = std::make_shared>>(); + // WAR for the const-ness check + auto const_params = + std::make_shared>(); + triton::common::TritonJson::Value load_request; + size_t buffer_len = 0; + RETURN_AND_RESPOND_IF_ERR( + req, EVRequestToJson(req, "load model", &load_request, &buffer_len)); + + if (buffer_len > 0) { + // Parse request body for parameters + triton::common::TritonJson::Value param_json; + if (load_request.Find("parameters", ¶m_json)) { + // Iterate over each member in 'param_json' + std::vector members; + RETURN_AND_RESPOND_IF_ERR(req, param_json.Members(&members)); + for (const auto& m : members) { + const char* param_str = nullptr; + size_t param_len = 0; + RETURN_AND_RESPOND_IF_ERR( + req, + param_json.MemberAsString(m.c_str(), ¶m_str, ¶m_len)); + + TRITONSERVER_Parameter* param = nullptr; + if (m == "config") { + param = TRITONSERVER_ParameterNew( + m.c_str(), TRITONSERVER_PARAMETER_STRING, param_str); + } else if (m.rfind("file:", 0) == 0) { + size_t decoded_size; + binary_files->emplace_back(std::vector()); RETURN_AND_RESPOND_IF_ERR( - req, - param_json.MemberAsString(m.c_str(), ¶m_str, ¶m_len)); - - TRITONSERVER_Parameter* param = nullptr; - if (m == "config") { - param = TRITONSERVER_ParameterNew( - m.c_str(), TRITONSERVER_PARAMETER_STRING, param_str); - } else if (m.rfind("file:", 0) == 0) { - size_t decoded_size; - binary_files.emplace_back(std::vector()); - RETURN_AND_RESPOND_IF_ERR( - req, DecodeBase64( - param_str, param_len, binary_files.back(), - decoded_size, m)); - param = TRITONSERVER_ParameterBytesNew( - m.c_str(), binary_files.back().data(), decoded_size); - } + req, DecodeBase64( + param_str, param_len, binary_files->back(), + decoded_size, m)); + param = TRITONSERVER_ParameterBytesNew( + m.c_str(), binary_files->back().data(), decoded_size); + } - if (param != nullptr) { - params->emplace_back(param); - const_params.emplace_back(param); - } else { - RETURN_AND_RESPOND_IF_ERR( - req, TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - "unexpected error on creating Triton parameter")); - } + if (param != nullptr) { + params->emplace_back(param); + const_params->emplace_back(param); + } else { + RETURN_AND_RESPOND_IF_ERR( + req, TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "unexpected error on creating Triton parameter")); } } } + } + + // If async disabled, fall back to synchronous (original behavior) + if (max_control_requests_ <= 0) { RETURN_AND_RESPOND_IF_ERR( req, TRITONSERVER_ServerLoadModelWithParameters( - server_.get(), model_name.c_str(), const_params.data(), - const_params.size())); - } else if (action == "unload") { - // Check if the dependent models should be removed - bool unload_dependents = false; - { - triton::common::TritonJson::Value control_request; - size_t buffer_len = 0; - RETURN_AND_RESPOND_IF_ERR( - req, EVRequestToJson( - req, "unload model", &control_request, &buffer_len)); - - if (buffer_len > 0) { - triton::common::TritonJson::Value params_json; - if (control_request.Find("parameters", ¶ms_json)) { - triton::common::TritonJson::Value ud_json; - if (params_json.Find("unload_dependents", &ud_json)) { - auto parse_err = ud_json.AsBool(&unload_dependents); - if (parse_err != nullptr) { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ErrorCode(parse_err), - (std::string("Unable to parse 'unload_dependents': ") + - TRITONSERVER_ErrorMessage(parse_err)) - .c_str()); - TRITONSERVER_ErrorDelete(parse_err); - } + server_.get(), model_name.c_str(), const_params->data(), + const_params->size())); + evhtp_send_reply(req, EVHTP_RES_OK); + return; + } + + // Check concurrency limit before pausing the request. + int current = control_request_cnt_.fetch_add(1, std::memory_order_acq_rel); + if (current >= max_control_requests_) { + control_request_cnt_.fetch_sub(1, std::memory_order_acq_rel); + RETURN_AND_RESPOND_WITH_ERR( + req, EVHTP_RES_SERVUNAVAIL, + "Model control request rejected: too many concurrent " + "load/unload requests"); + return; + } + + // Create the async request object. This captures the evhtp thread + // and pauses the request so the evhtp worker is freed immediately. + auto* ctrl_req = new ControlRequestClass(req); + TRITONSERVER_Server* raw_server = server_.get(); + auto* cnt_ptr = &control_request_cnt_; + + // Spawn a detached thread for the blocking call. Thread creation + // overhead is acceptable for model load/unload operations. + // Wrap in try/catch to handle std::system_error if thread creation fails. + try { + std::thread([ctrl_req, raw_server, model_name, params, binary_files, + const_params, cnt_ptr]() { + ctrl_req->err_ = TRITONSERVER_ServerLoadModelWithParameters( + raw_server, model_name.c_str(), const_params->data(), + const_params->size()); + cnt_ptr->fetch_sub(1, std::memory_order_acq_rel); + evthr_defer( + ctrl_req->thread_, ControlRequestClass::ReplyCallback, ctrl_req); + }).detach(); + } + catch (const std::system_error& e) { + // Thread creation failed — clean up and return error synchronously. + control_request_cnt_.fetch_sub(1, std::memory_order_acq_rel); + // ctrl_req has already paused the request; we must resume it. + evhtp_request_resume(req); + delete ctrl_req; + RETURN_AND_RESPOND_WITH_ERR( + req, EVHTP_RES_SERVERR, + (std::string("Failed to spawn load thread: ") + e.what()).c_str()); + return; + } + return; + } + + // --- Unload path: parse parameters synchronously, then offload blocking call + if (action == "unload") { + bool unload_dependents = false; + { + triton::common::TritonJson::Value control_request; + size_t buffer_len = 0; + RETURN_AND_RESPOND_IF_ERR( + req, + EVRequestToJson(req, "unload model", &control_request, &buffer_len)); + + if (buffer_len > 0) { + triton::common::TritonJson::Value params_json; + if (control_request.Find("parameters", ¶ms_json)) { + triton::common::TritonJson::Value ud_json; + if (params_json.Find("unload_dependents", &ud_json)) { + TRITONSERVER_Error* parse_err = ud_json.AsBool(&unload_dependents); + if (parse_err != nullptr) { + TRITONSERVER_Error* err = TRITONSERVER_ErrorNew( + TRITONSERVER_ErrorCode(parse_err), + (std::string("Unable to parse 'unload_dependents': ") + + TRITONSERVER_ErrorMessage(parse_err)) + .c_str()); + TRITONSERVER_ErrorDelete(parse_err); + RETURN_AND_RESPOND_IF_ERR(req, err); } } } } + } + + // If async disabled, fall back to synchronous (original behavior) + if (max_control_requests_ <= 0) { + TRITONSERVER_Error* err = nullptr; if (unload_dependents) { err = TRITONSERVER_ServerUnloadModelAndDependents( server_.get(), model_name.c_str()); } else { err = TRITONSERVER_ServerUnloadModel(server_.get(), model_name.c_str()); } + RETURN_AND_RESPOND_IF_ERR(req, err); + evhtp_send_reply(req, EVHTP_RES_OK); + return; + } + + // Check concurrency limit. + int current = control_request_cnt_.fetch_add(1, std::memory_order_acq_rel); + if (current >= max_control_requests_) { + control_request_cnt_.fetch_sub(1, std::memory_order_acq_rel); + RETURN_AND_RESPOND_WITH_ERR( + req, EVHTP_RES_SERVUNAVAIL, + "Model control request rejected: too many concurrent " + "load/unload requests"); + return; + } + + // Create the async request object. + auto* ctrl_req = new ControlRequestClass(req); + TRITONSERVER_Server* raw_server = server_.get(); + auto* cnt_ptr = &control_request_cnt_; + + // Wrap in try/catch to handle std::system_error if thread creation fails. + try { + std::thread([ctrl_req, raw_server, model_name, unload_dependents, + cnt_ptr]() { + if (unload_dependents) { + ctrl_req->err_ = TRITONSERVER_ServerUnloadModelAndDependents( + raw_server, model_name.c_str()); + } else { + ctrl_req->err_ = + TRITONSERVER_ServerUnloadModel(raw_server, model_name.c_str()); + } + cnt_ptr->fetch_sub(1, std::memory_order_acq_rel); + evthr_defer( + ctrl_req->thread_, ControlRequestClass::ReplyCallback, ctrl_req); + }).detach(); + } + catch (const std::system_error& e) { + // Thread creation failed — clean up and return error synchronously. + control_request_cnt_.fetch_sub(1, std::memory_order_acq_rel); + evhtp_request_resume(req); + delete ctrl_req; + RETURN_AND_RESPOND_WITH_ERR( + req, EVHTP_RES_SERVERR, + (std::string("Failed to spawn unload thread: ") + e.what()).c_str()); + return; } + return; } - RETURN_AND_RESPOND_IF_ERR(req, err); + // Unknown action — should not normally reach here evhtp_send_reply(req, EVHTP_RES_OK); } @@ -3874,6 +3991,24 @@ HTTPAPIServer::HandleInfer( request_release_payload.release(); } +void +HTTPAPIServer::ControlRequestClass::ReplyCallback( + evthr_t* thr, void* arg, void* shared) +{ + auto* ctrl_req = reinterpret_cast(arg); + evhtp_request_t* req = ctrl_req->req_; + if (req != nullptr) { + if (ctrl_req->err_ != nullptr) { + EVBufferAddErrorJson(req->buffer_out, ctrl_req->err_); + evhtp_send_reply(req, HttpCodeFromError(ctrl_req->err_)); + } else { + evhtp_send_reply(req, EVHTP_RES_OK); + } + evhtp_request_resume(req); + } + delete ctrl_req; +} + void HTTPAPIServer::InferRequestClass::ReplyCallback( evthr_t* thr, void* arg, void* shared) @@ -4892,12 +5027,14 @@ HTTPAPIServer::Create( const std::shared_ptr& shm_manager, const int32_t port, const bool reuse_port, const std::string& address, const std::string& header_forward_pattern, const int thread_cnt, - const size_t max_input_size, const RestrictedFeatures& restricted_features, + const int control_request_concurrency, const size_t max_input_size, + const RestrictedFeatures& restricted_features, std::unique_ptr* http_server) { http_server->reset(new HTTPAPIServer( server, trace_manager, shm_manager, port, reuse_port, address, - header_forward_pattern, thread_cnt, max_input_size, restricted_features)); + header_forward_pattern, thread_cnt, control_request_concurrency, + max_input_size, restricted_features)); const std::string addr = address + ":" + std::to_string(port); LOG_INFO << "Started HTTPService at " << addr; @@ -4928,10 +5065,14 @@ HTTPAPIServer::Create( GetValue(options, "header_forward_pattern", &header_forward_pattern)); RETURN_IF_ERR(GetValue(options, "thread_count", &thread_count)); + // Use model_load_thread_count as concurrency limit if available + int control_concurrency = 4; // default + GetValue(options, "control_request_concurrency", &control_concurrency); + return Create( server, trace_manager, shm_manager, port, reuse_port, address, - header_forward_pattern, thread_count, HTTP_DEFAULT_MAX_INPUT_SIZE, - restricted_features, service); + header_forward_pattern, thread_count, control_concurrency, + HTTP_DEFAULT_MAX_INPUT_SIZE, restricted_features, service); } diff --git a/src/http_server.h b/src/http_server.h index 2187ec1700..6be968fdd3 100644 --- a/src/http_server.h +++ b/src/http_server.h @@ -28,6 +28,7 @@ #include #include +#include #include #include #include @@ -36,6 +37,7 @@ #include #include #include +#include #include "common.h" #include "data_compressor.h" @@ -198,7 +200,8 @@ class HTTPAPIServer : public HTTPServer { const std::shared_ptr& smb_manager, const int32_t port, const bool reuse_port, const std::string& address, const std::string& header_forward_pattern, const int thread_cnt, - const size_t max_input_size, const RestrictedFeatures& restricted_apis, + const int control_request_concurrency, const size_t max_input_size, + const RestrictedFeatures& restricted_apis, std::unique_ptr* http_server); static TRITONSERVER_Error* Create( @@ -380,6 +383,53 @@ class HTTPAPIServer : public HTTPServer { evhtp_res response_code_{EVHTP_RES_OK}; }; + // Lightweight request class for model repository control operations + // (load/unload). Follows the same async pattern as InferRequestClass: + // capture the evhtp thread, pause the request, do blocking work off-thread, + // then defer the reply back via evthr_defer. + class ControlRequestClass { + public: + explicit ControlRequestClass(evhtp_request_t* req) + : req_(req), thread_(nullptr), err_(nullptr) + { + evhtp_connection_t* htpconn = evhtp_request_get_connection(req); + thread_ = htpconn->thread; + evhtp_request_pause(req); + evhtp_request_set_hook( + req_, evhtp_hook_on_request_fini, + (evhtp_hook)(void*)ControlRequestFiniHook, + reinterpret_cast(this)); + } + + ~ControlRequestClass() + { + if (req_ != nullptr) { + evhtp_request_unset_hook(req_, evhtp_hook_on_request_fini); + } + if (err_ != nullptr) { + TRITONSERVER_ErrorDelete(err_); + } + } + + // Called by evhtp when the connection is closed before we reply. + // Nulls req_ so ReplyCallback will skip the reply safely. + static evhtp_res ControlRequestFiniHook(evhtp_request* req, void* arg) + { + auto* ctrl_req = reinterpret_cast(arg); + ctrl_req->req_ = nullptr; + return EVHTP_RES_OK; + } + + // Deferred onto the evhtp thread via evthr_defer. Sends the reply + // and resumes the paused request. Deletes the ControlRequestClass. + // Defined in http_server.cc because it uses file-local helper functions. + static void ReplyCallback(evthr_t* thr, void* arg, void* shared); + + evhtp_request_t* req_; + evthr_t* thread_; + TRITONSERVER_Error* err_; + }; + class GenerateRequestClass : public InferRequestClass { public: explicit GenerateRequestClass( @@ -503,6 +553,7 @@ class HTTPAPIServer : public HTTPServer { const std::shared_ptr& shm_manager, const int32_t port, const bool reuse_port, const std::string& address, const std::string& header_forward_pattern, const int thread_cnt, + const int control_request_concurrency, const size_t max_input_size = HTTP_DEFAULT_MAX_INPUT_SIZE, const RestrictedFeatures& restricted_apis = {}); @@ -720,6 +771,14 @@ class HTTPAPIServer : public HTTPServer { RestrictedFeatures restricted_apis_{}; bool RespondIfRestricted( evhtp_request_t* req, const Restriction& restriction); + + // Maximum number of concurrent model repository control operations + // (load/unload) that can run off the evhtp worker threads. Each request + // spawns a detached thread; this limit prevents thread explosion. + // Configured via control_request_concurrency; main.cc wires it from + // --model-load-thread-count. 0 disables async (synchronous). + int max_control_requests_; + std::atomic control_request_cnt_{0}; }; }} // namespace triton::server diff --git a/src/main.cc b/src/main.cc index f0a6fc4e7b..8c8510b60f 100644 --- a/src/main.cc +++ b/src/main.cc @@ -138,7 +138,9 @@ StartHttpService( server, trace_manager, shm_manager, g_triton_params.http_port_, g_triton_params.reuse_http_port_, g_triton_params.http_address_, g_triton_params.http_forward_header_pattern_, - g_triton_params.http_thread_cnt_, g_triton_params.http_max_input_size_, + g_triton_params.http_thread_cnt_, + g_triton_params.model_load_thread_count_, + g_triton_params.http_max_input_size_, g_triton_params.http_restricted_apis_, service); if (err == nullptr) { err = (*service)->Start();