3
3
from file_handling .chain_handler import ChainHandler
4
4
from machine_learning .ml_factory import MLFactory
5
5
from machine_learning .fml_interface import FmlInterface
6
-
6
+ from diagnostics .interface .plotting_interface import PlottingInterface
7
+ import diagnostics .mcmc_plots .posteriors as m3post
8
+ import diagnostics .mcmc_plots .diagnostics as m3diag
9
+ from pydantic .v1 .utils import deep_update
7
10
8
11
class ConfigReader :
9
12
10
13
# Strictly unecessary but nice conceptually
11
14
_file_handler = None
12
15
_interface = None
16
+ _plot_interface = None
17
+
18
+ __default_settings = {
19
+ # Settings for file and I/O
20
+ "FileSettings" : {
21
+ # Name of input file
22
+ "FileName" : "" ,
23
+ # Name of chain in file
24
+ "ChainName" : "" ,
25
+ # More printouts
26
+ "Verbose" : False ,
27
+ # Make posteriors from chain?
28
+ "MakePosteriors" : False ,
29
+ # Run Diagnostics code?
30
+ "MakeDiagnostics" : False ,
31
+ # Make an ML model to replicate the chain likelihood model
32
+ "MakeMLModel" : False
33
+ },
34
+
35
+ "ParameterSettings" :{
36
+ "CircularParameters" : [],
37
+ # List of parameter names
38
+ "ParameterNames" :[],
39
+ # List of cuts
40
+ "ParameterCuts" :[],
41
+ # Name of label branch, used in ML
42
+ "LabelName" : "" ,
43
+ # Any parameters we want to ignore
44
+ "IgnoredParameters" :[]
45
+ },
46
+
47
+ # Settings for plotting tools
48
+ "PlottingSettings" :{
49
+ # Specific Settings for posterior plots
50
+ "PosteriorSettings" : {
51
+ # Make 2D Posterior Plots?
52
+ "Make2DPosteriors" : False ,
53
+ # Make a triangle plot
54
+ "MakeTrianglePlot" : False ,
55
+ # Variables in the triangle plot
56
+ "TrianglePlot" : [],
57
+ # 1D credible intervals
58
+ "MakeCredibleIntervals" : False ,
59
+ # Output file
60
+ "PosteriorOutputFile" : "posteriors.pdf"
61
+ },
62
+
63
+ # Specific Settings for diagnostic plots
64
+ "DiagnosticsSettings" : {
65
+ # Make violin plot?
66
+ "MakeViolin" : False ,
67
+ # Make trace + AC plot?
68
+ "MakeTraceAC" : False ,
69
+ # Make effective sample size plot?
70
+ "MakeESS" : False ,
71
+ # Make MCSE plot?
72
+ "MakeMCSE" : False ,
73
+ # Make suboptimality plot
74
+ "MakeSuboptimality" : False ,
75
+ # Step for calculation
76
+ "SuboptimalitySteps" : 0 ,
77
+ # Output file
78
+ "DiagnosticsOutputFile" : "diagnostics.pdf" ,
79
+ # Print summary statistic
80
+ "PrintSummary" : False
81
+ }
82
+ },
83
+ # Specific Settings for ML Applications
84
+ "MLSettings" : {
85
+ # Fitter package either SciKit or TensorFlow
86
+ "FitterPackage" : "" ,
87
+ # Fitter Model
88
+ "FitterName" : "" ,
89
+ # Keyword arguments for fitter
90
+ "FitterKwargs" : {},
91
+ #Use an external model that's already been trained?
92
+ "AddFromExternalModel" : False ,
93
+ # Proportion of input data set used for testing (range of 0-1 )
94
+ "TestSize" : 0.0 ,
95
+ # Name to save ML model in
96
+ "MLOutputFile" : "mlmodel"
97
+ }
98
+
99
+ }
100
+
101
+
13
102
14
103
def __init__ (self , config : str ):
15
104
with open (config , 'r' ) as c :
16
105
self ._yaml_config = yaml .safe_load (c )
106
+
107
+ # Update default settings
108
+ self .__chain_settings = deep_update (self .__default_settings , self ._yaml_config )
17
109
18
-
19
- def setup_file_handler (self )-> None :
110
+ def make_file_handler (self )-> None :
20
111
# Process MCMC chain
21
- self ._file_handler = ChainHandler (self ._yaml_config ["FileSettings" ]["FileName" ],
22
- self ._yaml_config ["FileSettings" ]["ChainName" ],
23
- self ._yaml_config ["FileSettings" ]["Verbose" ])
112
+ self ._file_handler = ChainHandler (self .__chain_settings ["FileSettings" ]["FileName" ],
113
+ self .__chain_settings ["FileSettings" ]["ChainName" ],
114
+ self .__chain_settings ["FileSettings" ]["Verbose" ])
24
115
25
- self ._file_handler .ignore_plots (self ._yaml_config ["FileSettings" ]["IgnoredParameters" ])
26
- self ._file_handler .add_additional_plots (self ._yaml_config ["FileSettings" ]["ParameterNames" ])
27
- self ._file_handler .add_additional_plots (self ._yaml_config ["FileSettings" ]["LabelName" ], True )
116
+ self ._file_handler .ignore_plots (self .__chain_settings ["ParameterSettings" ]["IgnoredParameters" ])
117
+ self ._file_handler .add_additional_plots (self .__chain_settings ["ParameterSettings" ]["ParameterNames" ])
118
+
119
+ self ._file_handler .add_additional_plots (self .__chain_settings ["ParameterSettings" ]["LabelName" ], True )
28
120
29
- self ._file_handler .add_new_cuts (self ._yaml_config [ "FileSettings " ]["ParameterCuts" ])
121
+ self ._file_handler .add_new_cuts (self .__chain_settings [ "ParameterSettings " ]["ParameterCuts" ])
30
122
31
123
self ._file_handler .convert_ttree_to_array ()
32
124
33
125
34
- def setup_ml_interface (self )-> None :
126
+
127
+ def make_posterior_plots (self ):
128
+ if self ._plot_interface is None :
129
+ self ._plot_interface = PlottingInterface (self ._file_handler )
130
+
131
+ posterior_labels = []
132
+
133
+ if self .__chain_settings ['PlottingSettings' ]['PosteriorSettings' ]['Make2DPosteriors' ]:
134
+ self ._plot_interface .initialise_new_plotter (m3post .PosteriorPlotter2D (self ._file_handler ), 'posterior_2d' )
135
+ posterior_labels .append ('posterior_2d' )
136
+
137
+ if self .__chain_settings ['PlottingSettings' ]['PosteriorSettings' ]['MakeTrianglePlot' ]:
138
+ self ._plot_interface .initialise_new_plotter (m3post .TrianglePlotter (self ._file_handler ), 'posterior_triangle' )
139
+ posterior_labels .append ('posterior_triangle' )
140
+
141
+ # Which variables do we actually want 2D plots for?
142
+ self ._plot_interface .set_variables_to_plot (self .__chain_settings ['PlottingSettings' ]['PosteriorSettings' ]['TrianglePlot' ], posterior_labels )
143
+
144
+ if self .__chain_settings ['PlottingSettings' ]['PosteriorSettings' ]['Make1DPosteriors' ]:
145
+ self ._plot_interface .initialise_new_plotter (m3post .PosteriorPlotter1D (self ._file_handler ), 'posterior_1d' )
146
+ posterior_labels .append ('posterior_1d' )
147
+
148
+
149
+ self ._plot_interface .set_credible_intervals (self .__chain_settings ['PlottingSettings' ]['PosteriorSettings' ]['CredibleIntervals' ])
150
+ self ._plot_interface .set_is_circular (self .__chain_settings ['ParameterSettings' ]['CircularParameters' ])
151
+
152
+ self ._plot_interface .make_plots (self .__chain_settings ['PlottingSettings' ]['PosteriorSettings' ]['PosteriorOutputFile' ], posterior_labels )
153
+
154
+ def make_diagnostics_plots (self ):
155
+ if self ._plot_interface is None :
156
+ self ._plot_interface = PlottingInterface (self ._file_handler )
157
+
158
+ diagnostic_labels = []
159
+
160
+ if self .__chain_settings ['PlottingSettings' ]['DiagnosticsSettings' ]['MakeViolin' ]:
161
+ self ._plot_interface .initialise_new_plotter (m3diag .ViolinPlotter (self ._file_handler ), 'violin_plotter' )
162
+ diagnostic_labels .append ('violin_plotter' )
163
+ if self .__chain_settings ['PlottingSettings' ]['DiagnosticsSettings' ]['MakeTraceAC' ]:
164
+ self ._plot_interface .initialise_new_plotter (m3diag .AutocorrelationTracePlotter (self ._file_handler ), 'trace_autocorr' )
165
+ diagnostic_labels .append ('trace_autocorr' )
166
+ if self .__chain_settings ['PlottingSettings' ]['DiagnosticsSettings' ]['MakeESS' ]:
167
+ self ._plot_interface .initialise_new_plotter (m3diag .EffectiveSampleSizePlotter (self ._file_handler ), 'ess_plot' )
168
+ diagnostic_labels .append ('ess_plot' )
169
+ if self .__chain_settings ['PlottingSettings' ]['DiagnosticsSettings' ]['MakeMCSE' ]:
170
+ self ._plot_interface .initialise_new_plotter (m3diag .MarkovChainStandardError (self ._file_handler ), 'msce_plot' )
171
+ diagnostic_labels .append ('msce_plot' )
172
+
173
+ self ._plot_interface .make_plots (self .__chain_settings ['PlottingSettings' ]['DiagnosticsSettings' ]['DiagnosticsOutputFile' ], diagnostic_labels )
174
+
175
+ # Final one, covariance plotter
176
+ if self .__chain_settings ['PlottingSettings' ]['DiagnosticsSettings' ]['MakeSuboptimality' ]:
177
+ suboptimality_obj = m3diag .CovarianceMatrixUtils (self ._file_handler )
178
+ suboptimality_obj .calculate_suboptimality (self .__chain_settings ['PlottingSettings' ]['DiagnosticsSettings' ]['SuboptimalitySteps' ], self .__chain_settings ['PlottingSettings' ]['DiagnosticsSettings' ]['SubOptimalityMin' ])
179
+ suboptimality_obj .plot_suboptimality (f"suboptimality_{ self .__chain_settings ['PlottingSettings' ]['DiagnosticsSettings' ]['DiagnosticsOutputFile' ]} " )
180
+
181
+ # Finally let's make a quick simmary
182
+ if self .__chain_settings ['PlottingSettings' ]['DiagnosticsSettings' ]['PrintSummary' ]:
183
+ self ._plot_interface .print_summary (f"summary_{ self .__chain_settings ['PlottingSettings' ]['DiagnosticsSettings' ]['DiagnosticsOutputFile' ]} .txt" )
184
+
185
+ def make_ml_interface (self )-> None :
35
186
if self ._file_handler is None :
36
187
raise Exception ("Cannot initialise ML interface without first setting up file handler!" )
37
188
38
189
39
- factory = MLFactory (self ._file_handler , self ._yaml_config [ "FileSettings " ]["LabelName" ])
40
- if self ._yaml_config [ "FitterSettings " ]["FitterPackage" ].lower () == "scikit" :
41
- self ._interface = factory .setup_scikit_model (self ._yaml_config [ "FitterSettings " ]["FitterName" ],
42
- ** self ._yaml_config [ "FitterSettings " ]["FitterKwargs" ])
190
+ factory = MLFactory (self ._file_handler , self .__chain_settings [ "ParameterSettings " ]["LabelName" ])
191
+ if self .__chain_settings [ "MLSettings " ]["FitterPackage" ].lower () == "scikit" :
192
+ self ._interface = factory .make_scikit_model (self .__chain_settings [ "MLSettings " ]["FitterName" ],
193
+ ** self .__chain_settings [ "MLSettings " ]["FitterKwargs" ])
43
194
44
- elif self ._yaml_config [ "FitterSettings " ]["FitterPackage" ].lower () == "tensorflow" :
45
- self ._interface = factory .setup_tensorflow_model (self ._yaml_config [ "FitterSettings " ]["FitterName" ],
46
- ** self ._yaml_config [ "FitterSettings " ]["FitterKwargs" ])
195
+ elif self .__chain_settings [ "MLSettings " ]["FitterPackage" ].lower () == "tensorflow" :
196
+ self ._interface = factory .make_tensorflow_model (self .__chain_settings [ "MLSettings " ]["FitterName" ],
197
+ ** self .__chain_settings [ "MLSettings " ]["FitterKwargs" ])
47
198
48
199
else :
49
200
raise ValueError ("Input not recognised!" )
50
201
51
- if self ._yaml_config [ "FitterSettings " ].get ("AddFromExternalModel" ):
52
- external_model = self ._yaml_config [ "FitterSettings " ]["ExternalModel" ]
202
+ if self .__chain_settings [ "MLSettings " ].get ("AddFromExternalModel" ):
203
+ external_model = self .__chain_settings [ "MLSettings " ]["ExternalModel" ]
53
204
self ._interface .load_model (external_model )
54
205
55
206
else :
56
- self ._interface .set_training_test_set (self ._yaml_config [ "FitterSettings " ]["TestSize" ])
207
+ self ._interface .set_training_test_set (self .__chain_settings [ "MLSettings " ]["TestSize" ])
57
208
58
209
self ._interface .train_model ()
59
210
self ._interface .test_model ()
60
- self ._interface .save_model (self ._yaml_config ["FileSettings" ]["ModelOutputName" ])
61
-
62
- def __call__ (self ) -> None :
63
- self .setup_file_handler ()
64
- self .setup_ml_interface ()
211
+ self ._interface .save_model (self .__chain_settings ["MLSettings" ]["MLOutputFile" ])
65
212
213
+
214
+
215
+ def __call__ (self ) -> None :
216
+
217
+ self .make_file_handler ()
218
+ if self .__chain_settings ["FileSettings" ]["MakePosteriors" ]:
219
+ self .make_posterior_plots ()
66
220
221
+ if self .__chain_settings ["FileSettings" ]["MakeDiagnostics" ]:
222
+ self .make_diagnostics_plots ()
223
+
224
+ if self .__chain_settings ["FileSettings" ]["MakeMLModel" ]:
225
+ self .make_ml_interface ()
226
+
227
+
67
228
@property
68
229
def chain_handler (self )-> ChainHandler | None :
69
230
return self ._file_handler
70
231
71
232
@property
72
233
def ml_interface (self )-> FmlInterface | None :
73
- return self ._interface
234
+ return self ._interface
235
+
0 commit comments