Skip to content

Feature ml patch #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Sep 17, 2024
1 change: 0 additions & 1 deletion MaCh3PythonUtils/machine_learning/fml_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from sklearn import metrics
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import scipy.stats as stats
import numpy as np

import pickle
Expand Down
4 changes: 2 additions & 2 deletions MaCh3PythonUtils/machine_learning/ml_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ def __setup_package_factory(self, package: str, algorithm: str, **kwargs):

return self.__IMPLEMENTED_ALGORITHMS[package][algorithm](*kwargs)

def setup_scikit_model(self, algorithm: str, **kwargs)->SciKitInterface:
def make_scikit_model(self, algorithm: str, **kwargs)->SciKitInterface:
# Simple wrapper for scikit packages
interface = SciKitInterface(self._chain, self._prediction_variable)
interface.add_model(self.__setup_package_factory(package="scikit", algorithm=algorithm, **kwargs))
return interface

def setup_tensorflow_model(self, algorithm: str, **kwargs):
def make_tensorflow_model(self, algorithm: str, **kwargs):
interface = TfInterface(self._chain, self._prediction_variable)

interface.add_model(self.__setup_package_factory(package="tensorflow", algorithm=algorithm))
Expand Down
1 change: 0 additions & 1 deletion configs/tensorflow_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,3 @@ MLSettings:

FitSettings:
batch_size: 100
epochs: 200
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ arviz
mplhep
tqdm
numba
jinja2
jinja2