Coverage for lisacattools/plugins/ucb.py: 82%

173 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-02-06 17:36 +0000

1# -*- coding: utf-8 -*- 

2# Copyright (C) 2021 - James I. Thorpe, Tyson B. Littenberg, Jean-Christophe 

3# Malapert 

4# 

5# This file is part of lisacattools. 

6# 

7# lisacattools is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lisacattools is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lisacattools. If not, see <https://www.gnu.org/licenses/>. 

19"""Module implemented the UCB catalog.""" 

20import glob 

21import logging 

22import os 

23from itertools import chain 

24from typing import List 

25from typing import Optional 

26from typing import Union 

27 

28import numpy as np 

29import pandas as pd 

30 

31from ..catalog import GWCatalog 

32from ..catalog import GWCatalogs 

33from ..catalog import UtilsLogs 

34from ..catalog import UtilsMonitoring 

35from ..utils import CacheManager 

36 

37UtilsLogs.addLoggingLevel("TRACE", 15) 

38 

39 

40class UcbCatalogs(GWCatalogs): 

41 """Implementation of the UCB catalogs.""" 

42 

43 EXTRA_DIR = "extra_directories" 

44 

45 def __init__( 

46 self, 

47 path: str, 

48 accepted_pattern: Optional[str] = "*.h5", 

49 rejected_pattern: Optional[str] = "*chain*", 

50 *args, 

51 **kwargs, 

52 ): 

53 """Init the UcbCatalogs by reading all catalogs with a specific 

54 pattern in a given directory and rejecting files by another pattern. 

55 

56 The list of catalogs is sorted by "observation week" 

57 

58 Args: 

59 path (str): directory 

60 accepted_pattern (str, optional): pattern to accept files. 

61 Defaults to "*.h5". 

62 rejected_pattern (str, optional): pattern to reject files. 

63 Defaults to "*chain*". 

64 

65 Raises: 

66 ValueError: no files found matching the accepted and rejected 

67 patterns. 

68 """ 

69 self.path = path 

70 self.accepted_pattern = accepted_pattern 

71 self.rejected_pattern = rejected_pattern 

72 self.extra_directories = ( 

73 kwargs[UcbCatalogs.EXTRA_DIR] 

74 if UcbCatalogs.EXTRA_DIR in kwargs 

75 else list() 

76 ) 

77 directories = self._search_directories( 

78 self.path, self.extra_directories 

79 ) 

80 self.cat_files = self._search_files( 

81 directories, accepted_pattern, rejected_pattern 

82 ) 

83 if len(self.cat_files) == 0: 

84 raise ValueError( 

85 f"no files found matching the accepted \ 

86 ({self.accepted_pattern}) and rejected \ 

87 ({self.rejected_pattern}) patterns in {directories}" 

88 ) 

89 self.__metadata = pd.concat( 

90 [self._read_cats(cat_file) for cat_file in self.cat_files] 

91 ) 

92 self.__metadata = self.__metadata.sort_values(by="Observation Time") 

93 

94 @UtilsMonitoring.io(level=logging.DEBUG) 

95 def _search_directories( 

96 self, path: str, extra_directories: List[str] 

97 ) -> List[str]: 

98 """Compute the list of directories on which the pattern will be applied. 

99 

100 Args: 

101 path (str) : main path 

102 extra_directories (List[str]) : others directories 

103 

104 Returns: 

105 List[str]: list of directories on which the pattern will be applied 

106 """ 

107 directories: List[str] = extra_directories[:] 

108 directories.append(path) 

109 return directories 

110 

111 @UtilsMonitoring.io(level=logging.DEBUG) 

112 def _search_files( 

113 self, directories: List[str], accepted_pattern, rejected_pattern 

114 ) -> List[str]: 

115 """Search files in directories according to a set of constraints : 

116 accepted and rejected patterns 

117 

118 Args: 

119 directories (List[str]): List of directories to scan 

120 accepted_pattern ([type]): pattern to get files 

121 rejected_pattern ([type]): pattern to reject files 

122 

123 Returns: 

124 List[str]: List of files 

125 """ 

126 accepted_files = [ 

127 glob.glob(path + os.path.sep + accepted_pattern) 

128 for path in directories 

129 ] 

130 accepted_files = list(chain(*accepted_files)) 

131 if rejected_pattern is None: 

132 rejected_files = list() 

133 else: 

134 rejected_files = [ 

135 list() 

136 if rejected_pattern is None 

137 else glob.glob(path + os.path.sep + rejected_pattern) 

138 for path in directories 

139 ] 

140 rejected_files = list(chain(*rejected_files)) 

141 cat_files = list(set(accepted_files) - set(rejected_files)) 

142 return cat_files 

143 

144 def _read_cats(self, cat_file: str) -> pd.DataFrame: 

145 """Reads the metadata of a given catalog and the location of the file. 

146 

147 Args: 

148 cat_file (str): catalog to load 

149 pattern (str) : pattern to be used when the catalog is a tar 

150 

151 Returns: 

152 pd.DataFrame: pandas data frame 

153 """ 

154 df = pd.read_hdf(cat_file, key="metadata") 

155 df["location"] = cat_file 

156 return df 

157 

158 @property 

159 @UtilsMonitoring.io(level=logging.DEBUG) 

160 def metadata(self) -> pd.DataFrame: 

161 __doc__ = GWCatalogs.metadata.__doc__ # noqa: F841 

162 return self.__metadata 

163 

164 @property 

165 @UtilsMonitoring.io(level=logging.TRACE) 

166 def count(self) -> int: 

167 __doc__ = GWCatalogs.count.__doc__ # noqa: F841 

168 return len(self.metadata.index) 

169 

170 @property 

171 @UtilsMonitoring.io(level=logging.TRACE) 

172 def files(self) -> List[str]: 

173 __doc__ = GWCatalogs.files.__doc__ # noqa: F841 

174 return self.cat_files 

175 

176 @UtilsMonitoring.io(level=logging.TRACE) 

177 def get_catalogs_name(self) -> List[str]: 

178 __doc__ = GWCatalogs.get_catalogs_name.__doc__ # noqa: F841 

179 return list(self.metadata.index) 

180 

181 @UtilsMonitoring.io(level=logging.TRACE) 

182 @UtilsMonitoring.time_spend(level=logging.DEBUG, threshold_in_ms=10) 

183 def get_first_catalog(self) -> GWCatalog: 

184 __doc__ = GWCatalogs.get_first_catalog.__doc__ # noqa: F841 

185 location = self.metadata.iloc[0]["location"] 

186 name = self.metadata.index[0] 

187 return UcbCatalog(name, location) 

188 

189 @UtilsMonitoring.io(level=logging.TRACE) 

190 @UtilsMonitoring.time_spend(level=logging.DEBUG, threshold_in_ms=10) 

191 def get_last_catalog(self) -> GWCatalog: 

192 __doc__ = GWCatalogs.get_last_catalog.__doc__ # noqa: F841 

193 location = self.metadata.iloc[self.count - 1]["location"] 

194 name = self.metadata.index[self.count - 1] 

195 return UcbCatalog(name, location) 

196 

197 @UtilsMonitoring.io(level=logging.TRACE) 

198 @UtilsMonitoring.time_spend(level=logging.DEBUG, threshold_in_ms=10) 

199 def get_catalog(self, idx: int) -> GWCatalog: 

200 __doc__ = GWCatalogs.get_catalog.__doc__ # noqa: F841 

201 location = self.metadata.iloc[idx]["location"] 

202 name = self.metadata.index[idx] 

203 return UcbCatalog(name, location) 

204 

205 @UtilsMonitoring.io(level=logging.TRACE) 

206 @UtilsMonitoring.time_spend(level=logging.DEBUG, threshold_in_ms=10) 

207 def get_catalog_by(self, name: str) -> GWCatalog: 

208 __doc__ = GWCatalogs.get_catalog_by.__doc__ # noqa: F841 

209 cat_idx = self.metadata.index.get_loc(name) 

210 return self.get_catalog(cat_idx) 

211 

212 def get_lineage(self, cat_name: str, src_name: str) -> pd.DataFrame: 

213 raise NotImplementedError( 

214 "Get_lineage is not implemented for this catalog !" 

215 ) 

216 

217 def get_lineage_data(self, lineage: pd.DataFrame) -> pd.DataFrame: 

218 raise NotImplementedError( 

219 "Get_lineage_data is not implemented for this catalog !" 

220 ) 

221 

222 def __repr__(self): 

223 return f"UcbCatalogs({self.path!r}, {self.accepted_pattern!r}, \ 

224 {self.rejected_pattern!r}, {self.extra_directories!r})" 

225 

226 def __str__(self): 

227 return f"UcbCatalogs: {self.path} {self.accepted_pattern!r} \ 

228 {self.rejected_pattern!r} {self.extra_directories!r}" 

229 

230 

231class UcbCatalog(GWCatalog): 

232 """Implementation of the Ucb catalog.""" 

233 

234 def __init__(self, catalog_name: str, location: str): 

235 """Init the LISA catalog with a name and a location 

236 

237 Args: 

238 name (str): name of the catalog 

239 location (str): location of the catalog 

240 """ 

241 self.__name = catalog_name 

242 self.__location = location 

243 store = pd.HDFStore(location, "r") 

244 self.__datasets = store.keys() 

245 store.close() 

246 

247 @CacheManager.get_cache_pandas( 

248 keycache_argument=[1, 2], level=logging.INFO 

249 ) 

250 def _read_chain_file( 

251 self, source_name: str, chain_file: str 

252 ) -> pd.DataFrame: 

253 """Read a source in a chain_file 

254 

255 Args: 

256 source_name (str): Name of the source to extract from the 

257 chain_file 

258 chain_file (str): file to load 

259 

260 Returns: 

261 pd.DataFrame: [description] 

262 """ 

263 dirname = os.path.dirname(self.location) 

264 source_samples_file = os.path.join(dirname, chain_file) 

265 source_samples = pd.read_hdf( 

266 source_samples_file, key=f"{source_name}_chain" 

267 ) 

268 return source_samples 

269 

270 @property 

271 @UtilsMonitoring.io(level=logging.DEBUG) 

272 def datasets(self): 

273 """dataset. 

274 

275 :getter: Returns the list of datasets 

276 :type: List 

277 """ 

278 return self.__datasets 

279 

280 @UtilsMonitoring.io(level=logging.DEBUG) 

281 @UtilsMonitoring.time_spend(level=logging.DEBUG, threshold_in_ms=10) 

282 def get_dataset(self, name: str) -> pd.DataFrame: 

283 """Returns a dataset based on its name. 

284 

285 Args: 

286 name (str): name of the dataset 

287 

288 Returns: 

289 pd.DataFrame: the dataset 

290 """ 

291 return pd.read_hdf(self.location, key=name) 

292 

293 @property 

294 @UtilsMonitoring.io(level=logging.DEBUG) 

295 def name(self) -> str: 

296 __doc__ = GWCatalog.name.__doc__ # noqa: F841 

297 return self.__name 

298 

299 @property 

300 @UtilsMonitoring.io(level=logging.DEBUG) 

301 def location(self) -> str: 

302 __doc__ = GWCatalog.location.__doc__ # noqa: F841 

303 return self.__location 

304 

305 @UtilsMonitoring.io(level=logging.DEBUG) 

306 @UtilsMonitoring.time_spend(level=logging.DEBUG, threshold_in_ms=100) 

307 def get_detections( 

308 self, attr: Union[List[str], str] = None 

309 ) -> Union[List[str], pd.DataFrame, pd.Series]: 

310 __doc__ = GWCatalog.get_detections.__doc__ # noqa: F841 

311 detections = self.get_dataset("detections") 

312 return ( 

313 list(detections.index) if attr is None else detections[attr].copy() 

314 ) 

315 

316 @UtilsMonitoring.io(level=logging.DEBUG) 

317 @UtilsMonitoring.time_spend(level=logging.DEBUG, threshold_in_ms=100) 

318 def get_attr_detections(self) -> List[str]: 

319 __doc__ = GWCatalog.get_attr_detections.__doc__ # noqa: F841 

320 return list(self.get_dataset("detections").columns) 

321 

322 @UtilsMonitoring.io(level=logging.DEBUG) 

323 @UtilsMonitoring.time_spend(level=logging.DEBUG, threshold_in_ms=100) 

324 def get_median_source(self, attr: str) -> pd.DataFrame: 

325 __doc__ = GWCatalog.get_median_source.__doc__ # noqa: F841 

326 val = self.get_detections(attr) 

327 source_idx = self.get_detections()[ 

328 np.argmin(np.abs(np.array(val) - val.median())) 

329 ] 

330 return self.get_detections(self.get_attr_detections()).loc[ 

331 [source_idx] 

332 ] 

333 

334 @UtilsMonitoring.io(level=logging.DEBUG) 

335 @UtilsMonitoring.time_spend(level=logging.DEBUG, threshold_in_ms=100) 

336 def get_source_samples( 

337 self, source_name: str, attr: List[str] = None 

338 ) -> pd.DataFrame: 

339 __doc__ = GWCatalog.get_source_samples.__doc__ # noqa: F841 

340 samples: pd.DataFrame = self.get_detections(["chain file"]) 

341 chain_file: str = samples.loc[source_name]["chain file"] 

342 source_samples = self._read_chain_file(source_name, chain_file) 

343 return source_samples if attr is None else source_samples[attr].copy() 

344 

345 @UtilsMonitoring.io(level=logging.DEBUG) 

346 @UtilsMonitoring.time_spend(level=logging.DEBUG) 

347 def get_attr_source_samples(self, source_name: str) -> List[str]: 

348 __doc__ = GWCatalog.get_attr_source_samples.__doc__ # noqa: F841 

349 return list(self.get_source_samples(source_name).columns) 

350 

351 @UtilsMonitoring.io(level=logging.TRACE) 

352 @UtilsMonitoring.time_spend(level=logging.DEBUG) 

353 def describe_source_samples(self, source_name: str) -> pd.DataFrame: 

354 __doc__ = GWCatalog.describe_source_samples.__doc__ # noqa: F841 

355 return self.get_source_samples(source_name).describe() 

356 

357 def __repr__(self): 

358 return f"UcbCatalog({self.__name!r}, {self.__location!r})" 

359 

360 def __str__(self): 

361 return f"UcbCatalog: {self.__name} {self.__location}"