Source code for wcsaxes.coordinate_helpers

# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""
This file defines the classes used to represent a 'coordinate', which includes
axes, ticks, tick labels, and grid lines.
"""

import numpy as np
from astropy import units as u
from astropy.extern import six

from matplotlib.ticker import Formatter
from matplotlib.transforms import Affine2D, ScaledTranslation
from matplotlib.patches import PathPatch
from matplotlib import rcParams

from .formatter_locator import AngleFormatterLocator, ScalarFormatterLocator
from .ticks import Ticks
from .ticklabels import TickLabels
from .axislabels import AxisLabels
from .grid_paths import get_lon_lat_path, get_gridline_path
from . import settings


__all__ = ['CoordinateHelper']


def wrap_angle_at(values, coord_wrap):
    # On ARM processors, np.mod emits warnings if there are NaN values in the
    # array, although this doesn't seem to happen on other processors.
    with np.errstate(invalid='ignore'):
        return np.mod(values - coord_wrap, 360.) - (360. - coord_wrap)


[docs]class CoordinateHelper(object): def __init__(self, parent_axes=None, parent_map=None, transform=None, coord_index=None, coord_type='scalar', coord_unit=None, coord_wrap=None, frame=None): # Keep a reference to the parent axes and the transform self.parent_axes = parent_axes self.parent_map = parent_map self.transform = transform self.coord_index = coord_index self.coord_unit = coord_unit self.frame = frame self.set_coord_type(coord_type, coord_wrap) # Initialize ticks self.dpi_transform = Affine2D() self.offset_transform = ScaledTranslation(0, 0, self.dpi_transform) self.ticks = Ticks(transform=parent_axes.transData + self.offset_transform) # Initialize tick labels self.ticklabels = TickLabels(self.frame, transform=None, # display coordinates figure=parent_axes.get_figure()) self.ticks.display_minor_ticks(False) self.minor_frequency = 5 # Initialize axis labels self.axislabels = AxisLabels(self.frame, transform=None, # display coordinates figure=parent_axes.get_figure()) # Initialize container for the grid lines self.grid_lines = [] # Initialize grid style. Take defaults from matplotlib.rcParams. # Based on matplotlib.axis.YTick._get_gridline. # # Matplotlib's gridlines use Line2D, but ours use PathPatch. # Patches take a slightly different format of linestyle argument. lines_to_patches_linestyle = { '-': 'solid', '--': 'dashed', '-.': 'dashdot', ':': 'dotted', 'none': 'none', 'None': 'none', ' ': 'none', '': 'none' } self.grid_lines_kwargs = {'visible': False, 'facecolor': 'none', 'edgecolor': rcParams['grid.color'], 'linestyle': lines_to_patches_linestyle[rcParams['grid.linestyle']], 'linewidth': rcParams['grid.linewidth'], 'alpha': rcParams.get('grid.alpha', 1.0), 'transform': self.parent_axes.transData}
[docs] def grid(self, draw_grid=True, grid_type='lines', **kwargs): """ Plot grid lines for this coordinate. Standard matplotlib appearance options (color, alpha, etc.) can be passed as keyword arguments. Parameters ---------- draw_grid : bool Whether to show the gridlines grid_type : { 'lines' | 'contours' } Whether to plot the contours by determining the grid lines in world coordinates and then plotting them in world coordinates (``'lines'``) or by determining the world coordinates at many positions in the image and then drawing contours (``'contours'``). The first is recommended for 2-d images, while for 3-d (or higher dimensional) cubes, the ``'contours'`` option is recommended. """ if grid_type in ('lines', 'contours'): self._grid_type = grid_type else: raise ValueError("grid_type should be 'lines' or 'contours'") if 'color' in kwargs: kwargs['edgecolor'] = kwargs.pop('color') self.grid_lines_kwargs.update(kwargs) if self.grid_lines_kwargs['visible']: if not draw_grid: self.grid_lines_kwargs['visible'] = False else: self.grid_lines_kwargs['visible'] = True
[docs] def set_coord_type(self, coord_type, coord_wrap=None): """ Set the coordinate type for the axis. Parameters ---------- coord_type : str One of 'longitude', 'latitude' or 'scalar' coord_wrap : float, optional The value to wrap at for angular coordinates """ self.coord_type = coord_type if coord_type == 'longitude' and coord_wrap is None: self.coord_wrap = 360 elif coord_type != 'longitude' and coord_wrap is not None: raise NotImplementedError('coord_wrap is not yet supported for non-longitude coordinates') else: self.coord_wrap = coord_wrap # Initialize tick formatter/locator if coord_type == 'scalar': self._coord_unit_scale = None self._formatter_locator = ScalarFormatterLocator(unit=self.coord_unit) elif coord_type in ['longitude', 'latitude']: if self.coord_unit is u.deg: self._coord_unit_scale = None else: self._coord_unit_scale = self.coord_unit.to(u.deg) self._formatter_locator = AngleFormatterLocator() else: raise ValueError("coord_type should be one of 'scalar', 'longitude', or 'latitude'")
[docs] def set_major_formatter(self, formatter): """ Set the formatter to use for the major tick labels. Parameters ---------- formatter : str or Formatter The format or formatter to use. """ if isinstance(formatter, Formatter): raise NotImplementedError() # figure out how to swap out formatter elif isinstance(formatter, six.string_types): self._formatter_locator.format = formatter else: raise TypeError("formatter should be a string or a Formatter " "instance")
[docs] def format_coord(self, value): """ Given the value of a coordinate, will format it according to the format of the formatter_locator. """ if not hasattr(self, "_fl_spacing"): return "" # _update_ticks has not been called yet fl = self._formatter_locator if isinstance(fl, AngleFormatterLocator): # Convert to degrees if needed if self._coord_unit_scale is not None: value *= self._coord_unit_scale if self.coord_type == 'longitude': value = wrap_angle_at(value, self.coord_wrap) value = value * u.degree value = value.to(fl._unit).value spacing = self._fl_spacing string = fl.formatter(values=[value] * fl._unit, spacing=spacing) return string[0]
[docs] def set_separator(self, separator): """ Set the separator to use for the angle major tick labels. Parameters ---------- separator : The separator between numbers in sexagesimal representation. Can be either a string or a tuple. """ if not (self._formatter_locator.__class__ == AngleFormatterLocator): raise TypeError("Separator can only be specified for angle coordinates") if isinstance(separator, six.string_types) or isinstance(separator, tuple): self._formatter_locator.sep = separator else: raise TypeError("separator should be a string or a tuple")
[docs] def set_format_unit(self, unit): """ Set the unit for the major tick labels. Parameters ---------- unit : class:`~astropy.units.Unit` The unit to which the tick labels should be converted to. """ if (not issubclass(unit.__class__, u.UnitBase)): raise TypeError("unit should be an astropy UnitBase subclass") self._formatter_locator.format_unit = unit
[docs] def set_ticks(self, values=None, spacing=None, number=None, size=None, width=None, color=None, alpha=None, exclude_overlapping=False): """ Set the location and properties of the ticks. At most one of the options from ``values``, ``spacing``, or ``number`` can be specified. Parameters ---------- values : iterable, optional The coordinate values at which to show the ticks. spacing : float, optional The spacing between ticks. number : float, optional The approximate number of ticks shown. size : float, optional The length of the ticks in points color : str or tuple A valid Matplotlib color for the ticks exclude_overlapping : bool, optional Whether to exclude tick labels that overlap over each other. """ if sum([values is None, spacing is None, number is None]) < 2: raise ValueError("At most one of values, spacing, or number should " "be specified") if values is not None: self._formatter_locator.values = values elif spacing is not None: self._formatter_locator.spacing = spacing elif number is not None: self._formatter_locator.number = number if size is not None: self.ticks.set_ticksize(size) if width is not None: self.ticks.set_linewidth(width) if color is not None: self.ticks.set_color(color) if alpha is not None: self.ticks.set_alpha(alpha) self.ticklabels.set_exclude_overlapping(exclude_overlapping)
[docs] def set_ticks_position(self, position): """ Set where ticks should appear Parameters ---------- position : str The axes on which the ticks for this coordinate should appear. Should be a string containing zero or more of ``'b'``, ``'t'``, ``'l'``, ``'r'``. For example, ``'lb'`` will lead the ticks to be shown on the left and bottom axis. """ self.ticks.set_visible_axes(position)
[docs] def set_ticks_visible(self, visible): """ Set whether ticks are visible or not. Parameters ---------- visible : bool The visibility of ticks. Setting as ``False`` will hide ticks along this coordinate. """ self.ticks.set_visible(visible)
[docs] def set_ticklabel(self, **kwargs): """ Set the visual properties for the tick labels. Parameters ---------- kwargs Keyword arguments are passed to :class:`matplotlib.text.Text`. These can include keywords to set the ``color``, ``size``, ``weight``, and other text properties. """ self.ticklabels.set(**kwargs)
[docs] def set_ticklabel_position(self, position): """ Set where tick labels should appear Parameters ---------- position : str The axes on which the tick labels for this coordinate should appear. Should be a string containing zero or more of ``'b'``, ``'t'``, ``'l'``, ``'r'``. For example, ``'lb'`` will lead the tick labels to be shown on the left and bottom axis. """ self.ticklabels.set_visible_axes(position)
[docs] def set_ticklabel_visible(self, visible): """ Set whether the tick labels are visible or not. Parameters ---------- visible : bool The visibility of ticks. Setting as ``False`` will hide this coordinate's tick labels. """ self.ticklabels.set_visible(visible)
[docs] def set_axislabel(self, text, minpad=1, **kwargs): """ Set the text and optionally visual properties for the axis label. Parameters ---------- text : str The axis label text. minpad : float, optional The padding for the label in terms of axis label font size. kwargs Keywords are passed to :class:`matplotlib.text.Text`. These can include keywords to set the ``color``, ``size``, ``weight``, and other text properties. """ self.axislabels.set_text(text) self.axislabels.set_minpad(minpad) self.axislabels.set(**kwargs)
[docs] def get_axislabel(self): """ Get the text for the axis label Returns ------- label : str The axis label """ return self.axislabels.get_text()
[docs] def set_axislabel_position(self, position): """ Set where axis labels should appear Parameters ---------- position : str The axes on which the axis label for this coordinate should appear. Should be a string containing zero or more of ``'b'``, ``'t'``, ``'l'``, ``'r'``. For example, ``'lb'`` will lead the axis label to be shown on the left and bottom axis. """ self.axislabels.set_visible_axes(position)
@property def locator(self): return self._formatter_locator.locator @property def formatter(self): return self._formatter_locator.formatter def _draw(self, renderer, bboxes, ticklabels_bbox): renderer.open_group('coordinate_axis') self._update_ticks(renderer) self.ticks.draw(renderer) self.ticklabels.draw(renderer, bboxes=bboxes, ticklabels_bbox=ticklabels_bbox) if self.grid_lines_kwargs['visible']: if self._grid_type == 'lines': self._update_grid_lines() else: self._update_grid_contour() if self._grid_type == 'lines': frame_patch = self.frame.patch for path in self.grid_lines: p = PathPatch(path, **self.grid_lines_kwargs) p.set_clip_path(frame_patch) p.draw(renderer) elif self._grid is not None: for line in self._grid.collections: line.set(**self.grid_lines_kwargs) line.draw(renderer) renderer.close_group('coordinate_axis') def _draw_axislabels(self, renderer, bboxes, ticklabels_bbox, visible_ticks): renderer.open_group('axis labels') self.axislabels.draw(renderer, bboxes=bboxes, ticklabels_bbox_list=ticklabels_bbox, visible_ticks=visible_ticks) renderer.close_group('axis labels') def _update_ticks(self, renderer): # TODO: this method should be optimized for speed # Here we determine the location and rotation of all the ticks. For # each axis, we can check the intersections for the specific # coordinate and once we have the tick positions, we can use the WCS # to determine the rotations. # Find the range of coordinates in all directions coord_range = self.parent_map.get_coord_range() # First find the ticks we want to show tick_world_coordinates, self._fl_spacing = self.locator(*coord_range[self.coord_index]) if self.ticks.get_display_minor_ticks(): minor_ticks_w_coordinates = self._formatter_locator.minor_locator(self._fl_spacing, self.get_minor_frequency(), *coord_range[self.coord_index]) # We want to allow non-standard rectangular frames, so we just rely on # the parent axes to tell us what the bounding frame is. frame = self.frame.sample(settings.FRAME_BOUNDARY_SAMPLES) self.ticks.clear() self.ticklabels.clear() self.lblinfo = [] self.lbl_world = [] # Look up parent axes' transform from data to figure coordinates. # # See: # http://matplotlib.org/users/transforms_tutorial.html#the-transformation-pipeline transData = self.parent_axes.transData invertedTransLimits = transData.inverted() for axis, spine in six.iteritems(frame): # Determine tick rotation in display coordinates and compare to # the normal angle in display coordinates. pixel0 = spine.data world0 = spine.world[:, self.coord_index] world0 = self.transform.transform(pixel0)[:, self.coord_index] axes0 = transData.transform(pixel0) # Advance 2 pixels in figure coordinates pixel1 = axes0.copy() pixel1[:, 0] += 2.0 pixel1 = invertedTransLimits.transform(pixel1) world1 = self.transform.transform(pixel1)[:, self.coord_index] # Advance 2 pixels in figure coordinates pixel2 = axes0.copy() pixel2[:, 1] += 2.0 if self.frame.origin == 'lower' else -2.0 pixel2 = invertedTransLimits.transform(pixel2) world2 = self.transform.transform(pixel2)[:, self.coord_index] dx = (world1 - world0) dy = (world2 - world0) # Rotate by 90 degrees dx, dy = -dy, dx if self._coord_unit_scale is not None: dx *= self._coord_unit_scale dy *= self._coord_unit_scale if self.coord_type == 'longitude': # Here we wrap at 180 not self.coord_wrap since we want to # always ensure abs(dx) < 180 and abs(dy) < 180 dx = wrap_angle_at(dx, 180.) dy = wrap_angle_at(dy, 180.) tick_angle = np.degrees(np.arctan2(dy, dx)) normal_angle_full = np.hstack([spine.normal_angle, spine.normal_angle[-1]]) with np.errstate(invalid='ignore'): reset = (((normal_angle_full - tick_angle) % 360 > 90.) & ((tick_angle - normal_angle_full) % 360 > 90.)) tick_angle[reset] -= 180. # We find for each interval the starting and ending coordinate, # ensuring that we take wrapping into account correctly for # longitudes. w1 = spine.world[:-1, self.coord_index] w2 = spine.world[1:, self.coord_index] if self._coord_unit_scale is not None: w1 = w1 * self._coord_unit_scale w2 = w2 * self._coord_unit_scale if self.coord_type == 'longitude': w1 = wrap_angle_at(w1, self.coord_wrap) w2 = wrap_angle_at(w2, self.coord_wrap) with np.errstate(invalid='ignore'): w1[w2 - w1 > 180.] += 360 w2[w1 - w2 > 180.] += 360 # For longitudes, we need to check ticks as well as ticks + 360, # since the above can produce pairs such as 359 to 361 or 0.5 to # 1.5, both of which would match a tick at 0.75. Otherwise we just # check the ticks determined above. self._compute_ticks(tick_world_coordinates, spine, axis, w1, w2, tick_angle) if self.ticks.get_display_minor_ticks(): self._compute_ticks(minor_ticks_w_coordinates, spine, axis, w1, w2, tick_angle, ticks='minor') # format tick labels, add to scene text = self.formatter(self.lbl_world * tick_world_coordinates.unit, spacing=self._fl_spacing) for kwargs, txt in zip(self.lblinfo, text): self.ticklabels.add(text=txt, **kwargs) def _compute_ticks(self, tick_world_coordinates, spine, axis, w1, w2, tick_angle, ticks='major'): tick_world_coordinates_values = tick_world_coordinates.value if self.coord_type == 'longitude': tick_world_coordinates_values = np.hstack([tick_world_coordinates_values, tick_world_coordinates_values + 360]) for t in tick_world_coordinates_values: # Find steps where a tick is present. We have to check # separately for the case where the tick falls exactly on the # frame points, otherwise we'll get two matches, one for w1 and # one for w2. with np.errstate(invalid='ignore'): intersections = np.hstack([np.nonzero((t - w1) == 0)[0], np.nonzero(((t - w1) * (t - w2)) < 0)[0]]) # But we also need to check for intersection with the last w2 if t - w2[-1] == 0: intersections = np.append(intersections, len(w2) - 1) # Loop over ticks, and find exact pixel coordinates by linear # interpolation for imin in intersections: imax = imin + 1 if np.allclose(w1[imin], w2[imin], rtol=1.e-13, atol=1.e-13): continue # tick is exactly aligned with frame else: frac = (t - w1[imin]) / (w2[imin] - w1[imin]) x_data_i = spine.data[imin, 0] + frac * (spine.data[imax, 0] - spine.data[imin, 0]) y_data_i = spine.data[imin, 1] + frac * (spine.data[imax, 1] - spine.data[imin, 1]) x_pix_i = spine.pixel[imin, 0] + frac * (spine.pixel[imax, 0] - spine.pixel[imin, 0]) y_pix_i = spine.pixel[imin, 1] + frac * (spine.pixel[imax, 1] - spine.pixel[imin, 1]) delta_angle = tick_angle[imax] - tick_angle[imin] if delta_angle > 180.: delta_angle -= 360. elif delta_angle < -180.: delta_angle += 360. angle_i = tick_angle[imin] + frac * delta_angle if self.coord_type == 'longitude': world = wrap_angle_at(t, self.coord_wrap) else: world = t if ticks == 'major': self.ticks.add(axis=axis, pixel=(x_data_i, y_data_i), world=world, angle=angle_i, axis_displacement=imin + frac) # store information to pass to ticklabels.add # it's faster to format many ticklabels at once outside # of the loop self.lblinfo.append(dict(axis=axis, pixel=(x_pix_i, y_pix_i), world=world, angle=spine.normal_angle[imin], axis_displacement=imin + frac)) self.lbl_world.append(world) else: self.ticks.add_minor(minor_axis=axis, minor_pixel=(x_data_i, y_data_i), minor_world=world, minor_angle=angle_i, minor_axis_displacement=imin + frac)
[docs] def display_minor_ticks(self, display_minor_ticks): """ Display minor ticks for this coordinate. Parameters ---------- display_minor_ticks : bool Whether or not to display minor ticks. """ self.ticks.display_minor_ticks(display_minor_ticks)
[docs] def get_minor_frequency(self): return self.minor_frequency
[docs] def set_minor_frequency(self, frequency): """ Set the frequency of minor ticks per major ticks. Parameters ---------- frequency : int The number of minor ticks per major ticks. """ self.minor_frequency = frequency
def _update_grid_lines(self): # For 3-d WCS with a correlated third axis, the *proper* way of # drawing a grid should be to find the world coordinates of all pixels # and drawing contours. What we are doing here assumes that we can # define the grid lines with just two of the coordinates (and # therefore assumes that the other coordinates are fixed and set to # the value in the slice). Here we basically assume that if the WCS # had a third axis, it has been abstracted away in the transformation. coord_range = self.parent_map.get_coord_range() tick_world_coordinates, spacing = self.locator(*coord_range[self.coord_index]) tick_world_coordinates_values = tick_world_coordinates.value n_coord = len(tick_world_coordinates_values) n_samples = settings.GRID_SAMPLES xy_world = np.zeros((n_samples * n_coord, 2)) self.grid_lines = [] for iw, w in enumerate(tick_world_coordinates_values): subset = slice(iw * n_samples, (iw + 1) * n_samples) if self.coord_index == 0: xy_world[subset, 0] = np.repeat(w, n_samples) xy_world[subset, 1] = np.linspace(coord_range[1][0], coord_range[1][1], n_samples) else: xy_world[subset, 0] = np.linspace(coord_range[0][0], coord_range[0][1], n_samples) xy_world[subset, 1] = np.repeat(w, n_samples) # We now convert all the world coordinates to pixel coordinates in a # single go rather than doing this in the gridline to path conversion # to fully benefit from vectorized coordinate transformations. # Currently xy_world is in deg, but transform function needs it in # native units if self._coord_unit_scale is not None: xy_world /= self._coord_unit_scale # Transform line to pixel coordinates pixel = self.transform.inverted().transform(xy_world) # Create round-tripped values for checking xy_world_round = self.transform.transform(pixel) for iw in range(n_coord): subset = slice(iw * n_samples, (iw + 1) * n_samples) self.grid_lines.append(self._get_gridline(xy_world[subset], pixel[subset], xy_world_round[subset])) def _get_gridline(self, xy_world, pixel, xy_world_round): if self.coord_type == 'scalar': return get_gridline_path(xy_world, pixel) else: return get_lon_lat_path(xy_world, pixel, xy_world_round) def _update_grid_contour(self): if hasattr(self, '_grid'): for line in self._grid.collections: line.remove() xmin, xmax = self.parent_axes.get_xlim() ymin, ymax = self.parent_axes.get_ylim() X, Y, field = self.transform.get_coord_slices(xmin, xmax, ymin, ymax, 200, 200) coord_range = self.parent_map.get_coord_range() tick_world_coordinates, spacing = self.locator(*coord_range[self.coord_index]) field = field[self.coord_index] # tick_world_coordinates is a Quantities array and we only needs its values tick_world_coordinates_values = tick_world_coordinates.value if self.coord_type == 'longitude': # Find biggest gap in tick_world_coordinates and wrap in middle # For now just assume spacing is equal, so any mid-point will do mid = 0.5 * (tick_world_coordinates_values[0] + tick_world_coordinates_values[1]) field = wrap_angle_at(field, mid) tick_world_coordinates_values = wrap_angle_at(tick_world_coordinates_values, mid) # Replace wraps by NaN reset = (np.abs(np.diff(field[:, :-1], axis=0)) > 180) | (np.abs(np.diff(field[:-1, :], axis=1)) > 180) field[:-1, :-1][reset] = np.nan field[1:, :-1][reset] = np.nan field[:-1, 1:][reset] = np.nan field[1:, 1:][reset] = np.nan if len(tick_world_coordinates_values) > 0: self._grid = self.parent_axes.contour(X, Y, field.transpose(), levels=np.sort(tick_world_coordinates_values)) else: self._grid = None