Skip to content

Commit c44ed3f

Browse files
yhna940yhna941
authored andcommitted
feat: add deterministic alg flag
1 parent e3244bc commit c44ed3f

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

README.md

+15
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,21 @@ key: "ENABLE_WEIGHT_SHARING"
180180
}
181181
```
182182

183+
* `ENABLE_DETERMINISTIC_ALGORITHMS`: Boolean flag to enable deterministic algorithm selection for TorchScript models. By default, deterministic algorithms are disabled.
184+
185+
When this flag is set to `true`, Triton will configure the PyTorch backend to use only deterministic algorithm implementations. This ensures that model outputs are reproducible across runs, at the cost of potential performance degradation. If any operation does not have a deterministic version, an error will be raised.
186+
187+
The section of model config file specifying this parameter will look like:
188+
189+
```
190+
parameters: {
191+
key: "ENABLE_DETERMINISTIC_ALGORITHMS"
192+
value: {
193+
string_value: "true"
194+
}
195+
}
196+
```
197+
183198
* `ENABLE_CACHE_CLEANING`: Boolean flag to enable CUDA cache cleaning after each model execution.
184199
If not specified, cache cleaning is disabled. This flag has no effect if model is on CPU.
185200
Setting this flag to true will negatively impact the performance due to additional CUDA cache

src/libtorch.cc

+32
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,10 @@ class ModelState : public BackendModel {
108108
bool EnabledCacheCleaning() { return enable_cache_cleaning_; }
109109

110110
bool EnabledWeightSharing() { return enable_weight_sharing_; }
111+
bool EnableDeterministicAlgorithms()
112+
{
113+
return enable_deterministic_algorithms_;
114+
}
111115
const std::map<std::string, std::pair<int64_t, int64_t>>& ModelOutputs()
112116
{
113117
return model_outputs_;
@@ -136,6 +140,9 @@ class ModelState : public BackendModel {
136140
// Flag to indicate whether weight sharing is enabled. Defaults to false.
137141
bool enable_weight_sharing_;
138142

143+
// Flag to indicate whether deterministic algorithms are enabled.
144+
bool enable_deterministic_algorithms_;
145+
139146
// Flag pairs to indicate if various JIT settings are set and
140147
// enabled respectively. Defaults to (false, true). Default behavior
141148
// is to do nothing if not explicitly set.
@@ -233,6 +240,7 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
233240
: BackendModel(triton_model), enable_optimized_execution_(true),
234241
enable_inference_mode_(true), enable_cudnn_(true),
235242
enable_cache_cleaning_(false), enable_weight_sharing_(false),
243+
enable_deterministic_algorithms_(false),
236244
enable_tensor_fuser_pair_({false, true}),
237245
enable_jit_profiling_pair_({false, true}),
238246
enable_jit_executor_pair_({false, true})
@@ -455,6 +463,26 @@ ModelState::ParseParameters()
455463
.c_str());
456464
}
457465

466+
// If `ENABLE_DETERMINISTIC_ALGORITHMS` is not present in 'parameters' then
467+
// no update is made to 'enable_deterministic_algorithms_'.
468+
err = ParseParameter(
469+
params, "ENABLE_DETERMINISTIC_ALGORITHMS",
470+
&enable_deterministic_algorithms_);
471+
if (err != nullptr) {
472+
if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) {
473+
return err;
474+
} else {
475+
TRITONSERVER_ErrorDelete(err);
476+
}
477+
} else {
478+
LOG_MESSAGE(
479+
TRITONSERVER_LOG_INFO,
480+
(std::string("Deterministic algorithms are ") +
481+
(enable_deterministic_algorithms_ ? "enabled" : "disabled") +
482+
" for model instance '" + Name() + "'")
483+
.c_str());
484+
}
485+
458486
// If 'ENABLE_JIT_PROFILING' is not present in 'parameters' then no update
459487
// is made to 'enable_jit_profiling'.
460488
bool enable_jit_profiling = false;
@@ -1588,6 +1616,10 @@ ModelInstanceState::Execute(
15881616
// enable/disable cudnn
15891617
at::globalContext().setUserEnabledCuDNN(model_state_->EnabledCudnn());
15901618

1619+
// enable/disable deterministic algorithms
1620+
at::globalContext().setDeterministicAlgorithms(
1621+
model_state_->EnableDeterministicAlgorithms(), false /* warn_only */);
1622+
15911623
// JIT. No change is made unless parameter is explicitly set.
15921624
if (std::get<0>(model_state_->EnabledJitProfiling())) {
15931625
torch::jit::getProfilingMode() =

0 commit comments

Comments
 (0)