Implementing Custom TritonServer Backends in C++ and Python
Environment Setup
CMake Installation
TritonServer backend compilation requires CMake 3.17 or higher. Download the latest version (3.28) from the official repository:
wget https://github.com/Kitware/CMake/releases/download/v3.28.1/cmake-3.28.1.tar.gz
Extract and configure:
tar zxvf cmake-3.28.1.tar.gz
cd cmake-3.28.1/
./bootstrap
If you encounter SSL errors during bootstrapping, install the development libraries:
sudo apt-get install libssl-dev
After successful bootstrap, compile and install:
make && sudo make install
RapidJSON Installation
Clone the RapidJSON library:
git clone https://github.com/miloyip/rapidjson.git
cd rapidjson
mkdir build
cd build
make && make install
C++ Custom Backend Implementation
Backend Structure Overview
The C++ custom backend requires implementing seven core APIs. These entry points enable TritonServer to communicate with your custom computation logic. Based on the reference implementation in the backend examples directory, here's how to structure your implementation.
Implementation Example
Create a new backend directory under backend/examples/backends/ and implement the following:
#include "triton/backend/backend_common.h"
#include "triton/backend/backend_input_collector.h"
#include "triton/backend/backend_model.h"
#include "triton/backend/backend_model_instance.h"
#include "triton/backend/backend_output_responder.h"
#include "triton/core/tritonbackend.h"
namespace triton { namespace backend { namespace custom {
class BackendContext;
class ModelResource;
class InstanceResource;
// State management for the backend
class BackendState {
public:
static TRITONSERVER_Error* Create(BackendState** state);
~BackendState();
const std::string& Identifier() const { return identifier_; }
void SetIdentifier(const std::string& id) { identifier_ = id; }
private:
BackendState() : identifier_("custom_backend") {}
std::string identifier_;
};
TRITONSERVER_Error*
BackendState::Create(BackendState** state)
{
try {
*state = new BackendState();
}
catch (const std::bad_alloc& ex) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INTERNAL,
"failed to allocate backend state");
}
return nullptr;
}
BackendState::~BackendState() = default;
extern "C" {
// Backend initialization - called once when loaded
TRITONSERVER_Error*
TRITONBACKEND_Initialize(TRITONBACKEND_Backend* backend)
{
const char* backend_name = nullptr;
RETURN_IF_ERROR(TRITONBACKEND_BackendName(backend, &backend_name));
LOG_MESSAGE(
TRITONSERVER_LOG_INFO,
(std::string("Initializing backend: ") + backend_name).c_str());
// Verify API version compatibility
uint32_t major_version, minor_version;
RETURN_IF_ERROR(TRITONBACKEND_ApiVersion(&major_version, &minor_version));
if ((major_version != TRITONBACKEND_API_VERSION_MAJOR) ||
(minor_version < TRITONBACKEND_API_VERSION_MINOR)) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_UNSUPPORTED,
"incompatible API version");
}
// Retrieve backend configuration
TRITONSERVER_Message* config_message = nullptr;
RETURN_IF_ERROR(TRITONBACKEND_BackendConfig(backend, &config_message));
const char* config_buffer = nullptr;
size_t config_size = 0;
RETURN_IF_ERROR(TRITONSERVER_MessageSerializeToJson(
config_message, &config_buffer, &config_size));
LOG_MESSAGE(
TRITONSERVER_LOG_INFO,
(std::string("Backend config: ") + config_buffer).c_str());
// Create and attach backend state
BackendState* backend_state = nullptr;
RETURN_IF_ERROR(BackendState::Create(&backend_state));
RETURN_IF_ERROR(
TRITONBACKEND_BackendSetState(backend, reinterpret_cast<void*>(backend_state)));
return nullptr;
}
// Backend cleanup - called when unloaded
TRITONSERVER_Error*
TRITONBACKEND_Finalize(TRITONBACKEND_Backend* backend)
{
void* raw_state = nullptr;
RETURN_IF_ERROR(TRITONBACKEND_BackendState(backend, &raw_state));
if (raw_state != nullptr) {
BackendState* state = reinterpret_cast<BackendState*>(raw_state);
delete state;
}
return nullptr;
}
} // extern "C"
// Model-level state management
class ModelResource : public BackendModel {
public:
static TRITONSERVER_Error* Create(
TRITONBACKEND_Model* model, ModelResource** resource);
virtual ~ModelResource() = default;
const std::string& InputName() const { return input_name_; }
const std::string& OutputName() const { return output_name_; }
TRITONSERVER_DataType DataType() const { return data_type_; }
const std::vector<int64_t>& Shape() const { return shape_; }
private:
ModelResource(TRITONBACKEND_Model* model);
TRITONSERVER_Error* Validate();
std::string input_name_;
std::string output_name_;
TRITONSERVER_DataType data_type_;
std::vector<int64_t> shape_;
};
ModelResource::ModelResource(TRITONBACKEND_Model* model)
: BackendModel(model)
{
THROW_IF_BACKEND_MODEL_ERROR(Validate());
}
TRITONSERVER_Error*
ModelResource::Create(TRITONBACKEND_Model* model, ModelResource** resource)
{
try {
*resource = new ModelResource(model);
}
catch (const BackendModelException& ex) {
RETURN_IF_ERROR(ex.err_);
}
return nullptr;
}
TRITONSERVER_Error*
ModelResource::Validate()
{
common::TritonJson::Value inputs, outputs;
RETURN_IF_ERROR(ModelConfig().MemberAsArray("input", &inputs));
RETURN_IF_ERROR(ModelConfig().MemberAsArray("output", &outputs));
// Validate single input/output configuration
RETURN_ERROR_IF_FALSE(
inputs.ArraySize() == 1, TRITONSERVER_ERROR_INVALID_ARG,
"model requires exactly one input");
RETURN_ERROR_IF_FALSE(
outputs.ArraySize() == 1, TRITONSERVER_ERROR_INVALID_ARG,
"model requires exactly one output");
// Extract input configuration
common::TritonJson::Value input_config;
RETURN_IF_ERROR(inputs.IndexAsObject(0, &input_config));
const char* in_name = nullptr;
size_t in_name_len = 0;
RETURN_IF_ERROR(input_config.MemberAsString("name", &in_name, &in_name_len));
input_name_ = std::string(in_name);
// Extract output configuration
common::TritonJson::Value output_config;
RETURN_IF_ERROR(outputs.IndexAsObject(0, &output_config));
const char* out_name = nullptr;
size_t out_name_len = 0;
RETURN_IF_ERROR(output_config.MemberAsString("name", &out_name, &out_name_len));
output_name_ = std::string(out_name);
// Verify datatype matching
std::string input_dtype, output_dtype;
RETURN_IF_ERROR(input_config.MemberAsString("data_type", &input_dtype));
RETURN_IF_ERROR(output_config.MemberAsString("data_type", &output_dtype));
RETURN_ERROR_IF_FALSE(
input_dtype == output_dtype, TRITONSERVER_ERROR_INVALID_ARG,
"input and output datatypes must match");
data_type_ = ModelConfigDataTypeToTritonServerDataType(input_dtype);
// Parse shape dimensions
RETURN_IF_ERROR(backend::ParseShape(input_config, "dims", &shape_));
return nullptr;
}
extern "C" {
// Model loading
TRITONSERVER_Error*
TRITONBACKEND_ModelInitialize(TRITONBACKEND_Model* model)
{
ModelResource* model_resource = nullptr;
RETURN_IF_ERROR(ModelResource::Create(model, &model_resource));
RETURN_IF_ERROR(
TRITONBACKEND_ModelSetState(model, reinterpret_cast<void*>(model_resource)));
LOG_MESSAGE(TRITONSERVER_LOG_INFO, "Model initialized successfully");
return nullptr;
}
// Model unloading
TRITONSERVER_Error*
TRITONBACKEND_ModelFinalize(TRITONBACKEND_Model* model)
{
void* state_ptr = nullptr;
RETURN_IF_ERROR(TRITONBACKEND_ModelState(model, &state_ptr));
if (state_ptr != nullptr) {
ModelResource* resource = reinterpret_cast<ModelResource*>(state_ptr);
delete resource;
}
return nullptr;
}
} // extern "C"
// Instance-level state management
class InstanceResource : public BackendModelInstance {
public:
static TRITONSERVER_Error* Create(
ModelResource* model_state,
TRITONBACKEND_ModelInstance* instance,
InstanceResource** resource);
virtual ~InstanceResource() = default;
ModelResource* GetModelState() const { return model_state_; }
private:
InstanceResource(
ModelResource* model_state,
TRITONBACKEND_ModelInstance* instance)
: BackendModelInstance(model_state, instance),
model_state_(model_state)
{
}
ModelResource* model_state_;
};
TRITONSERVER_Error*
InstanceResource::Create(
ModelResource* model_state,
TRITONBACKEND_ModelInstance* instance,
InstanceResource** resource)
{
try {
*resource = new InstanceResource(model_state, instance);
}
catch (const BackendModelInstanceException& ex) {
RETURN_IF_ERROR(ex.err_);
}
return nullptr;
}
extern "C" {
// Instance initialization
TRITONSERVER_Error*
TRITONBACKEND_ModelInstanceInitialize(TRITONBACKEND_ModelInstance* instance)
{
TRITONBACKEND_Model* model_handle = nullptr;
RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceModel(instance, &model_handle));
void* model_state_raw = nullptr;
RETURN_IF_ERROR(TRITONBACKEND_ModelState(model_handle, &model_state_raw));
ModelResource* model_state = reinterpret_cast<ModelResource*>(model_state_raw);
InstanceResource* instance_state = nullptr;
RETURN_IF_ERROR(InstanceResource::Create(model_state, instance, &instance_state));
RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceSetState(
instance, reinterpret_cast<void*>(instance_state)));
return nullptr;
}
// Instance cleanup
TRITONSERVER_Error*
TRITONBACKEND_ModelInstanceFinalize(TRITONBACKEND_ModelInstance* instance)
{
void* instance_state_raw = nullptr;
RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceState(instance, &instance_state_raw));
if (instance_state_raw != nullptr) {
InstanceResource* instance_state =
reinterpret_cast<InstanceResource*>(instance_state_raw);
delete instance_state;
}
return nullptr;
}
} // extern "C"
extern "C" {
// Main inference execution path
TRITONSERVER_Error*
TRITONBACKEND_ModelInstanceExecute(
TRITONBACKEND_ModelInstance* instance,
TRITONBACKEND_Request** requests,
uint32_t request_count)
{
// Retrieve instance state
InstanceResource* instance_state = nullptr;
RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceState(
instance, reinterpret_cast<void**>(&instance_state)));
ModelResource* model_state = instance_state->GetModelState();
LOG_MESSAGE(
TRITONSERVER_LOG_INFO,
(std::string("Processing batch of ") +
std::to_string(request_count) + " requests").c_str());
// Prepare response objects
std::vector<TRITONBACKEND_Response*> responses;
responses.reserve(request_count);
for (uint32_t i = 0; i < request_count; ++i) {
TRITONBACKEND_Response* response = nullptr;
RETURN_IF_ERROR(TRITONBACKEND_ResponseNew(&response, requests[i]));
responses.push_back(response);
}
// Input collection using backend utilities
BackendInputCollector collector(
requests, request_count, &responses,
model_state->TritonMemoryManager(),
false, nullptr);
std::vector<std::pair<TRITONSERVER_MemoryType, int64_t>> memory_types = {
{TRITONSERVER_MEMORY_CPU, 0},
{TRITONSERVER_MEMORY_CPU_PINNED, 0}
};
const char* input_buffer = nullptr;
size_t input_buffer_size = 0;
TRITONSERVER_MemoryType memory_type = TRITONSERVER_MEMORY_CPU;
int64_t device_id = 0;
RESPOND_ALL_AND_SET_NULL_IF_ERROR(
responses, request_count,
collector.ProcessTensor(
model_state->InputName().c_str(),
nullptr, 0,
memory_types,
&input_buffer, &input_buffer_size,
&memory_type, &device_id));
// Finalize input collection
bool needs_sync = collector.Finalize();
if (needs_sync) {
LOG_MESSAGE(TRITONSERVER_LOG_ERROR, "unexpected sync required");
}
// Prepare output tensor using the same buffer (passthrough)
const char* output_buffer = input_buffer;
TRITONSERVER_MemoryType output_memory_type = memory_type;
int64_t output_device_id = device_id;
// Determine tensor shape with batch dimension handling
bool supports_batching = false;
RESPOND_ALL_AND_SET_NULL_IF_ERROR(
responses, request_count,
model_state->SupportsFirstDimBatching(&supports_batching));
std::vector<int64_t> output_shape = model_state->Shape();
if (supports_batching) {
output_shape.insert(output_shape.begin(), request_count);
}
// Output response handling
BackendOutputResponder responder(
requests, request_count, &responses,
model_state->TritonMemoryManager(),
supports_batching, false, nullptr);
responder.ProcessTensor(
model_state->OutputName().c_str(),
model_state->DataType(),
output_shape,
output_buffer, output_memory_type, output_device_id);
bool responder_sync = responder.Finalize();
if (responder_sync) {
LOG_MESSAGE(TRITONSERVER_LOG_ERROR, "unexpected responder sync");
}
// Send all responses
for (auto& response : responses) {
if (response != nullptr) {
LOG_IF_ERROR(
TRITONBACKEND_ResponseSend(
response, TRITONSERVER_RESPONSE_COMPLETE_FINAL, nullptr),
"response send failed");
}
}
// Release requests and report statistics
for (uint32_t i = 0; i < request_count; ++i) {
#ifdef TRITON_ENABLE_STATS
LOG_IF_ERROR(
TRITONBACKEND_RequestRelease(requests[i], TRITONSERVER_REQUEST_RELEASE_ALL),
"request release failed");
#endif
}
return nullptr;
}
} // extern "C"
}}} // namespace triton::backend::custom
Build Commands
Compile the backend with CMake:
mkdir build
cmake -DCMAKE_INSTALL_PREFIX:PATH=`pwd`/install ..
make install
This generates a shared library named libtriton_custom.so.
Model Configuration
Create a config.pbtxt file in your model repository:
backend: "custom"
max_batch_size: 16
dynamic_batching {
max_queue_delay_microseconds: 1000000
}
input [
{
name: "INPUTTensor"
data_type: TYPE_FP32
dims: [ -1, 128 ]
}
]
output [
{
name: "OUTPUTTensor"
data_type: TYPE_FP32
dims: [ -1, 128 ]
}
]
instance_group [
{
kind: KIND_CPU
count: 2
}
]
The backend binary must be named following the pattern libtriton_<backend_name>.so.
Server Deployment
Launch the TritonServer container with the model repository mounted:
docker run --rm \
-p8000:8000 \
-p8001:8001 \
-p8002:8002 \
-v /path/to/model_repository:/models \
-it nvcr.io/nvidia/tritonserver:23.12-py3
Start the server:
tritonserver --model-repository=/models
Python Custom Backend Implementation
Overview
Python backends require implementing only three methods, significantly reducing implementation complexity compared to C++ backends. The model class must be named TritonPythonModel.
Implementation
import json
import numpy as np
import triton_python_backend_utils as pb_utils
class TritonPythonModel:
"""Python backend implementation for custom inference logic."""
def initialize(self, args):
"""Called once during model loading.
Parameters
----------
args : dict
Dictionary containing:
- model_config: JSON string with model configuration
- model_instance_kind: Device type (CPU/GPU)
- model_instance_device_id: Device identifier
- model_repository: Path to model directory
- model_version: Model version number
- model_name: Name of the model
"""
self.model_config = json.loads(args["model_config"])
# Extract output configurations
self.output0_config = pb_utils.get_output_config_by_name(
self.model_config, "OUTPUT0")
self.output1_config = pb_utils.get_output_config_by_name(
self.model_config, "OUTPUT1")
# Convert Triton types to numpy dtypes
self.output0_type = pb_utils.triton_string_to_numpy(
self.output0_config["data_type"])
self.output1_type = pb_utils.triton_string_to_numpy(
self.output1_config["data_type"])
print(f"Model initialized: {args['model_name']}")
def execute(self, requests):
"""Process inference requests.
Parameters
----------
requests : list
List of pb_utils.InferenceRequest objects
Returns
-------
list
List of pb_utils.InferenceResponse objects
"""
output0_type = self.output0_type
output1_type = self.output1_type
output_responses = []
for request in requests:
# Extract input tensors by name
input_tensor0 = pb_utils.get_input_tensor_by_name(request, "INPUT0")
input_tensor1 = pb_utils.get_input_tensor_by_name(request, "INPUT1")
# Perform computation
data0 = input_tensor0.as_numpy()
data1 = input_tensor1.as_numpy()
result_add = data0 + data1
result_subtract = data0 - data1
# Create output tensors
output_tensor0 = pb_utils.Tensor(
"OUTPUT0",
result_add.astype(output0_type))
output_tensor1 = pb_utils.Tensor(
"OUTPUT1",
result_subtract.astype(output1_type))
# Build response
response = pb_utils.InferenceResponse(
output_tensors=[output_tensor0, output_tensor1])
output_responses.append(response)
return output_responses
def finalize(self):
"""Cleanup resources before model unloading."""
print("Model cleanup completed")
Model Configuration
name: "addition_backend"
backend: "python"
input [
{
name: "INPUT0"
data_type: TYPE_FP32
dims: [ 4 ]
},
{
name: "INPUT1"
data_type: TYPE_FP32
dims: [ 4 ]
}
]
output [
{
name: "OUTPUT0"
data_type: TYPE_FP32
dims: [ 4 ]
},
{
name: "OUTPUT1"
data_type: TYPE_FP32
dims: [ 4 ]
}
]
instance_group [{ kind: KIND_CPU }]
Directory Structure
model_repository/
└── addition_backend/
├── config.pbtxt
└── 1/
└── model.py
Client-Side Inference
HTTP Client Implementation
import numpy as np
import tritonclient.http as httpclient
from tritonclient.utils import np_to_triton_dtype
def main():
model_name = "addition_backend"
input_shape = [4]
with httpclient.InferenceServerClient("localhost:8000") as client:
# Generate random test data
input0 = np.random.rand(*input_shape).astype(np.float32)
input1 = np.random.rand(*input_shape).astype(np.float32)
# Prepare inputs
inputs = [
httpclient.InferInput(
"INPUT0",
input0.shape,
np_to_triton_dtype(input0.dtype)),
httpclient.InferInput(
"INPUT1",
input1.shape,
np_to_triton_dtype(input1.dtype))
]
inputs[0].set_data_from_numpy(input0)
inputs[1].set_data_from_numpy(input1)
# Request outputs
outputs = [
httpclient.InferRequestedOutput("OUTPUT0"),
httpclient.InferRequestedOutput("OUTPUT1")
]
# Execute inference
response = client.infer(
model_name,
inputs,
request_id="1",
outputs=outputs)
# Extract results
output0 = response.as_numpy("OUTPUT0")
output1 = response.as_numpy("OUTPUT1")
# Verify results
print(f"Input A: {input0}")
print(f"Input B: {input1}")
print(f"A + B = {output0}")
print(f"A - B = {output1}")
assert np.allclose(input0 + input1, output0), "Addition verification failed"
assert np.allclose(input0 - input1, output1), "Subtraction verification failed"
print("Inference test passed")
if __name__ == "__main__":
main()
Running the Client
Execute the client script within the Triton client container:
docker run --rm -it \
--network host \
nvcr.io/nvidia/tritonserver:23.12-py3-sdk \
python client_script.py
Key Differences: C++ vs Python Backends
| Aspect | C++ Backend | Python Backend |
|---|---|---|
| API Count | 7 functions | 3 methods |
| Memory Control | Manual management | Automatic (Python) |
| Performance | Higher throughput | Slightly lower |
| Complexity | More boilerplate | Simplified implementation |
| Deployment | Shared library (.so) | Python source files |