Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions samples/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ add_library(trt_samples_common STATIC
getOptions.cpp
getOptions.h
getoptWin.h
globalTimerKernel.cu
globalTimerKernel.h
half.h
logger.cpp
logger.h
Expand Down
46 changes: 46 additions & 0 deletions samples/common/globalTimerKernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "globalTimerKernel.h"

namespace
{
__global__ void readGlobalTimerKernel(uint64_t* timestamp)
{
if (timestamp == nullptr)
{
return;
}
uint64_t ts;
asm volatile("mov.u64 %0, %%globaltimer;" : "=l"(ts));
*timestamp = ts;
}
} // namespace

namespace sample
{
// NOTE: cudaGetLastError() only surfaces synchronous launch errors (invalid
// configuration, bad stream, etc.). Any asynchronous execution errors from
// the kernel itself -- e.g. a dereference of a bad device pointer -- will not
// be reported here; they become visible only on a subsequent synchronization
// point such as cudaEventRecord / cudaStreamSynchronize / cudaDeviceSynchronize.
cudaError_t launchGlobalTimerKernel(uint64_t* dTimestamp, cudaStream_t stream) noexcept
{
readGlobalTimerKernel<<<1, 1, 0, stream>>>(dTimestamp);
return cudaGetLastError();
}
} // namespace sample
42 changes: 42 additions & 0 deletions samples/common/globalTimerKernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef TRT_GLOBAL_TIMER_KERNEL_H
#define TRT_GLOBAL_TIMER_KERNEL_H

#include <cstdint>
#include <cuda_runtime_api.h>

namespace sample
{
//! Launch a single-thread kernel that writes the current value of the PTX
//! %globaltimer register (GPU timer in ns) to \p dTimestamp on \p stream.
//!
//! Used as a replacement for cudaEventElapsedTime() when Confidential Compute
//! is enabled, where cudaEventElapsedTime() is documented to be unreliable.
//!
//! \param dTimestamp Device pointer to a uint64_t. Must be non-null and point
//! to valid device memory reachable from \p stream.
//! \param stream CUDA stream on which to launch the kernel.
//! \return Result of \c cudaGetLastError() after the launch. This reports
//! synchronous launch errors only; asynchronous execution errors will
//! surface on a subsequent \c cudaEventRecord /
//! \c cudaStreamSynchronize / similar.
[[nodiscard]] cudaError_t launchGlobalTimerKernel(uint64_t* dTimestamp, cudaStream_t stream) noexcept;
} // namespace sample

#endif // TRT_GLOBAL_TIMER_KERNEL_H
62 changes: 62 additions & 0 deletions samples/common/sampleDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,71 @@

#include <iomanip>

#if !defined(_WIN32)
#include <dlfcn.h>
#endif

namespace sample
{

namespace
{
// Subset of NVML types/constants needed to query Confidential Compute state.
// Declared locally so we do not introduce a build-time dependency on nvml.h
// or libnvidia-ml; functions are resolved via dlopen at runtime.
using NvmlReturnT = int32_t;
constexpr NvmlReturnT kNVML_SUCCESS = 0;
constexpr uint32_t kNVML_CC_FEATURE_ENABLED = 1;

struct NvmlConfComputeSystemState
{
uint32_t environment;
uint32_t ccFeature;
uint32_t devToolsMode;
};

using NvmlInitFn = NvmlReturnT (*)();
using NvmlShutdownFn = NvmlReturnT (*)();
using NvmlGetCcStateFn = NvmlReturnT (*)(NvmlConfComputeSystemState*);

bool queryConfidentialCompute()
{
#if defined(_WIN32)
return false;
#else
void* handle = dlopen("libnvidia-ml.so.1", RTLD_LAZY | RTLD_LOCAL);
if (handle == nullptr)
{
return false;
}

auto init = reinterpret_cast<NvmlInitFn>(dlsym(handle, "nvmlInit_v2"));
auto shutdown = reinterpret_cast<NvmlShutdownFn>(dlsym(handle, "nvmlShutdown"));
auto getState = reinterpret_cast<NvmlGetCcStateFn>(dlsym(handle, "nvmlSystemGetConfComputeState"));

bool enabled = false;
if (init != nullptr && shutdown != nullptr && getState != nullptr && init() == kNVML_SUCCESS)
{
NvmlConfComputeSystemState state{};
if (getState(&state) == kNVML_SUCCESS && state.ccFeature == kNVML_CC_FEATURE_ENABLED)
{
enabled = true;
}
shutdown();
}

dlclose(handle);
return enabled;
#endif
}
} // namespace

bool isConfidentialComputeEnabled()
{
static bool const kCC_ENABLED = queryConfidentialCompute();
return kCC_ENABLED;
}

// Construct GPU UUID string in the same format as nvidia-smi does.
std::string getUuidString(cudaUUID_t uuid)
{
Expand Down
42 changes: 40 additions & 2 deletions samples/common/sampleDevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,25 @@
#define TRT_SAMPLE_DEVICE_H

#include <cassert>
#include <cstdint>
#include <cuda.h>
#include <cuda_runtime.h>
#include <iostream>
#include <thread>

#include "common.h"
#include "globalTimerKernel.h"
#include "sampleUtils.h"

namespace sample
{

//! True when Confidential Compute is enabled on the current system. Cached on
//! the first call. When true, TrtCudaEvent falls back to a GPU global-timer
//! kernel because cudaEventElapsedTime() is unreliable under CC (see nvbug
//! 5598617, mirrors TRT-LLM PR #11657).
[[nodiscard]] bool isConfidentialComputeEnabled();

class TrtCudaEvent;

namespace
Expand Down Expand Up @@ -99,6 +107,10 @@ class TrtCudaEvent
{
const uint32_t flags = blocking ? cudaEventBlockingSync : cudaEventDefault;
CHECK(cudaEventCreateWithFlags(&mEvent, flags));
if (isConfidentialComputeEnabled())
{
CHECK(cudaMalloc(&mDeviceTimestamp, sizeof(uint64_t)));
}
}

TrtCudaEvent(const TrtCudaEvent&) = delete;
Expand All @@ -112,6 +124,10 @@ class TrtCudaEvent
~TrtCudaEvent()
{
CHECK(cudaEventDestroy(mEvent));
if (mDeviceTimestamp != nullptr)
{
CHECK(cudaFree(mDeviceTimestamp));
}
}

cudaEvent_t get() const
Expand All @@ -121,6 +137,10 @@ class TrtCudaEvent

void record(const TrtCudaStream& stream)
{
if (mDeviceTimestamp != nullptr)
{
CHECK(launchGlobalTimerKernel(mDeviceTimestamp, stream.get()));
}
CHECK(cudaEventRecord(mEvent, stream.get()));
}

Expand All @@ -136,13 +156,29 @@ class TrtCudaEvent
synchronize();
e.synchronize();

if (mDeviceTimestamp != nullptr && e.mDeviceTimestamp != nullptr)
{
// Confidential Compute path: read %globaltimer values captured at record().
// cudaEventElapsedTime() is unreliable under CC (nvbug 5598617); the global
// timer kernel reads the same underlying register directly.
// Use signed int64_t so the subtraction is well-defined if the events
// are ever measured out of order (otherwise the unsigned->signed cast
// is implementation-defined in C++17).
int64_t endNs{0};
int64_t startNs{0};
CHECK(cudaMemcpy(&endNs, mDeviceTimestamp, sizeof(int64_t), cudaMemcpyDeviceToHost));
CHECK(cudaMemcpy(&startNs, e.mDeviceTimestamp, sizeof(int64_t), cudaMemcpyDeviceToHost));
return static_cast<float>(endNs - startNs) / 1.0e6F;
}

float time{0};
CHECK(cudaEventElapsedTime(&time, e.get(), get()));
return time;
}

private:
cudaEvent_t mEvent{};
uint64_t* mDeviceTimestamp{nullptr};
};

inline void TrtCudaStream::wait(TrtCudaEvent& event)
Expand Down Expand Up @@ -361,8 +397,10 @@ struct HostDeallocator
cudaPointerAttributes attrs;
CHECK(cudaPointerGetAttributes(&attrs, ptr));

// If pinned, call cudaFreeHost() to deallocate it.
if (attrs.type == cudaMemoryTypeHost)
// If pinned, call cudaFreeHost() to deallocate it. Under Confidential
// Compute, memory returned by cudaMallocHost() may be reported as
// cudaMemoryTypeManaged; it must still be released via cudaFreeHost.
if (attrs.type == cudaMemoryTypeHost || attrs.type == cudaMemoryTypeManaged)
{
CHECK(cudaFreeHost(ptr));
}
Expand Down
Loading