From cdc1477a655492bb98987b4d5ec50651e218dbf1 Mon Sep 17 00:00:00 2001 From: henry-israel Date: Tue, 17 Sep 2024 13:40:43 +0100 Subject: [PATCH 01/11] adds scalar, makes demo config config more helpful --- MaCh3PythonUtils/machine_learning/fml_interface.py | 4 ++++ MaCh3PythonUtils/machine_learning/scikit_interface.py | 5 ++++- MaCh3PythonUtils/machine_learning/tf_interface.py | 7 +++++-- configs/tensorflow_config.yml | 2 +- 4 files changed, 14 insertions(+), 4 deletions(-) diff --git a/MaCh3PythonUtils/machine_learning/fml_interface.py b/MaCh3PythonUtils/machine_learning/fml_interface.py index 86d99bc..6a22cd6 100644 --- a/MaCh3PythonUtils/machine_learning/fml_interface.py +++ b/MaCh3PythonUtils/machine_learning/fml_interface.py @@ -7,6 +7,7 @@ from matplotlib.colors import LinearSegmentedColormap from sklearn import metrics +from sklearn.preprocessing import StandardScaler import matplotlib.pyplot as plt import scipy.stats as stats import numpy as np @@ -40,6 +41,7 @@ def __init__(self, chain: ChainHandler, prediction_variable: str) -> None: self._training_labels=None self._test_data=None self._test_labels=None + self._scalar = StandardScaler() def __separate_dataframe(self)->Tuple[pd.DataFrame, pd.DataFrame]: # Separates dataframe into features + labels @@ -52,6 +54,8 @@ def set_training_test_set(self, test_size: float): # Splits in traing + test_spit features, labels = self.__separate_dataframe() self._training_data, self._test_data, self._training_labels, self._test_labels = train_test_split(features, labels, test_size=test_size) + self._training_data = self._scalar.fit_transform(self._training_data) + @property def model(self)->Any: diff --git a/MaCh3PythonUtils/machine_learning/scikit_interface.py b/MaCh3PythonUtils/machine_learning/scikit_interface.py index 719f65e..3564bcb 100644 --- a/MaCh3PythonUtils/machine_learning/scikit_interface.py +++ b/MaCh3PythonUtils/machine_learning/scikit_interface.py @@ -26,8 +26,11 @@ def train_model(self): self._model.fit(self._training_data, self._training_labels) def model_predict(self, test_data: DataFrame): + + scale_data = self._scalar.transform(test_data) + if self._model is None: raise ValueError("No Model has been set!") - return self._model.predict(test_data) + return self._model.predict(scale_data) \ No newline at end of file diff --git a/MaCh3PythonUtils/machine_learning/tf_interface.py b/MaCh3PythonUtils/machine_learning/tf_interface.py index c61bd67..2f5c4c5 100644 --- a/MaCh3PythonUtils/machine_learning/tf_interface.py +++ b/MaCh3PythonUtils/machine_learning/tf_interface.py @@ -48,6 +48,9 @@ def save_model(self, output_file: str): def load_model(self, input_file: str): self._model = tf.saved_model.load(input_file) - def model_predict(self, testing_data): + def model_predict(self, test_data): + + scale_data = self._scalar.transform(test_data) + # Hacky but means it's consistent with sci-kit interface - return self._model.predict_on_batch(testing_data).T[0] + return self._model.predict_on_batch(scale_data).T[0] diff --git a/configs/tensorflow_config.yml b/configs/tensorflow_config.yml index 7fbc6c3..6cf312b 100644 --- a/configs/tensorflow_config.yml +++ b/configs/tensorflow_config.yml @@ -36,4 +36,4 @@ FitterSettings: FitSettings: batch_size: 100 - epochs: 10 \ No newline at end of file + epochs: 200 \ No newline at end of file From 9f4e74dd0817a8f65b2fc2cbfbc33df80a1125c7 Mon Sep 17 00:00:00 2001 From: henry-israel Date: Tue, 17 Sep 2024 13:54:34 +0100 Subject: [PATCH 02/11] more config updates --- configs/tensorflow_config.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/configs/tensorflow_config.yml b/configs/tensorflow_config.yml index 6cf312b..9d2d0f2 100644 --- a/configs/tensorflow_config.yml +++ b/configs/tensorflow_config.yml @@ -15,7 +15,7 @@ FitterSettings: FitterPackage : "TensorFlow" FitterName : "Sequential" - TestSize : 0.9 + TestSize : 0.4 FitterKwargs: BuildSettings: @@ -25,12 +25,15 @@ FitterSettings: Layers: - dense: units: 128 + activation: 'relu' - dense: units: 64 + activation: 'relu' - dropout: rate: 0.5 - dense: units: 16 + activation: 'relu' - dense: units: 1 From 85b3be6c43dff0f7a41b6dfd0c71a300c12b7e05 Mon Sep 17 00:00:00 2001 From: henry-israel Date: Tue, 17 Sep 2024 13:55:19 +0100 Subject: [PATCH 03/11] Adds basic version of diagnostics package from T2K MaCh3, will not work just yet... --- .../diagnostics/interface/__init__.py | 1 + .../interface/plotting_interface.py | 134 ++++++++++++ .../diagnostics/plotters/__init__.py | 3 + .../plotters/diagnostics/__init__.py | 3 + .../autocorrelation_trace_plotter.py | 49 +++++ .../diagnostics/covariance_matrix_utils.py | 112 ++++++++++ .../plotters/diagnostics/simple_diag_plots.py | 73 +++++++ .../diagnostics/plotters/plotter_base.py | 200 ++++++++++++++++++ .../plotters/posteriors/__init__.py | 3 + .../posteriors/posterior_base_classes.py | 52 +++++ .../plotters/posteriors/posteriors_1d.py | 88 ++++++++ .../plotters/posteriors/posteriors_2d.py | 115 ++++++++++ 12 files changed, 833 insertions(+) create mode 100644 MaCh3PythonUtils/diagnostics/interface/__init__.py create mode 100644 MaCh3PythonUtils/diagnostics/interface/plotting_interface.py create mode 100644 MaCh3PythonUtils/diagnostics/plotters/__init__.py create mode 100644 MaCh3PythonUtils/diagnostics/plotters/diagnostics/__init__.py create mode 100644 MaCh3PythonUtils/diagnostics/plotters/diagnostics/autocorrelation_trace_plotter.py create mode 100644 MaCh3PythonUtils/diagnostics/plotters/diagnostics/covariance_matrix_utils.py create mode 100644 MaCh3PythonUtils/diagnostics/plotters/diagnostics/simple_diag_plots.py create mode 100644 MaCh3PythonUtils/diagnostics/plotters/plotter_base.py create mode 100644 MaCh3PythonUtils/diagnostics/plotters/posteriors/__init__.py create mode 100644 MaCh3PythonUtils/diagnostics/plotters/posteriors/posterior_base_classes.py create mode 100644 MaCh3PythonUtils/diagnostics/plotters/posteriors/posteriors_1d.py create mode 100644 MaCh3PythonUtils/diagnostics/plotters/posteriors/posteriors_2d.py diff --git a/MaCh3PythonUtils/diagnostics/interface/__init__.py b/MaCh3PythonUtils/diagnostics/interface/__init__.py new file mode 100644 index 0000000..18b84c2 --- /dev/null +++ b/MaCh3PythonUtils/diagnostics/interface/__init__.py @@ -0,0 +1 @@ +from .plotting_interface import plotting_interface \ No newline at end of file diff --git a/MaCh3PythonUtils/diagnostics/interface/plotting_interface.py b/MaCh3PythonUtils/diagnostics/interface/plotting_interface.py new file mode 100644 index 0000000..13dd1d2 --- /dev/null +++ b/MaCh3PythonUtils/diagnostics/interface/plotting_interface.py @@ -0,0 +1,134 @@ +''' +Interface class for making plots! +''' +from typing import List +from MaCh3_plot_lib.file_handlers import root_file_loader +import MaCh3_plot_lib.plotters as pt +from matplotlib.backends.backend_pdf import PdfPages +import arviz as az + +class plotting_interface: + ''' + full interface object for making plots + inputs: + file_loader : root_file_loader instance + ''' + def __init__(self, file_loader: root_file_loader): + ''' + Constructor object + ''' + self._file_loader = file_loader + self._plotter_object_dict = {} # dict of objects from plotting tools + + + def initialise_new_plotter(self, new_plotter: pt.plotter_base._plotting_base_class , plot_label: str)->None: + ''' + Adds new plot object to our array + inputs : + new_plotter : plotting object + plot_label : [type=str], how do we want to call this plot? + ''' + self._plotter_object_dict[plot_label] = new_plotter + + def set_credible_intervals(self, credible_intervals: List[float])->None: + ''' + Sets set of credible intervals across all plots + inputs : + credible_intervals : [type=list[int]] sets up a list of credible intervals + ''' + + print(f"Setting credible intervals as {credible_intervals}") + + # bit of defensive programming + if not isinstance(credible_intervals, list): + raise ValueError(f"Cannot set credible intervals to {credible_intervals}") + + for plotter in list(self._plotter_object_dict.values()): + if not isinstance(plotter, pt.posterior_base_classes._posterior_plotting_base): + continue + + # set credible intervals + plotter.credible_intervals = credible_intervals + + def set_variables_to_plot(self, plot_variables, plot_labels: List[str]=[]): + ''' + Sets variables we actually want to plot for a subset of plots + ''' + for plotter in plot_labels: + self._plotter_object_dict[plotter].plot_params = plot_variables + + + def set_is_multimodal(self, param_ids: List[int | str]): + ''' + Lets posteriors know which parameters are multimodal + inputs: + param_ids : list of multi-modal parameter ids/names + ''' + print(f"Setting {param_ids} to be multimodal") + + # Loop over our plotters + for plotter in self._plotter_object_dict.values(): + if not isinstance(plotter, pt._posterior_plotting_base): + continue + + plotter.set_pars_multimodal(param_ids) + + + def set_is_circular(self, param_ids: List[int | str]): + ''' + Lets posteriors know which parameters are multimodal + inputs: + param_ids : list of multi-modal parameter ids/names + ''' + print(f"Setting {param_ids} to be circular") + + # Loop over our plotters + for plotter in self._plotter_object_dict.values(): + if not isinstance(plotter, pt.posterior_base_classes._posterior_plotting_base): + continue + + plotter.set_pars_circular(param_ids) + + def add_text_to_plots(self, text: str, location: tuple=(0.05, 0.95)): + for plotter in self._plotter_object_dict.values(): + plotter.add_text_to_figures(text, location) + + def make_plots(self, output_file_name: str, plot_labels: List[str] | str): + ''' + Outputs all plots from a list of labels to an output PDF + inputs : + output_file_name : [str] -> Output file pdf + plot_labels : names of plots in self._plotter_object_dict + ''' + # Cast our labels to hist + if not isinstance(plot_labels, list): + plot_labels = list(plot_labels) + + with PdfPages(output_file_name) as pdf_file: + for plotter_id in plot_labels: + try: + plotter_obj = self._plotter_object_dict.get(plotter_id) + except KeyError: + print(f"Warning:Key not found {plotter_id}, skipping") + continue + + print(f"Generating Plots for {plotter_id}") + plotter_obj.generate_all_plots() + print(f"Printing to {output_file_name}") + plotter_obj.write_to_pdf(existing_pdf_fig=pdf_file) + + def print_summary(self, latex_output_name:str=None): + ''' + Print stats summary to terminal and output as a LaTeX table [text file] + inputs : + latex_output_name : [type=str, optional] name of output file + ''' + summary = az.summary(self._file_loader.ttree_array, kind='stats', hdi_prob=0.9) + if latex_output_name is None: + return + + print(summary) + + with open(latex_output_name, "w") as output_file: + output_file.write(summary.to_latex()) + \ No newline at end of file diff --git a/MaCh3PythonUtils/diagnostics/plotters/__init__.py b/MaCh3PythonUtils/diagnostics/plotters/__init__.py new file mode 100644 index 0000000..50c30a9 --- /dev/null +++ b/MaCh3PythonUtils/diagnostics/plotters/__init__.py @@ -0,0 +1,3 @@ +from .plotter_base import _plotting_base_class +from .posteriors import * +from .diagnostics import * \ No newline at end of file diff --git a/MaCh3PythonUtils/diagnostics/plotters/diagnostics/__init__.py b/MaCh3PythonUtils/diagnostics/plotters/diagnostics/__init__.py new file mode 100644 index 0000000..e001b6d --- /dev/null +++ b/MaCh3PythonUtils/diagnostics/plotters/diagnostics/__init__.py @@ -0,0 +1,3 @@ +from .autocorrelation_trace_plotter import * +from .covariance_matrix_utils import * +from .simple_diag_plots import * \ No newline at end of file diff --git a/MaCh3PythonUtils/diagnostics/plotters/diagnostics/autocorrelation_trace_plotter.py b/MaCh3PythonUtils/diagnostics/plotters/diagnostics/autocorrelation_trace_plotter.py new file mode 100644 index 0000000..adca0b7 --- /dev/null +++ b/MaCh3PythonUtils/diagnostics/plotters/diagnostics/autocorrelation_trace_plotter.py @@ -0,0 +1,49 @@ +''' +HI : Class to make autocorrelations and traces, puts them all onto a single plot +''' +from MaCh3_plot_lib.file_handlers import root_file_loader +import arviz as az +from matplotlib import pyplot as plt +from MaCh3_plot_lib.plotters.plotter_base import _plotting_base_class +from matplotlib.figure import Figure + +class autocorrelation_trace_plotter(_plotting_base_class): + def __init__(self, file_loader: root_file_loader)->None: + # Constructor + super().__init__(file_loader) + + def _generate_plot(self, parameter_name: str) -> Figure: + ''' + Makes a combined trace and auto-correlation plot + inputs : + parameter_name : [type=str] Single parameter name + ''' + # Setup axes + fig, (trace_ax, autocorr_ax) = plt.subplots(nrows=2, sharex=False) + + # We want the numpy array containing our parameter + param_array = self._file_loader.ttree_array[parameter_name].to_numpy()[0] + + # Okay now we can plot our trace (might as well!) + trace_ax.plot(param_array, linewidth=0.05, color='purple') + + #next we need to grab our autocorrelations + # auto_corr = sm.tsa.acf(param_array, nlags=total_lags) + auto_corr = az.autocorr(param_array) + autocorr_ax.plot(auto_corr, color='purple') + # Now we do some tidying + trace_ax.set_ylabel(f"{parameter_name} variation", fontsize=11) + trace_ax.set_xlabel("Step", fontsize=11) + trace_ax.set_title(f"Trace for {parameter_name}", fontsize=15) + trace_ax.tick_params(labelsize=10) + # autocorr_ax.set_ylabel("Autocorrelation Function") + autocorr_ax.set_xlabel("Lag", fontsize=12) + autocorr_ax.set_ylabel("Autocorrelation", fontsize=12) + autocorr_ax.set_title(f"Autocorrelation for {parameter_name}", fontsize=15, verticalalignment='center_baseline') + autocorr_ax.tick_params(labelsize=10) + fig.suptitle(f"{parameter_name} diagnostics")#, fontsize=40) + # fig.subplots_adjust(hspace=0.01) + fig.subplots_adjust(wspace=0.0, hspace=0.3) + + plt.close() + return fig \ No newline at end of file diff --git a/MaCh3PythonUtils/diagnostics/plotters/diagnostics/covariance_matrix_utils.py b/MaCh3PythonUtils/diagnostics/plotters/diagnostics/covariance_matrix_utils.py new file mode 100644 index 0000000..87b887a --- /dev/null +++ b/MaCh3PythonUtils/diagnostics/plotters/diagnostics/covariance_matrix_utils.py @@ -0,0 +1,112 @@ +''' +Additional Diagnostics that can be used with MCMC but don't rely on plotting + +Suboptimality : https://www.jstor.org/stable/25651249?seq=3 +''' + +from MaCh3_plot_lib.file_handlers import root_file_loader +import numpy as np +from scipy.linalg import sqrtm +from tqdm.auto import tqdm +from matplotlib import pyplot as plt +from matplotlib.backends.backend_pdf import PdfPages +import warnings +import mplhep as hep +from concurrent.futures import ThreadPoolExecutor, as_completed + + +class covariance_matrix_utils: + def __init__(self, file_loader: root_file_loader)->None: + ''' + For calculating the covariance matrix + suboptimality + inputs: + ->file_loader : [type=root_file_loader] file handler object + ''' + # Let's just ignore some warnings :grin: + warnings.filterwarnings("ignore", category=DeprecationWarning) #Some imports are a little older + warnings.filterwarnings("ignore", category=UserWarning) #Some imports are a little older + + file_loader.get_ttree_as_arviz() # Ensure that we have things in the correct format + self._parameter_names = list(file_loader.ttree_array.keys()) + self._total_parameters = len(self._parameter_names) + self._ttree_data_frame = file_loader.ttree_array.to_dataframe() # All our handling will be done using thi s object + # Various useful class properties + self._suboptimality_array = [] # For filling with suboptimality + self._suboptimality_evaluation_points = [] #Where did we evalutate this? + self._full_covariance = self.calculate_covariance_matrix() + self._sqrt_full_covariance_inv = np.linalg.inv(sqrtm(self._full_covariance)) + plt.style.use(hep.style.ROOT) + + def calculate_covariance_matrix(self, min_step: int=0, max_step: int=-1)->np.ndarray: + ''' + Calculates covariance matrix for chain between indices min_step and max_step + inputs: + -> min_step : [type=int] minimum index to calculate covariance from + -> max_step : [type=int] maximum index to calculate covariance to + returns + -> covariance matrix + ''' + if(max_step<=0): + sub_array = self._ttree_data_frame[min_step : ] + else: + sub_array = self._ttree_data_frame[min_step : max_step] + + + return sub_array.cov().to_numpy() + + def _calculate_matrix_suboptimality(self, step_number: int)->float: + ''' + Calcualte suboptimality for a given covariance matrix + inputs: + -> how many steps in are we calculating this?? + returns: + -> suboptimality value + ''' + new_covariance = self.calculate_covariance_matrix(max_step=step_number) + + sqrt_input_cov = sqrtm(new_covariance) + + # Get product of square roots + matrix_prod = sqrt_input_cov @ self._sqrt_full_covariance_inv + + #Get eigen values + eigenvalues, _ = np.linalg.eig(matrix_prod) + + return self._total_parameters * np.sum(eigenvalues**(-2))/((np.sum(eigenvalues)**(-1))**2) + + def calculate_suboptimality(self, step_skip : int = 1000, min_step=0)->None: + ''' + Calculates the suboptimalit for every step_skip steps + inputs : + -> step_skip : Number of steps to skip + -> min_step : smallest number of starting steps + ''' + self._suboptimality_evaluation_points = np.arange(min_step, len(self._ttree_data_frame), step_skip) + # Make sure we have the last step as well! + self._suboptimality_evaluation_points = np.append(self._suboptimality_evaluation_points, len(self._ttree_data_frame)-1) + # make array to fill with suboptimality values + self._suboptimality_array = np.zeros(len(self._suboptimality_evaluation_points)) # Reset our suboptimality values + print(f"Calculating suboptimality for {len(self._suboptimality_evaluation_points)} points") + + # Lets speed this up a bit! + with ThreadPoolExecutor() as executor: + futures = {executor.submit(self._calculate_matrix_suboptimality, step): step + for step in self._suboptimality_evaluation_points} + + for future in tqdm(as_completed(futures), ascii="▖▘▝▗▚▞█", total=len(self._suboptimality_evaluation_points)): + i = np.where(self._suboptimality_evaluation_points==futures[future])[0] + self._suboptimality_array[i]=future.result() + + + def plot_suboptimality(self, output_file: str): + print(f"Saving to {output_file}") + with PdfPages(output_file) as pdf: + fig, axes = plt.subplots() + + axes.plot(self._suboptimality_evaluation_points, self._suboptimality_array) + axes.set_xlabel("Step Number") + axes.set_label("Suboptimality") + axes.set_title("Suboptimality Plot") + axes.set_yscale('log') + + pdf.savefig(fig) diff --git a/MaCh3PythonUtils/diagnostics/plotters/diagnostics/simple_diag_plots.py b/MaCh3PythonUtils/diagnostics/plotters/diagnostics/simple_diag_plots.py new file mode 100644 index 0000000..ba6b8c4 --- /dev/null +++ b/MaCh3PythonUtils/diagnostics/plotters/diagnostics/simple_diag_plots.py @@ -0,0 +1,73 @@ +''' +HI : Several simple diagnostics +''' + +from typing import List +from MaCh3_plot_lib.file_handlers import root_file_loader +import arviz as az +from matplotlib import pyplot as plt +from MaCh3_plot_lib.plotters.plotter_base import _plotting_base_class +from matplotlib.figure import Figure + +class effective_sample_size_plotter(_plotting_base_class): + ''' + Caclulates effective sample size : https://arxiv.org/pdf/1903.08008.pdf + ''' + def __init__(self, file_loader: root_file_loader)->None: + # Constructor + super().__init__(file_loader) + + def _generate_plot(self, parameter_name: str) -> Figure: + ''' + Makes effective sample size + inputs : + parameter_name : [type=str] Single parameter name + outputs : + Figure + ''' + fig, axes = plt.subplots() + + az.plot_ess(self._file_loader.ttree_array, var_names=parameter_name, + ax=axes, textsize=30, color='purple', drawstyle="steps-mid", linestyle="-") + plt.close() + return fig + +class markov_chain_standard_error(_plotting_base_class): + ''' + Calculates Markov Chain Standard Error : https://arxiv.org/pdf/1903.08008.pdf + ''' + def __init__(self, file_loader: root_file_loader)->None: + # Constructor + super().__init__(file_loader) + + def _generate_plot(self, parameter_name: str) -> Figure: + ''' + Makes MCSE + inputs : + parameter_name : [type=str] Single parameter name + outputs : + Figure + ''' + fig, axes = plt.subplots() + az.plot_mcse(self._file_loader.ttree_array, var_names=parameter_name, ax=axes, textsize=10, color='purple') + plt.close() + return fig + + +class violin_plotter(_plotting_base_class): + # Class to generate Violin Plots + def __init__(self, file_loader: root_file_loader)->None: + # Constructor + super().__init__(file_loader) + + def _generate_plot(self, parameter_name: str | List[str]) -> Figure: + ''' + Generates a plot for a single parameter + ''' + # total number of axes we need + fig, axes = plt.subplots() + az.plot_violin(self._file_loader.ttree_array, + var_names=parameter_name, ax=axes, textsize=10, + shade_kwargs={'color':'purple'}) + plt.close() + return fig \ No newline at end of file diff --git a/MaCh3PythonUtils/diagnostics/plotters/plotter_base.py b/MaCh3PythonUtils/diagnostics/plotters/plotter_base.py new file mode 100644 index 0000000..0282289 --- /dev/null +++ b/MaCh3PythonUtils/diagnostics/plotters/plotter_base.py @@ -0,0 +1,200 @@ +''' +HI : Mostly abstract base class, contains common methods for use by most other plotting classes +''' + +from MaCh3_plot_lib.file_handlers import root_file_loader +import arviz as az +from matplotlib import pyplot as plt +from matplotlib.figure import Figure +from matplotlib.backends.backend_pdf import PdfPages +import numpy as np +import warnings +from abc import ABC, abstractmethod +from typing import List +import mplhep as hep +from tqdm.auto import tqdm +from concurrent.futures import ThreadPoolExecutor, as_completed +import numpy.typing as npt +from typing import Any +from collections.abc import Iterable + + +# Base class with common methods +class _plotting_base_class(ABC): + ''' + Abstract class with common methods for creating plots + inputs: + file_loader : [type=root_file_loader] root_file_loader class instance + ''' + def __init__(self, file_loader: root_file_loader)->None: + ''' + Constructor + ''' + # I'm very lazy but this should speed up parallelisation since we won't have this warning + plt.rcParams.update({'figure.max_open_warning': 0}) + + file_loader.get_ttree_as_arviz() # Ensure that we have things in the correct format + self._file_loader=file_loader # Make sure we don't use crazy amount of memory + # Setup plotting styles + az.style.use(hep.style.ROOT) + plt.style.use(hep.style.ROOT) + + self._figure_list = np.empty([]) # List of all figures + + # Let's just ignore some warnings :grin: + self._all_plots_generated = False # Stops us making plots multiple times + + # Setting it like this replicates the old behaviour! + self._circular_params=np.ones(len(self._file_loader.ttree_array)).astype(bool) + warnings.filterwarnings("ignore", category=DeprecationWarning) #Some imports are a little older + warnings.filterwarnings("ignore", category=UserWarning) #Some imports are a little older + + az.utils.Dask().enable_dask(dask_kwargs={"dask": "parallelized", "output_dtypes": [float]}) # let it be a bit cleverer + az.utils.Numba().enable_numba() + + # Do we want figure text? + self._figure_text = None + self._text_location = (0.05, 0.85) + + # Default option + self._params_to_plot = list(self._file_loader.ttree_array.keys()) + + def __str__(self) -> str: + return "plotting_base_class" + + @property + def figure_list(self)->npt.NDArray[Any]: + return self._figure_list + + @figure_list.setter + def figure_list(self, new_figure)->None: + raise NotImplementedError("Cannot set figure list using property!") + + @abstractmethod + def _generate_plot(self, parameter_name: str | List[str])->None: + # Abstract method to generate a single plot + pass + + def generate_all_plots(self)->None: + # Generates plots for every parameter [can be overwritten] + if self._all_plots_generated: # No need to make all possible plots again! + return + + self._figure_list = np.empty(len(self._params_to_plot), dtype=Figure) + + # Parallelised loop + with ThreadPoolExecutor() as executor: + # Set of threadpools + futures = {executor.submit(self._generate_plot, param) : param for param in self._params_to_plot} + # Begin loop + for future in tqdm(as_completed(futures), ascii="▖▘▝▗▚▞█", total=len(self._params_to_plot)): + param_id = list(futures).index(future) # Unique plot ID + self._figure_list[param_id] = future.result() + plt.close() + + self._all_plots_generated=True + + # Lets us select a subset/add list of parameters we'd like to plot + @property + def plot_params(self)-> List[str] | List[List[str]]: + ''' + Getter for parameters we want to plot + ''' + return self._params_to_plot + + @plot_params.setter + def plot_params(self, new_plot_parameter_list: List[str]|List[List[str]]): + if len(new_plot_parameter_list)==0: + raise ValueError("Parameter list cannot have length 0") + + # Check our new parameters are in our list of keys + if isinstance(new_plot_parameter_list[0], str): + for parameter in new_plot_parameter_list: + self._parameter_not_found_error(parameter) + + + elif isinstance(new_plot_parameter_list[0][0], str): + for param_list in new_plot_parameter_list: + for param in param_list: + print(param) + self._parameter_not_found_error(param) + + else: + raise ValueError("Plot params must be of type List[str] or List[List[str]]") + + self._params_to_plot = new_plot_parameter_list + + + + def _parameter_not_found_error(self, parameter_name: str): + if parameter_name not in list(self._file_loader.ttree_array.keys()): + raise ValueError(f"{parameter_name} not in list of parameters!") + + def _get_param_index_from_name(self, parameter_name: str)->int: + # Gets index of parameter in our arviz array + + self._parameter_not_found_error(parameter_name) + param_id = list(self._file_loader.ttree_array.keys()).index(parameter_name) + return param_id + + def set_pars_circular(self, par_id_list: List[str] | List[int])->None: + ''' + Let the plotter know parameter set is cyclical + inputs: + par_id_list : List[str/int] List of Parameter indices or name + ''' + # If it's an int this is easy + if not isinstance(par_id_list, list): + par_id_list = list(par_id_list) + + for par_id in par_id_list: + if isinstance(par_id, int): + true_index = par_id + else: + true_index = self._get_param_index_from_name(par_id) + + self._circular_params[true_index] = True + + def add_text_to_figures(self, text: str, text_location: tuple=(0.05, 0.95))->None: + ''' + Add text to all figures + inputs: + text : Text to add + text_location : location of text + ''' + self._figure_text = text + self._text_location = text_location + + def write_to_pdf(self, output_pdf_name: str=None, existing_pdf_fig: PdfPages=None)->None: + ''' + Dump all our plots to PDF file must either set output name or existing_pdf_fig + + inputs : + output_pdf_name: [type=string] Output name for NEW pdf + existing_pdf_fig: [type=PDFPages] + returns : + pdf_fig: [type=PDFPages] pdf file reader [REMEMBER TO CLOSE THE PDF FILE AT THE END!] + ''' + if len(self._figure_list)==0: + return + + if output_pdf_name is not None and existing_pdf_fig is None: + pdf_file = PdfPages(output_pdf_name) + elif existing_pdf_fig is not None: + pdf_file = existing_pdf_fig + else: + raise ValueError("ERROR:Must set EITHER output_pdf_name OR existing_pdf_fig") + + # For some arrays we might want to make them 1D + if not isinstance(self._figure_list[0], Figure): + self._figure_list = [fig for sublist in self._figure_list for fig in sublist] + + + for fig in tqdm(self._figure_list, ascii=" ▖▘▝▗▚▞█"): + # Add text to all plots! + if self._figure_text is not None: + if len(fig.get_axes())==1: + fig.text(*self._text_location, self._figure_text, transform=fig.get_axes()[0].transAxes, fontsize=60, fontstyle='oblique') + + pdf_file.savefig(fig) + plt.close() diff --git a/MaCh3PythonUtils/diagnostics/plotters/posteriors/__init__.py b/MaCh3PythonUtils/diagnostics/plotters/posteriors/__init__.py new file mode 100644 index 0000000..f42f661 --- /dev/null +++ b/MaCh3PythonUtils/diagnostics/plotters/posteriors/__init__.py @@ -0,0 +1,3 @@ +from .posteriors_1d import * +from .posteriors_2d import * +from .posterior_base_classes import * diff --git a/MaCh3PythonUtils/diagnostics/plotters/posteriors/posterior_base_classes.py b/MaCh3PythonUtils/diagnostics/plotters/posteriors/posterior_base_classes.py new file mode 100644 index 0000000..e2ac4fa --- /dev/null +++ b/MaCh3PythonUtils/diagnostics/plotters/posteriors/posterior_base_classes.py @@ -0,0 +1,52 @@ +''' +HI : Contains a set of base classes with methods common to posterior plotting code +Separated into 1D and 2D-like objects to make life easier +''' +from MaCh3_plot_lib.file_handlers import root_file_loader +from MaCh3_plot_lib.plotters import _plotting_base_class +from typing import List +import numpy as np +from tqdm.auto import tqdm +import numpy.typing as npt + + +# Base class for all posterior plotters +class _posterior_plotting_base(_plotting_base_class): + # Small extension of _plotting_base class for posterior specific stuff + def __init__(self, file_loader: root_file_loader)->None: + # Setup additional features for posteriors + super().__init__(file_loader) + self._credible_intervals = np.array([0.6, 0.9, 0.95]) + self._parameter_multimodal = np.zeros(len(self._file_loader.ttree_array)).astype(bool) + + def __str__(self) -> str: + return "posterior_plotting_base_class" + + @property + def credible_intervals(self)->List[float]: + return self._credible_intervals + + @credible_intervals.setter + def credible_intervals(self, new_creds: List[float])->None: + # Sets credible intervals from list + self._credible_intervals = np.array(new_creds) # Make sure it's 1D + self._credible_intervals.sort() # Flatten it + + def set_pars_multimodal(self, par_id_list: List[str] | List[int], is_multimodal: bool=True)->None: + ''' + Let the plotter know parameter set is multimodal + inputs: + par_id_list : List[str/int] List of Parameter indices or name + is_multi_modal : Is the parameter multi-modal? + ''' + # If it's an int this is easy + if not isinstance(par_id_list, list): + par_id_list = list(par_id_list) + + for par_id in par_id_list: + if isinstance(par_id, int): + true_index = par_id + else: + true_index = self._get_param_index_from_name(par_id) + + self._parameter_multimodal[true_index] = is_multimodal diff --git a/MaCh3PythonUtils/diagnostics/plotters/posteriors/posteriors_1d.py b/MaCh3PythonUtils/diagnostics/plotters/posteriors/posteriors_1d.py new file mode 100644 index 0000000..3ab8d52 --- /dev/null +++ b/MaCh3PythonUtils/diagnostics/plotters/posteriors/posteriors_1d.py @@ -0,0 +1,88 @@ +from MaCh3_plot_lib.file_handlers import root_file_loader +import arviz as az +from matplotlib import pyplot as plt +import numpy as np +from MaCh3_plot_lib.plotters.posteriors.posterior_base_classes import _posterior_plotting_base +from matplotlib.figure import Figure + +''' +Plotting class for 1D plots. Slightly overcomplicated by various special cases but works a treat! +''' +class posterior_plotter_1D(_posterior_plotting_base): + def __init__(self, file_loader: root_file_loader)->None: + ''' + Constructor + ''' + # Inherit the abstract base class + super().__init__(file_loader) + + def _generate_plot(self, parameter_name: str) -> Figure: + ''' + Generates a single posterior plot for a parameter + + inputs : + parameter name : [type=str] Name of parameter + ''' + if not isinstance(parameter_name, str): + raise ValueError("Can only pass single parameters to posterior plotting class") + + # Checks if parameter is in our array+gets the index + param_index = self._get_param_index_from_name(parameter_name) + + fig, axes = plt.subplots(figsize=(30, 30)) + # Set bin number and range + n_bins = 50 + # Grab our density plot + # plt.rcParams['image.cmap'] = 'Purples' # To avoid any multi-threaded weirdness + line_colour_generator = iter(plt.cm.Purples(np.linspace(0.4, 0.8, len(self.credible_intervals)+1))) + line_colour = next(line_colour_generator) + hist_kwargs={'density' : True, 'bins' : n_bins, "alpha": 1.0, 'linewidth': None, 'edgecolor': line_colour, 'color': 'white'} + + _, bins, patches = axes.hist(self._file_loader.ttree_array[parameter_name].to_numpy()[0], **hist_kwargs) + # Make lines for each CI + + cred_rev = self.credible_intervals[::-1] + for credible in cred_rev: + line_colour = next(line_colour_generator) + + # We want the bayesian credible interval, for now we set the maximum number of modes for multi-modal parameters to 2 + hdi = az.hdi(self._file_loader.ttree_array, var_names=[parameter_name], hdi_prob=credible, + multimodal=self._parameter_multimodal[param_index], max_modes=20, circular=self._circular_params[param_index]) + + # Might be multimodal so we want all our credible intervals in a 1D array to make plotting easier! + credible_bounds = hdi[parameter_name].to_numpy() + + if isinstance(credible_bounds[0], float): + credible_bounds = np.array([credible_bounds]) #when we're not multi-modal we need to be a little careful + + # set up credible interval + plot_label = f"{100*credible}% credible interval " + for bound in credible_bounds: + # Set up plotting options + # Reduce the plotting array to JUST be between our boudnaries + mask = (bins>=bound[0]) & (bins<=bound[1]) + if bound[0]>bound[1]: + # We need a SLIGHTLY different treatment since we loop around for some parameters (delta_cp) + mask = (bins>=bound[0]) | (bins<=bound[1]) + + for patch_index in np.where(mask)[0]: + patches[patch_index-1].set_facecolor(line_colour) + patches[patch_index-1].set_edgecolor(None) + patches[patch_index-1].set_label(plot_label) + + # add legend + # Set Some labels + axes.set_xlabel(parameter_name, fontsize=50) + axes.tick_params(labelsize=40) + axes.set_title(f"Posterior Density for {parameter_name}", fontsize=60) + axes.set_ylim(ymin=0) + axes.set_ylabel("Posterior Density", fontsize=50) + + # Generate Unique set of labels and handles + plot_handles, plot_labels = axes.get_legend_handles_labels() + # Sometimes it loses track of the posterior histogram + unique_labs = [(h, l) for i, (h, l) in enumerate(zip(plot_handles, plot_labels)) if l not in plot_labels[:i]] + axes.legend(*zip(*unique_labs), loc=(0.6,0.85), fontsize=35, facecolor="white", edgecolor="black", frameon=True) + #Stop memory issues + plt.close() + return fig # Add to global figure list diff --git a/MaCh3PythonUtils/diagnostics/plotters/posteriors/posteriors_2d.py b/MaCh3PythonUtils/diagnostics/plotters/posteriors/posteriors_2d.py new file mode 100644 index 0000000..50c6785 --- /dev/null +++ b/MaCh3PythonUtils/diagnostics/plotters/posteriors/posteriors_2d.py @@ -0,0 +1,115 @@ +''' +HI : Makes 2D posterior plots. Currently no way of putting a legend on the plot (thanks arviz...) +''' +from MaCh3_plot_lib.file_handlers import root_file_loader +import arviz as az +from typing import List, Any +from matplotlib import pyplot as plt +from itertools import combinations +from MaCh3_plot_lib.plotters.posteriors.posterior_base_classes import _posterior_plotting_base +from matplotlib.figure import Figure +import numpy as np +import numpy.typing as npt + +class posterior_plotter_2D(_posterior_plotting_base): + # For making 2D possteriors + def __init__(self, file_loader: root_file_loader)->None: + ''' + Constructor + ''' + # Inherit the abstract base class + super().__init__(file_loader) + + def _generate_plot(self, parameter_names: List[str]) -> npt.NDArray[Any]: + ''' + Generates a 2D posterior plot + inputs : + -> Parameter Names [type=List[str]] list of parameters, will plot all combinations of pairs + + returns : + -> figure + ''' + # Let's get pairs of names + name_pairs_list = list(combinations(parameter_names, 2)) + fig_list = np.empty(len(name_pairs_list), dtype=Figure) + # Now we loop over our pairs of names + for i, (par_1, par_2) in enumerate(name_pairs_list): + + fig, axes = plt.subplots(figsize=(30, 30)) + + par_1_numpy_arr = self._file_loader.ttree_array[par_1].to_numpy()[0] + par_2_numpy_arr = self._file_loader.ttree_array[par_2].to_numpy()[0] + + ciruclar = self._circular_params[self._get_param_index_from_name(par_1)] | self._circular_params[self._get_param_index_from_name(par_2)] + + az.plot_kde(par_1_numpy_arr, par_2_numpy_arr, + hdi_probs= self._credible_intervals, + contourf_kwargs={"cmap": "Purples"}, + ax=axes, is_circular=ciruclar, legend=True) + + axes.set_xlabel(par_1, fontsize=50) + axes.set_ylabel(par_2, fontsize=50) + axes.tick_params(labelsize=40) + + + cred_as_str=",".join(f"{100*i}%" for i in self._credible_intervals) + axes.set_title(f"{par_1} vs {par_2} : [{cred_as_str}] credible intervals", fontsize="60") + plt.close() # CLose the canvas + fig_list[i] = fig + return fig_list + + + +class triangle_plotter(_posterior_plotting_base): + # Makes triangle plots + def __init__(self, file_loader: root_file_loader)->None: + ''' + Constructor + ''' + # Inherit the abstract base class + super().__init__(file_loader) + + def _generate_plot(self, parameter_names: List[str]) -> Figure: + ''' + Generates a single triangle plot and adds it to the figure list + inputs : + -> parameter_names : [List[str]] List of parameter names to put in the triangle + returns : + -> Figure + ''' + if not isinstance(parameter_names, list): + raise ValueError("Parameter names must be list when plotting triangle plots") + + # Check we are using a valid parameter set + for param in parameter_names: + self._parameter_not_found_error(param) # Check if parameters exist + + fig, axes = plt.subplots(nrows=len(parameter_names), ncols=len(parameter_names), figsize=(30, 30)) + + az.plot_pair(self._file_loader.ttree_array, var_names=parameter_names, + marginals=True, ax=axes, colorbar=True, figsize=(30,30), + kind='kde', + textsize=30, + kde_kwargs={ + "hdi_probs": self._credible_intervals, # Plot HDI contours + "contourf_kwargs": {"cmap": "Purples"}, + 'legend':True + }, + marginal_kwargs={ + # 'fill_kwargs': {'alpha': 0.0}, + 'plot_kwargs': {"linewidth": 4.5, "color": "purple"}, + # "quantiles": credible_intervals, + "rotated": False + }, + point_estimate='mean', + ) + + cred_as_str=",".join(f"{100*i}%" for i in self._credible_intervals) + fig.suptitle(f"Triangle Plot for : {cred_as_str} credible intervals", fontsize="60") + + # axes[-1, -1].legend(axes, [f"{i}% Credible Interval" for i in self._credible_intervals], frameon=True, loc='right') + + fig.subplots_adjust(wspace=0.01, hspace=0.01) + + plt.close() + return fig \ No newline at end of file From 129deb59316a46e93bcc0ccaf36ee0a360f6312b Mon Sep 17 00:00:00 2001 From: henry-israel Date: Tue, 17 Sep 2024 16:13:17 +0100 Subject: [PATCH 04/11] update gitignore --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 595ebed..c376aaa 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ *env* *__pycache__* *pdf -*.pkl \ No newline at end of file +*.pkl +*.keras \ No newline at end of file From 8b50de765aba653c5e9f30ab33cb454f8738dc5d Mon Sep 17 00:00:00 2001 From: henry-israel Date: Tue, 17 Sep 2024 16:13:27 +0100 Subject: [PATCH 05/11] Implements plotting lib --- MaCh3PythonUtils/config_reader.py | 216 +++++++++++++++--- MaCh3PythonUtils/diagnostics/__init__.py | 0 .../diagnostics/interface/__init__.py | 2 +- .../interface/plotting_interface.py | 15 +- .../diagnostics/plotters/__init__.py | 2 +- .../autocorrelation_trace_plotter.py | 10 +- .../diagnostics/covariance_matrix_utils.py | 12 +- .../plotters/diagnostics/simple_diag_plots.py | 22 +- .../diagnostics/plotters/plotter_base.py | 18 +- .../posteriors/posterior_base_classes.py | 10 +- .../plotters/posteriors/posteriors_1d.py | 14 +- .../plotters/posteriors/posteriors_2d.py | 18 +- .../file_handling/chain_handler.py | 19 +- configs/plotting_config.yml | 44 ++++ configs/scikit_config.yml | 13 +- configs/tensorflow_config.yml | 16 +- requirements.txt | 10 +- 17 files changed, 338 insertions(+), 103 deletions(-) create mode 100644 MaCh3PythonUtils/diagnostics/__init__.py create mode 100644 configs/plotting_config.yml diff --git a/MaCh3PythonUtils/config_reader.py b/MaCh3PythonUtils/config_reader.py index 3876dd9..4773db6 100644 --- a/MaCh3PythonUtils/config_reader.py +++ b/MaCh3PythonUtils/config_reader.py @@ -3,71 +3,233 @@ from file_handling.chain_handler import ChainHandler from machine_learning.ml_factory import MLFactory from machine_learning.fml_interface import FmlInterface - +from diagnostics.interface.plotting_interface import PlottingInterface +import diagnostics.plotters.posteriors as m3post +import diagnostics.plotters.diagnostics as m3diag +from pydantic.utils import deep_update class ConfigReader: # Strictly unecessary but nice conceptually _file_handler = None _interface = None + _plot_interface = None + + __default_settings = { + # Settings for file and I/O + "FileSettings" : { + # Name of input file + "FileName": "", + # Name of chain in file + "ChainName": "", + # More printouts + "Verbose": False, + # Make posteriors from chain? + "MakePosteriors": False, + # Run Diagnostics code? + "MakeDiagnostics": False, + # Make an ML model to replicate the chain likelihood model + "MakeMLModel": False + }, + + "ParameterSettings":{ + "CircularParameters" : [], + # List of parameter names + "ParameterNames":[], + # List of cuts + "ParameterCuts":[], + # Name of label branch, used in ML + "LabelName": "", + # Any parameters we want to ignore + "IgnoredParameters":[] + }, + + # Settings for plotting tools + "PlottingSettings":{ + # Specific Settings for posterior plots + "PosteriorSettings": { + # Make 2D Posterior Plots? + "Make2DPosteriors": False, + # Make a triangle plot + "MakeTrianglePlot": False, + # Variables in the triangle plot + "TrianglePlot": [], + # 1D credible intervals + "MakeCredibleIntervals": False, + # Output file + "PosteriorOutputFile": "posteriors.pdf" + }, + + # Specific Settings for diagnostic plots + "DiagnosticsSettings": { + # Make violin plot? + "MakeViolin": False, + # Make trace + AC plot? + "MakeTraceAC": False, + # Make effective sample size plot? + "MakeESS": False, + # Make MCSE plot? + "MakeMCSE": False, + # Make suboptimality plot + "MakeSuboptimality": False, + # Step for calculation + "SuboptimalitySteps": 0, + # Output file + "DiagnosticsOutputFile": "diagnostics.pdf", + # Print summary statistic + "PrintSummary": False + } + }, + # Specific Settings for ML Applications + "MLSettings": { + # Fitter package either SciKit or TensorFlow + "FitterPackage": "", + # Fitter Model + "FitterName": "", + # Keyword arguments for fitter + "FitterKwargs" : {}, + #Use an external model that's already been trained? + "AddFromExternalModel": False, + # Proportion of input data set used for testing (range of 0-1 ) + "TestSize": 0.0, + # Name to save ML model in + "MLOutputFile": "mlmodel" + } + + } + + def __init__(self, config: str): with open(config, 'r') as c: self._yaml_config = yaml.safe_load(c) + + # Update default settings + self.__chain_settings = deep_update(self.__default_settings, self._yaml_config) - - def setup_file_handler(self)->None: + def make_file_handler(self)->None: # Process MCMC chain - self._file_handler = ChainHandler(self._yaml_config["FileSettings"]["FileName"], - self._yaml_config["FileSettings"]["ChainName"], - self._yaml_config["FileSettings"]["Verbose"]) + self._file_handler = ChainHandler(self.__chain_settings["FileSettings"]["FileName"], + self.__chain_settings["FileSettings"]["ChainName"], + self.__chain_settings["FileSettings"]["Verbose"]) - self._file_handler.ignore_plots(self._yaml_config["FileSettings"]["IgnoredParameters"]) - self._file_handler.add_additional_plots(self._yaml_config["FileSettings"]["ParameterNames"]) - self._file_handler.add_additional_plots(self._yaml_config["FileSettings"]["LabelName"], True) + self._file_handler.ignore_plots(self.__chain_settings["ParameterSettings"]["IgnoredParameters"]) + self._file_handler.add_additional_plots(self.__chain_settings["ParameterSettings"]["ParameterNames"]) + + self._file_handler.add_additional_plots(self.__chain_settings["ParameterSettings"]["LabelName"], True) - self._file_handler.add_new_cuts(self._yaml_config["FileSettings"]["ParameterCuts"]) + self._file_handler.add_new_cuts(self.__chain_settings["ParameterSettings"]["ParameterCuts"]) self._file_handler.convert_ttree_to_array() - def setup_ml_interface(self)->None: + + def make_posterior_plots(self): + if self._plot_interface is None: + self._plot_interface = PlottingInterface(self._file_handler) + + posterior_labels = [] + + if self.__chain_settings['PlottingSettings']['PosteriorSettings']['Make2DPosteriors']: + self._plot_interface.initialise_new_plotter(m3post.PosteriorPlotter2D(self._file_handler), 'posterior_2d') + posterior_labels.append('posterior_2d') + + if self.__chain_settings['PlottingSettings']['PosteriorSettings']['MakeTrianglePlot']: + self._plot_interface.initialise_new_plotter(m3post.TrianglePlotter(self._file_handler), 'posterior_triangle') + posterior_labels.append('posterior_triangle') + + # Which variables do we actually want 2D plots for? + self._plot_interface.set_variables_to_plot(self.__chain_settings['PlottingSettings']['PosteriorSettings']['TrianglePlot'], posterior_labels) + + if self.__chain_settings['PlottingSettings']['PosteriorSettings']['Make1DPosteriors']: + self._plot_interface.initialise_new_plotter(m3post.PosteriorPlotter1D(self._file_handler), 'posterior_1d') + posterior_labels.append('posterior_1d') + + + self._plot_interface.set_credible_intervals(self.__chain_settings['PlottingSettings']['PosteriorSettings']['CredibleIntervals']) + self._plot_interface.set_is_circular(self.__chain_settings['ParameterSettings']['CircularParameters']) + + self._plot_interface.make_plots(self.__chain_settings['PlottingSettings']['PosteriorSettings']['PosteriorOutputFile'], posterior_labels) + + def make_diagnostics_plots(self): + if self._plot_interface is None: + self._plot_interface = PlottingInterface(self._file_handler) + + diagnostic_labels = [] + + if self.__chain_settings['PlottingSettings']['DiagnosticsSettings']['MakeViolin']: + self._plotting_interface.initialise_new_plotter(m3diag.ViolinPlotter(self._file_handler), 'violin_plotter') + diagnostic_labels.append('violin_plotter') + if self.__chain_settings['PlottingSettings']['DiagnosticsSettings']['MakeTraceAC']: + self._plotting_interface.initialise_new_plotter(m3diag.AutocorrelationTracePlotter(self._file_handler), 'trace_autocorr') + diagnostic_labels.append('trace_autocorr') + if self.__chain_settings['PlottingSettings']['DiagnosticsSettings']['MakeESS']: + self._plotting_interface.initialise_new_plotter(m3diag.EffectiveSampleSizePlotter(self._file_handler), 'ess_plot') + diagnostic_labels.append('ess_plot') + if self.__chain_settings['PlottingSettings']['DiagnosticsSettings']['MakeMCSE']: + self._plotting_interface.initialise_new_plotter(m3diag.MarkovChainStandardError(self._file_handler), 'msce_plot') + diagnostic_labels.append('msce_plot') + + self._plotting_interface.make_plots(self.__chain_settings['PlottingSettings']['DiagnosticsSettings']['DiagnosticsOutputFile'], diagnostic_labels) + + # Final one, covariance plotter + if self.__chain_settings['PlottingSettings']['DiagnosticsSettings']['MakeSuboptimality']: + suboptimality_obj = m3diag.CovarianceMatrixUtils(self._file_handler) + suboptimality_obj.calculate_suboptimality(self.__chain_settings['PlottingSettings']['DiagnosticsSettings']['SuboptimalitySteps'], self.__chain_settings['PlottingSettings']['DiagnosticsSettings']['SubOptimalityMin']) + suboptimality_obj.plot_suboptimality(f"suboptimality_{self.__chain_settings['PlottingSettings']['DiagnosticsSettings']['DiagnosticsOutputFile']}") + + # Finally let's make a quick simmary + if self.__chain_settings['PlottingSettings']['DiagnosticsSettings']['PrintSummary']: + self._plotting_interface.print_summary(f"summary_{self.__chain_settings['PlottingSettings']['DiagnosticsSettings']['DiagnosticsOutputFile']}.txt") + + def make_ml_interface(self)->None: if self._file_handler is None: raise Exception("Cannot initialise ML interface without first setting up file handler!") - factory = MLFactory(self._file_handler, self._yaml_config["FileSettings"]["LabelName"]) - if self._yaml_config["FitterSettings"]["FitterPackage"].lower() == "scikit": - self._interface = factory.setup_scikit_model(self._yaml_config["FitterSettings"]["FitterName"], - **self._yaml_config["FitterSettings"]["FitterKwargs"]) + factory = MLFactory(self._file_handler, self.__chain_settings["ParameterSettings"]["LabelName"]) + if self.__chain_settings["MLSettings"]["FitterPackage"].lower() == "scikit": + self._interface = factory.make_scikit_model(self.__chain_settings["MLSettings"]["FitterName"], + **self.__chain_settings["MLSettings"]["FitterKwargs"]) - elif self._yaml_config["FitterSettings"]["FitterPackage"].lower() == "tensorflow": - self._interface = factory.setup_tensorflow_model(self._yaml_config["FitterSettings"]["FitterName"], - **self._yaml_config["FitterSettings"]["FitterKwargs"]) + elif self.__chain_settings["MLSettings"]["FitterPackage"].lower() == "tensorflow": + self._interface = factory.make_tensorflow_model(self.__chain_settings["MLSettings"]["FitterName"], + **self.__chain_settings["MLSettings"]["FitterKwargs"]) else: raise ValueError("Input not recognised!") - if self._yaml_config["FitterSettings"].get("AddFromExternalModel"): - external_model = self._yaml_config["FitterSettings"]["ExternalModel"] + if self.__chain_settings["MLSettings"].get("AddFromExternalModel"): + external_model = self.__chain_settings["MLSettings"]["ExternalModel"] self._interface.load_model(external_model) else: - self._interface.set_training_test_set(self._yaml_config["FitterSettings"]["TestSize"]) + self._interface.set_training_test_set(self.__chain_settings["MLSettings"]["TestSize"]) self._interface.train_model() self._interface.test_model() - self._interface.save_model(self._yaml_config["FileSettings"]["ModelOutputName"]) - - def __call__(self) -> None: - self.setup_file_handler() - self.setup_ml_interface() + self._interface.save_model(self.__chain_settings["MLSettings"]["MLOutputFile"]) + + + def __call__(self) -> None: + + self.make_file_handler() + if self.__chain_settings["FileSettings"]["MakePosteriors"]: + self.make_posterior_plots() + if self.__chain_settings["FileSettings"]["MakeDiagnostics"]: + self.make_diagnostics_plots() + + if self.__chain_settings["FileSettings"]["MakeMLModel"]: + self.make_ml_interface() + + @property def chain_handler(self)->ChainHandler | None: return self._file_handler @property def ml_interface(self)->FmlInterface | None: - return self._interface \ No newline at end of file + return self._interface + diff --git a/MaCh3PythonUtils/diagnostics/__init__.py b/MaCh3PythonUtils/diagnostics/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/MaCh3PythonUtils/diagnostics/interface/__init__.py b/MaCh3PythonUtils/diagnostics/interface/__init__.py index 18b84c2..3aa29b1 100644 --- a/MaCh3PythonUtils/diagnostics/interface/__init__.py +++ b/MaCh3PythonUtils/diagnostics/interface/__init__.py @@ -1 +1 @@ -from .plotting_interface import plotting_interface \ No newline at end of file +from .plotting_interface import PlottingInterface \ No newline at end of file diff --git a/MaCh3PythonUtils/diagnostics/interface/plotting_interface.py b/MaCh3PythonUtils/diagnostics/interface/plotting_interface.py index 13dd1d2..02e8534 100644 --- a/MaCh3PythonUtils/diagnostics/interface/plotting_interface.py +++ b/MaCh3PythonUtils/diagnostics/interface/plotting_interface.py @@ -2,26 +2,27 @@ Interface class for making plots! ''' from typing import List -from MaCh3_plot_lib.file_handlers import root_file_loader -import MaCh3_plot_lib.plotters as pt +from file_handling.chain_handler import ChainHandler +import diagnostics.plotters as pt from matplotlib.backends.backend_pdf import PdfPages import arviz as az -class plotting_interface: +class PlottingInterface: ''' full interface object for making plots inputs: file_loader : root_file_loader instance ''' - def __init__(self, file_loader: root_file_loader): + def __init__(self, file_loader: ChainHandler): ''' Constructor object ''' self._file_loader = file_loader self._plotter_object_dict = {} # dict of objects from plotting tools + self._file_loader.make_arviz_tree() - def initialise_new_plotter(self, new_plotter: pt.plotter_base._plotting_base_class , plot_label: str)->None: + def initialise_new_plotter(self, new_plotter: pt.plotter_base._PlottingBaseClass , plot_label: str)->None: ''' Adds new plot object to our array inputs : @@ -44,7 +45,7 @@ def set_credible_intervals(self, credible_intervals: List[float])->None: raise ValueError(f"Cannot set credible intervals to {credible_intervals}") for plotter in list(self._plotter_object_dict.values()): - if not isinstance(plotter, pt.posterior_base_classes._posterior_plotting_base): + if not isinstance(plotter, pt.posterior_base_classes._PosteriorPlottingBase): continue # set credible intervals @@ -84,7 +85,7 @@ def set_is_circular(self, param_ids: List[int | str]): # Loop over our plotters for plotter in self._plotter_object_dict.values(): - if not isinstance(plotter, pt.posterior_base_classes._posterior_plotting_base): + if not isinstance(plotter, pt.posterior_base_classes._PosteriorPlottingBase): continue plotter.set_pars_circular(param_ids) diff --git a/MaCh3PythonUtils/diagnostics/plotters/__init__.py b/MaCh3PythonUtils/diagnostics/plotters/__init__.py index 50c30a9..0db26c1 100644 --- a/MaCh3PythonUtils/diagnostics/plotters/__init__.py +++ b/MaCh3PythonUtils/diagnostics/plotters/__init__.py @@ -1,3 +1,3 @@ -from .plotter_base import _plotting_base_class +from .plotter_base import _PlottingBaseClass from .posteriors import * from .diagnostics import * \ No newline at end of file diff --git a/MaCh3PythonUtils/diagnostics/plotters/diagnostics/autocorrelation_trace_plotter.py b/MaCh3PythonUtils/diagnostics/plotters/diagnostics/autocorrelation_trace_plotter.py index adca0b7..d0ef6f8 100644 --- a/MaCh3PythonUtils/diagnostics/plotters/diagnostics/autocorrelation_trace_plotter.py +++ b/MaCh3PythonUtils/diagnostics/plotters/diagnostics/autocorrelation_trace_plotter.py @@ -1,14 +1,14 @@ ''' HI : Class to make autocorrelations and traces, puts them all onto a single plot ''' -from MaCh3_plot_lib.file_handlers import root_file_loader +from file_handling.chain_handler import ChainHandler import arviz as az from matplotlib import pyplot as plt -from MaCh3_plot_lib.plotters.plotter_base import _plotting_base_class +from diagnostics.plotters.plotter_base import _PlottingBaseClass from matplotlib.figure import Figure -class autocorrelation_trace_plotter(_plotting_base_class): - def __init__(self, file_loader: root_file_loader)->None: +class AutocorrelationTracePlotter(_PlottingBaseClass): + def __init__(self, file_loader: ChainHandler)->None: # Constructor super().__init__(file_loader) @@ -22,7 +22,7 @@ def _generate_plot(self, parameter_name: str) -> Figure: fig, (trace_ax, autocorr_ax) = plt.subplots(nrows=2, sharex=False) # We want the numpy array containing our parameter - param_array = self._file_loader.ttree_array[parameter_name].to_numpy()[0] + param_array = self._file_loader.arviz_tree[parameter_name].to_numpy()[0] # Okay now we can plot our trace (might as well!) trace_ax.plot(param_array, linewidth=0.05, color='purple') diff --git a/MaCh3PythonUtils/diagnostics/plotters/diagnostics/covariance_matrix_utils.py b/MaCh3PythonUtils/diagnostics/plotters/diagnostics/covariance_matrix_utils.py index 87b887a..40c055b 100644 --- a/MaCh3PythonUtils/diagnostics/plotters/diagnostics/covariance_matrix_utils.py +++ b/MaCh3PythonUtils/diagnostics/plotters/diagnostics/covariance_matrix_utils.py @@ -4,7 +4,7 @@ Suboptimality : https://www.jstor.org/stable/25651249?seq=3 ''' -from MaCh3_plot_lib.file_handlers import root_file_loader +from file_handling.chain_handler import ChainHandler import numpy as np from scipy.linalg import sqrtm from tqdm.auto import tqdm @@ -15,21 +15,21 @@ from concurrent.futures import ThreadPoolExecutor, as_completed -class covariance_matrix_utils: - def __init__(self, file_loader: root_file_loader)->None: +class CovarianceMatrixUtils: + def __init__(self, file_loader: ChainHandler)->None: ''' For calculating the covariance matrix + suboptimality inputs: - ->file_loader : [type=root_file_loader] file handler object + ->file_loader : [type=ChainHandler] file handler object ''' # Let's just ignore some warnings :grin: warnings.filterwarnings("ignore", category=DeprecationWarning) #Some imports are a little older warnings.filterwarnings("ignore", category=UserWarning) #Some imports are a little older file_loader.get_ttree_as_arviz() # Ensure that we have things in the correct format - self._parameter_names = list(file_loader.ttree_array.keys()) + self._parameter_names = list(file_loader.arviz_tree.keys()) self._total_parameters = len(self._parameter_names) - self._ttree_data_frame = file_loader.ttree_array.to_dataframe() # All our handling will be done using thi s object + self._ttree_data_frame = file_loader.arviz_tree.to_dataframe() # All our handling will be done using thi s object # Various useful class properties self._suboptimality_array = [] # For filling with suboptimality self._suboptimality_evaluation_points = [] #Where did we evalutate this? diff --git a/MaCh3PythonUtils/diagnostics/plotters/diagnostics/simple_diag_plots.py b/MaCh3PythonUtils/diagnostics/plotters/diagnostics/simple_diag_plots.py index ba6b8c4..f9589ef 100644 --- a/MaCh3PythonUtils/diagnostics/plotters/diagnostics/simple_diag_plots.py +++ b/MaCh3PythonUtils/diagnostics/plotters/diagnostics/simple_diag_plots.py @@ -3,17 +3,17 @@ ''' from typing import List -from MaCh3_plot_lib.file_handlers import root_file_loader +from file_handling.chain_handler import ChainHandler import arviz as az from matplotlib import pyplot as plt -from MaCh3_plot_lib.plotters.plotter_base import _plotting_base_class +from diagnostics.plotters.plotter_base import _PlottingBaseClass from matplotlib.figure import Figure -class effective_sample_size_plotter(_plotting_base_class): +class EffectiveSampleSizePlotter(_PlottingBaseClass): ''' Caclulates effective sample size : https://arxiv.org/pdf/1903.08008.pdf ''' - def __init__(self, file_loader: root_file_loader)->None: + def __init__(self, file_loader: ChainHandler)->None: # Constructor super().__init__(file_loader) @@ -27,16 +27,16 @@ def _generate_plot(self, parameter_name: str) -> Figure: ''' fig, axes = plt.subplots() - az.plot_ess(self._file_loader.ttree_array, var_names=parameter_name, + az.plot_ess(self._file_loader.arviz_tree, var_names=parameter_name, ax=axes, textsize=30, color='purple', drawstyle="steps-mid", linestyle="-") plt.close() return fig -class markov_chain_standard_error(_plotting_base_class): +class MarkovChainStandardError(_PlottingBaseClass): ''' Calculates Markov Chain Standard Error : https://arxiv.org/pdf/1903.08008.pdf ''' - def __init__(self, file_loader: root_file_loader)->None: + def __init__(self, file_loader: ChainHandler)->None: # Constructor super().__init__(file_loader) @@ -49,14 +49,14 @@ def _generate_plot(self, parameter_name: str) -> Figure: Figure ''' fig, axes = plt.subplots() - az.plot_mcse(self._file_loader.ttree_array, var_names=parameter_name, ax=axes, textsize=10, color='purple') + az.plot_mcse(self._file_loader.arviz_tree, var_names=parameter_name, ax=axes, textsize=10, color='purple') plt.close() return fig -class violin_plotter(_plotting_base_class): +class ViolinPlotter(_PlottingBaseClass): # Class to generate Violin Plots - def __init__(self, file_loader: root_file_loader)->None: + def __init__(self, file_loader: ChainHandler)->None: # Constructor super().__init__(file_loader) @@ -66,7 +66,7 @@ def _generate_plot(self, parameter_name: str | List[str]) -> Figure: ''' # total number of axes we need fig, axes = plt.subplots() - az.plot_violin(self._file_loader.ttree_array, + az.plot_violin(self._file_loader.arviz_tree, var_names=parameter_name, ax=axes, textsize=10, shade_kwargs={'color':'purple'}) plt.close() diff --git a/MaCh3PythonUtils/diagnostics/plotters/plotter_base.py b/MaCh3PythonUtils/diagnostics/plotters/plotter_base.py index 0282289..8a0f214 100644 --- a/MaCh3PythonUtils/diagnostics/plotters/plotter_base.py +++ b/MaCh3PythonUtils/diagnostics/plotters/plotter_base.py @@ -2,7 +2,7 @@ HI : Mostly abstract base class, contains common methods for use by most other plotting classes ''' -from MaCh3_plot_lib.file_handlers import root_file_loader +from file_handling.chain_handler import ChainHandler import arviz as az from matplotlib import pyplot as plt from matplotlib.figure import Figure @@ -20,20 +20,19 @@ # Base class with common methods -class _plotting_base_class(ABC): +class _PlottingBaseClass(ABC): ''' Abstract class with common methods for creating plots inputs: - file_loader : [type=root_file_loader] root_file_loader class instance + file_loader : [type=ChainHandler] ChainHandler class instance ''' - def __init__(self, file_loader: root_file_loader)->None: + def __init__(self, file_loader: ChainHandler)->None: ''' Constructor ''' # I'm very lazy but this should speed up parallelisation since we won't have this warning plt.rcParams.update({'figure.max_open_warning': 0}) - file_loader.get_ttree_as_arviz() # Ensure that we have things in the correct format self._file_loader=file_loader # Make sure we don't use crazy amount of memory # Setup plotting styles az.style.use(hep.style.ROOT) @@ -45,7 +44,7 @@ def __init__(self, file_loader: root_file_loader)->None: self._all_plots_generated = False # Stops us making plots multiple times # Setting it like this replicates the old behaviour! - self._circular_params=np.ones(len(self._file_loader.ttree_array)).astype(bool) + self._circular_params=np.ones(len(self._file_loader.arviz_tree)).astype(bool) warnings.filterwarnings("ignore", category=DeprecationWarning) #Some imports are a little older warnings.filterwarnings("ignore", category=UserWarning) #Some imports are a little older @@ -57,7 +56,7 @@ def __init__(self, file_loader: root_file_loader)->None: self._text_location = (0.05, 0.85) # Default option - self._params_to_plot = list(self._file_loader.ttree_array.keys()) + self._params_to_plot = list(self._file_loader.arviz_tree.keys()) def __str__(self) -> str: return "plotting_base_class" @@ -116,7 +115,6 @@ def plot_params(self, new_plot_parameter_list: List[str]|List[List[str]]): elif isinstance(new_plot_parameter_list[0][0], str): for param_list in new_plot_parameter_list: for param in param_list: - print(param) self._parameter_not_found_error(param) else: @@ -127,14 +125,14 @@ def plot_params(self, new_plot_parameter_list: List[str]|List[List[str]]): def _parameter_not_found_error(self, parameter_name: str): - if parameter_name not in list(self._file_loader.ttree_array.keys()): + if parameter_name not in list(self._file_loader.arviz_tree.keys()): raise ValueError(f"{parameter_name} not in list of parameters!") def _get_param_index_from_name(self, parameter_name: str)->int: # Gets index of parameter in our arviz array self._parameter_not_found_error(parameter_name) - param_id = list(self._file_loader.ttree_array.keys()).index(parameter_name) + param_id = list(self._file_loader.arviz_tree.keys()).index(parameter_name) return param_id def set_pars_circular(self, par_id_list: List[str] | List[int])->None: diff --git a/MaCh3PythonUtils/diagnostics/plotters/posteriors/posterior_base_classes.py b/MaCh3PythonUtils/diagnostics/plotters/posteriors/posterior_base_classes.py index e2ac4fa..d9d041b 100644 --- a/MaCh3PythonUtils/diagnostics/plotters/posteriors/posterior_base_classes.py +++ b/MaCh3PythonUtils/diagnostics/plotters/posteriors/posterior_base_classes.py @@ -2,8 +2,8 @@ HI : Contains a set of base classes with methods common to posterior plotting code Separated into 1D and 2D-like objects to make life easier ''' -from MaCh3_plot_lib.file_handlers import root_file_loader -from MaCh3_plot_lib.plotters import _plotting_base_class +from file_handling.chain_handler import ChainHandler +from diagnostics.plotters.plotter_base import _PlottingBaseClass from typing import List import numpy as np from tqdm.auto import tqdm @@ -11,13 +11,13 @@ # Base class for all posterior plotters -class _posterior_plotting_base(_plotting_base_class): +class _PosteriorPlottingBase(_PlottingBaseClass): # Small extension of _plotting_base class for posterior specific stuff - def __init__(self, file_loader: root_file_loader)->None: + def __init__(self, file_loader: ChainHandler)->None: # Setup additional features for posteriors super().__init__(file_loader) self._credible_intervals = np.array([0.6, 0.9, 0.95]) - self._parameter_multimodal = np.zeros(len(self._file_loader.ttree_array)).astype(bool) + self._parameter_multimodal = np.zeros(len(self._file_loader.arviz_tree)).astype(bool) def __str__(self) -> str: return "posterior_plotting_base_class" diff --git a/MaCh3PythonUtils/diagnostics/plotters/posteriors/posteriors_1d.py b/MaCh3PythonUtils/diagnostics/plotters/posteriors/posteriors_1d.py index 3ab8d52..b4da96f 100644 --- a/MaCh3PythonUtils/diagnostics/plotters/posteriors/posteriors_1d.py +++ b/MaCh3PythonUtils/diagnostics/plotters/posteriors/posteriors_1d.py @@ -1,15 +1,15 @@ -from MaCh3_plot_lib.file_handlers import root_file_loader +from config_reader import ChainHandler import arviz as az from matplotlib import pyplot as plt import numpy as np -from MaCh3_plot_lib.plotters.posteriors.posterior_base_classes import _posterior_plotting_base +from diagnostics.plotters.posteriors.posterior_base_classes import _PosteriorPlottingBase from matplotlib.figure import Figure ''' Plotting class for 1D plots. Slightly overcomplicated by various special cases but works a treat! ''' -class posterior_plotter_1D(_posterior_plotting_base): - def __init__(self, file_loader: root_file_loader)->None: +class PosteriorPlotter1D(_PosteriorPlottingBase): + def __init__(self, file_loader: ChainHandler)->None: ''' Constructor ''' @@ -24,7 +24,7 @@ def _generate_plot(self, parameter_name: str) -> Figure: parameter name : [type=str] Name of parameter ''' if not isinstance(parameter_name, str): - raise ValueError("Can only pass single parameters to posterior plotting class") + raise ValueError(f"Can only pass single parameters to posterior plotting class. Cannot plot {parameter_name}") # Checks if parameter is in our array+gets the index param_index = self._get_param_index_from_name(parameter_name) @@ -38,7 +38,7 @@ def _generate_plot(self, parameter_name: str) -> Figure: line_colour = next(line_colour_generator) hist_kwargs={'density' : True, 'bins' : n_bins, "alpha": 1.0, 'linewidth': None, 'edgecolor': line_colour, 'color': 'white'} - _, bins, patches = axes.hist(self._file_loader.ttree_array[parameter_name].to_numpy()[0], **hist_kwargs) + _, bins, patches = axes.hist(self._file_loader.arviz_tree[parameter_name].to_numpy()[0], **hist_kwargs) # Make lines for each CI cred_rev = self.credible_intervals[::-1] @@ -46,7 +46,7 @@ def _generate_plot(self, parameter_name: str) -> Figure: line_colour = next(line_colour_generator) # We want the bayesian credible interval, for now we set the maximum number of modes for multi-modal parameters to 2 - hdi = az.hdi(self._file_loader.ttree_array, var_names=[parameter_name], hdi_prob=credible, + hdi = az.hdi(self._file_loader.arviz_tree, var_names=[parameter_name], hdi_prob=credible, multimodal=self._parameter_multimodal[param_index], max_modes=20, circular=self._circular_params[param_index]) # Might be multimodal so we want all our credible intervals in a 1D array to make plotting easier! diff --git a/MaCh3PythonUtils/diagnostics/plotters/posteriors/posteriors_2d.py b/MaCh3PythonUtils/diagnostics/plotters/posteriors/posteriors_2d.py index 50c6785..990e141 100644 --- a/MaCh3PythonUtils/diagnostics/plotters/posteriors/posteriors_2d.py +++ b/MaCh3PythonUtils/diagnostics/plotters/posteriors/posteriors_2d.py @@ -1,19 +1,19 @@ ''' HI : Makes 2D posterior plots. Currently no way of putting a legend on the plot (thanks arviz...) ''' -from MaCh3_plot_lib.file_handlers import root_file_loader +from config_reader import ChainHandler import arviz as az from typing import List, Any from matplotlib import pyplot as plt from itertools import combinations -from MaCh3_plot_lib.plotters.posteriors.posterior_base_classes import _posterior_plotting_base +from diagnostics.plotters.posteriors.posterior_base_classes import _PosteriorPlottingBase from matplotlib.figure import Figure import numpy as np import numpy.typing as npt -class posterior_plotter_2D(_posterior_plotting_base): +class PosteriorPlotter2D(_PosteriorPlottingBase): # For making 2D possteriors - def __init__(self, file_loader: root_file_loader)->None: + def __init__(self, file_loader: ChainHandler)->None: ''' Constructor ''' @@ -37,8 +37,8 @@ def _generate_plot(self, parameter_names: List[str]) -> npt.NDArray[Any]: fig, axes = plt.subplots(figsize=(30, 30)) - par_1_numpy_arr = self._file_loader.ttree_array[par_1].to_numpy()[0] - par_2_numpy_arr = self._file_loader.ttree_array[par_2].to_numpy()[0] + par_1_numpy_arr = self._file_loader.arviz_tree[par_1].to_numpy()[0] + par_2_numpy_arr = self._file_loader.arviz_tree[par_2].to_numpy()[0] ciruclar = self._circular_params[self._get_param_index_from_name(par_1)] | self._circular_params[self._get_param_index_from_name(par_2)] @@ -60,9 +60,9 @@ def _generate_plot(self, parameter_names: List[str]) -> npt.NDArray[Any]: -class triangle_plotter(_posterior_plotting_base): +class TrianglePlotter(_PosteriorPlottingBase): # Makes triangle plots - def __init__(self, file_loader: root_file_loader)->None: + def __init__(self, file_loader: ChainHandler)->None: ''' Constructor ''' @@ -86,7 +86,7 @@ def _generate_plot(self, parameter_names: List[str]) -> Figure: fig, axes = plt.subplots(nrows=len(parameter_names), ncols=len(parameter_names), figsize=(30, 30)) - az.plot_pair(self._file_loader.ttree_array, var_names=parameter_names, + az.plot_pair(self._file_loader.arviz_tree, var_names=parameter_names, marginals=True, ax=axes, colorbar=True, figsize=(30,30), kind='kde', textsize=30, diff --git a/MaCh3PythonUtils/file_handling/chain_handler.py b/MaCh3PythonUtils/file_handling/chain_handler.py index 9a3ccaf..9a020c7 100644 --- a/MaCh3PythonUtils/file_handling/chain_handler.py +++ b/MaCh3PythonUtils/file_handling/chain_handler.py @@ -8,6 +8,7 @@ from concurrent.futures import ThreadPoolExecutor import gc import numpy as np +import arviz as az class ChainHandler: """ @@ -44,6 +45,8 @@ def __init__(self, file_name: str, ttree_name: str="posteriors", verbose=False)- self._verbose = verbose self._ignored_branches = [] + self._arviz_tree = None + def close_file(self)->None: ''' Closes ROOT file, should be called to avoid memory issues! @@ -183,7 +186,7 @@ def ttree_array(self)->pd.DataFrame: :rtype: Union[np.array, pd.DataFrame, ak.Array] ''' return self._ttree_array - + @ttree_array.setter def ttree_array(self, new_array: Any=None)->None: ''' @@ -192,4 +195,16 @@ def ttree_array(self, new_array: Any=None)->None: :type new_array: Any ''' # Implemented in case someone tries to do something daft! - raise NotImplementedError("Cannot set converted TTree array to new type") \ No newline at end of file + raise NotImplementedError("Cannot set converted TTree array to new type") + + def make_arviz_tree(self): + if self._ttree_array is None: + raise RuntimeError("Error have not converted ROOT TTree to pandas data frame yet!") + + print("Generating Arviz data struct [this may take some time!]") + self._arviz_tree = az.dict_to_dataset(self._ttree_array.to_dict(orient='list')) + + + @property + def arviz_tree(self): + return self._arviz_tree \ No newline at end of file diff --git a/configs/plotting_config.yml b/configs/plotting_config.yml new file mode 100644 index 0000000..1ce6926 --- /dev/null +++ b/configs/plotting_config.yml @@ -0,0 +1,44 @@ +FileSettings: + FileName : "/users/php20hti/php20hti/public/chains/all_par_chain_without_adaption.root" + # FileName : "/users/php20hti/t2ksft/regeneration_times_project/inputs/markov_chains/oscpar_only_datafit.root" + ChainName : "posteriors" + # ChainName : "osc_posteriors" + Verbose : False + # Model Settings + MakeDiagnostics: True + MakePosteriors: True + +ParameterSettings: + ParameterNames : ["xsec"] + # ParameterNames : ["d", "th"] + IgnoredParameters : ["LogL_systematic_xsec_cov", "LogL_systematic_nddet_cov", ] + ParameterCuts : ["step>10000"] #["LogL<12345678", "step>10000"] + CircularParameters: [] + +PlottingSettings: + DiagnosticsSettings: + # Where are we prioting? + DiagnosticsOutputFile: "diagnostics_output.pdf" + # Make Trace/AC Plot? + MakeTraceAC: True + # Make Violin Plot? + MakeViolin: False + #Make ESS Plot? + MakeESS: False + # Make MCSE Plot + MakeMCSE: False + # Make suboptimality Plot + MakeSuboptimality: False + # Steps/calculation for subopt. + SuboptimalitySteps: 10000 + # Print summary stats + PrintSummary: True + + PosteriorSettings: + PosteriorOutputFile: "posteriors.pdf" + Make1DPosteriors: True + CredibleIntervals: [0.6, 0.90, 0.95] + Make2DPosteriors: False + MakeTrianglePlot: False + TrianglePlot: [] + \ No newline at end of file diff --git a/configs/scikit_config.yml b/configs/scikit_config.yml index 2568d23..6c8a3e7 100644 --- a/configs/scikit_config.yml +++ b/configs/scikit_config.yml @@ -3,15 +3,21 @@ FileSettings: # FileName : "/users/php20hti/t2ksft/regeneration_times_project/inputs/markov_chains/oscpar_only_datafit.root" ChainName : "posteriors" # ChainName : "osc_posteriors" + Verbose : False + ModelOutputName: "tf_model_full.pkl" + + # Model Settings + MakeMLModel: True + +ParameterSettings: ParameterNames : ["sin2th","delm2", "delta", "xsec", "sk", "nd"] # ParameterNames : ["d", "th"] LabelName : "LogL" IgnoredParameters : ["LogL_systematic_xsec_cov", "Log", "LogL_systematic_nddet_cov", ] ParameterCuts : ["step>10000"] #["LogL<12345678", "step>10000"] - Verbose : False - ModelOutputName: "tf_model_full.pkl" -FitterSettings: + +MLSettings: FitterPackage : "SciKit" FitterName : "histboost" @@ -20,3 +26,4 @@ FitterSettings: FitterKwargs: max_iter: 10000 + MLOutputName: "tf_model_full.pkl" diff --git a/configs/tensorflow_config.yml b/configs/tensorflow_config.yml index 9d2d0f2..c22cc5b 100644 --- a/configs/tensorflow_config.yml +++ b/configs/tensorflow_config.yml @@ -3,15 +3,19 @@ FileSettings: # FileName : "/users/php20hti/t2ksft/regeneration_times_project/inputs/markov_chains/oscpar_only_datafit.root" ChainName : "posteriors" # ChainName : "osc_posteriors" + Verbose : False + + # Model Settings + MakeMLModel: True + +ParameterSettings: ParameterNames : ["sin2th","delm2", "delta", "xsec", "sk", "nd"] - # ParameterNames : ["d", "th"] LabelName : "LogL" IgnoredParameters : ["LogL_systematic_xsec_cov", "Log", "LogL_systematic_nddet_cov", ] ParameterCuts : ["step>10000"] #["LogL<12345678", "step>10000"] - Verbose : False - ModelOutputName: "tf_model_full.keras" -FitterSettings: + +MLSettings: FitterPackage : "TensorFlow" FitterName : "Sequential" @@ -39,4 +43,6 @@ FitterSettings: FitSettings: batch_size: 100 - epochs: 200 \ No newline at end of file + epochs: 200 + + MLOutputName: "tf_model_full.keras" diff --git a/requirements.txt b/requirements.txt index 7bf1e2e..7af9f32 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,14 @@ numpy matplotlib scikit-learn -toml uproot pandas -statsmodels mpl_scatter_density scipy -astropy pyyaml -tensorflow # the biggun' \ No newline at end of file +tensorflow # the biggun' +pydantic +arviz +mplhep +tqdm +numba \ No newline at end of file From 2f218034bee354ab9a0faea478c5de175bb8001f Mon Sep 17 00:00:00 2001 From: henry-israel Date: Tue, 17 Sep 2024 16:27:32 +0100 Subject: [PATCH 06/11] Small patches' --- MaCh3PythonUtils/config_reader.py | 12 ++--- README.md | 89 ++++++++++++++++++++++++++----- configs/plotting_config.yml | 8 ++- configs/scikit_config.yml | 2 - 4 files changed, 88 insertions(+), 23 deletions(-) diff --git a/MaCh3PythonUtils/config_reader.py b/MaCh3PythonUtils/config_reader.py index 4773db6..30e4cb9 100644 --- a/MaCh3PythonUtils/config_reader.py +++ b/MaCh3PythonUtils/config_reader.py @@ -158,19 +158,19 @@ def make_diagnostics_plots(self): diagnostic_labels = [] if self.__chain_settings['PlottingSettings']['DiagnosticsSettings']['MakeViolin']: - self._plotting_interface.initialise_new_plotter(m3diag.ViolinPlotter(self._file_handler), 'violin_plotter') + self._plot_interface.initialise_new_plotter(m3diag.ViolinPlotter(self._file_handler), 'violin_plotter') diagnostic_labels.append('violin_plotter') if self.__chain_settings['PlottingSettings']['DiagnosticsSettings']['MakeTraceAC']: - self._plotting_interface.initialise_new_plotter(m3diag.AutocorrelationTracePlotter(self._file_handler), 'trace_autocorr') + self._plot_interface.initialise_new_plotter(m3diag.AutocorrelationTracePlotter(self._file_handler), 'trace_autocorr') diagnostic_labels.append('trace_autocorr') if self.__chain_settings['PlottingSettings']['DiagnosticsSettings']['MakeESS']: - self._plotting_interface.initialise_new_plotter(m3diag.EffectiveSampleSizePlotter(self._file_handler), 'ess_plot') + self._plot_interface.initialise_new_plotter(m3diag.EffectiveSampleSizePlotter(self._file_handler), 'ess_plot') diagnostic_labels.append('ess_plot') if self.__chain_settings['PlottingSettings']['DiagnosticsSettings']['MakeMCSE']: - self._plotting_interface.initialise_new_plotter(m3diag.MarkovChainStandardError(self._file_handler), 'msce_plot') + self._plot_interface.initialise_new_plotter(m3diag.MarkovChainStandardError(self._file_handler), 'msce_plot') diagnostic_labels.append('msce_plot') - self._plotting_interface.make_plots(self.__chain_settings['PlottingSettings']['DiagnosticsSettings']['DiagnosticsOutputFile'], diagnostic_labels) + self._plot_interface.make_plots(self.__chain_settings['PlottingSettings']['DiagnosticsSettings']['DiagnosticsOutputFile'], diagnostic_labels) # Final one, covariance plotter if self.__chain_settings['PlottingSettings']['DiagnosticsSettings']['MakeSuboptimality']: @@ -180,7 +180,7 @@ def make_diagnostics_plots(self): # Finally let's make a quick simmary if self.__chain_settings['PlottingSettings']['DiagnosticsSettings']['PrintSummary']: - self._plotting_interface.print_summary(f"summary_{self.__chain_settings['PlottingSettings']['DiagnosticsSettings']['DiagnosticsOutputFile']}.txt") + self._plot_interface.print_summary(f"summary_{self.__chain_settings['PlottingSettings']['DiagnosticsSettings']['DiagnosticsOutputFile']}.txt") def make_ml_interface(self)->None: if self._file_handler is None: diff --git a/README.md b/README.md index 0dbf176..8c1d6c7 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,30 @@ Very simple tool for predicting likelihoods from Markov Chains Using ML tools. Currently only accepts chains where all variables are saved within a ROOT TTree. -# Configs +# Setup +Setup is relatively simple. The recommended way of running this is using a virtual environment +```bash +virtualenv .env +``` + +Then to setup, +```bash +source .env/bin/activate +``` +The required modules can then be installed with +```bash +pip install -r requirements.txt +``` + +# Running +Running the package is also simple. Simply do +``` +python MaCh3PythonUtils -c /path/to/config.yml +``` + +Some example configs can be found in the `configs` folder. + +# Configs Configs are in YAML format For all pacakges the initial setup is very similar: @@ -11,29 +34,71 @@ For all pacakges the initial setup is very similar: FileSettings: FileName : "/path/to/file" # Name of File ChainName : "/tree/name" # Name of Tree in file containing chain + # Do you want it to be verbose? + Verbose: False + + # What plots do we want? + MakeMLModel: True # ML stuff + MakeDiagnostics: True # Make MCMC diagnostic plots + MakePosteriors: True # Make posteriors +# Settings for parameter options +ParameterSettings: # Names of parameters to fit, finds all parameters containing names in this string as sub-string ParameterNames : ["sin2th", "sin2th","delm2_12", "delta", "xsec"] - # Name of variable you're fitting in LabelName : "LogL" - # Parameters you don't want to include in the model IgnoredParameters : ["LogL_systematic_xsec_cov", "Log", "LogL_systematic_nddet_cov", ] - # Any cuts, for MaCh3 I'd recommend capping LogL ParameterCuts : ["LogL<12345678", "step>10000"] + CircularParameters: [] # Do any parameters loop? - # Do you want it to be verbose? - Verbose:= false +``` + +# Plotting Settings +This package contains various plotting tools required for analysing markov chains which are stored in the Plotting library! + +```yaml +PlottingSettings: + DiagnosticsSettings: + # Where are we prioting? + DiagnosticsOutputFile: "diagnostics_output.pdf" + # Make Trace/AC Plot? + MakeTraceAC: True + # Make Violin Plot? + MakeViolin: False + #Make ESS Plot? + MakeESS: False + # Make MCSE Plot + MakeMCSE: False + # Make suboptimality Plot + MakeSuboptimality: False + # Steps/calculation for subopt. + SuboptimalitySteps: 10000 + # Print summary stats + PrintSummary: True + + PosteriorSettings: + # Output PDF + PosteriorOutputFile: "posteriors.pdf" + # Do you want 1D CIs? + Make1DPosteriors: True + # Plotted CIs + CredibleIntervals: [0.6, 0.90, 0.95] + # Do you want 2D CIs? + Make2DPosteriors: False + # Do you want a triangle plot? + MakeTrianglePlot: False + # variables to put in the triangle - # Where is the model being pickled? - ModelOutputName : "histboost_model_full.pkl" ``` +# ML Settings + For scikit learn based pacakges the settings are then set in the following way, where FitterKwargs directly sets the keyword arguments for the scikit fitting tool being used ```yaml -FitterSettings: +MLSettings: # Package model is included in FitterPackage : "SciKit" @@ -81,6 +146,7 @@ FitterSettings: ``` + Here FitterKwargs is now split into sub-settings with `BuildSettings` being passed to the model `compile` method, `FitSettings` setting up training information, and `Layers` defining the types + kwargs of each layer in the model. New layers can be implemented in the `__TF_LAYER_IMPLEMENTATIONS` object which lives in `machine_learning/tf_interface` # Executables @@ -90,8 +156,3 @@ Simply run `python MachineLearningMCMC -c /path/to/toml/config` and it'll automa Implementing a new fitter is relatively simple. Most implementing is done in `machine_learining/ml_factory/MLFactory`. For Scikit-Learn based models, the new method just needs to imported and added to the `scikit` entry in `__IMPLEMENTED_ALGORITHMS`. For non-scikit based algorithms currentlt no implementation exists. For such cases a new interface class (which inherits from `FMLInterface`) needs to be implemented. Hopefully in future this is easy to do! - -# TODO: -* More libs than scikit -* Configurable fitters -* Clearer Readme \ No newline at end of file diff --git a/configs/plotting_config.yml b/configs/plotting_config.yml index 1ce6926..657b230 100644 --- a/configs/plotting_config.yml +++ b/configs/plotting_config.yml @@ -6,7 +6,7 @@ FileSettings: Verbose : False # Model Settings MakeDiagnostics: True - MakePosteriors: True + MakePosteriors: False ParameterSettings: ParameterNames : ["xsec"] @@ -35,10 +35,16 @@ PlottingSettings: PrintSummary: True PosteriorSettings: + # Output PDF PosteriorOutputFile: "posteriors.pdf" + # Do you want 1D CIs? Make1DPosteriors: True + # Plotted CIs CredibleIntervals: [0.6, 0.90, 0.95] + # Do you want 2D CIs? Make2DPosteriors: False + # Do you want a triangle plot? MakeTrianglePlot: False + # variables to put in the triangle plot TrianglePlot: [] \ No newline at end of file diff --git a/configs/scikit_config.yml b/configs/scikit_config.yml index 6c8a3e7..31ed34d 100644 --- a/configs/scikit_config.yml +++ b/configs/scikit_config.yml @@ -4,8 +4,6 @@ FileSettings: ChainName : "posteriors" # ChainName : "osc_posteriors" Verbose : False - ModelOutputName: "tf_model_full.pkl" - # Model Settings MakeMLModel: True From f315a9ce30092af9acbe884409f0b84ac13e3fcc Mon Sep 17 00:00:00 2001 From: henry-israel Date: Tue, 17 Sep 2024 16:29:47 +0100 Subject: [PATCH 07/11] More small fixes to naming --- MaCh3PythonUtils/config_reader.py | 4 ++-- MaCh3PythonUtils/diagnostics/interface/plotting_interface.py | 4 ++-- .../diagnostics/{plotters => mcmc_plots}/__init__.py | 0 .../{plotters => mcmc_plots}/diagnostics/__init__.py | 0 .../diagnostics/autocorrelation_trace_plotter.py | 2 +- .../diagnostics/covariance_matrix_utils.py | 0 .../{plotters => mcmc_plots}/diagnostics/simple_diag_plots.py | 2 +- .../diagnostics/{plotters => mcmc_plots}/plotter_base.py | 0 .../{plotters => mcmc_plots}/posteriors/__init__.py | 0 .../posteriors/posterior_base_classes.py | 2 +- .../{plotters => mcmc_plots}/posteriors/posteriors_1d.py | 2 +- .../{plotters => mcmc_plots}/posteriors/posteriors_2d.py | 2 +- 12 files changed, 9 insertions(+), 9 deletions(-) rename MaCh3PythonUtils/diagnostics/{plotters => mcmc_plots}/__init__.py (100%) rename MaCh3PythonUtils/diagnostics/{plotters => mcmc_plots}/diagnostics/__init__.py (100%) rename MaCh3PythonUtils/diagnostics/{plotters => mcmc_plots}/diagnostics/autocorrelation_trace_plotter.py (96%) rename MaCh3PythonUtils/diagnostics/{plotters => mcmc_plots}/diagnostics/covariance_matrix_utils.py (100%) rename MaCh3PythonUtils/diagnostics/{plotters => mcmc_plots}/diagnostics/simple_diag_plots.py (97%) rename MaCh3PythonUtils/diagnostics/{plotters => mcmc_plots}/plotter_base.py (100%) rename MaCh3PythonUtils/diagnostics/{plotters => mcmc_plots}/posteriors/__init__.py (100%) rename MaCh3PythonUtils/diagnostics/{plotters => mcmc_plots}/posteriors/posterior_base_classes.py (96%) rename MaCh3PythonUtils/diagnostics/{plotters => mcmc_plots}/posteriors/posteriors_1d.py (97%) rename MaCh3PythonUtils/diagnostics/{plotters => mcmc_plots}/posteriors/posteriors_2d.py (97%) diff --git a/MaCh3PythonUtils/config_reader.py b/MaCh3PythonUtils/config_reader.py index 30e4cb9..2d03402 100644 --- a/MaCh3PythonUtils/config_reader.py +++ b/MaCh3PythonUtils/config_reader.py @@ -4,8 +4,8 @@ from machine_learning.ml_factory import MLFactory from machine_learning.fml_interface import FmlInterface from diagnostics.interface.plotting_interface import PlottingInterface -import diagnostics.plotters.posteriors as m3post -import diagnostics.plotters.diagnostics as m3diag +import diagnostics.mcmc_plots.posteriors as m3post +import diagnostics.mcmc_plots.diagnostics as m3diag from pydantic.utils import deep_update class ConfigReader: diff --git a/MaCh3PythonUtils/diagnostics/interface/plotting_interface.py b/MaCh3PythonUtils/diagnostics/interface/plotting_interface.py index 02e8534..2424efa 100644 --- a/MaCh3PythonUtils/diagnostics/interface/plotting_interface.py +++ b/MaCh3PythonUtils/diagnostics/interface/plotting_interface.py @@ -3,7 +3,7 @@ ''' from typing import List from file_handling.chain_handler import ChainHandler -import diagnostics.plotters as pt +import diagnostics.mcmc_plots as pt from matplotlib.backends.backend_pdf import PdfPages import arviz as az @@ -124,7 +124,7 @@ def print_summary(self, latex_output_name:str=None): inputs : latex_output_name : [type=str, optional] name of output file ''' - summary = az.summary(self._file_loader.ttree_array, kind='stats', hdi_prob=0.9) + summary = az.summary(self._file_loader.arviz_tree, kind='stats', hdi_prob=0.9) if latex_output_name is None: return diff --git a/MaCh3PythonUtils/diagnostics/plotters/__init__.py b/MaCh3PythonUtils/diagnostics/mcmc_plots/__init__.py similarity index 100% rename from MaCh3PythonUtils/diagnostics/plotters/__init__.py rename to MaCh3PythonUtils/diagnostics/mcmc_plots/__init__.py diff --git a/MaCh3PythonUtils/diagnostics/plotters/diagnostics/__init__.py b/MaCh3PythonUtils/diagnostics/mcmc_plots/diagnostics/__init__.py similarity index 100% rename from MaCh3PythonUtils/diagnostics/plotters/diagnostics/__init__.py rename to MaCh3PythonUtils/diagnostics/mcmc_plots/diagnostics/__init__.py diff --git a/MaCh3PythonUtils/diagnostics/plotters/diagnostics/autocorrelation_trace_plotter.py b/MaCh3PythonUtils/diagnostics/mcmc_plots/diagnostics/autocorrelation_trace_plotter.py similarity index 96% rename from MaCh3PythonUtils/diagnostics/plotters/diagnostics/autocorrelation_trace_plotter.py rename to MaCh3PythonUtils/diagnostics/mcmc_plots/diagnostics/autocorrelation_trace_plotter.py index d0ef6f8..9769c8f 100644 --- a/MaCh3PythonUtils/diagnostics/plotters/diagnostics/autocorrelation_trace_plotter.py +++ b/MaCh3PythonUtils/diagnostics/mcmc_plots/diagnostics/autocorrelation_trace_plotter.py @@ -4,7 +4,7 @@ from file_handling.chain_handler import ChainHandler import arviz as az from matplotlib import pyplot as plt -from diagnostics.plotters.plotter_base import _PlottingBaseClass +from diagnostics.mcmc_plots.plotter_base import _PlottingBaseClass from matplotlib.figure import Figure class AutocorrelationTracePlotter(_PlottingBaseClass): diff --git a/MaCh3PythonUtils/diagnostics/plotters/diagnostics/covariance_matrix_utils.py b/MaCh3PythonUtils/diagnostics/mcmc_plots/diagnostics/covariance_matrix_utils.py similarity index 100% rename from MaCh3PythonUtils/diagnostics/plotters/diagnostics/covariance_matrix_utils.py rename to MaCh3PythonUtils/diagnostics/mcmc_plots/diagnostics/covariance_matrix_utils.py diff --git a/MaCh3PythonUtils/diagnostics/plotters/diagnostics/simple_diag_plots.py b/MaCh3PythonUtils/diagnostics/mcmc_plots/diagnostics/simple_diag_plots.py similarity index 97% rename from MaCh3PythonUtils/diagnostics/plotters/diagnostics/simple_diag_plots.py rename to MaCh3PythonUtils/diagnostics/mcmc_plots/diagnostics/simple_diag_plots.py index f9589ef..0e830ed 100644 --- a/MaCh3PythonUtils/diagnostics/plotters/diagnostics/simple_diag_plots.py +++ b/MaCh3PythonUtils/diagnostics/mcmc_plots/diagnostics/simple_diag_plots.py @@ -6,7 +6,7 @@ from file_handling.chain_handler import ChainHandler import arviz as az from matplotlib import pyplot as plt -from diagnostics.plotters.plotter_base import _PlottingBaseClass +from diagnostics.mcmc_plots.plotter_base import _PlottingBaseClass from matplotlib.figure import Figure class EffectiveSampleSizePlotter(_PlottingBaseClass): diff --git a/MaCh3PythonUtils/diagnostics/plotters/plotter_base.py b/MaCh3PythonUtils/diagnostics/mcmc_plots/plotter_base.py similarity index 100% rename from MaCh3PythonUtils/diagnostics/plotters/plotter_base.py rename to MaCh3PythonUtils/diagnostics/mcmc_plots/plotter_base.py diff --git a/MaCh3PythonUtils/diagnostics/plotters/posteriors/__init__.py b/MaCh3PythonUtils/diagnostics/mcmc_plots/posteriors/__init__.py similarity index 100% rename from MaCh3PythonUtils/diagnostics/plotters/posteriors/__init__.py rename to MaCh3PythonUtils/diagnostics/mcmc_plots/posteriors/__init__.py diff --git a/MaCh3PythonUtils/diagnostics/plotters/posteriors/posterior_base_classes.py b/MaCh3PythonUtils/diagnostics/mcmc_plots/posteriors/posterior_base_classes.py similarity index 96% rename from MaCh3PythonUtils/diagnostics/plotters/posteriors/posterior_base_classes.py rename to MaCh3PythonUtils/diagnostics/mcmc_plots/posteriors/posterior_base_classes.py index d9d041b..223008a 100644 --- a/MaCh3PythonUtils/diagnostics/plotters/posteriors/posterior_base_classes.py +++ b/MaCh3PythonUtils/diagnostics/mcmc_plots/posteriors/posterior_base_classes.py @@ -3,7 +3,7 @@ Separated into 1D and 2D-like objects to make life easier ''' from file_handling.chain_handler import ChainHandler -from diagnostics.plotters.plotter_base import _PlottingBaseClass +from diagnostics.mcmc_plots.plotter_base import _PlottingBaseClass from typing import List import numpy as np from tqdm.auto import tqdm diff --git a/MaCh3PythonUtils/diagnostics/plotters/posteriors/posteriors_1d.py b/MaCh3PythonUtils/diagnostics/mcmc_plots/posteriors/posteriors_1d.py similarity index 97% rename from MaCh3PythonUtils/diagnostics/plotters/posteriors/posteriors_1d.py rename to MaCh3PythonUtils/diagnostics/mcmc_plots/posteriors/posteriors_1d.py index b4da96f..16de174 100644 --- a/MaCh3PythonUtils/diagnostics/plotters/posteriors/posteriors_1d.py +++ b/MaCh3PythonUtils/diagnostics/mcmc_plots/posteriors/posteriors_1d.py @@ -2,7 +2,7 @@ import arviz as az from matplotlib import pyplot as plt import numpy as np -from diagnostics.plotters.posteriors.posterior_base_classes import _PosteriorPlottingBase +from diagnostics.mcmc_plots.posteriors.posterior_base_classes import _PosteriorPlottingBase from matplotlib.figure import Figure ''' diff --git a/MaCh3PythonUtils/diagnostics/plotters/posteriors/posteriors_2d.py b/MaCh3PythonUtils/diagnostics/mcmc_plots/posteriors/posteriors_2d.py similarity index 97% rename from MaCh3PythonUtils/diagnostics/plotters/posteriors/posteriors_2d.py rename to MaCh3PythonUtils/diagnostics/mcmc_plots/posteriors/posteriors_2d.py index 990e141..4751116 100644 --- a/MaCh3PythonUtils/diagnostics/plotters/posteriors/posteriors_2d.py +++ b/MaCh3PythonUtils/diagnostics/mcmc_plots/posteriors/posteriors_2d.py @@ -6,7 +6,7 @@ from typing import List, Any from matplotlib import pyplot as plt from itertools import combinations -from diagnostics.plotters.posteriors.posterior_base_classes import _PosteriorPlottingBase +from diagnostics.mcmc_plots.posteriors.posterior_base_classes import _PosteriorPlottingBase from matplotlib.figure import Figure import numpy as np import numpy.typing as npt From bce079a57a57fe65a7088b9c7e83871e249a189e Mon Sep 17 00:00:00 2001 From: henry-israel Date: Tue, 17 Sep 2024 16:30:39 +0100 Subject: [PATCH 08/11] EVEN MORE REQUIREMENTS --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 7af9f32..22123b5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,5 @@ pydantic arviz mplhep tqdm -numba \ No newline at end of file +numba +jinja2 \ No newline at end of file From ea99fc5de0458d940988dd3593556230639474b3 Mon Sep 17 00:00:00 2001 From: henry-israel Date: Tue, 17 Sep 2024 16:31:18 +0100 Subject: [PATCH 09/11] Switch off deprecation warning --- MaCh3PythonUtils/config_reader.py | 2 +- summary_diagnostics_output.pdf.txt | 0 2 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 summary_diagnostics_output.pdf.txt diff --git a/MaCh3PythonUtils/config_reader.py b/MaCh3PythonUtils/config_reader.py index 2d03402..dc0ff72 100644 --- a/MaCh3PythonUtils/config_reader.py +++ b/MaCh3PythonUtils/config_reader.py @@ -6,7 +6,7 @@ from diagnostics.interface.plotting_interface import PlottingInterface import diagnostics.mcmc_plots.posteriors as m3post import diagnostics.mcmc_plots.diagnostics as m3diag -from pydantic.utils import deep_update +from pydantic.v1.utils import deep_update class ConfigReader: diff --git a/summary_diagnostics_output.pdf.txt b/summary_diagnostics_output.pdf.txt new file mode 100644 index 0000000..e69de29 From 31013c091bf4b849363014dc6753e783d2589bae Mon Sep 17 00:00:00 2001 From: henry-israel Date: Tue, 17 Sep 2024 16:31:40 +0100 Subject: [PATCH 10/11] Oops summary diags --- summary_diagnostics_output.pdf.txt | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 summary_diagnostics_output.pdf.txt diff --git a/summary_diagnostics_output.pdf.txt b/summary_diagnostics_output.pdf.txt deleted file mode 100644 index e69de29..0000000 From 0acfcfa8577a629e9d9ef5159d2af1d8974db357 Mon Sep 17 00:00:00 2001 From: henry-israel Date: Tue, 17 Sep 2024 16:41:00 +0100 Subject: [PATCH 11/11] Oops, fixes name change --- MaCh3PythonUtils/machine_learning/fml_interface.py | 1 - MaCh3PythonUtils/machine_learning/ml_factory.py | 4 ++-- requirements.txt | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/MaCh3PythonUtils/machine_learning/fml_interface.py b/MaCh3PythonUtils/machine_learning/fml_interface.py index 6a22cd6..3b9029f 100644 --- a/MaCh3PythonUtils/machine_learning/fml_interface.py +++ b/MaCh3PythonUtils/machine_learning/fml_interface.py @@ -9,7 +9,6 @@ from sklearn import metrics from sklearn.preprocessing import StandardScaler import matplotlib.pyplot as plt -import scipy.stats as stats import numpy as np import pickle diff --git a/MaCh3PythonUtils/machine_learning/ml_factory.py b/MaCh3PythonUtils/machine_learning/ml_factory.py index 2f1b236..c48411b 100644 --- a/MaCh3PythonUtils/machine_learning/ml_factory.py +++ b/MaCh3PythonUtils/machine_learning/ml_factory.py @@ -51,13 +51,13 @@ def __setup_package_factory(self, package: str, algorithm: str, **kwargs): return self.__IMPLEMENTED_ALGORITHMS[package][algorithm](*kwargs) - def setup_scikit_model(self, algorithm: str, **kwargs)->SciKitInterface: + def make_scikit_model(self, algorithm: str, **kwargs)->SciKitInterface: # Simple wrapper for scikit packages interface = SciKitInterface(self._chain, self._prediction_variable) interface.add_model(self.__setup_package_factory(package="scikit", algorithm=algorithm, **kwargs)) return interface - def setup_tensorflow_model(self, algorithm: str, **kwargs): + def make_tensorflow_model(self, algorithm: str, **kwargs): interface = TfInterface(self._chain, self._prediction_variable) interface.add_model(self.__setup_package_factory(package="tensorflow", algorithm=algorithm)) diff --git a/requirements.txt b/requirements.txt index 22123b5..188bd21 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,4 @@ arviz mplhep tqdm numba -jinja2 \ No newline at end of file +jinja2x \ No newline at end of file