Skip to content

Commit 392ea04

Browse files
Feature ml patch (#10)
* adds scalar, makes demo config config more helpful * more config updates * Adds basic version of diagnostics package from T2K MaCh3, will not work just yet... * update gitignore * Implements plotting lib * Small patches' * More small fixes to naming * EVEN MORE REQUIREMENTS * Switch off deprecation warning * Oops summary diags * Oops, fixes name change
1 parent 2323e74 commit 392ea04

File tree

4 files changed

+3
-5
lines changed

4 files changed

+3
-5
lines changed

MaCh3PythonUtils/machine_learning/fml_interface.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from sklearn import metrics
1010
from sklearn.preprocessing import StandardScaler
1111
import matplotlib.pyplot as plt
12-
import scipy.stats as stats
1312
import numpy as np
1413

1514
import pickle

MaCh3PythonUtils/machine_learning/ml_factory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,13 @@ def __setup_package_factory(self, package: str, algorithm: str, **kwargs):
5151

5252
return self.__IMPLEMENTED_ALGORITHMS[package][algorithm](*kwargs)
5353

54-
def setup_scikit_model(self, algorithm: str, **kwargs)->SciKitInterface:
54+
def make_scikit_model(self, algorithm: str, **kwargs)->SciKitInterface:
5555
# Simple wrapper for scikit packages
5656
interface = SciKitInterface(self._chain, self._prediction_variable)
5757
interface.add_model(self.__setup_package_factory(package="scikit", algorithm=algorithm, **kwargs))
5858
return interface
5959

60-
def setup_tensorflow_model(self, algorithm: str, **kwargs):
60+
def make_tensorflow_model(self, algorithm: str, **kwargs):
6161
interface = TfInterface(self._chain, self._prediction_variable)
6262

6363
interface.add_model(self.__setup_package_factory(package="tensorflow", algorithm=algorithm))

configs/tensorflow_config.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,3 @@ MLSettings:
4444

4545
FitSettings:
4646
batch_size: 100
47-
epochs: 200

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ arviz
1212
mplhep
1313
tqdm
1414
numba
15-
jinja2
15+
jinja2

0 commit comments

Comments
 (0)