Skip to content

Commit 4b64d6e

Browse files
authored
Merge pull request #121 from sdpython/cpp
Fixes tests testing decision criterion
2 parents de886c4 + 53e8361 commit 4b64d6e

36 files changed

+761
-490
lines changed

_cmake/constants.cmake

+13-7
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,12 @@ set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} -O3") # -DNDEBUG
6262
#
6363
# C++ 14 or C++ 17 or...
6464
#
65+
# We need to use the same C++ version as scikit-learn to avoid crashes.
66+
set(CMAKE_CXX_SCIKITLEARN 11)
6567
if (PYTHON_MANYLINUX EQUAL "1")
6668
set(CMAKE_CXX_STANDARD_REQUIRED ON)
6769
set(CMAKE_CXX_EXTENSIONS OFF)
68-
set(CMAKE_CXX_STANDARD 17)
70+
set(CMAKE_CXX_STANDARD ${CMAKE_CXX_SCIKITLEARN})
6971
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wno-unknown-pragmas -Wextra")
7072
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wall -Wno-unknown-pragmas -Wextra")
7173
if(APPLE)
@@ -82,28 +84,32 @@ if (PYTHON_MANYLINUX EQUAL "1")
8284
else()
8385
if(MSVC)
8486
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /std:c++17")
85-
set(CMAKE_CXX_STANDARD 17)
87+
set(CMAKE_CXX_STANDARD ${CMAKE_CXX_SCIKITLEARN})
8688
elseif(APPLE)
8789
set(CMAKE_OSX_DEPLOYMENT_TARGET "10.15")
8890
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
89-
set(CMAKE_CXX_STANDARD 17)
91+
set(CMAKE_CXX_STANDARD ${CMAKE_CXX_SCIKITLEARN})
9092
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wno-unknown-pragmas -Wextra")
9193
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wall -Wno-unknown-pragmas -Wextra")
9294
else()
9395
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wno-unknown-pragmas -Wextra")
9496
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wall -Wno-unknown-pragmas -Wextra")
9597
if(CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL "15")
9698
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++23")
97-
set(CMAKE_CXX_STANDARD 23)
99+
# set(CMAKE_CXX_STANDARD 23)
100+
set(CMAKE_CXX_STANDARD ${CMAKE_CXX_SCIKITLEARN})
98101
elseif(CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL "11")
99102
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++20")
100-
set(CMAKE_CXX_STANDARD 20)
103+
# set(CMAKE_CXX_STANDARD 20)
104+
set(CMAKE_CXX_STANDARD ${CMAKE_CXX_SCIKITLEARN})
101105
elseif(CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL "9")
102106
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
103-
set(CMAKE_CXX_STANDARD 17)
107+
# set(CMAKE_CXX_STANDARD 17)
108+
set(CMAKE_CXX_STANDARD ${CMAKE_CXX_SCIKITLEARN})
104109
elseif(CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL "6")
105110
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14")
106-
set(CMAKE_CXX_STANDARD 14)
111+
# set(CMAKE_CXX_STANDARD 14)
112+
set(CMAKE_CXX_STANDARD ${CMAKE_CXX_SCIKITLEARN})
107113
else()
108114
message(FATAL_ERROR "gcc>=6.0 is needed but "
109115
"${CMAKE_C_COMPILER_VERSION} was detected.")

_cmake/finalize.cmake

+35-1
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,46 @@ else()
99
config_content_cuda "HAS_CUDA = 0")
1010
endif()
1111

12+
execute_process(
13+
COMMAND ${Python3_EXECUTABLE} -c "import cython;print(cython.__version__)"
14+
OUTPUT_VARIABLE CYTHON_VERSION
15+
ERROR_VARIABLE CYTHON_version_error
16+
RESULT_VARIABLE CYTHON_version_result
17+
OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_STRIP_TRAILING_WHITESPACE)
18+
19+
execute_process(
20+
COMMAND ${Python3_EXECUTABLE} -c "import sklearn;print(sklearn.__version__)"
21+
OUTPUT_VARIABLE SKLEARN_VERSION
22+
ERROR_VARIABLE SKLEARN_version_error
23+
RESULT_VARIABLE SKLEARN_version_result
24+
OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_STRIP_TRAILING_WHITESPACE)
25+
26+
execute_process(
27+
COMMAND ${Python3_EXECUTABLE} -c "import numpy;print(numpy.__version__)"
28+
OUTPUT_VARIABLE NUMPY_VERSION
29+
ERROR_VARIABLE NUMPY_version_error
30+
RESULT_VARIABLE NUMPY_version_result
31+
OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_STRIP_TRAILING_WHITESPACE)
32+
33+
execute_process(
34+
COMMAND ${Python3_EXECUTABLE} -c "import scipy;print(scipy.__version__)"
35+
OUTPUT_VARIABLE SCIPY_VERSION
36+
ERROR_VARIABLE SCIPY_version_error
37+
RESULT_VARIABLE SCIPY_version_result
38+
OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_STRIP_TRAILING_WHITESPACE)
39+
1240
set(
1341
config_content_comma
1442
"${config_content_cuda}"
1543
"\nCXX_FLAGS = '${CMAKE_CXX_FLAGS}'"
1644
"\nCMAKE_CXX_STANDARD_REQUIRED = '${CMAKE_CXX_STANDARD_REQUIRED}'"
1745
"\nCMAKE_CXX_EXTENSIONS = '${CMAKE_CXX_EXTENSIONS}'"
18-
"\nCMAKE_CXX_STANDARD = ${CMAKE_CXX_STANDARD}\n")
46+
"\nCMAKE_CXX_STANDARD = ${CMAKE_CXX_STANDARD}"
47+
"\n\n# Was compiled with the following versions."
48+
"\nCYTHON_VERSION = '${CYTHON_VERSION}'"
49+
"\nSKLEARN_VERSION = '${SKLEARN_VERSION}'"
50+
"\nNUMPY_VERSION = '${NUMPY_VERSION}'"
51+
"\nSCIPY_VERSION = '${SCIPY_VERSION}'"
52+
"\n")
1953

2054
string(REPLACE ";" "" config_content "${config_content_comma}")

_doc/api/batch.rst

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
mlinsights.mlbatch
2+
==================
13

2-
Speed up batch training
3-
=======================
4+
This was written for older version of scikit-learn and never
5+
revisited since. It may not bring much value.
46

57
MLCache
68
+++++++

_doc/api/blaslapack.rst

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11

2-
Blas & Lapack
3-
=============
2+
mlinsights.mlmodel.direct_blas_lapack
3+
=====================================
44

5-
.. contents::
6-
:local:
5+
A simple try to call scipy blas and lapack function from cython.
76

87
Lapack
98
++++++

_doc/api/helpers.rst

+2-5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11

2-
Helpers
3-
=======
4-
5-
.. contents::
6-
:local:
2+
mlinsights.helpers
3+
==================
74

85
Formatting
96
++++++++++

_doc/api/index.rst

+4-1
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@ API
33
===
44

55
.. toctree::
6+
:maxdepth: 2
67

78
plotting
89
helpers
910
metrics
1011
batch
1112
mlmodel
12-
tree
13+
mlmodel_tree
1314
search_rank
15+
sklapi
16+
tree
1417
blaslapack
1518
timeseries

_doc/api/metrics.rst

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
2-
metrics
3-
=======
1+
mlinsights.metrics
2+
==================
43

54
.. autofunction:: mlinsights.metrics.correlations.non_linear_correlations

_doc/api/mlmodel.rst

+3-60
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
1-
=======================
2-
Machine Learning Models
3-
=======================
4-
5-
.. contents::
6-
:local:
1+
==================
2+
mlinsights.mlmodel
3+
==================
74

85
Helpers
96
=======
@@ -67,12 +64,6 @@ PiecewiseRegressor
6764
.. autoclass:: mlinsights.mlmodel.piecewise_estimator.PiecewiseRegressor
6865
:members:
6966

70-
PiecewiseTreeRegressor
71-
++++++++++++++++++++++
72-
73-
.. autoclass:: mlinsights.mlmodel.piecewise_tree_regression.PiecewiseTreeRegressor
74-
:members:
75-
7667
QuantileMLPRegressor
7768
++++++++++++++++++++
7869

@@ -160,57 +151,9 @@ TraceableTfidfVectorizer
160151
.. autoclass:: mlinsights.mlmodel.sklearn_text.TraceableTfidfVectorizer
161152
:members:
162153

163-
Exploration
164-
===========
165-
166-
The following implementation play with :epkg:`scikit-learn`
167-
API, it overwrites the code handling parameters.
168-
169-
SkBaseTransformLearner
170-
++++++++++++++++++++++
171-
172-
.. autoclass:: mlinsights.sklapi.sklearn_base_transform_learner.SkBaseTransformLearner
173-
:members:
174-
175-
SkBaseTransformStacking
176-
+++++++++++++++++++++++
177-
178-
.. autoclass:: mlinsights.sklapi.sklearn_base_transform_stacking.SkBaseTransformStacking
179-
:members:
180-
181154
Exploration in C
182155
================
183156

184-
The following classes require :epkg:`scikit-learn` *>= 1.3.0*,
185-
otherwise, they do not get compiled.
186-
187-
SimpleRegressorCriterion
188-
++++++++++++++++++++++++
189-
190-
.. autoclass:: mlinsights.mlmodel.piecewise_tree_regression_criterion.SimpleRegressorCriterion
191-
:members:
192-
193-
SimpleRegressorCriterionFast
194-
++++++++++++++++++++++++++++
195-
196-
A similar design but a much faster implementation close to what
197-
:epkg:`scikit-learn` implements.
198-
199-
.. autoclass:: mlinsights.mlmodel.piecewise_tree_regression_criterion_fast.SimpleRegressorCriterionFast
200-
:members:
201-
202-
LinearRegressorCriterion
203-
++++++++++++++++++++++++
204-
205-
The next one implements a criterion which optimizes the mean square error
206-
assuming the points falling into one node of the tree are approximated by
207-
a line. The mean square error is the error made with a linear regressor
208-
and not a constant anymore. The documentation will be completed later.
209-
210-
`mlinsights.mlmodel.piecewise_tree_regression_criterion_linear.LinearRegressorCriterion`
211-
212-
`mlinsights.mlmodel.piecewise_tree_regression_criterion_linear_fast.SimpleRegressorCriterionFast`
213-
214157
Losses
215158
++++++
216159

_doc/api/mlmodel_tree.rst

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
==========================
2+
mlinsights.mlmodel (trees)
3+
==========================
4+
5+
.. _blog-internal-api-impurity-improvement:
6+
7+
Note about potentiel issues
8+
===========================
9+
10+
The main estimator `PiecewiseTreeRegressor` is based on the implementation
11+
on new criterion. It relies on a non-public API and as such is more likely
12+
to break. The unit test are unstable. They work when *scikit-learn*
13+
and this package are compiled with the same set of tools. If installed
14+
from PyPi, you can check which versions were used at compilation time.
15+
16+
.. runpython::
17+
:showcode:
18+
19+
from mlinsights._config import (
20+
CYTHON_VERSION,
21+
NUMPY_VERSION,
22+
SCIPY_VERSION,
23+
SKLEARN_VERSION,
24+
)
25+
print(f"CYTHON_VERSION: {CYTHON_VERSION}")
26+
print(f"NUMPY_VERSION: {NUMPY_VERSION}")
27+
print(f"SCIPY_VERSION: {SCIPY_VERSION}")
28+
print(f"SKLEARN_VERSION: {SKLEARN_VERSION}")
29+
30+
31+
The signature of method *impurity_improvement* has changed in version 0.24.
32+
That's usually easy to handle two versions of *scikit-learn* even overloaded
33+
in a class except that method is implemented in cython.
34+
The method must be overloaded the same way with the same signature.
35+
Tricks such as `*args` or `**kwargs` cannot be used.
36+
The way it was handled is implemented in
37+
PR `88 <https://github.com/sdpython/mlinsights/pull/88>`_.
38+
39+
Estimators
40+
==========
41+
42+
PiecewiseTreeRegressor
43+
++++++++++++++++++++++
44+
45+
.. autoclass:: mlinsights.mlmodel.piecewise_tree_regression.PiecewiseTreeRegressor
46+
:members:
47+
48+
Criterions
49+
==========
50+
51+
The following classes require :epkg:`scikit-learn` *>= 1.3.0*,
52+
otherwise, they do not get compiled. Section :ref:`blog-internal-api-impurity-improvement`
53+
explains why the execution may crash.
54+
55+
SimpleRegressorCriterion
56+
++++++++++++++++++++++++
57+
58+
.. autoclass:: mlinsights.mlmodel.piecewise_tree_regression_criterion.SimpleRegressorCriterion
59+
:members:
60+
61+
SimpleRegressorCriterionFast
62+
++++++++++++++++++++++++++++
63+
64+
A similar design but a much faster implementation close to what
65+
:epkg:`scikit-learn` implements.
66+
67+
.. autoclass:: mlinsights.mlmodel.piecewise_tree_regression_criterion_fast.SimpleRegressorCriterionFast
68+
:members:
69+
70+
LinearRegressorCriterion
71+
++++++++++++++++++++++++
72+
73+
The next one implements a criterion which optimizes the mean square error
74+
assuming the points falling into one node of the tree are approximated by
75+
a line. The mean square error is the error made with a linear regressor
76+
and not a constant anymore. The documentation will be completed later.
77+
78+
`mlinsights.mlmodel.piecewise_tree_regression_criterion_linear.LinearRegressorCriterion`
79+
80+
`mlinsights.mlmodel.piecewise_tree_regression_criterion_linear_fast.SimpleRegressorCriterionFast`

_doc/api/plotting.rst

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
2-
plotting
3-
========
1+
mlinsights.plotting
2+
===================
43

54
.. autofunction:: mlinsights.plotting.gallery.plot_gallery_images
65

_doc/api/search_rank.rst

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
2-
search_rank
3-
===========
1+
mlinsights.search_rank
2+
======================
43

54
SearchEngineVectors
65
+++++++++++++++++++

0 commit comments

Comments
 (0)