From 9f6f89738d0bbc64cbfcc43778303afb07f2c401 Mon Sep 17 00:00:00 2001 From: ndilalla Date: Wed, 18 Feb 2026 19:45:37 -0800 Subject: [PATCH] Skymap inconsistency fixed. --- fermipy/skymap.py | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/fermipy/skymap.py b/fermipy/skymap.py index dc5bbdd1..869d51f0 100644 --- a/fermipy/skymap.py +++ b/fermipy/skymap.py @@ -133,11 +133,17 @@ def __init__(self, counts, wcs, ebins=None): Parameters ---------- counts : `~numpy.ndarray` - Counts array in row-wise ordering (LON is first dimension). + Counts array. Internal layout is (energy, lat, lon), i.e. first + dimension is energy, then spatial (lat, lon). So for 3D, + counts.shape = (n_energy, n_lat, n_lon). For 2D, counts.shape = + (n_lat, n_lon). This is the opposite of FITS/WCS axis order (which + is lon, lat, energy); see sum_over_energy and create_from_hdu. """ Map_Base.__init__(self, counts) self._wcs = wcs + # _npix is (n_lon, n_lat, n_energy) for 3D or (n_lon, n_lat) for 2D, + # i.e. counts.shape reversed (matches WCS axis order). self._npix = counts.shape[::-1] if len(self._npix) == 3: @@ -364,9 +370,9 @@ def get_map_values(self, lons, lats, ibin=None): Returns ---------- - vals : numpy.ndarray((n)) - Values of pixels in the flattened map, np.nan used to flag - coords outside of map + vals : numpy.ndarray + For 3D maps: shape (ne, np) with ne energy bins and np sky coordinates. + For 2D maps: shape (np,). np.nan flags coords outside of map. """ pix_idxs = self.get_pixel_indices(lons, lats, ibin) idxs = copy.copy(pix_idxs) @@ -376,10 +382,19 @@ def get_map_values(self, lons, lats, ibin=None): for i, p in enumerate(pix_idxs): m &= (pix_idxs[i] >= 0) & (pix_idxs[i] < self._npix[i]) idxs[i][~m] = 0 - - vals = self.counts.T[idxs] - vals[~m] = np.nan - return vals + + # Use tuple() so NumPy does multi-dim advanced indexing. + # counts is (energy, lat, lon) for 3D, (lat, lon) for 2D. pix_idxs is + # [lon, lat, energy] or [lon, lat], so index as (energy, lat, lon). + if len(self._npix) == 3: + vals = self.counts[(idxs[2], idxs[1], idxs[0])] + vals[~m] = np.nan + # Return (ne, np) to match HEALPix HpxMap.get_map_values convention + return vals.T + else: + vals = self.counts[(idxs[1], idxs[0])] + vals[~m] = np.nan + return vals def interpolate(self, lon, lat, egy=None):