This README showcases how to deploy a simple ResNet model on Triton Inference Server. While Triton doesn't yet have a dedicated JAX backend, JAX/Flax models can be deployed using Python Backend. If you are new to Triton, it is recommended to watch this getting started video and review Part 1 of the conceptual guide before proceeding. For the purposes of demonstration, we are using a pre-trained model provided by flaxmodels.
Before diving into the specifics execution, an understanding of the underlying structure is needed. To use a JAX or a Flax model, the recommended path for this is using a "Python Model". Python models in Triton are basically classes with three Triton-specific functions: initialize, execute and finalize. Users can customize this class to serve any python function they write or any model they want as long as it can be loaded in python runtime. The initialize function runs when the python model is loaded into memory, and the finalize function runs when the model is unloaded from memory. Both of these functions are optional to define. For the purposes of this example, we will use the initialize and the execute functions to load and run(respectively) a resnet18 model.
class TritonPythonModel:
def initialize(self, args):
...
def execute(self, requests):
...
To use Triton, we need to build a model repository. The structure of the repository as follows:
model_repository
|
+-- resnet50
|
+-- config.pbtxt
+-- 1
|
+-- model.py
For this example, we have pre-built the model repository. Next, we install the required dependencies and launch the Triton Inference Server.
# Replace the yy.mm in the image name with the release year and month
# of the Triton version needed, eg. 22.12
docker run --gpus=all -it --shm-size=256m --rm -p8000:8000 -p8001:8001 -p8002:8002 -v$(pwd):/workspace/ -v/$(pwd)/model_repository:/models nvcr.io/nvidia/tritonserver:<yy.mm>-py3 bash
pip install --upgrade pip
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install --upgrade git+https://github.com/matthias-wright/flaxmodels.git
Let's breakdown the client application. First, we setup a connection with the Triton Inference Server.
client = httpclient.InferenceServerClient(url="localhost:8000")
Then we set the input and output arrays.
# Set Inputs
input_tensors = [
httpclient.InferInput("image", image.shape, datatype="FP32")
]
input_tensors[0].set_data_from_numpy(image)
# Set outputs
outputs = [
httpclient.InferRequestedOutput("fc_out")
]
Lastly, we query send a request to the Triton Inference Server.
# Query
query_response = client.infer(model_name="resnet50",
inputs=input_tensors,
outputs=outputs)
# Output
out = query_response.as_numpy("fc_out")