ctools 2.1.0
Loading...
Searching...
No Matches
cstsdist.py
Go to the documentation of this file.
1#! /usr/bin/env python
2# ==========================================================================
3# Generates the TS distribution for a particular model
4#
5# Copyright (C) 2011-2022 Juergen Knoedlseder
6#
7# This program is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# This program is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with this program. If not, see <http://www.gnu.org/licenses/>.
19#
20# ==========================================================================
21import sys
22import gammalib
23import ctools
24from cscripts import obsutils
25from cscripts import modutils
26from cscripts import ioutils
27from cscripts import mputils
28
29
30# ============== #
31# cstsdist class #
32# ============== #
33class cstsdist(ctools.csobservation):
34 """
35 Generates Test Statistic distribution for a model
36 """
37
38 # Constructor
39 def __init__(self, *argv):
40 """
41 Constructor
42 """
43 # Initialise application by calling the appropriate class constructor
44 self._init_csobservation(self.__class__.__name__, ctools.__version__, argv)
45
46 # Initialise some members
47 self._srcname = ''
48 self._fits = None
49 self._log_clients = False
50 self._model = None
51 self._nthreads = 0
52
53 # Return
54 return
55
56 # State methods por pickling
57 def __getstate__(self):
58 """
59 Extend ctools.csobservation getstate method to include some members
60
61 Returns
62 -------
63 state : dict
64 Pickled instance
65 """
66 # Set pickled dictionary
67 state = {'base' : ctools.csobservation.__getstate__(self),
68 'srcname' : self._srcname,
69 'fits' : self._fits,
70 'log_clients' : self._log_clients,
71 'model' : self._model,
72 'nthreads' : self._nthreads}
73
74 # Return pickled dictionary
75 return state
76
77 def __setstate__(self, state):
78 """
79 Extend ctools.csobservation setstate method to include some members
80
81 Parameters
82 ----------
83 state : dict
84 Pickled instance
85 """
86 ctools.csobservation.__setstate__(self, state['base'])
87 self._srcname = state['srcname']
88 self._fits = state['fits']
89 self._log_clients = state['log_clients']
90 self._model = state['model']
91 self._nthreads = state['nthreads']
92
93 # Return
94 return
95
96 # Private methods
97 def _get_parameters(self):
98 """
99 Get parameters from parfile and setup the observation
100 """
101 # Set observation if not done before
102 if self.obs().size() == 0:
103 self.obs(self._get_observations())
104
105 # Set observation statistic
106 self._set_obs_statistic(gammalib.toupper(self['statistic'].string()))
107
108 # Get source name
109 self._srcname = self['srcname'].string()
110
111 # Set models if we have none
112 if self.obs().models().size() == 0:
113 self.obs().models(self['inmodel'].filename())
114
115 # Query parameters
116 self['edisp'].boolean()
117 self['ntrials'].integer()
118 self['debug'].boolean()
119
120 # Read ahead output parameters
121 if self._read_ahead():
122 self['outfile'].filename()
123
124 # Write input parameters into logger
125 self._log_parameters(gammalib.TERSE)
126
127 # Set number of processes for multiprocessing
128 self._nthreads = mputils.nthreads(self)
129
130 # Return
131 return
132
133 def _sim(self, seed):
134 """
135 Return a simulated observation container
136
137 Parameters
138 ----------
139 seed : int
140 Random number generator seed
141
142 Returns
143 -------
144 sim : `~gammalib.GObservations`
145 Simulated observation container
146 """
147 # If observation is a counts cube then simulate events from the counts
148 # cube model ...
149 if self.obs().size() == 1 and self.obs()[0].eventtype() == 'CountsCube':
150
151 # If no counts cube model exists then compute it now
152 if self._model == None:
153 model = ctools.ctmodel(self.obs())
154 model['debug'] = self['debug'].boolean()
155 model['chatter'] = self['chatter'].integer()
156 model.run()
157 self._model = model.cube().copy() # Save copy for persistence
158
159 # Allocate random number generator
160 ran = gammalib.GRan()
161
162 # Get copy of model map
163 counts = self._model.counts().copy()
164
165 # Randomize counts
166 for i in range(counts.npix()):
167 counts[i] = ran.poisson(counts[i])
168
169 # Copy observations
170 sim = self.obs().copy()
171
172 # Set counts map
173 sim[0].events().counts(counts)
174
175 # ... otherwise simuate events from the observation container (works
176 # only for event lists
177 else:
178 sim = obsutils.sim(self.obs(),
179 seed = seed,
180 log = self._log_clients,
181 debug = self['debug'].boolean(),
182 nthreads = 1)
183
184 # Return simulated observation
185 return sim
186
187 def _trial(self, seed):
188 """
189 Create the TS for a single trial
190
191 Parameters
192 ----------
193 seed : int
194 Random number generator seed
195
196 Returns
197 -------
198 result : dict
199 Result dictionary
200 """
201 # Write header
202 self._log_header2(gammalib.EXPLICIT, 'Trial %d' % (seed+1))
203
204 # Simulate events
205 sim = self._sim(seed)
206
207 # Determine number of events in simulation
208 nevents = 0.0
209 for run in sim:
210 nevents += run.events().number()
211
212 # Write simulation results
213 self._log_header3(gammalib.EXPLICIT, 'Simulation')
214 self._log_value(gammalib.EXPLICIT, 'Number of simulated events', nevents)
215
216 # Fit model
217 fit = ctools.ctlike(sim)
218 fit['nthreads'] = 1 # Avoids OpenMP conflict
219 fit['debug'] = self['debug'].boolean()
220 fit['chatter'] = self['chatter'].integer()
221 fit.run()
222
223 # Get model fitting results
224 logL = fit.opt().value()
225 npred = fit.obs().npred()
226 models = fit.obs().models()
227 model = models[self._srcname]
228 ts = model.ts()
229
230 # Write fit results, either explicit or normal
231 if self._logExplicit():
232 self._log_header3(gammalib.EXPLICIT, 'Test source model fit')
233 self._log_value(gammalib.EXPLICIT, 'Test statistics', ts)
234 self._log_value(gammalib.EXPLICIT, 'log likelihood', logL)
235 self._log_value(gammalib.EXPLICIT, 'Number of predicted events', npred)
236 for model in models:
237 self._log_value(gammalib.EXPLICIT, 'Model', model.name())
238 for par in model:
239 self._log_string(gammalib.EXPLICIT, str(par))
240 elif self._logNormal():
241 prefactor = modutils.normalisation_parameter(model)
242 name = 'Trial %d' % seed
243 value = 'TS=%.3f %s=%e +/- %e' % \
244 (ts, prefactor.name(), prefactor.value(), prefactor.error())
245 self._log_value(gammalib.TERSE, name, value)
246
247 # Initialise results
248 colnames = []
249 values = {}
250
251 # Set TS value
252 colnames.append('TS')
253 values['TS'] = ts
254
255 # Set Nevents
256 colnames.append('Nevents')
257 values['Nevents'] = nevents
258
259 # Set Npred
260 colnames.append('Npred')
261 values['Npred'] = npred
262
263 # Gather free full fit parameters
264 for model in models:
265 model_name = model.name()
266 for par in model:
267 if par.is_free():
268
269 # Set parameter name
270 name = model_name + '_' + par.name()
271
272 # Append value
273 colnames.append(name)
274 values[name] = par.value()
275
276 # Append error
277 name = name+'_error'
278 colnames.append(name)
279 values[name] = par.error()
280
281 # Bundle together results
282 result = {'colnames': colnames, 'values': values}
283
284 # Return
285 return result
286
287 def _create_fits(self, results):
288 """
289 Create FITS file from results
290
291 Parameters
292 ----------
293 results : list of dict
294 List of result dictionaries
295 """
296 # Gather headers for parameter columns
297 headers = []
298 for colname in results[0]['colnames']:
299 if colname != 'TS' and colname != 'Nevents' and \
300 colname != 'Npred':
301 headers.append(colname)
302
303 # Create FITS table columns
304 nrows = len(results)
305 ts = gammalib.GFitsTableDoubleCol('TS', nrows)
306 nevents = gammalib.GFitsTableDoubleCol('NEVENTS', nrows)
307 npred = gammalib.GFitsTableDoubleCol('NPRED', nrows)
308 ts.unit('')
309 nevents.unit('counts')
310 npred.unit('counts')
311 columns = []
312 for header in headers:
313 name = gammalib.toupper(header)
314 column = gammalib.GFitsTableDoubleCol(name, nrows)
315 column.unit('')
316 columns.append(column)
317
318 # Fill FITS table columns
319 for i, result in enumerate(results):
320 ts[i] = result['values']['TS']
321 nevents[i] = result['values']['Nevents']
322 npred[i] = result['values']['Npred']
323 for k, column in enumerate(columns):
324 column[i] = result['values'][headers[k]]
325
326 # Initialise FITS Table with extension "TS_DISTRIBUTION"
327 table = gammalib.GFitsBinTable(nrows)
328 table.extname('TS_DISTRIBUTION')
329
330 # Add keywors for compatibility with gammalib.GMWLSpectrum
331 table.card('INSTRUME', 'CTA', 'Name of Instrument')
332 table.card('TELESCOP', 'CTA', 'Name of Telescope')
333
334 # Stamp header
335 self._stamp(table)
336
337 # Add script keywords
338 table.card('NTRIALS', self['ntrials'].integer(), 'Number of trials')
339 table.card('STAT', self['statistic'].string(), 'Optimization statistic')
340 table.card('EDISP', self['edisp'].boolean(), 'Use energy dispersion?')
341
342 # Append filled columns to fits table
343 table.append(ts)
344 table.append(nevents)
345 table.append(npred)
346 for column in columns:
347 table.append(column)
348
349 # Create the FITS file now
350 self._fits = gammalib.GFits()
351 self._fits.append(table)
352
353 # Return
354 return
355
356
357 # Public methods
358 def process(self):
359 """
360 Process the script
361 """
362 # Get parameters
363 self._get_parameters()
364
365 # Set test source model for this observation
366 self.models(modutils.test_source(self.obs().models(), self._srcname))
367
368 # Write observation into logger
369 self._log_observations(gammalib.NORMAL, self.obs(), 'Input observation')
370
371 # Write models into logger
372 self._log_models(gammalib.NORMAL, self.obs().models(), 'Input model')
373
374 # Write header
375 self._log_header1(gammalib.TERSE, 'Generate TS distribution')
376
377 # Get number of trials
378 ntrials = self['ntrials'].integer()
379
380 # Initialise results
381 results = []
382
383 # If more than a single thread is requested then use multiprocessing
384 if self._nthreads > 1:
385 args = [(self, '_trial', i) for i in range(ntrials)]
386 poolresults = mputils.process(self._nthreads, mputils.mpfunc, args)
387
388 # Continue with regular processing
389 for i in range(ntrials):
390
391 # If multiprocessing was used then recover results and put them
392 # into the log file
393 if self._nthreads > 1:
394 results.append(poolresults[i][0])
395 self._log_string(gammalib.TERSE, poolresults[i][1]['log'], False)
396
397 # ... otherwise make a trial now
398 else:
399
400 # Run trial
401 result = self._trial(i)
402
403 # Append results
404 results.append(result)
405
406 # Create FITS file
407 self._create_fits(results)
408
409 # Return
410 return
411
412 def models(self, models):
413 """
414 Set model
415
416 Parameters
417 ----------
418 models : `~gammalib.GModels`
419 Model container
420 """
421 # Copy models
422 self.obs().models(models)
423
424 # Return
425 return
426
427 def save(self):
428 """
429 Save TS distribution FITS file
430 """
431 # Write header
432 self._log_header1(gammalib.TERSE, 'Save TS distribution')
433
434 # Continue only if FITS file is valid
435 if self._fits != None:
436
437 # Get outmap parameter
438 outfile = self['outfile'].filename()
439
440 # Log file name
441 self._log_value(gammalib.NORMAL, 'TS distribution file', outfile.url())
442
443 # Save TS distribution
444 self._fits.saveto(outfile, self['clobber'].boolean())
445
446 # Return
447 return
448
450 """
451 Return TS distribution FITS file
452
453 Returns:
454 FITS file containing TS distribution
455 """
456 # Return
457 return self._fits
458
459
460# ======================== #
461# Main routine entry point #
462# ======================== #
463if __name__ == '__main__':
464
465 # Create instance of application
466 app = cstsdist(sys.argv)
467
468 # Execute application
469 app.execute()
__setstate__(self, state)
Definition cstsdist.py:77
_create_fits(self, results)
Definition cstsdist.py:287