# Licensed under a 3-clause BSD style license - see LICENSE.rst
from matplotlib.axes import Axes, subplot_class_factory
from matplotlib.transforms import Affine2D, Bbox, Transform
from astropy.wcs import WCS
from astropy.wcs.utils import wcs_to_celestial_frame
from astropy.extern import six
from .transforms import (WCSPixel2WorldTransform, WCSWorld2PixelTransform,
CoordinateTransform)
from .coordinates_map import CoordinatesMap
from .utils import get_coord_meta
from .frame import RectangularFrame
import numpy as np
__all__ = ['WCSAxes', 'WCSAxesSubplot']
VISUAL_PROPERTIES = ['facecolor', 'edgecolor', 'linewidth', 'alpha', 'linestyle']
IDENTITY = WCS(naxis=2)
IDENTITY.wcs.ctype = ["X", "Y"]
IDENTITY.wcs.crval = [1., 1.]
IDENTITY.wcs.crpix = [1., 1.]
IDENTITY.wcs.cdelt = [1., 1.]
[docs]class WCSAxes(Axes):
def __init__(self, fig, rect, wcs=None, transform=None, coord_meta=None,
transData=None, slices=None, frame_class=RectangularFrame,
**kwargs):
super(WCSAxes, self).__init__(fig, rect, **kwargs)
self._bboxes = []
self.frame_class = frame_class
if not (transData is None):
# User wants to override the transform for the final
# data->pixel mapping
self.transData = transData
self.reset_wcs(wcs=wcs, slices=slices, transform=transform, coord_meta=coord_meta)
self._hide_parent_artists()
self.format_coord = self._display_world_coords
self._display_coords_index = 0
fig.canvas.mpl_connect('key_press_event', self._set_cursor_prefs)
self.patch = self.coords.frame.patch
self._drawn = False
def _display_world_coords(self, x, y):
if not self._drawn:
return ""
if self._display_coords_index == -1:
return "%s %s (pixel)" % (x, y)
pixel = np.array([x, y])
coords = self._all_coords[self._display_coords_index]
world = coords._transform.transform(np.array([pixel]))[0]
xw = coords[self._x_index].format_coord(world[self._x_index])
yw = coords[self._y_index].format_coord(world[self._y_index])
if self._display_coords_index == 0:
system = "world"
else:
system = "world, overlay {0}".format(self._display_coords_index)
coord_string = "%s %s (%s)" % (xw, yw, system)
return coord_string
def _set_cursor_prefs(self, event, **kwargs):
if event.key == 'w':
self._display_coords_index += 1
if self._display_coords_index + 1 > len(self._all_coords):
self._display_coords_index = -1
def _hide_parent_artists(self):
# Turn off spines and current axes
for s in self.spines.values():
s.set_visible(False)
self.xaxis.set_visible(False)
self.yaxis.set_visible(False)
# We now overload ``imshow`` because we need to make sure that origin is
# set to ``lower`` for all images, which means that we need to flip RGB
# images.
[docs] def imshow(self, X, *args, **kwargs):
"""
Wrapper to Matplotlib's :meth:`~matplotlib.axes.Axes.imshow`.
If an RGB image is passed as a PIL object, it will be flipped
vertically and ``origin`` will be set to ``lower``, since WCS
transformations - like FITS files - assume that the origin is the lower
left pixel of the image (whereas RGB images have the origin in the top
left).
All arguments are passed to :meth:`~matplotlib.axes.Axes.imshow`.
"""
origin = kwargs.get('origin', None)
if origin == 'upper':
raise ValueError("Cannot use images with origin='upper' in WCSAxes.")
# To check whether the image is a PIL image we can check if the data
# has a 'getpixel' attribute - this is what Matplotlib's AxesImage does
try:
from PIL.Image import Image, FLIP_TOP_BOTTOM
except ImportError:
# We don't need to worry since PIL is not installed, so user cannot
# have passed RGB image.
pass
else:
if isinstance(X, Image) or hasattr(X, 'getpixel'):
X = X.transpose(FLIP_TOP_BOTTOM)
kwargs['origin'] = 'lower'
return super(WCSAxes, self).imshow(X, *args, **kwargs)
[docs] def reset_wcs(self, wcs=None, slices=None, transform=None, coord_meta=None):
"""
Reset the current Axes, to use a new WCS object.
"""
# Here determine all the coordinate axes that should be shown.
if wcs is None and transform is None:
self.wcs = IDENTITY
else:
# We now force call 'set', which ensures the WCS object is
# consistent, which will only be important if the WCS has been set
# by hand. For example if the user sets a celestial WCS by hand and
# forgets to set the units, WCS.wcs.set() will do this.
if wcs is not None:
wcs.wcs.set()
self.wcs = wcs
# If we are making a new WCS, we need to preserve the path object since
# it may already be used by objects that have been plotted, and we need
# to continue updating it. CoordinatesMap will create a new frame
# instance, but we can tell that instance to keep using the old path.
if hasattr(self, 'coords'):
previous_frame_path = self.coords.frame._path
else:
previous_frame_path = None
self.coords = CoordinatesMap(self, wcs=self.wcs, slice=slices,
transform=transform, coord_meta=coord_meta,
frame_class=self.frame_class,
previous_frame_path=previous_frame_path)
self._all_coords = [self.coords]
if slices is None:
self.slices = ('x', 'y')
self._x_index = 0
self._y_index = 1
else:
self.slices = slices
self._x_index = self.slices.index('x')
self._y_index = self.slices.index('y')
# Common default settings for Rectangular Frame
if self.frame_class is RectangularFrame:
for coord_index in range(len(self.slices)):
if self.slices[coord_index] == 'x':
self.coords[coord_index].set_axislabel_position('b')
self.coords[coord_index].set_ticklabel_position('b')
elif self.slices[coord_index] == 'y':
self.coords[coord_index].set_axislabel_position('l')
self.coords[coord_index].set_ticklabel_position('l')
else:
self.coords[coord_index].set_axislabel_position('')
self.coords[coord_index].set_ticklabel_position('')
self.coords[coord_index].set_ticks_position('')
[docs] def draw(self, renderer, inframe=False):
# In Axes.draw, the following code can result in the xlim and ylim
# values changing, so we need to force call this here to make sure that
# the limits are correct before we update the patch.
locator = self.get_axes_locator()
if locator:
pos = locator(self, renderer)
self.apply_aspect(pos)
else:
self.apply_aspect()
# We need to make sure that that frame path is up to date
self.coords.frame._update_patch_path()
super(WCSAxes, self).draw(renderer, inframe)
# Here need to find out range of all coordinates, and update range for
# each coordinate axis. For now, just assume it covers the whole sky.
self._bboxes = []
self._ticklabels_bbox = []
visible_ticks = []
for coords in self._all_coords:
coords.frame.update()
for coord in coords:
coord._draw(renderer, bboxes=self._bboxes,
ticklabels_bbox=self._ticklabels_bbox)
visible_ticks.extend(coord.ticklabels.get_visible_axes())
for coords in self._all_coords:
for coord in coords:
coord._draw_axislabels(renderer, bboxes=self._bboxes,
ticklabels_bbox=self._ticklabels_bbox,
visible_ticks=visible_ticks)
self.coords.frame.draw(renderer)
self._drawn = True
[docs] def set_xlabel(self, label):
self.coords[self._x_index].set_axislabel(label)
[docs] def set_ylabel(self, label):
self.coords[self._y_index].set_axislabel(label)
[docs] def get_xlabel(self):
return self.coords[self._x_index].get_axislabel()
[docs] def get_ylabel(self):
return self.coords[self._y_index].get_axislabel()
[docs] def get_coords_overlay(self, frame, coord_meta=None):
# Here we can't use get_transform because that deals with
# pixel-to-pixel transformations when passing a WCS object.
if isinstance(frame, WCS):
coords = CoordinatesMap(self, frame, frame_class=self.frame_class)
else:
if coord_meta is None:
coord_meta = get_coord_meta(frame)
transform = self._get_transform_no_transdata(frame)
coords = CoordinatesMap(self, transform=transform, coord_meta=coord_meta, frame_class=self.frame_class)
self._all_coords.append(coords)
# Common settings for overlay
coords[0].set_axislabel_position('t')
coords[1].set_axislabel_position('r')
coords[0].set_ticklabel_position('t')
coords[1].set_ticklabel_position('r')
self.overlay_coords = coords
return coords
def _get_transform_no_transdata(self, frame):
"""
Return a transform from data to the specified frame
"""
if self.wcs is None and frame != 'pixel':
raise ValueError('No WCS specified, so only pixel coordinates are available')
if isinstance(frame, WCS):
coord_in = wcs_to_celestial_frame(self.wcs)
coord_out = wcs_to_celestial_frame(frame)
if coord_in == coord_out:
return (WCSPixel2WorldTransform(self.wcs, slice=self.slices)
+ WCSWorld2PixelTransform(frame))
else:
return (WCSPixel2WorldTransform(self.wcs, slice=self.slices)
+ CoordinateTransform(self.wcs, frame)
+ WCSWorld2PixelTransform(frame))
elif frame == 'pixel':
return Affine2D()
elif isinstance(frame, Transform):
pixel2world = WCSPixel2WorldTransform(self.wcs, slice=self.slices)
return pixel2world + frame
else:
pixel2world = WCSPixel2WorldTransform(self.wcs, slice=self.slices)
if frame == 'world':
return pixel2world
else:
coordinate_transform = CoordinateTransform(self.wcs, frame)
if coordinate_transform.same_frames:
return pixel2world
else:
return pixel2world + CoordinateTransform(self.wcs, frame)
[docs] def get_tightbbox(self, renderer):
if not self.get_visible():
return
bb = [b for b in self._bboxes if b and (b.width != 0 or b.height != 0)]
if bb:
_bbox = Bbox.union(bb)
return _bbox
else:
return self.get_window_extent(renderer)
[docs] def grid(self, draw_grid=True, **kwargs):
"""
Plot gridlines for both coordinates.
Standard matplotlib appearance options (color, alpha, etc.) can be
passed as keyword arguments.
Parameters
----------
draw_grid : bool
Whether to show the gridlines
"""
if draw_grid and hasattr(self, 'coords'):
self.coords.grid(draw_grid=draw_grid, **kwargs)
# In the following, we put the generated subplot class in a temporary class and
# we then inherit it - if we don't do this, the generated class appears to
# belong in matplotlib, not in WCSAxes, from the API's point of view.
[docs]class WCSAxesSubplot(subplot_class_factory(WCSAxes)):
pass