Skip to content

Commit c41862a

Browse files
authored
build(cudnn-sys): Add CUDNN_INCLUDE_DIR (#213)
* build(cudnn-sys): Add CUDNN_INCLUDE_DIR Enables users to specify a non-standard cuDNN install path. This seems to be needed for the newer editions of the CUDA toolkit, as cuDNN isn't included by default (at least in the Fedora repo's, you have to install from a tarball)
1 parent 3c2e043 commit c41862a

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

crates/cudnn-sys/build/cudnn_sdk.rs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
use std::env;
12
use std::error;
23
use std::fs;
34
use std::path;
5+
use std::path::Path;
46

57
/// Represents the cuDNN SDK installation.
68
#[derive(Debug, Clone)]
@@ -57,17 +59,25 @@ impl CudnnSdk {
5759
}
5860

5961
fn find_cudnn_include_dir() -> Result<path::PathBuf, Box<dyn error::Error>> {
62+
let cudnn_include_dir = env::var_os("CUDNN_INCLUDE_DIR");
63+
6064
#[cfg(not(target_os = "windows"))]
6165
const CUDNN_DEFAULT_PATHS: &[&str] = &["/usr/include", "/usr/local/include"];
6266
#[cfg(target_os = "windows")]
6367
const CUDNN_DEFAULT_PATHS: &[&str] = &[
6468
"C:/Program Files/NVIDIA/CUDNN/v9.x/include",
6569
"C:/Program Files/NVIDIA/CUDNN/v8.x/include",
6670
];
67-
CUDNN_DEFAULT_PATHS
71+
72+
let mut cudnn_paths: Vec<&Path> = CUDNN_DEFAULT_PATHS.iter().map(Path::new).collect();
73+
if let Some(override_path) = &cudnn_include_dir {
74+
cudnn_paths.push(Path::new(override_path));
75+
}
76+
77+
cudnn_paths
6878
.iter()
69-
.find(|s| Self::is_cudnn_include_path(s))
70-
.map(path::PathBuf::from)
79+
.find(|p| Self::is_cudnn_include_path(p))
80+
.map(|p| p.to_path_buf())
7181
.ok_or("Cannot find cuDNN include directory.".into())
7282
}
7383

0 commit comments

Comments
 (0)