Coverage for lisacattools/analyze.py: 34%
190 statements
« prev ^ index » next coverage.py v7.0.5, created at 2023-02-06 17:36 +0000
« prev ^ index » next coverage.py v7.0.5, created at 2023-02-06 17:36 +0000
1# -*- coding: utf-8 -*-
2# Copyright (C) 2021 - James I. Thorpe, Tyson B. Littenberg, Jean-Christophe
3# Malapert
4#
5# This file is part of lisacattools.
6#
7# lisacattools is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lisacattools is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lisacattools. If not, see <https://www.gnu.org/licenses/>.
19import logging
20import os
21from typing import Dict
22from typing import List
23from typing import NoReturn
25import corner
26import ligo.skymap.plot # noqa: F401
27import matplotlib.pyplot as plt
28import numpy as np
29import pandas as pd
30import seaborn as sns
32from .catalog import GWCatalog
33from .catalog import GWCatalogs
34from .custom_logging import UtilsLogs
35from .monitoring import UtilsMonitoring
36from .utils import FrameEnum
37from .utils import HPhist
39UtilsLogs.addLoggingLevel("TRACE", 15)
42class LisaAnalyse:
43 """Factory to create an analysis for a catalog or a time-evolution of the
44 catalog."""
46 @staticmethod
47 def create(catalog, save_dir=None):
48 obj = None
49 if isinstance(catalog, GWCatalog):
50 obj = CatalogAnalysis(catalog, save_dir)
51 elif isinstance(catalog, GWCatalogs):
52 obj = HistoryAnalysis(catalog, save_dir)
53 else:
54 raise NotImplementedError(f"type {type(catalog)} not implemented")
55 return obj
58class AbstractLisaAnalyze:
59 """Abstract Object to link the two implementation and to share some
60 method."""
62 def __init__(self):
63 pass
65 @UtilsMonitoring.io(level=logging.DEBUG)
66 def _get_variable(
67 self, dico: Dict, variable: str, default_val: object
68 ) -> object:
69 return default_val if variable not in dico else dico[variable]
71 @UtilsMonitoring.io(entry=True, exit=False, level=logging.DEBUG)
72 def plot_corners_ds(self, sources, *args, **kwargs):
73 color = self._get_variable(kwargs, "color", "red")
74 plot_datapoints = self._get_variable(kwargs, "plot_datapoints", False)
75 fill_contours = self._get_variable(kwargs, "fill_contours", True)
76 bins = self._get_variable(kwargs, "bins", 50)
77 smooth = self._get_variable(kwargs, "smooth", 1.0)
78 levels = self._get_variable(kwargs, "levels", [0.68, 0.95])
79 fontsize = self._get_variable(kwargs, "fontsize", 16)
80 fig = self._get_variable(kwargs, "fig", None)
81 title = self._get_variable(kwargs, "title", "parameters")
82 if fig:
83 corner.corner(
84 sources,
85 fig=fig,
86 color=color,
87 plot_datapoints=plot_datapoints,
88 fill_contours=fill_contours,
89 bins=bins,
90 smooth=smooth,
91 levels=levels,
92 label_kwargs={"fontsize": fontsize},
93 )
94 else:
95 figIn = corner.corner(
96 sources,
97 color=color,
98 plot_datapoints=plot_datapoints,
99 fill_contours=fill_contours,
100 bins=bins,
101 smooth=smooth,
102 levels=levels,
103 label_kwargs={"fontsize": fontsize},
104 )
105 figIn.suptitle(title)
108class CatalogAnalysis(AbstractLisaAnalyze):
109 """Handle the analysis of one catalog."""
111 def __init__(self, catalog: GWCatalog, save_img_dir=None):
112 """Init the analysis with a Lisa catalog."""
113 self.catalog = catalog
114 self.save_img_dir = save_img_dir
116 @property
117 def catalog(self):
118 """Catalog.
120 :getter: Returns the catalog of this analysis
121 :setter: Sets the catalog.
122 :type: GWCatalog
123 """
124 return self._catalog
126 @catalog.setter
127 def catalog(self, value):
128 self._catalog = value
130 @property
131 def save_img_dir(self):
132 """Save image directory for plot.
134 :getter: Returns the directory where plots are saved
135 :setter: Sets the directory where plots are saved.
136 :type: str
137 """
138 return self._save_img_dir
140 @save_img_dir.setter
141 def save_img_dir(self, value):
142 self._save_img_dir = value
144 @UtilsMonitoring.io(entry=True, exit=False, level=logging.DEBUG)
145 def plot_mbh_mergers_history(self) -> NoReturn:
146 """Plot the history of observed mergers."""
148 mergeTimes = self.catalog.get_detections("Barycenter Merge Time")
149 mergeTimes.sort_values(ascending=True, inplace=True)
150 mergeT = np.insert(np.array(mergeTimes) / 86400, 0, 0)
151 mergeCount = np.arange(0, len(mergeTimes) + 1)
152 fig, ax = plt.subplots(figsize=[8, 6], dpi=100)
153 ax.step(mergeT, mergeCount, where="post")
154 for m in range(0, len(mergeTimes)):
155 plt.annotate(
156 mergeTimes.index[m], # this is the text
157 # this is the point to label
158 (mergeTimes[m] / 86400, mergeCount[m]),
159 textcoords="offset points", # how to position the text
160 xytext=(2, 5), # distance from text to points (x,y)
161 rotation="horizontal",
162 ha="left",
163 ) # horizontal alignment can be left, right or center
164 ax.set_xlabel("Observation Time [days]")
165 ax.set_ylabel("Merger Count")
166 ax.set_title(f"MBH Mergers in catalog {self.catalog.name}")
167 ax.grid()
168 if self.save_img_dir:
169 fig.savefig(
170 os.path.join(
171 self.save_img_dir,
172 "MBH_mergers_" + self.catalog.name + ".png",
173 )
174 )
175 # plt.show()
177 @UtilsMonitoring.io(entry=True, exit=False, level=logging.DEBUG)
178 def plot_individual_sources(self) -> NoReturn:
179 """Plot the indivual sources."""
181 fig, ax = plt.subplots(figsize=[8, 6], dpi=100)
182 detections = self.catalog.get_detections(["Mass 1", "Mass 2"])
183 sources = list(detections.index)
184 for idx, source in enumerate(sources):
185 chain = self.catalog.get_source_samples(
186 source, ["Mass 1", "Mass 2"]
187 )
188 l1, m1, h1 = np.quantile(
189 np.array(chain["Mass 1"]), [0.05, 0.5, 0.95]
190 )
191 l2, m2, h2 = np.quantile(
192 np.array(chain["Mass 2"]), [0.05, 0.5, 0.95]
193 )
194 if idx < 10:
195 mkr = "o"
196 else:
197 mkr = "^"
198 ax.errorbar(
199 m1,
200 m2,
201 xerr=np.vstack((m1 - l1, h1 - m1)),
202 yerr=np.vstack((m2 - l2, h2 - m2)),
203 label=source,
204 markersize=6,
205 capsize=2,
206 marker=mkr,
207 markerfacecolor="none",
208 )
209 ax.set_xscale("log", nonpositive="clip")
210 ax.set_yscale("log", nonpositive="clip")
211 ax.grid()
212 ax.set_xlabel("Mass 1 [MSun]")
213 ax.set_ylabel("Mass 2 [MSun]")
214 ax.set_title("90%% CI for Component Masses in %s " % self.catalog.name)
215 ax.legend(loc="lower right")
216 if self.save_img_dir:
217 fig.savefig(
218 os.path.join(
219 self.save_img_dir,
220 "component_masses" + self.catalog.name + ".png",
221 )
222 )
223 # plt.show()
225 @UtilsMonitoring.io(entry=True, exit=False, level=logging.DEBUG)
226 def plot_corners(self, source_name, params, *args, **kwargs) -> NoReturn:
227 """Some corners plots."""
228 sources = self.catalog.get_source_samples(source_name, params)
229 self.plot_corners_ds(source_name, sources, *args, **kwargs)
231 @UtilsMonitoring.io(entry=True, exit=False, level=logging.DEBUG)
232 def plot_skymap(
233 self, source, nside, system: FrameEnum = FrameEnum.ECLIPTIC
234 ) -> NoReturn:
235 """Plot skymap."""
236 hp_map = HPhist(source, nside, system)
237 fig = plt.figure(figsize=(8, 6), dpi=100)
238 ax = plt.axes(
239 [0.05, 0.05, 0.9, 0.9], projection="geo degrees mollweide"
240 )
241 ax.grid()
242 ax.imshow_hpx((hp_map), cmap="plasma")
243 if self.save_img_dir:
244 fig.savefig(os.path.join(self.save_img_dir, "skymap.png"))
247class HistoryAnalysis(AbstractLisaAnalyze):
248 """Analyse a particular source to see how it's parameter estimates
249 improve over time"""
251 def __init__(self, catalogs: GWCatalogs, save_img_dir=None):
252 """Init the HistoryAnalysis with all catalogs to load the parameter
253 estimates over the time."""
254 self.catalogs = catalogs
255 self.save_img_dir = save_img_dir
257 @property
258 def catalogs(self):
259 """Catalogs.
261 :getter: Returns the catalogs of this analysis
262 :setter: Sets the catalogs.
263 :type: GWCatalogs
264 """
265 return self._catalogs
267 @catalogs.setter
268 def catalogs(self, value):
269 self._catalogs = value
271 @property
272 def save_img_dir(self):
273 """Save image directory for plot.
275 :getter: Returns the directory where plots are saved
276 :setter: Sets the directory where plots are saved.
277 :type: str
278 """
279 return self._save_img_dir
281 @save_img_dir.setter
282 def save_img_dir(self, value):
283 self._save_img_dir = value
285 @UtilsMonitoring.io(entry=True, exit=False, level=logging.DEBUG)
286 def plot_parameter_time_evolution(
287 self,
288 df: pd.DataFrame,
289 time_parameter: str,
290 parameter: str,
291 *args,
292 **kwargs,
293 ) -> NoReturn:
294 """Plot the parameter that evolves over time.
296 Note: extra parameter can be configured:
297 - plot_type, default : scatter
298 - grid, default : True
299 - marker, default : 's'
300 - linestyle, default : '-'
301 - yscale, default : log
302 - title, default : Evolution
304 Args:
305 df (pd.DataFrame): data
306 time_parameter (str): time parameter in the data
307 parameter (str): parameter to plot over the time
308 """
309 plot_type = self._get_variable(kwargs, "scatter", "scatter")
310 grid = self._get_variable(kwargs, "grid", True)
311 marker = self._get_variable(kwargs, "marker", "s")
312 linestyle = self._get_variable(kwargs, "linestyle", "-")
313 yscale = self._get_variable(kwargs, "yscale", "log")
314 title: str = self._get_variable(kwargs, "title", "Evolution")
316 fig, ax = plt.subplots(figsize=[8, 6], dpi=100)
317 df.plot(
318 kind=plot_type,
319 x=time_parameter,
320 y=parameter,
321 ax=ax,
322 grid=grid,
323 marker=marker,
324 linestyle=linestyle,
325 )
326 ax.set_yscale(yscale)
327 ax.set_title(title)
328 if self.save_img_dir:
329 fig.savefig(
330 os.path.join(
331 self.save_img_dir, title.replace(" ", "_") + ".png"
332 )
333 )
335 @UtilsMonitoring.io(entry=True, exit=False, level=logging.DEBUG)
336 def plot_parameter_time_evolution_from_source(
337 self,
338 catalog_name: str,
339 source_name: str,
340 time_parameter: str,
341 parameter: str,
342 *args,
343 **kwargs,
344 ) -> NoReturn:
345 """Plot the parameter that evolves over time for a given source
346 starting from a catalog.
348 Note: extra parameter can be configured:
349 - plot_type, default : scatter
350 - grid, default : True
351 - marker, default : 's'
352 - linestyle, default : '-'
353 - yscale, default : log
354 - title, default : Evolution
356 Args:
357 df (pd.DataFrame): data
358 catalog_name (str) : Start the evolution from the oldest one
359 until that one
360 source_name (str) : source name to follow up
361 time_parameter (str): time parameter in the data
362 parameter (str): parameter to plot over the time
363 """
364 catalogs = self.catalogs
365 srcHist = catalogs.get_lineage(catalog_name, source_name)
366 self.plot_parameter_time_evolution(
367 srcHist, time_parameter, parameter, *args, **kwargs
368 )
370 @UtilsMonitoring.io(entry=True, exit=False, level=logging.DEBUG)
371 def plot_parameters_evolution(
372 self,
373 all_epochs: pd.DataFrame,
374 params: List,
375 scales: List,
376 *args,
377 **kwargs,
378 ) -> NoReturn:
379 """Show evolution over many different epochs.
381 Args:
382 all_epochs (pd.DataFrame): observation of a source at
383 different epochs
384 params (List): list of parameters to plot
385 scales (List): Scale for each plot
386 """
387 title = self._get_variable(kwargs, "title", "Parameter Evolution")
388 x_title = self._get_variable(kwargs, "x_title", "Observation Week")
389 nrows = int(np.ceil(len(params) / 2))
390 fig = plt.figure(figsize=(10.0, 10.0), dpi=100)
392 for idx, param in enumerate(params):
393 ax = fig.add_subplot(nrows, 2, idx + 1)
394 sns.violinplot(
395 ax=ax,
396 x=x_title,
397 y=param,
398 data=all_epochs,
399 scale="width",
400 width=0.8,
401 inner="quartile",
402 )
403 ax.set_yscale(scales[idx])
404 ax.grid(axis="y")
406 fig.suptitle(title)
407 if self.save_img_dir:
408 fig.savefig(
409 os.path.join(
410 self.save_img_dir, title.replace(" ", "_") + ".png"
411 )
412 )
414 @UtilsMonitoring.io(entry=True, exit=False, level=logging.DEBUG)
415 def plot_parameters_correlation_evolution(
416 self,
417 allEpochs: pd.DataFrame,
418 wks: List,
419 params: List,
420 colors: List,
421 *args,
422 **kwargs,
423 ) -> NoReturn:
424 """To dig into how parameter correlations might change over time, we
425 can look at a time-evolving corner plot
427 Args:
428 allEpochs (pd.DataFrame): observation of a source at different
429 epochs
430 wks (List): weeks to plot
431 params (List): parameters to plot
432 colors (List): color according the weeks
433 """
434 title = self._get_variable(kwargs, "title", "Evolution of parameters")
435 fig = plt.figure(figsize=[8, 8], dpi=100)
436 for idx, wk in enumerate(wks):
437 epoch = allEpochs[allEpochs["Observation Week"] == wk]
438 self.plot_corners_ds(epoch[params], fig=fig, color=colors[idx])
439 fig.suptitle(title)
440 if self.save_img_dir:
441 fig.savefig(
442 os.path.join(
443 self.save_img_dir, title.replace(" ", "_") + ".png"
444 )
445 )
447 @UtilsMonitoring.io(entry=True, exit=False, level=logging.DEBUG)
448 def plot_skymap_evolution(
449 self,
450 nside: int,
451 allEpochs: pd.DataFrame,
452 wks: List,
453 system: FrameEnum = FrameEnum.GALACTIC,
454 *args,
455 **kwargs,
456 ) -> NoReturn:
457 """Plot the skymap evolution
459 Args:
460 nside (int): parameter for healpix related to the number of cells
461 allEpochs (pd.DataFrame): observation of a source at different
462 epochs
463 wks (List): weeks to plot
464 system (FrameEnum, optional): coordinate reference frame. Defaults
465 to 'FrameEnum.GALACTIC'.
466 """
467 title = self._get_variable(
468 kwargs, "title", "Sky Localization Evolution"
469 )
470 fig = plt.figure(figsize=(10, 10), dpi=100)
471 ncols = 2
472 nrows = int(np.ceil(len(wks) / ncols))
473 for idx, wk in enumerate(wks):
474 hpmap = HPhist(
475 allEpochs[allEpochs["Observation Week"] == wk], nside, system
476 )
477 ax = fig.add_subplot(
478 nrows, ncols, idx + 1, projection="geo degrees mollweide"
479 )
480 ax.grid()
481 # ax.contour_hpx(hpmap, cmap='Blues',levels=4,alpha=0.8)
482 ax.imshow_hpx(hpmap, cmap="plasma")
483 ax.set_title(f"Week {wk}")
484 fig.suptitle(title)
485 if self.save_img_dir:
486 fig.savefig(
487 os.path.join(
488 self.save_img_dir, title.replace(" ", "_") + ".png"
489 )
490 )