Coverage for lisacattools/plugins/ucb.py: 82%
173 statements
« prev ^ index » next coverage.py v7.0.5, created at 2023-02-06 17:36 +0000
« prev ^ index » next coverage.py v7.0.5, created at 2023-02-06 17:36 +0000
1# -*- coding: utf-8 -*-
2# Copyright (C) 2021 - James I. Thorpe, Tyson B. Littenberg, Jean-Christophe
3# Malapert
4#
5# This file is part of lisacattools.
6#
7# lisacattools is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lisacattools is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lisacattools. If not, see <https://www.gnu.org/licenses/>.
19"""Module implemented the UCB catalog."""
20import glob
21import logging
22import os
23from itertools import chain
24from typing import List
25from typing import Optional
26from typing import Union
28import numpy as np
29import pandas as pd
31from ..catalog import GWCatalog
32from ..catalog import GWCatalogs
33from ..catalog import UtilsLogs
34from ..catalog import UtilsMonitoring
35from ..utils import CacheManager
37UtilsLogs.addLoggingLevel("TRACE", 15)
40class UcbCatalogs(GWCatalogs):
41 """Implementation of the UCB catalogs."""
43 EXTRA_DIR = "extra_directories"
45 def __init__(
46 self,
47 path: str,
48 accepted_pattern: Optional[str] = "*.h5",
49 rejected_pattern: Optional[str] = "*chain*",
50 *args,
51 **kwargs,
52 ):
53 """Init the UcbCatalogs by reading all catalogs with a specific
54 pattern in a given directory and rejecting files by another pattern.
56 The list of catalogs is sorted by "observation week"
58 Args:
59 path (str): directory
60 accepted_pattern (str, optional): pattern to accept files.
61 Defaults to "*.h5".
62 rejected_pattern (str, optional): pattern to reject files.
63 Defaults to "*chain*".
65 Raises:
66 ValueError: no files found matching the accepted and rejected
67 patterns.
68 """
69 self.path = path
70 self.accepted_pattern = accepted_pattern
71 self.rejected_pattern = rejected_pattern
72 self.extra_directories = (
73 kwargs[UcbCatalogs.EXTRA_DIR]
74 if UcbCatalogs.EXTRA_DIR in kwargs
75 else list()
76 )
77 directories = self._search_directories(
78 self.path, self.extra_directories
79 )
80 self.cat_files = self._search_files(
81 directories, accepted_pattern, rejected_pattern
82 )
83 if len(self.cat_files) == 0:
84 raise ValueError(
85 f"no files found matching the accepted \
86 ({self.accepted_pattern}) and rejected \
87 ({self.rejected_pattern}) patterns in {directories}"
88 )
89 self.__metadata = pd.concat(
90 [self._read_cats(cat_file) for cat_file in self.cat_files]
91 )
92 self.__metadata = self.__metadata.sort_values(by="Observation Time")
94 @UtilsMonitoring.io(level=logging.DEBUG)
95 def _search_directories(
96 self, path: str, extra_directories: List[str]
97 ) -> List[str]:
98 """Compute the list of directories on which the pattern will be applied.
100 Args:
101 path (str) : main path
102 extra_directories (List[str]) : others directories
104 Returns:
105 List[str]: list of directories on which the pattern will be applied
106 """
107 directories: List[str] = extra_directories[:]
108 directories.append(path)
109 return directories
111 @UtilsMonitoring.io(level=logging.DEBUG)
112 def _search_files(
113 self, directories: List[str], accepted_pattern, rejected_pattern
114 ) -> List[str]:
115 """Search files in directories according to a set of constraints :
116 accepted and rejected patterns
118 Args:
119 directories (List[str]): List of directories to scan
120 accepted_pattern ([type]): pattern to get files
121 rejected_pattern ([type]): pattern to reject files
123 Returns:
124 List[str]: List of files
125 """
126 accepted_files = [
127 glob.glob(path + os.path.sep + accepted_pattern)
128 for path in directories
129 ]
130 accepted_files = list(chain(*accepted_files))
131 if rejected_pattern is None:
132 rejected_files = list()
133 else:
134 rejected_files = [
135 list()
136 if rejected_pattern is None
137 else glob.glob(path + os.path.sep + rejected_pattern)
138 for path in directories
139 ]
140 rejected_files = list(chain(*rejected_files))
141 cat_files = list(set(accepted_files) - set(rejected_files))
142 return cat_files
144 def _read_cats(self, cat_file: str) -> pd.DataFrame:
145 """Reads the metadata of a given catalog and the location of the file.
147 Args:
148 cat_file (str): catalog to load
149 pattern (str) : pattern to be used when the catalog is a tar
151 Returns:
152 pd.DataFrame: pandas data frame
153 """
154 df = pd.read_hdf(cat_file, key="metadata")
155 df["location"] = cat_file
156 return df
158 @property
159 @UtilsMonitoring.io(level=logging.DEBUG)
160 def metadata(self) -> pd.DataFrame:
161 __doc__ = GWCatalogs.metadata.__doc__ # noqa: F841
162 return self.__metadata
164 @property
165 @UtilsMonitoring.io(level=logging.TRACE)
166 def count(self) -> int:
167 __doc__ = GWCatalogs.count.__doc__ # noqa: F841
168 return len(self.metadata.index)
170 @property
171 @UtilsMonitoring.io(level=logging.TRACE)
172 def files(self) -> List[str]:
173 __doc__ = GWCatalogs.files.__doc__ # noqa: F841
174 return self.cat_files
176 @UtilsMonitoring.io(level=logging.TRACE)
177 def get_catalogs_name(self) -> List[str]:
178 __doc__ = GWCatalogs.get_catalogs_name.__doc__ # noqa: F841
179 return list(self.metadata.index)
181 @UtilsMonitoring.io(level=logging.TRACE)
182 @UtilsMonitoring.time_spend(level=logging.DEBUG, threshold_in_ms=10)
183 def get_first_catalog(self) -> GWCatalog:
184 __doc__ = GWCatalogs.get_first_catalog.__doc__ # noqa: F841
185 location = self.metadata.iloc[0]["location"]
186 name = self.metadata.index[0]
187 return UcbCatalog(name, location)
189 @UtilsMonitoring.io(level=logging.TRACE)
190 @UtilsMonitoring.time_spend(level=logging.DEBUG, threshold_in_ms=10)
191 def get_last_catalog(self) -> GWCatalog:
192 __doc__ = GWCatalogs.get_last_catalog.__doc__ # noqa: F841
193 location = self.metadata.iloc[self.count - 1]["location"]
194 name = self.metadata.index[self.count - 1]
195 return UcbCatalog(name, location)
197 @UtilsMonitoring.io(level=logging.TRACE)
198 @UtilsMonitoring.time_spend(level=logging.DEBUG, threshold_in_ms=10)
199 def get_catalog(self, idx: int) -> GWCatalog:
200 __doc__ = GWCatalogs.get_catalog.__doc__ # noqa: F841
201 location = self.metadata.iloc[idx]["location"]
202 name = self.metadata.index[idx]
203 return UcbCatalog(name, location)
205 @UtilsMonitoring.io(level=logging.TRACE)
206 @UtilsMonitoring.time_spend(level=logging.DEBUG, threshold_in_ms=10)
207 def get_catalog_by(self, name: str) -> GWCatalog:
208 __doc__ = GWCatalogs.get_catalog_by.__doc__ # noqa: F841
209 cat_idx = self.metadata.index.get_loc(name)
210 return self.get_catalog(cat_idx)
212 def get_lineage(self, cat_name: str, src_name: str) -> pd.DataFrame:
213 raise NotImplementedError(
214 "Get_lineage is not implemented for this catalog !"
215 )
217 def get_lineage_data(self, lineage: pd.DataFrame) -> pd.DataFrame:
218 raise NotImplementedError(
219 "Get_lineage_data is not implemented for this catalog !"
220 )
222 def __repr__(self):
223 return f"UcbCatalogs({self.path!r}, {self.accepted_pattern!r}, \
224 {self.rejected_pattern!r}, {self.extra_directories!r})"
226 def __str__(self):
227 return f"UcbCatalogs: {self.path} {self.accepted_pattern!r} \
228 {self.rejected_pattern!r} {self.extra_directories!r}"
231class UcbCatalog(GWCatalog):
232 """Implementation of the Ucb catalog."""
234 def __init__(self, catalog_name: str, location: str):
235 """Init the LISA catalog with a name and a location
237 Args:
238 name (str): name of the catalog
239 location (str): location of the catalog
240 """
241 self.__name = catalog_name
242 self.__location = location
243 store = pd.HDFStore(location, "r")
244 self.__datasets = store.keys()
245 store.close()
247 @CacheManager.get_cache_pandas(
248 keycache_argument=[1, 2], level=logging.INFO
249 )
250 def _read_chain_file(
251 self, source_name: str, chain_file: str
252 ) -> pd.DataFrame:
253 """Read a source in a chain_file
255 Args:
256 source_name (str): Name of the source to extract from the
257 chain_file
258 chain_file (str): file to load
260 Returns:
261 pd.DataFrame: [description]
262 """
263 dirname = os.path.dirname(self.location)
264 source_samples_file = os.path.join(dirname, chain_file)
265 source_samples = pd.read_hdf(
266 source_samples_file, key=f"{source_name}_chain"
267 )
268 return source_samples
270 @property
271 @UtilsMonitoring.io(level=logging.DEBUG)
272 def datasets(self):
273 """dataset.
275 :getter: Returns the list of datasets
276 :type: List
277 """
278 return self.__datasets
280 @UtilsMonitoring.io(level=logging.DEBUG)
281 @UtilsMonitoring.time_spend(level=logging.DEBUG, threshold_in_ms=10)
282 def get_dataset(self, name: str) -> pd.DataFrame:
283 """Returns a dataset based on its name.
285 Args:
286 name (str): name of the dataset
288 Returns:
289 pd.DataFrame: the dataset
290 """
291 return pd.read_hdf(self.location, key=name)
293 @property
294 @UtilsMonitoring.io(level=logging.DEBUG)
295 def name(self) -> str:
296 __doc__ = GWCatalog.name.__doc__ # noqa: F841
297 return self.__name
299 @property
300 @UtilsMonitoring.io(level=logging.DEBUG)
301 def location(self) -> str:
302 __doc__ = GWCatalog.location.__doc__ # noqa: F841
303 return self.__location
305 @UtilsMonitoring.io(level=logging.DEBUG)
306 @UtilsMonitoring.time_spend(level=logging.DEBUG, threshold_in_ms=100)
307 def get_detections(
308 self, attr: Union[List[str], str] = None
309 ) -> Union[List[str], pd.DataFrame, pd.Series]:
310 __doc__ = GWCatalog.get_detections.__doc__ # noqa: F841
311 detections = self.get_dataset("detections")
312 return (
313 list(detections.index) if attr is None else detections[attr].copy()
314 )
316 @UtilsMonitoring.io(level=logging.DEBUG)
317 @UtilsMonitoring.time_spend(level=logging.DEBUG, threshold_in_ms=100)
318 def get_attr_detections(self) -> List[str]:
319 __doc__ = GWCatalog.get_attr_detections.__doc__ # noqa: F841
320 return list(self.get_dataset("detections").columns)
322 @UtilsMonitoring.io(level=logging.DEBUG)
323 @UtilsMonitoring.time_spend(level=logging.DEBUG, threshold_in_ms=100)
324 def get_median_source(self, attr: str) -> pd.DataFrame:
325 __doc__ = GWCatalog.get_median_source.__doc__ # noqa: F841
326 val = self.get_detections(attr)
327 source_idx = self.get_detections()[
328 np.argmin(np.abs(np.array(val) - val.median()))
329 ]
330 return self.get_detections(self.get_attr_detections()).loc[
331 [source_idx]
332 ]
334 @UtilsMonitoring.io(level=logging.DEBUG)
335 @UtilsMonitoring.time_spend(level=logging.DEBUG, threshold_in_ms=100)
336 def get_source_samples(
337 self, source_name: str, attr: List[str] = None
338 ) -> pd.DataFrame:
339 __doc__ = GWCatalog.get_source_samples.__doc__ # noqa: F841
340 samples: pd.DataFrame = self.get_detections(["chain file"])
341 chain_file: str = samples.loc[source_name]["chain file"]
342 source_samples = self._read_chain_file(source_name, chain_file)
343 return source_samples if attr is None else source_samples[attr].copy()
345 @UtilsMonitoring.io(level=logging.DEBUG)
346 @UtilsMonitoring.time_spend(level=logging.DEBUG)
347 def get_attr_source_samples(self, source_name: str) -> List[str]:
348 __doc__ = GWCatalog.get_attr_source_samples.__doc__ # noqa: F841
349 return list(self.get_source_samples(source_name).columns)
351 @UtilsMonitoring.io(level=logging.TRACE)
352 @UtilsMonitoring.time_spend(level=logging.DEBUG)
353 def describe_source_samples(self, source_name: str) -> pd.DataFrame:
354 __doc__ = GWCatalog.describe_source_samples.__doc__ # noqa: F841
355 return self.get_source_samples(source_name).describe()
357 def __repr__(self):
358 return f"UcbCatalog({self.__name!r}, {self.__location!r})"
360 def __str__(self):
361 return f"UcbCatalog: {self.__name} {self.__location}"