diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 9568f4280d..9c431be6cf 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2019-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -59,6 +59,13 @@ if(TRITON_ENABLE_GRPC) set(TRITON_COMMON_ENABLE_GRPC ON) endif() # TRITON_ENABLE_GRPC +# protobuf +# +if(TRITON_ENABLE_HTTP OR TRITON_ENABLE_METRICS OR TRITON_ENABLE_SAGEMAKER OR + TRITON_ENABLE_VERTEX_AI) + set(TRITON_COMMON_ENABLE_PROTOBUF ON) +endif() + FetchContent_MakeAvailable(repo-common repo-core repo-backend) # CUDA @@ -406,6 +413,17 @@ if(${TRITON_ENABLE_HTTP} re2::re2 ) + # model_config.h (GetElementCount, etc.) needs Protobuf + generated protos. + if(TARGET triton-common-model-config) + target_link_libraries( + http-endpoint-library + PUBLIC + triton-common-model-config + proto-library + protobuf::libprotobuf + ) + endif() + target_include_directories( http-endpoint-library PRIVATE $ diff --git a/src/common.cc b/src/common.cc index a7591b8324..c62e1435b9 100644 --- a/src/common.cc +++ b/src/common.cc @@ -1,4 +1,4 @@ -// Copyright 2020-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2020-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -105,35 +105,6 @@ ShapeToString(const std::vector& shape) return ShapeToString(shape.data(), shape.size()); } -int64_t -GetElementCount(const std::vector& dims) -{ - bool first = true; - int64_t cnt = 0; - for (auto dim : dims) { - if (dim == WILDCARD_DIM) { - return -1; - } else if (dim < 0) { // invalid dim - return -2; - } else if (dim == 0) { - return 0; - } - - if (first) { - cnt = dim; - first = false; - } else { - // Check for overflow before multiplication - if (cnt > (INT64_MAX / dim)) { - return -3; - } - cnt *= dim; - } - } - - return cnt; -} - bool Contains(const std::vector& vec, const std::string& str) { diff --git a/src/common.h b/src/common.h index 403aec9a6a..fe4d4a1e96 100644 --- a/src/common.h +++ b/src/common.h @@ -1,4 +1,4 @@ -// Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2019-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -51,10 +51,6 @@ constexpr char kTritonSharedMemoryRegionPrefix[] = constexpr int MAX_GRPC_MESSAGE_SIZE = INT32_MAX; -/// The value for a dimension in a shape that indicates that that -/// dimension can take on any size. -constexpr int WILDCARD_DIM = -1; - // Maximum allowed depth for JSON parsing constexpr int32_t HTTP_MAX_JSON_NESTING_DEPTH = 100; @@ -162,15 +158,6 @@ TRITONSERVER_Error* GetModelVersionFromString( std::string GetEnvironmentVariableOrDefault( const std::string& variable_name, const std::string& default_value); -/// Get the number of elements in a shape. -/// -/// \param dims The shape. -/// \return The number of elements, -1 if the number of elements -/// cannot be determined because the shape contains one or more -/// wildcard dimensions, -2 if the shape contains an invalid dim, -/// or -3 if the number is too large to represent as an int64_t. -int64_t GetElementCount(const std::vector& dims); - /// Convert shape to string representation. /// /// \param shape The shape as a vector. diff --git a/src/http_server.cc b/src/http_server.cc index cce57ee254..f5a8d8cadc 100644 --- a/src/http_server.cc +++ b/src/http_server.cc @@ -44,6 +44,7 @@ #define TRITONJSON_STATUSRETURN(M) \ return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, (M).c_str()) #define TRITONJSON_STATUSSUCCESS nullptr +#include "triton/common/model_config.h" #include "triton/common/triton_json.h" namespace triton { namespace server { @@ -2614,14 +2615,13 @@ HTTPAPIServer::ParseJsonTritonIO( memory_type_id)); } } else { - const int64_t element_cnt = GetElementCount(shape_vec); + const int64_t element_cnt = triton::common::GetElementCount(shape_vec); if (element_cnt == 0) { RETURN_IF_ERR(TRITONSERVER_InferenceRequestAppendInputData( irequest, input_name, nullptr, 0 /* byte_size */, TRITONSERVER_MEMORY_CPU, 0 /* memory_type_id */)); - } else if (element_cnt == -2) { - // -2 indicates invalid dimension + } else if (element_cnt == triton::common::INVALID_SIZE) { return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INVALID_ARG, std::string( @@ -2629,8 +2629,7 @@ HTTPAPIServer::ParseJsonTritonIO( "': shape " + ShapeToString(shape_vec) + " contains one or more invalid dimensions") .c_str()); - } else if (element_cnt == -3) { - // -3 indicates integer overflow + } else if (element_cnt == triton::common::OVERFLOW_SIZE) { return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INVALID_ARG, std::string( @@ -2639,6 +2638,15 @@ HTTPAPIServer::ParseJsonTritonIO( " causes total element count to exceed maximum size of " + std::to_string(INT64_MAX)) .c_str()); + } else if (element_cnt == triton::common::WILDCARD_SIZE) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + std::string( + "invalid shape for input '" + std::string(input_name) + + "': shape " + ShapeToString(shape_vec) + + " contains one or more variable-size dimensions (-1); cannot " + "determine element count for JSON input") + .c_str()); } else { // JSON... presence of "data" already validated but still // checking here. Flow in this endpoint needs to be diff --git a/src/test/tensor_size_test.cc b/src/test/tensor_size_test.cc index 26969fcd42..ed1cc09bcb 100644 --- a/src/test/tensor_size_test.cc +++ b/src/test/tensor_size_test.cc @@ -173,25 +173,35 @@ assert_get_byte_size_success( { int64_t size; TRITONSERVER_Error* err; + inference::DataType core_dtype = tcore::TritonToDataType(dtype); - // Backend (old API) + // Backend (public old API) ASSERT_EQ(expected_size, tb::GetByteSize(dtype, shape)); - // Backend (new API) + // Backend (public new API) err = tb::GetByteSize(dtype, shape, &size); ASSERT_EQ(err, nullptr); ASSERT_EQ(expected_size, size); - // Common - inference::DataType core_dtype = tcore::TritonToDataType(dtype); + // Common (public API) ASSERT_EQ(tc::GetByteSize(core_dtype, shape), expected_size); - // Core + // Core (internal helper) if (test_core) { size = 0; auto status = tcore::GetByteSize(core_dtype, shape, kTensorName, &size); - ASSERT_TRUE(status.IsOk()) << status.Message(); - ASSERT_EQ(size, expected_size); + // Special case: rejects wildcard / non-fixed size with INVALID_ARG + if (expected_size == tc::WILDCARD_SIZE) { + ASSERT_FALSE(status.IsOk()); + ASSERT_EQ(status.StatusCode(), triton::core::Status::Code::INVALID_ARG); + ASSERT_TRUE( + std::string(status.Message()) + .find("contains one or more variable-size dimensions") != + std::string::npos); + } else { + ASSERT_TRUE(status.IsOk()) << status.Message(); + ASSERT_EQ(size, expected_size); + } } } @@ -274,10 +284,6 @@ TEST_F(GetElementCountTest, GetElementCountWildcard) // Test 3: multiple -1 dims shape = {8, -1, -1}; assert_get_element_count_success(shape, expected_cnt); - - // Test 4: -1 dim before overflow - shape = {-1, 1LL << 32, 1LL << 31}; - assert_get_element_count_success(shape, expected_cnt); } TEST_F(GetElementCountTest, GetElementCountZero) @@ -323,6 +329,12 @@ TEST_F(GetElementCountTest, GetElementCountInvalidDim) error_msg = std::string("shape") + tb::ShapeToString(shape) + " contains an invalid dim."; assert_get_element_count_error(shape, ErrorCode::kInvalidDim, error_msg); + + // Test 4: invalid dim after a wildcard + shape = {-1, -2}; + error_msg = std::string("shape") + tb::ShapeToString(shape) + + " contains an invalid dim."; + assert_get_element_count_error(shape, ErrorCode::kInvalidDim, error_msg); } TEST_F(GetElementCountTest, GetElementCountOverflow) @@ -339,11 +351,56 @@ TEST_F(GetElementCountTest, GetElementCountOverflow) shape = {1LL << 32, 1LL << 31}; error_msg = "unexpected integer overflow while calculating element count."; assert_get_element_count_error(shape, ErrorCode::kOverflow, error_msg); +} + +TEST_F(GetElementCountTest, GetElementCountMixed) +{ + std::vector shape; + std::string error_msg; + + // Test 1: -1 dim before overflow + shape = {-1, 1LL << 32, 1LL << 31}; + error_msg = "unexpected integer overflow while calculating element count."; + assert_get_element_count_error(shape, ErrorCode::kOverflow, error_msg); + + // Test 2: -1 dim before overflow 2 + shape = {1LL << 32, -1, 1LL << 31}; + error_msg = "unexpected integer overflow while calculating element count."; + assert_get_element_count_error(shape, ErrorCode::kOverflow, error_msg); // Test 3: overflows before -1 dim shape = {1LL << 32, 1LL << 31, -1}; error_msg = "unexpected integer overflow while calculating element count."; assert_get_element_count_error(shape, ErrorCode::kOverflow, error_msg); + + // Test 4: -1 dim before invalid dim + shape = {-1, -2}; + error_msg = std::string("shape") + tb::ShapeToString(shape) + + " contains an invalid dim."; + assert_get_element_count_error(shape, ErrorCode::kInvalidDim, error_msg); + + // Test 5: invalid dim before -1 dim + shape = {-2, -1}; + error_msg = std::string("shape") + tb::ShapeToString(shape) + + " contains an invalid dim."; + assert_get_element_count_error(shape, ErrorCode::kInvalidDim, error_msg); + + // Test 6: invalid dim before overflow dim + shape = {-2, 1LL << 32, 1LL << 31}; + error_msg = std::string("shape") + tb::ShapeToString(shape) + + " contains an invalid dim."; + assert_get_element_count_error(shape, ErrorCode::kInvalidDim, error_msg); + + // Test 7: invalid dim before overflow dim 2 + shape = {1LL << 32, -2, 1LL << 31}; + error_msg = std::string("shape") + tb::ShapeToString(shape) + + " contains an invalid dim."; + assert_get_element_count_error(shape, ErrorCode::kInvalidDim, error_msg); + + // Test 8: overflow dim before invalid dim + shape = {1LL << 32, 1LL << 31, -2}; + error_msg = "unexpected integer overflow while calculating element count."; + assert_get_element_count_error(shape, ErrorCode::kOverflow, error_msg); } class GetByteSizeTest : public ::testing::Test { @@ -400,7 +457,6 @@ TEST_F(GetByteSizeTest, GetByteSizeWildcard) ASSERT_TRUE(status.IsOk()) << status.Message(); ASSERT_EQ(size, sizeof(int32_t) * 8 * 8); - // Test 3: invalid shape and element count overflows dtype = TRITONSERVER_TYPE_INVALID; shape = {1LL << 40, 1LL << 40}; @@ -481,6 +537,57 @@ TEST_F(GetByteSizeTest, GetByteSizeOverflow) assert_get_byte_size_error(dtype, shape, ErrorCode::kOverflow, error_msg); } +TEST_F(GetByteSizeTest, GetByteSizeMixed) +{ + TRITONSERVER_DataType dtype = TRITONSERVER_TYPE_INT32; + std::vector shape; + std::string error_msg; + + // Test 1: wildcard dim before overflow + shape = {-1, 1LL << 32, 1LL << 31}; + error_msg = "unexpected integer overflow while calculating byte size."; + assert_get_byte_size_error(dtype, shape, ErrorCode::kOverflow, error_msg); + + // Test 2: wildcard dim before overflow 2 + shape = {1LL << 32, -1, 1LL << 31}; + error_msg = "unexpected integer overflow while calculating byte size."; + assert_get_byte_size_error(dtype, shape, ErrorCode::kOverflow, error_msg); + + // Test 3: overflows before wildcard dim + shape = {1LL << 32, 1LL << 31, -1}; + error_msg = "unexpected integer overflow while calculating byte size."; + assert_get_byte_size_error(dtype, shape, ErrorCode::kOverflow, error_msg); + + // Test 4: wildcard dim before invalid dim + shape = {-1, -2}; + error_msg = std::string("shape") + tb::ShapeToString(shape) + + " contains an invalid dim."; + assert_get_byte_size_error(dtype, shape, ErrorCode::kInvalidDim, error_msg); + + // Test 5: invalid dim before wildcard dim + shape = {-2, -1}; + error_msg = std::string("shape") + tb::ShapeToString(shape) + + " contains an invalid dim."; + assert_get_byte_size_error(dtype, shape, ErrorCode::kInvalidDim, error_msg); + + // Test 6: invalid dim before overflow + shape = {-2, 1LL << 32, 1LL << 31}; + error_msg = std::string("shape") + tb::ShapeToString(shape) + + " contains an invalid dim."; + assert_get_byte_size_error(dtype, shape, ErrorCode::kInvalidDim, error_msg); + + // Test 7: invalid dim before overflow 2 + shape = {1LL << 32, -2, 1LL << 31}; + error_msg = std::string("shape") + tb::ShapeToString(shape) + + " contains an invalid dim."; + assert_get_byte_size_error(dtype, shape, ErrorCode::kInvalidDim, error_msg); + + // Test 8: overflow before invalid dim + shape = {1LL << 32, 1LL << 31, -2}; + error_msg = "unexpected integer overflow while calculating byte size."; + assert_get_byte_size_error(dtype, shape, ErrorCode::kOverflow, error_msg); +} + } // namespace int