Fading Coder

One Final Commit for the Last Sprint

Home > Tech > Content

Implementing Custom TritonServer Backends in C++ and Python

Tech 1

Environment Setup

CMake Installation

TritonServer backend compilation requires CMake 3.17 or higher. Download the latest version (3.28) from the official repository:

wget https://github.com/Kitware/CMake/releases/download/v3.28.1/cmake-3.28.1.tar.gz

Extract and configure:

tar zxvf cmake-3.28.1.tar.gz
cd cmake-3.28.1/
./bootstrap

If you encounter SSL errors during bootstrapping, install the development libraries:

sudo apt-get install libssl-dev

After successful bootstrap, compile and install:

make && sudo make install

RapidJSON Installation

Clone the RapidJSON library:

git clone https://github.com/miloyip/rapidjson.git
cd rapidjson
mkdir build
cd build
make && make install

C++ Custom Backend Implementation

Backend Structure Overview

The C++ custom backend requires implementing seven core APIs. These entry points enable TritonServer to communicate with your custom computation logic. Based on the reference implementation in the backend examples directory, here's how to structure your implementation.

Implementation Example

Create a new backend directory under backend/examples/backends/ and implement the following:

#include "triton/backend/backend_common.h"
#include "triton/backend/backend_input_collector.h"
#include "triton/backend/backend_model.h"
#include "triton/backend/backend_model_instance.h"
#include "triton/backend/backend_output_responder.h"
#include "triton/core/tritonbackend.h"

namespace triton { namespace backend { namespace custom {

class BackendContext;
class ModelResource;
class InstanceResource;

// State management for the backend
class BackendState {
public:
    static TRITONSERVER_Error* Create(BackendState** state);
    ~BackendState();
    
    const std::string& Identifier() const { return identifier_; }
    void SetIdentifier(const std::string& id) { identifier_ = id; }
    
private:
    BackendState() : identifier_("custom_backend") {}
    std::string identifier_;
};

TRITONSERVER_Error*
BackendState::Create(BackendState** state)
{
    try {
        *state = new BackendState();
    }
    catch (const std::bad_alloc& ex) {
        return TRITONSERVER_ErrorNew(
            TRITONSERVER_ERROR_INTERNAL,
            "failed to allocate backend state");
    }
    return nullptr;
}

BackendState::~BackendState() = default;

extern "C" {

// Backend initialization - called once when loaded
TRITONSERVER_Error*
TRITONBACKEND_Initialize(TRITONBACKEND_Backend* backend)
{
    const char* backend_name = nullptr;
    RETURN_IF_ERROR(TRITONBACKEND_BackendName(backend, &backend_name));
    
    LOG_MESSAGE(
        TRITONSERVER_LOG_INFO,
        (std::string("Initializing backend: ") + backend_name).c_str());
    
    // Verify API version compatibility
    uint32_t major_version, minor_version;
    RETURN_IF_ERROR(TRITONBACKEND_ApiVersion(&major_version, &minor_version));
    
    if ((major_version != TRITONBACKEND_API_VERSION_MAJOR) ||
        (minor_version < TRITONBACKEND_API_VERSION_MINOR)) {
        return TRITONSERVER_ErrorNew(
            TRITONSERVER_ERROR_UNSUPPORTED,
            "incompatible API version");
    }
    
    // Retrieve backend configuration
    TRITONSERVER_Message* config_message = nullptr;
    RETURN_IF_ERROR(TRITONBACKEND_BackendConfig(backend, &config_message));
    
    const char* config_buffer = nullptr;
    size_t config_size = 0;
    RETURN_IF_ERROR(TRITONSERVER_MessageSerializeToJson(
        config_message, &config_buffer, &config_size));
    
    LOG_MESSAGE(
        TRITONSERVER_LOG_INFO,
        (std::string("Backend config: ") + config_buffer).c_str());
    
    // Create and attach backend state
    BackendState* backend_state = nullptr;
    RETURN_IF_ERROR(BackendState::Create(&backend_state));
    RETURN_IF_ERROR(
        TRITONBACKEND_BackendSetState(backend, reinterpret_cast<void*>(backend_state)));
    
    return nullptr;
}

// Backend cleanup - called when unloaded
TRITONSERVER_Error*
TRITONBACKEND_Finalize(TRITONBACKEND_Backend* backend)
{
    void* raw_state = nullptr;
    RETURN_IF_ERROR(TRITONBACKEND_BackendState(backend, &raw_state));
    
    if (raw_state != nullptr) {
        BackendState* state = reinterpret_cast<BackendState*>(raw_state);
        delete state;
    }
    
    return nullptr;
}

}  // extern "C"

// Model-level state management
class ModelResource : public BackendModel {
public:
    static TRITONSERVER_Error* Create(
        TRITONBACKEND_Model* model, ModelResource** resource);
    virtual ~ModelResource() = default;
    
    const std::string& InputName() const { return input_name_; }
    const std::string& OutputName() const { return output_name_; }
    TRITONSERVER_DataType DataType() const { return data_type_; }
    const std::vector<int64_t>& Shape() const { return shape_; }
    
private:
    ModelResource(TRITONBACKEND_Model* model);
    TRITONSERVER_Error* Validate();
    
    std::string input_name_;
    std::string output_name_;
    TRITONSERVER_DataType data_type_;
    std::vector<int64_t> shape_;
};

ModelResource::ModelResource(TRITONBACKEND_Model* model)
    : BackendModel(model)
{
    THROW_IF_BACKEND_MODEL_ERROR(Validate());
}

TRITONSERVER_Error*
ModelResource::Create(TRITONBACKEND_Model* model, ModelResource** resource)
{
    try {
        *resource = new ModelResource(model);
    }
    catch (const BackendModelException& ex) {
        RETURN_IF_ERROR(ex.err_);
    }
    return nullptr;
}

TRITONSERVER_Error*
ModelResource::Validate()
{
    common::TritonJson::Value inputs, outputs;
    RETURN_IF_ERROR(ModelConfig().MemberAsArray("input", &inputs));
    RETURN_IF_ERROR(ModelConfig().MemberAsArray("output", &outputs));
    
    // Validate single input/output configuration
    RETURN_ERROR_IF_FALSE(
        inputs.ArraySize() == 1, TRITONSERVER_ERROR_INVALID_ARG,
        "model requires exactly one input");
    RETURN_ERROR_IF_FALSE(
        outputs.ArraySize() == 1, TRITONSERVER_ERROR_INVALID_ARG,
        "model requires exactly one output");
    
    // Extract input configuration
    common::TritonJson::Value input_config;
    RETURN_IF_ERROR(inputs.IndexAsObject(0, &input_config));
    
    const char* in_name = nullptr;
    size_t in_name_len = 0;
    RETURN_IF_ERROR(input_config.MemberAsString("name", &in_name, &in_name_len));
    input_name_ = std::string(in_name);
    
    // Extract output configuration
    common::TritonJson::Value output_config;
    RETURN_IF_ERROR(outputs.IndexAsObject(0, &output_config));
    
    const char* out_name = nullptr;
    size_t out_name_len = 0;
    RETURN_IF_ERROR(output_config.MemberAsString("name", &out_name, &out_name_len));
    output_name_ = std::string(out_name);
    
    // Verify datatype matching
    std::string input_dtype, output_dtype;
    RETURN_IF_ERROR(input_config.MemberAsString("data_type", &input_dtype));
    RETURN_IF_ERROR(output_config.MemberAsString("data_type", &output_dtype));
    RETURN_ERROR_IF_FALSE(
        input_dtype == output_dtype, TRITONSERVER_ERROR_INVALID_ARG,
        "input and output datatypes must match");
    
    data_type_ = ModelConfigDataTypeToTritonServerDataType(input_dtype);
    
    // Parse shape dimensions
    RETURN_IF_ERROR(backend::ParseShape(input_config, "dims", &shape_));
    
    return nullptr;
}

extern "C" {

// Model loading
TRITONSERVER_Error*
TRITONBACKEND_ModelInitialize(TRITONBACKEND_Model* model)
{
    ModelResource* model_resource = nullptr;
    RETURN_IF_ERROR(ModelResource::Create(model, &model_resource));
    RETURN_IF_ERROR(
        TRITONBACKEND_ModelSetState(model, reinterpret_cast<void*>(model_resource)));
    
    LOG_MESSAGE(TRITONSERVER_LOG_INFO, "Model initialized successfully");
    return nullptr;
}

// Model unloading
TRITONSERVER_Error*
TRITONBACKEND_ModelFinalize(TRITONBACKEND_Model* model)
{
    void* state_ptr = nullptr;
    RETURN_IF_ERROR(TRITONBACKEND_ModelState(model, &state_ptr));
    
    if (state_ptr != nullptr) {
        ModelResource* resource = reinterpret_cast<ModelResource*>(state_ptr);
        delete resource;
    }
    
    return nullptr;
}

}  // extern "C"

// Instance-level state management
class InstanceResource : public BackendModelInstance {
public:
    static TRITONSERVER_Error* Create(
        ModelResource* model_state,
        TRITONBACKEND_ModelInstance* instance,
        InstanceResource** resource);
    virtual ~InstanceResource() = default;
    
    ModelResource* GetModelState() const { return model_state_; }
    
private:
    InstanceResource(
        ModelResource* model_state,
        TRITONBACKEND_ModelInstance* instance)
        : BackendModelInstance(model_state, instance),
          model_state_(model_state)
    {
    }
    
    ModelResource* model_state_;
};

TRITONSERVER_Error*
InstanceResource::Create(
    ModelResource* model_state,
    TRITONBACKEND_ModelInstance* instance,
    InstanceResource** resource)
{
    try {
        *resource = new InstanceResource(model_state, instance);
    }
    catch (const BackendModelInstanceException& ex) {
        RETURN_IF_ERROR(ex.err_);
    }
    return nullptr;
}

extern "C" {

// Instance initialization
TRITONSERVER_Error*
TRITONBACKEND_ModelInstanceInitialize(TRITONBACKEND_ModelInstance* instance)
{
    TRITONBACKEND_Model* model_handle = nullptr;
    RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceModel(instance, &model_handle));
    
    void* model_state_raw = nullptr;
    RETURN_IF_ERROR(TRITONBACKEND_ModelState(model_handle, &model_state_raw));
    ModelResource* model_state = reinterpret_cast<ModelResource*>(model_state_raw);
    
    InstanceResource* instance_state = nullptr;
    RETURN_IF_ERROR(InstanceResource::Create(model_state, instance, &instance_state));
    RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceSetState(
        instance, reinterpret_cast<void*>(instance_state)));
    
    return nullptr;
}

// Instance cleanup
TRITONSERVER_Error*
TRITONBACKEND_ModelInstanceFinalize(TRITONBACKEND_ModelInstance* instance)
{
    void* instance_state_raw = nullptr;
    RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceState(instance, &instance_state_raw));
    
    if (instance_state_raw != nullptr) {
        InstanceResource* instance_state = 
            reinterpret_cast<InstanceResource*>(instance_state_raw);
        delete instance_state;
    }
    
    return nullptr;
}

}  // extern "C"

extern "C" {

// Main inference execution path
TRITONSERVER_Error*
TRITONBACKEND_ModelInstanceExecute(
    TRITONBACKEND_ModelInstance* instance,
    TRITONBACKEND_Request** requests,
    uint32_t request_count)
{
    // Retrieve instance state
    InstanceResource* instance_state = nullptr;
    RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceState(
        instance, reinterpret_cast<void**>(&instance_state)));
    ModelResource* model_state = instance_state->GetModelState();
    
    LOG_MESSAGE(
        TRITONSERVER_LOG_INFO,
        (std::string("Processing batch of ") + 
         std::to_string(request_count) + " requests").c_str());
    
    // Prepare response objects
    std::vector<TRITONBACKEND_Response*> responses;
    responses.reserve(request_count);
    
    for (uint32_t i = 0; i < request_count; ++i) {
        TRITONBACKEND_Response* response = nullptr;
        RETURN_IF_ERROR(TRITONBACKEND_ResponseNew(&response, requests[i]));
        responses.push_back(response);
    }
    
    // Input collection using backend utilities
    BackendInputCollector collector(
        requests, request_count, &responses,
        model_state->TritonMemoryManager(),
        false, nullptr);
    
    std::vector<std::pair<TRITONSERVER_MemoryType, int64_t>> memory_types = {
        {TRITONSERVER_MEMORY_CPU, 0},
        {TRITONSERVER_MEMORY_CPU_PINNED, 0}
    };
    
    const char* input_buffer = nullptr;
    size_t input_buffer_size = 0;
    TRITONSERVER_MemoryType memory_type = TRITONSERVER_MEMORY_CPU;
    int64_t device_id = 0;
    
    RESPOND_ALL_AND_SET_NULL_IF_ERROR(
        responses, request_count,
        collector.ProcessTensor(
            model_state->InputName().c_str(),
            nullptr, 0,
            memory_types,
            &input_buffer, &input_buffer_size,
            &memory_type, &device_id));
    
    // Finalize input collection
    bool needs_sync = collector.Finalize();
    if (needs_sync) {
        LOG_MESSAGE(TRITONSERVER_LOG_ERROR, "unexpected sync required");
    }
    
    // Prepare output tensor using the same buffer (passthrough)
    const char* output_buffer = input_buffer;
    TRITONSERVER_MemoryType output_memory_type = memory_type;
    int64_t output_device_id = device_id;
    
    // Determine tensor shape with batch dimension handling
    bool supports_batching = false;
    RESPOND_ALL_AND_SET_NULL_IF_ERROR(
        responses, request_count,
        model_state->SupportsFirstDimBatching(&supports_batching));
    
    std::vector<int64_t> output_shape = model_state->Shape();
    if (supports_batching) {
        output_shape.insert(output_shape.begin(), request_count);
    }
    
    // Output response handling
    BackendOutputResponder responder(
        requests, request_count, &responses,
        model_state->TritonMemoryManager(),
        supports_batching, false, nullptr);
    
    responder.ProcessTensor(
        model_state->OutputName().c_str(),
        model_state->DataType(),
        output_shape,
        output_buffer, output_memory_type, output_device_id);
    
    bool responder_sync = responder.Finalize();
    if (responder_sync) {
        LOG_MESSAGE(TRITONSERVER_LOG_ERROR, "unexpected responder sync");
    }
    
    // Send all responses
    for (auto& response : responses) {
        if (response != nullptr) {
            LOG_IF_ERROR(
                TRITONBACKEND_ResponseSend(
                    response, TRITONSERVER_RESPONSE_COMPLETE_FINAL, nullptr),
                "response send failed");
        }
    }
    
    // Release requests and report statistics
    for (uint32_t i = 0; i < request_count; ++i) {
#ifdef TRITON_ENABLE_STATS
        LOG_IF_ERROR(
            TRITONBACKEND_RequestRelease(requests[i], TRITONSERVER_REQUEST_RELEASE_ALL),
            "request release failed");
#endif
    }
    
    return nullptr;
}

}  // extern "C"

}}}  // namespace triton::backend::custom

Build Commands

Compile the backend with CMake:

mkdir build
cmake -DCMAKE_INSTALL_PREFIX:PATH=`pwd`/install ..
make install

This generates a shared library named libtriton_custom.so.

Model Configuration

Create a config.pbtxt file in your model repository:

backend: "custom"
max_batch_size: 16
dynamic_batching {
  max_queue_delay_microseconds: 1000000
}

input [
  {
    name: "INPUTTensor"
    data_type: TYPE_FP32
    dims: [ -1, 128 ]
  }
]

output [
  {
    name: "OUTPUTTensor"
    data_type: TYPE_FP32
    dims: [ -1, 128 ]
  }
]

instance_group [
  {
    kind: KIND_CPU
    count: 2
  }
]

The backend binary must be named following the pattern libtriton_<backend_name>.so.

Server Deployment

Launch the TritonServer container with the model repository mounted:

docker run --rm \
  -p8000:8000 \
  -p8001:8001 \
  -p8002:8002 \
  -v /path/to/model_repository:/models \
  -it nvcr.io/nvidia/tritonserver:23.12-py3

Start the server:

tritonserver --model-repository=/models

Python Custom Backend Implementation

Overview

Python backends require implementing only three methods, significantly reducing implementation complexity compared to C++ backends. The model class must be named TritonPythonModel.

Implementation

import json
import numpy as np
import triton_python_backend_utils as pb_utils


class TritonPythonModel:
    """Python backend implementation for custom inference logic."""
    
    def initialize(self, args):
        """Called once during model loading.
        
        Parameters
        ----------
        args : dict
            Dictionary containing:
            - model_config: JSON string with model configuration
            - model_instance_kind: Device type (CPU/GPU)
            - model_instance_device_id: Device identifier
            - model_repository: Path to model directory
            - model_version: Model version number
            - model_name: Name of the model
        """
        self.model_config = json.loads(args["model_config"])
        
        # Extract output configurations
        self.output0_config = pb_utils.get_output_config_by_name(
            self.model_config, "OUTPUT0")
        self.output1_config = pb_utils.get_output_config_by_name(
            self.model_config, "OUTPUT1")
        
        # Convert Triton types to numpy dtypes
        self.output0_type = pb_utils.triton_string_to_numpy(
            self.output0_config["data_type"])
        self.output1_type = pb_utils.triton_string_to_numpy(
            self.output1_config["data_type"])
        
        print(f"Model initialized: {args['model_name']}")
        
    def execute(self, requests):
        """Process inference requests.
        
        Parameters
        ----------
        requests : list
            List of pb_utils.InferenceRequest objects
            
        Returns
        -------
        list
            List of pb_utils.InferenceResponse objects
        """
        output0_type = self.output0_type
        output1_type = self.output1_type
        
        output_responses = []
        
        for request in requests:
            # Extract input tensors by name
            input_tensor0 = pb_utils.get_input_tensor_by_name(request, "INPUT0")
            input_tensor1 = pb_utils.get_input_tensor_by_name(request, "INPUT1")
            
            # Perform computation
            data0 = input_tensor0.as_numpy()
            data1 = input_tensor1.as_numpy()
            
            result_add = data0 + data1
            result_subtract = data0 - data1
            
            # Create output tensors
            output_tensor0 = pb_utils.Tensor(
                "OUTPUT0", 
                result_add.astype(output0_type))
            output_tensor1 = pb_utils.Tensor(
                "OUTPUT1", 
                result_subtract.astype(output1_type))
            
            # Build response
            response = pb_utils.InferenceResponse(
                output_tensors=[output_tensor0, output_tensor1])
            output_responses.append(response)
            
        return output_responses
    
    def finalize(self):
        """Cleanup resources before model unloading."""
        print("Model cleanup completed")

Model Configuration

name: "addition_backend"
backend: "python"

input [
  {
    name: "INPUT0"
    data_type: TYPE_FP32
    dims: [ 4 ]
  },
  {
    name: "INPUT1"
    data_type: TYPE_FP32
    dims: [ 4 ]
  }
]

output [
  {
    name: "OUTPUT0"
    data_type: TYPE_FP32
    dims: [ 4 ]
  },
  {
    name: "OUTPUT1"
    data_type: TYPE_FP32
    dims: [ 4 ]
  }
]

instance_group [{ kind: KIND_CPU }]

Directory Structure

model_repository/
└── addition_backend/
    ├── config.pbtxt
    └── 1/
        └── model.py

Client-Side Inference

HTTP Client Implementation

import numpy as np
import tritonclient.http as httpclient
from tritonclient.utils import np_to_triton_dtype


def main():
    model_name = "addition_backend"
    input_shape = [4]
    
    with httpclient.InferenceServerClient("localhost:8000") as client:
        # Generate random test data
        input0 = np.random.rand(*input_shape).astype(np.float32)
        input1 = np.random.rand(*input_shape).astype(np.float32)
        
        # Prepare inputs
        inputs = [
            httpclient.InferInput(
                "INPUT0",
                input0.shape,
                np_to_triton_dtype(input0.dtype)),
            httpclient.InferInput(
                "INPUT1",
                input1.shape,
                np_to_triton_dtype(input1.dtype))
        ]
        
        inputs[0].set_data_from_numpy(input0)
        inputs[1].set_data_from_numpy(input1)
        
        # Request outputs
        outputs = [
            httpclient.InferRequestedOutput("OUTPUT0"),
            httpclient.InferRequestedOutput("OUTPUT1")
        ]
        
        # Execute inference
        response = client.infer(
            model_name,
            inputs,
            request_id="1",
            outputs=outputs)
        
        # Extract results
        output0 = response.as_numpy("OUTPUT0")
        output1 = response.as_numpy("OUTPUT1")
        
        # Verify results
        print(f"Input A: {input0}")
        print(f"Input B: {input1}")
        print(f"A + B = {output0}")
        print(f"A - B = {output1}")
        
        assert np.allclose(input0 + input1, output0), "Addition verification failed"
        assert np.allclose(input0 - input1, output1), "Subtraction verification failed"
        
        print("Inference test passed")


if __name__ == "__main__":
    main()

Running the Client

Execute the client script within the Triton client container:

docker run --rm -it \
  --network host \
  nvcr.io/nvidia/tritonserver:23.12-py3-sdk \
  python client_script.py

Key Differences: C++ vs Python Backends

Aspect C++ Backend Python Backend
API Count 7 functions 3 methods
Memory Control Manual management Automatic (Python)
Performance Higher throughput Slightly lower
Complexity More boilerplate Simplified implementation
Deployment Shared library (.so) Python source files

Related Articles

Understanding Strong and Weak References in Java

Strong References Strong reference are the most prevalent type of object referencing in Java. When an object has a strong reference pointing to it, the garbage collector will not reclaim its memory. F...

Comprehensive Guide to SSTI Explained with Payload Bypass Techniques

Introduction Server-Side Template Injection (SSTI) is a vulnerability in web applications where user input is improper handled within the template engine and executed on the server. This exploit can r...

Implement Image Upload Functionality for Django Integrated TinyMCE Editor

Django’s Admin panel is highly user-friendly, and pairing it with TinyMCE, an effective rich text editor, simplifies content management significantly. Combining the two is particular useful for bloggi...

Leave a Comment

Anonymous

◎Feel free to join the discussion and share your thoughts.