Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions fermipy/skymap.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,17 @@ def __init__(self, counts, wcs, ebins=None):
Parameters
----------
counts : `~numpy.ndarray`
Counts array in row-wise ordering (LON is first dimension).
Counts array. Internal layout is (energy, lat, lon), i.e. first
dimension is energy, then spatial (lat, lon). So for 3D,
counts.shape = (n_energy, n_lat, n_lon). For 2D, counts.shape =
(n_lat, n_lon). This is the opposite of FITS/WCS axis order (which
is lon, lat, energy); see sum_over_energy and create_from_hdu.
"""
Map_Base.__init__(self, counts)
self._wcs = wcs

# _npix is (n_lon, n_lat, n_energy) for 3D or (n_lon, n_lat) for 2D,
# i.e. counts.shape reversed (matches WCS axis order).
self._npix = counts.shape[::-1]

if len(self._npix) == 3:
Expand Down Expand Up @@ -364,9 +370,9 @@ def get_map_values(self, lons, lats, ibin=None):

Returns
----------
vals : numpy.ndarray((n))
Values of pixels in the flattened map, np.nan used to flag
coords outside of map
vals : numpy.ndarray
For 3D maps: shape (ne, np) with ne energy bins and np sky coordinates.
For 2D maps: shape (np,). np.nan flags coords outside of map.
"""
pix_idxs = self.get_pixel_indices(lons, lats, ibin)
idxs = copy.copy(pix_idxs)
Expand All @@ -376,10 +382,19 @@ def get_map_values(self, lons, lats, ibin=None):
for i, p in enumerate(pix_idxs):
m &= (pix_idxs[i] >= 0) & (pix_idxs[i] < self._npix[i])
idxs[i][~m] = 0

vals = self.counts.T[idxs]
vals[~m] = np.nan
return vals

# Use tuple() so NumPy does multi-dim advanced indexing.
# counts is (energy, lat, lon) for 3D, (lat, lon) for 2D. pix_idxs is
# [lon, lat, energy] or [lon, lat], so index as (energy, lat, lon).
if len(self._npix) == 3:
vals = self.counts[(idxs[2], idxs[1], idxs[0])]
vals[~m] = np.nan
# Return (ne, np) to match HEALPix HpxMap.get_map_values convention
return vals.T
else:
vals = self.counts[(idxs[1], idxs[0])]
vals[~m] = np.nan
return vals

def interpolate(self, lon, lat, egy=None):

Expand Down
Loading