Skip to content

Commit 2323e74

Browse files
Feature plotting utils (#9)
* adds scalar, makes demo config config more helpful * more config updates * Adds basic version of diagnostics package from T2K MaCh3, will not work just yet... * update gitignore * Implements plotting lib * Small patches' * More small fixes to naming * EVEN MORE REQUIREMENTS * Switch off deprecation warning * Oops summary diags
1 parent 28df1b8 commit 2323e74

21 files changed

+1189
-55
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
*env*
33
*__pycache__*
44
*pdf
5-
*.pkl
5+
*.pkl
6+
*.keras

MaCh3PythonUtils/config_reader.py

Lines changed: 189 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,71 +3,233 @@
33
from file_handling.chain_handler import ChainHandler
44
from machine_learning.ml_factory import MLFactory
55
from machine_learning.fml_interface import FmlInterface
6-
6+
from diagnostics.interface.plotting_interface import PlottingInterface
7+
import diagnostics.mcmc_plots.posteriors as m3post
8+
import diagnostics.mcmc_plots.diagnostics as m3diag
9+
from pydantic.v1.utils import deep_update
710

811
class ConfigReader:
912

1013
# Strictly unecessary but nice conceptually
1114
_file_handler = None
1215
_interface = None
16+
_plot_interface = None
17+
18+
__default_settings = {
19+
# Settings for file and I/O
20+
"FileSettings" : {
21+
# Name of input file
22+
"FileName": "",
23+
# Name of chain in file
24+
"ChainName": "",
25+
# More printouts
26+
"Verbose": False,
27+
# Make posteriors from chain?
28+
"MakePosteriors": False,
29+
# Run Diagnostics code?
30+
"MakeDiagnostics": False,
31+
# Make an ML model to replicate the chain likelihood model
32+
"MakeMLModel": False
33+
},
34+
35+
"ParameterSettings":{
36+
"CircularParameters" : [],
37+
# List of parameter names
38+
"ParameterNames":[],
39+
# List of cuts
40+
"ParameterCuts":[],
41+
# Name of label branch, used in ML
42+
"LabelName": "",
43+
# Any parameters we want to ignore
44+
"IgnoredParameters":[]
45+
},
46+
47+
# Settings for plotting tools
48+
"PlottingSettings":{
49+
# Specific Settings for posterior plots
50+
"PosteriorSettings": {
51+
# Make 2D Posterior Plots?
52+
"Make2DPosteriors": False,
53+
# Make a triangle plot
54+
"MakeTrianglePlot": False,
55+
# Variables in the triangle plot
56+
"TrianglePlot": [],
57+
# 1D credible intervals
58+
"MakeCredibleIntervals": False,
59+
# Output file
60+
"PosteriorOutputFile": "posteriors.pdf"
61+
},
62+
63+
# Specific Settings for diagnostic plots
64+
"DiagnosticsSettings": {
65+
# Make violin plot?
66+
"MakeViolin": False,
67+
# Make trace + AC plot?
68+
"MakeTraceAC": False,
69+
# Make effective sample size plot?
70+
"MakeESS": False,
71+
# Make MCSE plot?
72+
"MakeMCSE": False,
73+
# Make suboptimality plot
74+
"MakeSuboptimality": False,
75+
# Step for calculation
76+
"SuboptimalitySteps": 0,
77+
# Output file
78+
"DiagnosticsOutputFile": "diagnostics.pdf",
79+
# Print summary statistic
80+
"PrintSummary": False
81+
}
82+
},
83+
# Specific Settings for ML Applications
84+
"MLSettings": {
85+
# Fitter package either SciKit or TensorFlow
86+
"FitterPackage": "",
87+
# Fitter Model
88+
"FitterName": "",
89+
# Keyword arguments for fitter
90+
"FitterKwargs" : {},
91+
#Use an external model that's already been trained?
92+
"AddFromExternalModel": False,
93+
# Proportion of input data set used for testing (range of 0-1 )
94+
"TestSize": 0.0,
95+
# Name to save ML model in
96+
"MLOutputFile": "mlmodel"
97+
}
98+
99+
}
100+
101+
13102

14103
def __init__(self, config: str):
15104
with open(config, 'r') as c:
16105
self._yaml_config = yaml.safe_load(c)
106+
107+
# Update default settings
108+
self.__chain_settings = deep_update(self.__default_settings, self._yaml_config)
17109

18-
19-
def setup_file_handler(self)->None:
110+
def make_file_handler(self)->None:
20111
# Process MCMC chain
21-
self._file_handler = ChainHandler(self._yaml_config["FileSettings"]["FileName"],
22-
self._yaml_config["FileSettings"]["ChainName"],
23-
self._yaml_config["FileSettings"]["Verbose"])
112+
self._file_handler = ChainHandler(self.__chain_settings["FileSettings"]["FileName"],
113+
self.__chain_settings["FileSettings"]["ChainName"],
114+
self.__chain_settings["FileSettings"]["Verbose"])
24115

25-
self._file_handler.ignore_plots(self._yaml_config["FileSettings"]["IgnoredParameters"])
26-
self._file_handler.add_additional_plots(self._yaml_config["FileSettings"]["ParameterNames"])
27-
self._file_handler.add_additional_plots(self._yaml_config["FileSettings"]["LabelName"], True)
116+
self._file_handler.ignore_plots(self.__chain_settings["ParameterSettings"]["IgnoredParameters"])
117+
self._file_handler.add_additional_plots(self.__chain_settings["ParameterSettings"]["ParameterNames"])
118+
119+
self._file_handler.add_additional_plots(self.__chain_settings["ParameterSettings"]["LabelName"], True)
28120

29-
self._file_handler.add_new_cuts(self._yaml_config["FileSettings"]["ParameterCuts"])
121+
self._file_handler.add_new_cuts(self.__chain_settings["ParameterSettings"]["ParameterCuts"])
30122

31123
self._file_handler.convert_ttree_to_array()
32124

33125

34-
def setup_ml_interface(self)->None:
126+
127+
def make_posterior_plots(self):
128+
if self._plot_interface is None:
129+
self._plot_interface = PlottingInterface(self._file_handler)
130+
131+
posterior_labels = []
132+
133+
if self.__chain_settings['PlottingSettings']['PosteriorSettings']['Make2DPosteriors']:
134+
self._plot_interface.initialise_new_plotter(m3post.PosteriorPlotter2D(self._file_handler), 'posterior_2d')
135+
posterior_labels.append('posterior_2d')
136+
137+
if self.__chain_settings['PlottingSettings']['PosteriorSettings']['MakeTrianglePlot']:
138+
self._plot_interface.initialise_new_plotter(m3post.TrianglePlotter(self._file_handler), 'posterior_triangle')
139+
posterior_labels.append('posterior_triangle')
140+
141+
# Which variables do we actually want 2D plots for?
142+
self._plot_interface.set_variables_to_plot(self.__chain_settings['PlottingSettings']['PosteriorSettings']['TrianglePlot'], posterior_labels)
143+
144+
if self.__chain_settings['PlottingSettings']['PosteriorSettings']['Make1DPosteriors']:
145+
self._plot_interface.initialise_new_plotter(m3post.PosteriorPlotter1D(self._file_handler), 'posterior_1d')
146+
posterior_labels.append('posterior_1d')
147+
148+
149+
self._plot_interface.set_credible_intervals(self.__chain_settings['PlottingSettings']['PosteriorSettings']['CredibleIntervals'])
150+
self._plot_interface.set_is_circular(self.__chain_settings['ParameterSettings']['CircularParameters'])
151+
152+
self._plot_interface.make_plots(self.__chain_settings['PlottingSettings']['PosteriorSettings']['PosteriorOutputFile'], posterior_labels)
153+
154+
def make_diagnostics_plots(self):
155+
if self._plot_interface is None:
156+
self._plot_interface = PlottingInterface(self._file_handler)
157+
158+
diagnostic_labels = []
159+
160+
if self.__chain_settings['PlottingSettings']['DiagnosticsSettings']['MakeViolin']:
161+
self._plot_interface.initialise_new_plotter(m3diag.ViolinPlotter(self._file_handler), 'violin_plotter')
162+
diagnostic_labels.append('violin_plotter')
163+
if self.__chain_settings['PlottingSettings']['DiagnosticsSettings']['MakeTraceAC']:
164+
self._plot_interface.initialise_new_plotter(m3diag.AutocorrelationTracePlotter(self._file_handler), 'trace_autocorr')
165+
diagnostic_labels.append('trace_autocorr')
166+
if self.__chain_settings['PlottingSettings']['DiagnosticsSettings']['MakeESS']:
167+
self._plot_interface.initialise_new_plotter(m3diag.EffectiveSampleSizePlotter(self._file_handler), 'ess_plot')
168+
diagnostic_labels.append('ess_plot')
169+
if self.__chain_settings['PlottingSettings']['DiagnosticsSettings']['MakeMCSE']:
170+
self._plot_interface.initialise_new_plotter(m3diag.MarkovChainStandardError(self._file_handler), 'msce_plot')
171+
diagnostic_labels.append('msce_plot')
172+
173+
self._plot_interface.make_plots(self.__chain_settings['PlottingSettings']['DiagnosticsSettings']['DiagnosticsOutputFile'], diagnostic_labels)
174+
175+
# Final one, covariance plotter
176+
if self.__chain_settings['PlottingSettings']['DiagnosticsSettings']['MakeSuboptimality']:
177+
suboptimality_obj = m3diag.CovarianceMatrixUtils(self._file_handler)
178+
suboptimality_obj.calculate_suboptimality(self.__chain_settings['PlottingSettings']['DiagnosticsSettings']['SuboptimalitySteps'], self.__chain_settings['PlottingSettings']['DiagnosticsSettings']['SubOptimalityMin'])
179+
suboptimality_obj.plot_suboptimality(f"suboptimality_{self.__chain_settings['PlottingSettings']['DiagnosticsSettings']['DiagnosticsOutputFile']}")
180+
181+
# Finally let's make a quick simmary
182+
if self.__chain_settings['PlottingSettings']['DiagnosticsSettings']['PrintSummary']:
183+
self._plot_interface.print_summary(f"summary_{self.__chain_settings['PlottingSettings']['DiagnosticsSettings']['DiagnosticsOutputFile']}.txt")
184+
185+
def make_ml_interface(self)->None:
35186
if self._file_handler is None:
36187
raise Exception("Cannot initialise ML interface without first setting up file handler!")
37188

38189

39-
factory = MLFactory(self._file_handler, self._yaml_config["FileSettings"]["LabelName"])
40-
if self._yaml_config["FitterSettings"]["FitterPackage"].lower() == "scikit":
41-
self._interface = factory.setup_scikit_model(self._yaml_config["FitterSettings"]["FitterName"],
42-
**self._yaml_config["FitterSettings"]["FitterKwargs"])
190+
factory = MLFactory(self._file_handler, self.__chain_settings["ParameterSettings"]["LabelName"])
191+
if self.__chain_settings["MLSettings"]["FitterPackage"].lower() == "scikit":
192+
self._interface = factory.make_scikit_model(self.__chain_settings["MLSettings"]["FitterName"],
193+
**self.__chain_settings["MLSettings"]["FitterKwargs"])
43194

44-
elif self._yaml_config["FitterSettings"]["FitterPackage"].lower() == "tensorflow":
45-
self._interface = factory.setup_tensorflow_model(self._yaml_config["FitterSettings"]["FitterName"],
46-
**self._yaml_config["FitterSettings"]["FitterKwargs"])
195+
elif self.__chain_settings["MLSettings"]["FitterPackage"].lower() == "tensorflow":
196+
self._interface = factory.make_tensorflow_model(self.__chain_settings["MLSettings"]["FitterName"],
197+
**self.__chain_settings["MLSettings"]["FitterKwargs"])
47198

48199
else:
49200
raise ValueError("Input not recognised!")
50201

51-
if self._yaml_config["FitterSettings"].get("AddFromExternalModel"):
52-
external_model = self._yaml_config["FitterSettings"]["ExternalModel"]
202+
if self.__chain_settings["MLSettings"].get("AddFromExternalModel"):
203+
external_model = self.__chain_settings["MLSettings"]["ExternalModel"]
53204
self._interface.load_model(external_model)
54205

55206
else:
56-
self._interface.set_training_test_set(self._yaml_config["FitterSettings"]["TestSize"])
207+
self._interface.set_training_test_set(self.__chain_settings["MLSettings"]["TestSize"])
57208

58209
self._interface.train_model()
59210
self._interface.test_model()
60-
self._interface.save_model(self._yaml_config["FileSettings"]["ModelOutputName"])
61-
62-
def __call__(self) -> None:
63-
self.setup_file_handler()
64-
self.setup_ml_interface()
211+
self._interface.save_model(self.__chain_settings["MLSettings"]["MLOutputFile"])
65212

213+
214+
215+
def __call__(self) -> None:
216+
217+
self.make_file_handler()
218+
if self.__chain_settings["FileSettings"]["MakePosteriors"]:
219+
self.make_posterior_plots()
66220

221+
if self.__chain_settings["FileSettings"]["MakeDiagnostics"]:
222+
self.make_diagnostics_plots()
223+
224+
if self.__chain_settings["FileSettings"]["MakeMLModel"]:
225+
self.make_ml_interface()
226+
227+
67228
@property
68229
def chain_handler(self)->ChainHandler | None:
69230
return self._file_handler
70231

71232
@property
72233
def ml_interface(self)->FmlInterface | None:
73-
return self._interface
234+
return self._interface
235+

MaCh3PythonUtils/diagnostics/__init__.py

Whitespace-only changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .plotting_interface import PlottingInterface

0 commit comments

Comments
 (0)