20
20
import textwrap
21
21
import time
22
22
import warnings
23
- from collections .abc import Callable
23
+ from collections .abc import Callable , Collection , Sequence
24
24
from contextlib import AbstractContextManager , nullcontext
25
25
from io import BytesIO , StringIO
26
26
from pathlib import Path
@@ -2736,6 +2736,96 @@ def check_mkl_openmp():
2736
2736
)
2737
2737
2738
2738
2739
+ def _check_required_file (
2740
+ paths : Collection [Path ],
2741
+ required_regexs : Collection [str | re .Pattern [str ]],
2742
+ ) -> list [tuple [str , str ]]:
2743
+ """Select path parents for each required pattern."""
2744
+ libs : list [tuple [str , str ]] = []
2745
+ for req in required_regexs :
2746
+ found = False
2747
+ for path in paths :
2748
+ m = re .search (req , path .name )
2749
+ if m :
2750
+ libs .append ((str (path .parent ), m .string [slice (* m .span ())]))
2751
+ found = True
2752
+ break
2753
+ if not found :
2754
+ _logger .debug ("Required file '%s' not found" , req )
2755
+ raise RuntimeError (f"Required file { req } not found" )
2756
+ return libs
2757
+
2758
+
2759
+ def _get_cxx_library_dirs () -> list [str ]:
2760
+ """Query C++ search dirs and return those the existing ones."""
2761
+ cmd = [config .cxx , "-print-search-dirs" ]
2762
+ p = subprocess_Popen (
2763
+ cmd ,
2764
+ stdout = subprocess .PIPE ,
2765
+ stderr = subprocess .PIPE ,
2766
+ stdin = subprocess .PIPE ,
2767
+ )
2768
+ (stdout , stderr ) = p .communicate (input = b"" )
2769
+ if p .returncode != 0 :
2770
+ warnings .warn (
2771
+ "Pytensor cxx failed to communicate its search dirs. As a consequence, "
2772
+ "it might not be possible to automatically determine the blas link flags to use.\n "
2773
+ f"Command that was run: { config .cxx } -print-search-dirs\n "
2774
+ f"Output printed to stderr: { stderr .decode (sys .stderr .encoding )} "
2775
+ )
2776
+ return []
2777
+
2778
+ maybe_lib_dirs = [
2779
+ [Path (p ).resolve () for p in line [len ("libraries: =" ) :].split (":" )]
2780
+ for line in stdout .decode (sys .getdefaultencoding ()).splitlines ()
2781
+ if line .startswith ("libraries: =" )
2782
+ ]
2783
+ if not maybe_lib_dirs :
2784
+ return []
2785
+ return [str (d ) for d in maybe_lib_dirs [0 ] if d .exists () and d .is_dir ()]
2786
+
2787
+
2788
+ def _check_libs (
2789
+ all_libs : Collection [Path ],
2790
+ required_libs : Collection [str | re .Pattern ],
2791
+ extra_compile_flags : Sequence [str ] = (),
2792
+ cxx_library_dirs : Sequence [str ] = (),
2793
+ ) -> str :
2794
+ """Assembly library paths and try BLAS flags, returning the flags on success."""
2795
+ found_libs = _check_required_file (
2796
+ all_libs ,
2797
+ required_libs ,
2798
+ )
2799
+ path_quote = '"' if sys .platform == "win32" else ""
2800
+ libdir_ldflags = list (
2801
+ dict .fromkeys (
2802
+ [
2803
+ f"-L{ path_quote } { lib_path } { path_quote } "
2804
+ for lib_path , _ in found_libs
2805
+ if lib_path not in cxx_library_dirs
2806
+ ]
2807
+ )
2808
+ )
2809
+
2810
+ flags = (
2811
+ libdir_ldflags
2812
+ + [f"-l{ lib_name } " for _ , lib_name in found_libs ]
2813
+ + list (extra_compile_flags )
2814
+ )
2815
+ res = try_blas_flag (flags )
2816
+ if not res :
2817
+ _logger .debug ("Supplied flags '%s' failed to compile" , res )
2818
+ raise RuntimeError (f"Supplied flags { flags } failed to compile" )
2819
+
2820
+ if any ("mkl" in flag for flag in flags ):
2821
+ try :
2822
+ check_mkl_openmp ()
2823
+ except Exception as e :
2824
+ _logger .debug (e )
2825
+ _logger .debug ("The following blas flags will be used: '%s'" , res )
2826
+ return res
2827
+
2828
+
2739
2829
def default_blas_ldflags () -> str :
2740
2830
"""Look for an available BLAS implementation in the system.
2741
2831
@@ -2763,88 +2853,6 @@ def default_blas_ldflags() -> str:
2763
2853
2764
2854
"""
2765
2855
2766
- def check_required_file (paths , required_regexs ):
2767
- libs = []
2768
- for req in required_regexs :
2769
- found = False
2770
- for path in paths :
2771
- m = re .search (req , path .name )
2772
- if m :
2773
- libs .append ((str (path .parent ), m .string [slice (* m .span ())]))
2774
- found = True
2775
- break
2776
- if not found :
2777
- _logger .debug ("Required file '%s' not found" , req )
2778
- raise RuntimeError (f"Required file { req } not found" )
2779
- return libs
2780
-
2781
- def get_cxx_library_dirs ():
2782
- cmd = [config .cxx , "-print-search-dirs" ]
2783
- p = subprocess_Popen (
2784
- cmd ,
2785
- stdout = subprocess .PIPE ,
2786
- stderr = subprocess .PIPE ,
2787
- stdin = subprocess .PIPE ,
2788
- )
2789
- (stdout , stderr ) = p .communicate (input = b"" )
2790
- if p .returncode != 0 :
2791
- warnings .warn (
2792
- "Pytensor cxx failed to communicate its search dirs. As a consequence, "
2793
- "it might not be possible to automatically determine the blas link flags to use.\n "
2794
- f"Command that was run: { config .cxx } -print-search-dirs\n "
2795
- f"Output printed to stderr: { stderr .decode (sys .stderr .encoding )} "
2796
- )
2797
- return []
2798
-
2799
- maybe_lib_dirs = [
2800
- [Path (p ).resolve () for p in line [len ("libraries: =" ) :].split (":" )]
2801
- for line in stdout .decode (sys .getdefaultencoding ()).splitlines ()
2802
- if line .startswith ("libraries: =" )
2803
- ]
2804
- if len (maybe_lib_dirs ) > 0 :
2805
- maybe_lib_dirs = maybe_lib_dirs [0 ]
2806
- return [str (d ) for d in maybe_lib_dirs if d .exists () and d .is_dir ()]
2807
-
2808
- def check_libs (
2809
- all_libs , required_libs , extra_compile_flags = None , cxx_library_dirs = None
2810
- ) -> str :
2811
- if cxx_library_dirs is None :
2812
- cxx_library_dirs = []
2813
- if extra_compile_flags is None :
2814
- extra_compile_flags = []
2815
- found_libs = check_required_file (
2816
- all_libs ,
2817
- required_libs ,
2818
- )
2819
- path_quote = '"' if sys .platform == "win32" else ""
2820
- libdir_ldflags = list (
2821
- dict .fromkeys (
2822
- [
2823
- f"-L{ path_quote } { lib_path } { path_quote } "
2824
- for lib_path , _ in found_libs
2825
- if lib_path not in cxx_library_dirs
2826
- ]
2827
- )
2828
- )
2829
-
2830
- flags = (
2831
- libdir_ldflags
2832
- + [f"-l{ lib_name } " for _ , lib_name in found_libs ]
2833
- + extra_compile_flags
2834
- )
2835
- res = try_blas_flag (flags )
2836
- if res :
2837
- if any ("mkl" in flag for flag in flags ):
2838
- try :
2839
- check_mkl_openmp ()
2840
- except Exception as e :
2841
- _logger .debug (e )
2842
- _logger .debug ("The following blas flags will be used: '%s'" , res )
2843
- return res
2844
- else :
2845
- _logger .debug ("Supplied flags '%s' failed to compile" , res )
2846
- raise RuntimeError (f"Supplied flags { flags } failed to compile" )
2847
-
2848
2856
# If no compiler is available we default to empty ldflags
2849
2857
if not config .cxx :
2850
2858
return ""
@@ -2854,7 +2862,7 @@ def check_libs(
2854
2862
else :
2855
2863
rpath = None
2856
2864
2857
- cxx_library_dirs = get_cxx_library_dirs ()
2865
+ cxx_library_dirs = _get_cxx_library_dirs ()
2858
2866
searched_library_dirs = cxx_library_dirs + _std_lib_dirs
2859
2867
if sys .platform == "win32" :
2860
2868
# Conda on Windows saves MKL libraries under CONDA_PREFIX\Library\bin
@@ -2884,7 +2892,7 @@ def check_libs(
2884
2892
try :
2885
2893
# 1. Try to use MKL with INTEL OpenMP threading
2886
2894
_logger .debug ("Checking MKL flags with intel threading" )
2887
- return check_libs (
2895
+ return _check_libs (
2888
2896
all_libs ,
2889
2897
required_libs = [
2890
2898
"mkl_core" ,
@@ -2901,7 +2909,7 @@ def check_libs(
2901
2909
try :
2902
2910
# 2. Try to use MKL with GNU OpenMP threading
2903
2911
_logger .debug ("Checking MKL flags with GNU OpenMP threading" )
2904
- return check_libs (
2912
+ return _check_libs (
2905
2913
all_libs ,
2906
2914
required_libs = ["mkl_core" , "mkl_rt" , "mkl_gnu_thread" , "gomp" , "pthread" ],
2907
2915
extra_compile_flags = [f"-Wl,-rpath,{ rpath } " ] if rpath is not None else [],
@@ -2924,7 +2932,7 @@ def check_libs(
2924
2932
try :
2925
2933
_logger .debug ("Checking Lapack + blas" )
2926
2934
# 4. Try to use LAPACK + BLAS
2927
- return check_libs (
2935
+ return _check_libs (
2928
2936
all_libs ,
2929
2937
required_libs = ["lapack" , "blas" , "cblas" , "m" ],
2930
2938
extra_compile_flags = [f"-Wl,-rpath,{ rpath } " ] if rpath is not None else [],
@@ -2935,7 +2943,7 @@ def check_libs(
2935
2943
try :
2936
2944
# 5. Try to use BLAS alone
2937
2945
_logger .debug ("Checking blas alone" )
2938
- return check_libs (
2946
+ return _check_libs (
2939
2947
all_libs ,
2940
2948
required_libs = ["blas" , "cblas" ],
2941
2949
extra_compile_flags = [f"-Wl,-rpath,{ rpath } " ] if rpath is not None else [],
@@ -2946,7 +2954,7 @@ def check_libs(
2946
2954
try :
2947
2955
# 6. Try to use openblas
2948
2956
_logger .debug ("Checking openblas" )
2949
- return check_libs (
2957
+ return _check_libs (
2950
2958
all_libs ,
2951
2959
required_libs = ["openblas" , "gfortran" , "gomp" , "m" ],
2952
2960
extra_compile_flags = ["-fopenmp" , f"-Wl,-rpath,{ rpath } " ]
0 commit comments