Skip to content

docs: update TRT-LLM speculative decoding guide to LLM API / PyTorch backend#150

Merged
whoisj merged 4 commits intotriton-inference-server:mainfrom
faradawn:fix/speculative-decoding-trtllm-backend
Apr 21, 2026
Merged

docs: update TRT-LLM speculative decoding guide to LLM API / PyTorch backend#150
whoisj merged 4 commits intotriton-inference-server:mainfrom
faradawn:fix/speculative-decoding-trtllm-backend

Conversation

@faradawn
Copy link
Copy Markdown
Contributor

@faradawn faradawn commented Apr 10, 2026

Replace the legacy TRT engine backend workflow with the modern LLM API / PyTorch backend.

Replace the legacy TRT engine backend approach (trtllm-build, inflight_batcher_llm,
fill_template.py) with the modern LLM API / PyTorch backend workflow. Update EAGLE
section to use EAGLE 3 with Llama-3.1-8B-Instruct, add deprecation notice for
MEDUSA (unsupported on PyTorch backend), and update Draft Model section to use
DraftTargetDecodingConfig via model.yaml.

Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
@faradawn
Copy link
Copy Markdown
Contributor Author

Tested on H100x2


=============================
== Triton Inference Server ==
=============================

NVIDIA Release 25.12 (build 246541570)
Triton Server Version 2.64.0

Copyright (c) 2018-2025, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.

Various files include modifications (c) NVIDIA CORPORATION & AFFILIATES.  All rights reserved.

GOVERNING TERMS: The software and materials are governed by the NVIDIA Software License Agreement
(found at https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-software-license-agreement/)
and the Product-Specific Terms for NVIDIA AI Products
(found at https://www.nvidia.com/en-us/agreements/enterprise-software/product-specific-terms-for-ai-products/).

I0410 16:57:07.444265 1 pinned_memory_manager.cc:277] "Pinned memory pool is created at '0x7a9ffa000000' with size 268435456"
I0410 16:57:07.444400 1 cuda_memory_manager.cc:107] "CUDA memory pool is created on device 0 with size 67108864"
I0410 16:57:07.444405 1 cuda_memory_manager.cc:107] "CUDA memory pool is created on device 1 with size 67108864"
I0410 16:57:07.635758 1 model_lifecycle.cc:473] "loading: tensorrt_llm:1"
/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
W0410 16:57:19.482000 176 torch/utils/cpp_extension.py:2422] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
W0410 16:57:19.482000 176 torch/utils/cpp_extension.py:2422] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
[TensorRT-LLM] TensorRT LLM version: 1.1.0
I0410 16:57:22.818622 1 python_be.cc:2289] "TRITONBACKEND_ModelInstanceInitialize: tensorrt_llm_0_0 (CPU device 0)"
/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
W0410 16:57:34.830000 339 torch/utils/cpp_extension.py:2422] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
W0410 16:57:34.830000 339 torch/utils/cpp_extension.py:2422] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
[TensorRT-LLM] TensorRT LLM version: 1.1.0
I0410 16:57:35.506899 1 model.py:166] "[trtllm] rank0 is starting trtllm engine with args: {'model': 'TinyLlama/TinyLlama-1.1B-Chat-v1.0', 'backend': 'pytorch', 'tensor_parallel_size': 1, 'pipeline_parallel_size': 1, 'speculative_config': NGramDecodingConfig(max_draft_len=3, speculative_model_dir=None, max_concurrency=None, load_format=None, max_matching_ngram_size=3, is_keep_all=True, is_use_oldest=True, is_public_pool=True), 'disable_overlap_scheduler': True}"
[04/10/2026-16:57:35] [TRT-LLM] [I] Using LLM with PyTorch backend
[04/10/2026-16:57:35] [TRT-LLM] [W] Using default gpus_per_node: 2
[04/10/2026-16:57:35] [TRT-LLM] [I] Set nccl_plugin to None.
[04/10/2026-16:57:35] [TRT-LLM] [I] neither checkpoint_format nor checkpoint_loader were provided, checkpoint_format will be set to HF.
The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
You are using a model of type llama to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
�[33;20mrank 0 using MpiPoolSession to spawn MPI processes
�[0m/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
Multiple distributions found for package optimum. Picked distribution: optimum
Multiple distributions found for package modelopt. Picked distribution: nvidia-modelopt
W0410 16:57:44.402000 521 torch/utils/cpp_extension.py:2422] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
W0410 16:57:44.402000 521 torch/utils/cpp_extension.py:2422] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
[TensorRT-LLM] TensorRT LLM version: 1.1.0
[TensorRT-LLM][INFO] Refreshed the MPI local session
`torch_dtype` is deprecated! Use `dtype` instead!

Loading safetensors weights in parallel:   0%|          | 0/1 [00:00<?, ?it/s]
Loading safetensors weights in parallel: 100%|██████████| 1/1 [00:00<00:00, 163.20it/s]

Loading weights:   0%|          | 0/449 [00:00<?, ?it/s]
Loading weights:  26%|██▌       | 117/449 [00:00<00:00, 1116.46it/s]
Loading weights:  67%|██████▋   | 299/449 [00:00<00:00, 1515.26it/s]
Loading weights: 100%|██████████| 449/449 [00:00<00:00, 1536.24it/s]
Model init total -- 1.70s
[TensorRT-LLM][INFO] Max KV cache blocks per sequence: 65 [window size=2051], tokens per block=32, primary blocks=2048, secondary blocks=0
[TensorRT-LLM][INFO] Number of tokens per block: 32.
[TensorRT-LLM][INFO] [MemUsageChange] Allocated 1.38 GiB for max tokens in paged KV cache (65536).
[TensorRT-LLM][WARNING] Attention workspace size is not enough, increase the size from 0 bytes to 67462144 bytes
[TensorRT-LLM][WARNING] Attention workspace size is not enough, increase the size from 0 bytes to 35943424 bytes
[TensorRT-LLM][WARNING] [kv cache manager] storeContextBlocks: Can not find sequence for request 2048
[TensorRT-LLM][WARNING] [kv cache manager] storeContextBlocks: Can not find sequence for request 2049
[TensorRT-LLM][WARNING] [kv cache manager] storeContextBlocks: Can not find sequence for request 2050
[TensorRT-LLM][WARNING] [kv cache manager] storeContextBlocks: Can not find sequence for request 2051
[TensorRT-LLM][WARNING] [kv cache manager] storeContextBlocks: Can not find sequence for request 2052
[TensorRT-LLM][INFO] Max KV cache blocks per sequence: 65 [window size=2051], tokens per block=32, primary blocks=116617, secondary blocks=0
[TensorRT-LLM][INFO] Number of tokens per block: 32.
[TensorRT-LLM][INFO] [MemUsageChange] Allocated 78.30 GiB for max tokens in paged KV cache (3731744).
[TensorRT-LLM][WARNING] Attention workspace size is not enough, increase the size from 0 bytes to 67462144 bytes
[TensorRT-LLM][WARNING] Attention workspace size is not enough, increase the size from 0 bytes to 35943424 bytes
I0410 16:58:08.485508 1 model_lifecycle.cc:849] "successfully loaded 'tensorrt_llm'"
I0410 16:58:08.485783 1 server.cc:620] 
+------------------+------+
| Repository Agent | Path |
+------------------+------+
+------------------+------+

I0410 16:58:08.485849 1 server.cc:647] 
+---------+-------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| Backend | Path                                                  | Config                                                                                                                                                                                            |
+---------+-------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| python  | /opt/tritonserver/backends/python/libtriton_python.so | {"cmdline":{"auto-complete-config":"true","backend-directory":"/opt/tritonserver/backends","min-compute-capability":"6.000000","shm-region-prefix-name":"prefix0_","default-max-batch-size":"4"}} |
+---------+-------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+

I0410 16:58:08.485894 1 server.cc:690] 
+--------------+---------+--------+
| Model        | Version | Status |
+--------------+---------+--------+
| tensorrt_llm | 1       | READY  |
+--------------+---------+--------+

I0410 16:58:08.547882 1 metrics.cc:889] "Collecting metrics for GPU 0: NVIDIA H100 NVL"
I0410 16:58:08.547920 1 metrics.cc:889] "Collecting metrics for GPU 1: NVIDIA H100 NVL"
I0410 16:58:08.558088 1 metrics.cc:782] "Collecting CPU metrics"
I0410 16:58:08.558287 1 tritonserver.cc:2598] 
+----------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| Option                           | Value                                                                                                                                                                                                           |
+----------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| server_id                        | triton                                                                                                                                                                                                          |
| server_version                   | 2.64.0                                                                                                                                                                                                          |
| server_extensions                | classification sequence model_repository model_repository(unload_dependents) schedule_policy model_configuration system_shared_memory cuda_shared_memory binary_tensor_data parameters statistics trace logging |
| model_repository_path[0]         | /llmapi_repo/                                                                                                                                                                                                   |
| model_control_mode               | MODE_NONE                                                                                                                                                                                                       |
| strict_model_config              | 0                                                                                                                                                                                                               |
| model_config_name                |                                                                                                                                                                                                                 |
| rate_limit                       | OFF                                                                                                                                                                                                             |
| pinned_memory_pool_byte_size     | 268435456                                                                                                                                                                                                       |
| cuda_memory_pool_byte_size{0}    | 67108864                                                                                                                                                                                                        |
| cuda_memory_pool_byte_size{1}    | 67108864                                                                                                                                                                                                        |
| min_supported_compute_capability | 6.0                                                                                                                                                                                                             |
| strict_readiness                 | 1                                                                                                                                                                                                               |
| exit_timeout                     | 30                                                                                                                                                                                                              |
| cache_enabled                    | 0                                                                                                                                                                                                               |
+----------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+

I0410 16:58:08.565597 1 grpc_server.cc:2562] "Started GRPCInferenceService at 0.0.0.0:8001"
I0410 16:58:08.565866 1 http_server.cc:4815] "Started HTTPService at 0.0.0.0:8000"
I0410 16:58:08.607374 1 http_server.cc:358] "Started Metrics Service at 0.0.0.0:8002"
[TensorRT-LLM][WARNING] [kv cache manager] storeContextBlocks: Can not find sequence for request 2048
Signal (15) received.
I0410 17:29:18.352502 1 grpc_server.cc:2640] "Timeout 30: Found 3 gRPC service connections and inference handlers"
I0410 17:29:19.352712 1 server.cc:312] "Waiting for in-flight requests to complete."
I0410 17:29:19.352808 1 server.cc:328] "Timeout 29: Found 0 model versions that have in-flight inferences"
I0410 17:29:19.353399 1 server.cc:343] "All models are stopped, unloading models"
I0410 17:29:19.353416 1 server.cc:352] "Timeout 29: Found 1 live models and 0 in-flight non-inference requests"
I0410 17:29:20.353556 1 server.cc:352] "Timeout 28: Found 1 live models and 0 in-flight non-inference requests"
I0410 17:29:20.357674 1 model.py:667] "[trtllm] Issuing finalize to trtllm backend"
I0410 17:29:21.160382 1 model.py:291] "[trtllm] Shutdown complete"
I0410 17:29:21.205203 1 model.py:689] "[trtllm] Running Garbage Collector on finalize..."
I0410 17:29:21.353736 1 server.cc:352] "Timeout 27: Found 1 live models and 0 in-flight non-inference requests"
I0410 17:29:21.647287 1 model.py:692] "[trtllm] Garbage Collector on finalize... done"
I0410 17:29:22.353934 1 server.cc:352] "Timeout 26: Found 1 live models and 0 in-flight non-inference requests"
I0410 17:29:23.354102 1 server.cc:352] "Timeout 25: Found 1 live models and 0 in-flight non-inference requests"
I0410 17:29:23.939953 1 model_lifecycle.cc:636] "successfully unloaded 'tensorrt_llm' version 1"
I0410 17:29:24.354293 1 server.cc:352] "Timeout 24: Found 0 live models and 0 in-flight non-inference requests"

Query the server

curl -X POST localhost:8000/v2/models/tensorrt_llm/genera
te -d '{"text_input": "The future of AI is", "sampling_param_max_tokens": 50}' | jq 
 
{
  "model_name": "tensorrt_llm",
  "model_version": "1",
  "text_output": "The future of AI is bright, and it's not just for big companies. Small businesses can also benefit from AI by automating repetitive tasks, improving customer service, and enhancing productivity. By investing in AI, small businesses"
}

Comment thread Feature_Guide/Speculative_Decoding/TRT-LLM/README.md Outdated
Comment thread Feature_Guide/Speculative_Decoding/TRT-LLM/README.md Outdated
Comment thread Feature_Guide/Speculative_Decoding/TRT-LLM/README.md Outdated
Comment thread Feature_Guide/Speculative_Decoding/TRT-LLM/README.md Outdated
Comment thread Feature_Guide/Speculative_Decoding/TRT-LLM/README.md Outdated
Comment thread Feature_Guide/Speculative_Decoding/TRT-LLM/README.md Outdated
…d code blocks render correctly in GFM

Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
@whoisj
Copy link
Copy Markdown
Contributor

whoisj commented Apr 17, 2026

LGTM, will approve if @yinggeh doesn't have any blockers.

Comment thread Feature_Guide/Speculative_Decoding/TRT-LLM/README.md Outdated
Comment thread Feature_Guide/Speculative_Decoding/TRT-LLM/README.md Outdated
Comment thread Feature_Guide/Speculative_Decoding/TRT-LLM/README.md Outdated
- Standardize EAGLE-3 naming (hyphen) throughout — was inconsistent
- Indent numbered list content 3 spaces so items render as a single list
- Fix launch_triton_server.py URL: was tensorrtllm_backend/scripts (404),
  now NVIDIA/TensorRT-LLM/triton_backend/scripts (correct repo)
- Fix engine backend archive links: point to
  tensorrtllm_backend#tensorrt-engine-backend instead of deleted archive file

Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
@yinggeh
Copy link
Copy Markdown
Contributor

yinggeh commented Apr 21, 2026

Pre-commit error can be safely ignored
error: cannot format Quick_Deploy/PyTorch/export.py: module 'ast' has no attribute 'Str'

@whoisj whoisj merged commit 4fe2b90 into triton-inference-server:main Apr 21, 2026
2 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants