Coverage for lisacattools/plugins/mbh.py: 75%
191 statements
« prev ^ index » next coverage.py v7.0.5, created at 2023-02-06 17:36 +0000
« prev ^ index » next coverage.py v7.0.5, created at 2023-02-06 17:36 +0000
1# -*- coding: utf-8 -*-
2# Copyright (C) 2021 - James I. Thorpe, Tyson B. Littenberg, Jean-Christophe
3# Malapert
4#
5# This file is part of lisacattools.
6#
7# lisacattools is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lisacattools is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lisacattools. If not, see <https://www.gnu.org/licenses/>.
19"""Module implemented the MBH catalog."""
20import glob
21import logging
22import os
23from itertools import chain
24from typing import List
25from typing import Optional
26from typing import Union
28import numpy as np
29import pandas as pd
31from ..catalog import GWCatalog
32from ..catalog import GWCatalogs
33from ..catalog import UtilsLogs
34from ..catalog import UtilsMonitoring
36UtilsLogs.addLoggingLevel("TRACE", 15)
39class MbhCatalogs(GWCatalogs):
40 """Implementation of the MBH catalogs."""
42 EXTRA_DIR = "extra_directories"
44 def __init__(
45 self,
46 path: str,
47 accepted_pattern: Optional[str] = "MBH_wk*C.h5",
48 rejected_pattern: Optional[str] = None,
49 *args,
50 **kwargs,
51 ):
52 """Init the MbhCatalogs by reading all catalogs with a specific
53 pattern in a given directory and rejecting files by another pattern.
55 The list of catalogs is sorted by "observation week"
57 Args:
58 path (str): directory
59 accepted_pattern (str, optional): pattern to accept files.
60 Defaults to "MBH_wk*C.h5".
61 rejected_pattern (str, optional): pattern to reject files.
62 Defaults to None.
64 Raises:
65 ValueError: no files found matching the accepted and rejected
66 patterns.
67 """
68 self.path = path
69 self.accepted_pattern = accepted_pattern
70 self.rejected_pattern = rejected_pattern
71 self.extra_directories = (
72 kwargs[MbhCatalogs.EXTRA_DIR]
73 if MbhCatalogs.EXTRA_DIR in kwargs
74 else list()
75 )
76 directories = self._search_directories(
77 self.path, self.extra_directories
78 )
79 self.cat_files = self._search_files(
80 directories, accepted_pattern, rejected_pattern
81 )
82 if len(self.cat_files) == 0:
83 raise ValueError(
84 f"no files found matching the accepted \
85 ({self.accepted_pattern}) and rejected \
86 ({self.rejected_pattern}) patterns in {directories}"
87 )
88 self.__metadata = pd.concat(
89 [self._read_cats(cat_file) for cat_file in self.cat_files]
90 )
91 self.__metadata = self.__metadata.sort_values(by="observation week")
93 @UtilsMonitoring.io(level=logging.DEBUG)
94 def _search_directories(
95 self, path: str, extra_directories: List[str]
96 ) -> List[str]:
97 """Compute the list of directories on which the pattern will be applied.
99 Args:
100 path (str) : main path
101 extra_directories (List[str]) : others directories
103 Returns:
104 List[str]: list of directories on which the pattern will be applied
105 """
106 directories: List[str] = extra_directories[:]
107 directories.append(path)
108 return directories
110 @UtilsMonitoring.io(level=logging.DEBUG)
111 def _search_files(
112 self, directories: List[str], accepted_pattern, rejected_pattern
113 ) -> List[str]:
114 """Search files in directories according to a set of constraints :
115 accepted and rejected patterns
117 Args:
118 directories (List[str]): List of directories to scan
119 accepted_pattern ([type]): pattern to get files
120 rejected_pattern ([type]): pattern to reject files
122 Returns:
123 List[str]: List of files
124 """
125 accepted_files = [
126 glob.glob(path + os.path.sep + accepted_pattern)
127 for path in directories
128 ]
129 accepted_files = list(chain(*accepted_files))
130 if rejected_pattern is None:
131 rejected_files = list()
132 else:
133 rejected_files = [
134 list()
135 if rejected_pattern is None
136 else glob.glob(path + os.path.sep + rejected_pattern)
137 for path in directories
138 ]
139 rejected_files = list(chain(*rejected_files))
140 cat_files = list(set(accepted_files) - set(rejected_files))
141 return cat_files
143 @UtilsMonitoring.io(level=logging.DEBUG)
144 def _read_cats(self, cat_file: str) -> pd.DataFrame:
145 """Reads the metadata of a given catalog and the location of the file.
147 Args:
148 cat_file (str): catalog to load
150 Returns:
151 pd.DataFrame: pandas data frame
152 """
153 df = pd.read_hdf(cat_file, key="metadata")
154 df["location"] = cat_file
155 return df
157 @property
158 @UtilsMonitoring.io(level=logging.DEBUG)
159 def metadata(self) -> pd.DataFrame:
160 __doc__ = GWCatalogs.metadata.__doc__ # noqa: F841
161 return self.__metadata
163 @property
164 @UtilsMonitoring.io(level=logging.TRACE)
165 def count(self) -> int:
166 __doc__ = GWCatalogs.count.__doc__ # noqa: F841
167 return len(self.metadata.index)
169 @property
170 @UtilsMonitoring.io(level=logging.TRACE)
171 def files(self) -> List[str]:
172 __doc__ = GWCatalogs.files.__doc__ # noqa: F841
173 return self.cat_files
175 @UtilsMonitoring.io(level=logging.TRACE)
176 def get_catalogs_name(self) -> List[str]:
177 __doc__ = GWCatalogs.get_catalogs_name.__doc__ # noqa: F841
178 return list(self.metadata.index)
180 @UtilsMonitoring.io(level=logging.TRACE)
181 @UtilsMonitoring.time_spend(level=logging.DEBUG, threshold_in_ms=10)
182 def get_first_catalog(self) -> GWCatalog:
183 __doc__ = GWCatalogs.get_first_catalog.__doc__ # noqa: F841
184 location = self.metadata.iloc[0]["location"]
185 name = self.metadata.index[0]
186 return MbhCatalog(name, location)
188 @UtilsMonitoring.io(level=logging.TRACE)
189 @UtilsMonitoring.time_spend(level=logging.DEBUG, threshold_in_ms=10)
190 def get_last_catalog(self) -> GWCatalog:
191 __doc__ = GWCatalogs.get_last_catalog.__doc__ # noqa: F841
192 location = self.metadata.iloc[self.count - 1]["location"]
193 name = self.metadata.index[self.count - 1]
194 return MbhCatalog(name, location)
196 @UtilsMonitoring.io(level=logging.TRACE)
197 @UtilsMonitoring.time_spend(level=logging.DEBUG, threshold_in_ms=10)
198 def get_catalog(self, idx: int) -> GWCatalog:
199 __doc__ = GWCatalogs.get_catalog.__doc__ # noqa: F841
200 location = self.metadata.iloc[idx]["location"]
201 name = self.metadata.index[idx]
202 return MbhCatalog(name, location)
204 @UtilsMonitoring.io(level=logging.TRACE)
205 @UtilsMonitoring.time_spend(level=logging.DEBUG, threshold_in_ms=10)
206 def get_catalog_by(self, name: str) -> GWCatalog:
207 __doc__ = GWCatalogs.get_catalog_by.__doc__ # noqa: F841
208 cat_idx = self.metadata.index.get_loc(name)
209 return self.get_catalog(cat_idx)
211 @UtilsMonitoring.io(level=logging.TRACE)
212 @UtilsMonitoring.time_spend(level=logging.DEBUG, threshold_in_ms=100)
213 def get_lineage(self, cat_name: str, src_name: str) -> pd.DataFrame:
214 __doc__ = GWCatalogs.get_lineage.__doc__ # noqa: F841
216 dfs: List[pd.Series] = list()
217 while src_name != "" and cat_name not in [None, ""]:
218 detections = self.get_catalog_by(cat_name).get_dataset(
219 "detections"
220 )
221 src = detections.loc[[src_name]]
222 try:
223 wk = self.metadata.loc[cat_name]["observation week"]
224 except: # noqa: E722
225 wk = self.metadata.loc[cat_name]["Observation Week"]
227 src.insert(0, "Observation Week", wk, True)
228 src.insert(1, "Catalog", cat_name, True)
229 dfs.append(src)
230 try:
231 prnt = self.metadata.loc[cat_name]["parent"]
232 except: # noqa: E722
233 prnt = self.metadata.loc[cat_name]["Parent"]
235 cat_name = prnt
236 src_name = src.iloc[0]["Parent"]
238 histDF: pd.DataFrame = pd.concat(dfs, axis=0)
239 histDF.drop_duplicates(
240 subset="Log Likelihood", keep="last", inplace=True
241 )
242 histDF.sort_values(by="Observation Week", ascending=True, inplace=True)
243 return histDF
245 @UtilsMonitoring.io(level=logging.TRACE)
246 @UtilsMonitoring.time_spend(level=logging.DEBUG, threshold_in_ms=100)
247 def get_lineage_data(self, lineage: pd.DataFrame) -> pd.DataFrame:
248 __doc__ = GWCatalogs.get_lineage_data.__doc__ # noqa: F841
250 def _process_lineage(source_epoch, source_data, obs_week):
251 source_data.insert(
252 len(source_data.columns), "Source", source_epoch, True
253 )
254 source_data.insert(
255 len(source_data.columns), "Observation Week", obs_week, True
256 )
257 return source_data
259 source_epochs = list(lineage.index)
261 merge_source_epochs: pd.DataFrame = pd.concat(
262 [
263 _process_lineage(
264 source_epoch,
265 self.get_catalog_by(
266 lineage.loc[source_epoch]["Catalog"]
267 ).get_source_samples(source_epoch),
268 lineage.loc[source_epoch]["Observation Week"],
269 )
270 for source_epoch in source_epochs
271 ]
272 )
273 merge_source_epochs = merge_source_epochs[
274 [
275 "Source",
276 "Observation Week",
277 "Mass 1",
278 "Mass 2",
279 "Spin 1",
280 "Spin 2",
281 "Ecliptic Latitude",
282 "Ecliptic Longitude",
283 "Luminosity Distance",
284 "Barycenter Merge Time",
285 "Merger Phase",
286 "Polarization",
287 "cos inclination",
288 ]
289 ].copy()
290 return merge_source_epochs
292 def __repr__(self):
293 return f"MbhCatalogs({self.path!r}, {self.accepted_pattern!r}, \
294 {self.rejected_pattern!r}, {self.extra_directories!r})"
296 def __str__(self):
297 return f"MbhCatalogs: {self.path} {self.accepted_pattern!r} \
298 {self.rejected_pattern!r} {self.extra_directories!r}"
301class MbhCatalog(GWCatalog):
302 """Implementation of the Mbh catalog."""
304 def __init__(self, name: str, location: str):
305 """Init the MBH catalog with a name and a location
307 Args:
308 name (str): name of the catalog
309 location (str): location of the catalog
310 """
311 self.__name = name
312 self.__location = location
313 store = pd.HDFStore(location, "r")
314 self.__datasets = store.keys()
315 store.close()
317 @property
318 @UtilsMonitoring.io(level=logging.DEBUG)
319 def datasets(self):
320 """dataset.
322 :getter: Returns the list of datasets
323 :type: List
324 """
325 return self.__datasets
327 @UtilsMonitoring.io(level=logging.DEBUG)
328 def get_dataset(self, name: str) -> pd.DataFrame:
329 """Returns a dataset based on its name.
331 Args:
332 name (str): name of the dataset
334 Returns:
335 pd.DataFrame: the dataset
336 """
337 return pd.read_hdf(self.location, key=name)
339 @property
340 @UtilsMonitoring.io(level=logging.DEBUG)
341 def name(self) -> str:
342 __doc__ = GWCatalog.name.__doc__ # noqa: F841
343 return self.__name
345 @property
346 @UtilsMonitoring.io(level=logging.DEBUG)
347 def location(self) -> str:
348 __doc__ = GWCatalog.location.__doc__ # noqa: F841
349 return self.__location
351 @UtilsMonitoring.io(level=logging.DEBUG)
352 def get_detections(
353 self, attr: Union[List[str], str] = None
354 ) -> Union[List[str], pd.DataFrame, pd.Series]:
355 __doc__ = GWCatalog.get_detections.__doc__ # noqa: F841
356 detections = self.get_dataset("detections")
357 return (
358 list(detections.index) if attr is None else detections[attr].copy()
359 )
361 @UtilsMonitoring.io(level=logging.DEBUG)
362 def get_attr_detections(self) -> List[str]:
363 __doc__ = GWCatalog.get_attr_detections.__doc__ # noqa: F841
364 return list(self.get_dataset("detections").columns)
366 @UtilsMonitoring.io(level=logging.DEBUG)
367 def get_median_source(self, attr: str) -> pd.DataFrame:
368 __doc__ = GWCatalog.get_median_source.__doc__ # noqa: F841
369 detections: pd.Series = self.get_detections(attr)
370 source_idx = self.get_detections()[
371 np.argmin(np.abs(np.array(detections) - detections.median()))
372 ]
373 return self.get_detections(self.get_attr_detections()).loc[
374 [source_idx]
375 ]
377 @UtilsMonitoring.io(level=logging.DEBUG)
378 def get_source_samples(
379 self, source_name: str, attr: List[str] = None
380 ) -> pd.DataFrame:
381 __doc__ = GWCatalog.get_source_samples.__doc__ # noqa: F841
382 samples = self.get_dataset(f"{source_name}_chain")
383 return samples if attr is None else samples[attr].copy()
385 @UtilsMonitoring.io(level=logging.DEBUG)
386 def get_attr_source_samples(self, source_name: str) -> List[str]:
387 __doc__ = GWCatalog.get_attr_source_samples.__doc__ # noqa: F841
388 return list(self.get_dataset(f"{source_name}_chain").columns)
390 @UtilsMonitoring.io(level=logging.TRACE)
391 def describe_source_samples(self, source_name: str) -> pd.DataFrame:
392 __doc__ = GWCatalog.describe_source_samples.__doc__ # noqa: F841
393 return self.get_source_samples(source_name).describe()
395 def __repr__(self):
396 return f"MbhCatalog({self.__name!r}, {self.__location!r})"
398 def __str__(self):
399 return f"MbhCatalog: {self.__name} {self.__location}"