from __future__ import annotations
import warnings
from abc import ABC
from typing import Any, Tuple, Optional, List
import math
import numpy as np
from scipy import interpolate
import matplotlib.pyplot as plt
try:
import cupy as cp
except (ModuleNotFoundError, ImportError):
import numpy as cp
from . import detector as lisa_models
from .utils.utility import AET, get_array_module
from .utils.constants import *
from .stochastic import (
StochasticContribution,
FittedHyperbolicTangentGalacticForeground,
)
from .sensitivity import SensitivityMatrix
class DataResidualArray:
pass
[docs]
class DataResidualArray:
"""Container to hold Data, residual, or template information.
This class abstracts the connection with the sensitivity matrices to make this analysis
as generic as possible for the user frontend, while handling
special computations in the backend.
Args:
data_res_in: Data, residual, or template input information. Can be a list, numpy array
or another :class:`DataResidualArray`.
dt: Timestep in seconds.
f_arr: Frequency array.
df: Delta f in frequency domain.
**kwargs: For future compatibility.
"""
def __init__(
self,
data_res_in: List[np.ndarray] | np.ndarray | DataResidualArray,
dt: Optional[float] = None,
f_arr: Optional[np.ndarray] = None,
df: Optional[float] = None,
**kwargs: dict,
) -> None:
if isinstance(data_res_in, DataResidualArray):
for key, item in data_res_in.__dict__.items():
setattr(self, key, item)
else:
self._check_inputs(dt=dt, f_arr=f_arr, df=df)
self.data_res_arr = data_res_in
self._store_time_and_frequency_information(dt=dt, f_arr=f_arr, df=df)
@property
def init_kwargs(self) -> dict:
"""Initial dt, df, f_arr"""
return self._init_kwargs
@init_kwargs.setter
def init_kwargs(self, init_kwargs: dict) -> None:
"""Set initial kwargs."""
self._init_kwargs = init_kwargs
def _check_inputs(
self,
dt: Optional[float] = None,
f_arr: Optional[np.ndarray] = None,
df: Optional[float] = None,
):
number_of_none = 0
number_of_none += 1 if dt is None else 0
number_of_none += 1 if f_arr is None else 0
number_of_none += 1 if df is None else 0
if number_of_none == 3:
raise ValueError("Must provide either df, dt, or f_arr.")
elif number_of_none == 1:
raise ValueError(
"Can only provide one of dt, f_arr, or df. Not more than one."
)
self.init_kwargs = dict(dt=dt, f_arr=f_arr, df=df)
def _store_time_and_frequency_information(
self,
dt: Optional[float] = None,
f_arr: Optional[np.ndarray] = None,
df: Optional[float] = None,
):
if dt is not None:
self._dt = dt
self._Tobs = self.data_length * dt
self._df = 1 / self._Tobs
self._fmax = 1 / (2 * dt)
xp = get_array_module(self.data_res_arr)
self._f_arr = xp.asarray(np.fft.rfftfreq(self.data_length, dt))
# transform data
tmp = (
xp.fft.rfft(self.data_res_arr, axis=-1)
* self._dt
)
del self._data_res_arr
self._data_res_arr = tmp
self.data_length = self._data_res_arr.shape[-1]
elif df is not None:
self._df = df
self._Tobs = 1 / self._df
self._fmax = (self.data_length - 1) * df
self._dt = 1 / (2 * self._fmax)
self._f_arr = np.arange(0.0, self._fmax, self._df)
elif f_arr is not None:
self._f_arr = f_arr
self._fmax = f_arr.max()
# constant spacing
if np.all(np.diff(f_arr) == np.diff(f_arr)[0]):
self._df = np.diff(f_arr)[0].item()
if f_arr[0] == 0.0:
# could be fft because of constant spacing and f_arr[0] == 0.0
self._Tobs = 1 / self._df
self._dt = 1 / (2 * self._fmax)
else:
# cannot be fft basis
self._Tobs = None
self._dt = None
else:
self._df = None
self._Tobs = None
self._dt = None
if len(self.f_arr) != self.data_length:
raise ValueError(
"Entered or determined f_arr does not have the same length as the data channel inputs."
)
@property
def fmax(self):
"""Maximum frequency."""
return self._fmax
@property
def f_arr(self):
"""Frequency array."""
return self._f_arr
@property
def dt(self):
"""Time step in seconds."""
if self._dt is None:
raise ValueError("dt cannot be determined from this f_arr input.")
return self._dt
@property
def Tobs(self):
"""Observation time in seconds"""
if self._Tobs is None:
raise ValueError("Tobs cannot be determined from this f_arr input.")
return self._Tobs
@property
def df(self):
"""Delta f in the frequency domain."""
if self._df is None:
raise ValueError("df cannot be determined from this f_arr input.")
return self._df
@property
def frequency_arr(self) -> np.ndarray:
"""Frequency array"""
return self._f_arr
@property
def data_res_arr(self) -> np.ndarray:
"""Actual data residual array"""
return self._data_res_arr
@data_res_arr.setter
def data_res_arr(self, data_res_arr: List[np.ndarray] | np.ndarray) -> None:
"""Set ``data_res_arr``."""
self._data_res_arr_input = data_res_arr
if (
isinstance(data_res_arr, np.ndarray) or isinstance(data_res_arr, cp.ndarray)
) and data_res_arr.ndim == 1:
data_res_arr = [data_res_arr]
elif (
isinstance(data_res_arr, np.ndarray) or isinstance(data_res_arr, cp.ndarray)
) and data_res_arr.ndim == 2:
data_res_arr = list(data_res_arr)
new_out = np.full(len(data_res_arr), None, dtype=object)
self.data_length = None
for i in range(len(data_res_arr)):
current_data = data_res_arr[i]
if isinstance(current_data, np.ndarray) or isinstance(
current_data, cp.ndarray
):
if self.data_length is None:
self.data_length = len(current_data)
else:
assert len(current_data) == self.data_length
new_out[i] = current_data
else:
raise ValueError
self.nchannels = len(new_out)
xp = get_array_module(new_out[0])
self._data_res_arr = xp.asarray(list(new_out), dtype=new_out[0].dtype)
def __getitem__(self, index: tuple) -> np.ndarray:
"""Index this class directly in ``self.data_res_arr``."""
return self.data_res_arr[index]
@property
def ndim(self) -> int:
"""Number of dimensions in the `data_res_arr`."""
return self.data_res_arr.ndim
[docs]
def flatten(self) -> np.ndarray:
"""Flatten the ``data_res_arr``."""
return self.data_res_arr.flatten()
@property
def shape(self) -> tuple:
"""Shape of ``data_res_arr``."""
return self.data_res_arr.shape
[docs]
def loglog(
self,
ax: Optional[List[plt.Axes] | plt.Axes] = None,
fig: Optional[plt.Figure] = None,
inds: Optional[List[int] | int] = None,
char_strain: Optional[bool] = False,
**kwargs: dict,
) -> Tuple[plt.Figure, plt.Axes]:
"""Produce a log-log plot of the data.
Args:
ax: Matplotlib Axes objects to add plots. Either a list of Axes objects or a single Axes object.
fig: Matplotlib figure object.
inds: Integer index to select out which data to add to a single access.
A list can be provided if ax is a list. They must be the same length.
char_strain: If ``True`` return plot in characteristic strain representation.
**kwargs: Keyword arguments to be passed to ``loglog`` function in matplotlib.
Returns:
Matplotlib figure and axes objects in a 2-tuple.
"""
if ax is None and fig is None:
nrows = 1
ncols = self.shape[0]
fig, ax = plt.subplots(nrows, ncols, sharex=True, sharey=True)
ax = ax.ravel()
inds_list = range(len(ax))
elif ax is not None:
if isinstance(ax, list):
assert len(ax) == np.prod(self.shape[:-1])
if inds is None:
inds_list = list(np.arange(np.prod(self.shape[:-1])))
else:
assert isinstance(inds, list) and len(inds) == len(ax)
inds_list = inds
elif isinstance(ax, plt.Axes):
assert inds is not None and (
isinstance(inds, tuple) or isinstance(inds, int)
)
ax = [ax]
inds_list = [inds]
elif fig is not None:
raise NotImplementedError
for i, ax_tmp in zip(inds_list, ax):
plot_in = np.abs(self.data_res_arr[i])
if char_strain:
plot_in *= self.frequency_arr
ax_tmp.loglog(self.frequency_arr, plot_in, **kwargs)
return (fig, ax)
@property
def char_strain(self) -> np.ndarray:
"""Characteristic strain representation of the data."""
return np.sqrt(self.f_arr) * np.abs(self.data_res_arr)