336 lines
12 KiB
Python
336 lines
12 KiB
Python
import numpy as np
|
|
|
|
from matplotlib import ticker as mticker
|
|
from matplotlib.transforms import Bbox, Transform
|
|
|
|
|
|
def _find_line_box_crossings(xys, bbox):
|
|
"""
|
|
Find the points where a polyline crosses a bbox, and the crossing angles.
|
|
|
|
Parameters
|
|
----------
|
|
xys : (N, 2) array
|
|
The polyline coordinates.
|
|
bbox : `.Bbox`
|
|
The bounding box.
|
|
|
|
Returns
|
|
-------
|
|
list of ((float, float), float)
|
|
Four separate lists of crossings, for the left, right, bottom, and top
|
|
sides of the bbox, respectively. For each list, the entries are the
|
|
``((x, y), ccw_angle_in_degrees)`` of the crossing, where an angle of 0
|
|
means that the polyline is moving to the right at the crossing point.
|
|
|
|
The entries are computed by linearly interpolating at each crossing
|
|
between the nearest points on either side of the bbox edges.
|
|
"""
|
|
crossings = []
|
|
dxys = xys[1:] - xys[:-1]
|
|
for sl in [slice(None), slice(None, None, -1)]:
|
|
us, vs = xys.T[sl] # "this" coord, "other" coord
|
|
dus, dvs = dxys.T[sl]
|
|
umin, vmin = bbox.min[sl]
|
|
umax, vmax = bbox.max[sl]
|
|
for u0, inside in [(umin, us > umin), (umax, us < umax)]:
|
|
crossings.append([])
|
|
idxs, = (inside[:-1] ^ inside[1:]).nonzero()
|
|
for idx in idxs:
|
|
v = vs[idx] + (u0 - us[idx]) * dvs[idx] / dus[idx]
|
|
if not vmin <= v <= vmax:
|
|
continue
|
|
crossing = (u0, v)[sl]
|
|
theta = np.degrees(np.arctan2(*dxys[idx][::-1]))
|
|
crossings[-1].append((crossing, theta))
|
|
return crossings
|
|
|
|
|
|
class ExtremeFinderSimple:
|
|
"""
|
|
A helper class to figure out the range of grid lines that need to be drawn.
|
|
"""
|
|
|
|
def __init__(self, nx, ny):
|
|
"""
|
|
Parameters
|
|
----------
|
|
nx, ny : int
|
|
The number of samples in each direction.
|
|
"""
|
|
self.nx = nx
|
|
self.ny = ny
|
|
|
|
def __call__(self, transform_xy, x1, y1, x2, y2):
|
|
"""
|
|
Compute an approximation of the bounding box obtained by applying
|
|
*transform_xy* to the box delimited by ``(x1, y1, x2, y2)``.
|
|
|
|
The intended use is to have ``(x1, y1, x2, y2)`` in axes coordinates,
|
|
and have *transform_xy* be the transform from axes coordinates to data
|
|
coordinates; this method then returns the range of data coordinates
|
|
that span the actual axes.
|
|
|
|
The computation is done by sampling ``nx * ny`` equispaced points in
|
|
the ``(x1, y1, x2, y2)`` box and finding the resulting points with
|
|
extremal coordinates; then adding some padding to take into account the
|
|
finite sampling.
|
|
|
|
As each sampling step covers a relative range of *1/nx* or *1/ny*,
|
|
the padding is computed by expanding the span covered by the extremal
|
|
coordinates by these fractions.
|
|
"""
|
|
x, y = np.meshgrid(
|
|
np.linspace(x1, x2, self.nx), np.linspace(y1, y2, self.ny))
|
|
xt, yt = transform_xy(np.ravel(x), np.ravel(y))
|
|
return self._add_pad(xt.min(), xt.max(), yt.min(), yt.max())
|
|
|
|
def _add_pad(self, x_min, x_max, y_min, y_max):
|
|
"""Perform the padding mentioned in `__call__`."""
|
|
dx = (x_max - x_min) / self.nx
|
|
dy = (y_max - y_min) / self.ny
|
|
return x_min - dx, x_max + dx, y_min - dy, y_max + dy
|
|
|
|
|
|
class _User2DTransform(Transform):
|
|
"""A transform defined by two user-set functions."""
|
|
|
|
input_dims = output_dims = 2
|
|
|
|
def __init__(self, forward, backward):
|
|
"""
|
|
Parameters
|
|
----------
|
|
forward, backward : callable
|
|
The forward and backward transforms, taking ``x`` and ``y`` as
|
|
separate arguments and returning ``(tr_x, tr_y)``.
|
|
"""
|
|
# The normal Matplotlib convention would be to take and return an
|
|
# (N, 2) array but axisartist uses the transposed version.
|
|
super().__init__()
|
|
self._forward = forward
|
|
self._backward = backward
|
|
|
|
def transform_non_affine(self, values):
|
|
# docstring inherited
|
|
return np.transpose(self._forward(*np.transpose(values)))
|
|
|
|
def inverted(self):
|
|
# docstring inherited
|
|
return type(self)(self._backward, self._forward)
|
|
|
|
|
|
class GridFinder:
|
|
"""
|
|
Internal helper for `~.grid_helper_curvelinear.GridHelperCurveLinear`, with
|
|
the same constructor parameters; should not be directly instantiated.
|
|
"""
|
|
|
|
def __init__(self,
|
|
transform,
|
|
extreme_finder=None,
|
|
grid_locator1=None,
|
|
grid_locator2=None,
|
|
tick_formatter1=None,
|
|
tick_formatter2=None):
|
|
if extreme_finder is None:
|
|
extreme_finder = ExtremeFinderSimple(20, 20)
|
|
if grid_locator1 is None:
|
|
grid_locator1 = MaxNLocator()
|
|
if grid_locator2 is None:
|
|
grid_locator2 = MaxNLocator()
|
|
if tick_formatter1 is None:
|
|
tick_formatter1 = FormatterPrettyPrint()
|
|
if tick_formatter2 is None:
|
|
tick_formatter2 = FormatterPrettyPrint()
|
|
self.extreme_finder = extreme_finder
|
|
self.grid_locator1 = grid_locator1
|
|
self.grid_locator2 = grid_locator2
|
|
self.tick_formatter1 = tick_formatter1
|
|
self.tick_formatter2 = tick_formatter2
|
|
self.set_transform(transform)
|
|
|
|
def get_grid_info(self, x1, y1, x2, y2):
|
|
"""
|
|
lon_values, lat_values : list of grid values. if integer is given,
|
|
rough number of grids in each direction.
|
|
"""
|
|
|
|
extremes = self.extreme_finder(self.inv_transform_xy, x1, y1, x2, y2)
|
|
|
|
# min & max rage of lat (or lon) for each grid line will be drawn.
|
|
# i.e., gridline of lon=0 will be drawn from lat_min to lat_max.
|
|
|
|
lon_min, lon_max, lat_min, lat_max = extremes
|
|
lon_levs, lon_n, lon_factor = self.grid_locator1(lon_min, lon_max)
|
|
lon_levs = np.asarray(lon_levs)
|
|
lat_levs, lat_n, lat_factor = self.grid_locator2(lat_min, lat_max)
|
|
lat_levs = np.asarray(lat_levs)
|
|
|
|
lon_values = lon_levs[:lon_n] / lon_factor
|
|
lat_values = lat_levs[:lat_n] / lat_factor
|
|
|
|
lon_lines, lat_lines = self._get_raw_grid_lines(lon_values,
|
|
lat_values,
|
|
lon_min, lon_max,
|
|
lat_min, lat_max)
|
|
|
|
ddx = (x2-x1)*1.e-10
|
|
ddy = (y2-y1)*1.e-10
|
|
bb = Bbox.from_extents(x1-ddx, y1-ddy, x2+ddx, y2+ddy)
|
|
|
|
grid_info = {
|
|
"extremes": extremes,
|
|
"lon_lines": lon_lines,
|
|
"lat_lines": lat_lines,
|
|
"lon": self._clip_grid_lines_and_find_ticks(
|
|
lon_lines, lon_values, lon_levs, bb),
|
|
"lat": self._clip_grid_lines_and_find_ticks(
|
|
lat_lines, lat_values, lat_levs, bb),
|
|
}
|
|
|
|
tck_labels = grid_info["lon"]["tick_labels"] = {}
|
|
for direction in ["left", "bottom", "right", "top"]:
|
|
levs = grid_info["lon"]["tick_levels"][direction]
|
|
tck_labels[direction] = self.tick_formatter1(
|
|
direction, lon_factor, levs)
|
|
|
|
tck_labels = grid_info["lat"]["tick_labels"] = {}
|
|
for direction in ["left", "bottom", "right", "top"]:
|
|
levs = grid_info["lat"]["tick_levels"][direction]
|
|
tck_labels[direction] = self.tick_formatter2(
|
|
direction, lat_factor, levs)
|
|
|
|
return grid_info
|
|
|
|
def _get_raw_grid_lines(self,
|
|
lon_values, lat_values,
|
|
lon_min, lon_max, lat_min, lat_max):
|
|
|
|
lons_i = np.linspace(lon_min, lon_max, 100) # for interpolation
|
|
lats_i = np.linspace(lat_min, lat_max, 100)
|
|
|
|
lon_lines = [self.transform_xy(np.full_like(lats_i, lon), lats_i)
|
|
for lon in lon_values]
|
|
lat_lines = [self.transform_xy(lons_i, np.full_like(lons_i, lat))
|
|
for lat in lat_values]
|
|
|
|
return lon_lines, lat_lines
|
|
|
|
def _clip_grid_lines_and_find_ticks(self, lines, values, levs, bb):
|
|
gi = {
|
|
"values": [],
|
|
"levels": [],
|
|
"tick_levels": dict(left=[], bottom=[], right=[], top=[]),
|
|
"tick_locs": dict(left=[], bottom=[], right=[], top=[]),
|
|
"lines": [],
|
|
}
|
|
|
|
tck_levels = gi["tick_levels"]
|
|
tck_locs = gi["tick_locs"]
|
|
for (lx, ly), v, lev in zip(lines, values, levs):
|
|
tcks = _find_line_box_crossings(np.column_stack([lx, ly]), bb)
|
|
gi["levels"].append(v)
|
|
gi["lines"].append([(lx, ly)])
|
|
|
|
for tck, direction in zip(tcks,
|
|
["left", "right", "bottom", "top"]):
|
|
for t in tck:
|
|
tck_levels[direction].append(lev)
|
|
tck_locs[direction].append(t)
|
|
|
|
return gi
|
|
|
|
def set_transform(self, aux_trans):
|
|
if isinstance(aux_trans, Transform):
|
|
self._aux_transform = aux_trans
|
|
elif len(aux_trans) == 2 and all(map(callable, aux_trans)):
|
|
self._aux_transform = _User2DTransform(*aux_trans)
|
|
else:
|
|
raise TypeError("'aux_trans' must be either a Transform "
|
|
"instance or a pair of callables")
|
|
|
|
def get_transform(self):
|
|
return self._aux_transform
|
|
|
|
update_transform = set_transform # backcompat alias.
|
|
|
|
def transform_xy(self, x, y):
|
|
return self._aux_transform.transform(np.column_stack([x, y])).T
|
|
|
|
def inv_transform_xy(self, x, y):
|
|
return self._aux_transform.inverted().transform(
|
|
np.column_stack([x, y])).T
|
|
|
|
def update(self, **kwargs):
|
|
for k, v in kwargs.items():
|
|
if k in ["extreme_finder",
|
|
"grid_locator1",
|
|
"grid_locator2",
|
|
"tick_formatter1",
|
|
"tick_formatter2"]:
|
|
setattr(self, k, v)
|
|
else:
|
|
raise ValueError(f"Unknown update property {k!r}")
|
|
|
|
|
|
class MaxNLocator(mticker.MaxNLocator):
|
|
def __init__(self, nbins=10, steps=None,
|
|
trim=True,
|
|
integer=False,
|
|
symmetric=False,
|
|
prune=None):
|
|
# trim argument has no effect. It has been left for API compatibility
|
|
super().__init__(nbins, steps=steps, integer=integer,
|
|
symmetric=symmetric, prune=prune)
|
|
self.create_dummy_axis()
|
|
|
|
def __call__(self, v1, v2):
|
|
locs = super().tick_values(v1, v2)
|
|
return np.array(locs), len(locs), 1 # 1: factor (see angle_helper)
|
|
|
|
|
|
class FixedLocator:
|
|
def __init__(self, locs):
|
|
self._locs = locs
|
|
|
|
def __call__(self, v1, v2):
|
|
v1, v2 = sorted([v1, v2])
|
|
locs = np.array([l for l in self._locs if v1 <= l <= v2])
|
|
return locs, len(locs), 1 # 1: factor (see angle_helper)
|
|
|
|
|
|
# Tick Formatter
|
|
|
|
class FormatterPrettyPrint:
|
|
def __init__(self, useMathText=True):
|
|
self._fmt = mticker.ScalarFormatter(
|
|
useMathText=useMathText, useOffset=False)
|
|
self._fmt.create_dummy_axis()
|
|
|
|
def __call__(self, direction, factor, values):
|
|
return self._fmt.format_ticks(values)
|
|
|
|
|
|
class DictFormatter:
|
|
def __init__(self, format_dict, formatter=None):
|
|
"""
|
|
format_dict : dictionary for format strings to be used.
|
|
formatter : fall-back formatter
|
|
"""
|
|
super().__init__()
|
|
self._format_dict = format_dict
|
|
self._fallback_formatter = formatter
|
|
|
|
def __call__(self, direction, factor, values):
|
|
"""
|
|
factor is ignored if value is found in the dictionary
|
|
"""
|
|
if self._fallback_formatter:
|
|
fallback_strings = self._fallback_formatter(
|
|
direction, factor, values)
|
|
else:
|
|
fallback_strings = [""] * len(values)
|
|
return [self._format_dict.get(k, v)
|
|
for k, v in zip(values, fallback_strings)]
|