If you want to follow along, download this post and run it in the Jupyter Notebook.
A colleague came to me with a question. She had some Herschel data in a FITS image with world coordinate system information (WCS) as well as some derived data in pixel coordinates based on that image. She wanted to create a new FITS image with WCS information to allow her to match the two images. It turns out the simplest way to explain how to do that was basically to write a blog post, so here we are.
To begin, use the usual incantations to bring in matplotlib and NumPy:
%matplotlib inline
%config InlineBackend.figure_formats = ['retina']
%config InlineBackend.print_figure_kwargs = {'facecolor': (1.0, 1.0, 1.0, 0.0)}
import numpy as np
from matplotlib import pyplot as plt
import matplotlib
from matplotlib.colors import LogNorm
# correct image orientation and improve appearance:
matplotlib.rcParams.update({
'image.origin': 'lower',
'image.interpolation': 'nearest',
'image.cmap': 'magma',
'font.family': 'serif',
})
# not great practice for real code, but we can safely ignore
# warnings in this example:
import warnings
warnings.simplefilter('ignore')
Load a FITS image with world coordinate system info¶
To work with, we're using an image from the Herschel M33 Extended Survey (HerM33es) (PI: Kramer) taken with the SPIRE instrument at a wavelength of 350 µm. We're going to retrieve the Herschel image from the NASA/IPAC Infrared Science Archive with Python, but you could just as easily download it here and drop it in this directory.
from urllib.request import urlretrieve
from os.path import exists
url = 'https://irsa.ipac.caltech.edu/data/Herschel/HerM33es/images/spire_350_v6.fits'
dst = 'spire_350_v6.fits'
if not exists(dst):
urlretrieve(url, dst)
dst
Let's see what we got:
from astropy.io import fits
hdul = fits.open(dst)
image = hdul[0].data
plt.imshow(image)
hdul.info()
Functionality to convert between sky coordinates (e.g. RA and Dec) and pixels is in astropy.wcs
, mostly accessed through instances of the WCS
class.
from astropy.wcs import WCS
The WCS
class can be initialized with an existing set of WCS transformations drawn from a FITS header, so we pass in the header of our Herschel image:
wcs = WCS(header=hdul[0].header)
To plot on axes labeled with the celestial coordinates for our image, we pass projection=wcs
when creating our axes.
ax = plt.subplot(projection=wcs)
im = ax.imshow(image)
plt.colorbar(im)
Converting world coordinates to pixels¶
Now let's do some manipulations in pixel space. First, suppose we're only interested in the regions from 23º to 24º in (decimal) RA and 30º to 31º in Dec. WCSAxes (the underlying translation layer from WCS to matplotlib axes coordinates) doesn't let us set limits in RA and Dec (yet?) so we convert the coordinates of our lower left and upper right corner.
The way to do this is the all_world2pix(ra, dec, origin)
method, which takes arrays of RAs and Decs (or whatever world coordinates are used for this file) and an origin (Python starts counting array indices at 0) to produce pixel coordinates following the WCS transformations.
dec_ll, ra_ll = 30, 24
dec_ur, ra_ur = 31, 23
(xmin, xmax), (ymin, ymax) = wcs.all_world2pix([ra_ll, ra_ur], [dec_ll, dec_ur], 0)
(xmin, xmax), (ymin, ymax)
Now we pass (xmin, xmax)
and (ymin, ymax)
to set the axis limits in pixels and crop our image to the region of interest:
ax = plt.subplot(projection=wcs)
im = ax.imshow(image)
plt.colorbar(im)
ax.set(xlim=(xmin, xmax), ylim=(ymin, ymax))
Say we wanted to cut the corresponding part of the underlying array. See how we have non-integer values for our axis limits? That won't work for array slicing. Since you usually want to err on the side of keeping more data, round the minima down and maxima up:
xmin_int, xmax_int = int(np.floor(xmin)), int(np.ceil(xmax))
ymin_int, ymax_int = int(np.floor(ymin)), int(np.ceil(ymax))
(xmin_int, xmax_int), (ymin_int, ymax_int)
Recall that NumPy interprets array indices with x coordinate last (so, (y, x)
for our 2D image). Cut out the corresponding region with slice syntax:
subregion = image[ymin_int:ymax_int,xmin_int:xmax_int]
Make sure it looks right:
_ = plt.imshow(subregion)