Source code for specreduce.background

# Licensed under a 3-clause BSD style license - see LICENSE.rst

import warnings
from dataclasses import dataclass, field

import numpy as np
from astropy import units as u
from astropy.utils.decorators import deprecated_attribute

from specreduce.compat import SPECUTILS_LT_2, Spectrum
from specreduce.core import _ImageParser, MaskingOption, ImageLike
from specreduce.extract import _ap_weight_image
from specreduce.tracing import Trace, FlatTrace

__all__ = ["Background"]


[docs] @dataclass class Background(_ImageParser): """ Determine the background from an image for subtraction. Example: :: trace = FlatTrace(image, trace_pos) bg = Background.two_sided(image, trace, bkg_sep, width=bkg_width) subtracted_image = image - bg Parameters ---------- image Image with 2-D spectral image data traces : List, `specreduce.tracing.Trace`, int, float Individual or list of trace object(s) (or integers/floats to define FlatTraces) to extract the background. If None, a ``FlatTrace`` at the center of the image (according to ``disp_axis``) will be used. width Width of extraction aperture in pixels. statistic Statistic to use when computing the background. ``average`` will account for partial pixel weights, ``median`` will include all partial pixels. disp_axis Dispersion axis. crossdisp_axis Cross-dispersion axis. mask_treatment Specifies how to handle masked or non-finite values in the input image. The accepted values are: - ``apply``: The image remains unchanged, and any existing mask is combined\ with a mask derived from non-finite values. - ``ignore``: The image remains unchanged, and any existing mask is dropped. - ``propagate``: The image remains unchanged, and any masked or non-finite pixel\ causes the mask to extend across the entire cross-dispersion axis. - ``zero_fill``: Pixels that are either masked or non-finite are replaced with 0.0,\ and the mask is dropped. - ``nan_fill``: Pixels that are either masked or non-finite are replaced with nan,\ and the mask is dropped. - ``apply_mask_only``: The image and mask are left unmodified. - ``apply_nan_only``: The image is left unmodified, the old mask is dropped, and a\ new mask is created based on non-finite values. """ # required so numpy won't call __rsub__ on individual elements # https://stackoverflow.com/a/58409215 __array_ufunc__ = None image: ImageLike traces: list = field(default_factory=list) width: float = 5 statistic: str = "average" disp_axis: int = 1 crossdisp_axis: int = 0 mask_treatment: MaskingOption = "apply" _valid_mask_treatment_methods = ( "apply", "ignore", "propagate", "zero_fill", "nan_fill", "apply_mask_only", "apply_nan_only", ) # TO-DO: update bkg_array with Spectrum alternative (is bkg_image enough?) bkg_array = deprecated_attribute("bkg_array", "1.3") def __post_init__(self): self.image = self._parse_image( self.image, disp_axis=self.disp_axis, mask_treatment=self.mask_treatment ) # always work with masked array, even if there is no masked # or non-finite data, in case padding is needed. if not, mask will be # dropped at the end and a regular array will be returned. img = np.ma.masked_array(self.image.data, self.image.mask) if self.width < 0: raise ValueError("width must be positive") if self.width == 0: self._bkg_array = np.zeros(self.image.shape[self.disp_axis]) return self._set_traces() bkg_wimage = np.zeros_like(self.image.data, dtype=np.float64) for trace in self.traces: # note: ArrayTrace can have masked values, but if it does a MaskedArray # will be returned so this should be reflected in the window size here # (i.e, np.nanmax is not required.) windows_max = trace.trace.data.max() + self.width / 2 windows_min = trace.trace.data.min() - self.width / 2 if windows_max > self.image.shape[self.crossdisp_axis]: warnings.warn( "background window extends beyond image boundaries " + f"({windows_max} >= {self.image.shape[self.crossdisp_axis]})" ) if windows_min < 0: warnings.warn( "background window extends beyond image boundaries " + f"({windows_min} < 0)" ) # pass trace.trace.data to ignore any mask on the trace bkg_wimage += _ap_weight_image( trace, self.width, self.disp_axis, self.crossdisp_axis, self.image.shape ) if np.any(bkg_wimage > 1): raise ValueError("background regions overlapped") if np.any(np.sum(bkg_wimage, axis=self.crossdisp_axis) == 0): raise ValueError( "background window does not remain in bounds across entire dispersion axis" ) # noqa # check if image contained within background window is fully-nonfinite and raise an error if np.all(img.mask[bkg_wimage > 0]): raise ValueError( "Image is fully masked within background window determined by `width`." ) # noqa if self.statistic == "median": # make it clear in the expose image that partial pixels are fully-weighted bkg_wimage[bkg_wimage > 0] = 1 self.bkg_wimage = bkg_wimage if self.statistic == "average": self._bkg_array = np.ma.average(img, weights=self.bkg_wimage, axis=self.crossdisp_axis) elif self.statistic == "median": # combine where background weight image is 0 with image masked (which already # accounts for non-finite data that wasn't already masked) img.mask = np.logical_or(self.bkg_wimage == 0, self.image.mask) self._bkg_array = np.ma.median(img, axis=self.crossdisp_axis) else: raise ValueError("statistic must be 'average' or 'median'") def _set_traces(self): """Determine `traces` from input. If an integer/float or list if int/float is passed in, use these to construct FlatTrace objects. These values must be positive. If None (which is initialized to an empty list), construct a FlatTrace using the center of image (according to disp. axis). Otherwise, any Trace object or list of Trace objects can be passed in.""" if self.traces == []: # assume a flat trace at the image center if nothing is passed in. trace_pos = self.image.shape[self.disp_axis] / 2.0 self.traces = [FlatTrace(self.image, trace_pos)] if isinstance(self.traces, Trace): # if just one trace, turn it into iterable. self.traces = [self.traces] return # finally, if float/int is passed in convert to FlatTrace(s) if isinstance(self.traces, (float, int)): # for a single number self.traces = [self.traces] if np.all([isinstance(x, (float, int)) for x in self.traces]): self.traces = [FlatTrace(self.image, trace_pos) for trace_pos in self.traces] return else: if not np.all([isinstance(x, Trace) for x in self.traces]): raise ValueError( "`traces` must be a `Trace` object or list of " "`Trace` objects, a number or list of numbers to " "define FlatTraces, or None to use a FlatTrace in " "the middle of the image." )
[docs] @classmethod def two_sided(cls, image, trace_object, separation, **kwargs): """ Determine the background from an image for subtraction centered around an input trace. Example: :: trace = FitTrace(image, guess=trace_pos) bg = Background.two_sided(image, trace, bkg_sep, width=bkg_width) Parameters ---------- image : `~astropy.nddata.NDData`-like or array-like Image with 2-D spectral image data. Assumes cross-dispersion (spatial) direction is axis 0 and dispersion (wavelength) direction is axis 1. trace_object: `~specreduce.tracing.Trace` estimated trace of the spectrum to center the background traces separation: float separation from ``trace_object`` for the background regions width : float width of each background aperture in pixels statistic: string statistic to use when computing the background. 'average' will account for partial pixel weights, 'median' will include all partial pixels. disp_axis : int dispersion axis crossdisp_axis : int cross-dispersion axis mask_treatment : string The method for handling masked or non-finite data. Choice of ``filter``, ``omit`, or ``zero_fill``. If `filter` is chosen, masked/non-finite data will be filtered during the fit to each bin/column (along disp. axis) to find the peak. If ``omit`` is chosen, columns along disp_axis with any masked/non-finite data values will be fully masked (i.e, 2D mask is collapsed to 1D and applied). If ``zero_fill`` is chosen, masked/non-finite data will be replaced with 0.0 in the input image, and the mask will then be dropped. For all three options, the input mask (optional on input NDData object) will be combined with a mask generated from any non-finite values in the image data. """ image = _ImageParser._get_data_from_image(image) if image is not None else cls.image kwargs["traces"] = [trace_object - separation, trace_object + separation] return cls(image=image, **kwargs)
[docs] @classmethod def one_sided(cls, image, trace_object, separation, **kwargs): """ Determine the background from an image for subtraction above or below an input trace. Example: :: trace = FitTrace(image, guess=trace_pos) bg = Background.one_sided(image, trace, bkg_sep, width=bkg_width) Parameters ---------- image : `~astropy.nddata.NDData`-like or array-like Image with 2-D spectral image data. Assumes cross-dispersion (spatial) direction is axis 0 and dispersion (wavelength) direction is axis 1. trace_object: `~specreduce.tracing.Trace` estimated trace of the spectrum to center the background traces separation: float separation from ``trace_object`` for the background, positive will be above the trace, negative below. width : float width of each background aperture in pixels statistic: string statistic to use when computing the background. 'average' will account for partial pixel weights, 'median' will include all partial pixels. disp_axis : int dispersion axis crossdisp_axis : int cross-dispersion axis mask_treatment : string The method for handling masked or non-finite data. Choice of ``filter``, ``omit``, or ``zero_fill``. If `filter` is chosen, masked/non-finite data will be filtered during the fit to each bin/column (along disp. axis) to find the peak. If ``omit`` is chosen, columns along disp_axis with any masked/non-finite data values will be fully masked (i.e, 2D mask is collapsed to 1D and applied). If ``zero_fill`` is chosen, masked/non-finite data will be replaced with 0.0 in the input image, and the mask will then be dropped. For all three options, the input mask (optional on input NDData object) will be combined with a mask generated from any non-finite values in the image data. """ image = _ImageParser._get_data_from_image(image) if image is not None else cls.image kwargs["traces"] = [trace_object + separation] return cls(image=image, **kwargs)
[docs] def bkg_image(self, image=None): """ Expose the background tiled to the dimension of ``image``. Parameters ---------- image : `~astropy.nddata.NDData`-like or array-like, optional Image with 2-D spectral image data. Assumes cross-dispersion (spatial) direction is axis 0 and dispersion (wavelength) direction is axis 1. If None, will extract the background from ``image`` used to initialize the class. [default: None] Returns ------- spec : `~specutils.Spectrum1D` Spectrum object with same shape as ``image``. """ image = self._parse_image(image) arr = np.tile(self._bkg_array, (image.shape[0], 1)) if SPECUTILS_LT_2: kwargs = {} else: kwargs = {"spectral_axis_index": arr.ndim - 1} return Spectrum( arr * image.unit, spectral_axis=image.spectral_axis, **kwargs )
[docs] def bkg_spectrum(self, image=None, bkg_statistic=None): """ Expose the 1D spectrum of the background. Parameters ---------- image : `~astropy.nddata.NDData`-like or array-like, optional Image with 2-D spectral image data. Assumes cross-dispersion (spatial) direction is axis 0 and dispersion (wavelength) direction is axis 1. If None, will extract the background from ``image`` used to initialize the class. [default: None] Returns ------- spec : `~specutils.Spectrum1D` The background 1-D spectrum, with flux expressed in the same units as the input image (or DN if none were provided). """ if bkg_statistic is not None: warnings.warn("'bkg_statistic' is deprecated and will be removed in a future release. " "Please use the 'statistic' argument in the Background initializer instead.", # noqa DeprecationWarning,) return Spectrum(self._bkg_array * self.image.unit, self.image.spectral_axis)
[docs] def sub_image(self, image=None): """ Subtract the computed background from ``image``. Parameters ---------- image : nddata-compatible image or None image with 2-D spectral image data. If None, will extract the background from ``image`` used to initialize the class. Returns ------- spec : `~specutils.Spectrum1D` Spectrum object with same shape as ``image``. """ image = self._parse_image(image) if not SPECUTILS_LT_2: return image - self.bkg_image(image) # a compare_wcs argument is needed for Spectrum.subtract() in order to # avoid a TypeError from SpectralCoord when image's spectral axis is in # pixels. it is not needed when image's spectral axis has physical units kwargs = {"compare_wcs": None} if image.spectral_axis.unit == u.pix else {} # https://docs.astropy.org/en/stable/nddata/mixins/ndarithmetic.html return image.subtract(self.bkg_image(image), **kwargs)
[docs] def sub_spectrum(self, image=None): """ Expose the 1D spectrum of the background-subtracted image. Parameters ---------- image : nddata-compatible image or None image with 2-D spectral image data. If None, will extract the background from ``image`` used to initialize the class. Returns ------- spec : `~specutils.Spectrum1D` The background 1-D spectrum, with flux expressed in the same units as the input image (or u.DN if none were provided) and the spectral axis expressed in pixel units. """ sub_image = self.sub_image(image=image) try: return sub_image.collapse(np.nansum, axis=self.crossdisp_axis) except u.UnitTypeError: # can't collapse with a spectral axis in pixels because # SpectralCoord only allows frequency/wavelength equivalent units... ext1d = np.nansum(sub_image.flux, axis=self.crossdisp_axis) return Spectrum(ext1d, spectral_axis=sub_image.spectral_axis)
def __rsub__(self, image): """ Subtract the background from an image. """ return self.sub_image(image)