Skip to content

Commit d2ef2b0

Browse files
Extract local functions and type them
They could get tested independently. This was mostly done to understand their purpose.
1 parent 2928211 commit d2ef2b0

File tree

1 file changed

+97
-89
lines changed

1 file changed

+97
-89
lines changed

pytensor/link/c/cmodule.py

+97-89
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import textwrap
2121
import time
2222
import warnings
23-
from collections.abc import Callable
23+
from collections.abc import Callable, Collection, Sequence
2424
from contextlib import AbstractContextManager, nullcontext
2525
from io import BytesIO, StringIO
2626
from pathlib import Path
@@ -2736,6 +2736,96 @@ def check_mkl_openmp():
27362736
)
27372737

27382738

2739+
def _check_required_file(
2740+
paths: Collection[Path],
2741+
required_regexs: Collection[str | re.Pattern[str]],
2742+
) -> list[tuple[str, str]]:
2743+
"""Select path parents for each required pattern."""
2744+
libs: list[tuple[str, str]] = []
2745+
for req in required_regexs:
2746+
found = False
2747+
for path in paths:
2748+
m = re.search(req, path.name)
2749+
if m:
2750+
libs.append((str(path.parent), m.string[slice(*m.span())]))
2751+
found = True
2752+
break
2753+
if not found:
2754+
_logger.debug("Required file '%s' not found", req)
2755+
raise RuntimeError(f"Required file {req} not found")
2756+
return libs
2757+
2758+
2759+
def _get_cxx_library_dirs() -> list[str]:
2760+
"""Query C++ search dirs and return those the existing ones."""
2761+
cmd = [config.cxx, "-print-search-dirs"]
2762+
p = subprocess_Popen(
2763+
cmd,
2764+
stdout=subprocess.PIPE,
2765+
stderr=subprocess.PIPE,
2766+
stdin=subprocess.PIPE,
2767+
)
2768+
(stdout, stderr) = p.communicate(input=b"")
2769+
if p.returncode != 0:
2770+
warnings.warn(
2771+
"Pytensor cxx failed to communicate its search dirs. As a consequence, "
2772+
"it might not be possible to automatically determine the blas link flags to use.\n"
2773+
f"Command that was run: {config.cxx} -print-search-dirs\n"
2774+
f"Output printed to stderr: {stderr.decode(sys.stderr.encoding)}"
2775+
)
2776+
return []
2777+
2778+
maybe_lib_dirs = [
2779+
[Path(p).resolve() for p in line[len("libraries: =") :].split(":")]
2780+
for line in stdout.decode(sys.getdefaultencoding()).splitlines()
2781+
if line.startswith("libraries: =")
2782+
]
2783+
if not maybe_lib_dirs:
2784+
return []
2785+
return [str(d) for d in maybe_lib_dirs[0] if d.exists() and d.is_dir()]
2786+
2787+
2788+
def _check_libs(
2789+
all_libs: Collection[Path],
2790+
required_libs: Collection[str | re.Pattern],
2791+
extra_compile_flags: Sequence[str] = (),
2792+
cxx_library_dirs: Sequence[str] = (),
2793+
) -> str:
2794+
"""Assembly library paths and try BLAS flags, returning the flags on success."""
2795+
found_libs = _check_required_file(
2796+
all_libs,
2797+
required_libs,
2798+
)
2799+
path_quote = '"' if sys.platform == "win32" else ""
2800+
libdir_ldflags = list(
2801+
dict.fromkeys(
2802+
[
2803+
f"-L{path_quote}{lib_path}{path_quote}"
2804+
for lib_path, _ in found_libs
2805+
if lib_path not in cxx_library_dirs
2806+
]
2807+
)
2808+
)
2809+
2810+
flags = (
2811+
libdir_ldflags
2812+
+ [f"-l{lib_name}" for _, lib_name in found_libs]
2813+
+ list(extra_compile_flags)
2814+
)
2815+
res = try_blas_flag(flags)
2816+
if not res:
2817+
_logger.debug("Supplied flags '%s' failed to compile", res)
2818+
raise RuntimeError(f"Supplied flags {flags} failed to compile")
2819+
2820+
if any("mkl" in flag for flag in flags):
2821+
try:
2822+
check_mkl_openmp()
2823+
except Exception as e:
2824+
_logger.debug(e)
2825+
_logger.debug("The following blas flags will be used: '%s'", res)
2826+
return res
2827+
2828+
27392829
def default_blas_ldflags() -> str:
27402830
"""Look for an available BLAS implementation in the system.
27412831
@@ -2763,88 +2853,6 @@ def default_blas_ldflags() -> str:
27632853
27642854
"""
27652855

2766-
def check_required_file(paths, required_regexs):
2767-
libs = []
2768-
for req in required_regexs:
2769-
found = False
2770-
for path in paths:
2771-
m = re.search(req, path.name)
2772-
if m:
2773-
libs.append((str(path.parent), m.string[slice(*m.span())]))
2774-
found = True
2775-
break
2776-
if not found:
2777-
_logger.debug("Required file '%s' not found", req)
2778-
raise RuntimeError(f"Required file {req} not found")
2779-
return libs
2780-
2781-
def get_cxx_library_dirs():
2782-
cmd = [config.cxx, "-print-search-dirs"]
2783-
p = subprocess_Popen(
2784-
cmd,
2785-
stdout=subprocess.PIPE,
2786-
stderr=subprocess.PIPE,
2787-
stdin=subprocess.PIPE,
2788-
)
2789-
(stdout, stderr) = p.communicate(input=b"")
2790-
if p.returncode != 0:
2791-
warnings.warn(
2792-
"Pytensor cxx failed to communicate its search dirs. As a consequence, "
2793-
"it might not be possible to automatically determine the blas link flags to use.\n"
2794-
f"Command that was run: {config.cxx} -print-search-dirs\n"
2795-
f"Output printed to stderr: {stderr.decode(sys.stderr.encoding)}"
2796-
)
2797-
return []
2798-
2799-
maybe_lib_dirs = [
2800-
[Path(p).resolve() for p in line[len("libraries: =") :].split(":")]
2801-
for line in stdout.decode(sys.getdefaultencoding()).splitlines()
2802-
if line.startswith("libraries: =")
2803-
]
2804-
if len(maybe_lib_dirs) > 0:
2805-
maybe_lib_dirs = maybe_lib_dirs[0]
2806-
return [str(d) for d in maybe_lib_dirs if d.exists() and d.is_dir()]
2807-
2808-
def check_libs(
2809-
all_libs, required_libs, extra_compile_flags=None, cxx_library_dirs=None
2810-
) -> str:
2811-
if cxx_library_dirs is None:
2812-
cxx_library_dirs = []
2813-
if extra_compile_flags is None:
2814-
extra_compile_flags = []
2815-
found_libs = check_required_file(
2816-
all_libs,
2817-
required_libs,
2818-
)
2819-
path_quote = '"' if sys.platform == "win32" else ""
2820-
libdir_ldflags = list(
2821-
dict.fromkeys(
2822-
[
2823-
f"-L{path_quote}{lib_path}{path_quote}"
2824-
for lib_path, _ in found_libs
2825-
if lib_path not in cxx_library_dirs
2826-
]
2827-
)
2828-
)
2829-
2830-
flags = (
2831-
libdir_ldflags
2832-
+ [f"-l{lib_name}" for _, lib_name in found_libs]
2833-
+ extra_compile_flags
2834-
)
2835-
res = try_blas_flag(flags)
2836-
if res:
2837-
if any("mkl" in flag for flag in flags):
2838-
try:
2839-
check_mkl_openmp()
2840-
except Exception as e:
2841-
_logger.debug(e)
2842-
_logger.debug("The following blas flags will be used: '%s'", res)
2843-
return res
2844-
else:
2845-
_logger.debug("Supplied flags '%s' failed to compile", res)
2846-
raise RuntimeError(f"Supplied flags {flags} failed to compile")
2847-
28482856
# If no compiler is available we default to empty ldflags
28492857
if not config.cxx:
28502858
return ""
@@ -2854,7 +2862,7 @@ def check_libs(
28542862
else:
28552863
rpath = None
28562864

2857-
cxx_library_dirs = get_cxx_library_dirs()
2865+
cxx_library_dirs = _get_cxx_library_dirs()
28582866
searched_library_dirs = cxx_library_dirs + _std_lib_dirs
28592867
if sys.platform == "win32":
28602868
# Conda on Windows saves MKL libraries under CONDA_PREFIX\Library\bin
@@ -2884,7 +2892,7 @@ def check_libs(
28842892
try:
28852893
# 1. Try to use MKL with INTEL OpenMP threading
28862894
_logger.debug("Checking MKL flags with intel threading")
2887-
return check_libs(
2895+
return _check_libs(
28882896
all_libs,
28892897
required_libs=[
28902898
"mkl_core",
@@ -2901,7 +2909,7 @@ def check_libs(
29012909
try:
29022910
# 2. Try to use MKL with GNU OpenMP threading
29032911
_logger.debug("Checking MKL flags with GNU OpenMP threading")
2904-
return check_libs(
2912+
return _check_libs(
29052913
all_libs,
29062914
required_libs=["mkl_core", "mkl_rt", "mkl_gnu_thread", "gomp", "pthread"],
29072915
extra_compile_flags=[f"-Wl,-rpath,{rpath}"] if rpath is not None else [],
@@ -2924,7 +2932,7 @@ def check_libs(
29242932
try:
29252933
_logger.debug("Checking Lapack + blas")
29262934
# 4. Try to use LAPACK + BLAS
2927-
return check_libs(
2935+
return _check_libs(
29282936
all_libs,
29292937
required_libs=["lapack", "blas", "cblas", "m"],
29302938
extra_compile_flags=[f"-Wl,-rpath,{rpath}"] if rpath is not None else [],
@@ -2935,7 +2943,7 @@ def check_libs(
29352943
try:
29362944
# 5. Try to use BLAS alone
29372945
_logger.debug("Checking blas alone")
2938-
return check_libs(
2946+
return _check_libs(
29392947
all_libs,
29402948
required_libs=["blas", "cblas"],
29412949
extra_compile_flags=[f"-Wl,-rpath,{rpath}"] if rpath is not None else [],
@@ -2946,7 +2954,7 @@ def check_libs(
29462954
try:
29472955
# 6. Try to use openblas
29482956
_logger.debug("Checking openblas")
2949-
return check_libs(
2957+
return _check_libs(
29502958
all_libs,
29512959
required_libs=["openblas", "gfortran", "gomp", "m"],
29522960
extra_compile_flags=["-fopenmp", f"-Wl,-rpath,{rpath}"]

0 commit comments

Comments
 (0)