up follow livre
This commit is contained in:
parent
b4b4398bb0
commit
3a7a3849ae
12242 changed files with 2564461 additions and 6914 deletions
|
@ -0,0 +1,10 @@
|
|||
from . import axes_size as Size
|
||||
from .axes_divider import Divider, SubplotDivider, make_axes_locatable
|
||||
from .axes_grid import AxesGrid, Grid, ImageGrid
|
||||
|
||||
from .parasite_axes import host_subplot, host_axes
|
||||
|
||||
__all__ = ["Size",
|
||||
"Divider", "SubplotDivider", "make_axes_locatable",
|
||||
"AxesGrid", "Grid", "ImageGrid",
|
||||
"host_subplot", "host_axes"]
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,414 @@
|
|||
from matplotlib import transforms
|
||||
from matplotlib.offsetbox import (AnchoredOffsetbox, AuxTransformBox,
|
||||
DrawingArea, TextArea, VPacker)
|
||||
from matplotlib.patches import (Rectangle, ArrowStyle,
|
||||
FancyArrowPatch, PathPatch)
|
||||
from matplotlib.text import TextPath
|
||||
|
||||
__all__ = ['AnchoredDrawingArea', 'AnchoredAuxTransformBox',
|
||||
'AnchoredSizeBar', 'AnchoredDirectionArrows']
|
||||
|
||||
|
||||
class AnchoredDrawingArea(AnchoredOffsetbox):
|
||||
def __init__(self, width, height, xdescent, ydescent,
|
||||
loc, pad=0.4, borderpad=0.5, prop=None, frameon=True,
|
||||
**kwargs):
|
||||
"""
|
||||
An anchored container with a fixed size and fillable `.DrawingArea`.
|
||||
|
||||
Artists added to the *drawing_area* will have their coordinates
|
||||
interpreted as pixels. Any transformations set on the artists will be
|
||||
overridden.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
width, height : float
|
||||
Width and height of the container, in pixels.
|
||||
xdescent, ydescent : float
|
||||
Descent of the container in the x- and y- direction, in pixels.
|
||||
loc : str
|
||||
Location of this artist. Valid locations are
|
||||
'upper left', 'upper center', 'upper right',
|
||||
'center left', 'center', 'center right',
|
||||
'lower left', 'lower center', 'lower right'.
|
||||
For backward compatibility, numeric values are accepted as well.
|
||||
See the parameter *loc* of `.Legend` for details.
|
||||
pad : float, default: 0.4
|
||||
Padding around the child objects, in fraction of the font size.
|
||||
borderpad : float, default: 0.5
|
||||
Border padding, in fraction of the font size.
|
||||
prop : `~matplotlib.font_manager.FontProperties`, optional
|
||||
Font property used as a reference for paddings.
|
||||
frameon : bool, default: True
|
||||
If True, draw a box around this artist.
|
||||
**kwargs
|
||||
Keyword arguments forwarded to `.AnchoredOffsetbox`.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
drawing_area : `~matplotlib.offsetbox.DrawingArea`
|
||||
A container for artists to display.
|
||||
|
||||
Examples
|
||||
--------
|
||||
To display blue and red circles of different sizes in the upper right
|
||||
of an Axes *ax*:
|
||||
|
||||
>>> ada = AnchoredDrawingArea(20, 20, 0, 0,
|
||||
... loc='upper right', frameon=False)
|
||||
>>> ada.drawing_area.add_artist(Circle((10, 10), 10, fc="b"))
|
||||
>>> ada.drawing_area.add_artist(Circle((30, 10), 5, fc="r"))
|
||||
>>> ax.add_artist(ada)
|
||||
"""
|
||||
self.da = DrawingArea(width, height, xdescent, ydescent)
|
||||
self.drawing_area = self.da
|
||||
|
||||
super().__init__(
|
||||
loc, pad=pad, borderpad=borderpad, child=self.da, prop=None,
|
||||
frameon=frameon, **kwargs
|
||||
)
|
||||
|
||||
|
||||
class AnchoredAuxTransformBox(AnchoredOffsetbox):
|
||||
def __init__(self, transform, loc,
|
||||
pad=0.4, borderpad=0.5, prop=None, frameon=True, **kwargs):
|
||||
"""
|
||||
An anchored container with transformed coordinates.
|
||||
|
||||
Artists added to the *drawing_area* are scaled according to the
|
||||
coordinates of the transformation used. The dimensions of this artist
|
||||
will scale to contain the artists added.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
transform : `~matplotlib.transforms.Transform`
|
||||
The transformation object for the coordinate system in use, i.e.,
|
||||
:attr:`matplotlib.axes.Axes.transData`.
|
||||
loc : str
|
||||
Location of this artist. Valid locations are
|
||||
'upper left', 'upper center', 'upper right',
|
||||
'center left', 'center', 'center right',
|
||||
'lower left', 'lower center', 'lower right'.
|
||||
For backward compatibility, numeric values are accepted as well.
|
||||
See the parameter *loc* of `.Legend` for details.
|
||||
pad : float, default: 0.4
|
||||
Padding around the child objects, in fraction of the font size.
|
||||
borderpad : float, default: 0.5
|
||||
Border padding, in fraction of the font size.
|
||||
prop : `~matplotlib.font_manager.FontProperties`, optional
|
||||
Font property used as a reference for paddings.
|
||||
frameon : bool, default: True
|
||||
If True, draw a box around this artist.
|
||||
**kwargs
|
||||
Keyword arguments forwarded to `.AnchoredOffsetbox`.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
drawing_area : `~matplotlib.offsetbox.AuxTransformBox`
|
||||
A container for artists to display.
|
||||
|
||||
Examples
|
||||
--------
|
||||
To display an ellipse in the upper left, with a width of 0.1 and
|
||||
height of 0.4 in data coordinates:
|
||||
|
||||
>>> box = AnchoredAuxTransformBox(ax.transData, loc='upper left')
|
||||
>>> el = Ellipse((0, 0), width=0.1, height=0.4, angle=30)
|
||||
>>> box.drawing_area.add_artist(el)
|
||||
>>> ax.add_artist(box)
|
||||
"""
|
||||
self.drawing_area = AuxTransformBox(transform)
|
||||
|
||||
super().__init__(loc, pad=pad, borderpad=borderpad,
|
||||
child=self.drawing_area, prop=prop, frameon=frameon,
|
||||
**kwargs)
|
||||
|
||||
|
||||
class AnchoredSizeBar(AnchoredOffsetbox):
|
||||
def __init__(self, transform, size, label, loc,
|
||||
pad=0.1, borderpad=0.1, sep=2,
|
||||
frameon=True, size_vertical=0, color='black',
|
||||
label_top=False, fontproperties=None, fill_bar=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Draw a horizontal scale bar with a center-aligned label underneath.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
transform : `~matplotlib.transforms.Transform`
|
||||
The transformation object for the coordinate system in use, i.e.,
|
||||
:attr:`matplotlib.axes.Axes.transData`.
|
||||
size : float
|
||||
Horizontal length of the size bar, given in coordinates of
|
||||
*transform*.
|
||||
label : str
|
||||
Label to display.
|
||||
loc : str
|
||||
Location of the size bar. Valid locations are
|
||||
'upper left', 'upper center', 'upper right',
|
||||
'center left', 'center', 'center right',
|
||||
'lower left', 'lower center', 'lower right'.
|
||||
For backward compatibility, numeric values are accepted as well.
|
||||
See the parameter *loc* of `.Legend` for details.
|
||||
pad : float, default: 0.1
|
||||
Padding around the label and size bar, in fraction of the font
|
||||
size.
|
||||
borderpad : float, default: 0.1
|
||||
Border padding, in fraction of the font size.
|
||||
sep : float, default: 2
|
||||
Separation between the label and the size bar, in points.
|
||||
frameon : bool, default: True
|
||||
If True, draw a box around the horizontal bar and label.
|
||||
size_vertical : float, default: 0
|
||||
Vertical length of the size bar, given in coordinates of
|
||||
*transform*.
|
||||
color : str, default: 'black'
|
||||
Color for the size bar and label.
|
||||
label_top : bool, default: False
|
||||
If True, the label will be over the size bar.
|
||||
fontproperties : `~matplotlib.font_manager.FontProperties`, optional
|
||||
Font properties for the label text.
|
||||
fill_bar : bool, optional
|
||||
If True and if *size_vertical* is nonzero, the size bar will
|
||||
be filled in with the color specified by the size bar.
|
||||
Defaults to True if *size_vertical* is greater than
|
||||
zero and False otherwise.
|
||||
**kwargs
|
||||
Keyword arguments forwarded to `.AnchoredOffsetbox`.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
size_bar : `~matplotlib.offsetbox.AuxTransformBox`
|
||||
Container for the size bar.
|
||||
txt_label : `~matplotlib.offsetbox.TextArea`
|
||||
Container for the label of the size bar.
|
||||
|
||||
Notes
|
||||
-----
|
||||
If *prop* is passed as a keyword argument, but *fontproperties* is
|
||||
not, then *prop* is assumed to be the intended *fontproperties*.
|
||||
Using both *prop* and *fontproperties* is not supported.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> import numpy as np
|
||||
>>> from mpl_toolkits.axes_grid1.anchored_artists import (
|
||||
... AnchoredSizeBar)
|
||||
>>> fig, ax = plt.subplots()
|
||||
>>> ax.imshow(np.random.random((10, 10)))
|
||||
>>> bar = AnchoredSizeBar(ax.transData, 3, '3 data units', 4)
|
||||
>>> ax.add_artist(bar)
|
||||
>>> fig.show()
|
||||
|
||||
Using all the optional parameters
|
||||
|
||||
>>> import matplotlib.font_manager as fm
|
||||
>>> fontprops = fm.FontProperties(size=14, family='monospace')
|
||||
>>> bar = AnchoredSizeBar(ax.transData, 3, '3 units', 4, pad=0.5,
|
||||
... sep=5, borderpad=0.5, frameon=False,
|
||||
... size_vertical=0.5, color='white',
|
||||
... fontproperties=fontprops)
|
||||
"""
|
||||
if fill_bar is None:
|
||||
fill_bar = size_vertical > 0
|
||||
|
||||
self.size_bar = AuxTransformBox(transform)
|
||||
self.size_bar.add_artist(Rectangle((0, 0), size, size_vertical,
|
||||
fill=fill_bar, facecolor=color,
|
||||
edgecolor=color))
|
||||
|
||||
if fontproperties is None and 'prop' in kwargs:
|
||||
fontproperties = kwargs.pop('prop')
|
||||
|
||||
if fontproperties is None:
|
||||
textprops = {'color': color}
|
||||
else:
|
||||
textprops = {'color': color, 'fontproperties': fontproperties}
|
||||
|
||||
self.txt_label = TextArea(label, textprops=textprops)
|
||||
|
||||
if label_top:
|
||||
_box_children = [self.txt_label, self.size_bar]
|
||||
else:
|
||||
_box_children = [self.size_bar, self.txt_label]
|
||||
|
||||
self._box = VPacker(children=_box_children,
|
||||
align="center",
|
||||
pad=0, sep=sep)
|
||||
|
||||
super().__init__(loc, pad=pad, borderpad=borderpad, child=self._box,
|
||||
prop=fontproperties, frameon=frameon, **kwargs)
|
||||
|
||||
|
||||
class AnchoredDirectionArrows(AnchoredOffsetbox):
|
||||
def __init__(self, transform, label_x, label_y, length=0.15,
|
||||
fontsize=0.08, loc='upper left', angle=0, aspect_ratio=1,
|
||||
pad=0.4, borderpad=0.4, frameon=False, color='w', alpha=1,
|
||||
sep_x=0.01, sep_y=0, fontproperties=None, back_length=0.15,
|
||||
head_width=10, head_length=15, tail_width=2,
|
||||
text_props=None, arrow_props=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Draw two perpendicular arrows to indicate directions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
transform : `~matplotlib.transforms.Transform`
|
||||
The transformation object for the coordinate system in use, i.e.,
|
||||
:attr:`matplotlib.axes.Axes.transAxes`.
|
||||
label_x, label_y : str
|
||||
Label text for the x and y arrows
|
||||
length : float, default: 0.15
|
||||
Length of the arrow, given in coordinates of *transform*.
|
||||
fontsize : float, default: 0.08
|
||||
Size of label strings, given in coordinates of *transform*.
|
||||
loc : str, default: 'upper left'
|
||||
Location of the arrow. Valid locations are
|
||||
'upper left', 'upper center', 'upper right',
|
||||
'center left', 'center', 'center right',
|
||||
'lower left', 'lower center', 'lower right'.
|
||||
For backward compatibility, numeric values are accepted as well.
|
||||
See the parameter *loc* of `.Legend` for details.
|
||||
angle : float, default: 0
|
||||
The angle of the arrows in degrees.
|
||||
aspect_ratio : float, default: 1
|
||||
The ratio of the length of arrow_x and arrow_y.
|
||||
Negative numbers can be used to change the direction.
|
||||
pad : float, default: 0.4
|
||||
Padding around the labels and arrows, in fraction of the font size.
|
||||
borderpad : float, default: 0.4
|
||||
Border padding, in fraction of the font size.
|
||||
frameon : bool, default: False
|
||||
If True, draw a box around the arrows and labels.
|
||||
color : str, default: 'white'
|
||||
Color for the arrows and labels.
|
||||
alpha : float, default: 1
|
||||
Alpha values of the arrows and labels
|
||||
sep_x, sep_y : float, default: 0.01 and 0 respectively
|
||||
Separation between the arrows and labels in coordinates of
|
||||
*transform*.
|
||||
fontproperties : `~matplotlib.font_manager.FontProperties`, optional
|
||||
Font properties for the label text.
|
||||
back_length : float, default: 0.15
|
||||
Fraction of the arrow behind the arrow crossing.
|
||||
head_width : float, default: 10
|
||||
Width of arrow head, sent to `.ArrowStyle`.
|
||||
head_length : float, default: 15
|
||||
Length of arrow head, sent to `.ArrowStyle`.
|
||||
tail_width : float, default: 2
|
||||
Width of arrow tail, sent to `.ArrowStyle`.
|
||||
text_props, arrow_props : dict
|
||||
Properties of the text and arrows, passed to `.TextPath` and
|
||||
`.FancyArrowPatch`.
|
||||
**kwargs
|
||||
Keyword arguments forwarded to `.AnchoredOffsetbox`.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
arrow_x, arrow_y : `~matplotlib.patches.FancyArrowPatch`
|
||||
Arrow x and y
|
||||
text_path_x, text_path_y : `~matplotlib.text.TextPath`
|
||||
Path for arrow labels
|
||||
p_x, p_y : `~matplotlib.patches.PathPatch`
|
||||
Patch for arrow labels
|
||||
box : `~matplotlib.offsetbox.AuxTransformBox`
|
||||
Container for the arrows and labels.
|
||||
|
||||
Notes
|
||||
-----
|
||||
If *prop* is passed as a keyword argument, but *fontproperties* is
|
||||
not, then *prop* is assumed to be the intended *fontproperties*.
|
||||
Using both *prop* and *fontproperties* is not supported.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> import numpy as np
|
||||
>>> from mpl_toolkits.axes_grid1.anchored_artists import (
|
||||
... AnchoredDirectionArrows)
|
||||
>>> fig, ax = plt.subplots()
|
||||
>>> ax.imshow(np.random.random((10, 10)))
|
||||
>>> arrows = AnchoredDirectionArrows(ax.transAxes, '111', '110')
|
||||
>>> ax.add_artist(arrows)
|
||||
>>> fig.show()
|
||||
|
||||
Using several of the optional parameters, creating downward pointing
|
||||
arrow and high contrast text labels.
|
||||
|
||||
>>> import matplotlib.font_manager as fm
|
||||
>>> fontprops = fm.FontProperties(family='monospace')
|
||||
>>> arrows = AnchoredDirectionArrows(ax.transAxes, 'East', 'South',
|
||||
... loc='lower left', color='k',
|
||||
... aspect_ratio=-1, sep_x=0.02,
|
||||
... sep_y=-0.01,
|
||||
... text_props={'ec':'w', 'fc':'k'},
|
||||
... fontproperties=fontprops)
|
||||
"""
|
||||
if arrow_props is None:
|
||||
arrow_props = {}
|
||||
|
||||
if text_props is None:
|
||||
text_props = {}
|
||||
|
||||
arrowstyle = ArrowStyle("Simple",
|
||||
head_width=head_width,
|
||||
head_length=head_length,
|
||||
tail_width=tail_width)
|
||||
|
||||
if fontproperties is None and 'prop' in kwargs:
|
||||
fontproperties = kwargs.pop('prop')
|
||||
|
||||
if 'color' not in arrow_props:
|
||||
arrow_props['color'] = color
|
||||
|
||||
if 'alpha' not in arrow_props:
|
||||
arrow_props['alpha'] = alpha
|
||||
|
||||
if 'color' not in text_props:
|
||||
text_props['color'] = color
|
||||
|
||||
if 'alpha' not in text_props:
|
||||
text_props['alpha'] = alpha
|
||||
|
||||
t_start = transform
|
||||
t_end = t_start + transforms.Affine2D().rotate_deg(angle)
|
||||
|
||||
self.box = AuxTransformBox(t_end)
|
||||
|
||||
length_x = length
|
||||
length_y = length*aspect_ratio
|
||||
|
||||
self.arrow_x = FancyArrowPatch(
|
||||
(0, back_length*length_y),
|
||||
(length_x, back_length*length_y),
|
||||
arrowstyle=arrowstyle,
|
||||
shrinkA=0.0,
|
||||
shrinkB=0.0,
|
||||
**arrow_props)
|
||||
|
||||
self.arrow_y = FancyArrowPatch(
|
||||
(back_length*length_x, 0),
|
||||
(back_length*length_x, length_y),
|
||||
arrowstyle=arrowstyle,
|
||||
shrinkA=0.0,
|
||||
shrinkB=0.0,
|
||||
**arrow_props)
|
||||
|
||||
self.box.add_artist(self.arrow_x)
|
||||
self.box.add_artist(self.arrow_y)
|
||||
|
||||
text_path_x = TextPath((
|
||||
length_x+sep_x, back_length*length_y+sep_y), label_x,
|
||||
size=fontsize, prop=fontproperties)
|
||||
self.p_x = PathPatch(text_path_x, transform=t_start, **text_props)
|
||||
self.box.add_artist(self.p_x)
|
||||
|
||||
text_path_y = TextPath((
|
||||
length_x*back_length+sep_x, length_y*(1-back_length)+sep_y),
|
||||
label_y, size=fontsize, prop=fontproperties)
|
||||
self.p_y = PathPatch(text_path_y, **text_props)
|
||||
self.box.add_artist(self.p_y)
|
||||
|
||||
super().__init__(loc, pad=pad, borderpad=borderpad, child=self.box,
|
||||
frameon=frameon, **kwargs)
|
|
@ -0,0 +1,618 @@
|
|||
"""
|
||||
Helper classes to adjust the positions of multiple axes at drawing time.
|
||||
"""
|
||||
|
||||
import functools
|
||||
|
||||
import numpy as np
|
||||
|
||||
import matplotlib as mpl
|
||||
from matplotlib import _api
|
||||
from matplotlib.gridspec import SubplotSpec
|
||||
import matplotlib.transforms as mtransforms
|
||||
from . import axes_size as Size
|
||||
|
||||
|
||||
class Divider:
|
||||
"""
|
||||
An Axes positioning class.
|
||||
|
||||
The divider is initialized with lists of horizontal and vertical sizes
|
||||
(:mod:`mpl_toolkits.axes_grid1.axes_size`) based on which a given
|
||||
rectangular area will be divided.
|
||||
|
||||
The `new_locator` method then creates a callable object
|
||||
that can be used as the *axes_locator* of the axes.
|
||||
"""
|
||||
|
||||
def __init__(self, fig, pos, horizontal, vertical,
|
||||
aspect=None, anchor="C"):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
fig : Figure
|
||||
pos : tuple of 4 floats
|
||||
Position of the rectangle that will be divided.
|
||||
horizontal : list of :mod:`~mpl_toolkits.axes_grid1.axes_size`
|
||||
Sizes for horizontal division.
|
||||
vertical : list of :mod:`~mpl_toolkits.axes_grid1.axes_size`
|
||||
Sizes for vertical division.
|
||||
aspect : bool, optional
|
||||
Whether overall rectangular area is reduced so that the relative
|
||||
part of the horizontal and vertical scales have the same scale.
|
||||
anchor : (float, float) or {'C', 'SW', 'S', 'SE', 'E', 'NE', 'N', \
|
||||
'NW', 'W'}, default: 'C'
|
||||
Placement of the reduced rectangle, when *aspect* is True.
|
||||
"""
|
||||
|
||||
self._fig = fig
|
||||
self._pos = pos
|
||||
self._horizontal = horizontal
|
||||
self._vertical = vertical
|
||||
self._anchor = anchor
|
||||
self.set_anchor(anchor)
|
||||
self._aspect = aspect
|
||||
self._xrefindex = 0
|
||||
self._yrefindex = 0
|
||||
self._locator = None
|
||||
|
||||
def get_horizontal_sizes(self, renderer):
|
||||
return np.array([s.get_size(renderer) for s in self.get_horizontal()])
|
||||
|
||||
def get_vertical_sizes(self, renderer):
|
||||
return np.array([s.get_size(renderer) for s in self.get_vertical()])
|
||||
|
||||
def set_position(self, pos):
|
||||
"""
|
||||
Set the position of the rectangle.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pos : tuple of 4 floats
|
||||
position of the rectangle that will be divided
|
||||
"""
|
||||
self._pos = pos
|
||||
|
||||
def get_position(self):
|
||||
"""Return the position of the rectangle."""
|
||||
return self._pos
|
||||
|
||||
def set_anchor(self, anchor):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
anchor : (float, float) or {'C', 'SW', 'S', 'SE', 'E', 'NE', 'N', \
|
||||
'NW', 'W'}
|
||||
Either an (*x*, *y*) pair of relative coordinates (0 is left or
|
||||
bottom, 1 is right or top), 'C' (center), or a cardinal direction
|
||||
('SW', southwest, is bottom left, etc.).
|
||||
|
||||
See Also
|
||||
--------
|
||||
.Axes.set_anchor
|
||||
"""
|
||||
if isinstance(anchor, str):
|
||||
_api.check_in_list(mtransforms.Bbox.coefs, anchor=anchor)
|
||||
elif not isinstance(anchor, (tuple, list)) or len(anchor) != 2:
|
||||
raise TypeError("anchor must be str or 2-tuple")
|
||||
self._anchor = anchor
|
||||
|
||||
def get_anchor(self):
|
||||
"""Return the anchor."""
|
||||
return self._anchor
|
||||
|
||||
def get_subplotspec(self):
|
||||
return None
|
||||
|
||||
def set_horizontal(self, h):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
h : list of :mod:`~mpl_toolkits.axes_grid1.axes_size`
|
||||
sizes for horizontal division
|
||||
"""
|
||||
self._horizontal = h
|
||||
|
||||
def get_horizontal(self):
|
||||
"""Return horizontal sizes."""
|
||||
return self._horizontal
|
||||
|
||||
def set_vertical(self, v):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
v : list of :mod:`~mpl_toolkits.axes_grid1.axes_size`
|
||||
sizes for vertical division
|
||||
"""
|
||||
self._vertical = v
|
||||
|
||||
def get_vertical(self):
|
||||
"""Return vertical sizes."""
|
||||
return self._vertical
|
||||
|
||||
def set_aspect(self, aspect=False):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
aspect : bool
|
||||
"""
|
||||
self._aspect = aspect
|
||||
|
||||
def get_aspect(self):
|
||||
"""Return aspect."""
|
||||
return self._aspect
|
||||
|
||||
def set_locator(self, _locator):
|
||||
self._locator = _locator
|
||||
|
||||
def get_locator(self):
|
||||
return self._locator
|
||||
|
||||
def get_position_runtime(self, ax, renderer):
|
||||
if self._locator is None:
|
||||
return self.get_position()
|
||||
else:
|
||||
return self._locator(ax, renderer).bounds
|
||||
|
||||
@staticmethod
|
||||
def _calc_k(sizes, total):
|
||||
# sizes is a (n, 2) array of (rel_size, abs_size); this method finds
|
||||
# the k factor such that sum(rel_size * k + abs_size) == total.
|
||||
rel_sum, abs_sum = sizes.sum(0)
|
||||
return (total - abs_sum) / rel_sum if rel_sum else 0
|
||||
|
||||
@staticmethod
|
||||
def _calc_offsets(sizes, k):
|
||||
# Apply k factors to (n, 2) sizes array of (rel_size, abs_size); return
|
||||
# the resulting cumulative offset positions.
|
||||
return np.cumsum([0, *(sizes @ [k, 1])])
|
||||
|
||||
def new_locator(self, nx, ny, nx1=None, ny1=None):
|
||||
"""
|
||||
Return an axes locator callable for the specified cell.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
nx, nx1 : int
|
||||
Integers specifying the column-position of the
|
||||
cell. When *nx1* is None, a single *nx*-th column is
|
||||
specified. Otherwise, location of columns spanning between *nx*
|
||||
to *nx1* (but excluding *nx1*-th column) is specified.
|
||||
ny, ny1 : int
|
||||
Same as *nx* and *nx1*, but for row positions.
|
||||
"""
|
||||
if nx1 is None:
|
||||
nx1 = nx + 1
|
||||
if ny1 is None:
|
||||
ny1 = ny + 1
|
||||
# append_size("left") adds a new size at the beginning of the
|
||||
# horizontal size lists; this shift transforms e.g.
|
||||
# new_locator(nx=2, ...) into effectively new_locator(nx=3, ...). To
|
||||
# take that into account, instead of recording nx, we record
|
||||
# nx-self._xrefindex, where _xrefindex is shifted by 1 by each
|
||||
# append_size("left"), and re-add self._xrefindex back to nx in
|
||||
# _locate, when the actual axes position is computed. Ditto for y.
|
||||
xref = self._xrefindex
|
||||
yref = self._yrefindex
|
||||
locator = functools.partial(
|
||||
self._locate, nx - xref, ny - yref, nx1 - xref, ny1 - yref)
|
||||
locator.get_subplotspec = self.get_subplotspec
|
||||
return locator
|
||||
|
||||
def _locate(self, nx, ny, nx1, ny1, axes, renderer):
|
||||
"""
|
||||
Implementation of ``divider.new_locator().__call__``.
|
||||
|
||||
The axes locator callable returned by ``new_locator()`` is created as
|
||||
a `functools.partial` of this method with *nx*, *ny*, *nx1*, and *ny1*
|
||||
specifying the requested cell.
|
||||
"""
|
||||
nx += self._xrefindex
|
||||
nx1 += self._xrefindex
|
||||
ny += self._yrefindex
|
||||
ny1 += self._yrefindex
|
||||
|
||||
fig_w, fig_h = self._fig.bbox.size / self._fig.dpi
|
||||
x, y, w, h = self.get_position_runtime(axes, renderer)
|
||||
|
||||
hsizes = self.get_horizontal_sizes(renderer)
|
||||
vsizes = self.get_vertical_sizes(renderer)
|
||||
k_h = self._calc_k(hsizes, fig_w * w)
|
||||
k_v = self._calc_k(vsizes, fig_h * h)
|
||||
|
||||
if self.get_aspect():
|
||||
k = min(k_h, k_v)
|
||||
ox = self._calc_offsets(hsizes, k)
|
||||
oy = self._calc_offsets(vsizes, k)
|
||||
|
||||
ww = (ox[-1] - ox[0]) / fig_w
|
||||
hh = (oy[-1] - oy[0]) / fig_h
|
||||
pb = mtransforms.Bbox.from_bounds(x, y, w, h)
|
||||
pb1 = mtransforms.Bbox.from_bounds(x, y, ww, hh)
|
||||
x0, y0 = pb1.anchored(self.get_anchor(), pb).p0
|
||||
|
||||
else:
|
||||
ox = self._calc_offsets(hsizes, k_h)
|
||||
oy = self._calc_offsets(vsizes, k_v)
|
||||
x0, y0 = x, y
|
||||
|
||||
if nx1 is None:
|
||||
nx1 = -1
|
||||
if ny1 is None:
|
||||
ny1 = -1
|
||||
|
||||
x1, w1 = x0 + ox[nx] / fig_w, (ox[nx1] - ox[nx]) / fig_w
|
||||
y1, h1 = y0 + oy[ny] / fig_h, (oy[ny1] - oy[ny]) / fig_h
|
||||
|
||||
return mtransforms.Bbox.from_bounds(x1, y1, w1, h1)
|
||||
|
||||
def append_size(self, position, size):
|
||||
_api.check_in_list(["left", "right", "bottom", "top"],
|
||||
position=position)
|
||||
if position == "left":
|
||||
self._horizontal.insert(0, size)
|
||||
self._xrefindex += 1
|
||||
elif position == "right":
|
||||
self._horizontal.append(size)
|
||||
elif position == "bottom":
|
||||
self._vertical.insert(0, size)
|
||||
self._yrefindex += 1
|
||||
else: # 'top'
|
||||
self._vertical.append(size)
|
||||
|
||||
def add_auto_adjustable_area(self, use_axes, pad=0.1, adjust_dirs=None):
|
||||
"""
|
||||
Add auto-adjustable padding around *use_axes* to take their decorations
|
||||
(title, labels, ticks, ticklabels) into account during layout.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
use_axes : `~matplotlib.axes.Axes` or list of `~matplotlib.axes.Axes`
|
||||
The Axes whose decorations are taken into account.
|
||||
pad : float, default: 0.1
|
||||
Additional padding in inches.
|
||||
adjust_dirs : list of {"left", "right", "bottom", "top"}, optional
|
||||
The sides where padding is added; defaults to all four sides.
|
||||
"""
|
||||
if adjust_dirs is None:
|
||||
adjust_dirs = ["left", "right", "bottom", "top"]
|
||||
for d in adjust_dirs:
|
||||
self.append_size(d, Size._AxesDecorationsSize(use_axes, d) + pad)
|
||||
|
||||
|
||||
class SubplotDivider(Divider):
|
||||
"""
|
||||
The Divider class whose rectangle area is specified as a subplot geometry.
|
||||
"""
|
||||
|
||||
def __init__(self, fig, *args, horizontal=None, vertical=None,
|
||||
aspect=None, anchor='C'):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
fig : `~matplotlib.figure.Figure`
|
||||
|
||||
*args : tuple (*nrows*, *ncols*, *index*) or int
|
||||
The array of subplots in the figure has dimensions ``(nrows,
|
||||
ncols)``, and *index* is the index of the subplot being created.
|
||||
*index* starts at 1 in the upper left corner and increases to the
|
||||
right.
|
||||
|
||||
If *nrows*, *ncols*, and *index* are all single digit numbers, then
|
||||
*args* can be passed as a single 3-digit number (e.g. 234 for
|
||||
(2, 3, 4)).
|
||||
horizontal : list of :mod:`~mpl_toolkits.axes_grid1.axes_size`, optional
|
||||
Sizes for horizontal division.
|
||||
vertical : list of :mod:`~mpl_toolkits.axes_grid1.axes_size`, optional
|
||||
Sizes for vertical division.
|
||||
aspect : bool, optional
|
||||
Whether overall rectangular area is reduced so that the relative
|
||||
part of the horizontal and vertical scales have the same scale.
|
||||
anchor : (float, float) or {'C', 'SW', 'S', 'SE', 'E', 'NE', 'N', \
|
||||
'NW', 'W'}, default: 'C'
|
||||
Placement of the reduced rectangle, when *aspect* is True.
|
||||
"""
|
||||
self.figure = fig
|
||||
super().__init__(fig, [0, 0, 1, 1],
|
||||
horizontal=horizontal or [], vertical=vertical or [],
|
||||
aspect=aspect, anchor=anchor)
|
||||
self.set_subplotspec(SubplotSpec._from_subplot_args(fig, args))
|
||||
|
||||
def get_position(self):
|
||||
"""Return the bounds of the subplot box."""
|
||||
return self.get_subplotspec().get_position(self.figure).bounds
|
||||
|
||||
def get_subplotspec(self):
|
||||
"""Get the SubplotSpec instance."""
|
||||
return self._subplotspec
|
||||
|
||||
def set_subplotspec(self, subplotspec):
|
||||
"""Set the SubplotSpec instance."""
|
||||
self._subplotspec = subplotspec
|
||||
self.set_position(subplotspec.get_position(self.figure))
|
||||
|
||||
|
||||
class AxesDivider(Divider):
|
||||
"""
|
||||
Divider based on the preexisting axes.
|
||||
"""
|
||||
|
||||
def __init__(self, axes, xref=None, yref=None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
axes : :class:`~matplotlib.axes.Axes`
|
||||
xref
|
||||
yref
|
||||
"""
|
||||
self._axes = axes
|
||||
if xref is None:
|
||||
self._xref = Size.AxesX(axes)
|
||||
else:
|
||||
self._xref = xref
|
||||
if yref is None:
|
||||
self._yref = Size.AxesY(axes)
|
||||
else:
|
||||
self._yref = yref
|
||||
|
||||
super().__init__(fig=axes.get_figure(), pos=None,
|
||||
horizontal=[self._xref], vertical=[self._yref],
|
||||
aspect=None, anchor="C")
|
||||
|
||||
def _get_new_axes(self, *, axes_class=None, **kwargs):
|
||||
axes = self._axes
|
||||
if axes_class is None:
|
||||
axes_class = type(axes)
|
||||
return axes_class(axes.get_figure(), axes.get_position(original=True),
|
||||
**kwargs)
|
||||
|
||||
def new_horizontal(self, size, pad=None, pack_start=False, **kwargs):
|
||||
"""
|
||||
Helper method for ``append_axes("left")`` and ``append_axes("right")``.
|
||||
|
||||
See the documentation of `append_axes` for more details.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
if pad is None:
|
||||
pad = mpl.rcParams["figure.subplot.wspace"] * self._xref
|
||||
pos = "left" if pack_start else "right"
|
||||
if pad:
|
||||
if not isinstance(pad, Size._Base):
|
||||
pad = Size.from_any(pad, fraction_ref=self._xref)
|
||||
self.append_size(pos, pad)
|
||||
if not isinstance(size, Size._Base):
|
||||
size = Size.from_any(size, fraction_ref=self._xref)
|
||||
self.append_size(pos, size)
|
||||
locator = self.new_locator(
|
||||
nx=0 if pack_start else len(self._horizontal) - 1,
|
||||
ny=self._yrefindex)
|
||||
ax = self._get_new_axes(**kwargs)
|
||||
ax.set_axes_locator(locator)
|
||||
return ax
|
||||
|
||||
def new_vertical(self, size, pad=None, pack_start=False, **kwargs):
|
||||
"""
|
||||
Helper method for ``append_axes("top")`` and ``append_axes("bottom")``.
|
||||
|
||||
See the documentation of `append_axes` for more details.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
if pad is None:
|
||||
pad = mpl.rcParams["figure.subplot.hspace"] * self._yref
|
||||
pos = "bottom" if pack_start else "top"
|
||||
if pad:
|
||||
if not isinstance(pad, Size._Base):
|
||||
pad = Size.from_any(pad, fraction_ref=self._yref)
|
||||
self.append_size(pos, pad)
|
||||
if not isinstance(size, Size._Base):
|
||||
size = Size.from_any(size, fraction_ref=self._yref)
|
||||
self.append_size(pos, size)
|
||||
locator = self.new_locator(
|
||||
nx=self._xrefindex,
|
||||
ny=0 if pack_start else len(self._vertical) - 1)
|
||||
ax = self._get_new_axes(**kwargs)
|
||||
ax.set_axes_locator(locator)
|
||||
return ax
|
||||
|
||||
def append_axes(self, position, size, pad=None, *, axes_class=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Add a new axes on a given side of the main axes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
position : {"left", "right", "bottom", "top"}
|
||||
Where the new axes is positioned relative to the main axes.
|
||||
size : :mod:`~mpl_toolkits.axes_grid1.axes_size` or float or str
|
||||
The axes width or height. float or str arguments are interpreted
|
||||
as ``axes_size.from_any(size, AxesX(<main_axes>))`` for left or
|
||||
right axes, and likewise with ``AxesY`` for bottom or top axes.
|
||||
pad : :mod:`~mpl_toolkits.axes_grid1.axes_size` or float or str
|
||||
Padding between the axes. float or str arguments are interpreted
|
||||
as for *size*. Defaults to :rc:`figure.subplot.wspace` times the
|
||||
main Axes width (left or right axes) or :rc:`figure.subplot.hspace`
|
||||
times the main Axes height (bottom or top axes).
|
||||
axes_class : subclass type of `~.axes.Axes`, optional
|
||||
The type of the new axes. Defaults to the type of the main axes.
|
||||
**kwargs
|
||||
All extra keywords arguments are passed to the created axes.
|
||||
"""
|
||||
create_axes, pack_start = _api.check_getitem({
|
||||
"left": (self.new_horizontal, True),
|
||||
"right": (self.new_horizontal, False),
|
||||
"bottom": (self.new_vertical, True),
|
||||
"top": (self.new_vertical, False),
|
||||
}, position=position)
|
||||
ax = create_axes(
|
||||
size, pad, pack_start=pack_start, axes_class=axes_class, **kwargs)
|
||||
self._fig.add_axes(ax)
|
||||
return ax
|
||||
|
||||
def get_aspect(self):
|
||||
if self._aspect is None:
|
||||
aspect = self._axes.get_aspect()
|
||||
if aspect == "auto":
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
else:
|
||||
return self._aspect
|
||||
|
||||
def get_position(self):
|
||||
if self._pos is None:
|
||||
bbox = self._axes.get_position(original=True)
|
||||
return bbox.bounds
|
||||
else:
|
||||
return self._pos
|
||||
|
||||
def get_anchor(self):
|
||||
if self._anchor is None:
|
||||
return self._axes.get_anchor()
|
||||
else:
|
||||
return self._anchor
|
||||
|
||||
def get_subplotspec(self):
|
||||
return self._axes.get_subplotspec()
|
||||
|
||||
|
||||
# Helper for HBoxDivider/VBoxDivider.
|
||||
# The variable names are written for a horizontal layout, but the calculations
|
||||
# work identically for vertical layouts.
|
||||
def _locate(x, y, w, h, summed_widths, equal_heights, fig_w, fig_h, anchor):
|
||||
|
||||
total_width = fig_w * w
|
||||
max_height = fig_h * h
|
||||
|
||||
# Determine the k factors.
|
||||
n = len(equal_heights)
|
||||
eq_rels, eq_abss = equal_heights.T
|
||||
sm_rels, sm_abss = summed_widths.T
|
||||
A = np.diag([*eq_rels, 0])
|
||||
A[:n, -1] = -1
|
||||
A[-1, :-1] = sm_rels
|
||||
B = [*(-eq_abss), total_width - sm_abss.sum()]
|
||||
# A @ K = B: This finds factors {k_0, ..., k_{N-1}, H} so that
|
||||
# eq_rel_i * k_i + eq_abs_i = H for all i: all axes have the same height
|
||||
# sum(sm_rel_i * k_i + sm_abs_i) = total_width: fixed total width
|
||||
# (foo_rel_i * k_i + foo_abs_i will end up being the size of foo.)
|
||||
*karray, height = np.linalg.solve(A, B)
|
||||
if height > max_height: # Additionally, upper-bound the height.
|
||||
karray = (max_height - eq_abss) / eq_rels
|
||||
|
||||
# Compute the offsets corresponding to these factors.
|
||||
ox = np.cumsum([0, *(sm_rels * karray + sm_abss)])
|
||||
ww = (ox[-1] - ox[0]) / fig_w
|
||||
h0_rel, h0_abs = equal_heights[0]
|
||||
hh = (karray[0]*h0_rel + h0_abs) / fig_h
|
||||
pb = mtransforms.Bbox.from_bounds(x, y, w, h)
|
||||
pb1 = mtransforms.Bbox.from_bounds(x, y, ww, hh)
|
||||
x0, y0 = pb1.anchored(anchor, pb).p0
|
||||
|
||||
return x0, y0, ox, hh
|
||||
|
||||
|
||||
class HBoxDivider(SubplotDivider):
|
||||
"""
|
||||
A `.SubplotDivider` for laying out axes horizontally, while ensuring that
|
||||
they have equal heights.
|
||||
|
||||
Examples
|
||||
--------
|
||||
.. plot:: gallery/axes_grid1/demo_axes_hbox_divider.py
|
||||
"""
|
||||
|
||||
def new_locator(self, nx, nx1=None):
|
||||
"""
|
||||
Create an axes locator callable for the specified cell.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
nx, nx1 : int
|
||||
Integers specifying the column-position of the
|
||||
cell. When *nx1* is None, a single *nx*-th column is
|
||||
specified. Otherwise, location of columns spanning between *nx*
|
||||
to *nx1* (but excluding *nx1*-th column) is specified.
|
||||
"""
|
||||
return super().new_locator(nx, 0, nx1, 0)
|
||||
|
||||
def _locate(self, nx, ny, nx1, ny1, axes, renderer):
|
||||
# docstring inherited
|
||||
nx += self._xrefindex
|
||||
nx1 += self._xrefindex
|
||||
fig_w, fig_h = self._fig.bbox.size / self._fig.dpi
|
||||
x, y, w, h = self.get_position_runtime(axes, renderer)
|
||||
summed_ws = self.get_horizontal_sizes(renderer)
|
||||
equal_hs = self.get_vertical_sizes(renderer)
|
||||
x0, y0, ox, hh = _locate(
|
||||
x, y, w, h, summed_ws, equal_hs, fig_w, fig_h, self.get_anchor())
|
||||
if nx1 is None:
|
||||
nx1 = -1
|
||||
x1, w1 = x0 + ox[nx] / fig_w, (ox[nx1] - ox[nx]) / fig_w
|
||||
y1, h1 = y0, hh
|
||||
return mtransforms.Bbox.from_bounds(x1, y1, w1, h1)
|
||||
|
||||
|
||||
class VBoxDivider(SubplotDivider):
|
||||
"""
|
||||
A `.SubplotDivider` for laying out axes vertically, while ensuring that
|
||||
they have equal widths.
|
||||
"""
|
||||
|
||||
def new_locator(self, ny, ny1=None):
|
||||
"""
|
||||
Create an axes locator callable for the specified cell.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ny, ny1 : int
|
||||
Integers specifying the row-position of the
|
||||
cell. When *ny1* is None, a single *ny*-th row is
|
||||
specified. Otherwise, location of rows spanning between *ny*
|
||||
to *ny1* (but excluding *ny1*-th row) is specified.
|
||||
"""
|
||||
return super().new_locator(0, ny, 0, ny1)
|
||||
|
||||
def _locate(self, nx, ny, nx1, ny1, axes, renderer):
|
||||
# docstring inherited
|
||||
ny += self._yrefindex
|
||||
ny1 += self._yrefindex
|
||||
fig_w, fig_h = self._fig.bbox.size / self._fig.dpi
|
||||
x, y, w, h = self.get_position_runtime(axes, renderer)
|
||||
summed_hs = self.get_vertical_sizes(renderer)
|
||||
equal_ws = self.get_horizontal_sizes(renderer)
|
||||
y0, x0, oy, ww = _locate(
|
||||
y, x, h, w, summed_hs, equal_ws, fig_h, fig_w, self.get_anchor())
|
||||
if ny1 is None:
|
||||
ny1 = -1
|
||||
x1, w1 = x0, ww
|
||||
y1, h1 = y0 + oy[ny] / fig_h, (oy[ny1] - oy[ny]) / fig_h
|
||||
return mtransforms.Bbox.from_bounds(x1, y1, w1, h1)
|
||||
|
||||
|
||||
def make_axes_locatable(axes):
|
||||
divider = AxesDivider(axes)
|
||||
locator = divider.new_locator(nx=0, ny=0)
|
||||
axes.set_axes_locator(locator)
|
||||
|
||||
return divider
|
||||
|
||||
|
||||
def make_axes_area_auto_adjustable(
|
||||
ax, use_axes=None, pad=0.1, adjust_dirs=None):
|
||||
"""
|
||||
Add auto-adjustable padding around *ax* to take its decorations (title,
|
||||
labels, ticks, ticklabels) into account during layout, using
|
||||
`.Divider.add_auto_adjustable_area`.
|
||||
|
||||
By default, padding is determined from the decorations of *ax*.
|
||||
Pass *use_axes* to consider the decorations of other Axes instead.
|
||||
"""
|
||||
if adjust_dirs is None:
|
||||
adjust_dirs = ["left", "right", "bottom", "top"]
|
||||
divider = make_axes_locatable(ax)
|
||||
if use_axes is None:
|
||||
use_axes = ax
|
||||
divider.add_auto_adjustable_area(use_axes=use_axes, pad=pad,
|
||||
adjust_dirs=adjust_dirs)
|
|
@ -0,0 +1,563 @@
|
|||
from numbers import Number
|
||||
import functools
|
||||
from types import MethodType
|
||||
|
||||
import numpy as np
|
||||
|
||||
from matplotlib import _api, cbook
|
||||
from matplotlib.gridspec import SubplotSpec
|
||||
|
||||
from .axes_divider import Size, SubplotDivider, Divider
|
||||
from .mpl_axes import Axes, SimpleAxisArtist
|
||||
|
||||
|
||||
class CbarAxesBase:
|
||||
def __init__(self, *args, orientation, **kwargs):
|
||||
self.orientation = orientation
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def colorbar(self, mappable, **kwargs):
|
||||
return self.get_figure(root=False).colorbar(
|
||||
mappable, cax=self, location=self.orientation, **kwargs)
|
||||
|
||||
|
||||
_cbaraxes_class_factory = cbook._make_class_factory(CbarAxesBase, "Cbar{}")
|
||||
|
||||
|
||||
class Grid:
|
||||
"""
|
||||
A grid of Axes.
|
||||
|
||||
In Matplotlib, the Axes location (and size) is specified in normalized
|
||||
figure coordinates. This may not be ideal for images that needs to be
|
||||
displayed with a given aspect ratio; for example, it is difficult to
|
||||
display multiple images of a same size with some fixed padding between
|
||||
them. AxesGrid can be used in such case.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
axes_all : list of Axes
|
||||
A flat list of Axes. Note that you can also access this directly
|
||||
from the grid. The following is equivalent ::
|
||||
|
||||
grid[i] == grid.axes_all[i]
|
||||
len(grid) == len(grid.axes_all)
|
||||
|
||||
axes_column : list of list of Axes
|
||||
A 2D list of Axes where the first index is the column. This results
|
||||
in the usage pattern ``grid.axes_column[col][row]``.
|
||||
axes_row : list of list of Axes
|
||||
A 2D list of Axes where the first index is the row. This results
|
||||
in the usage pattern ``grid.axes_row[row][col]``.
|
||||
axes_llc : Axes
|
||||
The Axes in the lower left corner.
|
||||
ngrids : int
|
||||
Number of Axes in the grid.
|
||||
"""
|
||||
|
||||
_defaultAxesClass = Axes
|
||||
|
||||
def __init__(self, fig,
|
||||
rect,
|
||||
nrows_ncols,
|
||||
ngrids=None,
|
||||
direction="row",
|
||||
axes_pad=0.02,
|
||||
*,
|
||||
share_all=False,
|
||||
share_x=True,
|
||||
share_y=True,
|
||||
label_mode="L",
|
||||
axes_class=None,
|
||||
aspect=False,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
fig : `.Figure`
|
||||
The parent figure.
|
||||
rect : (float, float, float, float), (int, int, int), int, or \
|
||||
`~.SubplotSpec`
|
||||
The axes position, as a ``(left, bottom, width, height)`` tuple,
|
||||
as a three-digit subplot position code (e.g., ``(1, 2, 1)`` or
|
||||
``121``), or as a `~.SubplotSpec`.
|
||||
nrows_ncols : (int, int)
|
||||
Number of rows and columns in the grid.
|
||||
ngrids : int or None, default: None
|
||||
If not None, only the first *ngrids* axes in the grid are created.
|
||||
direction : {"row", "column"}, default: "row"
|
||||
Whether axes are created in row-major ("row by row") or
|
||||
column-major order ("column by column"). This also affects the
|
||||
order in which axes are accessed using indexing (``grid[index]``).
|
||||
axes_pad : float or (float, float), default: 0.02
|
||||
Padding or (horizontal padding, vertical padding) between axes, in
|
||||
inches.
|
||||
share_all : bool, default: False
|
||||
Whether all axes share their x- and y-axis. Overrides *share_x*
|
||||
and *share_y*.
|
||||
share_x : bool, default: True
|
||||
Whether all axes of a column share their x-axis.
|
||||
share_y : bool, default: True
|
||||
Whether all axes of a row share their y-axis.
|
||||
label_mode : {"L", "1", "all", "keep"}, default: "L"
|
||||
Determines which axes will get tick labels:
|
||||
|
||||
- "L": All axes on the left column get vertical tick labels;
|
||||
all axes on the bottom row get horizontal tick labels.
|
||||
- "1": Only the bottom left axes is labelled.
|
||||
- "all": All axes are labelled.
|
||||
- "keep": Do not do anything.
|
||||
|
||||
axes_class : subclass of `matplotlib.axes.Axes`, default: `.mpl_axes.Axes`
|
||||
The type of Axes to create.
|
||||
aspect : bool, default: False
|
||||
Whether the axes aspect ratio follows the aspect ratio of the data
|
||||
limits.
|
||||
"""
|
||||
self._nrows, self._ncols = nrows_ncols
|
||||
|
||||
if ngrids is None:
|
||||
ngrids = self._nrows * self._ncols
|
||||
else:
|
||||
if not 0 < ngrids <= self._nrows * self._ncols:
|
||||
raise ValueError(
|
||||
"ngrids must be positive and not larger than nrows*ncols")
|
||||
|
||||
self.ngrids = ngrids
|
||||
|
||||
self._horiz_pad_size, self._vert_pad_size = map(
|
||||
Size.Fixed, np.broadcast_to(axes_pad, 2))
|
||||
|
||||
_api.check_in_list(["column", "row"], direction=direction)
|
||||
self._direction = direction
|
||||
|
||||
if axes_class is None:
|
||||
axes_class = self._defaultAxesClass
|
||||
elif isinstance(axes_class, (list, tuple)):
|
||||
cls, kwargs = axes_class
|
||||
axes_class = functools.partial(cls, **kwargs)
|
||||
|
||||
kw = dict(horizontal=[], vertical=[], aspect=aspect)
|
||||
if isinstance(rect, (Number, SubplotSpec)):
|
||||
self._divider = SubplotDivider(fig, rect, **kw)
|
||||
elif len(rect) == 3:
|
||||
self._divider = SubplotDivider(fig, *rect, **kw)
|
||||
elif len(rect) == 4:
|
||||
self._divider = Divider(fig, rect, **kw)
|
||||
else:
|
||||
raise TypeError("Incorrect rect format")
|
||||
|
||||
rect = self._divider.get_position()
|
||||
|
||||
axes_array = np.full((self._nrows, self._ncols), None, dtype=object)
|
||||
for i in range(self.ngrids):
|
||||
col, row = self._get_col_row(i)
|
||||
if share_all:
|
||||
sharex = sharey = axes_array[0, 0]
|
||||
else:
|
||||
sharex = axes_array[0, col] if share_x else None
|
||||
sharey = axes_array[row, 0] if share_y else None
|
||||
axes_array[row, col] = axes_class(
|
||||
fig, rect, sharex=sharex, sharey=sharey)
|
||||
self.axes_all = axes_array.ravel(
|
||||
order="C" if self._direction == "row" else "F").tolist()
|
||||
self.axes_column = axes_array.T.tolist()
|
||||
self.axes_row = axes_array.tolist()
|
||||
self.axes_llc = self.axes_column[0][-1]
|
||||
|
||||
self._init_locators()
|
||||
|
||||
for ax in self.axes_all:
|
||||
fig.add_axes(ax)
|
||||
|
||||
self.set_label_mode(label_mode)
|
||||
|
||||
def _init_locators(self):
|
||||
self._divider.set_horizontal(
|
||||
[Size.Scaled(1), self._horiz_pad_size] * (self._ncols-1) + [Size.Scaled(1)])
|
||||
self._divider.set_vertical(
|
||||
[Size.Scaled(1), self._vert_pad_size] * (self._nrows-1) + [Size.Scaled(1)])
|
||||
for i in range(self.ngrids):
|
||||
col, row = self._get_col_row(i)
|
||||
self.axes_all[i].set_axes_locator(
|
||||
self._divider.new_locator(nx=2 * col, ny=2 * (self._nrows - 1 - row)))
|
||||
|
||||
def _get_col_row(self, n):
|
||||
if self._direction == "column":
|
||||
col, row = divmod(n, self._nrows)
|
||||
else:
|
||||
row, col = divmod(n, self._ncols)
|
||||
|
||||
return col, row
|
||||
|
||||
# Good to propagate __len__ if we have __getitem__
|
||||
def __len__(self):
|
||||
return len(self.axes_all)
|
||||
|
||||
def __getitem__(self, i):
|
||||
return self.axes_all[i]
|
||||
|
||||
def get_geometry(self):
|
||||
"""
|
||||
Return the number of rows and columns of the grid as (nrows, ncols).
|
||||
"""
|
||||
return self._nrows, self._ncols
|
||||
|
||||
def set_axes_pad(self, axes_pad):
|
||||
"""
|
||||
Set the padding between the axes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
axes_pad : (float, float)
|
||||
The padding (horizontal pad, vertical pad) in inches.
|
||||
"""
|
||||
self._horiz_pad_size.fixed_size = axes_pad[0]
|
||||
self._vert_pad_size.fixed_size = axes_pad[1]
|
||||
|
||||
def get_axes_pad(self):
|
||||
"""
|
||||
Return the axes padding.
|
||||
|
||||
Returns
|
||||
-------
|
||||
hpad, vpad
|
||||
Padding (horizontal pad, vertical pad) in inches.
|
||||
"""
|
||||
return (self._horiz_pad_size.fixed_size,
|
||||
self._vert_pad_size.fixed_size)
|
||||
|
||||
def set_aspect(self, aspect):
|
||||
"""Set the aspect of the SubplotDivider."""
|
||||
self._divider.set_aspect(aspect)
|
||||
|
||||
def get_aspect(self):
|
||||
"""Return the aspect of the SubplotDivider."""
|
||||
return self._divider.get_aspect()
|
||||
|
||||
def set_label_mode(self, mode):
|
||||
"""
|
||||
Define which axes have tick labels.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mode : {"L", "1", "all", "keep"}
|
||||
The label mode:
|
||||
|
||||
- "L": All axes on the left column get vertical tick labels;
|
||||
all axes on the bottom row get horizontal tick labels.
|
||||
- "1": Only the bottom left axes is labelled.
|
||||
- "all": All axes are labelled.
|
||||
- "keep": Do not do anything.
|
||||
"""
|
||||
_api.check_in_list(["all", "L", "1", "keep"], mode=mode)
|
||||
is_last_row, is_first_col = (
|
||||
np.mgrid[:self._nrows, :self._ncols] == [[[self._nrows - 1]], [[0]]])
|
||||
if mode == "all":
|
||||
bottom = left = np.full((self._nrows, self._ncols), True)
|
||||
elif mode == "L":
|
||||
bottom = is_last_row
|
||||
left = is_first_col
|
||||
elif mode == "1":
|
||||
bottom = left = is_last_row & is_first_col
|
||||
else:
|
||||
return
|
||||
for i in range(self._nrows):
|
||||
for j in range(self._ncols):
|
||||
ax = self.axes_row[i][j]
|
||||
if isinstance(ax.axis, MethodType):
|
||||
bottom_axis = SimpleAxisArtist(ax.xaxis, 1, ax.spines["bottom"])
|
||||
left_axis = SimpleAxisArtist(ax.yaxis, 1, ax.spines["left"])
|
||||
else:
|
||||
bottom_axis = ax.axis["bottom"]
|
||||
left_axis = ax.axis["left"]
|
||||
bottom_axis.toggle(ticklabels=bottom[i, j], label=bottom[i, j])
|
||||
left_axis.toggle(ticklabels=left[i, j], label=left[i, j])
|
||||
|
||||
def get_divider(self):
|
||||
return self._divider
|
||||
|
||||
def set_axes_locator(self, locator):
|
||||
self._divider.set_locator(locator)
|
||||
|
||||
def get_axes_locator(self):
|
||||
return self._divider.get_locator()
|
||||
|
||||
|
||||
class ImageGrid(Grid):
|
||||
"""
|
||||
A grid of Axes for Image display.
|
||||
|
||||
This class is a specialization of `~.axes_grid1.axes_grid.Grid` for displaying a
|
||||
grid of images. In particular, it forces all axes in a column to share their x-axis
|
||||
and all axes in a row to share their y-axis. It further provides helpers to add
|
||||
colorbars to some or all axes.
|
||||
"""
|
||||
|
||||
def __init__(self, fig,
|
||||
rect,
|
||||
nrows_ncols,
|
||||
ngrids=None,
|
||||
direction="row",
|
||||
axes_pad=0.02,
|
||||
*,
|
||||
share_all=False,
|
||||
aspect=True,
|
||||
label_mode="L",
|
||||
cbar_mode=None,
|
||||
cbar_location="right",
|
||||
cbar_pad=None,
|
||||
cbar_size="5%",
|
||||
cbar_set_cax=True,
|
||||
axes_class=None,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
fig : `.Figure`
|
||||
The parent figure.
|
||||
rect : (float, float, float, float) or int
|
||||
The axes position, as a ``(left, bottom, width, height)`` tuple or
|
||||
as a three-digit subplot position code (e.g., "121").
|
||||
nrows_ncols : (int, int)
|
||||
Number of rows and columns in the grid.
|
||||
ngrids : int or None, default: None
|
||||
If not None, only the first *ngrids* axes in the grid are created.
|
||||
direction : {"row", "column"}, default: "row"
|
||||
Whether axes are created in row-major ("row by row") or
|
||||
column-major order ("column by column"). This also affects the
|
||||
order in which axes are accessed using indexing (``grid[index]``).
|
||||
axes_pad : float or (float, float), default: 0.02in
|
||||
Padding or (horizontal padding, vertical padding) between axes, in
|
||||
inches.
|
||||
share_all : bool, default: False
|
||||
Whether all axes share their x- and y-axis. Note that in any case,
|
||||
all axes in a column share their x-axis and all axes in a row share
|
||||
their y-axis.
|
||||
aspect : bool, default: True
|
||||
Whether the axes aspect ratio follows the aspect ratio of the data
|
||||
limits.
|
||||
label_mode : {"L", "1", "all"}, default: "L"
|
||||
Determines which axes will get tick labels:
|
||||
|
||||
- "L": All axes on the left column get vertical tick labels;
|
||||
all axes on the bottom row get horizontal tick labels.
|
||||
- "1": Only the bottom left axes is labelled.
|
||||
- "all": all axes are labelled.
|
||||
|
||||
cbar_mode : {"each", "single", "edge", None}, default: None
|
||||
Whether to create a colorbar for "each" axes, a "single" colorbar
|
||||
for the entire grid, colorbars only for axes on the "edge"
|
||||
determined by *cbar_location*, or no colorbars. The colorbars are
|
||||
stored in the :attr:`cbar_axes` attribute.
|
||||
cbar_location : {"left", "right", "bottom", "top"}, default: "right"
|
||||
cbar_pad : float, default: None
|
||||
Padding between the image axes and the colorbar axes.
|
||||
|
||||
.. versionchanged:: 3.10
|
||||
``cbar_mode="single"`` no longer adds *axes_pad* between the axes
|
||||
and the colorbar if the *cbar_location* is "left" or "bottom".
|
||||
|
||||
cbar_size : size specification (see `.Size.from_any`), default: "5%"
|
||||
Colorbar size.
|
||||
cbar_set_cax : bool, default: True
|
||||
If True, each axes in the grid has a *cax* attribute that is bound
|
||||
to associated *cbar_axes*.
|
||||
axes_class : subclass of `matplotlib.axes.Axes`, default: None
|
||||
"""
|
||||
_api.check_in_list(["each", "single", "edge", None],
|
||||
cbar_mode=cbar_mode)
|
||||
_api.check_in_list(["left", "right", "bottom", "top"],
|
||||
cbar_location=cbar_location)
|
||||
self._colorbar_mode = cbar_mode
|
||||
self._colorbar_location = cbar_location
|
||||
self._colorbar_pad = cbar_pad
|
||||
self._colorbar_size = cbar_size
|
||||
# The colorbar axes are created in _init_locators().
|
||||
|
||||
super().__init__(
|
||||
fig, rect, nrows_ncols, ngrids,
|
||||
direction=direction, axes_pad=axes_pad,
|
||||
share_all=share_all, share_x=True, share_y=True, aspect=aspect,
|
||||
label_mode=label_mode, axes_class=axes_class)
|
||||
|
||||
for ax in self.cbar_axes:
|
||||
fig.add_axes(ax)
|
||||
|
||||
if cbar_set_cax:
|
||||
if self._colorbar_mode == "single":
|
||||
for ax in self.axes_all:
|
||||
ax.cax = self.cbar_axes[0]
|
||||
elif self._colorbar_mode == "edge":
|
||||
for index, ax in enumerate(self.axes_all):
|
||||
col, row = self._get_col_row(index)
|
||||
if self._colorbar_location in ("left", "right"):
|
||||
ax.cax = self.cbar_axes[row]
|
||||
else:
|
||||
ax.cax = self.cbar_axes[col]
|
||||
else:
|
||||
for ax, cax in zip(self.axes_all, self.cbar_axes):
|
||||
ax.cax = cax
|
||||
|
||||
def _init_locators(self):
|
||||
# Slightly abusing this method to inject colorbar creation into init.
|
||||
|
||||
if self._colorbar_pad is None:
|
||||
# horizontal or vertical arrangement?
|
||||
if self._colorbar_location in ("left", "right"):
|
||||
self._colorbar_pad = self._horiz_pad_size.fixed_size
|
||||
else:
|
||||
self._colorbar_pad = self._vert_pad_size.fixed_size
|
||||
self.cbar_axes = [
|
||||
_cbaraxes_class_factory(self._defaultAxesClass)(
|
||||
self.axes_all[0].get_figure(root=False), self._divider.get_position(),
|
||||
orientation=self._colorbar_location)
|
||||
for _ in range(self.ngrids)]
|
||||
|
||||
cb_mode = self._colorbar_mode
|
||||
cb_location = self._colorbar_location
|
||||
|
||||
h = []
|
||||
v = []
|
||||
|
||||
h_ax_pos = []
|
||||
h_cb_pos = []
|
||||
if cb_mode == "single" and cb_location in ("left", "bottom"):
|
||||
if cb_location == "left":
|
||||
sz = self._nrows * Size.AxesX(self.axes_llc)
|
||||
h.append(Size.from_any(self._colorbar_size, sz))
|
||||
h.append(Size.from_any(self._colorbar_pad, sz))
|
||||
locator = self._divider.new_locator(nx=0, ny=0, ny1=-1)
|
||||
elif cb_location == "bottom":
|
||||
sz = self._ncols * Size.AxesY(self.axes_llc)
|
||||
v.append(Size.from_any(self._colorbar_size, sz))
|
||||
v.append(Size.from_any(self._colorbar_pad, sz))
|
||||
locator = self._divider.new_locator(nx=0, nx1=-1, ny=0)
|
||||
for i in range(self.ngrids):
|
||||
self.cbar_axes[i].set_visible(False)
|
||||
self.cbar_axes[0].set_axes_locator(locator)
|
||||
self.cbar_axes[0].set_visible(True)
|
||||
|
||||
for col, ax in enumerate(self.axes_row[0]):
|
||||
if col != 0:
|
||||
h.append(self._horiz_pad_size)
|
||||
|
||||
if ax:
|
||||
sz = Size.AxesX(ax, aspect="axes", ref_ax=self.axes_all[0])
|
||||
else:
|
||||
sz = Size.AxesX(self.axes_all[0],
|
||||
aspect="axes", ref_ax=self.axes_all[0])
|
||||
|
||||
if (cb_location == "left"
|
||||
and (cb_mode == "each"
|
||||
or (cb_mode == "edge" and col == 0))):
|
||||
h_cb_pos.append(len(h))
|
||||
h.append(Size.from_any(self._colorbar_size, sz))
|
||||
h.append(Size.from_any(self._colorbar_pad, sz))
|
||||
|
||||
h_ax_pos.append(len(h))
|
||||
h.append(sz)
|
||||
|
||||
if (cb_location == "right"
|
||||
and (cb_mode == "each"
|
||||
or (cb_mode == "edge" and col == self._ncols - 1))):
|
||||
h.append(Size.from_any(self._colorbar_pad, sz))
|
||||
h_cb_pos.append(len(h))
|
||||
h.append(Size.from_any(self._colorbar_size, sz))
|
||||
|
||||
v_ax_pos = []
|
||||
v_cb_pos = []
|
||||
for row, ax in enumerate(self.axes_column[0][::-1]):
|
||||
if row != 0:
|
||||
v.append(self._vert_pad_size)
|
||||
|
||||
if ax:
|
||||
sz = Size.AxesY(ax, aspect="axes", ref_ax=self.axes_all[0])
|
||||
else:
|
||||
sz = Size.AxesY(self.axes_all[0],
|
||||
aspect="axes", ref_ax=self.axes_all[0])
|
||||
|
||||
if (cb_location == "bottom"
|
||||
and (cb_mode == "each"
|
||||
or (cb_mode == "edge" and row == 0))):
|
||||
v_cb_pos.append(len(v))
|
||||
v.append(Size.from_any(self._colorbar_size, sz))
|
||||
v.append(Size.from_any(self._colorbar_pad, sz))
|
||||
|
||||
v_ax_pos.append(len(v))
|
||||
v.append(sz)
|
||||
|
||||
if (cb_location == "top"
|
||||
and (cb_mode == "each"
|
||||
or (cb_mode == "edge" and row == self._nrows - 1))):
|
||||
v.append(Size.from_any(self._colorbar_pad, sz))
|
||||
v_cb_pos.append(len(v))
|
||||
v.append(Size.from_any(self._colorbar_size, sz))
|
||||
|
||||
for i in range(self.ngrids):
|
||||
col, row = self._get_col_row(i)
|
||||
locator = self._divider.new_locator(nx=h_ax_pos[col],
|
||||
ny=v_ax_pos[self._nrows-1-row])
|
||||
self.axes_all[i].set_axes_locator(locator)
|
||||
|
||||
if cb_mode == "each":
|
||||
if cb_location in ("right", "left"):
|
||||
locator = self._divider.new_locator(
|
||||
nx=h_cb_pos[col], ny=v_ax_pos[self._nrows - 1 - row])
|
||||
|
||||
elif cb_location in ("top", "bottom"):
|
||||
locator = self._divider.new_locator(
|
||||
nx=h_ax_pos[col], ny=v_cb_pos[self._nrows - 1 - row])
|
||||
|
||||
self.cbar_axes[i].set_axes_locator(locator)
|
||||
elif cb_mode == "edge":
|
||||
if (cb_location == "left" and col == 0
|
||||
or cb_location == "right" and col == self._ncols - 1):
|
||||
locator = self._divider.new_locator(
|
||||
nx=h_cb_pos[0], ny=v_ax_pos[self._nrows - 1 - row])
|
||||
self.cbar_axes[row].set_axes_locator(locator)
|
||||
elif (cb_location == "bottom" and row == self._nrows - 1
|
||||
or cb_location == "top" and row == 0):
|
||||
locator = self._divider.new_locator(nx=h_ax_pos[col],
|
||||
ny=v_cb_pos[0])
|
||||
self.cbar_axes[col].set_axes_locator(locator)
|
||||
|
||||
if cb_mode == "single":
|
||||
if cb_location == "right":
|
||||
sz = self._nrows * Size.AxesX(self.axes_llc)
|
||||
h.append(Size.from_any(self._colorbar_pad, sz))
|
||||
h.append(Size.from_any(self._colorbar_size, sz))
|
||||
locator = self._divider.new_locator(nx=-2, ny=0, ny1=-1)
|
||||
elif cb_location == "top":
|
||||
sz = self._ncols * Size.AxesY(self.axes_llc)
|
||||
v.append(Size.from_any(self._colorbar_pad, sz))
|
||||
v.append(Size.from_any(self._colorbar_size, sz))
|
||||
locator = self._divider.new_locator(nx=0, nx1=-1, ny=-2)
|
||||
if cb_location in ("right", "top"):
|
||||
for i in range(self.ngrids):
|
||||
self.cbar_axes[i].set_visible(False)
|
||||
self.cbar_axes[0].set_axes_locator(locator)
|
||||
self.cbar_axes[0].set_visible(True)
|
||||
elif cb_mode == "each":
|
||||
for i in range(self.ngrids):
|
||||
self.cbar_axes[i].set_visible(True)
|
||||
elif cb_mode == "edge":
|
||||
if cb_location in ("right", "left"):
|
||||
count = self._nrows
|
||||
else:
|
||||
count = self._ncols
|
||||
for i in range(count):
|
||||
self.cbar_axes[i].set_visible(True)
|
||||
for j in range(i + 1, self.ngrids):
|
||||
self.cbar_axes[j].set_visible(False)
|
||||
else:
|
||||
for i in range(self.ngrids):
|
||||
self.cbar_axes[i].set_visible(False)
|
||||
self.cbar_axes[i].set_position([1., 1., 0.001, 0.001],
|
||||
which="active")
|
||||
|
||||
self._divider.set_horizontal(h)
|
||||
self._divider.set_vertical(v)
|
||||
|
||||
|
||||
AxesGrid = ImageGrid
|
|
@ -0,0 +1,157 @@
|
|||
from types import MethodType
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .axes_divider import make_axes_locatable, Size
|
||||
from .mpl_axes import Axes, SimpleAxisArtist
|
||||
|
||||
|
||||
def make_rgb_axes(ax, pad=0.01, axes_class=None, **kwargs):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
ax : `~matplotlib.axes.Axes`
|
||||
Axes instance to create the RGB Axes in.
|
||||
pad : float, optional
|
||||
Fraction of the Axes height to pad.
|
||||
axes_class : `matplotlib.axes.Axes` or None, optional
|
||||
Axes class to use for the R, G, and B Axes. If None, use
|
||||
the same class as *ax*.
|
||||
**kwargs
|
||||
Forwarded to *axes_class* init for the R, G, and B Axes.
|
||||
"""
|
||||
|
||||
divider = make_axes_locatable(ax)
|
||||
|
||||
pad_size = pad * Size.AxesY(ax)
|
||||
|
||||
xsize = ((1-2*pad)/3) * Size.AxesX(ax)
|
||||
ysize = ((1-2*pad)/3) * Size.AxesY(ax)
|
||||
|
||||
divider.set_horizontal([Size.AxesX(ax), pad_size, xsize])
|
||||
divider.set_vertical([ysize, pad_size, ysize, pad_size, ysize])
|
||||
|
||||
ax.set_axes_locator(divider.new_locator(0, 0, ny1=-1))
|
||||
|
||||
ax_rgb = []
|
||||
if axes_class is None:
|
||||
axes_class = type(ax)
|
||||
|
||||
for ny in [4, 2, 0]:
|
||||
ax1 = axes_class(ax.get_figure(), ax.get_position(original=True),
|
||||
sharex=ax, sharey=ax, **kwargs)
|
||||
locator = divider.new_locator(nx=2, ny=ny)
|
||||
ax1.set_axes_locator(locator)
|
||||
for t in ax1.yaxis.get_ticklabels() + ax1.xaxis.get_ticklabels():
|
||||
t.set_visible(False)
|
||||
try:
|
||||
for axis in ax1.axis.values():
|
||||
axis.major_ticklabels.set_visible(False)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
ax_rgb.append(ax1)
|
||||
|
||||
fig = ax.get_figure()
|
||||
for ax1 in ax_rgb:
|
||||
fig.add_axes(ax1)
|
||||
|
||||
return ax_rgb
|
||||
|
||||
|
||||
class RGBAxes:
|
||||
"""
|
||||
4-panel `~.Axes.imshow` (RGB, R, G, B).
|
||||
|
||||
Layout::
|
||||
|
||||
┌───────────────┬─────┐
|
||||
│ │ R │
|
||||
│ ├─────┤
|
||||
│ RGB │ G │
|
||||
│ ├─────┤
|
||||
│ │ B │
|
||||
└───────────────┴─────┘
|
||||
|
||||
Subclasses can override the ``_defaultAxesClass`` attribute.
|
||||
By default RGBAxes uses `.mpl_axes.Axes`.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
RGB : ``_defaultAxesClass``
|
||||
The Axes object for the three-channel `~.Axes.imshow`.
|
||||
R : ``_defaultAxesClass``
|
||||
The Axes object for the red channel `~.Axes.imshow`.
|
||||
G : ``_defaultAxesClass``
|
||||
The Axes object for the green channel `~.Axes.imshow`.
|
||||
B : ``_defaultAxesClass``
|
||||
The Axes object for the blue channel `~.Axes.imshow`.
|
||||
"""
|
||||
|
||||
_defaultAxesClass = Axes
|
||||
|
||||
def __init__(self, *args, pad=0, **kwargs):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
pad : float, default: 0
|
||||
Fraction of the Axes height to put as padding.
|
||||
axes_class : `~matplotlib.axes.Axes`
|
||||
Axes class to use. If not provided, ``_defaultAxesClass`` is used.
|
||||
*args
|
||||
Forwarded to *axes_class* init for the RGB Axes
|
||||
**kwargs
|
||||
Forwarded to *axes_class* init for the RGB, R, G, and B Axes
|
||||
"""
|
||||
axes_class = kwargs.pop("axes_class", self._defaultAxesClass)
|
||||
self.RGB = ax = axes_class(*args, **kwargs)
|
||||
ax.get_figure().add_axes(ax)
|
||||
self.R, self.G, self.B = make_rgb_axes(
|
||||
ax, pad=pad, axes_class=axes_class, **kwargs)
|
||||
# Set the line color and ticks for the axes.
|
||||
for ax1 in [self.RGB, self.R, self.G, self.B]:
|
||||
if isinstance(ax1.axis, MethodType):
|
||||
ad = Axes.AxisDict(self)
|
||||
ad.update(
|
||||
bottom=SimpleAxisArtist(ax1.xaxis, 1, ax1.spines["bottom"]),
|
||||
top=SimpleAxisArtist(ax1.xaxis, 2, ax1.spines["top"]),
|
||||
left=SimpleAxisArtist(ax1.yaxis, 1, ax1.spines["left"]),
|
||||
right=SimpleAxisArtist(ax1.yaxis, 2, ax1.spines["right"]))
|
||||
else:
|
||||
ad = ax1.axis
|
||||
ad[:].line.set_color("w")
|
||||
ad[:].major_ticks.set_markeredgecolor("w")
|
||||
|
||||
def imshow_rgb(self, r, g, b, **kwargs):
|
||||
"""
|
||||
Create the four images {rgb, r, g, b}.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
r, g, b : array-like
|
||||
The red, green, and blue arrays.
|
||||
**kwargs
|
||||
Forwarded to `~.Axes.imshow` calls for the four images.
|
||||
|
||||
Returns
|
||||
-------
|
||||
rgb : `~matplotlib.image.AxesImage`
|
||||
r : `~matplotlib.image.AxesImage`
|
||||
g : `~matplotlib.image.AxesImage`
|
||||
b : `~matplotlib.image.AxesImage`
|
||||
"""
|
||||
if not (r.shape == g.shape == b.shape):
|
||||
raise ValueError(
|
||||
f'Input shapes ({r.shape}, {g.shape}, {b.shape}) do not match')
|
||||
RGB = np.dstack([r, g, b])
|
||||
R = np.zeros_like(RGB)
|
||||
R[:, :, 0] = r
|
||||
G = np.zeros_like(RGB)
|
||||
G[:, :, 1] = g
|
||||
B = np.zeros_like(RGB)
|
||||
B[:, :, 2] = b
|
||||
im_rgb = self.RGB.imshow(RGB, **kwargs)
|
||||
im_r = self.R.imshow(R, **kwargs)
|
||||
im_g = self.G.imshow(G, **kwargs)
|
||||
im_b = self.B.imshow(B, **kwargs)
|
||||
return im_rgb, im_r, im_g, im_b
|
|
@ -0,0 +1,271 @@
|
|||
"""
|
||||
Provides classes of simple units that will be used with `.AxesDivider`
|
||||
class (or others) to determine the size of each Axes. The unit
|
||||
classes define `get_size` method that returns a tuple of two floats,
|
||||
meaning relative and absolute sizes, respectively.
|
||||
|
||||
Note that this class is nothing more than a simple tuple of two
|
||||
floats. Take a look at the Divider class to see how these two
|
||||
values are used.
|
||||
|
||||
Once created, the unit classes can be modified by simple arithmetic
|
||||
operations: addition /subtraction with another unit type or a real number and scaling
|
||||
(multiplication or division) by a real number.
|
||||
"""
|
||||
|
||||
from numbers import Real
|
||||
|
||||
from matplotlib import _api
|
||||
from matplotlib.axes import Axes
|
||||
|
||||
|
||||
class _Base:
|
||||
def __rmul__(self, other):
|
||||
return self * other
|
||||
|
||||
def __mul__(self, other):
|
||||
if not isinstance(other, Real):
|
||||
return NotImplemented
|
||||
return Fraction(other, self)
|
||||
|
||||
def __div__(self, other):
|
||||
return (1 / other) * self
|
||||
|
||||
def __add__(self, other):
|
||||
if isinstance(other, _Base):
|
||||
return Add(self, other)
|
||||
else:
|
||||
return Add(self, Fixed(other))
|
||||
|
||||
def __neg__(self):
|
||||
return -1 * self
|
||||
|
||||
def __radd__(self, other):
|
||||
# other cannot be a _Base instance, because A + B would trigger
|
||||
# A.__add__(B) first.
|
||||
return Add(self, Fixed(other))
|
||||
|
||||
def __sub__(self, other):
|
||||
return self + (-other)
|
||||
|
||||
def get_size(self, renderer):
|
||||
"""
|
||||
Return two-float tuple with relative and absolute sizes.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement")
|
||||
|
||||
|
||||
class Add(_Base):
|
||||
"""
|
||||
Sum of two sizes.
|
||||
"""
|
||||
|
||||
def __init__(self, a, b):
|
||||
self._a = a
|
||||
self._b = b
|
||||
|
||||
def get_size(self, renderer):
|
||||
a_rel_size, a_abs_size = self._a.get_size(renderer)
|
||||
b_rel_size, b_abs_size = self._b.get_size(renderer)
|
||||
return a_rel_size + b_rel_size, a_abs_size + b_abs_size
|
||||
|
||||
|
||||
class Fixed(_Base):
|
||||
"""
|
||||
Simple fixed size with absolute part = *fixed_size* and relative part = 0.
|
||||
"""
|
||||
|
||||
def __init__(self, fixed_size):
|
||||
_api.check_isinstance(Real, fixed_size=fixed_size)
|
||||
self.fixed_size = fixed_size
|
||||
|
||||
def get_size(self, renderer):
|
||||
rel_size = 0.
|
||||
abs_size = self.fixed_size
|
||||
return rel_size, abs_size
|
||||
|
||||
|
||||
class Scaled(_Base):
|
||||
"""
|
||||
Simple scaled(?) size with absolute part = 0 and
|
||||
relative part = *scalable_size*.
|
||||
"""
|
||||
|
||||
def __init__(self, scalable_size):
|
||||
self._scalable_size = scalable_size
|
||||
|
||||
def get_size(self, renderer):
|
||||
rel_size = self._scalable_size
|
||||
abs_size = 0.
|
||||
return rel_size, abs_size
|
||||
|
||||
Scalable = Scaled
|
||||
|
||||
|
||||
def _get_axes_aspect(ax):
|
||||
aspect = ax.get_aspect()
|
||||
if aspect == "auto":
|
||||
aspect = 1.
|
||||
return aspect
|
||||
|
||||
|
||||
class AxesX(_Base):
|
||||
"""
|
||||
Scaled size whose relative part corresponds to the data width
|
||||
of the *axes* multiplied by the *aspect*.
|
||||
"""
|
||||
|
||||
def __init__(self, axes, aspect=1., ref_ax=None):
|
||||
self._axes = axes
|
||||
self._aspect = aspect
|
||||
if aspect == "axes" and ref_ax is None:
|
||||
raise ValueError("ref_ax must be set when aspect='axes'")
|
||||
self._ref_ax = ref_ax
|
||||
|
||||
def get_size(self, renderer):
|
||||
l1, l2 = self._axes.get_xlim()
|
||||
if self._aspect == "axes":
|
||||
ref_aspect = _get_axes_aspect(self._ref_ax)
|
||||
aspect = ref_aspect / _get_axes_aspect(self._axes)
|
||||
else:
|
||||
aspect = self._aspect
|
||||
|
||||
rel_size = abs(l2-l1)*aspect
|
||||
abs_size = 0.
|
||||
return rel_size, abs_size
|
||||
|
||||
|
||||
class AxesY(_Base):
|
||||
"""
|
||||
Scaled size whose relative part corresponds to the data height
|
||||
of the *axes* multiplied by the *aspect*.
|
||||
"""
|
||||
|
||||
def __init__(self, axes, aspect=1., ref_ax=None):
|
||||
self._axes = axes
|
||||
self._aspect = aspect
|
||||
if aspect == "axes" and ref_ax is None:
|
||||
raise ValueError("ref_ax must be set when aspect='axes'")
|
||||
self._ref_ax = ref_ax
|
||||
|
||||
def get_size(self, renderer):
|
||||
l1, l2 = self._axes.get_ylim()
|
||||
|
||||
if self._aspect == "axes":
|
||||
ref_aspect = _get_axes_aspect(self._ref_ax)
|
||||
aspect = _get_axes_aspect(self._axes)
|
||||
else:
|
||||
aspect = self._aspect
|
||||
|
||||
rel_size = abs(l2-l1)*aspect
|
||||
abs_size = 0.
|
||||
return rel_size, abs_size
|
||||
|
||||
|
||||
class MaxExtent(_Base):
|
||||
"""
|
||||
Size whose absolute part is either the largest width or the largest height
|
||||
of the given *artist_list*.
|
||||
"""
|
||||
|
||||
def __init__(self, artist_list, w_or_h):
|
||||
self._artist_list = artist_list
|
||||
_api.check_in_list(["width", "height"], w_or_h=w_or_h)
|
||||
self._w_or_h = w_or_h
|
||||
|
||||
def add_artist(self, a):
|
||||
self._artist_list.append(a)
|
||||
|
||||
def get_size(self, renderer):
|
||||
rel_size = 0.
|
||||
extent_list = [
|
||||
getattr(a.get_window_extent(renderer), self._w_or_h) / a.figure.dpi
|
||||
for a in self._artist_list]
|
||||
abs_size = max(extent_list, default=0)
|
||||
return rel_size, abs_size
|
||||
|
||||
|
||||
class MaxWidth(MaxExtent):
|
||||
"""
|
||||
Size whose absolute part is the largest width of the given *artist_list*.
|
||||
"""
|
||||
|
||||
def __init__(self, artist_list):
|
||||
super().__init__(artist_list, "width")
|
||||
|
||||
|
||||
class MaxHeight(MaxExtent):
|
||||
"""
|
||||
Size whose absolute part is the largest height of the given *artist_list*.
|
||||
"""
|
||||
|
||||
def __init__(self, artist_list):
|
||||
super().__init__(artist_list, "height")
|
||||
|
||||
|
||||
class Fraction(_Base):
|
||||
"""
|
||||
An instance whose size is a *fraction* of the *ref_size*.
|
||||
|
||||
>>> s = Fraction(0.3, AxesX(ax))
|
||||
"""
|
||||
|
||||
def __init__(self, fraction, ref_size):
|
||||
_api.check_isinstance(Real, fraction=fraction)
|
||||
self._fraction_ref = ref_size
|
||||
self._fraction = fraction
|
||||
|
||||
def get_size(self, renderer):
|
||||
if self._fraction_ref is None:
|
||||
return self._fraction, 0.
|
||||
else:
|
||||
r, a = self._fraction_ref.get_size(renderer)
|
||||
rel_size = r*self._fraction
|
||||
abs_size = a*self._fraction
|
||||
return rel_size, abs_size
|
||||
|
||||
|
||||
def from_any(size, fraction_ref=None):
|
||||
"""
|
||||
Create a Fixed unit when the first argument is a float, or a
|
||||
Fraction unit if that is a string that ends with %. The second
|
||||
argument is only meaningful when Fraction unit is created.
|
||||
|
||||
>>> from mpl_toolkits.axes_grid1.axes_size import from_any
|
||||
>>> a = from_any(1.2) # => Fixed(1.2)
|
||||
>>> from_any("50%", a) # => Fraction(0.5, a)
|
||||
"""
|
||||
if isinstance(size, Real):
|
||||
return Fixed(size)
|
||||
elif isinstance(size, str):
|
||||
if size[-1] == "%":
|
||||
return Fraction(float(size[:-1]) / 100, fraction_ref)
|
||||
raise ValueError("Unknown format")
|
||||
|
||||
|
||||
class _AxesDecorationsSize(_Base):
|
||||
"""
|
||||
Fixed size, corresponding to the size of decorations on a given Axes side.
|
||||
"""
|
||||
|
||||
_get_size_map = {
|
||||
"left": lambda tight_bb, axes_bb: axes_bb.xmin - tight_bb.xmin,
|
||||
"right": lambda tight_bb, axes_bb: tight_bb.xmax - axes_bb.xmax,
|
||||
"bottom": lambda tight_bb, axes_bb: axes_bb.ymin - tight_bb.ymin,
|
||||
"top": lambda tight_bb, axes_bb: tight_bb.ymax - axes_bb.ymax,
|
||||
}
|
||||
|
||||
def __init__(self, ax, direction):
|
||||
_api.check_in_list(self._get_size_map, direction=direction)
|
||||
self._direction = direction
|
||||
self._ax_list = [ax] if isinstance(ax, Axes) else ax
|
||||
|
||||
def get_size(self, renderer):
|
||||
sz = max([
|
||||
self._get_size_map[self._direction](
|
||||
ax.get_tightbbox(renderer, call_axes_locator=False), ax.bbox)
|
||||
for ax in self._ax_list])
|
||||
dpi = renderer.points_to_pixels(72)
|
||||
abs_size = sz / dpi
|
||||
rel_size = 0
|
||||
return rel_size, abs_size
|
|
@ -0,0 +1,519 @@
|
|||
"""
|
||||
A collection of functions and objects for creating or placing inset axes.
|
||||
"""
|
||||
|
||||
from matplotlib import _api, _docstring
|
||||
from matplotlib.offsetbox import AnchoredOffsetbox
|
||||
from matplotlib.patches import Patch, Rectangle
|
||||
from matplotlib.path import Path
|
||||
from matplotlib.transforms import Bbox
|
||||
from matplotlib.transforms import IdentityTransform, TransformedBbox
|
||||
|
||||
from . import axes_size as Size
|
||||
from .parasite_axes import HostAxes
|
||||
|
||||
|
||||
class AnchoredLocatorBase(AnchoredOffsetbox):
|
||||
def __init__(self, bbox_to_anchor, offsetbox, loc,
|
||||
borderpad=0.5, bbox_transform=None):
|
||||
super().__init__(
|
||||
loc, pad=0., child=None, borderpad=borderpad,
|
||||
bbox_to_anchor=bbox_to_anchor, bbox_transform=bbox_transform
|
||||
)
|
||||
|
||||
def draw(self, renderer):
|
||||
raise RuntimeError("No draw method should be called")
|
||||
|
||||
def __call__(self, ax, renderer):
|
||||
fig = ax.get_figure(root=False)
|
||||
if renderer is None:
|
||||
renderer = fig._get_renderer()
|
||||
self.axes = ax
|
||||
bbox = self.get_window_extent(renderer)
|
||||
px, py = self.get_offset(bbox.width, bbox.height, 0, 0, renderer)
|
||||
bbox_canvas = Bbox.from_bounds(px, py, bbox.width, bbox.height)
|
||||
tr = fig.transSubfigure.inverted()
|
||||
return TransformedBbox(bbox_canvas, tr)
|
||||
|
||||
|
||||
class AnchoredSizeLocator(AnchoredLocatorBase):
|
||||
def __init__(self, bbox_to_anchor, x_size, y_size, loc,
|
||||
borderpad=0.5, bbox_transform=None):
|
||||
super().__init__(
|
||||
bbox_to_anchor, None, loc,
|
||||
borderpad=borderpad, bbox_transform=bbox_transform
|
||||
)
|
||||
|
||||
self.x_size = Size.from_any(x_size)
|
||||
self.y_size = Size.from_any(y_size)
|
||||
|
||||
def get_bbox(self, renderer):
|
||||
bbox = self.get_bbox_to_anchor()
|
||||
dpi = renderer.points_to_pixels(72.)
|
||||
|
||||
r, a = self.x_size.get_size(renderer)
|
||||
width = bbox.width * r + a * dpi
|
||||
r, a = self.y_size.get_size(renderer)
|
||||
height = bbox.height * r + a * dpi
|
||||
|
||||
fontsize = renderer.points_to_pixels(self.prop.get_size_in_points())
|
||||
pad = self.pad * fontsize
|
||||
|
||||
return Bbox.from_bounds(0, 0, width, height).padded(pad)
|
||||
|
||||
|
||||
class AnchoredZoomLocator(AnchoredLocatorBase):
|
||||
def __init__(self, parent_axes, zoom, loc,
|
||||
borderpad=0.5,
|
||||
bbox_to_anchor=None,
|
||||
bbox_transform=None):
|
||||
self.parent_axes = parent_axes
|
||||
self.zoom = zoom
|
||||
if bbox_to_anchor is None:
|
||||
bbox_to_anchor = parent_axes.bbox
|
||||
super().__init__(
|
||||
bbox_to_anchor, None, loc, borderpad=borderpad,
|
||||
bbox_transform=bbox_transform)
|
||||
|
||||
def get_bbox(self, renderer):
|
||||
bb = self.parent_axes.transData.transform_bbox(self.axes.viewLim)
|
||||
fontsize = renderer.points_to_pixels(self.prop.get_size_in_points())
|
||||
pad = self.pad * fontsize
|
||||
return (
|
||||
Bbox.from_bounds(
|
||||
0, 0, abs(bb.width * self.zoom), abs(bb.height * self.zoom))
|
||||
.padded(pad))
|
||||
|
||||
|
||||
class BboxPatch(Patch):
|
||||
@_docstring.interpd
|
||||
def __init__(self, bbox, **kwargs):
|
||||
"""
|
||||
Patch showing the shape bounded by a Bbox.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bbox : `~matplotlib.transforms.Bbox`
|
||||
Bbox to use for the extents of this patch.
|
||||
|
||||
**kwargs
|
||||
Patch properties. Valid arguments include:
|
||||
|
||||
%(Patch:kwdoc)s
|
||||
"""
|
||||
if "transform" in kwargs:
|
||||
raise ValueError("transform should not be set")
|
||||
|
||||
kwargs["transform"] = IdentityTransform()
|
||||
super().__init__(**kwargs)
|
||||
self.bbox = bbox
|
||||
|
||||
def get_path(self):
|
||||
# docstring inherited
|
||||
x0, y0, x1, y1 = self.bbox.extents
|
||||
return Path._create_closed([(x0, y0), (x1, y0), (x1, y1), (x0, y1)])
|
||||
|
||||
|
||||
class BboxConnector(Patch):
|
||||
@staticmethod
|
||||
def get_bbox_edge_pos(bbox, loc):
|
||||
"""
|
||||
Return the ``(x, y)`` coordinates of corner *loc* of *bbox*; parameters
|
||||
behave as documented for the `.BboxConnector` constructor.
|
||||
"""
|
||||
x0, y0, x1, y1 = bbox.extents
|
||||
if loc == 1:
|
||||
return x1, y1
|
||||
elif loc == 2:
|
||||
return x0, y1
|
||||
elif loc == 3:
|
||||
return x0, y0
|
||||
elif loc == 4:
|
||||
return x1, y0
|
||||
|
||||
@staticmethod
|
||||
def connect_bbox(bbox1, bbox2, loc1, loc2=None):
|
||||
"""
|
||||
Construct a `.Path` connecting corner *loc1* of *bbox1* to corner
|
||||
*loc2* of *bbox2*, where parameters behave as documented as for the
|
||||
`.BboxConnector` constructor.
|
||||
"""
|
||||
if isinstance(bbox1, Rectangle):
|
||||
bbox1 = TransformedBbox(Bbox.unit(), bbox1.get_transform())
|
||||
if isinstance(bbox2, Rectangle):
|
||||
bbox2 = TransformedBbox(Bbox.unit(), bbox2.get_transform())
|
||||
if loc2 is None:
|
||||
loc2 = loc1
|
||||
x1, y1 = BboxConnector.get_bbox_edge_pos(bbox1, loc1)
|
||||
x2, y2 = BboxConnector.get_bbox_edge_pos(bbox2, loc2)
|
||||
return Path([[x1, y1], [x2, y2]])
|
||||
|
||||
@_docstring.interpd
|
||||
def __init__(self, bbox1, bbox2, loc1, loc2=None, **kwargs):
|
||||
"""
|
||||
Connect two bboxes with a straight line.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bbox1, bbox2 : `~matplotlib.transforms.Bbox`
|
||||
Bounding boxes to connect.
|
||||
|
||||
loc1, loc2 : {1, 2, 3, 4}
|
||||
Corner of *bbox1* and *bbox2* to draw the line. Valid values are::
|
||||
|
||||
'upper right' : 1,
|
||||
'upper left' : 2,
|
||||
'lower left' : 3,
|
||||
'lower right' : 4
|
||||
|
||||
*loc2* is optional and defaults to *loc1*.
|
||||
|
||||
**kwargs
|
||||
Patch properties for the line drawn. Valid arguments include:
|
||||
|
||||
%(Patch:kwdoc)s
|
||||
"""
|
||||
if "transform" in kwargs:
|
||||
raise ValueError("transform should not be set")
|
||||
|
||||
kwargs["transform"] = IdentityTransform()
|
||||
kwargs.setdefault(
|
||||
"fill", bool({'fc', 'facecolor', 'color'}.intersection(kwargs)))
|
||||
super().__init__(**kwargs)
|
||||
self.bbox1 = bbox1
|
||||
self.bbox2 = bbox2
|
||||
self.loc1 = loc1
|
||||
self.loc2 = loc2
|
||||
|
||||
def get_path(self):
|
||||
# docstring inherited
|
||||
return self.connect_bbox(self.bbox1, self.bbox2,
|
||||
self.loc1, self.loc2)
|
||||
|
||||
|
||||
class BboxConnectorPatch(BboxConnector):
|
||||
@_docstring.interpd
|
||||
def __init__(self, bbox1, bbox2, loc1a, loc2a, loc1b, loc2b, **kwargs):
|
||||
"""
|
||||
Connect two bboxes with a quadrilateral.
|
||||
|
||||
The quadrilateral is specified by two lines that start and end at
|
||||
corners of the bboxes. The four sides of the quadrilateral are defined
|
||||
by the two lines given, the line between the two corners specified in
|
||||
*bbox1* and the line between the two corners specified in *bbox2*.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bbox1, bbox2 : `~matplotlib.transforms.Bbox`
|
||||
Bounding boxes to connect.
|
||||
|
||||
loc1a, loc2a, loc1b, loc2b : {1, 2, 3, 4}
|
||||
The first line connects corners *loc1a* of *bbox1* and *loc2a* of
|
||||
*bbox2*; the second line connects corners *loc1b* of *bbox1* and
|
||||
*loc2b* of *bbox2*. Valid values are::
|
||||
|
||||
'upper right' : 1,
|
||||
'upper left' : 2,
|
||||
'lower left' : 3,
|
||||
'lower right' : 4
|
||||
|
||||
**kwargs
|
||||
Patch properties for the line drawn:
|
||||
|
||||
%(Patch:kwdoc)s
|
||||
"""
|
||||
if "transform" in kwargs:
|
||||
raise ValueError("transform should not be set")
|
||||
super().__init__(bbox1, bbox2, loc1a, loc2a, **kwargs)
|
||||
self.loc1b = loc1b
|
||||
self.loc2b = loc2b
|
||||
|
||||
def get_path(self):
|
||||
# docstring inherited
|
||||
path1 = self.connect_bbox(self.bbox1, self.bbox2, self.loc1, self.loc2)
|
||||
path2 = self.connect_bbox(self.bbox2, self.bbox1,
|
||||
self.loc2b, self.loc1b)
|
||||
path_merged = [*path1.vertices, *path2.vertices, path1.vertices[0]]
|
||||
return Path(path_merged)
|
||||
|
||||
|
||||
def _add_inset_axes(parent_axes, axes_class, axes_kwargs, axes_locator):
|
||||
"""Helper function to add an inset axes and disable navigation in it."""
|
||||
if axes_class is None:
|
||||
axes_class = HostAxes
|
||||
if axes_kwargs is None:
|
||||
axes_kwargs = {}
|
||||
fig = parent_axes.get_figure(root=False)
|
||||
inset_axes = axes_class(
|
||||
fig, parent_axes.get_position(),
|
||||
**{"navigate": False, **axes_kwargs, "axes_locator": axes_locator})
|
||||
return fig.add_axes(inset_axes)
|
||||
|
||||
|
||||
@_docstring.interpd
|
||||
def inset_axes(parent_axes, width, height, loc='upper right',
|
||||
bbox_to_anchor=None, bbox_transform=None,
|
||||
axes_class=None, axes_kwargs=None,
|
||||
borderpad=0.5):
|
||||
"""
|
||||
Create an inset axes with a given width and height.
|
||||
|
||||
Both sizes used can be specified either in inches or percentage.
|
||||
For example,::
|
||||
|
||||
inset_axes(parent_axes, width='40%%', height='30%%', loc='lower left')
|
||||
|
||||
creates in inset axes in the lower left corner of *parent_axes* which spans
|
||||
over 30%% in height and 40%% in width of the *parent_axes*. Since the usage
|
||||
of `.inset_axes` may become slightly tricky when exceeding such standard
|
||||
cases, it is recommended to read :doc:`the examples
|
||||
</gallery/axes_grid1/inset_locator_demo>`.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The meaning of *bbox_to_anchor* and *bbox_to_transform* is interpreted
|
||||
differently from that of legend. The value of bbox_to_anchor
|
||||
(or the return value of its get_points method; the default is
|
||||
*parent_axes.bbox*) is transformed by the bbox_transform (the default
|
||||
is Identity transform) and then interpreted as points in the pixel
|
||||
coordinate (which is dpi dependent).
|
||||
|
||||
Thus, following three calls are identical and creates an inset axes
|
||||
with respect to the *parent_axes*::
|
||||
|
||||
axins = inset_axes(parent_axes, "30%%", "40%%")
|
||||
axins = inset_axes(parent_axes, "30%%", "40%%",
|
||||
bbox_to_anchor=parent_axes.bbox)
|
||||
axins = inset_axes(parent_axes, "30%%", "40%%",
|
||||
bbox_to_anchor=(0, 0, 1, 1),
|
||||
bbox_transform=parent_axes.transAxes)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
parent_axes : `matplotlib.axes.Axes`
|
||||
Axes to place the inset axes.
|
||||
|
||||
width, height : float or str
|
||||
Size of the inset axes to create. If a float is provided, it is
|
||||
the size in inches, e.g. *width=1.3*. If a string is provided, it is
|
||||
the size in relative units, e.g. *width='40%%'*. By default, i.e. if
|
||||
neither *bbox_to_anchor* nor *bbox_transform* are specified, those
|
||||
are relative to the parent_axes. Otherwise, they are to be understood
|
||||
relative to the bounding box provided via *bbox_to_anchor*.
|
||||
|
||||
loc : str, default: 'upper right'
|
||||
Location to place the inset axes. Valid locations are
|
||||
'upper left', 'upper center', 'upper right',
|
||||
'center left', 'center', 'center right',
|
||||
'lower left', 'lower center', 'lower right'.
|
||||
For backward compatibility, numeric values are accepted as well.
|
||||
See the parameter *loc* of `.Legend` for details.
|
||||
|
||||
bbox_to_anchor : tuple or `~matplotlib.transforms.BboxBase`, optional
|
||||
Bbox that the inset axes will be anchored to. If None,
|
||||
a tuple of (0, 0, 1, 1) is used if *bbox_transform* is set
|
||||
to *parent_axes.transAxes* or *parent_axes.figure.transFigure*.
|
||||
Otherwise, *parent_axes.bbox* is used. If a tuple, can be either
|
||||
[left, bottom, width, height], or [left, bottom].
|
||||
If the kwargs *width* and/or *height* are specified in relative units,
|
||||
the 2-tuple [left, bottom] cannot be used. Note that,
|
||||
unless *bbox_transform* is set, the units of the bounding box
|
||||
are interpreted in the pixel coordinate. When using *bbox_to_anchor*
|
||||
with tuple, it almost always makes sense to also specify
|
||||
a *bbox_transform*. This might often be the axes transform
|
||||
*parent_axes.transAxes*.
|
||||
|
||||
bbox_transform : `~matplotlib.transforms.Transform`, optional
|
||||
Transformation for the bbox that contains the inset axes.
|
||||
If None, a `.transforms.IdentityTransform` is used. The value
|
||||
of *bbox_to_anchor* (or the return value of its get_points method)
|
||||
is transformed by the *bbox_transform* and then interpreted
|
||||
as points in the pixel coordinate (which is dpi dependent).
|
||||
You may provide *bbox_to_anchor* in some normalized coordinate,
|
||||
and give an appropriate transform (e.g., *parent_axes.transAxes*).
|
||||
|
||||
axes_class : `~matplotlib.axes.Axes` type, default: `.HostAxes`
|
||||
The type of the newly created inset axes.
|
||||
|
||||
axes_kwargs : dict, optional
|
||||
Keyword arguments to pass to the constructor of the inset axes.
|
||||
Valid arguments include:
|
||||
|
||||
%(Axes:kwdoc)s
|
||||
|
||||
borderpad : float, default: 0.5
|
||||
Padding between inset axes and the bbox_to_anchor.
|
||||
The units are axes font size, i.e. for a default font size of 10 points
|
||||
*borderpad = 0.5* is equivalent to a padding of 5 points.
|
||||
|
||||
Returns
|
||||
-------
|
||||
inset_axes : *axes_class*
|
||||
Inset axes object created.
|
||||
"""
|
||||
|
||||
if (bbox_transform in [parent_axes.transAxes,
|
||||
parent_axes.get_figure(root=False).transFigure]
|
||||
and bbox_to_anchor is None):
|
||||
_api.warn_external("Using the axes or figure transform requires a "
|
||||
"bounding box in the respective coordinates. "
|
||||
"Using bbox_to_anchor=(0, 0, 1, 1) now.")
|
||||
bbox_to_anchor = (0, 0, 1, 1)
|
||||
if bbox_to_anchor is None:
|
||||
bbox_to_anchor = parent_axes.bbox
|
||||
if (isinstance(bbox_to_anchor, tuple) and
|
||||
(isinstance(width, str) or isinstance(height, str))):
|
||||
if len(bbox_to_anchor) != 4:
|
||||
raise ValueError("Using relative units for width or height "
|
||||
"requires to provide a 4-tuple or a "
|
||||
"`Bbox` instance to `bbox_to_anchor.")
|
||||
return _add_inset_axes(
|
||||
parent_axes, axes_class, axes_kwargs,
|
||||
AnchoredSizeLocator(
|
||||
bbox_to_anchor, width, height, loc=loc,
|
||||
bbox_transform=bbox_transform, borderpad=borderpad))
|
||||
|
||||
|
||||
@_docstring.interpd
|
||||
def zoomed_inset_axes(parent_axes, zoom, loc='upper right',
|
||||
bbox_to_anchor=None, bbox_transform=None,
|
||||
axes_class=None, axes_kwargs=None,
|
||||
borderpad=0.5):
|
||||
"""
|
||||
Create an anchored inset axes by scaling a parent axes. For usage, also see
|
||||
:doc:`the examples </gallery/axes_grid1/inset_locator_demo2>`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
parent_axes : `~matplotlib.axes.Axes`
|
||||
Axes to place the inset axes.
|
||||
|
||||
zoom : float
|
||||
Scaling factor of the data axes. *zoom* > 1 will enlarge the
|
||||
coordinates (i.e., "zoomed in"), while *zoom* < 1 will shrink the
|
||||
coordinates (i.e., "zoomed out").
|
||||
|
||||
loc : str, default: 'upper right'
|
||||
Location to place the inset axes. Valid locations are
|
||||
'upper left', 'upper center', 'upper right',
|
||||
'center left', 'center', 'center right',
|
||||
'lower left', 'lower center', 'lower right'.
|
||||
For backward compatibility, numeric values are accepted as well.
|
||||
See the parameter *loc* of `.Legend` for details.
|
||||
|
||||
bbox_to_anchor : tuple or `~matplotlib.transforms.BboxBase`, optional
|
||||
Bbox that the inset axes will be anchored to. If None,
|
||||
*parent_axes.bbox* is used. If a tuple, can be either
|
||||
[left, bottom, width, height], or [left, bottom].
|
||||
If the kwargs *width* and/or *height* are specified in relative units,
|
||||
the 2-tuple [left, bottom] cannot be used. Note that
|
||||
the units of the bounding box are determined through the transform
|
||||
in use. When using *bbox_to_anchor* it almost always makes sense to
|
||||
also specify a *bbox_transform*. This might often be the axes transform
|
||||
*parent_axes.transAxes*.
|
||||
|
||||
bbox_transform : `~matplotlib.transforms.Transform`, optional
|
||||
Transformation for the bbox that contains the inset axes.
|
||||
If None, a `.transforms.IdentityTransform` is used (i.e. pixel
|
||||
coordinates). This is useful when not providing any argument to
|
||||
*bbox_to_anchor*. When using *bbox_to_anchor* it almost always makes
|
||||
sense to also specify a *bbox_transform*. This might often be the
|
||||
axes transform *parent_axes.transAxes*. Inversely, when specifying
|
||||
the axes- or figure-transform here, be aware that not specifying
|
||||
*bbox_to_anchor* will use *parent_axes.bbox*, the units of which are
|
||||
in display (pixel) coordinates.
|
||||
|
||||
axes_class : `~matplotlib.axes.Axes` type, default: `.HostAxes`
|
||||
The type of the newly created inset axes.
|
||||
|
||||
axes_kwargs : dict, optional
|
||||
Keyword arguments to pass to the constructor of the inset axes.
|
||||
Valid arguments include:
|
||||
|
||||
%(Axes:kwdoc)s
|
||||
|
||||
borderpad : float, default: 0.5
|
||||
Padding between inset axes and the bbox_to_anchor.
|
||||
The units are axes font size, i.e. for a default font size of 10 points
|
||||
*borderpad = 0.5* is equivalent to a padding of 5 points.
|
||||
|
||||
Returns
|
||||
-------
|
||||
inset_axes : *axes_class*
|
||||
Inset axes object created.
|
||||
"""
|
||||
|
||||
return _add_inset_axes(
|
||||
parent_axes, axes_class, axes_kwargs,
|
||||
AnchoredZoomLocator(
|
||||
parent_axes, zoom=zoom, loc=loc,
|
||||
bbox_to_anchor=bbox_to_anchor, bbox_transform=bbox_transform,
|
||||
borderpad=borderpad))
|
||||
|
||||
|
||||
class _TransformedBboxWithCallback(TransformedBbox):
|
||||
"""
|
||||
Variant of `.TransformBbox` which calls *callback* before returning points.
|
||||
|
||||
Used by `.mark_inset` to unstale the parent axes' viewlim as needed.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, callback, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._callback = callback
|
||||
|
||||
def get_points(self):
|
||||
self._callback()
|
||||
return super().get_points()
|
||||
|
||||
|
||||
@_docstring.interpd
|
||||
def mark_inset(parent_axes, inset_axes, loc1, loc2, **kwargs):
|
||||
"""
|
||||
Draw a box to mark the location of an area represented by an inset axes.
|
||||
|
||||
This function draws a box in *parent_axes* at the bounding box of
|
||||
*inset_axes*, and shows a connection with the inset axes by drawing lines
|
||||
at the corners, giving a "zoomed in" effect.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
parent_axes : `~matplotlib.axes.Axes`
|
||||
Axes which contains the area of the inset axes.
|
||||
|
||||
inset_axes : `~matplotlib.axes.Axes`
|
||||
The inset axes.
|
||||
|
||||
loc1, loc2 : {1, 2, 3, 4}
|
||||
Corners to use for connecting the inset axes and the area in the
|
||||
parent axes.
|
||||
|
||||
**kwargs
|
||||
Patch properties for the lines and box drawn:
|
||||
|
||||
%(Patch:kwdoc)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
pp : `~matplotlib.patches.Patch`
|
||||
The patch drawn to represent the area of the inset axes.
|
||||
|
||||
p1, p2 : `~matplotlib.patches.Patch`
|
||||
The patches connecting two corners of the inset axes and its area.
|
||||
"""
|
||||
rect = _TransformedBboxWithCallback(
|
||||
inset_axes.viewLim, parent_axes.transData,
|
||||
callback=parent_axes._unstale_viewLim)
|
||||
|
||||
kwargs.setdefault("fill", bool({'fc', 'facecolor', 'color'}.intersection(kwargs)))
|
||||
pp = BboxPatch(rect, **kwargs)
|
||||
parent_axes.add_patch(pp)
|
||||
|
||||
p1 = BboxConnector(inset_axes.bbox, rect, loc1=loc1, **kwargs)
|
||||
inset_axes.add_patch(p1)
|
||||
p1.set_clip_on(False)
|
||||
p2 = BboxConnector(inset_axes.bbox, rect, loc1=loc2, **kwargs)
|
||||
inset_axes.add_patch(p2)
|
||||
p2.set_clip_on(False)
|
||||
|
||||
return pp, p1, p2
|
|
@ -0,0 +1,128 @@
|
|||
import matplotlib.axes as maxes
|
||||
from matplotlib.artist import Artist
|
||||
from matplotlib.axis import XAxis, YAxis
|
||||
|
||||
|
||||
class SimpleChainedObjects:
|
||||
def __init__(self, objects):
|
||||
self._objects = objects
|
||||
|
||||
def __getattr__(self, k):
|
||||
_a = SimpleChainedObjects([getattr(a, k) for a in self._objects])
|
||||
return _a
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
for m in self._objects:
|
||||
m(*args, **kwargs)
|
||||
|
||||
|
||||
class Axes(maxes.Axes):
|
||||
|
||||
class AxisDict(dict):
|
||||
def __init__(self, axes):
|
||||
self.axes = axes
|
||||
super().__init__()
|
||||
|
||||
def __getitem__(self, k):
|
||||
if isinstance(k, tuple):
|
||||
r = SimpleChainedObjects(
|
||||
# super() within a list comprehension needs explicit args.
|
||||
[super(Axes.AxisDict, self).__getitem__(k1) for k1 in k])
|
||||
return r
|
||||
elif isinstance(k, slice):
|
||||
if k.start is None and k.stop is None and k.step is None:
|
||||
return SimpleChainedObjects(list(self.values()))
|
||||
else:
|
||||
raise ValueError("Unsupported slice")
|
||||
else:
|
||||
return dict.__getitem__(self, k)
|
||||
|
||||
def __call__(self, *v, **kwargs):
|
||||
return maxes.Axes.axis(self.axes, *v, **kwargs)
|
||||
|
||||
@property
|
||||
def axis(self):
|
||||
return self._axislines
|
||||
|
||||
def clear(self):
|
||||
# docstring inherited
|
||||
super().clear()
|
||||
# Init axis artists.
|
||||
self._axislines = self.AxisDict(self)
|
||||
self._axislines.update(
|
||||
bottom=SimpleAxisArtist(self.xaxis, 1, self.spines["bottom"]),
|
||||
top=SimpleAxisArtist(self.xaxis, 2, self.spines["top"]),
|
||||
left=SimpleAxisArtist(self.yaxis, 1, self.spines["left"]),
|
||||
right=SimpleAxisArtist(self.yaxis, 2, self.spines["right"]))
|
||||
|
||||
|
||||
class SimpleAxisArtist(Artist):
|
||||
def __init__(self, axis, axisnum, spine):
|
||||
self._axis = axis
|
||||
self._axisnum = axisnum
|
||||
self.line = spine
|
||||
|
||||
if isinstance(axis, XAxis):
|
||||
self._axis_direction = ["bottom", "top"][axisnum-1]
|
||||
elif isinstance(axis, YAxis):
|
||||
self._axis_direction = ["left", "right"][axisnum-1]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"axis must be instance of XAxis or YAxis, but got {axis}")
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def major_ticks(self):
|
||||
tickline = "tick%dline" % self._axisnum
|
||||
return SimpleChainedObjects([getattr(tick, tickline)
|
||||
for tick in self._axis.get_major_ticks()])
|
||||
|
||||
@property
|
||||
def major_ticklabels(self):
|
||||
label = "label%d" % self._axisnum
|
||||
return SimpleChainedObjects([getattr(tick, label)
|
||||
for tick in self._axis.get_major_ticks()])
|
||||
|
||||
@property
|
||||
def label(self):
|
||||
return self._axis.label
|
||||
|
||||
def set_visible(self, b):
|
||||
self.toggle(all=b)
|
||||
self.line.set_visible(b)
|
||||
self._axis.set_visible(True)
|
||||
super().set_visible(b)
|
||||
|
||||
def set_label(self, txt):
|
||||
self._axis.set_label_text(txt)
|
||||
|
||||
def toggle(self, all=None, ticks=None, ticklabels=None, label=None):
|
||||
|
||||
if all:
|
||||
_ticks, _ticklabels, _label = True, True, True
|
||||
elif all is not None:
|
||||
_ticks, _ticklabels, _label = False, False, False
|
||||
else:
|
||||
_ticks, _ticklabels, _label = None, None, None
|
||||
|
||||
if ticks is not None:
|
||||
_ticks = ticks
|
||||
if ticklabels is not None:
|
||||
_ticklabels = ticklabels
|
||||
if label is not None:
|
||||
_label = label
|
||||
|
||||
if _ticks is not None:
|
||||
tickparam = {f"tick{self._axisnum}On": _ticks}
|
||||
self._axis.set_tick_params(**tickparam)
|
||||
if _ticklabels is not None:
|
||||
tickparam = {f"label{self._axisnum}On": _ticklabels}
|
||||
self._axis.set_tick_params(**tickparam)
|
||||
|
||||
if _label is not None:
|
||||
pos = self._axis.get_label_position()
|
||||
if (pos == self._axis_direction) and not _label:
|
||||
self._axis.label.set_visible(False)
|
||||
elif _label:
|
||||
self._axis.label.set_visible(True)
|
||||
self._axis.set_label_position(self._axis_direction)
|
|
@ -0,0 +1,257 @@
|
|||
from matplotlib import _api, cbook
|
||||
import matplotlib.artist as martist
|
||||
import matplotlib.transforms as mtransforms
|
||||
from matplotlib.transforms import Bbox
|
||||
from .mpl_axes import Axes
|
||||
|
||||
|
||||
class ParasiteAxesBase:
|
||||
|
||||
def __init__(self, parent_axes, aux_transform=None,
|
||||
*, viewlim_mode=None, **kwargs):
|
||||
self._parent_axes = parent_axes
|
||||
self.transAux = aux_transform
|
||||
self.set_viewlim_mode(viewlim_mode)
|
||||
kwargs["frameon"] = False
|
||||
super().__init__(parent_axes.get_figure(root=False),
|
||||
parent_axes._position, **kwargs)
|
||||
|
||||
def clear(self):
|
||||
super().clear()
|
||||
martist.setp(self.get_children(), visible=False)
|
||||
self._get_lines = self._parent_axes._get_lines
|
||||
self._parent_axes.callbacks._connect_picklable(
|
||||
"xlim_changed", self._sync_lims)
|
||||
self._parent_axes.callbacks._connect_picklable(
|
||||
"ylim_changed", self._sync_lims)
|
||||
|
||||
def pick(self, mouseevent):
|
||||
# This most likely goes to Artist.pick (depending on axes_class given
|
||||
# to the factory), which only handles pick events registered on the
|
||||
# axes associated with each child:
|
||||
super().pick(mouseevent)
|
||||
# But parasite axes are additionally given pick events from their host
|
||||
# axes (cf. HostAxesBase.pick), which we handle here:
|
||||
for a in self.get_children():
|
||||
if (hasattr(mouseevent.inaxes, "parasites")
|
||||
and self in mouseevent.inaxes.parasites):
|
||||
a.pick(mouseevent)
|
||||
|
||||
# aux_transform support
|
||||
|
||||
def _set_lim_and_transforms(self):
|
||||
if self.transAux is not None:
|
||||
self.transAxes = self._parent_axes.transAxes
|
||||
self.transData = self.transAux + self._parent_axes.transData
|
||||
self._xaxis_transform = mtransforms.blended_transform_factory(
|
||||
self.transData, self.transAxes)
|
||||
self._yaxis_transform = mtransforms.blended_transform_factory(
|
||||
self.transAxes, self.transData)
|
||||
else:
|
||||
super()._set_lim_and_transforms()
|
||||
|
||||
def set_viewlim_mode(self, mode):
|
||||
_api.check_in_list([None, "equal", "transform"], mode=mode)
|
||||
self._viewlim_mode = mode
|
||||
|
||||
def get_viewlim_mode(self):
|
||||
return self._viewlim_mode
|
||||
|
||||
def _sync_lims(self, parent):
|
||||
viewlim = parent.viewLim.frozen()
|
||||
mode = self.get_viewlim_mode()
|
||||
if mode is None:
|
||||
pass
|
||||
elif mode == "equal":
|
||||
self.viewLim.set(viewlim)
|
||||
elif mode == "transform":
|
||||
self.viewLim.set(viewlim.transformed(self.transAux.inverted()))
|
||||
else:
|
||||
_api.check_in_list([None, "equal", "transform"], mode=mode)
|
||||
|
||||
# end of aux_transform support
|
||||
|
||||
|
||||
parasite_axes_class_factory = cbook._make_class_factory(
|
||||
ParasiteAxesBase, "{}Parasite")
|
||||
ParasiteAxes = parasite_axes_class_factory(Axes)
|
||||
|
||||
|
||||
class HostAxesBase:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.parasites = []
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def get_aux_axes(
|
||||
self, tr=None, viewlim_mode="equal", axes_class=None, **kwargs):
|
||||
"""
|
||||
Add a parasite axes to this host.
|
||||
|
||||
Despite this method's name, this should actually be thought of as an
|
||||
``add_parasite_axes`` method.
|
||||
|
||||
.. versionchanged:: 3.7
|
||||
Defaults to same base axes class as host axes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tr : `~matplotlib.transforms.Transform` or None, default: None
|
||||
If a `.Transform`, the following relation will hold:
|
||||
``parasite.transData = tr + host.transData``.
|
||||
If None, the parasite's and the host's ``transData`` are unrelated.
|
||||
viewlim_mode : {"equal", "transform", None}, default: "equal"
|
||||
How the parasite's view limits are set: directly equal to the
|
||||
parent axes ("equal"), equal after application of *tr*
|
||||
("transform"), or independently (None).
|
||||
axes_class : subclass type of `~matplotlib.axes.Axes`, optional
|
||||
The `~.axes.Axes` subclass that is instantiated. If None, the base
|
||||
class of the host axes is used.
|
||||
**kwargs
|
||||
Other parameters are forwarded to the parasite axes constructor.
|
||||
"""
|
||||
if axes_class is None:
|
||||
axes_class = self._base_axes_class
|
||||
parasite_axes_class = parasite_axes_class_factory(axes_class)
|
||||
ax2 = parasite_axes_class(
|
||||
self, tr, viewlim_mode=viewlim_mode, **kwargs)
|
||||
# note that ax2.transData == tr + ax1.transData
|
||||
# Anything you draw in ax2 will match the ticks and grids of ax1.
|
||||
self.parasites.append(ax2)
|
||||
ax2._remove_method = self.parasites.remove
|
||||
return ax2
|
||||
|
||||
def draw(self, renderer):
|
||||
orig_children_len = len(self._children)
|
||||
|
||||
locator = self.get_axes_locator()
|
||||
if locator:
|
||||
pos = locator(self, renderer)
|
||||
self.set_position(pos, which="active")
|
||||
self.apply_aspect(pos)
|
||||
else:
|
||||
self.apply_aspect()
|
||||
|
||||
rect = self.get_position()
|
||||
for ax in self.parasites:
|
||||
ax.apply_aspect(rect)
|
||||
self._children.extend(ax.get_children())
|
||||
|
||||
super().draw(renderer)
|
||||
del self._children[orig_children_len:]
|
||||
|
||||
def clear(self):
|
||||
super().clear()
|
||||
for ax in self.parasites:
|
||||
ax.clear()
|
||||
|
||||
def pick(self, mouseevent):
|
||||
super().pick(mouseevent)
|
||||
# Also pass pick events on to parasite axes and, in turn, their
|
||||
# children (cf. ParasiteAxesBase.pick)
|
||||
for a in self.parasites:
|
||||
a.pick(mouseevent)
|
||||
|
||||
def twinx(self, axes_class=None):
|
||||
"""
|
||||
Create a twin of Axes with a shared x-axis but independent y-axis.
|
||||
|
||||
The y-axis of self will have ticks on the left and the returned axes
|
||||
will have ticks on the right.
|
||||
"""
|
||||
ax = self._add_twin_axes(axes_class, sharex=self)
|
||||
self.axis["right"].set_visible(False)
|
||||
ax.axis["right"].set_visible(True)
|
||||
ax.axis["left", "top", "bottom"].set_visible(False)
|
||||
return ax
|
||||
|
||||
def twiny(self, axes_class=None):
|
||||
"""
|
||||
Create a twin of Axes with a shared y-axis but independent x-axis.
|
||||
|
||||
The x-axis of self will have ticks on the bottom and the returned axes
|
||||
will have ticks on the top.
|
||||
"""
|
||||
ax = self._add_twin_axes(axes_class, sharey=self)
|
||||
self.axis["top"].set_visible(False)
|
||||
ax.axis["top"].set_visible(True)
|
||||
ax.axis["left", "right", "bottom"].set_visible(False)
|
||||
return ax
|
||||
|
||||
def twin(self, aux_trans=None, axes_class=None):
|
||||
"""
|
||||
Create a twin of Axes with no shared axis.
|
||||
|
||||
While self will have ticks on the left and bottom axis, the returned
|
||||
axes will have ticks on the top and right axis.
|
||||
"""
|
||||
if aux_trans is None:
|
||||
aux_trans = mtransforms.IdentityTransform()
|
||||
ax = self._add_twin_axes(
|
||||
axes_class, aux_transform=aux_trans, viewlim_mode="transform")
|
||||
self.axis["top", "right"].set_visible(False)
|
||||
ax.axis["top", "right"].set_visible(True)
|
||||
ax.axis["left", "bottom"].set_visible(False)
|
||||
return ax
|
||||
|
||||
def _add_twin_axes(self, axes_class, **kwargs):
|
||||
"""
|
||||
Helper for `.twinx`/`.twiny`/`.twin`.
|
||||
|
||||
*kwargs* are forwarded to the parasite axes constructor.
|
||||
"""
|
||||
if axes_class is None:
|
||||
axes_class = self._base_axes_class
|
||||
ax = parasite_axes_class_factory(axes_class)(self, **kwargs)
|
||||
self.parasites.append(ax)
|
||||
ax._remove_method = self._remove_any_twin
|
||||
return ax
|
||||
|
||||
def _remove_any_twin(self, ax):
|
||||
self.parasites.remove(ax)
|
||||
restore = ["top", "right"]
|
||||
if ax._sharex:
|
||||
restore.remove("top")
|
||||
if ax._sharey:
|
||||
restore.remove("right")
|
||||
self.axis[tuple(restore)].set_visible(True)
|
||||
self.axis[tuple(restore)].toggle(ticklabels=False, label=False)
|
||||
|
||||
def get_tightbbox(self, renderer=None, *, call_axes_locator=True,
|
||||
bbox_extra_artists=None):
|
||||
bbs = [
|
||||
*[ax.get_tightbbox(renderer, call_axes_locator=call_axes_locator)
|
||||
for ax in self.parasites],
|
||||
super().get_tightbbox(renderer,
|
||||
call_axes_locator=call_axes_locator,
|
||||
bbox_extra_artists=bbox_extra_artists)]
|
||||
return Bbox.union([b for b in bbs if b.width != 0 or b.height != 0])
|
||||
|
||||
|
||||
host_axes_class_factory = host_subplot_class_factory = \
|
||||
cbook._make_class_factory(HostAxesBase, "{}HostAxes", "_base_axes_class")
|
||||
HostAxes = SubplotHost = host_axes_class_factory(Axes)
|
||||
|
||||
|
||||
def host_axes(*args, axes_class=Axes, figure=None, **kwargs):
|
||||
"""
|
||||
Create axes that can act as a hosts to parasitic axes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
figure : `~matplotlib.figure.Figure`
|
||||
Figure to which the axes will be added. Defaults to the current figure
|
||||
`.pyplot.gcf()`.
|
||||
|
||||
*args, **kwargs
|
||||
Will be passed on to the underlying `~.axes.Axes` object creation.
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
host_axes_class = host_axes_class_factory(axes_class)
|
||||
if figure is None:
|
||||
figure = plt.gcf()
|
||||
ax = host_axes_class(figure, *args, **kwargs)
|
||||
figure.add_axes(ax)
|
||||
return ax
|
||||
|
||||
|
||||
host_subplot = host_axes
|
|
@ -0,0 +1,10 @@
|
|||
from pathlib import Path
|
||||
|
||||
|
||||
# Check that the test directories exist
|
||||
if not (Path(__file__).parent / "baseline_images").exists():
|
||||
raise OSError(
|
||||
'The baseline image directory does not exist. '
|
||||
'This is most likely because the test data is not installed. '
|
||||
'You may need to install matplotlib from source to get the '
|
||||
'test data.')
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,2 @@
|
|||
from matplotlib.testing.conftest import (mpl_test_settings, # noqa
|
||||
pytest_configure, pytest_unconfigure)
|
|
@ -0,0 +1,782 @@
|
|||
from itertools import product
|
||||
import io
|
||||
import platform
|
||||
|
||||
import matplotlib as mpl
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.ticker as mticker
|
||||
from matplotlib import cbook
|
||||
from matplotlib.backend_bases import MouseEvent
|
||||
from matplotlib.colors import LogNorm
|
||||
from matplotlib.patches import Circle, Ellipse
|
||||
from matplotlib.transforms import Bbox, TransformedBbox
|
||||
from matplotlib.testing.decorators import (
|
||||
check_figures_equal, image_comparison, remove_ticks_and_titles)
|
||||
|
||||
from mpl_toolkits.axes_grid1 import (
|
||||
axes_size as Size,
|
||||
host_subplot, make_axes_locatable,
|
||||
Grid, AxesGrid, ImageGrid)
|
||||
from mpl_toolkits.axes_grid1.anchored_artists import (
|
||||
AnchoredAuxTransformBox, AnchoredDrawingArea,
|
||||
AnchoredDirectionArrows, AnchoredSizeBar)
|
||||
from mpl_toolkits.axes_grid1.axes_divider import (
|
||||
Divider, HBoxDivider, make_axes_area_auto_adjustable, SubplotDivider,
|
||||
VBoxDivider)
|
||||
from mpl_toolkits.axes_grid1.axes_rgb import RGBAxes
|
||||
from mpl_toolkits.axes_grid1.inset_locator import (
|
||||
zoomed_inset_axes, mark_inset, inset_axes, BboxConnectorPatch)
|
||||
import mpl_toolkits.axes_grid1.mpl_axes
|
||||
import pytest
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import assert_array_equal, assert_array_almost_equal
|
||||
|
||||
|
||||
def test_divider_append_axes():
|
||||
fig, ax = plt.subplots()
|
||||
divider = make_axes_locatable(ax)
|
||||
axs = {
|
||||
"main": ax,
|
||||
"top": divider.append_axes("top", 1.2, pad=0.1, sharex=ax),
|
||||
"bottom": divider.append_axes("bottom", 1.2, pad=0.1, sharex=ax),
|
||||
"left": divider.append_axes("left", 1.2, pad=0.1, sharey=ax),
|
||||
"right": divider.append_axes("right", 1.2, pad=0.1, sharey=ax),
|
||||
}
|
||||
fig.canvas.draw()
|
||||
bboxes = {k: axs[k].get_window_extent() for k in axs}
|
||||
dpi = fig.dpi
|
||||
assert bboxes["top"].height == pytest.approx(1.2 * dpi)
|
||||
assert bboxes["bottom"].height == pytest.approx(1.2 * dpi)
|
||||
assert bboxes["left"].width == pytest.approx(1.2 * dpi)
|
||||
assert bboxes["right"].width == pytest.approx(1.2 * dpi)
|
||||
assert bboxes["top"].y0 - bboxes["main"].y1 == pytest.approx(0.1 * dpi)
|
||||
assert bboxes["main"].y0 - bboxes["bottom"].y1 == pytest.approx(0.1 * dpi)
|
||||
assert bboxes["main"].x0 - bboxes["left"].x1 == pytest.approx(0.1 * dpi)
|
||||
assert bboxes["right"].x0 - bboxes["main"].x1 == pytest.approx(0.1 * dpi)
|
||||
assert bboxes["left"].y0 == bboxes["main"].y0 == bboxes["right"].y0
|
||||
assert bboxes["left"].y1 == bboxes["main"].y1 == bboxes["right"].y1
|
||||
assert bboxes["top"].x0 == bboxes["main"].x0 == bboxes["bottom"].x0
|
||||
assert bboxes["top"].x1 == bboxes["main"].x1 == bboxes["bottom"].x1
|
||||
|
||||
|
||||
# Update style when regenerating the test image
|
||||
@image_comparison(['twin_axes_empty_and_removed'], extensions=["png"], tol=1,
|
||||
style=('classic', '_classic_test_patch'))
|
||||
def test_twin_axes_empty_and_removed():
|
||||
# Purely cosmetic font changes (avoid overlap)
|
||||
mpl.rcParams.update(
|
||||
{"font.size": 8, "xtick.labelsize": 8, "ytick.labelsize": 8})
|
||||
generators = ["twinx", "twiny", "twin"]
|
||||
modifiers = ["", "host invisible", "twin removed", "twin invisible",
|
||||
"twin removed\nhost invisible"]
|
||||
# Unmodified host subplot at the beginning for reference
|
||||
h = host_subplot(len(modifiers)+1, len(generators), 2)
|
||||
h.text(0.5, 0.5, "host_subplot",
|
||||
horizontalalignment="center", verticalalignment="center")
|
||||
# Host subplots with various modifications (twin*, visibility) applied
|
||||
for i, (mod, gen) in enumerate(product(modifiers, generators),
|
||||
len(generators) + 1):
|
||||
h = host_subplot(len(modifiers)+1, len(generators), i)
|
||||
t = getattr(h, gen)()
|
||||
if "twin invisible" in mod:
|
||||
t.axis[:].set_visible(False)
|
||||
if "twin removed" in mod:
|
||||
t.remove()
|
||||
if "host invisible" in mod:
|
||||
h.axis[:].set_visible(False)
|
||||
h.text(0.5, 0.5, gen + ("\n" + mod if mod else ""),
|
||||
horizontalalignment="center", verticalalignment="center")
|
||||
plt.subplots_adjust(wspace=0.5, hspace=1)
|
||||
|
||||
|
||||
def test_twin_axes_both_with_units():
|
||||
host = host_subplot(111)
|
||||
with pytest.warns(mpl.MatplotlibDeprecationWarning):
|
||||
host.plot_date([0, 1, 2], [0, 1, 2], xdate=False, ydate=True)
|
||||
twin = host.twinx()
|
||||
twin.plot(["a", "b", "c"])
|
||||
assert host.get_yticklabels()[0].get_text() == "00:00:00"
|
||||
assert twin.get_yticklabels()[0].get_text() == "a"
|
||||
|
||||
|
||||
def test_axesgrid_colorbar_log_smoketest():
|
||||
fig = plt.figure()
|
||||
grid = AxesGrid(fig, 111, # modified to be only subplot
|
||||
nrows_ncols=(1, 1),
|
||||
ngrids=1,
|
||||
label_mode="L",
|
||||
cbar_location="top",
|
||||
cbar_mode="single",
|
||||
)
|
||||
|
||||
Z = 10000 * np.random.rand(10, 10)
|
||||
im = grid[0].imshow(Z, interpolation="nearest", norm=LogNorm())
|
||||
|
||||
grid.cbar_axes[0].colorbar(im)
|
||||
|
||||
|
||||
def test_inset_colorbar_tight_layout_smoketest():
|
||||
fig, ax = plt.subplots(1, 1)
|
||||
pts = ax.scatter([0, 1], [0, 1], c=[1, 5])
|
||||
|
||||
cax = inset_axes(ax, width="3%", height="70%")
|
||||
plt.colorbar(pts, cax=cax)
|
||||
|
||||
with pytest.warns(UserWarning, match="This figure includes Axes"):
|
||||
# Will warn, but not raise an error
|
||||
plt.tight_layout()
|
||||
|
||||
|
||||
@image_comparison(['inset_locator.png'], style='default', remove_text=True)
|
||||
def test_inset_locator():
|
||||
fig, ax = plt.subplots(figsize=[5, 4])
|
||||
|
||||
# prepare the demo image
|
||||
# Z is a 15x15 array
|
||||
Z = cbook.get_sample_data("axes_grid/bivariate_normal.npy")
|
||||
extent = (-3, 4, -4, 3)
|
||||
Z2 = np.zeros((150, 150))
|
||||
ny, nx = Z.shape
|
||||
Z2[30:30+ny, 30:30+nx] = Z
|
||||
|
||||
ax.imshow(Z2, extent=extent, interpolation="nearest",
|
||||
origin="lower")
|
||||
|
||||
axins = zoomed_inset_axes(ax, zoom=6, loc='upper right')
|
||||
axins.imshow(Z2, extent=extent, interpolation="nearest",
|
||||
origin="lower")
|
||||
axins.yaxis.get_major_locator().set_params(nbins=7)
|
||||
axins.xaxis.get_major_locator().set_params(nbins=7)
|
||||
# sub region of the original image
|
||||
x1, x2, y1, y2 = -1.5, -0.9, -2.5, -1.9
|
||||
axins.set_xlim(x1, x2)
|
||||
axins.set_ylim(y1, y2)
|
||||
|
||||
plt.xticks(visible=False)
|
||||
plt.yticks(visible=False)
|
||||
|
||||
# draw a bbox of the region of the inset axes in the parent axes and
|
||||
# connecting lines between the bbox and the inset axes area
|
||||
mark_inset(ax, axins, loc1=2, loc2=4, fc="none", ec="0.5")
|
||||
|
||||
asb = AnchoredSizeBar(ax.transData,
|
||||
0.5,
|
||||
'0.5',
|
||||
loc='lower center',
|
||||
pad=0.1, borderpad=0.5, sep=5,
|
||||
frameon=False)
|
||||
ax.add_artist(asb)
|
||||
|
||||
|
||||
@image_comparison(['inset_axes.png'], style='default', remove_text=True)
|
||||
def test_inset_axes():
|
||||
fig, ax = plt.subplots(figsize=[5, 4])
|
||||
|
||||
# prepare the demo image
|
||||
# Z is a 15x15 array
|
||||
Z = cbook.get_sample_data("axes_grid/bivariate_normal.npy")
|
||||
extent = (-3, 4, -4, 3)
|
||||
Z2 = np.zeros((150, 150))
|
||||
ny, nx = Z.shape
|
||||
Z2[30:30+ny, 30:30+nx] = Z
|
||||
|
||||
ax.imshow(Z2, extent=extent, interpolation="nearest",
|
||||
origin="lower")
|
||||
|
||||
# creating our inset axes with a bbox_transform parameter
|
||||
axins = inset_axes(ax, width=1., height=1., bbox_to_anchor=(1, 1),
|
||||
bbox_transform=ax.transAxes)
|
||||
|
||||
axins.imshow(Z2, extent=extent, interpolation="nearest",
|
||||
origin="lower")
|
||||
axins.yaxis.get_major_locator().set_params(nbins=7)
|
||||
axins.xaxis.get_major_locator().set_params(nbins=7)
|
||||
# sub region of the original image
|
||||
x1, x2, y1, y2 = -1.5, -0.9, -2.5, -1.9
|
||||
axins.set_xlim(x1, x2)
|
||||
axins.set_ylim(y1, y2)
|
||||
|
||||
plt.xticks(visible=False)
|
||||
plt.yticks(visible=False)
|
||||
|
||||
# draw a bbox of the region of the inset axes in the parent axes and
|
||||
# connecting lines between the bbox and the inset axes area
|
||||
mark_inset(ax, axins, loc1=2, loc2=4, fc="none", ec="0.5")
|
||||
|
||||
asb = AnchoredSizeBar(ax.transData,
|
||||
0.5,
|
||||
'0.5',
|
||||
loc='lower center',
|
||||
pad=0.1, borderpad=0.5, sep=5,
|
||||
frameon=False)
|
||||
ax.add_artist(asb)
|
||||
|
||||
|
||||
def test_inset_axes_complete():
|
||||
dpi = 100
|
||||
figsize = (6, 5)
|
||||
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
||||
fig.subplots_adjust(.1, .1, .9, .9)
|
||||
|
||||
ins = inset_axes(ax, width=2., height=2., borderpad=0)
|
||||
fig.canvas.draw()
|
||||
assert_array_almost_equal(
|
||||
ins.get_position().extents,
|
||||
[(0.9*figsize[0]-2.)/figsize[0], (0.9*figsize[1]-2.)/figsize[1],
|
||||
0.9, 0.9])
|
||||
|
||||
ins = inset_axes(ax, width="40%", height="30%", borderpad=0)
|
||||
fig.canvas.draw()
|
||||
assert_array_almost_equal(
|
||||
ins.get_position().extents, [.9-.8*.4, .9-.8*.3, 0.9, 0.9])
|
||||
|
||||
ins = inset_axes(ax, width=1., height=1.2, bbox_to_anchor=(200, 100),
|
||||
loc=3, borderpad=0)
|
||||
fig.canvas.draw()
|
||||
assert_array_almost_equal(
|
||||
ins.get_position().extents,
|
||||
[200/dpi/figsize[0], 100/dpi/figsize[1],
|
||||
(200/dpi+1)/figsize[0], (100/dpi+1.2)/figsize[1]])
|
||||
|
||||
ins1 = inset_axes(ax, width="35%", height="60%", loc=3, borderpad=1)
|
||||
ins2 = inset_axes(ax, width="100%", height="100%",
|
||||
bbox_to_anchor=(0, 0, .35, .60),
|
||||
bbox_transform=ax.transAxes, loc=3, borderpad=1)
|
||||
fig.canvas.draw()
|
||||
assert_array_equal(ins1.get_position().extents,
|
||||
ins2.get_position().extents)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ins = inset_axes(ax, width="40%", height="30%",
|
||||
bbox_to_anchor=(0.4, 0.5))
|
||||
|
||||
with pytest.warns(UserWarning):
|
||||
ins = inset_axes(ax, width="40%", height="30%",
|
||||
bbox_transform=ax.transAxes)
|
||||
|
||||
|
||||
def test_inset_axes_tight():
|
||||
# gh-26287 found that inset_axes raised with bbox_inches=tight
|
||||
fig, ax = plt.subplots()
|
||||
inset_axes(ax, width=1.3, height=0.9)
|
||||
|
||||
f = io.BytesIO()
|
||||
fig.savefig(f, bbox_inches="tight")
|
||||
|
||||
|
||||
@image_comparison(['fill_facecolor.png'], remove_text=True, style='mpl20')
|
||||
def test_fill_facecolor():
|
||||
fig, ax = plt.subplots(1, 5)
|
||||
fig.set_size_inches(5, 5)
|
||||
for i in range(1, 4):
|
||||
ax[i].yaxis.set_visible(False)
|
||||
ax[4].yaxis.tick_right()
|
||||
bbox = Bbox.from_extents(0, 0.4, 1, 0.6)
|
||||
|
||||
# fill with blue by setting 'fc' field
|
||||
bbox1 = TransformedBbox(bbox, ax[0].transData)
|
||||
bbox2 = TransformedBbox(bbox, ax[1].transData)
|
||||
# set color to BboxConnectorPatch
|
||||
p = BboxConnectorPatch(
|
||||
bbox1, bbox2, loc1a=1, loc2a=2, loc1b=4, loc2b=3,
|
||||
ec="r", fc="b")
|
||||
p.set_clip_on(False)
|
||||
ax[0].add_patch(p)
|
||||
# set color to marked area
|
||||
axins = zoomed_inset_axes(ax[0], 1, loc='upper right')
|
||||
axins.set_xlim(0, 0.2)
|
||||
axins.set_ylim(0, 0.2)
|
||||
plt.gca().axes.xaxis.set_ticks([])
|
||||
plt.gca().axes.yaxis.set_ticks([])
|
||||
mark_inset(ax[0], axins, loc1=2, loc2=4, fc="b", ec="0.5")
|
||||
|
||||
# fill with yellow by setting 'facecolor' field
|
||||
bbox3 = TransformedBbox(bbox, ax[1].transData)
|
||||
bbox4 = TransformedBbox(bbox, ax[2].transData)
|
||||
# set color to BboxConnectorPatch
|
||||
p = BboxConnectorPatch(
|
||||
bbox3, bbox4, loc1a=1, loc2a=2, loc1b=4, loc2b=3,
|
||||
ec="r", facecolor="y")
|
||||
p.set_clip_on(False)
|
||||
ax[1].add_patch(p)
|
||||
# set color to marked area
|
||||
axins = zoomed_inset_axes(ax[1], 1, loc='upper right')
|
||||
axins.set_xlim(0, 0.2)
|
||||
axins.set_ylim(0, 0.2)
|
||||
plt.gca().axes.xaxis.set_ticks([])
|
||||
plt.gca().axes.yaxis.set_ticks([])
|
||||
mark_inset(ax[1], axins, loc1=2, loc2=4, facecolor="y", ec="0.5")
|
||||
|
||||
# fill with green by setting 'color' field
|
||||
bbox5 = TransformedBbox(bbox, ax[2].transData)
|
||||
bbox6 = TransformedBbox(bbox, ax[3].transData)
|
||||
# set color to BboxConnectorPatch
|
||||
p = BboxConnectorPatch(
|
||||
bbox5, bbox6, loc1a=1, loc2a=2, loc1b=4, loc2b=3,
|
||||
ec="r", color="g")
|
||||
p.set_clip_on(False)
|
||||
ax[2].add_patch(p)
|
||||
# set color to marked area
|
||||
axins = zoomed_inset_axes(ax[2], 1, loc='upper right')
|
||||
axins.set_xlim(0, 0.2)
|
||||
axins.set_ylim(0, 0.2)
|
||||
plt.gca().axes.xaxis.set_ticks([])
|
||||
plt.gca().axes.yaxis.set_ticks([])
|
||||
mark_inset(ax[2], axins, loc1=2, loc2=4, color="g", ec="0.5")
|
||||
|
||||
# fill with green but color won't show if set fill to False
|
||||
bbox7 = TransformedBbox(bbox, ax[3].transData)
|
||||
bbox8 = TransformedBbox(bbox, ax[4].transData)
|
||||
# BboxConnectorPatch won't show green
|
||||
p = BboxConnectorPatch(
|
||||
bbox7, bbox8, loc1a=1, loc2a=2, loc1b=4, loc2b=3,
|
||||
ec="r", fc="g", fill=False)
|
||||
p.set_clip_on(False)
|
||||
ax[3].add_patch(p)
|
||||
# marked area won't show green
|
||||
axins = zoomed_inset_axes(ax[3], 1, loc='upper right')
|
||||
axins.set_xlim(0, 0.2)
|
||||
axins.set_ylim(0, 0.2)
|
||||
axins.xaxis.set_ticks([])
|
||||
axins.yaxis.set_ticks([])
|
||||
mark_inset(ax[3], axins, loc1=2, loc2=4, fc="g", ec="0.5", fill=False)
|
||||
|
||||
|
||||
# Update style when regenerating the test image
|
||||
@image_comparison(['zoomed_axes.png', 'inverted_zoomed_axes.png'],
|
||||
style=('classic', '_classic_test_patch'),
|
||||
tol=0 if platform.machine() == 'x86_64' else 0.02)
|
||||
def test_zooming_with_inverted_axes():
|
||||
fig, ax = plt.subplots()
|
||||
ax.plot([1, 2, 3], [1, 2, 3])
|
||||
ax.axis([1, 3, 1, 3])
|
||||
inset_ax = zoomed_inset_axes(ax, zoom=2.5, loc='lower right')
|
||||
inset_ax.axis([1.1, 1.4, 1.1, 1.4])
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.plot([1, 2, 3], [1, 2, 3])
|
||||
ax.axis([3, 1, 3, 1])
|
||||
inset_ax = zoomed_inset_axes(ax, zoom=2.5, loc='lower right')
|
||||
inset_ax.axis([1.4, 1.1, 1.4, 1.1])
|
||||
|
||||
|
||||
# Update style when regenerating the test image
|
||||
@image_comparison(['anchored_direction_arrows.png'],
|
||||
tol=0 if platform.machine() == 'x86_64' else 0.01,
|
||||
style=('classic', '_classic_test_patch'))
|
||||
def test_anchored_direction_arrows():
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(np.zeros((10, 10)), interpolation='nearest')
|
||||
|
||||
simple_arrow = AnchoredDirectionArrows(ax.transAxes, 'X', 'Y')
|
||||
ax.add_artist(simple_arrow)
|
||||
|
||||
|
||||
# Update style when regenerating the test image
|
||||
@image_comparison(['anchored_direction_arrows_many_args.png'],
|
||||
style=('classic', '_classic_test_patch'))
|
||||
def test_anchored_direction_arrows_many_args():
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(np.ones((10, 10)))
|
||||
|
||||
direction_arrows = AnchoredDirectionArrows(
|
||||
ax.transAxes, 'A', 'B', loc='upper right', color='red',
|
||||
aspect_ratio=-0.5, pad=0.6, borderpad=2, frameon=True, alpha=0.7,
|
||||
sep_x=-0.06, sep_y=-0.08, back_length=0.1, head_width=9,
|
||||
head_length=10, tail_width=5)
|
||||
ax.add_artist(direction_arrows)
|
||||
|
||||
|
||||
def test_axes_locatable_position():
|
||||
fig, ax = plt.subplots()
|
||||
divider = make_axes_locatable(ax)
|
||||
with mpl.rc_context({"figure.subplot.wspace": 0.02}):
|
||||
cax = divider.append_axes('right', size='5%')
|
||||
fig.canvas.draw()
|
||||
assert np.isclose(cax.get_position(original=False).width,
|
||||
0.03621495327102808)
|
||||
|
||||
|
||||
@image_comparison(['image_grid_each_left_label_mode_all.png'], style='mpl20',
|
||||
savefig_kwarg={'bbox_inches': 'tight'})
|
||||
def test_image_grid_each_left_label_mode_all():
|
||||
imdata = np.arange(100).reshape((10, 10))
|
||||
|
||||
fig = plt.figure(1, (3, 3))
|
||||
grid = ImageGrid(fig, (1, 1, 1), nrows_ncols=(3, 2), axes_pad=(0.5, 0.3),
|
||||
cbar_mode="each", cbar_location="left", cbar_size="15%",
|
||||
label_mode="all")
|
||||
# 3-tuple rect => SubplotDivider
|
||||
assert isinstance(grid.get_divider(), SubplotDivider)
|
||||
assert grid.get_axes_pad() == (0.5, 0.3)
|
||||
assert grid.get_aspect() # True by default for ImageGrid
|
||||
for ax, cax in zip(grid, grid.cbar_axes):
|
||||
im = ax.imshow(imdata, interpolation='none')
|
||||
cax.colorbar(im)
|
||||
|
||||
|
||||
@image_comparison(['image_grid_single_bottom_label_mode_1.png'], style='mpl20',
|
||||
savefig_kwarg={'bbox_inches': 'tight'})
|
||||
def test_image_grid_single_bottom():
|
||||
imdata = np.arange(100).reshape((10, 10))
|
||||
|
||||
fig = plt.figure(1, (2.5, 1.5))
|
||||
grid = ImageGrid(fig, (0, 0, 1, 1), nrows_ncols=(1, 3),
|
||||
axes_pad=(0.2, 0.15), cbar_mode="single", cbar_pad=0.3,
|
||||
cbar_location="bottom", cbar_size="10%", label_mode="1")
|
||||
# 4-tuple rect => Divider, isinstance will give True for SubplotDivider
|
||||
assert type(grid.get_divider()) is Divider
|
||||
for i in range(3):
|
||||
im = grid[i].imshow(imdata, interpolation='none')
|
||||
grid.cbar_axes[0].colorbar(im)
|
||||
|
||||
|
||||
def test_image_grid_label_mode_invalid():
|
||||
fig = plt.figure()
|
||||
with pytest.raises(ValueError, match="'foo' is not a valid value for mode"):
|
||||
ImageGrid(fig, (0, 0, 1, 1), (2, 1), label_mode="foo")
|
||||
|
||||
|
||||
@image_comparison(['image_grid.png'],
|
||||
remove_text=True, style='mpl20',
|
||||
savefig_kwarg={'bbox_inches': 'tight'})
|
||||
def test_image_grid():
|
||||
# test that image grid works with bbox_inches=tight.
|
||||
im = np.arange(100).reshape((10, 10))
|
||||
|
||||
fig = plt.figure(1, (4, 4))
|
||||
grid = ImageGrid(fig, 111, nrows_ncols=(2, 2), axes_pad=0.1)
|
||||
assert grid.get_axes_pad() == (0.1, 0.1)
|
||||
for i in range(4):
|
||||
grid[i].imshow(im, interpolation='nearest')
|
||||
|
||||
|
||||
def test_gettightbbox():
|
||||
fig, ax = plt.subplots(figsize=(8, 6))
|
||||
|
||||
l, = ax.plot([1, 2, 3], [0, 1, 0])
|
||||
|
||||
ax_zoom = zoomed_inset_axes(ax, 4)
|
||||
ax_zoom.plot([1, 2, 3], [0, 1, 0])
|
||||
|
||||
mark_inset(ax, ax_zoom, loc1=1, loc2=3, fc="none", ec='0.3')
|
||||
|
||||
remove_ticks_and_titles(fig)
|
||||
bbox = fig.get_tightbbox(fig.canvas.get_renderer())
|
||||
np.testing.assert_array_almost_equal(bbox.extents,
|
||||
[-17.7, -13.9, 7.2, 5.4])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("click_on", ["big", "small"])
|
||||
@pytest.mark.parametrize("big_on_axes,small_on_axes", [
|
||||
("gca", "gca"),
|
||||
("host", "host"),
|
||||
("host", "parasite"),
|
||||
("parasite", "host"),
|
||||
("parasite", "parasite")
|
||||
])
|
||||
def test_picking_callbacks_overlap(big_on_axes, small_on_axes, click_on):
|
||||
"""Test pick events on normal, host or parasite axes."""
|
||||
# Two rectangles are drawn and "clicked on", a small one and a big one
|
||||
# enclosing the small one. The axis on which they are drawn as well as the
|
||||
# rectangle that is clicked on are varied.
|
||||
# In each case we expect that both rectangles are picked if we click on the
|
||||
# small one and only the big one is picked if we click on the big one.
|
||||
# Also tests picking on normal axes ("gca") as a control.
|
||||
big = plt.Rectangle((0.25, 0.25), 0.5, 0.5, picker=5)
|
||||
small = plt.Rectangle((0.4, 0.4), 0.2, 0.2, facecolor="r", picker=5)
|
||||
# Machinery for "receiving" events
|
||||
received_events = []
|
||||
def on_pick(event):
|
||||
received_events.append(event)
|
||||
plt.gcf().canvas.mpl_connect('pick_event', on_pick)
|
||||
# Shortcut
|
||||
rectangles_on_axes = (big_on_axes, small_on_axes)
|
||||
# Axes setup
|
||||
axes = {"gca": None, "host": None, "parasite": None}
|
||||
if "gca" in rectangles_on_axes:
|
||||
axes["gca"] = plt.gca()
|
||||
if "host" in rectangles_on_axes or "parasite" in rectangles_on_axes:
|
||||
axes["host"] = host_subplot(111)
|
||||
axes["parasite"] = axes["host"].twin()
|
||||
# Add rectangles to axes
|
||||
axes[big_on_axes].add_patch(big)
|
||||
axes[small_on_axes].add_patch(small)
|
||||
# Simulate picking with click mouse event
|
||||
if click_on == "big":
|
||||
click_axes = axes[big_on_axes]
|
||||
axes_coords = (0.3, 0.3)
|
||||
else:
|
||||
click_axes = axes[small_on_axes]
|
||||
axes_coords = (0.5, 0.5)
|
||||
# In reality mouse events never happen on parasite axes, only host axes
|
||||
if click_axes is axes["parasite"]:
|
||||
click_axes = axes["host"]
|
||||
(x, y) = click_axes.transAxes.transform(axes_coords)
|
||||
m = MouseEvent("button_press_event", click_axes.get_figure(root=True).canvas, x, y,
|
||||
button=1)
|
||||
click_axes.pick(m)
|
||||
# Checks
|
||||
expected_n_events = 2 if click_on == "small" else 1
|
||||
assert len(received_events) == expected_n_events
|
||||
event_rects = [event.artist for event in received_events]
|
||||
assert big in event_rects
|
||||
if click_on == "small":
|
||||
assert small in event_rects
|
||||
|
||||
|
||||
@image_comparison(['anchored_artists.png'], remove_text=True, style='mpl20')
|
||||
def test_anchored_artists():
|
||||
fig, ax = plt.subplots(figsize=(3, 3))
|
||||
ada = AnchoredDrawingArea(40, 20, 0, 0, loc='upper right', pad=0.,
|
||||
frameon=False)
|
||||
p1 = Circle((10, 10), 10)
|
||||
ada.drawing_area.add_artist(p1)
|
||||
p2 = Circle((30, 10), 5, fc="r")
|
||||
ada.drawing_area.add_artist(p2)
|
||||
ax.add_artist(ada)
|
||||
|
||||
box = AnchoredAuxTransformBox(ax.transData, loc='upper left')
|
||||
el = Ellipse((0, 0), width=0.1, height=0.4, angle=30, color='cyan')
|
||||
box.drawing_area.add_artist(el)
|
||||
ax.add_artist(box)
|
||||
|
||||
# This block used to test the AnchoredEllipse class, but that was removed. The block
|
||||
# remains, though it duplicates the above ellipse, so that the test image doesn't
|
||||
# need to be regenerated.
|
||||
box = AnchoredAuxTransformBox(ax.transData, loc='lower left', frameon=True,
|
||||
pad=0.5, borderpad=0.4)
|
||||
el = Ellipse((0, 0), width=0.1, height=0.25, angle=-60)
|
||||
box.drawing_area.add_artist(el)
|
||||
ax.add_artist(box)
|
||||
|
||||
asb = AnchoredSizeBar(ax.transData, 0.2, r"0.2 units", loc='lower right',
|
||||
pad=0.3, borderpad=0.4, sep=4, fill_bar=True,
|
||||
frameon=False, label_top=True, prop={'size': 20},
|
||||
size_vertical=0.05, color='green')
|
||||
ax.add_artist(asb)
|
||||
|
||||
|
||||
def test_hbox_divider():
|
||||
arr1 = np.arange(20).reshape((4, 5))
|
||||
arr2 = np.arange(20).reshape((5, 4))
|
||||
|
||||
fig, (ax1, ax2) = plt.subplots(1, 2)
|
||||
ax1.imshow(arr1)
|
||||
ax2.imshow(arr2)
|
||||
|
||||
pad = 0.5 # inches.
|
||||
divider = HBoxDivider(
|
||||
fig, 111, # Position of combined axes.
|
||||
horizontal=[Size.AxesX(ax1), Size.Fixed(pad), Size.AxesX(ax2)],
|
||||
vertical=[Size.AxesY(ax1), Size.Scaled(1), Size.AxesY(ax2)])
|
||||
ax1.set_axes_locator(divider.new_locator(0))
|
||||
ax2.set_axes_locator(divider.new_locator(2))
|
||||
|
||||
fig.canvas.draw()
|
||||
p1 = ax1.get_position()
|
||||
p2 = ax2.get_position()
|
||||
assert p1.height == p2.height
|
||||
assert p2.width / p1.width == pytest.approx((4 / 5) ** 2)
|
||||
|
||||
|
||||
def test_vbox_divider():
|
||||
arr1 = np.arange(20).reshape((4, 5))
|
||||
arr2 = np.arange(20).reshape((5, 4))
|
||||
|
||||
fig, (ax1, ax2) = plt.subplots(1, 2)
|
||||
ax1.imshow(arr1)
|
||||
ax2.imshow(arr2)
|
||||
|
||||
pad = 0.5 # inches.
|
||||
divider = VBoxDivider(
|
||||
fig, 111, # Position of combined axes.
|
||||
horizontal=[Size.AxesX(ax1), Size.Scaled(1), Size.AxesX(ax2)],
|
||||
vertical=[Size.AxesY(ax1), Size.Fixed(pad), Size.AxesY(ax2)])
|
||||
ax1.set_axes_locator(divider.new_locator(0))
|
||||
ax2.set_axes_locator(divider.new_locator(2))
|
||||
|
||||
fig.canvas.draw()
|
||||
p1 = ax1.get_position()
|
||||
p2 = ax2.get_position()
|
||||
assert p1.width == p2.width
|
||||
assert p1.height / p2.height == pytest.approx((4 / 5) ** 2)
|
||||
|
||||
|
||||
def test_axes_class_tuple():
|
||||
fig = plt.figure()
|
||||
axes_class = (mpl_toolkits.axes_grid1.mpl_axes.Axes, {})
|
||||
gr = AxesGrid(fig, 111, nrows_ncols=(1, 1), axes_class=axes_class)
|
||||
|
||||
|
||||
def test_grid_axes_lists():
|
||||
"""Test Grid axes_all, axes_row and axes_column relationship."""
|
||||
fig = plt.figure()
|
||||
grid = Grid(fig, 111, (2, 3), direction="row")
|
||||
assert_array_equal(grid, grid.axes_all)
|
||||
assert_array_equal(grid.axes_row, np.transpose(grid.axes_column))
|
||||
assert_array_equal(grid, np.ravel(grid.axes_row), "row")
|
||||
assert grid.get_geometry() == (2, 3)
|
||||
grid = Grid(fig, 111, (2, 3), direction="column")
|
||||
assert_array_equal(grid, np.ravel(grid.axes_column), "column")
|
||||
|
||||
|
||||
@pytest.mark.parametrize('direction', ('row', 'column'))
|
||||
def test_grid_axes_position(direction):
|
||||
"""Test positioning of the axes in Grid."""
|
||||
fig = plt.figure()
|
||||
grid = Grid(fig, 111, (2, 2), direction=direction)
|
||||
loc = [ax.get_axes_locator() for ax in np.ravel(grid.axes_row)]
|
||||
# Test nx.
|
||||
assert loc[1].args[0] > loc[0].args[0]
|
||||
assert loc[0].args[0] == loc[2].args[0]
|
||||
assert loc[3].args[0] == loc[1].args[0]
|
||||
# Test ny.
|
||||
assert loc[2].args[1] < loc[0].args[1]
|
||||
assert loc[0].args[1] == loc[1].args[1]
|
||||
assert loc[3].args[1] == loc[2].args[1]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('rect, ngrids, error, message', (
|
||||
((1, 1), None, TypeError, "Incorrect rect format"),
|
||||
(111, -1, ValueError, "ngrids must be positive"),
|
||||
(111, 7, ValueError, "ngrids must be positive"),
|
||||
))
|
||||
def test_grid_errors(rect, ngrids, error, message):
|
||||
fig = plt.figure()
|
||||
with pytest.raises(error, match=message):
|
||||
Grid(fig, rect, (2, 3), ngrids=ngrids)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('anchor, error, message', (
|
||||
(None, TypeError, "anchor must be str"),
|
||||
("CC", ValueError, "'CC' is not a valid value for anchor"),
|
||||
((1, 1, 1), TypeError, "anchor must be str"),
|
||||
))
|
||||
def test_divider_errors(anchor, error, message):
|
||||
fig = plt.figure()
|
||||
with pytest.raises(error, match=message):
|
||||
Divider(fig, [0, 0, 1, 1], [Size.Fixed(1)], [Size.Fixed(1)],
|
||||
anchor=anchor)
|
||||
|
||||
|
||||
@check_figures_equal(extensions=["png"])
|
||||
def test_mark_inset_unstales_viewlim(fig_test, fig_ref):
|
||||
inset, full = fig_test.subplots(1, 2)
|
||||
full.plot([0, 5], [0, 5])
|
||||
inset.set(xlim=(1, 2), ylim=(1, 2))
|
||||
# Check that mark_inset unstales full's viewLim before drawing the marks.
|
||||
mark_inset(full, inset, 1, 4)
|
||||
|
||||
inset, full = fig_ref.subplots(1, 2)
|
||||
full.plot([0, 5], [0, 5])
|
||||
inset.set(xlim=(1, 2), ylim=(1, 2))
|
||||
mark_inset(full, inset, 1, 4)
|
||||
# Manually unstale the full's viewLim.
|
||||
fig_ref.canvas.draw()
|
||||
|
||||
|
||||
def test_auto_adjustable():
|
||||
fig = plt.figure()
|
||||
ax = fig.add_axes([0, 0, 1, 1])
|
||||
pad = 0.1
|
||||
make_axes_area_auto_adjustable(ax, pad=pad)
|
||||
fig.canvas.draw()
|
||||
tbb = ax.get_tightbbox()
|
||||
assert tbb.x0 == pytest.approx(pad * fig.dpi)
|
||||
assert tbb.x1 == pytest.approx(fig.bbox.width - pad * fig.dpi)
|
||||
assert tbb.y0 == pytest.approx(pad * fig.dpi)
|
||||
assert tbb.y1 == pytest.approx(fig.bbox.height - pad * fig.dpi)
|
||||
|
||||
|
||||
# Update style when regenerating the test image
|
||||
@image_comparison(['rgb_axes.png'], remove_text=True,
|
||||
style=('classic', '_classic_test_patch'))
|
||||
def test_rgb_axes():
|
||||
fig = plt.figure()
|
||||
ax = RGBAxes(fig, (0.1, 0.1, 0.8, 0.8), pad=0.1)
|
||||
rng = np.random.default_rng(19680801)
|
||||
r = rng.random((5, 5))
|
||||
g = rng.random((5, 5))
|
||||
b = rng.random((5, 5))
|
||||
ax.imshow_rgb(r, g, b, interpolation='none')
|
||||
|
||||
|
||||
# The original version of this test relied on mpl_toolkits's slightly different
|
||||
# colorbar implementation; moving to matplotlib's own colorbar implementation
|
||||
# caused the small image comparison error.
|
||||
@image_comparison(['imagegrid_cbar_mode.png'],
|
||||
remove_text=True, style='mpl20', tol=0.3)
|
||||
def test_imagegrid_cbar_mode_edge():
|
||||
arr = np.arange(16).reshape((4, 4))
|
||||
|
||||
fig = plt.figure(figsize=(18, 9))
|
||||
|
||||
positions = (241, 242, 243, 244, 245, 246, 247, 248)
|
||||
directions = ['row']*4 + ['column']*4
|
||||
cbar_locations = ['left', 'right', 'top', 'bottom']*2
|
||||
|
||||
for position, direction, location in zip(
|
||||
positions, directions, cbar_locations):
|
||||
grid = ImageGrid(fig, position,
|
||||
nrows_ncols=(2, 2),
|
||||
direction=direction,
|
||||
cbar_location=location,
|
||||
cbar_size='20%',
|
||||
cbar_mode='edge')
|
||||
ax1, ax2, ax3, ax4 = grid
|
||||
|
||||
ax1.imshow(arr, cmap='nipy_spectral')
|
||||
ax2.imshow(arr.T, cmap='hot')
|
||||
ax3.imshow(np.hypot(arr, arr.T), cmap='jet')
|
||||
ax4.imshow(np.arctan2(arr, arr.T), cmap='hsv')
|
||||
|
||||
# In each row/column, the "first" colorbars must be overwritten by the
|
||||
# "second" ones. To achieve this, clear out the axes first.
|
||||
for ax in grid:
|
||||
ax.cax.cla()
|
||||
cb = ax.cax.colorbar(ax.images[0])
|
||||
|
||||
|
||||
def test_imagegrid():
|
||||
fig = plt.figure()
|
||||
grid = ImageGrid(fig, 111, nrows_ncols=(1, 1))
|
||||
ax = grid[0]
|
||||
im = ax.imshow([[1, 2]], norm=mpl.colors.LogNorm())
|
||||
cb = ax.cax.colorbar(im)
|
||||
assert isinstance(cb.locator, mticker.LogLocator)
|
||||
|
||||
|
||||
def test_removal():
|
||||
import matplotlib.pyplot as plt
|
||||
import mpl_toolkits.axisartist as AA
|
||||
fig = plt.figure()
|
||||
ax = host_subplot(111, axes_class=AA.Axes, figure=fig)
|
||||
col = ax.fill_between(range(5), 0, range(5))
|
||||
fig.canvas.draw()
|
||||
col.remove()
|
||||
fig.canvas.draw()
|
||||
|
||||
|
||||
@image_comparison(['anchored_locator_base_call.png'], style="mpl20")
|
||||
def test_anchored_locator_base_call():
|
||||
fig = plt.figure(figsize=(3, 3))
|
||||
fig1, fig2 = fig.subfigures(nrows=2, ncols=1)
|
||||
|
||||
ax = fig1.subplots()
|
||||
ax.set(aspect=1, xlim=(-15, 15), ylim=(-20, 5))
|
||||
ax.set(xticks=[], yticks=[])
|
||||
|
||||
Z = cbook.get_sample_data("axes_grid/bivariate_normal.npy")
|
||||
extent = (-3, 4, -4, 3)
|
||||
|
||||
axins = zoomed_inset_axes(ax, zoom=2, loc="upper left")
|
||||
axins.set(xticks=[], yticks=[])
|
||||
|
||||
axins.imshow(Z, extent=extent, origin="lower")
|
||||
|
||||
|
||||
def test_grid_with_axes_class_not_overriding_axis():
|
||||
Grid(plt.figure(), 111, (2, 2), axes_class=mpl.axes.Axes)
|
||||
RGBAxes(plt.figure(), 111, axes_class=mpl.axes.Axes)
|
|
@ -0,0 +1,14 @@
|
|||
from .axislines import Axes
|
||||
from .axislines import ( # noqa: F401
|
||||
AxesZero, AxisArtistHelper, AxisArtistHelperRectlinear,
|
||||
GridHelperBase, GridHelperRectlinear, Subplot, SubplotZero)
|
||||
from .axis_artist import AxisArtist, GridlinesCollection # noqa: F401
|
||||
from .grid_helper_curvelinear import GridHelperCurveLinear # noqa: F401
|
||||
from .floating_axes import FloatingAxes, FloatingSubplot # noqa: F401
|
||||
from mpl_toolkits.axes_grid1.parasite_axes import (
|
||||
host_axes_class_factory, parasite_axes_class_factory)
|
||||
|
||||
|
||||
ParasiteAxes = parasite_axes_class_factory(Axes)
|
||||
HostAxes = host_axes_class_factory(Axes)
|
||||
SubplotHost = HostAxes
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,394 @@
|
|||
import numpy as np
|
||||
import math
|
||||
|
||||
from mpl_toolkits.axisartist.grid_finder import ExtremeFinderSimple
|
||||
|
||||
|
||||
def select_step_degree(dv):
|
||||
|
||||
degree_limits_ = [1.5, 3, 7, 13, 20, 40, 70, 120, 270, 520]
|
||||
degree_steps_ = [1, 2, 5, 10, 15, 30, 45, 90, 180, 360]
|
||||
degree_factors = [1.] * len(degree_steps_)
|
||||
|
||||
minsec_limits_ = [1.5, 2.5, 3.5, 8, 11, 18, 25, 45]
|
||||
minsec_steps_ = [1, 2, 3, 5, 10, 15, 20, 30]
|
||||
|
||||
minute_limits_ = np.array(minsec_limits_) / 60
|
||||
minute_factors = [60.] * len(minute_limits_)
|
||||
|
||||
second_limits_ = np.array(minsec_limits_) / 3600
|
||||
second_factors = [3600.] * len(second_limits_)
|
||||
|
||||
degree_limits = [*second_limits_, *minute_limits_, *degree_limits_]
|
||||
degree_steps = [*minsec_steps_, *minsec_steps_, *degree_steps_]
|
||||
degree_factors = [*second_factors, *minute_factors, *degree_factors]
|
||||
|
||||
n = np.searchsorted(degree_limits, dv)
|
||||
step = degree_steps[n]
|
||||
factor = degree_factors[n]
|
||||
|
||||
return step, factor
|
||||
|
||||
|
||||
def select_step_hour(dv):
|
||||
|
||||
hour_limits_ = [1.5, 2.5, 3.5, 5, 7, 10, 15, 21, 36]
|
||||
hour_steps_ = [1, 2, 3, 4, 6, 8, 12, 18, 24]
|
||||
hour_factors = [1.] * len(hour_steps_)
|
||||
|
||||
minsec_limits_ = [1.5, 2.5, 3.5, 4.5, 5.5, 8, 11, 14, 18, 25, 45]
|
||||
minsec_steps_ = [1, 2, 3, 4, 5, 6, 10, 12, 15, 20, 30]
|
||||
|
||||
minute_limits_ = np.array(minsec_limits_) / 60
|
||||
minute_factors = [60.] * len(minute_limits_)
|
||||
|
||||
second_limits_ = np.array(minsec_limits_) / 3600
|
||||
second_factors = [3600.] * len(second_limits_)
|
||||
|
||||
hour_limits = [*second_limits_, *minute_limits_, *hour_limits_]
|
||||
hour_steps = [*minsec_steps_, *minsec_steps_, *hour_steps_]
|
||||
hour_factors = [*second_factors, *minute_factors, *hour_factors]
|
||||
|
||||
n = np.searchsorted(hour_limits, dv)
|
||||
step = hour_steps[n]
|
||||
factor = hour_factors[n]
|
||||
|
||||
return step, factor
|
||||
|
||||
|
||||
def select_step_sub(dv):
|
||||
|
||||
# subarcsec or degree
|
||||
tmp = 10.**(int(math.log10(dv))-1.)
|
||||
|
||||
factor = 1./tmp
|
||||
|
||||
if 1.5*tmp >= dv:
|
||||
step = 1
|
||||
elif 3.*tmp >= dv:
|
||||
step = 2
|
||||
elif 7.*tmp >= dv:
|
||||
step = 5
|
||||
else:
|
||||
step = 1
|
||||
factor = 0.1*factor
|
||||
|
||||
return step, factor
|
||||
|
||||
|
||||
def select_step(v1, v2, nv, hour=False, include_last=True,
|
||||
threshold_factor=3600.):
|
||||
|
||||
if v1 > v2:
|
||||
v1, v2 = v2, v1
|
||||
|
||||
dv = (v2 - v1) / nv
|
||||
|
||||
if hour:
|
||||
_select_step = select_step_hour
|
||||
cycle = 24.
|
||||
else:
|
||||
_select_step = select_step_degree
|
||||
cycle = 360.
|
||||
|
||||
# for degree
|
||||
if dv > 1 / threshold_factor:
|
||||
step, factor = _select_step(dv)
|
||||
else:
|
||||
step, factor = select_step_sub(dv*threshold_factor)
|
||||
|
||||
factor = factor * threshold_factor
|
||||
|
||||
levs = np.arange(np.floor(v1 * factor / step),
|
||||
np.ceil(v2 * factor / step) + 0.5,
|
||||
dtype=int) * step
|
||||
|
||||
# n : number of valid levels. If there is a cycle, e.g., [0, 90, 180,
|
||||
# 270, 360], the grid line needs to be extended from 0 to 360, so
|
||||
# we need to return the whole array. However, the last level (360)
|
||||
# needs to be ignored often. In this case, so we return n=4.
|
||||
|
||||
n = len(levs)
|
||||
|
||||
# we need to check the range of values
|
||||
# for example, -90 to 90, 0 to 360,
|
||||
|
||||
if factor == 1. and levs[-1] >= levs[0] + cycle: # check for cycle
|
||||
nv = int(cycle / step)
|
||||
if include_last:
|
||||
levs = levs[0] + np.arange(0, nv+1, 1) * step
|
||||
else:
|
||||
levs = levs[0] + np.arange(0, nv, 1) * step
|
||||
|
||||
n = len(levs)
|
||||
|
||||
return np.array(levs), n, factor
|
||||
|
||||
|
||||
def select_step24(v1, v2, nv, include_last=True, threshold_factor=3600):
|
||||
v1, v2 = v1 / 15, v2 / 15
|
||||
levs, n, factor = select_step(v1, v2, nv, hour=True,
|
||||
include_last=include_last,
|
||||
threshold_factor=threshold_factor)
|
||||
return levs * 15, n, factor
|
||||
|
||||
|
||||
def select_step360(v1, v2, nv, include_last=True, threshold_factor=3600):
|
||||
return select_step(v1, v2, nv, hour=False,
|
||||
include_last=include_last,
|
||||
threshold_factor=threshold_factor)
|
||||
|
||||
|
||||
class LocatorBase:
|
||||
def __init__(self, nbins, include_last=True):
|
||||
self.nbins = nbins
|
||||
self._include_last = include_last
|
||||
|
||||
def set_params(self, nbins=None):
|
||||
if nbins is not None:
|
||||
self.nbins = int(nbins)
|
||||
|
||||
|
||||
class LocatorHMS(LocatorBase):
|
||||
def __call__(self, v1, v2):
|
||||
return select_step24(v1, v2, self.nbins, self._include_last)
|
||||
|
||||
|
||||
class LocatorHM(LocatorBase):
|
||||
def __call__(self, v1, v2):
|
||||
return select_step24(v1, v2, self.nbins, self._include_last,
|
||||
threshold_factor=60)
|
||||
|
||||
|
||||
class LocatorH(LocatorBase):
|
||||
def __call__(self, v1, v2):
|
||||
return select_step24(v1, v2, self.nbins, self._include_last,
|
||||
threshold_factor=1)
|
||||
|
||||
|
||||
class LocatorDMS(LocatorBase):
|
||||
def __call__(self, v1, v2):
|
||||
return select_step360(v1, v2, self.nbins, self._include_last)
|
||||
|
||||
|
||||
class LocatorDM(LocatorBase):
|
||||
def __call__(self, v1, v2):
|
||||
return select_step360(v1, v2, self.nbins, self._include_last,
|
||||
threshold_factor=60)
|
||||
|
||||
|
||||
class LocatorD(LocatorBase):
|
||||
def __call__(self, v1, v2):
|
||||
return select_step360(v1, v2, self.nbins, self._include_last,
|
||||
threshold_factor=1)
|
||||
|
||||
|
||||
class FormatterDMS:
|
||||
deg_mark = r"^{\circ}"
|
||||
min_mark = r"^{\prime}"
|
||||
sec_mark = r"^{\prime\prime}"
|
||||
|
||||
fmt_d = "$%d" + deg_mark + "$"
|
||||
fmt_ds = r"$%d.%s" + deg_mark + "$"
|
||||
|
||||
# %s for sign
|
||||
fmt_d_m = r"$%s%d" + deg_mark + r"\,%02d" + min_mark + "$"
|
||||
fmt_d_ms = r"$%s%d" + deg_mark + r"\,%02d.%s" + min_mark + "$"
|
||||
|
||||
fmt_d_m_partial = "$%s%d" + deg_mark + r"\,%02d" + min_mark + r"\,"
|
||||
fmt_s_partial = "%02d" + sec_mark + "$"
|
||||
fmt_ss_partial = "%02d.%s" + sec_mark + "$"
|
||||
|
||||
def _get_number_fraction(self, factor):
|
||||
## check for fractional numbers
|
||||
number_fraction = None
|
||||
# check for 60
|
||||
|
||||
for threshold in [1, 60, 3600]:
|
||||
if factor <= threshold:
|
||||
break
|
||||
|
||||
d = factor // threshold
|
||||
int_log_d = int(np.floor(np.log10(d)))
|
||||
if 10**int_log_d == d and d != 1:
|
||||
number_fraction = int_log_d
|
||||
factor = factor // 10**int_log_d
|
||||
return factor, number_fraction
|
||||
|
||||
return factor, number_fraction
|
||||
|
||||
def __call__(self, direction, factor, values):
|
||||
if len(values) == 0:
|
||||
return []
|
||||
|
||||
ss = np.sign(values)
|
||||
signs = ["-" if v < 0 else "" for v in values]
|
||||
|
||||
factor, number_fraction = self._get_number_fraction(factor)
|
||||
|
||||
values = np.abs(values)
|
||||
|
||||
if number_fraction is not None:
|
||||
values, frac_part = divmod(values, 10 ** number_fraction)
|
||||
frac_fmt = "%%0%dd" % (number_fraction,)
|
||||
frac_str = [frac_fmt % (f1,) for f1 in frac_part]
|
||||
|
||||
if factor == 1:
|
||||
if number_fraction is None:
|
||||
return [self.fmt_d % (s * int(v),) for s, v in zip(ss, values)]
|
||||
else:
|
||||
return [self.fmt_ds % (s * int(v), f1)
|
||||
for s, v, f1 in zip(ss, values, frac_str)]
|
||||
elif factor == 60:
|
||||
deg_part, min_part = divmod(values, 60)
|
||||
if number_fraction is None:
|
||||
return [self.fmt_d_m % (s1, d1, m1)
|
||||
for s1, d1, m1 in zip(signs, deg_part, min_part)]
|
||||
else:
|
||||
return [self.fmt_d_ms % (s, d1, m1, f1)
|
||||
for s, d1, m1, f1
|
||||
in zip(signs, deg_part, min_part, frac_str)]
|
||||
|
||||
elif factor == 3600:
|
||||
if ss[-1] == -1:
|
||||
inverse_order = True
|
||||
values = values[::-1]
|
||||
signs = signs[::-1]
|
||||
else:
|
||||
inverse_order = False
|
||||
|
||||
l_hm_old = ""
|
||||
r = []
|
||||
|
||||
deg_part, min_part_ = divmod(values, 3600)
|
||||
min_part, sec_part = divmod(min_part_, 60)
|
||||
|
||||
if number_fraction is None:
|
||||
sec_str = [self.fmt_s_partial % (s1,) for s1 in sec_part]
|
||||
else:
|
||||
sec_str = [self.fmt_ss_partial % (s1, f1)
|
||||
for s1, f1 in zip(sec_part, frac_str)]
|
||||
|
||||
for s, d1, m1, s1 in zip(signs, deg_part, min_part, sec_str):
|
||||
l_hm = self.fmt_d_m_partial % (s, d1, m1)
|
||||
if l_hm != l_hm_old:
|
||||
l_hm_old = l_hm
|
||||
l = l_hm + s1
|
||||
else:
|
||||
l = "$" + s + s1
|
||||
r.append(l)
|
||||
|
||||
if inverse_order:
|
||||
return r[::-1]
|
||||
else:
|
||||
return r
|
||||
|
||||
else: # factor > 3600.
|
||||
return [r"$%s^{\circ}$" % v for v in ss*values]
|
||||
|
||||
|
||||
class FormatterHMS(FormatterDMS):
|
||||
deg_mark = r"^\mathrm{h}"
|
||||
min_mark = r"^\mathrm{m}"
|
||||
sec_mark = r"^\mathrm{s}"
|
||||
|
||||
fmt_d = "$%d" + deg_mark + "$"
|
||||
fmt_ds = r"$%d.%s" + deg_mark + "$"
|
||||
|
||||
# %s for sign
|
||||
fmt_d_m = r"$%s%d" + deg_mark + r"\,%02d" + min_mark+"$"
|
||||
fmt_d_ms = r"$%s%d" + deg_mark + r"\,%02d.%s" + min_mark+"$"
|
||||
|
||||
fmt_d_m_partial = "$%s%d" + deg_mark + r"\,%02d" + min_mark + r"\,"
|
||||
fmt_s_partial = "%02d" + sec_mark + "$"
|
||||
fmt_ss_partial = "%02d.%s" + sec_mark + "$"
|
||||
|
||||
def __call__(self, direction, factor, values): # hour
|
||||
return super().__call__(direction, factor, np.asarray(values) / 15)
|
||||
|
||||
|
||||
class ExtremeFinderCycle(ExtremeFinderSimple):
|
||||
# docstring inherited
|
||||
|
||||
def __init__(self, nx, ny,
|
||||
lon_cycle=360., lat_cycle=None,
|
||||
lon_minmax=None, lat_minmax=(-90, 90)):
|
||||
"""
|
||||
This subclass handles the case where one or both coordinates should be
|
||||
taken modulo 360, or be restricted to not exceed a specific range.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
nx, ny : int
|
||||
The number of samples in each direction.
|
||||
|
||||
lon_cycle, lat_cycle : 360 or None
|
||||
If not None, values in the corresponding direction are taken modulo
|
||||
*lon_cycle* or *lat_cycle*; in theory this can be any number but
|
||||
the implementation actually assumes that it is 360 (if not None);
|
||||
other values give nonsensical results.
|
||||
|
||||
This is done by "unwrapping" the transformed grid coordinates so
|
||||
that jumps are less than a half-cycle; then normalizing the span to
|
||||
no more than a full cycle.
|
||||
|
||||
For example, if values are in the union of the [0, 2] and
|
||||
[358, 360] intervals (typically, angles measured modulo 360), the
|
||||
values in the second interval are normalized to [-2, 0] instead so
|
||||
that the values now cover [-2, 2]. If values are in a range of
|
||||
[5, 1000], this gets normalized to [5, 365].
|
||||
|
||||
lon_minmax, lat_minmax : (float, float) or None
|
||||
If not None, the computed bounding box is clipped to the given
|
||||
range in the corresponding direction.
|
||||
"""
|
||||
self.nx, self.ny = nx, ny
|
||||
self.lon_cycle, self.lat_cycle = lon_cycle, lat_cycle
|
||||
self.lon_minmax = lon_minmax
|
||||
self.lat_minmax = lat_minmax
|
||||
|
||||
def __call__(self, transform_xy, x1, y1, x2, y2):
|
||||
# docstring inherited
|
||||
x, y = np.meshgrid(
|
||||
np.linspace(x1, x2, self.nx), np.linspace(y1, y2, self.ny))
|
||||
lon, lat = transform_xy(np.ravel(x), np.ravel(y))
|
||||
|
||||
# iron out jumps, but algorithm should be improved.
|
||||
# This is just naive way of doing and my fail for some cases.
|
||||
# Consider replacing this with numpy.unwrap
|
||||
# We are ignoring invalid warnings. They are triggered when
|
||||
# comparing arrays with NaNs using > We are already handling
|
||||
# that correctly using np.nanmin and np.nanmax
|
||||
with np.errstate(invalid='ignore'):
|
||||
if self.lon_cycle is not None:
|
||||
lon0 = np.nanmin(lon)
|
||||
lon -= 360. * ((lon - lon0) > 180.)
|
||||
if self.lat_cycle is not None:
|
||||
lat0 = np.nanmin(lat)
|
||||
lat -= 360. * ((lat - lat0) > 180.)
|
||||
|
||||
lon_min, lon_max = np.nanmin(lon), np.nanmax(lon)
|
||||
lat_min, lat_max = np.nanmin(lat), np.nanmax(lat)
|
||||
|
||||
lon_min, lon_max, lat_min, lat_max = \
|
||||
self._add_pad(lon_min, lon_max, lat_min, lat_max)
|
||||
|
||||
# check cycle
|
||||
if self.lon_cycle:
|
||||
lon_max = min(lon_max, lon_min + self.lon_cycle)
|
||||
if self.lat_cycle:
|
||||
lat_max = min(lat_max, lat_min + self.lat_cycle)
|
||||
|
||||
if self.lon_minmax is not None:
|
||||
min0 = self.lon_minmax[0]
|
||||
lon_min = max(min0, lon_min)
|
||||
max0 = self.lon_minmax[1]
|
||||
lon_max = min(max0, lon_max)
|
||||
|
||||
if self.lat_minmax is not None:
|
||||
min0 = self.lat_minmax[0]
|
||||
lat_min = max(min0, lat_min)
|
||||
max0 = self.lat_minmax[1]
|
||||
lat_max = min(max0, lat_max)
|
||||
|
||||
return lon_min, lon_max, lat_min, lat_max
|
|
@ -0,0 +1,2 @@
|
|||
from mpl_toolkits.axes_grid1.axes_divider import ( # noqa
|
||||
Divider, SubplotDivider, AxesDivider, make_axes_locatable)
|
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,193 @@
|
|||
"""
|
||||
Provides classes to style the axis lines.
|
||||
"""
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
import matplotlib as mpl
|
||||
from matplotlib.patches import _Style, FancyArrowPatch
|
||||
from matplotlib.path import Path
|
||||
from matplotlib.transforms import IdentityTransform
|
||||
|
||||
|
||||
class _FancyAxislineStyle:
|
||||
class SimpleArrow(FancyArrowPatch):
|
||||
"""The artist class that will be returned for SimpleArrow style."""
|
||||
_ARROW_STYLE = "->"
|
||||
|
||||
def __init__(self, axis_artist, line_path, transform,
|
||||
line_mutation_scale):
|
||||
self._axis_artist = axis_artist
|
||||
self._line_transform = transform
|
||||
self._line_path = line_path
|
||||
self._line_mutation_scale = line_mutation_scale
|
||||
|
||||
FancyArrowPatch.__init__(self,
|
||||
path=self._line_path,
|
||||
arrowstyle=self._ARROW_STYLE,
|
||||
patchA=None,
|
||||
patchB=None,
|
||||
shrinkA=0.,
|
||||
shrinkB=0.,
|
||||
mutation_scale=line_mutation_scale,
|
||||
mutation_aspect=None,
|
||||
transform=IdentityTransform(),
|
||||
)
|
||||
|
||||
def set_line_mutation_scale(self, scale):
|
||||
self.set_mutation_scale(scale*self._line_mutation_scale)
|
||||
|
||||
def _extend_path(self, path, mutation_size=10):
|
||||
"""
|
||||
Extend the path to make a room for drawing arrow.
|
||||
"""
|
||||
(x0, y0), (x1, y1) = path.vertices[-2:]
|
||||
theta = math.atan2(y1 - y0, x1 - x0)
|
||||
x2 = x1 + math.cos(theta) * mutation_size
|
||||
y2 = y1 + math.sin(theta) * mutation_size
|
||||
if path.codes is None:
|
||||
return Path(np.concatenate([path.vertices, [[x2, y2]]]))
|
||||
else:
|
||||
return Path(np.concatenate([path.vertices, [[x2, y2]]]),
|
||||
np.concatenate([path.codes, [Path.LINETO]]))
|
||||
|
||||
def set_path(self, path):
|
||||
self._line_path = path
|
||||
|
||||
def draw(self, renderer):
|
||||
"""
|
||||
Draw the axis line.
|
||||
1) Transform the path to the display coordinate.
|
||||
2) Extend the path to make a room for arrow.
|
||||
3) Update the path of the FancyArrowPatch.
|
||||
4) Draw.
|
||||
"""
|
||||
path_in_disp = self._line_transform.transform_path(self._line_path)
|
||||
mutation_size = self.get_mutation_scale() # line_mutation_scale()
|
||||
extended_path = self._extend_path(path_in_disp,
|
||||
mutation_size=mutation_size)
|
||||
self._path_original = extended_path
|
||||
FancyArrowPatch.draw(self, renderer)
|
||||
|
||||
def get_window_extent(self, renderer=None):
|
||||
|
||||
path_in_disp = self._line_transform.transform_path(self._line_path)
|
||||
mutation_size = self.get_mutation_scale() # line_mutation_scale()
|
||||
extended_path = self._extend_path(path_in_disp,
|
||||
mutation_size=mutation_size)
|
||||
self._path_original = extended_path
|
||||
return FancyArrowPatch.get_window_extent(self, renderer)
|
||||
|
||||
class FilledArrow(SimpleArrow):
|
||||
"""The artist class that will be returned for FilledArrow style."""
|
||||
_ARROW_STYLE = "-|>"
|
||||
|
||||
def __init__(self, axis_artist, line_path, transform,
|
||||
line_mutation_scale, facecolor):
|
||||
super().__init__(axis_artist, line_path, transform,
|
||||
line_mutation_scale)
|
||||
self.set_facecolor(facecolor)
|
||||
|
||||
|
||||
class AxislineStyle(_Style):
|
||||
"""
|
||||
A container class which defines style classes for AxisArtists.
|
||||
|
||||
An instance of any axisline style class is a callable object,
|
||||
whose call signature is ::
|
||||
|
||||
__call__(self, axis_artist, path, transform)
|
||||
|
||||
When called, this should return an `.Artist` with the following methods::
|
||||
|
||||
def set_path(self, path):
|
||||
# set the path for axisline.
|
||||
|
||||
def set_line_mutation_scale(self, scale):
|
||||
# set the scale
|
||||
|
||||
def draw(self, renderer):
|
||||
# draw
|
||||
"""
|
||||
|
||||
_style_list = {}
|
||||
|
||||
class _Base:
|
||||
# The derived classes are required to be able to be initialized
|
||||
# w/o arguments, i.e., all its argument (except self) must have
|
||||
# the default values.
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
initialization.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, axis_artist, transform):
|
||||
"""
|
||||
Given the AxisArtist instance, and transform for the path (set_path
|
||||
method), return the Matplotlib artist for drawing the axis line.
|
||||
"""
|
||||
return self.new_line(axis_artist, transform)
|
||||
|
||||
class SimpleArrow(_Base):
|
||||
"""
|
||||
A simple arrow.
|
||||
"""
|
||||
|
||||
ArrowAxisClass = _FancyAxislineStyle.SimpleArrow
|
||||
|
||||
def __init__(self, size=1):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
size : float
|
||||
Size of the arrow as a fraction of the ticklabel size.
|
||||
"""
|
||||
|
||||
self.size = size
|
||||
super().__init__()
|
||||
|
||||
def new_line(self, axis_artist, transform):
|
||||
|
||||
linepath = Path([(0, 0), (0, 1)])
|
||||
axisline = self.ArrowAxisClass(axis_artist, linepath, transform,
|
||||
line_mutation_scale=self.size)
|
||||
return axisline
|
||||
|
||||
_style_list["->"] = SimpleArrow
|
||||
|
||||
class FilledArrow(SimpleArrow):
|
||||
"""
|
||||
An arrow with a filled head.
|
||||
"""
|
||||
|
||||
ArrowAxisClass = _FancyAxislineStyle.FilledArrow
|
||||
|
||||
def __init__(self, size=1, facecolor=None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
size : float
|
||||
Size of the arrow as a fraction of the ticklabel size.
|
||||
facecolor : :mpltype:`color`, default: :rc:`axes.edgecolor`
|
||||
Fill color.
|
||||
|
||||
.. versionadded:: 3.7
|
||||
"""
|
||||
|
||||
if facecolor is None:
|
||||
facecolor = mpl.rcParams['axes.edgecolor']
|
||||
self.size = size
|
||||
self._facecolor = facecolor
|
||||
super().__init__(size=size)
|
||||
|
||||
def new_line(self, axis_artist, transform):
|
||||
linepath = Path([(0, 0), (0, 1)])
|
||||
axisline = self.ArrowAxisClass(axis_artist, linepath, transform,
|
||||
line_mutation_scale=self.size,
|
||||
facecolor=self._facecolor)
|
||||
return axisline
|
||||
|
||||
_style_list["-|>"] = FilledArrow
|
|
@ -0,0 +1,479 @@
|
|||
"""
|
||||
Axislines includes modified implementation of the Axes class. The
|
||||
biggest difference is that the artists responsible for drawing the axis spine,
|
||||
ticks, ticklabels and axis labels are separated out from Matplotlib's Axis
|
||||
class. Originally, this change was motivated to support curvilinear
|
||||
grid. Here are a few reasons that I came up with a new axes class:
|
||||
|
||||
* "top" and "bottom" x-axis (or "left" and "right" y-axis) can have
|
||||
different ticks (tick locations and labels). This is not possible
|
||||
with the current Matplotlib, although some twin axes trick can help.
|
||||
|
||||
* Curvilinear grid.
|
||||
|
||||
* angled ticks.
|
||||
|
||||
In the new axes class, xaxis and yaxis is set to not visible by
|
||||
default, and new set of artist (AxisArtist) are defined to draw axis
|
||||
line, ticks, ticklabels and axis label. Axes.axis attribute serves as
|
||||
a dictionary of these artists, i.e., ax.axis["left"] is a AxisArtist
|
||||
instance responsible to draw left y-axis. The default Axes.axis contains
|
||||
"bottom", "left", "top" and "right".
|
||||
|
||||
AxisArtist can be considered as a container artist and has the following
|
||||
children artists which will draw ticks, labels, etc.
|
||||
|
||||
* line
|
||||
* major_ticks, major_ticklabels
|
||||
* minor_ticks, minor_ticklabels
|
||||
* offsetText
|
||||
* label
|
||||
|
||||
Note that these are separate artists from `matplotlib.axis.Axis`, thus most
|
||||
tick-related functions in Matplotlib won't work. For example, color and
|
||||
markerwidth of the ``ax.axis["bottom"].major_ticks`` will follow those of
|
||||
Axes.xaxis unless explicitly specified.
|
||||
|
||||
In addition to AxisArtist, the Axes will have *gridlines* attribute,
|
||||
which obviously draws grid lines. The gridlines needs to be separated
|
||||
from the axis as some gridlines can never pass any axis.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
import matplotlib as mpl
|
||||
from matplotlib import _api
|
||||
import matplotlib.axes as maxes
|
||||
from matplotlib.path import Path
|
||||
from mpl_toolkits.axes_grid1 import mpl_axes
|
||||
from .axisline_style import AxislineStyle # noqa
|
||||
from .axis_artist import AxisArtist, GridlinesCollection
|
||||
|
||||
|
||||
class _AxisArtistHelperBase:
|
||||
"""
|
||||
Base class for axis helper.
|
||||
|
||||
Subclasses should define the methods listed below. The *axes*
|
||||
argument will be the ``.axes`` attribute of the caller artist. ::
|
||||
|
||||
# Construct the spine.
|
||||
|
||||
def get_line_transform(self, axes):
|
||||
return transform
|
||||
|
||||
def get_line(self, axes):
|
||||
return path
|
||||
|
||||
# Construct the label.
|
||||
|
||||
def get_axislabel_transform(self, axes):
|
||||
return transform
|
||||
|
||||
def get_axislabel_pos_angle(self, axes):
|
||||
return (x, y), angle
|
||||
|
||||
# Construct the ticks.
|
||||
|
||||
def get_tick_transform(self, axes):
|
||||
return transform
|
||||
|
||||
def get_tick_iterators(self, axes):
|
||||
# A pair of iterables (one for major ticks, one for minor ticks)
|
||||
# that yield (tick_position, tick_angle, tick_label).
|
||||
return iter_major, iter_minor
|
||||
"""
|
||||
|
||||
def __init__(self, nth_coord):
|
||||
self.nth_coord = nth_coord
|
||||
|
||||
def update_lim(self, axes):
|
||||
pass
|
||||
|
||||
def get_nth_coord(self):
|
||||
return self.nth_coord
|
||||
|
||||
def _to_xy(self, values, const):
|
||||
"""
|
||||
Create a (*values.shape, 2)-shape array representing (x, y) pairs.
|
||||
|
||||
The other coordinate is filled with the constant *const*.
|
||||
|
||||
Example::
|
||||
|
||||
>>> self.nth_coord = 0
|
||||
>>> self._to_xy([1, 2, 3], const=0)
|
||||
array([[1, 0],
|
||||
[2, 0],
|
||||
[3, 0]])
|
||||
"""
|
||||
if self.nth_coord == 0:
|
||||
return np.stack(np.broadcast_arrays(values, const), axis=-1)
|
||||
elif self.nth_coord == 1:
|
||||
return np.stack(np.broadcast_arrays(const, values), axis=-1)
|
||||
else:
|
||||
raise ValueError("Unexpected nth_coord")
|
||||
|
||||
|
||||
class _FixedAxisArtistHelperBase(_AxisArtistHelperBase):
|
||||
"""Helper class for a fixed (in the axes coordinate) axis."""
|
||||
|
||||
@_api.delete_parameter("3.9", "nth_coord")
|
||||
def __init__(self, loc, nth_coord=None):
|
||||
"""``nth_coord = 0``: x-axis; ``nth_coord = 1``: y-axis."""
|
||||
super().__init__(_api.check_getitem(
|
||||
{"bottom": 0, "top": 0, "left": 1, "right": 1}, loc=loc))
|
||||
self._loc = loc
|
||||
self._pos = {"bottom": 0, "top": 1, "left": 0, "right": 1}[loc]
|
||||
# axis line in transAxes
|
||||
self._path = Path(self._to_xy((0, 1), const=self._pos))
|
||||
|
||||
# LINE
|
||||
|
||||
def get_line(self, axes):
|
||||
return self._path
|
||||
|
||||
def get_line_transform(self, axes):
|
||||
return axes.transAxes
|
||||
|
||||
# LABEL
|
||||
|
||||
def get_axislabel_transform(self, axes):
|
||||
return axes.transAxes
|
||||
|
||||
def get_axislabel_pos_angle(self, axes):
|
||||
"""
|
||||
Return the label reference position in transAxes.
|
||||
|
||||
get_label_transform() returns a transform of (transAxes+offset)
|
||||
"""
|
||||
return dict(left=((0., 0.5), 90), # (position, angle_tangent)
|
||||
right=((1., 0.5), 90),
|
||||
bottom=((0.5, 0.), 0),
|
||||
top=((0.5, 1.), 0))[self._loc]
|
||||
|
||||
# TICK
|
||||
|
||||
def get_tick_transform(self, axes):
|
||||
return [axes.get_xaxis_transform(), axes.get_yaxis_transform()][self.nth_coord]
|
||||
|
||||
|
||||
class _FloatingAxisArtistHelperBase(_AxisArtistHelperBase):
|
||||
def __init__(self, nth_coord, value):
|
||||
self._value = value
|
||||
super().__init__(nth_coord)
|
||||
|
||||
def get_line(self, axes):
|
||||
raise RuntimeError("get_line method should be defined by the derived class")
|
||||
|
||||
|
||||
class FixedAxisArtistHelperRectilinear(_FixedAxisArtistHelperBase):
|
||||
|
||||
@_api.delete_parameter("3.9", "nth_coord")
|
||||
def __init__(self, axes, loc, nth_coord=None):
|
||||
"""
|
||||
nth_coord = along which coordinate value varies
|
||||
in 2D, nth_coord = 0 -> x axis, nth_coord = 1 -> y axis
|
||||
"""
|
||||
super().__init__(loc)
|
||||
self.axis = [axes.xaxis, axes.yaxis][self.nth_coord]
|
||||
|
||||
# TICK
|
||||
|
||||
def get_tick_iterators(self, axes):
|
||||
"""tick_loc, tick_angle, tick_label"""
|
||||
angle_normal, angle_tangent = {0: (90, 0), 1: (0, 90)}[self.nth_coord]
|
||||
|
||||
major = self.axis.major
|
||||
major_locs = major.locator()
|
||||
major_labels = major.formatter.format_ticks(major_locs)
|
||||
|
||||
minor = self.axis.minor
|
||||
minor_locs = minor.locator()
|
||||
minor_labels = minor.formatter.format_ticks(minor_locs)
|
||||
|
||||
tick_to_axes = self.get_tick_transform(axes) - axes.transAxes
|
||||
|
||||
def _f(locs, labels):
|
||||
for loc, label in zip(locs, labels):
|
||||
c = self._to_xy(loc, const=self._pos)
|
||||
# check if the tick point is inside axes
|
||||
c2 = tick_to_axes.transform(c)
|
||||
if mpl.transforms._interval_contains_close((0, 1), c2[self.nth_coord]):
|
||||
yield c, angle_normal, angle_tangent, label
|
||||
|
||||
return _f(major_locs, major_labels), _f(minor_locs, minor_labels)
|
||||
|
||||
|
||||
class FloatingAxisArtistHelperRectilinear(_FloatingAxisArtistHelperBase):
|
||||
|
||||
def __init__(self, axes, nth_coord,
|
||||
passingthrough_point, axis_direction="bottom"):
|
||||
super().__init__(nth_coord, passingthrough_point)
|
||||
self._axis_direction = axis_direction
|
||||
self.axis = [axes.xaxis, axes.yaxis][self.nth_coord]
|
||||
|
||||
def get_line(self, axes):
|
||||
fixed_coord = 1 - self.nth_coord
|
||||
data_to_axes = axes.transData - axes.transAxes
|
||||
p = data_to_axes.transform([self._value, self._value])
|
||||
return Path(self._to_xy((0, 1), const=p[fixed_coord]))
|
||||
|
||||
def get_line_transform(self, axes):
|
||||
return axes.transAxes
|
||||
|
||||
def get_axislabel_transform(self, axes):
|
||||
return axes.transAxes
|
||||
|
||||
def get_axislabel_pos_angle(self, axes):
|
||||
"""
|
||||
Return the label reference position in transAxes.
|
||||
|
||||
get_label_transform() returns a transform of (transAxes+offset)
|
||||
"""
|
||||
angle = [0, 90][self.nth_coord]
|
||||
fixed_coord = 1 - self.nth_coord
|
||||
data_to_axes = axes.transData - axes.transAxes
|
||||
p = data_to_axes.transform([self._value, self._value])
|
||||
verts = self._to_xy(0.5, const=p[fixed_coord])
|
||||
return (verts, angle) if 0 <= verts[fixed_coord] <= 1 else (None, None)
|
||||
|
||||
def get_tick_transform(self, axes):
|
||||
return axes.transData
|
||||
|
||||
def get_tick_iterators(self, axes):
|
||||
"""tick_loc, tick_angle, tick_label"""
|
||||
angle_normal, angle_tangent = {0: (90, 0), 1: (0, 90)}[self.nth_coord]
|
||||
|
||||
major = self.axis.major
|
||||
major_locs = major.locator()
|
||||
major_labels = major.formatter.format_ticks(major_locs)
|
||||
|
||||
minor = self.axis.minor
|
||||
minor_locs = minor.locator()
|
||||
minor_labels = minor.formatter.format_ticks(minor_locs)
|
||||
|
||||
data_to_axes = axes.transData - axes.transAxes
|
||||
|
||||
def _f(locs, labels):
|
||||
for loc, label in zip(locs, labels):
|
||||
c = self._to_xy(loc, const=self._value)
|
||||
c1, c2 = data_to_axes.transform(c)
|
||||
if 0 <= c1 <= 1 and 0 <= c2 <= 1:
|
||||
yield c, angle_normal, angle_tangent, label
|
||||
|
||||
return _f(major_locs, major_labels), _f(minor_locs, minor_labels)
|
||||
|
||||
|
||||
class AxisArtistHelper: # Backcompat.
|
||||
Fixed = _FixedAxisArtistHelperBase
|
||||
Floating = _FloatingAxisArtistHelperBase
|
||||
|
||||
|
||||
class AxisArtistHelperRectlinear: # Backcompat.
|
||||
Fixed = FixedAxisArtistHelperRectilinear
|
||||
Floating = FloatingAxisArtistHelperRectilinear
|
||||
|
||||
|
||||
class GridHelperBase:
|
||||
|
||||
def __init__(self):
|
||||
self._old_limits = None
|
||||
super().__init__()
|
||||
|
||||
def update_lim(self, axes):
|
||||
x1, x2 = axes.get_xlim()
|
||||
y1, y2 = axes.get_ylim()
|
||||
if self._old_limits != (x1, x2, y1, y2):
|
||||
self._update_grid(x1, y1, x2, y2)
|
||||
self._old_limits = (x1, x2, y1, y2)
|
||||
|
||||
def _update_grid(self, x1, y1, x2, y2):
|
||||
"""Cache relevant computations when the axes limits have changed."""
|
||||
|
||||
def get_gridlines(self, which, axis):
|
||||
"""
|
||||
Return list of grid lines as a list of paths (list of points).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
which : {"both", "major", "minor"}
|
||||
axis : {"both", "x", "y"}
|
||||
"""
|
||||
return []
|
||||
|
||||
|
||||
class GridHelperRectlinear(GridHelperBase):
|
||||
|
||||
def __init__(self, axes):
|
||||
super().__init__()
|
||||
self.axes = axes
|
||||
|
||||
@_api.delete_parameter(
|
||||
"3.9", "nth_coord", addendum="'nth_coord' is now inferred from 'loc'.")
|
||||
def new_fixed_axis(
|
||||
self, loc, nth_coord=None, axis_direction=None, offset=None, axes=None):
|
||||
if axes is None:
|
||||
_api.warn_external(
|
||||
"'new_fixed_axis' explicitly requires the axes keyword.")
|
||||
axes = self.axes
|
||||
if axis_direction is None:
|
||||
axis_direction = loc
|
||||
return AxisArtist(axes, FixedAxisArtistHelperRectilinear(axes, loc),
|
||||
offset=offset, axis_direction=axis_direction)
|
||||
|
||||
def new_floating_axis(self, nth_coord, value, axis_direction="bottom", axes=None):
|
||||
if axes is None:
|
||||
_api.warn_external(
|
||||
"'new_floating_axis' explicitly requires the axes keyword.")
|
||||
axes = self.axes
|
||||
helper = FloatingAxisArtistHelperRectilinear(
|
||||
axes, nth_coord, value, axis_direction)
|
||||
axisline = AxisArtist(axes, helper, axis_direction=axis_direction)
|
||||
axisline.line.set_clip_on(True)
|
||||
axisline.line.set_clip_box(axisline.axes.bbox)
|
||||
return axisline
|
||||
|
||||
def get_gridlines(self, which="major", axis="both"):
|
||||
"""
|
||||
Return list of gridline coordinates in data coordinates.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
which : {"both", "major", "minor"}
|
||||
axis : {"both", "x", "y"}
|
||||
"""
|
||||
_api.check_in_list(["both", "major", "minor"], which=which)
|
||||
_api.check_in_list(["both", "x", "y"], axis=axis)
|
||||
gridlines = []
|
||||
|
||||
if axis in ("both", "x"):
|
||||
locs = []
|
||||
y1, y2 = self.axes.get_ylim()
|
||||
if which in ("both", "major"):
|
||||
locs.extend(self.axes.xaxis.major.locator())
|
||||
if which in ("both", "minor"):
|
||||
locs.extend(self.axes.xaxis.minor.locator())
|
||||
gridlines.extend([[x, x], [y1, y2]] for x in locs)
|
||||
|
||||
if axis in ("both", "y"):
|
||||
x1, x2 = self.axes.get_xlim()
|
||||
locs = []
|
||||
if self.axes.yaxis._major_tick_kw["gridOn"]:
|
||||
locs.extend(self.axes.yaxis.major.locator())
|
||||
if self.axes.yaxis._minor_tick_kw["gridOn"]:
|
||||
locs.extend(self.axes.yaxis.minor.locator())
|
||||
gridlines.extend([[x1, x2], [y, y]] for y in locs)
|
||||
|
||||
return gridlines
|
||||
|
||||
|
||||
class Axes(maxes.Axes):
|
||||
|
||||
def __init__(self, *args, grid_helper=None, **kwargs):
|
||||
self._axisline_on = True
|
||||
self._grid_helper = grid_helper if grid_helper else GridHelperRectlinear(self)
|
||||
super().__init__(*args, **kwargs)
|
||||
self.toggle_axisline(True)
|
||||
|
||||
def toggle_axisline(self, b=None):
|
||||
if b is None:
|
||||
b = not self._axisline_on
|
||||
if b:
|
||||
self._axisline_on = True
|
||||
self.spines[:].set_visible(False)
|
||||
self.xaxis.set_visible(False)
|
||||
self.yaxis.set_visible(False)
|
||||
else:
|
||||
self._axisline_on = False
|
||||
self.spines[:].set_visible(True)
|
||||
self.xaxis.set_visible(True)
|
||||
self.yaxis.set_visible(True)
|
||||
|
||||
@property
|
||||
def axis(self):
|
||||
return self._axislines
|
||||
|
||||
def clear(self):
|
||||
# docstring inherited
|
||||
|
||||
# Init gridlines before clear() as clear() calls grid().
|
||||
self.gridlines = gridlines = GridlinesCollection(
|
||||
[],
|
||||
colors=mpl.rcParams['grid.color'],
|
||||
linestyles=mpl.rcParams['grid.linestyle'],
|
||||
linewidths=mpl.rcParams['grid.linewidth'])
|
||||
self._set_artist_props(gridlines)
|
||||
gridlines.set_grid_helper(self.get_grid_helper())
|
||||
|
||||
super().clear()
|
||||
|
||||
# clip_path is set after Axes.clear(): that's when a patch is created.
|
||||
gridlines.set_clip_path(self.axes.patch)
|
||||
|
||||
# Init axis artists.
|
||||
self._axislines = mpl_axes.Axes.AxisDict(self)
|
||||
new_fixed_axis = self.get_grid_helper().new_fixed_axis
|
||||
self._axislines.update({
|
||||
loc: new_fixed_axis(loc=loc, axes=self, axis_direction=loc)
|
||||
for loc in ["bottom", "top", "left", "right"]})
|
||||
for axisline in [self._axislines["top"], self._axislines["right"]]:
|
||||
axisline.label.set_visible(False)
|
||||
axisline.major_ticklabels.set_visible(False)
|
||||
axisline.minor_ticklabels.set_visible(False)
|
||||
|
||||
def get_grid_helper(self):
|
||||
return self._grid_helper
|
||||
|
||||
def grid(self, visible=None, which='major', axis="both", **kwargs):
|
||||
"""
|
||||
Toggle the gridlines, and optionally set the properties of the lines.
|
||||
"""
|
||||
# There are some discrepancies in the behavior of grid() between
|
||||
# axes_grid and Matplotlib, because axes_grid explicitly sets the
|
||||
# visibility of the gridlines.
|
||||
super().grid(visible, which=which, axis=axis, **kwargs)
|
||||
if not self._axisline_on:
|
||||
return
|
||||
if visible is None:
|
||||
visible = (self.axes.xaxis._minor_tick_kw["gridOn"]
|
||||
or self.axes.xaxis._major_tick_kw["gridOn"]
|
||||
or self.axes.yaxis._minor_tick_kw["gridOn"]
|
||||
or self.axes.yaxis._major_tick_kw["gridOn"])
|
||||
self.gridlines.set(which=which, axis=axis, visible=visible)
|
||||
self.gridlines.set(**kwargs)
|
||||
|
||||
def get_children(self):
|
||||
if self._axisline_on:
|
||||
children = [*self._axislines.values(), self.gridlines]
|
||||
else:
|
||||
children = []
|
||||
children.extend(super().get_children())
|
||||
return children
|
||||
|
||||
def new_fixed_axis(self, loc, offset=None):
|
||||
return self.get_grid_helper().new_fixed_axis(loc, offset=offset, axes=self)
|
||||
|
||||
def new_floating_axis(self, nth_coord, value, axis_direction="bottom"):
|
||||
return self.get_grid_helper().new_floating_axis(
|
||||
nth_coord, value, axis_direction=axis_direction, axes=self)
|
||||
|
||||
|
||||
class AxesZero(Axes):
|
||||
|
||||
def clear(self):
|
||||
super().clear()
|
||||
new_floating_axis = self.get_grid_helper().new_floating_axis
|
||||
self._axislines.update(
|
||||
xzero=new_floating_axis(
|
||||
nth_coord=0, value=0., axis_direction="bottom", axes=self),
|
||||
yzero=new_floating_axis(
|
||||
nth_coord=1, value=0., axis_direction="left", axes=self),
|
||||
)
|
||||
for k in ["xzero", "yzero"]:
|
||||
self._axislines[k].line.set_clip_path(self.patch)
|
||||
self._axislines[k].set_visible(False)
|
||||
|
||||
|
||||
Subplot = Axes
|
||||
SubplotZero = AxesZero
|
|
@ -0,0 +1,275 @@
|
|||
"""
|
||||
An experimental support for curvilinear grid.
|
||||
"""
|
||||
|
||||
# TODO :
|
||||
# see if tick_iterator method can be simplified by reusing the parent method.
|
||||
|
||||
import functools
|
||||
|
||||
import numpy as np
|
||||
|
||||
import matplotlib as mpl
|
||||
from matplotlib import _api, cbook
|
||||
import matplotlib.patches as mpatches
|
||||
from matplotlib.path import Path
|
||||
|
||||
from mpl_toolkits.axes_grid1.parasite_axes import host_axes_class_factory
|
||||
|
||||
from . import axislines, grid_helper_curvelinear
|
||||
from .axis_artist import AxisArtist
|
||||
from .grid_finder import ExtremeFinderSimple
|
||||
|
||||
|
||||
class FloatingAxisArtistHelper(
|
||||
grid_helper_curvelinear.FloatingAxisArtistHelper):
|
||||
pass
|
||||
|
||||
|
||||
class FixedAxisArtistHelper(grid_helper_curvelinear.FloatingAxisArtistHelper):
|
||||
|
||||
def __init__(self, grid_helper, side, nth_coord_ticks=None):
|
||||
"""
|
||||
nth_coord = along which coordinate value varies.
|
||||
nth_coord = 0 -> x axis, nth_coord = 1 -> y axis
|
||||
"""
|
||||
lon1, lon2, lat1, lat2 = grid_helper.grid_finder.extreme_finder(*[None] * 5)
|
||||
value, nth_coord = _api.check_getitem(
|
||||
dict(left=(lon1, 0), right=(lon2, 0), bottom=(lat1, 1), top=(lat2, 1)),
|
||||
side=side)
|
||||
super().__init__(grid_helper, nth_coord, value, axis_direction=side)
|
||||
if nth_coord_ticks is None:
|
||||
nth_coord_ticks = nth_coord
|
||||
self.nth_coord_ticks = nth_coord_ticks
|
||||
|
||||
self.value = value
|
||||
self.grid_helper = grid_helper
|
||||
self._side = side
|
||||
|
||||
def update_lim(self, axes):
|
||||
self.grid_helper.update_lim(axes)
|
||||
self._grid_info = self.grid_helper._grid_info
|
||||
|
||||
def get_tick_iterators(self, axes):
|
||||
"""tick_loc, tick_angle, tick_label, (optionally) tick_label"""
|
||||
|
||||
grid_finder = self.grid_helper.grid_finder
|
||||
|
||||
lat_levs, lat_n, lat_factor = self._grid_info["lat_info"]
|
||||
yy0 = lat_levs / lat_factor
|
||||
|
||||
lon_levs, lon_n, lon_factor = self._grid_info["lon_info"]
|
||||
xx0 = lon_levs / lon_factor
|
||||
|
||||
extremes = self.grid_helper.grid_finder.extreme_finder(*[None] * 5)
|
||||
xmin, xmax = sorted(extremes[:2])
|
||||
ymin, ymax = sorted(extremes[2:])
|
||||
|
||||
def trf_xy(x, y):
|
||||
trf = grid_finder.get_transform() + axes.transData
|
||||
return trf.transform(np.column_stack(np.broadcast_arrays(x, y))).T
|
||||
|
||||
if self.nth_coord == 0:
|
||||
mask = (ymin <= yy0) & (yy0 <= ymax)
|
||||
(xx1, yy1), (dxx1, dyy1), (dxx2, dyy2) = \
|
||||
grid_helper_curvelinear._value_and_jacobian(
|
||||
trf_xy, self.value, yy0[mask], (xmin, xmax), (ymin, ymax))
|
||||
labels = self._grid_info["lat_labels"]
|
||||
|
||||
elif self.nth_coord == 1:
|
||||
mask = (xmin <= xx0) & (xx0 <= xmax)
|
||||
(xx1, yy1), (dxx2, dyy2), (dxx1, dyy1) = \
|
||||
grid_helper_curvelinear._value_and_jacobian(
|
||||
trf_xy, xx0[mask], self.value, (xmin, xmax), (ymin, ymax))
|
||||
labels = self._grid_info["lon_labels"]
|
||||
|
||||
labels = [l for l, m in zip(labels, mask) if m]
|
||||
|
||||
angle_normal = np.arctan2(dyy1, dxx1)
|
||||
angle_tangent = np.arctan2(dyy2, dxx2)
|
||||
mm = (dyy1 == 0) & (dxx1 == 0) # points with degenerate normal
|
||||
angle_normal[mm] = angle_tangent[mm] + np.pi / 2
|
||||
|
||||
tick_to_axes = self.get_tick_transform(axes) - axes.transAxes
|
||||
in_01 = functools.partial(
|
||||
mpl.transforms._interval_contains_close, (0, 1))
|
||||
|
||||
def f1():
|
||||
for x, y, normal, tangent, lab \
|
||||
in zip(xx1, yy1, angle_normal, angle_tangent, labels):
|
||||
c2 = tick_to_axes.transform((x, y))
|
||||
if in_01(c2[0]) and in_01(c2[1]):
|
||||
yield [x, y], *np.rad2deg([normal, tangent]), lab
|
||||
|
||||
return f1(), iter([])
|
||||
|
||||
def get_line(self, axes):
|
||||
self.update_lim(axes)
|
||||
k, v = dict(left=("lon_lines0", 0),
|
||||
right=("lon_lines0", 1),
|
||||
bottom=("lat_lines0", 0),
|
||||
top=("lat_lines0", 1))[self._side]
|
||||
xx, yy = self._grid_info[k][v]
|
||||
return Path(np.column_stack([xx, yy]))
|
||||
|
||||
|
||||
class ExtremeFinderFixed(ExtremeFinderSimple):
|
||||
# docstring inherited
|
||||
|
||||
def __init__(self, extremes):
|
||||
"""
|
||||
This subclass always returns the same bounding box.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
extremes : (float, float, float, float)
|
||||
The bounding box that this helper always returns.
|
||||
"""
|
||||
self._extremes = extremes
|
||||
|
||||
def __call__(self, transform_xy, x1, y1, x2, y2):
|
||||
# docstring inherited
|
||||
return self._extremes
|
||||
|
||||
|
||||
class GridHelperCurveLinear(grid_helper_curvelinear.GridHelperCurveLinear):
|
||||
|
||||
def __init__(self, aux_trans, extremes,
|
||||
grid_locator1=None,
|
||||
grid_locator2=None,
|
||||
tick_formatter1=None,
|
||||
tick_formatter2=None):
|
||||
# docstring inherited
|
||||
super().__init__(aux_trans,
|
||||
extreme_finder=ExtremeFinderFixed(extremes),
|
||||
grid_locator1=grid_locator1,
|
||||
grid_locator2=grid_locator2,
|
||||
tick_formatter1=tick_formatter1,
|
||||
tick_formatter2=tick_formatter2)
|
||||
|
||||
def new_fixed_axis(
|
||||
self, loc, nth_coord=None, axis_direction=None, offset=None, axes=None):
|
||||
if axes is None:
|
||||
axes = self.axes
|
||||
if axis_direction is None:
|
||||
axis_direction = loc
|
||||
# This is not the same as the FixedAxisArtistHelper class used by
|
||||
# grid_helper_curvelinear.GridHelperCurveLinear.new_fixed_axis!
|
||||
helper = FixedAxisArtistHelper(
|
||||
self, loc, nth_coord_ticks=nth_coord)
|
||||
axisline = AxisArtist(axes, helper, axis_direction=axis_direction)
|
||||
# Perhaps should be moved to the base class?
|
||||
axisline.line.set_clip_on(True)
|
||||
axisline.line.set_clip_box(axisline.axes.bbox)
|
||||
return axisline
|
||||
|
||||
# new_floating_axis will inherit the grid_helper's extremes.
|
||||
|
||||
# def new_floating_axis(self, nth_coord, value, axes=None, axis_direction="bottom"):
|
||||
# axis = super(GridHelperCurveLinear,
|
||||
# self).new_floating_axis(nth_coord,
|
||||
# value, axes=axes,
|
||||
# axis_direction=axis_direction)
|
||||
# # set extreme values of the axis helper
|
||||
# if nth_coord == 1:
|
||||
# axis.get_helper().set_extremes(*self._extremes[:2])
|
||||
# elif nth_coord == 0:
|
||||
# axis.get_helper().set_extremes(*self._extremes[2:])
|
||||
# return axis
|
||||
|
||||
def _update_grid(self, x1, y1, x2, y2):
|
||||
if self._grid_info is None:
|
||||
self._grid_info = dict()
|
||||
|
||||
grid_info = self._grid_info
|
||||
|
||||
grid_finder = self.grid_finder
|
||||
extremes = grid_finder.extreme_finder(grid_finder.inv_transform_xy,
|
||||
x1, y1, x2, y2)
|
||||
|
||||
lon_min, lon_max = sorted(extremes[:2])
|
||||
lat_min, lat_max = sorted(extremes[2:])
|
||||
grid_info["extremes"] = lon_min, lon_max, lat_min, lat_max # extremes
|
||||
|
||||
lon_levs, lon_n, lon_factor = \
|
||||
grid_finder.grid_locator1(lon_min, lon_max)
|
||||
lon_levs = np.asarray(lon_levs)
|
||||
lat_levs, lat_n, lat_factor = \
|
||||
grid_finder.grid_locator2(lat_min, lat_max)
|
||||
lat_levs = np.asarray(lat_levs)
|
||||
|
||||
grid_info["lon_info"] = lon_levs, lon_n, lon_factor
|
||||
grid_info["lat_info"] = lat_levs, lat_n, lat_factor
|
||||
|
||||
grid_info["lon_labels"] = grid_finder._format_ticks(
|
||||
1, "bottom", lon_factor, lon_levs)
|
||||
grid_info["lat_labels"] = grid_finder._format_ticks(
|
||||
2, "bottom", lat_factor, lat_levs)
|
||||
|
||||
lon_values = lon_levs[:lon_n] / lon_factor
|
||||
lat_values = lat_levs[:lat_n] / lat_factor
|
||||
|
||||
lon_lines, lat_lines = grid_finder._get_raw_grid_lines(
|
||||
lon_values[(lon_min < lon_values) & (lon_values < lon_max)],
|
||||
lat_values[(lat_min < lat_values) & (lat_values < lat_max)],
|
||||
lon_min, lon_max, lat_min, lat_max)
|
||||
|
||||
grid_info["lon_lines"] = lon_lines
|
||||
grid_info["lat_lines"] = lat_lines
|
||||
|
||||
lon_lines, lat_lines = grid_finder._get_raw_grid_lines(
|
||||
# lon_min, lon_max, lat_min, lat_max)
|
||||
extremes[:2], extremes[2:], *extremes)
|
||||
|
||||
grid_info["lon_lines0"] = lon_lines
|
||||
grid_info["lat_lines0"] = lat_lines
|
||||
|
||||
def get_gridlines(self, which="major", axis="both"):
|
||||
grid_lines = []
|
||||
if axis in ["both", "x"]:
|
||||
grid_lines.extend(self._grid_info["lon_lines"])
|
||||
if axis in ["both", "y"]:
|
||||
grid_lines.extend(self._grid_info["lat_lines"])
|
||||
return grid_lines
|
||||
|
||||
|
||||
class FloatingAxesBase:
|
||||
|
||||
def __init__(self, *args, grid_helper, **kwargs):
|
||||
_api.check_isinstance(GridHelperCurveLinear, grid_helper=grid_helper)
|
||||
super().__init__(*args, grid_helper=grid_helper, **kwargs)
|
||||
self.set_aspect(1.)
|
||||
|
||||
def _gen_axes_patch(self):
|
||||
# docstring inherited
|
||||
x0, x1, y0, y1 = self.get_grid_helper().grid_finder.extreme_finder(*[None] * 5)
|
||||
patch = mpatches.Polygon([(x0, y0), (x1, y0), (x1, y1), (x0, y1)])
|
||||
patch.get_path()._interpolation_steps = 100
|
||||
return patch
|
||||
|
||||
def clear(self):
|
||||
super().clear()
|
||||
self.patch.set_transform(
|
||||
self.get_grid_helper().grid_finder.get_transform()
|
||||
+ self.transData)
|
||||
# The original patch is not in the draw tree; it is only used for
|
||||
# clipping purposes.
|
||||
orig_patch = super()._gen_axes_patch()
|
||||
orig_patch.set_figure(self.get_figure(root=False))
|
||||
orig_patch.set_transform(self.transAxes)
|
||||
self.patch.set_clip_path(orig_patch)
|
||||
self.gridlines.set_clip_path(orig_patch)
|
||||
self.adjust_axes_lim()
|
||||
|
||||
def adjust_axes_lim(self):
|
||||
bbox = self.patch.get_path().get_extents(
|
||||
# First transform to pixel coords, then to parent data coords.
|
||||
self.patch.get_transform() - self.transData)
|
||||
bbox = bbox.expanded(1.02, 1.02)
|
||||
self.set_xlim(bbox.xmin, bbox.xmax)
|
||||
self.set_ylim(bbox.ymin, bbox.ymax)
|
||||
|
||||
|
||||
floatingaxes_class_factory = cbook._make_class_factory(FloatingAxesBase, "Floating{}")
|
||||
FloatingAxes = floatingaxes_class_factory(host_axes_class_factory(axislines.Axes))
|
||||
FloatingSubplot = FloatingAxes
|
|
@ -0,0 +1,326 @@
|
|||
import numpy as np
|
||||
|
||||
from matplotlib import ticker as mticker, _api
|
||||
from matplotlib.transforms import Bbox, Transform
|
||||
|
||||
|
||||
def _find_line_box_crossings(xys, bbox):
|
||||
"""
|
||||
Find the points where a polyline crosses a bbox, and the crossing angles.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
xys : (N, 2) array
|
||||
The polyline coordinates.
|
||||
bbox : `.Bbox`
|
||||
The bounding box.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list of ((float, float), float)
|
||||
Four separate lists of crossings, for the left, right, bottom, and top
|
||||
sides of the bbox, respectively. For each list, the entries are the
|
||||
``((x, y), ccw_angle_in_degrees)`` of the crossing, where an angle of 0
|
||||
means that the polyline is moving to the right at the crossing point.
|
||||
|
||||
The entries are computed by linearly interpolating at each crossing
|
||||
between the nearest points on either side of the bbox edges.
|
||||
"""
|
||||
crossings = []
|
||||
dxys = xys[1:] - xys[:-1]
|
||||
for sl in [slice(None), slice(None, None, -1)]:
|
||||
us, vs = xys.T[sl] # "this" coord, "other" coord
|
||||
dus, dvs = dxys.T[sl]
|
||||
umin, vmin = bbox.min[sl]
|
||||
umax, vmax = bbox.max[sl]
|
||||
for u0, inside in [(umin, us > umin), (umax, us < umax)]:
|
||||
cross = []
|
||||
idxs, = (inside[:-1] ^ inside[1:]).nonzero()
|
||||
for idx in idxs:
|
||||
v = vs[idx] + (u0 - us[idx]) * dvs[idx] / dus[idx]
|
||||
if not vmin <= v <= vmax:
|
||||
continue
|
||||
crossing = (u0, v)[sl]
|
||||
theta = np.degrees(np.arctan2(*dxys[idx][::-1]))
|
||||
cross.append((crossing, theta))
|
||||
crossings.append(cross)
|
||||
return crossings
|
||||
|
||||
|
||||
class ExtremeFinderSimple:
|
||||
"""
|
||||
A helper class to figure out the range of grid lines that need to be drawn.
|
||||
"""
|
||||
|
||||
def __init__(self, nx, ny):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
nx, ny : int
|
||||
The number of samples in each direction.
|
||||
"""
|
||||
self.nx = nx
|
||||
self.ny = ny
|
||||
|
||||
def __call__(self, transform_xy, x1, y1, x2, y2):
|
||||
"""
|
||||
Compute an approximation of the bounding box obtained by applying
|
||||
*transform_xy* to the box delimited by ``(x1, y1, x2, y2)``.
|
||||
|
||||
The intended use is to have ``(x1, y1, x2, y2)`` in axes coordinates,
|
||||
and have *transform_xy* be the transform from axes coordinates to data
|
||||
coordinates; this method then returns the range of data coordinates
|
||||
that span the actual axes.
|
||||
|
||||
The computation is done by sampling ``nx * ny`` equispaced points in
|
||||
the ``(x1, y1, x2, y2)`` box and finding the resulting points with
|
||||
extremal coordinates; then adding some padding to take into account the
|
||||
finite sampling.
|
||||
|
||||
As each sampling step covers a relative range of *1/nx* or *1/ny*,
|
||||
the padding is computed by expanding the span covered by the extremal
|
||||
coordinates by these fractions.
|
||||
"""
|
||||
x, y = np.meshgrid(
|
||||
np.linspace(x1, x2, self.nx), np.linspace(y1, y2, self.ny))
|
||||
xt, yt = transform_xy(np.ravel(x), np.ravel(y))
|
||||
return self._add_pad(xt.min(), xt.max(), yt.min(), yt.max())
|
||||
|
||||
def _add_pad(self, x_min, x_max, y_min, y_max):
|
||||
"""Perform the padding mentioned in `__call__`."""
|
||||
dx = (x_max - x_min) / self.nx
|
||||
dy = (y_max - y_min) / self.ny
|
||||
return x_min - dx, x_max + dx, y_min - dy, y_max + dy
|
||||
|
||||
|
||||
class _User2DTransform(Transform):
|
||||
"""A transform defined by two user-set functions."""
|
||||
|
||||
input_dims = output_dims = 2
|
||||
|
||||
def __init__(self, forward, backward):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
forward, backward : callable
|
||||
The forward and backward transforms, taking ``x`` and ``y`` as
|
||||
separate arguments and returning ``(tr_x, tr_y)``.
|
||||
"""
|
||||
# The normal Matplotlib convention would be to take and return an
|
||||
# (N, 2) array but axisartist uses the transposed version.
|
||||
super().__init__()
|
||||
self._forward = forward
|
||||
self._backward = backward
|
||||
|
||||
def transform_non_affine(self, values):
|
||||
# docstring inherited
|
||||
return np.transpose(self._forward(*np.transpose(values)))
|
||||
|
||||
def inverted(self):
|
||||
# docstring inherited
|
||||
return type(self)(self._backward, self._forward)
|
||||
|
||||
|
||||
class GridFinder:
|
||||
"""
|
||||
Internal helper for `~.grid_helper_curvelinear.GridHelperCurveLinear`, with
|
||||
the same constructor parameters; should not be directly instantiated.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
transform,
|
||||
extreme_finder=None,
|
||||
grid_locator1=None,
|
||||
grid_locator2=None,
|
||||
tick_formatter1=None,
|
||||
tick_formatter2=None):
|
||||
if extreme_finder is None:
|
||||
extreme_finder = ExtremeFinderSimple(20, 20)
|
||||
if grid_locator1 is None:
|
||||
grid_locator1 = MaxNLocator()
|
||||
if grid_locator2 is None:
|
||||
grid_locator2 = MaxNLocator()
|
||||
if tick_formatter1 is None:
|
||||
tick_formatter1 = FormatterPrettyPrint()
|
||||
if tick_formatter2 is None:
|
||||
tick_formatter2 = FormatterPrettyPrint()
|
||||
self.extreme_finder = extreme_finder
|
||||
self.grid_locator1 = grid_locator1
|
||||
self.grid_locator2 = grid_locator2
|
||||
self.tick_formatter1 = tick_formatter1
|
||||
self.tick_formatter2 = tick_formatter2
|
||||
self.set_transform(transform)
|
||||
|
||||
def _format_ticks(self, idx, direction, factor, levels):
|
||||
"""
|
||||
Helper to support both standard formatters (inheriting from
|
||||
`.mticker.Formatter`) and axisartist-specific ones; should be called instead of
|
||||
directly calling ``self.tick_formatter1`` and ``self.tick_formatter2``. This
|
||||
method should be considered as a temporary workaround which will be removed in
|
||||
the future at the same time as axisartist-specific formatters.
|
||||
"""
|
||||
fmt = _api.check_getitem(
|
||||
{1: self.tick_formatter1, 2: self.tick_formatter2}, idx=idx)
|
||||
return (fmt.format_ticks(levels) if isinstance(fmt, mticker.Formatter)
|
||||
else fmt(direction, factor, levels))
|
||||
|
||||
def get_grid_info(self, x1, y1, x2, y2):
|
||||
"""
|
||||
lon_values, lat_values : list of grid values. if integer is given,
|
||||
rough number of grids in each direction.
|
||||
"""
|
||||
|
||||
extremes = self.extreme_finder(self.inv_transform_xy, x1, y1, x2, y2)
|
||||
|
||||
# min & max rage of lat (or lon) for each grid line will be drawn.
|
||||
# i.e., gridline of lon=0 will be drawn from lat_min to lat_max.
|
||||
|
||||
lon_min, lon_max, lat_min, lat_max = extremes
|
||||
lon_levs, lon_n, lon_factor = self.grid_locator1(lon_min, lon_max)
|
||||
lon_levs = np.asarray(lon_levs)
|
||||
lat_levs, lat_n, lat_factor = self.grid_locator2(lat_min, lat_max)
|
||||
lat_levs = np.asarray(lat_levs)
|
||||
|
||||
lon_values = lon_levs[:lon_n] / lon_factor
|
||||
lat_values = lat_levs[:lat_n] / lat_factor
|
||||
|
||||
lon_lines, lat_lines = self._get_raw_grid_lines(lon_values,
|
||||
lat_values,
|
||||
lon_min, lon_max,
|
||||
lat_min, lat_max)
|
||||
|
||||
bb = Bbox.from_extents(x1, y1, x2, y2).expanded(1 + 2e-10, 1 + 2e-10)
|
||||
|
||||
grid_info = {
|
||||
"extremes": extremes,
|
||||
# "lon", "lat", filled below.
|
||||
}
|
||||
|
||||
for idx, lon_or_lat, levs, factor, values, lines in [
|
||||
(1, "lon", lon_levs, lon_factor, lon_values, lon_lines),
|
||||
(2, "lat", lat_levs, lat_factor, lat_values, lat_lines),
|
||||
]:
|
||||
grid_info[lon_or_lat] = gi = {
|
||||
"lines": [[l] for l in lines],
|
||||
"ticks": {"left": [], "right": [], "bottom": [], "top": []},
|
||||
}
|
||||
for (lx, ly), v, level in zip(lines, values, levs):
|
||||
all_crossings = _find_line_box_crossings(np.column_stack([lx, ly]), bb)
|
||||
for side, crossings in zip(
|
||||
["left", "right", "bottom", "top"], all_crossings):
|
||||
for crossing in crossings:
|
||||
gi["ticks"][side].append({"level": level, "loc": crossing})
|
||||
for side in gi["ticks"]:
|
||||
levs = [tick["level"] for tick in gi["ticks"][side]]
|
||||
labels = self._format_ticks(idx, side, factor, levs)
|
||||
for tick, label in zip(gi["ticks"][side], labels):
|
||||
tick["label"] = label
|
||||
|
||||
return grid_info
|
||||
|
||||
def _get_raw_grid_lines(self,
|
||||
lon_values, lat_values,
|
||||
lon_min, lon_max, lat_min, lat_max):
|
||||
|
||||
lons_i = np.linspace(lon_min, lon_max, 100) # for interpolation
|
||||
lats_i = np.linspace(lat_min, lat_max, 100)
|
||||
|
||||
lon_lines = [self.transform_xy(np.full_like(lats_i, lon), lats_i)
|
||||
for lon in lon_values]
|
||||
lat_lines = [self.transform_xy(lons_i, np.full_like(lons_i, lat))
|
||||
for lat in lat_values]
|
||||
|
||||
return lon_lines, lat_lines
|
||||
|
||||
def set_transform(self, aux_trans):
|
||||
if isinstance(aux_trans, Transform):
|
||||
self._aux_transform = aux_trans
|
||||
elif len(aux_trans) == 2 and all(map(callable, aux_trans)):
|
||||
self._aux_transform = _User2DTransform(*aux_trans)
|
||||
else:
|
||||
raise TypeError("'aux_trans' must be either a Transform "
|
||||
"instance or a pair of callables")
|
||||
|
||||
def get_transform(self):
|
||||
return self._aux_transform
|
||||
|
||||
update_transform = set_transform # backcompat alias.
|
||||
|
||||
def transform_xy(self, x, y):
|
||||
return self._aux_transform.transform(np.column_stack([x, y])).T
|
||||
|
||||
def inv_transform_xy(self, x, y):
|
||||
return self._aux_transform.inverted().transform(
|
||||
np.column_stack([x, y])).T
|
||||
|
||||
def update(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if k in ["extreme_finder",
|
||||
"grid_locator1",
|
||||
"grid_locator2",
|
||||
"tick_formatter1",
|
||||
"tick_formatter2"]:
|
||||
setattr(self, k, v)
|
||||
else:
|
||||
raise ValueError(f"Unknown update property {k!r}")
|
||||
|
||||
|
||||
class MaxNLocator(mticker.MaxNLocator):
|
||||
def __init__(self, nbins=10, steps=None,
|
||||
trim=True,
|
||||
integer=False,
|
||||
symmetric=False,
|
||||
prune=None):
|
||||
# trim argument has no effect. It has been left for API compatibility
|
||||
super().__init__(nbins, steps=steps, integer=integer,
|
||||
symmetric=symmetric, prune=prune)
|
||||
self.create_dummy_axis()
|
||||
|
||||
def __call__(self, v1, v2):
|
||||
locs = super().tick_values(v1, v2)
|
||||
return np.array(locs), len(locs), 1 # 1: factor (see angle_helper)
|
||||
|
||||
|
||||
class FixedLocator:
|
||||
def __init__(self, locs):
|
||||
self._locs = locs
|
||||
|
||||
def __call__(self, v1, v2):
|
||||
v1, v2 = sorted([v1, v2])
|
||||
locs = np.array([l for l in self._locs if v1 <= l <= v2])
|
||||
return locs, len(locs), 1 # 1: factor (see angle_helper)
|
||||
|
||||
|
||||
# Tick Formatter
|
||||
|
||||
class FormatterPrettyPrint:
|
||||
def __init__(self, useMathText=True):
|
||||
self._fmt = mticker.ScalarFormatter(
|
||||
useMathText=useMathText, useOffset=False)
|
||||
self._fmt.create_dummy_axis()
|
||||
|
||||
def __call__(self, direction, factor, values):
|
||||
return self._fmt.format_ticks(values)
|
||||
|
||||
|
||||
class DictFormatter:
|
||||
def __init__(self, format_dict, formatter=None):
|
||||
"""
|
||||
format_dict : dictionary for format strings to be used.
|
||||
formatter : fall-back formatter
|
||||
"""
|
||||
super().__init__()
|
||||
self._format_dict = format_dict
|
||||
self._fallback_formatter = formatter
|
||||
|
||||
def __call__(self, direction, factor, values):
|
||||
"""
|
||||
factor is ignored if value is found in the dictionary
|
||||
"""
|
||||
if self._fallback_formatter:
|
||||
fallback_strings = self._fallback_formatter(
|
||||
direction, factor, values)
|
||||
else:
|
||||
fallback_strings = [""] * len(values)
|
||||
return [self._format_dict.get(k, v)
|
||||
for k, v in zip(values, fallback_strings)]
|
|
@ -0,0 +1,328 @@
|
|||
"""
|
||||
An experimental support for curvilinear grid.
|
||||
"""
|
||||
|
||||
import functools
|
||||
|
||||
import numpy as np
|
||||
|
||||
import matplotlib as mpl
|
||||
from matplotlib import _api
|
||||
from matplotlib.path import Path
|
||||
from matplotlib.transforms import Affine2D, IdentityTransform
|
||||
from .axislines import (
|
||||
_FixedAxisArtistHelperBase, _FloatingAxisArtistHelperBase, GridHelperBase)
|
||||
from .axis_artist import AxisArtist
|
||||
from .grid_finder import GridFinder
|
||||
|
||||
|
||||
def _value_and_jacobian(func, xs, ys, xlims, ylims):
|
||||
"""
|
||||
Compute *func* and its derivatives along x and y at positions *xs*, *ys*,
|
||||
while ensuring that finite difference calculations don't try to evaluate
|
||||
values outside of *xlims*, *ylims*.
|
||||
"""
|
||||
eps = np.finfo(float).eps ** (1/2) # see e.g. scipy.optimize.approx_fprime
|
||||
val = func(xs, ys)
|
||||
# Take the finite difference step in the direction where the bound is the
|
||||
# furthest; the step size is min of epsilon and distance to that bound.
|
||||
xlo, xhi = sorted(xlims)
|
||||
dxlo = xs - xlo
|
||||
dxhi = xhi - xs
|
||||
xeps = (np.take([-1, 1], dxhi >= dxlo)
|
||||
* np.minimum(eps, np.maximum(dxlo, dxhi)))
|
||||
val_dx = func(xs + xeps, ys)
|
||||
ylo, yhi = sorted(ylims)
|
||||
dylo = ys - ylo
|
||||
dyhi = yhi - ys
|
||||
yeps = (np.take([-1, 1], dyhi >= dylo)
|
||||
* np.minimum(eps, np.maximum(dylo, dyhi)))
|
||||
val_dy = func(xs, ys + yeps)
|
||||
return (val, (val_dx - val) / xeps, (val_dy - val) / yeps)
|
||||
|
||||
|
||||
class FixedAxisArtistHelper(_FixedAxisArtistHelperBase):
|
||||
"""
|
||||
Helper class for a fixed axis.
|
||||
"""
|
||||
|
||||
def __init__(self, grid_helper, side, nth_coord_ticks=None):
|
||||
"""
|
||||
nth_coord = along which coordinate value varies.
|
||||
nth_coord = 0 -> x axis, nth_coord = 1 -> y axis
|
||||
"""
|
||||
|
||||
super().__init__(loc=side)
|
||||
|
||||
self.grid_helper = grid_helper
|
||||
if nth_coord_ticks is None:
|
||||
nth_coord_ticks = self.nth_coord
|
||||
self.nth_coord_ticks = nth_coord_ticks
|
||||
|
||||
self.side = side
|
||||
|
||||
def update_lim(self, axes):
|
||||
self.grid_helper.update_lim(axes)
|
||||
|
||||
def get_tick_transform(self, axes):
|
||||
return axes.transData
|
||||
|
||||
def get_tick_iterators(self, axes):
|
||||
"""tick_loc, tick_angle, tick_label"""
|
||||
v1, v2 = axes.get_ylim() if self.nth_coord == 0 else axes.get_xlim()
|
||||
if v1 > v2: # Inverted limits.
|
||||
side = {"left": "right", "right": "left",
|
||||
"top": "bottom", "bottom": "top"}[self.side]
|
||||
else:
|
||||
side = self.side
|
||||
|
||||
angle_tangent = dict(left=90, right=90, bottom=0, top=0)[side]
|
||||
|
||||
def iter_major():
|
||||
for nth_coord, show_labels in [
|
||||
(self.nth_coord_ticks, True), (1 - self.nth_coord_ticks, False)]:
|
||||
gi = self.grid_helper._grid_info[["lon", "lat"][nth_coord]]
|
||||
for tick in gi["ticks"][side]:
|
||||
yield (*tick["loc"], angle_tangent,
|
||||
(tick["label"] if show_labels else ""))
|
||||
|
||||
return iter_major(), iter([])
|
||||
|
||||
|
||||
class FloatingAxisArtistHelper(_FloatingAxisArtistHelperBase):
|
||||
|
||||
def __init__(self, grid_helper, nth_coord, value, axis_direction=None):
|
||||
"""
|
||||
nth_coord = along which coordinate value varies.
|
||||
nth_coord = 0 -> x axis, nth_coord = 1 -> y axis
|
||||
"""
|
||||
super().__init__(nth_coord, value)
|
||||
self.value = value
|
||||
self.grid_helper = grid_helper
|
||||
self._extremes = -np.inf, np.inf
|
||||
self._line_num_points = 100 # number of points to create a line
|
||||
|
||||
def set_extremes(self, e1, e2):
|
||||
if e1 is None:
|
||||
e1 = -np.inf
|
||||
if e2 is None:
|
||||
e2 = np.inf
|
||||
self._extremes = e1, e2
|
||||
|
||||
def update_lim(self, axes):
|
||||
self.grid_helper.update_lim(axes)
|
||||
|
||||
x1, x2 = axes.get_xlim()
|
||||
y1, y2 = axes.get_ylim()
|
||||
grid_finder = self.grid_helper.grid_finder
|
||||
extremes = grid_finder.extreme_finder(grid_finder.inv_transform_xy,
|
||||
x1, y1, x2, y2)
|
||||
|
||||
lon_min, lon_max, lat_min, lat_max = extremes
|
||||
e_min, e_max = self._extremes # ranges of other coordinates
|
||||
if self.nth_coord == 0:
|
||||
lat_min = max(e_min, lat_min)
|
||||
lat_max = min(e_max, lat_max)
|
||||
elif self.nth_coord == 1:
|
||||
lon_min = max(e_min, lon_min)
|
||||
lon_max = min(e_max, lon_max)
|
||||
|
||||
lon_levs, lon_n, lon_factor = \
|
||||
grid_finder.grid_locator1(lon_min, lon_max)
|
||||
lat_levs, lat_n, lat_factor = \
|
||||
grid_finder.grid_locator2(lat_min, lat_max)
|
||||
|
||||
if self.nth_coord == 0:
|
||||
xx0 = np.full(self._line_num_points, self.value)
|
||||
yy0 = np.linspace(lat_min, lat_max, self._line_num_points)
|
||||
xx, yy = grid_finder.transform_xy(xx0, yy0)
|
||||
elif self.nth_coord == 1:
|
||||
xx0 = np.linspace(lon_min, lon_max, self._line_num_points)
|
||||
yy0 = np.full(self._line_num_points, self.value)
|
||||
xx, yy = grid_finder.transform_xy(xx0, yy0)
|
||||
|
||||
self._grid_info = {
|
||||
"extremes": (lon_min, lon_max, lat_min, lat_max),
|
||||
"lon_info": (lon_levs, lon_n, np.asarray(lon_factor)),
|
||||
"lat_info": (lat_levs, lat_n, np.asarray(lat_factor)),
|
||||
"lon_labels": grid_finder._format_ticks(
|
||||
1, "bottom", lon_factor, lon_levs),
|
||||
"lat_labels": grid_finder._format_ticks(
|
||||
2, "bottom", lat_factor, lat_levs),
|
||||
"line_xy": (xx, yy),
|
||||
}
|
||||
|
||||
def get_axislabel_transform(self, axes):
|
||||
return Affine2D() # axes.transData
|
||||
|
||||
def get_axislabel_pos_angle(self, axes):
|
||||
def trf_xy(x, y):
|
||||
trf = self.grid_helper.grid_finder.get_transform() + axes.transData
|
||||
return trf.transform([x, y]).T
|
||||
|
||||
xmin, xmax, ymin, ymax = self._grid_info["extremes"]
|
||||
if self.nth_coord == 0:
|
||||
xx0 = self.value
|
||||
yy0 = (ymin + ymax) / 2
|
||||
elif self.nth_coord == 1:
|
||||
xx0 = (xmin + xmax) / 2
|
||||
yy0 = self.value
|
||||
xy1, dxy1_dx, dxy1_dy = _value_and_jacobian(
|
||||
trf_xy, xx0, yy0, (xmin, xmax), (ymin, ymax))
|
||||
p = axes.transAxes.inverted().transform(xy1)
|
||||
if 0 <= p[0] <= 1 and 0 <= p[1] <= 1:
|
||||
d = [dxy1_dy, dxy1_dx][self.nth_coord]
|
||||
return xy1, np.rad2deg(np.arctan2(*d[::-1]))
|
||||
else:
|
||||
return None, None
|
||||
|
||||
def get_tick_transform(self, axes):
|
||||
return IdentityTransform() # axes.transData
|
||||
|
||||
def get_tick_iterators(self, axes):
|
||||
"""tick_loc, tick_angle, tick_label, (optionally) tick_label"""
|
||||
|
||||
lat_levs, lat_n, lat_factor = self._grid_info["lat_info"]
|
||||
yy0 = lat_levs / lat_factor
|
||||
|
||||
lon_levs, lon_n, lon_factor = self._grid_info["lon_info"]
|
||||
xx0 = lon_levs / lon_factor
|
||||
|
||||
e0, e1 = self._extremes
|
||||
|
||||
def trf_xy(x, y):
|
||||
trf = self.grid_helper.grid_finder.get_transform() + axes.transData
|
||||
return trf.transform(np.column_stack(np.broadcast_arrays(x, y))).T
|
||||
|
||||
# find angles
|
||||
if self.nth_coord == 0:
|
||||
mask = (e0 <= yy0) & (yy0 <= e1)
|
||||
(xx1, yy1), (dxx1, dyy1), (dxx2, dyy2) = _value_and_jacobian(
|
||||
trf_xy, self.value, yy0[mask], (-np.inf, np.inf), (e0, e1))
|
||||
labels = self._grid_info["lat_labels"]
|
||||
|
||||
elif self.nth_coord == 1:
|
||||
mask = (e0 <= xx0) & (xx0 <= e1)
|
||||
(xx1, yy1), (dxx2, dyy2), (dxx1, dyy1) = _value_and_jacobian(
|
||||
trf_xy, xx0[mask], self.value, (-np.inf, np.inf), (e0, e1))
|
||||
labels = self._grid_info["lon_labels"]
|
||||
|
||||
labels = [l for l, m in zip(labels, mask) if m]
|
||||
|
||||
angle_normal = np.arctan2(dyy1, dxx1)
|
||||
angle_tangent = np.arctan2(dyy2, dxx2)
|
||||
mm = (dyy1 == 0) & (dxx1 == 0) # points with degenerate normal
|
||||
angle_normal[mm] = angle_tangent[mm] + np.pi / 2
|
||||
|
||||
tick_to_axes = self.get_tick_transform(axes) - axes.transAxes
|
||||
in_01 = functools.partial(
|
||||
mpl.transforms._interval_contains_close, (0, 1))
|
||||
|
||||
def iter_major():
|
||||
for x, y, normal, tangent, lab \
|
||||
in zip(xx1, yy1, angle_normal, angle_tangent, labels):
|
||||
c2 = tick_to_axes.transform((x, y))
|
||||
if in_01(c2[0]) and in_01(c2[1]):
|
||||
yield [x, y], *np.rad2deg([normal, tangent]), lab
|
||||
|
||||
return iter_major(), iter([])
|
||||
|
||||
def get_line_transform(self, axes):
|
||||
return axes.transData
|
||||
|
||||
def get_line(self, axes):
|
||||
self.update_lim(axes)
|
||||
x, y = self._grid_info["line_xy"]
|
||||
return Path(np.column_stack([x, y]))
|
||||
|
||||
|
||||
class GridHelperCurveLinear(GridHelperBase):
|
||||
def __init__(self, aux_trans,
|
||||
extreme_finder=None,
|
||||
grid_locator1=None,
|
||||
grid_locator2=None,
|
||||
tick_formatter1=None,
|
||||
tick_formatter2=None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
aux_trans : `.Transform` or tuple[Callable, Callable]
|
||||
The transform from curved coordinates to rectilinear coordinate:
|
||||
either a `.Transform` instance (which provides also its inverse),
|
||||
or a pair of callables ``(trans, inv_trans)`` that define the
|
||||
transform and its inverse. The callables should have signature::
|
||||
|
||||
x_rect, y_rect = trans(x_curved, y_curved)
|
||||
x_curved, y_curved = inv_trans(x_rect, y_rect)
|
||||
|
||||
extreme_finder
|
||||
|
||||
grid_locator1, grid_locator2
|
||||
Grid locators for each axis.
|
||||
|
||||
tick_formatter1, tick_formatter2
|
||||
Tick formatters for each axis.
|
||||
"""
|
||||
super().__init__()
|
||||
self._grid_info = None
|
||||
self.grid_finder = GridFinder(aux_trans,
|
||||
extreme_finder,
|
||||
grid_locator1,
|
||||
grid_locator2,
|
||||
tick_formatter1,
|
||||
tick_formatter2)
|
||||
|
||||
def update_grid_finder(self, aux_trans=None, **kwargs):
|
||||
if aux_trans is not None:
|
||||
self.grid_finder.update_transform(aux_trans)
|
||||
self.grid_finder.update(**kwargs)
|
||||
self._old_limits = None # Force revalidation.
|
||||
|
||||
@_api.make_keyword_only("3.9", "nth_coord")
|
||||
def new_fixed_axis(
|
||||
self, loc, nth_coord=None, axis_direction=None, offset=None, axes=None):
|
||||
if axes is None:
|
||||
axes = self.axes
|
||||
if axis_direction is None:
|
||||
axis_direction = loc
|
||||
helper = FixedAxisArtistHelper(self, loc, nth_coord_ticks=nth_coord)
|
||||
axisline = AxisArtist(axes, helper, axis_direction=axis_direction)
|
||||
# Why is clip not set on axisline, unlike in new_floating_axis or in
|
||||
# the floating_axig.GridHelperCurveLinear subclass?
|
||||
return axisline
|
||||
|
||||
def new_floating_axis(self, nth_coord, value, axes=None, axis_direction="bottom"):
|
||||
if axes is None:
|
||||
axes = self.axes
|
||||
helper = FloatingAxisArtistHelper(
|
||||
self, nth_coord, value, axis_direction)
|
||||
axisline = AxisArtist(axes, helper)
|
||||
axisline.line.set_clip_on(True)
|
||||
axisline.line.set_clip_box(axisline.axes.bbox)
|
||||
# axisline.major_ticklabels.set_visible(True)
|
||||
# axisline.minor_ticklabels.set_visible(False)
|
||||
return axisline
|
||||
|
||||
def _update_grid(self, x1, y1, x2, y2):
|
||||
self._grid_info = self.grid_finder.get_grid_info(x1, y1, x2, y2)
|
||||
|
||||
def get_gridlines(self, which="major", axis="both"):
|
||||
grid_lines = []
|
||||
if axis in ["both", "x"]:
|
||||
for gl in self._grid_info["lon"]["lines"]:
|
||||
grid_lines.extend(gl)
|
||||
if axis in ["both", "y"]:
|
||||
for gl in self._grid_info["lat"]["lines"]:
|
||||
grid_lines.extend(gl)
|
||||
return grid_lines
|
||||
|
||||
@_api.deprecated("3.9")
|
||||
def get_tick_iterator(self, nth_coord, axis_side, minor=False):
|
||||
angle_tangent = dict(left=90, right=90, bottom=0, top=0)[axis_side]
|
||||
lon_or_lat = ["lon", "lat"][nth_coord]
|
||||
if not minor: # major ticks
|
||||
for tick in self._grid_info[lon_or_lat]["ticks"][axis_side]:
|
||||
yield *tick["loc"], angle_tangent, tick["label"]
|
||||
else:
|
||||
for tick in self._grid_info[lon_or_lat]["ticks"][axis_side]:
|
||||
yield *tick["loc"], angle_tangent, ""
|
|
@ -0,0 +1,7 @@
|
|||
from mpl_toolkits.axes_grid1.parasite_axes import (
|
||||
host_axes_class_factory, parasite_axes_class_factory)
|
||||
from .axislines import Axes
|
||||
|
||||
|
||||
ParasiteAxes = parasite_axes_class_factory(Axes)
|
||||
HostAxes = SubplotHost = host_axes_class_factory(Axes)
|
|
@ -0,0 +1,10 @@
|
|||
from pathlib import Path
|
||||
|
||||
|
||||
# Check that the test directories exist
|
||||
if not (Path(__file__).parent / "baseline_images").exists():
|
||||
raise OSError(
|
||||
'The baseline image directory does not exist. '
|
||||
'This is most likely because the test data is not installed. '
|
||||
'You may need to install matplotlib from source to get the '
|
||||
'test data.')
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,2 @@
|
|||
from matplotlib.testing.conftest import (mpl_test_settings, # noqa
|
||||
pytest_configure, pytest_unconfigure)
|
|
@ -0,0 +1,141 @@
|
|||
import re
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mpl_toolkits.axisartist.angle_helper import (
|
||||
FormatterDMS, FormatterHMS, select_step, select_step24, select_step360)
|
||||
|
||||
|
||||
_MS_RE = (
|
||||
r'''\$ # Mathtext
|
||||
(
|
||||
# The sign sometimes appears on a 0 when a fraction is shown.
|
||||
# Check later that there's only one.
|
||||
(?P<degree_sign>-)?
|
||||
(?P<degree>[0-9.]+) # Degrees value
|
||||
{degree} # Degree symbol (to be replaced by format.)
|
||||
)?
|
||||
(
|
||||
(?(degree)\\,) # Separator if degrees are also visible.
|
||||
(?P<minute_sign>-)?
|
||||
(?P<minute>[0-9.]+) # Minutes value
|
||||
{minute} # Minute symbol (to be replaced by format.)
|
||||
)?
|
||||
(
|
||||
(?(minute)\\,) # Separator if minutes are also visible.
|
||||
(?P<second_sign>-)?
|
||||
(?P<second>[0-9.]+) # Seconds value
|
||||
{second} # Second symbol (to be replaced by format.)
|
||||
)?
|
||||
\$ # Mathtext
|
||||
'''
|
||||
)
|
||||
DMS_RE = re.compile(_MS_RE.format(degree=re.escape(FormatterDMS.deg_mark),
|
||||
minute=re.escape(FormatterDMS.min_mark),
|
||||
second=re.escape(FormatterDMS.sec_mark)),
|
||||
re.VERBOSE)
|
||||
HMS_RE = re.compile(_MS_RE.format(degree=re.escape(FormatterHMS.deg_mark),
|
||||
minute=re.escape(FormatterHMS.min_mark),
|
||||
second=re.escape(FormatterHMS.sec_mark)),
|
||||
re.VERBOSE)
|
||||
|
||||
|
||||
def dms2float(degrees, minutes=0, seconds=0):
|
||||
return degrees + minutes / 60.0 + seconds / 3600.0
|
||||
|
||||
|
||||
@pytest.mark.parametrize('args, kwargs, expected_levels, expected_factor', [
|
||||
((-180, 180, 10), {'hour': False}, np.arange(-180, 181, 30), 1.0),
|
||||
((-12, 12, 10), {'hour': True}, np.arange(-12, 13, 2), 1.0)
|
||||
])
|
||||
def test_select_step(args, kwargs, expected_levels, expected_factor):
|
||||
levels, n, factor = select_step(*args, **kwargs)
|
||||
|
||||
assert n == len(levels)
|
||||
np.testing.assert_array_equal(levels, expected_levels)
|
||||
assert factor == expected_factor
|
||||
|
||||
|
||||
@pytest.mark.parametrize('args, kwargs, expected_levels, expected_factor', [
|
||||
((-180, 180, 10), {}, np.arange(-180, 181, 30), 1.0),
|
||||
((-12, 12, 10), {}, np.arange(-750, 751, 150), 60.0)
|
||||
])
|
||||
def test_select_step24(args, kwargs, expected_levels, expected_factor):
|
||||
levels, n, factor = select_step24(*args, **kwargs)
|
||||
|
||||
assert n == len(levels)
|
||||
np.testing.assert_array_equal(levels, expected_levels)
|
||||
assert factor == expected_factor
|
||||
|
||||
|
||||
@pytest.mark.parametrize('args, kwargs, expected_levels, expected_factor', [
|
||||
((dms2float(20, 21.2), dms2float(21, 33.3), 5), {},
|
||||
np.arange(1215, 1306, 15), 60.0),
|
||||
((dms2float(20.5, seconds=21.2), dms2float(20.5, seconds=33.3), 5), {},
|
||||
np.arange(73820, 73835, 2), 3600.0),
|
||||
((dms2float(20, 21.2), dms2float(20, 53.3), 5), {},
|
||||
np.arange(1220, 1256, 5), 60.0),
|
||||
((21.2, 33.3, 5), {},
|
||||
np.arange(20, 35, 2), 1.0),
|
||||
((dms2float(20, 21.2), dms2float(21, 33.3), 5), {},
|
||||
np.arange(1215, 1306, 15), 60.0),
|
||||
((dms2float(20.5, seconds=21.2), dms2float(20.5, seconds=33.3), 5), {},
|
||||
np.arange(73820, 73835, 2), 3600.0),
|
||||
((dms2float(20.5, seconds=21.2), dms2float(20.5, seconds=21.4), 5), {},
|
||||
np.arange(7382120, 7382141, 5), 360000.0),
|
||||
# test threshold factor
|
||||
((dms2float(20.5, seconds=11.2), dms2float(20.5, seconds=53.3), 5),
|
||||
{'threshold_factor': 60}, np.arange(12301, 12310), 600.0),
|
||||
((dms2float(20.5, seconds=11.2), dms2float(20.5, seconds=53.3), 5),
|
||||
{'threshold_factor': 1}, np.arange(20502, 20517, 2), 1000.0),
|
||||
])
|
||||
def test_select_step360(args, kwargs, expected_levels, expected_factor):
|
||||
levels, n, factor = select_step360(*args, **kwargs)
|
||||
|
||||
assert n == len(levels)
|
||||
np.testing.assert_array_equal(levels, expected_levels)
|
||||
assert factor == expected_factor
|
||||
|
||||
|
||||
@pytest.mark.parametrize('Formatter, regex',
|
||||
[(FormatterDMS, DMS_RE),
|
||||
(FormatterHMS, HMS_RE)],
|
||||
ids=['Degree/Minute/Second', 'Hour/Minute/Second'])
|
||||
@pytest.mark.parametrize('direction, factor, values', [
|
||||
("left", 60, [0, -30, -60]),
|
||||
("left", 600, [12301, 12302, 12303]),
|
||||
("left", 3600, [0, -30, -60]),
|
||||
("left", 36000, [738210, 738215, 738220]),
|
||||
("left", 360000, [7382120, 7382125, 7382130]),
|
||||
("left", 1., [45, 46, 47]),
|
||||
("left", 10., [452, 453, 454]),
|
||||
])
|
||||
def test_formatters(Formatter, regex, direction, factor, values):
|
||||
fmt = Formatter()
|
||||
result = fmt(direction, factor, values)
|
||||
|
||||
prev_degree = prev_minute = prev_second = None
|
||||
for tick, value in zip(result, values):
|
||||
m = regex.match(tick)
|
||||
assert m is not None, f'{tick!r} is not an expected tick format.'
|
||||
|
||||
sign = sum(m.group(sign + '_sign') is not None
|
||||
for sign in ('degree', 'minute', 'second'))
|
||||
assert sign <= 1, f'Only one element of tick {tick!r} may have a sign.'
|
||||
sign = 1 if sign == 0 else -1
|
||||
|
||||
degree = float(m.group('degree') or prev_degree or 0)
|
||||
minute = float(m.group('minute') or prev_minute or 0)
|
||||
second = float(m.group('second') or prev_second or 0)
|
||||
if Formatter == FormatterHMS:
|
||||
# 360 degrees as plot range -> 24 hours as labelled range
|
||||
expected_value = pytest.approx((value // 15) / factor)
|
||||
else:
|
||||
expected_value = pytest.approx(value / factor)
|
||||
assert sign * dms2float(degree, minute, second) == expected_value, \
|
||||
f'{tick!r} does not match expected tick value.'
|
||||
|
||||
prev_degree = degree
|
||||
prev_minute = minute
|
||||
prev_second = second
|
|
@ -0,0 +1,99 @@
|
|||
import matplotlib.pyplot as plt
|
||||
from matplotlib.testing.decorators import image_comparison
|
||||
|
||||
from mpl_toolkits.axisartist import AxisArtistHelperRectlinear
|
||||
from mpl_toolkits.axisartist.axis_artist import (AxisArtist, AxisLabel,
|
||||
LabelBase, Ticks, TickLabels)
|
||||
|
||||
|
||||
@image_comparison(['axis_artist_ticks.png'], style='default')
|
||||
def test_ticks():
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
ax.xaxis.set_visible(False)
|
||||
ax.yaxis.set_visible(False)
|
||||
|
||||
locs_angles = [((i / 10, 0.0), i * 30) for i in range(-1, 12)]
|
||||
|
||||
ticks_in = Ticks(ticksize=10, axis=ax.xaxis)
|
||||
ticks_in.set_locs_angles(locs_angles)
|
||||
ax.add_artist(ticks_in)
|
||||
|
||||
ticks_out = Ticks(ticksize=10, tick_out=True, color='C3', axis=ax.xaxis)
|
||||
ticks_out.set_locs_angles(locs_angles)
|
||||
ax.add_artist(ticks_out)
|
||||
|
||||
|
||||
@image_comparison(['axis_artist_labelbase.png'], style='default')
|
||||
def test_labelbase():
|
||||
# Remove this line when this test image is regenerated.
|
||||
plt.rcParams['text.kerning_factor'] = 6
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
ax.plot([0.5], [0.5], "o")
|
||||
|
||||
label = LabelBase(0.5, 0.5, "Test")
|
||||
label._ref_angle = -90
|
||||
label._offset_radius = 50
|
||||
label.set_rotation(-90)
|
||||
label.set(ha="center", va="top")
|
||||
ax.add_artist(label)
|
||||
|
||||
|
||||
@image_comparison(['axis_artist_ticklabels.png'], style='default')
|
||||
def test_ticklabels():
|
||||
# Remove this line when this test image is regenerated.
|
||||
plt.rcParams['text.kerning_factor'] = 6
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
ax.xaxis.set_visible(False)
|
||||
ax.yaxis.set_visible(False)
|
||||
|
||||
ax.plot([0.2, 0.4], [0.5, 0.5], "o")
|
||||
|
||||
ticks = Ticks(ticksize=10, axis=ax.xaxis)
|
||||
ax.add_artist(ticks)
|
||||
locs_angles_labels = [((0.2, 0.5), -90, "0.2"),
|
||||
((0.4, 0.5), -120, "0.4")]
|
||||
tick_locs_angles = [(xy, a + 180) for xy, a, l in locs_angles_labels]
|
||||
ticks.set_locs_angles(tick_locs_angles)
|
||||
|
||||
ticklabels = TickLabels(axis_direction="left")
|
||||
ticklabels._locs_angles_labels = locs_angles_labels
|
||||
ticklabels.set_pad(10)
|
||||
ax.add_artist(ticklabels)
|
||||
|
||||
ax.plot([0.5], [0.5], "s")
|
||||
axislabel = AxisLabel(0.5, 0.5, "Test")
|
||||
axislabel._offset_radius = 20
|
||||
axislabel._ref_angle = 0
|
||||
axislabel.set_axis_direction("bottom")
|
||||
ax.add_artist(axislabel)
|
||||
|
||||
ax.set_xlim(0, 1)
|
||||
ax.set_ylim(0, 1)
|
||||
|
||||
|
||||
@image_comparison(['axis_artist.png'], style='default')
|
||||
def test_axis_artist():
|
||||
# Remove this line when this test image is regenerated.
|
||||
plt.rcParams['text.kerning_factor'] = 6
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
ax.xaxis.set_visible(False)
|
||||
ax.yaxis.set_visible(False)
|
||||
|
||||
for loc in ('left', 'right', 'bottom'):
|
||||
helper = AxisArtistHelperRectlinear.Fixed(ax, loc=loc)
|
||||
axisline = AxisArtist(ax, helper, offset=None, axis_direction=loc)
|
||||
ax.add_artist(axisline)
|
||||
|
||||
# Settings for bottom AxisArtist.
|
||||
axisline.set_label("TTT")
|
||||
axisline.major_ticks.set_tick_out(False)
|
||||
axisline.label.set_pad(5)
|
||||
|
||||
ax.set_ylabel("Test")
|
|
@ -0,0 +1,145 @@
|
|||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.testing.decorators import image_comparison
|
||||
from matplotlib.transforms import IdentityTransform
|
||||
|
||||
from mpl_toolkits.axisartist.axislines import AxesZero, SubplotZero, Subplot
|
||||
from mpl_toolkits.axisartist import Axes, SubplotHost
|
||||
|
||||
|
||||
@image_comparison(['SubplotZero.png'], style='default')
|
||||
def test_SubplotZero():
|
||||
# Remove this line when this test image is regenerated.
|
||||
plt.rcParams['text.kerning_factor'] = 6
|
||||
|
||||
fig = plt.figure()
|
||||
|
||||
ax = SubplotZero(fig, 1, 1, 1)
|
||||
fig.add_subplot(ax)
|
||||
|
||||
ax.axis["xzero"].set_visible(True)
|
||||
ax.axis["xzero"].label.set_text("Axis Zero")
|
||||
|
||||
for n in ["top", "right"]:
|
||||
ax.axis[n].set_visible(False)
|
||||
|
||||
xx = np.arange(0, 2 * np.pi, 0.01)
|
||||
ax.plot(xx, np.sin(xx))
|
||||
ax.set_ylabel("Test")
|
||||
|
||||
|
||||
@image_comparison(['Subplot.png'], style='default')
|
||||
def test_Subplot():
|
||||
# Remove this line when this test image is regenerated.
|
||||
plt.rcParams['text.kerning_factor'] = 6
|
||||
|
||||
fig = plt.figure()
|
||||
|
||||
ax = Subplot(fig, 1, 1, 1)
|
||||
fig.add_subplot(ax)
|
||||
|
||||
xx = np.arange(0, 2 * np.pi, 0.01)
|
||||
ax.plot(xx, np.sin(xx))
|
||||
ax.set_ylabel("Test")
|
||||
|
||||
ax.axis["top"].major_ticks.set_tick_out(True)
|
||||
ax.axis["bottom"].major_ticks.set_tick_out(True)
|
||||
|
||||
ax.axis["bottom"].set_label("Tk0")
|
||||
|
||||
|
||||
def test_Axes():
|
||||
fig = plt.figure()
|
||||
ax = Axes(fig, [0.15, 0.1, 0.65, 0.8])
|
||||
fig.add_axes(ax)
|
||||
ax.plot([1, 2, 3], [0, 1, 2])
|
||||
ax.set_xscale('log')
|
||||
fig.canvas.draw()
|
||||
|
||||
|
||||
@image_comparison(['ParasiteAxesAuxTrans_meshplot.png'],
|
||||
remove_text=True, style='default', tol=0.075)
|
||||
def test_ParasiteAxesAuxTrans():
|
||||
data = np.ones((6, 6))
|
||||
data[2, 2] = 2
|
||||
data[0, :] = 0
|
||||
data[-2, :] = 0
|
||||
data[:, 0] = 0
|
||||
data[:, -2] = 0
|
||||
x = np.arange(6)
|
||||
y = np.arange(6)
|
||||
xx, yy = np.meshgrid(x, y)
|
||||
|
||||
funcnames = ['pcolor', 'pcolormesh', 'contourf']
|
||||
|
||||
fig = plt.figure()
|
||||
for i, name in enumerate(funcnames):
|
||||
|
||||
ax1 = SubplotHost(fig, 1, 3, i+1)
|
||||
fig.add_subplot(ax1)
|
||||
|
||||
ax2 = ax1.get_aux_axes(IdentityTransform(), viewlim_mode=None)
|
||||
if name.startswith('pcolor'):
|
||||
getattr(ax2, name)(xx, yy, data[:-1, :-1])
|
||||
else:
|
||||
getattr(ax2, name)(xx, yy, data)
|
||||
ax1.set_xlim((0, 5))
|
||||
ax1.set_ylim((0, 5))
|
||||
|
||||
ax2.contour(xx, yy, data, colors='k')
|
||||
|
||||
|
||||
@image_comparison(['axisline_style.png'], remove_text=True, style='mpl20')
|
||||
def test_axisline_style():
|
||||
fig = plt.figure(figsize=(2, 2))
|
||||
ax = fig.add_subplot(axes_class=AxesZero)
|
||||
ax.axis["xzero"].set_axisline_style("-|>")
|
||||
ax.axis["xzero"].set_visible(True)
|
||||
ax.axis["yzero"].set_axisline_style("->")
|
||||
ax.axis["yzero"].set_visible(True)
|
||||
|
||||
for direction in ("left", "right", "bottom", "top"):
|
||||
ax.axis[direction].set_visible(False)
|
||||
|
||||
|
||||
@image_comparison(['axisline_style_size_color.png'], remove_text=True,
|
||||
style='mpl20')
|
||||
def test_axisline_style_size_color():
|
||||
fig = plt.figure(figsize=(2, 2))
|
||||
ax = fig.add_subplot(axes_class=AxesZero)
|
||||
ax.axis["xzero"].set_axisline_style("-|>", size=2.0, facecolor='r')
|
||||
ax.axis["xzero"].set_visible(True)
|
||||
ax.axis["yzero"].set_axisline_style("->, size=1.5")
|
||||
ax.axis["yzero"].set_visible(True)
|
||||
|
||||
for direction in ("left", "right", "bottom", "top"):
|
||||
ax.axis[direction].set_visible(False)
|
||||
|
||||
|
||||
@image_comparison(['axisline_style_tight.png'], remove_text=True,
|
||||
style='mpl20')
|
||||
def test_axisline_style_tight():
|
||||
fig = plt.figure(figsize=(2, 2), layout='tight')
|
||||
ax = fig.add_subplot(axes_class=AxesZero)
|
||||
ax.axis["xzero"].set_axisline_style("-|>", size=5, facecolor='g')
|
||||
ax.axis["xzero"].set_visible(True)
|
||||
ax.axis["yzero"].set_axisline_style("->, size=8")
|
||||
ax.axis["yzero"].set_visible(True)
|
||||
|
||||
for direction in ("left", "right", "bottom", "top"):
|
||||
ax.axis[direction].set_visible(False)
|
||||
|
||||
|
||||
@image_comparison(['subplotzero_ylabel.png'], style='mpl20')
|
||||
def test_subplotzero_ylabel():
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111, axes_class=SubplotZero)
|
||||
|
||||
ax.set(xlim=(-3, 7), ylim=(-3, 7), xlabel="x", ylabel="y")
|
||||
|
||||
zero_axis = ax.axis["xzero", "yzero"]
|
||||
zero_axis.set_visible(True) # they are hidden by default
|
||||
|
||||
ax.axis["left", "right", "bottom", "top"].set_visible(False)
|
||||
|
||||
zero_axis.set_axisline_style("->")
|
|
@ -0,0 +1,115 @@
|
|||
import numpy as np
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.projections as mprojections
|
||||
import matplotlib.transforms as mtransforms
|
||||
from matplotlib.testing.decorators import image_comparison
|
||||
from mpl_toolkits.axisartist.axislines import Subplot
|
||||
from mpl_toolkits.axisartist.floating_axes import (
|
||||
FloatingAxes, GridHelperCurveLinear)
|
||||
from mpl_toolkits.axisartist.grid_finder import FixedLocator
|
||||
from mpl_toolkits.axisartist import angle_helper
|
||||
|
||||
|
||||
def test_subplot():
|
||||
fig = plt.figure(figsize=(5, 5))
|
||||
ax = Subplot(fig, 111)
|
||||
fig.add_subplot(ax)
|
||||
|
||||
|
||||
# Rather high tolerance to allow ongoing work with floating axes internals;
|
||||
# remove when image is regenerated.
|
||||
@image_comparison(['curvelinear3.png'], style='default', tol=5)
|
||||
def test_curvelinear3():
|
||||
fig = plt.figure(figsize=(5, 5))
|
||||
|
||||
tr = (mtransforms.Affine2D().scale(np.pi / 180, 1) +
|
||||
mprojections.PolarAxes.PolarTransform(apply_theta_transforms=False))
|
||||
grid_helper = GridHelperCurveLinear(
|
||||
tr,
|
||||
extremes=(0, 360, 10, 3),
|
||||
grid_locator1=angle_helper.LocatorDMS(15),
|
||||
grid_locator2=FixedLocator([2, 4, 6, 8, 10]),
|
||||
tick_formatter1=angle_helper.FormatterDMS(),
|
||||
tick_formatter2=None)
|
||||
ax1 = fig.add_subplot(axes_class=FloatingAxes, grid_helper=grid_helper)
|
||||
|
||||
r_scale = 10
|
||||
tr2 = mtransforms.Affine2D().scale(1, 1 / r_scale) + tr
|
||||
grid_helper2 = GridHelperCurveLinear(
|
||||
tr2,
|
||||
extremes=(0, 360, 10 * r_scale, 3 * r_scale),
|
||||
grid_locator2=FixedLocator([30, 60, 90]))
|
||||
|
||||
ax1.axis["right"] = axis = grid_helper2.new_fixed_axis("right", axes=ax1)
|
||||
|
||||
ax1.axis["left"].label.set_text("Test 1")
|
||||
ax1.axis["right"].label.set_text("Test 2")
|
||||
ax1.axis["left", "right"].set_visible(False)
|
||||
|
||||
axis = grid_helper.new_floating_axis(1, 7, axes=ax1,
|
||||
axis_direction="bottom")
|
||||
ax1.axis["z"] = axis
|
||||
axis.toggle(all=True, label=True)
|
||||
axis.label.set_text("z = ?")
|
||||
axis.label.set_visible(True)
|
||||
axis.line.set_color("0.5")
|
||||
|
||||
ax2 = ax1.get_aux_axes(tr)
|
||||
|
||||
xx, yy = [67, 90, 75, 30], [2, 5, 8, 4]
|
||||
ax2.scatter(xx, yy)
|
||||
l, = ax2.plot(xx, yy, "k-")
|
||||
l.set_clip_path(ax1.patch)
|
||||
|
||||
|
||||
# Rather high tolerance to allow ongoing work with floating axes internals;
|
||||
# remove when image is regenerated.
|
||||
@image_comparison(['curvelinear4.png'], style='default', tol=0.9)
|
||||
def test_curvelinear4():
|
||||
# Remove this line when this test image is regenerated.
|
||||
plt.rcParams['text.kerning_factor'] = 6
|
||||
|
||||
fig = plt.figure(figsize=(5, 5))
|
||||
|
||||
tr = (mtransforms.Affine2D().scale(np.pi / 180, 1) +
|
||||
mprojections.PolarAxes.PolarTransform(apply_theta_transforms=False))
|
||||
grid_helper = GridHelperCurveLinear(
|
||||
tr,
|
||||
extremes=(120, 30, 10, 0),
|
||||
grid_locator1=angle_helper.LocatorDMS(5),
|
||||
grid_locator2=FixedLocator([2, 4, 6, 8, 10]),
|
||||
tick_formatter1=angle_helper.FormatterDMS(),
|
||||
tick_formatter2=None)
|
||||
ax1 = fig.add_subplot(axes_class=FloatingAxes, grid_helper=grid_helper)
|
||||
ax1.clear() # Check that clear() also restores the correct limits on ax1.
|
||||
|
||||
ax1.axis["left"].label.set_text("Test 1")
|
||||
ax1.axis["right"].label.set_text("Test 2")
|
||||
ax1.axis["top"].set_visible(False)
|
||||
|
||||
axis = grid_helper.new_floating_axis(1, 70, axes=ax1,
|
||||
axis_direction="bottom")
|
||||
ax1.axis["z"] = axis
|
||||
axis.toggle(all=True, label=True)
|
||||
axis.label.set_axis_direction("top")
|
||||
axis.label.set_text("z = ?")
|
||||
axis.label.set_visible(True)
|
||||
axis.line.set_color("0.5")
|
||||
|
||||
ax2 = ax1.get_aux_axes(tr)
|
||||
|
||||
xx, yy = [67, 90, 75, 30], [2, 5, 8, 4]
|
||||
ax2.scatter(xx, yy)
|
||||
l, = ax2.plot(xx, yy, "k-")
|
||||
l.set_clip_path(ax1.patch)
|
||||
|
||||
|
||||
def test_axis_direction():
|
||||
# Check that axis direction is propagated on a floating axis
|
||||
fig = plt.figure()
|
||||
ax = Subplot(fig, 111)
|
||||
fig.add_subplot(ax)
|
||||
ax.axis['y'] = ax.new_floating_axis(nth_coord=1, value=0,
|
||||
axis_direction='left')
|
||||
assert ax.axis['y']._axis_direction == 'left'
|
|
@ -0,0 +1,34 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from matplotlib.transforms import Bbox
|
||||
from mpl_toolkits.axisartist.grid_finder import (
|
||||
_find_line_box_crossings, FormatterPrettyPrint, MaxNLocator)
|
||||
|
||||
|
||||
def test_find_line_box_crossings():
|
||||
x = np.array([-3, -2, -1, 0., 1, 2, 3, 2, 1, 0, -1, -2, -3, 5])
|
||||
y = np.arange(len(x))
|
||||
bbox = Bbox.from_extents(-2, 3, 2, 12.5)
|
||||
left, right, bottom, top = _find_line_box_crossings(
|
||||
np.column_stack([x, y]), bbox)
|
||||
((lx0, ly0), la0), ((lx1, ly1), la1), = left
|
||||
((rx0, ry0), ra0), ((rx1, ry1), ra1), = right
|
||||
((bx0, by0), ba0), = bottom
|
||||
((tx0, ty0), ta0), = top
|
||||
assert (lx0, ly0, la0) == (-2, 11, 135)
|
||||
assert (lx1, ly1, la1) == pytest.approx((-2., 12.125, 7.125016))
|
||||
assert (rx0, ry0, ra0) == (2, 5, 45)
|
||||
assert (rx1, ry1, ra1) == (2, 7, 135)
|
||||
assert (bx0, by0, ba0) == (0, 3, 45)
|
||||
assert (tx0, ty0, ta0) == pytest.approx((1., 12.5, 7.125016))
|
||||
|
||||
|
||||
def test_pretty_print_format():
|
||||
locator = MaxNLocator()
|
||||
locs, nloc, factor = locator(0, 100)
|
||||
|
||||
fmt = FormatterPrettyPrint()
|
||||
|
||||
assert fmt("left", None, locs) == \
|
||||
[r'$\mathdefault{%d}$' % (l, ) for l in locs]
|
|
@ -0,0 +1,207 @@
|
|||
import numpy as np
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.path import Path
|
||||
from matplotlib.projections import PolarAxes
|
||||
from matplotlib.ticker import FuncFormatter
|
||||
from matplotlib.transforms import Affine2D, Transform
|
||||
from matplotlib.testing.decorators import image_comparison
|
||||
|
||||
from mpl_toolkits.axisartist import SubplotHost
|
||||
from mpl_toolkits.axes_grid1.parasite_axes import host_axes_class_factory
|
||||
from mpl_toolkits.axisartist import angle_helper
|
||||
from mpl_toolkits.axisartist.axislines import Axes
|
||||
from mpl_toolkits.axisartist.grid_helper_curvelinear import \
|
||||
GridHelperCurveLinear
|
||||
|
||||
|
||||
@image_comparison(['custom_transform.png'], style='default', tol=0.2)
|
||||
def test_custom_transform():
|
||||
class MyTransform(Transform):
|
||||
input_dims = output_dims = 2
|
||||
|
||||
def __init__(self, resolution):
|
||||
"""
|
||||
Resolution is the number of steps to interpolate between each input
|
||||
line segment to approximate its path in transformed space.
|
||||
"""
|
||||
Transform.__init__(self)
|
||||
self._resolution = resolution
|
||||
|
||||
def transform(self, ll):
|
||||
x, y = ll.T
|
||||
return np.column_stack([x, y - x])
|
||||
|
||||
transform_non_affine = transform
|
||||
|
||||
def transform_path(self, path):
|
||||
ipath = path.interpolated(self._resolution)
|
||||
return Path(self.transform(ipath.vertices), ipath.codes)
|
||||
|
||||
transform_path_non_affine = transform_path
|
||||
|
||||
def inverted(self):
|
||||
return MyTransformInv(self._resolution)
|
||||
|
||||
class MyTransformInv(Transform):
|
||||
input_dims = output_dims = 2
|
||||
|
||||
def __init__(self, resolution):
|
||||
Transform.__init__(self)
|
||||
self._resolution = resolution
|
||||
|
||||
def transform(self, ll):
|
||||
x, y = ll.T
|
||||
return np.column_stack([x, y + x])
|
||||
|
||||
def inverted(self):
|
||||
return MyTransform(self._resolution)
|
||||
|
||||
fig = plt.figure()
|
||||
|
||||
SubplotHost = host_axes_class_factory(Axes)
|
||||
|
||||
tr = MyTransform(1)
|
||||
grid_helper = GridHelperCurveLinear(tr)
|
||||
ax1 = SubplotHost(fig, 1, 1, 1, grid_helper=grid_helper)
|
||||
fig.add_subplot(ax1)
|
||||
|
||||
ax2 = ax1.get_aux_axes(tr, viewlim_mode="equal")
|
||||
ax2.plot([3, 6], [5.0, 10.])
|
||||
|
||||
ax1.set_aspect(1.)
|
||||
ax1.set_xlim(0, 10)
|
||||
ax1.set_ylim(0, 10)
|
||||
|
||||
ax1.grid(True)
|
||||
|
||||
|
||||
@image_comparison(['polar_box.png'], style='default', tol=0.04)
|
||||
def test_polar_box():
|
||||
fig = plt.figure(figsize=(5, 5))
|
||||
|
||||
# PolarAxes.PolarTransform takes radian. However, we want our coordinate
|
||||
# system in degree
|
||||
tr = (Affine2D().scale(np.pi / 180., 1.) +
|
||||
PolarAxes.PolarTransform(apply_theta_transforms=False))
|
||||
|
||||
# polar projection, which involves cycle, and also has limits in
|
||||
# its coordinates, needs a special method to find the extremes
|
||||
# (min, max of the coordinate within the view).
|
||||
extreme_finder = angle_helper.ExtremeFinderCycle(20, 20,
|
||||
lon_cycle=360,
|
||||
lat_cycle=None,
|
||||
lon_minmax=None,
|
||||
lat_minmax=(0, np.inf))
|
||||
|
||||
grid_helper = GridHelperCurveLinear(
|
||||
tr,
|
||||
extreme_finder=extreme_finder,
|
||||
grid_locator1=angle_helper.LocatorDMS(12),
|
||||
tick_formatter1=angle_helper.FormatterDMS(),
|
||||
tick_formatter2=FuncFormatter(lambda x, p: "eight" if x == 8 else f"{int(x)}"),
|
||||
)
|
||||
|
||||
ax1 = SubplotHost(fig, 1, 1, 1, grid_helper=grid_helper)
|
||||
|
||||
ax1.axis["right"].major_ticklabels.set_visible(True)
|
||||
ax1.axis["top"].major_ticklabels.set_visible(True)
|
||||
|
||||
# let right axis shows ticklabels for 1st coordinate (angle)
|
||||
ax1.axis["right"].get_helper().nth_coord_ticks = 0
|
||||
# let bottom axis shows ticklabels for 2nd coordinate (radius)
|
||||
ax1.axis["bottom"].get_helper().nth_coord_ticks = 1
|
||||
|
||||
fig.add_subplot(ax1)
|
||||
|
||||
ax1.axis["lat"] = axis = grid_helper.new_floating_axis(0, 45, axes=ax1)
|
||||
axis.label.set_text("Test")
|
||||
axis.label.set_visible(True)
|
||||
axis.get_helper().set_extremes(2, 12)
|
||||
|
||||
ax1.axis["lon"] = axis = grid_helper.new_floating_axis(1, 6, axes=ax1)
|
||||
axis.label.set_text("Test 2")
|
||||
axis.get_helper().set_extremes(-180, 90)
|
||||
|
||||
# A parasite axes with given transform
|
||||
ax2 = ax1.get_aux_axes(tr, viewlim_mode="equal")
|
||||
assert ax2.transData == tr + ax1.transData
|
||||
# Anything you draw in ax2 will match the ticks and grids of ax1.
|
||||
ax2.plot(np.linspace(0, 30, 50), np.linspace(10, 10, 50))
|
||||
|
||||
ax1.set_aspect(1.)
|
||||
ax1.set_xlim(-5, 12)
|
||||
ax1.set_ylim(-5, 10)
|
||||
|
||||
ax1.grid(True)
|
||||
|
||||
|
||||
# Remove tol & kerning_factor when this test image is regenerated.
|
||||
@image_comparison(['axis_direction.png'], style='default', tol=0.13)
|
||||
def test_axis_direction():
|
||||
plt.rcParams['text.kerning_factor'] = 6
|
||||
|
||||
fig = plt.figure(figsize=(5, 5))
|
||||
|
||||
# PolarAxes.PolarTransform takes radian. However, we want our coordinate
|
||||
# system in degree
|
||||
tr = (Affine2D().scale(np.pi / 180., 1.) +
|
||||
PolarAxes.PolarTransform(apply_theta_transforms=False))
|
||||
|
||||
# polar projection, which involves cycle, and also has limits in
|
||||
# its coordinates, needs a special method to find the extremes
|
||||
# (min, max of the coordinate within the view).
|
||||
|
||||
# 20, 20 : number of sampling points along x, y direction
|
||||
extreme_finder = angle_helper.ExtremeFinderCycle(20, 20,
|
||||
lon_cycle=360,
|
||||
lat_cycle=None,
|
||||
lon_minmax=None,
|
||||
lat_minmax=(0, np.inf),
|
||||
)
|
||||
|
||||
grid_locator1 = angle_helper.LocatorDMS(12)
|
||||
tick_formatter1 = angle_helper.FormatterDMS()
|
||||
|
||||
grid_helper = GridHelperCurveLinear(tr,
|
||||
extreme_finder=extreme_finder,
|
||||
grid_locator1=grid_locator1,
|
||||
tick_formatter1=tick_formatter1)
|
||||
|
||||
ax1 = SubplotHost(fig, 1, 1, 1, grid_helper=grid_helper)
|
||||
|
||||
for axis in ax1.axis.values():
|
||||
axis.set_visible(False)
|
||||
|
||||
fig.add_subplot(ax1)
|
||||
|
||||
ax1.axis["lat1"] = axis = grid_helper.new_floating_axis(
|
||||
0, 130,
|
||||
axes=ax1, axis_direction="left")
|
||||
axis.label.set_text("Test")
|
||||
axis.label.set_visible(True)
|
||||
axis.get_helper().set_extremes(0.001, 10)
|
||||
|
||||
ax1.axis["lat2"] = axis = grid_helper.new_floating_axis(
|
||||
0, 50,
|
||||
axes=ax1, axis_direction="right")
|
||||
axis.label.set_text("Test")
|
||||
axis.label.set_visible(True)
|
||||
axis.get_helper().set_extremes(0.001, 10)
|
||||
|
||||
ax1.axis["lon"] = axis = grid_helper.new_floating_axis(
|
||||
1, 10,
|
||||
axes=ax1, axis_direction="bottom")
|
||||
axis.label.set_text("Test 2")
|
||||
axis.get_helper().set_extremes(50, 130)
|
||||
axis.major_ticklabels.set_axis_direction("top")
|
||||
axis.label.set_axis_direction("top")
|
||||
|
||||
grid_helper.grid_finder.grid_locator1.set_params(nbins=5)
|
||||
grid_helper.grid_finder.grid_locator2.set_params(nbins=5)
|
||||
|
||||
ax1.set_aspect(1.)
|
||||
ax1.set_xlim(-8, 8)
|
||||
ax1.set_ylim(-4, 12)
|
||||
|
||||
ax1.grid(True)
|
|
@ -0,0 +1,3 @@
|
|||
from .axes3d import Axes3D
|
||||
|
||||
__all__ = ['Axes3D']
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
1427
venv/lib/python3.13/site-packages/mpl_toolkits/mplot3d/art3d.py
Normal file
1427
venv/lib/python3.13/site-packages/mpl_toolkits/mplot3d/art3d.py
Normal file
File diff suppressed because it is too large
Load diff
4162
venv/lib/python3.13/site-packages/mpl_toolkits/mplot3d/axes3d.py
Normal file
4162
venv/lib/python3.13/site-packages/mpl_toolkits/mplot3d/axes3d.py
Normal file
File diff suppressed because it is too large
Load diff
750
venv/lib/python3.13/site-packages/mpl_toolkits/mplot3d/axis3d.py
Normal file
750
venv/lib/python3.13/site-packages/mpl_toolkits/mplot3d/axis3d.py
Normal file
|
@ -0,0 +1,750 @@
|
|||
# axis3d.py, original mplot3d version by John Porter
|
||||
# Created: 23 Sep 2005
|
||||
# Parts rewritten by Reinier Heeres <reinier@heeres.eu>
|
||||
|
||||
import inspect
|
||||
|
||||
import numpy as np
|
||||
|
||||
import matplotlib as mpl
|
||||
from matplotlib import (
|
||||
_api, artist, lines as mlines, axis as maxis, patches as mpatches,
|
||||
transforms as mtransforms, colors as mcolors)
|
||||
from . import art3d, proj3d
|
||||
|
||||
|
||||
def _move_from_center(coord, centers, deltas, axmask=(True, True, True)):
|
||||
"""
|
||||
For each coordinate where *axmask* is True, move *coord* away from
|
||||
*centers* by *deltas*.
|
||||
"""
|
||||
coord = np.asarray(coord)
|
||||
return coord + axmask * np.copysign(1, coord - centers) * deltas
|
||||
|
||||
|
||||
def _tick_update_position(tick, tickxs, tickys, labelpos):
|
||||
"""Update tick line and label position and style."""
|
||||
|
||||
tick.label1.set_position(labelpos)
|
||||
tick.label2.set_position(labelpos)
|
||||
tick.tick1line.set_visible(True)
|
||||
tick.tick2line.set_visible(False)
|
||||
tick.tick1line.set_linestyle('-')
|
||||
tick.tick1line.set_marker('')
|
||||
tick.tick1line.set_data(tickxs, tickys)
|
||||
tick.gridline.set_data([0], [0])
|
||||
|
||||
|
||||
class Axis(maxis.XAxis):
|
||||
"""An Axis class for the 3D plots."""
|
||||
# These points from the unit cube make up the x, y and z-planes
|
||||
_PLANES = (
|
||||
(0, 3, 7, 4), (1, 2, 6, 5), # yz planes
|
||||
(0, 1, 5, 4), (3, 2, 6, 7), # xz planes
|
||||
(0, 1, 2, 3), (4, 5, 6, 7), # xy planes
|
||||
)
|
||||
|
||||
# Some properties for the axes
|
||||
_AXINFO = {
|
||||
'x': {'i': 0, 'tickdir': 1, 'juggled': (1, 0, 2)},
|
||||
'y': {'i': 1, 'tickdir': 0, 'juggled': (0, 1, 2)},
|
||||
'z': {'i': 2, 'tickdir': 0, 'juggled': (0, 2, 1)},
|
||||
}
|
||||
|
||||
def _old_init(self, adir, v_intervalx, d_intervalx, axes, *args,
|
||||
rotate_label=None, **kwargs):
|
||||
return locals()
|
||||
|
||||
def _new_init(self, axes, *, rotate_label=None, **kwargs):
|
||||
return locals()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
params = _api.select_matching_signature(
|
||||
[self._old_init, self._new_init], *args, **kwargs)
|
||||
if "adir" in params:
|
||||
_api.warn_deprecated(
|
||||
"3.6", message=f"The signature of 3D Axis constructors has "
|
||||
f"changed in %(since)s; the new signature is "
|
||||
f"{inspect.signature(type(self).__init__)}", pending=True)
|
||||
if params["adir"] != self.axis_name:
|
||||
raise ValueError(f"Cannot instantiate {type(self).__name__} "
|
||||
f"with adir={params['adir']!r}")
|
||||
axes = params["axes"]
|
||||
rotate_label = params["rotate_label"]
|
||||
args = params.get("args", ())
|
||||
kwargs = params["kwargs"]
|
||||
|
||||
name = self.axis_name
|
||||
|
||||
self._label_position = 'default'
|
||||
self._tick_position = 'default'
|
||||
|
||||
# This is a temporary member variable.
|
||||
# Do not depend on this existing in future releases!
|
||||
self._axinfo = self._AXINFO[name].copy()
|
||||
# Common parts
|
||||
self._axinfo.update({
|
||||
'label': {'va': 'center', 'ha': 'center',
|
||||
'rotation_mode': 'anchor'},
|
||||
'color': mpl.rcParams[f'axes3d.{name}axis.panecolor'],
|
||||
'tick': {
|
||||
'inward_factor': 0.2,
|
||||
'outward_factor': 0.1,
|
||||
},
|
||||
})
|
||||
|
||||
if mpl.rcParams['_internal.classic_mode']:
|
||||
self._axinfo.update({
|
||||
'axisline': {'linewidth': 0.75, 'color': (0, 0, 0, 1)},
|
||||
'grid': {
|
||||
'color': (0.9, 0.9, 0.9, 1),
|
||||
'linewidth': 1.0,
|
||||
'linestyle': '-',
|
||||
},
|
||||
})
|
||||
self._axinfo['tick'].update({
|
||||
'linewidth': {
|
||||
True: mpl.rcParams['lines.linewidth'], # major
|
||||
False: mpl.rcParams['lines.linewidth'], # minor
|
||||
}
|
||||
})
|
||||
else:
|
||||
self._axinfo.update({
|
||||
'axisline': {
|
||||
'linewidth': mpl.rcParams['axes.linewidth'],
|
||||
'color': mpl.rcParams['axes.edgecolor'],
|
||||
},
|
||||
'grid': {
|
||||
'color': mpl.rcParams['grid.color'],
|
||||
'linewidth': mpl.rcParams['grid.linewidth'],
|
||||
'linestyle': mpl.rcParams['grid.linestyle'],
|
||||
},
|
||||
})
|
||||
self._axinfo['tick'].update({
|
||||
'linewidth': {
|
||||
True: ( # major
|
||||
mpl.rcParams['xtick.major.width'] if name in 'xz'
|
||||
else mpl.rcParams['ytick.major.width']),
|
||||
False: ( # minor
|
||||
mpl.rcParams['xtick.minor.width'] if name in 'xz'
|
||||
else mpl.rcParams['ytick.minor.width']),
|
||||
}
|
||||
})
|
||||
|
||||
super().__init__(axes, *args, **kwargs)
|
||||
|
||||
# data and viewing intervals for this direction
|
||||
if "d_intervalx" in params:
|
||||
self.set_data_interval(*params["d_intervalx"])
|
||||
if "v_intervalx" in params:
|
||||
self.set_view_interval(*params["v_intervalx"])
|
||||
self.set_rotate_label(rotate_label)
|
||||
self._init3d() # Inline after init3d deprecation elapses.
|
||||
|
||||
__init__.__signature__ = inspect.signature(_new_init)
|
||||
adir = _api.deprecated("3.6", pending=True)(
|
||||
property(lambda self: self.axis_name))
|
||||
|
||||
def _init3d(self):
|
||||
self.line = mlines.Line2D(
|
||||
xdata=(0, 0), ydata=(0, 0),
|
||||
linewidth=self._axinfo['axisline']['linewidth'],
|
||||
color=self._axinfo['axisline']['color'],
|
||||
antialiased=True)
|
||||
|
||||
# Store dummy data in Polygon object
|
||||
self.pane = mpatches.Polygon([[0, 0], [0, 1]], closed=False)
|
||||
self.set_pane_color(self._axinfo['color'])
|
||||
|
||||
self.axes._set_artist_props(self.line)
|
||||
self.axes._set_artist_props(self.pane)
|
||||
self.gridlines = art3d.Line3DCollection([])
|
||||
self.axes._set_artist_props(self.gridlines)
|
||||
self.axes._set_artist_props(self.label)
|
||||
self.axes._set_artist_props(self.offsetText)
|
||||
# Need to be able to place the label at the correct location
|
||||
self.label._transform = self.axes.transData
|
||||
self.offsetText._transform = self.axes.transData
|
||||
|
||||
@_api.deprecated("3.6", pending=True)
|
||||
def init3d(self): # After deprecation elapses, inline _init3d to __init__.
|
||||
self._init3d()
|
||||
|
||||
def get_major_ticks(self, numticks=None):
|
||||
ticks = super().get_major_ticks(numticks)
|
||||
for t in ticks:
|
||||
for obj in [
|
||||
t.tick1line, t.tick2line, t.gridline, t.label1, t.label2]:
|
||||
obj.set_transform(self.axes.transData)
|
||||
return ticks
|
||||
|
||||
def get_minor_ticks(self, numticks=None):
|
||||
ticks = super().get_minor_ticks(numticks)
|
||||
for t in ticks:
|
||||
for obj in [
|
||||
t.tick1line, t.tick2line, t.gridline, t.label1, t.label2]:
|
||||
obj.set_transform(self.axes.transData)
|
||||
return ticks
|
||||
|
||||
def set_ticks_position(self, position):
|
||||
"""
|
||||
Set the ticks position.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
position : {'lower', 'upper', 'both', 'default', 'none'}
|
||||
The position of the bolded axis lines, ticks, and tick labels.
|
||||
"""
|
||||
_api.check_in_list(['lower', 'upper', 'both', 'default', 'none'],
|
||||
position=position)
|
||||
self._tick_position = position
|
||||
|
||||
def get_ticks_position(self):
|
||||
"""
|
||||
Get the ticks position.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str : {'lower', 'upper', 'both', 'default', 'none'}
|
||||
The position of the bolded axis lines, ticks, and tick labels.
|
||||
"""
|
||||
return self._tick_position
|
||||
|
||||
def set_label_position(self, position):
|
||||
"""
|
||||
Set the label position.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
position : {'lower', 'upper', 'both', 'default', 'none'}
|
||||
The position of the axis label.
|
||||
"""
|
||||
_api.check_in_list(['lower', 'upper', 'both', 'default', 'none'],
|
||||
position=position)
|
||||
self._label_position = position
|
||||
|
||||
def get_label_position(self):
|
||||
"""
|
||||
Get the label position.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str : {'lower', 'upper', 'both', 'default', 'none'}
|
||||
The position of the axis label.
|
||||
"""
|
||||
return self._label_position
|
||||
|
||||
def set_pane_color(self, color, alpha=None):
|
||||
"""
|
||||
Set pane color.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
color : :mpltype:`color`
|
||||
Color for axis pane.
|
||||
alpha : float, optional
|
||||
Alpha value for axis pane. If None, base it on *color*.
|
||||
"""
|
||||
color = mcolors.to_rgba(color, alpha)
|
||||
self._axinfo['color'] = color
|
||||
self.pane.set_edgecolor(color)
|
||||
self.pane.set_facecolor(color)
|
||||
self.pane.set_alpha(color[-1])
|
||||
self.stale = True
|
||||
|
||||
def set_rotate_label(self, val):
|
||||
"""
|
||||
Whether to rotate the axis label: True, False or None.
|
||||
If set to None the label will be rotated if longer than 4 chars.
|
||||
"""
|
||||
self._rotate_label = val
|
||||
self.stale = True
|
||||
|
||||
def get_rotate_label(self, text):
|
||||
if self._rotate_label is not None:
|
||||
return self._rotate_label
|
||||
else:
|
||||
return len(text) > 4
|
||||
|
||||
def _get_coord_info(self):
|
||||
mins, maxs = np.array([
|
||||
self.axes.get_xbound(),
|
||||
self.axes.get_ybound(),
|
||||
self.axes.get_zbound(),
|
||||
]).T
|
||||
|
||||
# Project the bounds along the current position of the cube:
|
||||
bounds = mins[0], maxs[0], mins[1], maxs[1], mins[2], maxs[2]
|
||||
bounds_proj = self.axes._transformed_cube(bounds)
|
||||
|
||||
# Determine which one of the parallel planes are higher up:
|
||||
means_z0 = np.zeros(3)
|
||||
means_z1 = np.zeros(3)
|
||||
for i in range(3):
|
||||
means_z0[i] = np.mean(bounds_proj[self._PLANES[2 * i], 2])
|
||||
means_z1[i] = np.mean(bounds_proj[self._PLANES[2 * i + 1], 2])
|
||||
highs = means_z0 < means_z1
|
||||
|
||||
# Special handling for edge-on views
|
||||
equals = np.abs(means_z0 - means_z1) <= np.finfo(float).eps
|
||||
if np.sum(equals) == 2:
|
||||
vertical = np.where(~equals)[0][0]
|
||||
if vertical == 2: # looking at XY plane
|
||||
highs = np.array([True, True, highs[2]])
|
||||
elif vertical == 1: # looking at XZ plane
|
||||
highs = np.array([True, highs[1], False])
|
||||
elif vertical == 0: # looking at YZ plane
|
||||
highs = np.array([highs[0], False, False])
|
||||
|
||||
return mins, maxs, bounds_proj, highs
|
||||
|
||||
def _calc_centers_deltas(self, maxs, mins):
|
||||
centers = 0.5 * (maxs + mins)
|
||||
# In mpl3.8, the scale factor was 1/12. mpl3.9 changes this to
|
||||
# 1/12 * 24/25 = 0.08 to compensate for the change in automargin
|
||||
# behavior and keep appearance the same. The 24/25 factor is from the
|
||||
# 1/48 padding added to each side of the axis in mpl3.8.
|
||||
scale = 0.08
|
||||
deltas = (maxs - mins) * scale
|
||||
return centers, deltas
|
||||
|
||||
def _get_axis_line_edge_points(self, minmax, maxmin, position=None):
|
||||
"""Get the edge points for the black bolded axis line."""
|
||||
# When changing vertical axis some of the axes has to be
|
||||
# moved to the other plane so it looks the same as if the z-axis
|
||||
# was the vertical axis.
|
||||
mb = [minmax, maxmin] # line from origin to nearest corner to camera
|
||||
mb_rev = mb[::-1]
|
||||
mm = [[mb, mb_rev, mb_rev], [mb_rev, mb_rev, mb], [mb, mb, mb]]
|
||||
mm = mm[self.axes._vertical_axis][self._axinfo["i"]]
|
||||
|
||||
juggled = self._axinfo["juggled"]
|
||||
edge_point_0 = mm[0].copy() # origin point
|
||||
|
||||
if ((position == 'lower' and mm[1][juggled[-1]] < mm[0][juggled[-1]]) or
|
||||
(position == 'upper' and mm[1][juggled[-1]] > mm[0][juggled[-1]])):
|
||||
edge_point_0[juggled[-1]] = mm[1][juggled[-1]]
|
||||
else:
|
||||
edge_point_0[juggled[0]] = mm[1][juggled[0]]
|
||||
|
||||
edge_point_1 = edge_point_0.copy()
|
||||
edge_point_1[juggled[1]] = mm[1][juggled[1]]
|
||||
|
||||
return edge_point_0, edge_point_1
|
||||
|
||||
def _get_all_axis_line_edge_points(self, minmax, maxmin, axis_position=None):
|
||||
# Determine edge points for the axis lines
|
||||
edgep1s = []
|
||||
edgep2s = []
|
||||
position = []
|
||||
if axis_position in (None, 'default'):
|
||||
edgep1, edgep2 = self._get_axis_line_edge_points(minmax, maxmin)
|
||||
edgep1s = [edgep1]
|
||||
edgep2s = [edgep2]
|
||||
position = ['default']
|
||||
else:
|
||||
edgep1_l, edgep2_l = self._get_axis_line_edge_points(minmax, maxmin,
|
||||
position='lower')
|
||||
edgep1_u, edgep2_u = self._get_axis_line_edge_points(minmax, maxmin,
|
||||
position='upper')
|
||||
if axis_position in ('lower', 'both'):
|
||||
edgep1s.append(edgep1_l)
|
||||
edgep2s.append(edgep2_l)
|
||||
position.append('lower')
|
||||
if axis_position in ('upper', 'both'):
|
||||
edgep1s.append(edgep1_u)
|
||||
edgep2s.append(edgep2_u)
|
||||
position.append('upper')
|
||||
return edgep1s, edgep2s, position
|
||||
|
||||
def _get_tickdir(self, position):
|
||||
"""
|
||||
Get the direction of the tick.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
position : str, optional : {'upper', 'lower', 'default'}
|
||||
The position of the axis.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tickdir : int
|
||||
Index which indicates which coordinate the tick line will
|
||||
align with.
|
||||
"""
|
||||
_api.check_in_list(('upper', 'lower', 'default'), position=position)
|
||||
|
||||
# TODO: Move somewhere else where it's triggered less:
|
||||
tickdirs_base = [v["tickdir"] for v in self._AXINFO.values()] # default
|
||||
elev_mod = np.mod(self.axes.elev + 180, 360) - 180
|
||||
azim_mod = np.mod(self.axes.azim, 360)
|
||||
if position == 'upper':
|
||||
if elev_mod >= 0:
|
||||
tickdirs_base = [2, 2, 0]
|
||||
else:
|
||||
tickdirs_base = [1, 0, 0]
|
||||
if 0 <= azim_mod < 180:
|
||||
tickdirs_base[2] = 1
|
||||
elif position == 'lower':
|
||||
if elev_mod >= 0:
|
||||
tickdirs_base = [1, 0, 1]
|
||||
else:
|
||||
tickdirs_base = [2, 2, 1]
|
||||
if 0 <= azim_mod < 180:
|
||||
tickdirs_base[2] = 0
|
||||
info_i = [v["i"] for v in self._AXINFO.values()]
|
||||
|
||||
i = self._axinfo["i"]
|
||||
vert_ax = self.axes._vertical_axis
|
||||
j = vert_ax - 2
|
||||
# default: tickdir = [[1, 2, 1], [2, 2, 0], [1, 0, 0]][vert_ax][i]
|
||||
tickdir = np.roll(info_i, -j)[np.roll(tickdirs_base, j)][i]
|
||||
return tickdir
|
||||
|
||||
def active_pane(self):
|
||||
mins, maxs, tc, highs = self._get_coord_info()
|
||||
info = self._axinfo
|
||||
index = info['i']
|
||||
if not highs[index]:
|
||||
loc = mins[index]
|
||||
plane = self._PLANES[2 * index]
|
||||
else:
|
||||
loc = maxs[index]
|
||||
plane = self._PLANES[2 * index + 1]
|
||||
xys = np.array([tc[p] for p in plane])
|
||||
return xys, loc
|
||||
|
||||
def draw_pane(self, renderer):
|
||||
"""
|
||||
Draw pane.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
renderer : `~matplotlib.backend_bases.RendererBase` subclass
|
||||
"""
|
||||
renderer.open_group('pane3d', gid=self.get_gid())
|
||||
xys, loc = self.active_pane()
|
||||
self.pane.xy = xys[:, :2]
|
||||
self.pane.draw(renderer)
|
||||
renderer.close_group('pane3d')
|
||||
|
||||
def _axmask(self):
|
||||
axmask = [True, True, True]
|
||||
axmask[self._axinfo["i"]] = False
|
||||
return axmask
|
||||
|
||||
def _draw_ticks(self, renderer, edgep1, centers, deltas, highs,
|
||||
deltas_per_point, pos):
|
||||
ticks = self._update_ticks()
|
||||
info = self._axinfo
|
||||
index = info["i"]
|
||||
juggled = info["juggled"]
|
||||
|
||||
mins, maxs, tc, highs = self._get_coord_info()
|
||||
centers, deltas = self._calc_centers_deltas(maxs, mins)
|
||||
|
||||
# Draw ticks:
|
||||
tickdir = self._get_tickdir(pos)
|
||||
tickdelta = deltas[tickdir] if highs[tickdir] else -deltas[tickdir]
|
||||
|
||||
tick_info = info['tick']
|
||||
tick_out = tick_info['outward_factor'] * tickdelta
|
||||
tick_in = tick_info['inward_factor'] * tickdelta
|
||||
tick_lw = tick_info['linewidth']
|
||||
edgep1_tickdir = edgep1[tickdir]
|
||||
out_tickdir = edgep1_tickdir + tick_out
|
||||
in_tickdir = edgep1_tickdir - tick_in
|
||||
|
||||
default_label_offset = 8. # A rough estimate
|
||||
points = deltas_per_point * deltas
|
||||
for tick in ticks:
|
||||
# Get tick line positions
|
||||
pos = edgep1.copy()
|
||||
pos[index] = tick.get_loc()
|
||||
pos[tickdir] = out_tickdir
|
||||
x1, y1, z1 = proj3d.proj_transform(*pos, self.axes.M)
|
||||
pos[tickdir] = in_tickdir
|
||||
x2, y2, z2 = proj3d.proj_transform(*pos, self.axes.M)
|
||||
|
||||
# Get position of label
|
||||
labeldeltas = (tick.get_pad() + default_label_offset) * points
|
||||
|
||||
pos[tickdir] = edgep1_tickdir
|
||||
pos = _move_from_center(pos, centers, labeldeltas, self._axmask())
|
||||
lx, ly, lz = proj3d.proj_transform(*pos, self.axes.M)
|
||||
|
||||
_tick_update_position(tick, (x1, x2), (y1, y2), (lx, ly))
|
||||
tick.tick1line.set_linewidth(tick_lw[tick._major])
|
||||
tick.draw(renderer)
|
||||
|
||||
def _draw_offset_text(self, renderer, edgep1, edgep2, labeldeltas, centers,
|
||||
highs, pep, dx, dy):
|
||||
# Get general axis information:
|
||||
info = self._axinfo
|
||||
index = info["i"]
|
||||
juggled = info["juggled"]
|
||||
tickdir = info["tickdir"]
|
||||
|
||||
# Which of the two edge points do we want to
|
||||
# use for locating the offset text?
|
||||
if juggled[2] == 2:
|
||||
outeredgep = edgep1
|
||||
outerindex = 0
|
||||
else:
|
||||
outeredgep = edgep2
|
||||
outerindex = 1
|
||||
|
||||
pos = _move_from_center(outeredgep, centers, labeldeltas,
|
||||
self._axmask())
|
||||
olx, oly, olz = proj3d.proj_transform(*pos, self.axes.M)
|
||||
self.offsetText.set_text(self.major.formatter.get_offset())
|
||||
self.offsetText.set_position((olx, oly))
|
||||
angle = art3d._norm_text_angle(np.rad2deg(np.arctan2(dy, dx)))
|
||||
self.offsetText.set_rotation(angle)
|
||||
# Must set rotation mode to "anchor" so that
|
||||
# the alignment point is used as the "fulcrum" for rotation.
|
||||
self.offsetText.set_rotation_mode('anchor')
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Note: the following statement for determining the proper alignment of
|
||||
# the offset text. This was determined entirely by trial-and-error
|
||||
# and should not be in any way considered as "the way". There are
|
||||
# still some edge cases where alignment is not quite right, but this
|
||||
# seems to be more of a geometry issue (in other words, I might be
|
||||
# using the wrong reference points).
|
||||
#
|
||||
# (TT, FF, TF, FT) are the shorthand for the tuple of
|
||||
# (centpt[tickdir] <= pep[tickdir, outerindex],
|
||||
# centpt[index] <= pep[index, outerindex])
|
||||
#
|
||||
# Three-letters (e.g., TFT, FTT) are short-hand for the array of bools
|
||||
# from the variable 'highs'.
|
||||
# ---------------------------------------------------------------------
|
||||
centpt = proj3d.proj_transform(*centers, self.axes.M)
|
||||
if centpt[tickdir] > pep[tickdir, outerindex]:
|
||||
# if FT and if highs has an even number of Trues
|
||||
if (centpt[index] <= pep[index, outerindex]
|
||||
and np.count_nonzero(highs) % 2 == 0):
|
||||
# Usually, this means align right, except for the FTT case,
|
||||
# in which offset for axis 1 and 2 are aligned left.
|
||||
if highs.tolist() == [False, True, True] and index in (1, 2):
|
||||
align = 'left'
|
||||
else:
|
||||
align = 'right'
|
||||
else:
|
||||
# The FF case
|
||||
align = 'left'
|
||||
else:
|
||||
# if TF and if highs has an even number of Trues
|
||||
if (centpt[index] > pep[index, outerindex]
|
||||
and np.count_nonzero(highs) % 2 == 0):
|
||||
# Usually mean align left, except if it is axis 2
|
||||
align = 'right' if index == 2 else 'left'
|
||||
else:
|
||||
# The TT case
|
||||
align = 'right'
|
||||
|
||||
self.offsetText.set_va('center')
|
||||
self.offsetText.set_ha(align)
|
||||
self.offsetText.draw(renderer)
|
||||
|
||||
def _draw_labels(self, renderer, edgep1, edgep2, labeldeltas, centers, dx, dy):
|
||||
label = self._axinfo["label"]
|
||||
|
||||
# Draw labels
|
||||
lxyz = 0.5 * (edgep1 + edgep2)
|
||||
lxyz = _move_from_center(lxyz, centers, labeldeltas, self._axmask())
|
||||
tlx, tly, tlz = proj3d.proj_transform(*lxyz, self.axes.M)
|
||||
self.label.set_position((tlx, tly))
|
||||
if self.get_rotate_label(self.label.get_text()):
|
||||
angle = art3d._norm_text_angle(np.rad2deg(np.arctan2(dy, dx)))
|
||||
self.label.set_rotation(angle)
|
||||
self.label.set_va(label['va'])
|
||||
self.label.set_ha(label['ha'])
|
||||
self.label.set_rotation_mode(label['rotation_mode'])
|
||||
self.label.draw(renderer)
|
||||
|
||||
@artist.allow_rasterization
|
||||
def draw(self, renderer):
|
||||
self.label._transform = self.axes.transData
|
||||
self.offsetText._transform = self.axes.transData
|
||||
renderer.open_group("axis3d", gid=self.get_gid())
|
||||
|
||||
# Get general axis information:
|
||||
mins, maxs, tc, highs = self._get_coord_info()
|
||||
centers, deltas = self._calc_centers_deltas(maxs, mins)
|
||||
|
||||
# Calculate offset distances
|
||||
# A rough estimate; points are ambiguous since 3D plots rotate
|
||||
reltoinches = self.get_figure(root=False).dpi_scale_trans.inverted()
|
||||
ax_inches = reltoinches.transform(self.axes.bbox.size)
|
||||
ax_points_estimate = sum(72. * ax_inches)
|
||||
deltas_per_point = 48 / ax_points_estimate
|
||||
default_offset = 21.
|
||||
labeldeltas = (self.labelpad + default_offset) * deltas_per_point * deltas
|
||||
|
||||
# Determine edge points for the axis lines
|
||||
minmax = np.where(highs, maxs, mins) # "origin" point
|
||||
maxmin = np.where(~highs, maxs, mins) # "opposite" corner near camera
|
||||
|
||||
for edgep1, edgep2, pos in zip(*self._get_all_axis_line_edge_points(
|
||||
minmax, maxmin, self._tick_position)):
|
||||
# Project the edge points along the current position
|
||||
pep = proj3d._proj_trans_points([edgep1, edgep2], self.axes.M)
|
||||
pep = np.asarray(pep)
|
||||
|
||||
# The transAxes transform is used because the Text object
|
||||
# rotates the text relative to the display coordinate system.
|
||||
# Therefore, if we want the labels to remain parallel to the
|
||||
# axis regardless of the aspect ratio, we need to convert the
|
||||
# edge points of the plane to display coordinates and calculate
|
||||
# an angle from that.
|
||||
# TODO: Maybe Text objects should handle this themselves?
|
||||
dx, dy = (self.axes.transAxes.transform([pep[0:2, 1]]) -
|
||||
self.axes.transAxes.transform([pep[0:2, 0]]))[0]
|
||||
|
||||
# Draw the lines
|
||||
self.line.set_data(pep[0], pep[1])
|
||||
self.line.draw(renderer)
|
||||
|
||||
# Draw ticks
|
||||
self._draw_ticks(renderer, edgep1, centers, deltas, highs,
|
||||
deltas_per_point, pos)
|
||||
|
||||
# Draw Offset text
|
||||
self._draw_offset_text(renderer, edgep1, edgep2, labeldeltas,
|
||||
centers, highs, pep, dx, dy)
|
||||
|
||||
for edgep1, edgep2, pos in zip(*self._get_all_axis_line_edge_points(
|
||||
minmax, maxmin, self._label_position)):
|
||||
# See comments above
|
||||
pep = proj3d._proj_trans_points([edgep1, edgep2], self.axes.M)
|
||||
pep = np.asarray(pep)
|
||||
dx, dy = (self.axes.transAxes.transform([pep[0:2, 1]]) -
|
||||
self.axes.transAxes.transform([pep[0:2, 0]]))[0]
|
||||
|
||||
# Draw labels
|
||||
self._draw_labels(renderer, edgep1, edgep2, labeldeltas, centers, dx, dy)
|
||||
|
||||
renderer.close_group('axis3d')
|
||||
self.stale = False
|
||||
|
||||
@artist.allow_rasterization
|
||||
def draw_grid(self, renderer):
|
||||
if not self.axes._draw_grid:
|
||||
return
|
||||
|
||||
renderer.open_group("grid3d", gid=self.get_gid())
|
||||
|
||||
ticks = self._update_ticks()
|
||||
if len(ticks):
|
||||
# Get general axis information:
|
||||
info = self._axinfo
|
||||
index = info["i"]
|
||||
|
||||
mins, maxs, tc, highs = self._get_coord_info()
|
||||
|
||||
minmax = np.where(highs, maxs, mins)
|
||||
maxmin = np.where(~highs, maxs, mins)
|
||||
|
||||
# Grid points where the planes meet
|
||||
xyz0 = np.tile(minmax, (len(ticks), 1))
|
||||
xyz0[:, index] = [tick.get_loc() for tick in ticks]
|
||||
|
||||
# Grid lines go from the end of one plane through the plane
|
||||
# intersection (at xyz0) to the end of the other plane. The first
|
||||
# point (0) differs along dimension index-2 and the last (2) along
|
||||
# dimension index-1.
|
||||
lines = np.stack([xyz0, xyz0, xyz0], axis=1)
|
||||
lines[:, 0, index - 2] = maxmin[index - 2]
|
||||
lines[:, 2, index - 1] = maxmin[index - 1]
|
||||
self.gridlines.set_segments(lines)
|
||||
gridinfo = info['grid']
|
||||
self.gridlines.set_color(gridinfo['color'])
|
||||
self.gridlines.set_linewidth(gridinfo['linewidth'])
|
||||
self.gridlines.set_linestyle(gridinfo['linestyle'])
|
||||
self.gridlines.do_3d_projection()
|
||||
self.gridlines.draw(renderer)
|
||||
|
||||
renderer.close_group('grid3d')
|
||||
|
||||
# TODO: Get this to work (more) properly when mplot3d supports the
|
||||
# transforms framework.
|
||||
def get_tightbbox(self, renderer=None, *, for_layout_only=False):
|
||||
# docstring inherited
|
||||
if not self.get_visible():
|
||||
return
|
||||
# We have to directly access the internal data structures
|
||||
# (and hope they are up to date) because at draw time we
|
||||
# shift the ticks and their labels around in (x, y) space
|
||||
# based on the projection, the current view port, and their
|
||||
# position in 3D space. If we extend the transforms framework
|
||||
# into 3D we would not need to do this different book keeping
|
||||
# than we do in the normal axis
|
||||
major_locs = self.get_majorticklocs()
|
||||
minor_locs = self.get_minorticklocs()
|
||||
|
||||
ticks = [*self.get_minor_ticks(len(minor_locs)),
|
||||
*self.get_major_ticks(len(major_locs))]
|
||||
view_low, view_high = self.get_view_interval()
|
||||
if view_low > view_high:
|
||||
view_low, view_high = view_high, view_low
|
||||
interval_t = self.get_transform().transform([view_low, view_high])
|
||||
|
||||
ticks_to_draw = []
|
||||
for tick in ticks:
|
||||
try:
|
||||
loc_t = self.get_transform().transform(tick.get_loc())
|
||||
except AssertionError:
|
||||
# Transform.transform doesn't allow masked values but
|
||||
# some scales might make them, so we need this try/except.
|
||||
pass
|
||||
else:
|
||||
if mtransforms._interval_contains_close(interval_t, loc_t):
|
||||
ticks_to_draw.append(tick)
|
||||
|
||||
ticks = ticks_to_draw
|
||||
|
||||
bb_1, bb_2 = self._get_ticklabel_bboxes(ticks, renderer)
|
||||
other = []
|
||||
|
||||
if self.line.get_visible():
|
||||
other.append(self.line.get_window_extent(renderer))
|
||||
if (self.label.get_visible() and not for_layout_only and
|
||||
self.label.get_text()):
|
||||
other.append(self.label.get_window_extent(renderer))
|
||||
|
||||
return mtransforms.Bbox.union([*bb_1, *bb_2, *other])
|
||||
|
||||
d_interval = _api.deprecated(
|
||||
"3.6", alternative="get_data_interval", pending=True)(
|
||||
property(lambda self: self.get_data_interval(),
|
||||
lambda self, minmax: self.set_data_interval(*minmax)))
|
||||
v_interval = _api.deprecated(
|
||||
"3.6", alternative="get_view_interval", pending=True)(
|
||||
property(lambda self: self.get_view_interval(),
|
||||
lambda self, minmax: self.set_view_interval(*minmax)))
|
||||
|
||||
|
||||
class XAxis(Axis):
|
||||
axis_name = "x"
|
||||
get_view_interval, set_view_interval = maxis._make_getset_interval(
|
||||
"view", "xy_viewLim", "intervalx")
|
||||
get_data_interval, set_data_interval = maxis._make_getset_interval(
|
||||
"data", "xy_dataLim", "intervalx")
|
||||
|
||||
|
||||
class YAxis(Axis):
|
||||
axis_name = "y"
|
||||
get_view_interval, set_view_interval = maxis._make_getset_interval(
|
||||
"view", "xy_viewLim", "intervaly")
|
||||
get_data_interval, set_data_interval = maxis._make_getset_interval(
|
||||
"data", "xy_dataLim", "intervaly")
|
||||
|
||||
|
||||
class ZAxis(Axis):
|
||||
axis_name = "z"
|
||||
get_view_interval, set_view_interval = maxis._make_getset_interval(
|
||||
"view", "zz_viewLim", "intervalx")
|
||||
get_data_interval, set_data_interval = maxis._make_getset_interval(
|
||||
"data", "zz_dataLim", "intervalx")
|
219
venv/lib/python3.13/site-packages/mpl_toolkits/mplot3d/proj3d.py
Normal file
219
venv/lib/python3.13/site-packages/mpl_toolkits/mplot3d/proj3d.py
Normal file
|
@ -0,0 +1,219 @@
|
|||
"""
|
||||
Various transforms used for by the 3D code
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from matplotlib import _api
|
||||
|
||||
|
||||
def world_transformation(xmin, xmax,
|
||||
ymin, ymax,
|
||||
zmin, zmax, pb_aspect=None):
|
||||
"""
|
||||
Produce a matrix that scales homogeneous coords in the specified ranges
|
||||
to [0, 1], or [0, pb_aspect[i]] if the plotbox aspect ratio is specified.
|
||||
"""
|
||||
dx = xmax - xmin
|
||||
dy = ymax - ymin
|
||||
dz = zmax - zmin
|
||||
if pb_aspect is not None:
|
||||
ax, ay, az = pb_aspect
|
||||
dx /= ax
|
||||
dy /= ay
|
||||
dz /= az
|
||||
|
||||
return np.array([[1/dx, 0, 0, -xmin/dx],
|
||||
[ 0, 1/dy, 0, -ymin/dy],
|
||||
[ 0, 0, 1/dz, -zmin/dz],
|
||||
[ 0, 0, 0, 1]])
|
||||
|
||||
|
||||
def _rotation_about_vector(v, angle):
|
||||
"""
|
||||
Produce a rotation matrix for an angle in radians about a vector.
|
||||
"""
|
||||
vx, vy, vz = v / np.linalg.norm(v)
|
||||
s = np.sin(angle)
|
||||
c = np.cos(angle)
|
||||
t = 2*np.sin(angle/2)**2 # more numerically stable than t = 1-c
|
||||
|
||||
R = np.array([
|
||||
[t*vx*vx + c, t*vx*vy - vz*s, t*vx*vz + vy*s],
|
||||
[t*vy*vx + vz*s, t*vy*vy + c, t*vy*vz - vx*s],
|
||||
[t*vz*vx - vy*s, t*vz*vy + vx*s, t*vz*vz + c]])
|
||||
|
||||
return R
|
||||
|
||||
|
||||
def _view_axes(E, R, V, roll):
|
||||
"""
|
||||
Get the unit viewing axes in data coordinates.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
E : 3-element numpy array
|
||||
The coordinates of the eye/camera.
|
||||
R : 3-element numpy array
|
||||
The coordinates of the center of the view box.
|
||||
V : 3-element numpy array
|
||||
Unit vector in the direction of the vertical axis.
|
||||
roll : float
|
||||
The roll angle in radians.
|
||||
|
||||
Returns
|
||||
-------
|
||||
u : 3-element numpy array
|
||||
Unit vector pointing towards the right of the screen.
|
||||
v : 3-element numpy array
|
||||
Unit vector pointing towards the top of the screen.
|
||||
w : 3-element numpy array
|
||||
Unit vector pointing out of the screen.
|
||||
"""
|
||||
w = (E - R)
|
||||
w = w/np.linalg.norm(w)
|
||||
u = np.cross(V, w)
|
||||
u = u/np.linalg.norm(u)
|
||||
v = np.cross(w, u) # Will be a unit vector
|
||||
|
||||
# Save some computation for the default roll=0
|
||||
if roll != 0:
|
||||
# A positive rotation of the camera is a negative rotation of the world
|
||||
Rroll = _rotation_about_vector(w, -roll)
|
||||
u = np.dot(Rroll, u)
|
||||
v = np.dot(Rroll, v)
|
||||
return u, v, w
|
||||
|
||||
|
||||
def _view_transformation_uvw(u, v, w, E):
|
||||
"""
|
||||
Return the view transformation matrix.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
u : 3-element numpy array
|
||||
Unit vector pointing towards the right of the screen.
|
||||
v : 3-element numpy array
|
||||
Unit vector pointing towards the top of the screen.
|
||||
w : 3-element numpy array
|
||||
Unit vector pointing out of the screen.
|
||||
E : 3-element numpy array
|
||||
The coordinates of the eye/camera.
|
||||
"""
|
||||
Mr = np.eye(4)
|
||||
Mt = np.eye(4)
|
||||
Mr[:3, :3] = [u, v, w]
|
||||
Mt[:3, -1] = -E
|
||||
M = np.dot(Mr, Mt)
|
||||
return M
|
||||
|
||||
|
||||
def _persp_transformation(zfront, zback, focal_length):
|
||||
e = focal_length
|
||||
a = 1 # aspect ratio
|
||||
b = (zfront+zback)/(zfront-zback)
|
||||
c = -2*(zfront*zback)/(zfront-zback)
|
||||
proj_matrix = np.array([[e, 0, 0, 0],
|
||||
[0, e/a, 0, 0],
|
||||
[0, 0, b, c],
|
||||
[0, 0, -1, 0]])
|
||||
return proj_matrix
|
||||
|
||||
|
||||
def _ortho_transformation(zfront, zback):
|
||||
# note: w component in the resulting vector will be (zback-zfront), not 1
|
||||
a = -(zfront + zback)
|
||||
b = -(zfront - zback)
|
||||
proj_matrix = np.array([[2, 0, 0, 0],
|
||||
[0, 2, 0, 0],
|
||||
[0, 0, -2, 0],
|
||||
[0, 0, a, b]])
|
||||
return proj_matrix
|
||||
|
||||
|
||||
def _proj_transform_vec(vec, M):
|
||||
vecw = np.dot(M, vec.data)
|
||||
w = vecw[3]
|
||||
txs, tys, tzs = vecw[0]/w, vecw[1]/w, vecw[2]/w
|
||||
if np.ma.isMA(vec[0]): # we check each to protect for scalars
|
||||
txs = np.ma.array(txs, mask=vec[0].mask)
|
||||
if np.ma.isMA(vec[1]):
|
||||
tys = np.ma.array(tys, mask=vec[1].mask)
|
||||
if np.ma.isMA(vec[2]):
|
||||
tzs = np.ma.array(tzs, mask=vec[2].mask)
|
||||
return txs, tys, tzs
|
||||
|
||||
|
||||
def _proj_transform_vec_clip(vec, M, focal_length):
|
||||
vecw = np.dot(M, vec.data)
|
||||
w = vecw[3]
|
||||
txs, tys, tzs = vecw[0] / w, vecw[1] / w, vecw[2] / w
|
||||
if np.isinf(focal_length): # don't clip orthographic projection
|
||||
tis = np.ones(txs.shape, dtype=bool)
|
||||
else:
|
||||
tis = (-1 <= txs) & (txs <= 1) & (-1 <= tys) & (tys <= 1) & (tzs <= 0)
|
||||
if np.ma.isMA(vec[0]):
|
||||
tis = tis & ~vec[0].mask
|
||||
if np.ma.isMA(vec[1]):
|
||||
tis = tis & ~vec[1].mask
|
||||
if np.ma.isMA(vec[2]):
|
||||
tis = tis & ~vec[2].mask
|
||||
|
||||
txs = np.ma.masked_array(txs, ~tis)
|
||||
tys = np.ma.masked_array(tys, ~tis)
|
||||
tzs = np.ma.masked_array(tzs, ~tis)
|
||||
return txs, tys, tzs, tis
|
||||
|
||||
|
||||
def inv_transform(xs, ys, zs, invM):
|
||||
"""
|
||||
Transform the points by the inverse of the projection matrix, *invM*.
|
||||
"""
|
||||
vec = _vec_pad_ones(xs, ys, zs)
|
||||
vecr = np.dot(invM, vec)
|
||||
if vecr.shape == (4,):
|
||||
vecr = vecr.reshape((4, 1))
|
||||
for i in range(vecr.shape[1]):
|
||||
if vecr[3][i] != 0:
|
||||
vecr[:, i] = vecr[:, i] / vecr[3][i]
|
||||
return vecr[0], vecr[1], vecr[2]
|
||||
|
||||
|
||||
def _vec_pad_ones(xs, ys, zs):
|
||||
if np.ma.isMA(xs) or np.ma.isMA(ys) or np.ma.isMA(zs):
|
||||
return np.ma.array([xs, ys, zs, np.ones_like(xs)])
|
||||
else:
|
||||
return np.array([xs, ys, zs, np.ones_like(xs)])
|
||||
|
||||
|
||||
def proj_transform(xs, ys, zs, M):
|
||||
"""
|
||||
Transform the points by the projection matrix *M*.
|
||||
"""
|
||||
vec = _vec_pad_ones(xs, ys, zs)
|
||||
return _proj_transform_vec(vec, M)
|
||||
|
||||
|
||||
@_api.deprecated("3.10")
|
||||
def proj_transform_clip(xs, ys, zs, M):
|
||||
return _proj_transform_clip(xs, ys, zs, M, focal_length=np.inf)
|
||||
|
||||
|
||||
def _proj_transform_clip(xs, ys, zs, M, focal_length):
|
||||
"""
|
||||
Transform the points by the projection matrix
|
||||
and return the clipping result
|
||||
returns txs, tys, tzs, tis
|
||||
"""
|
||||
vec = _vec_pad_ones(xs, ys, zs)
|
||||
return _proj_transform_vec_clip(vec, M, focal_length)
|
||||
|
||||
|
||||
def _proj_points(points, M):
|
||||
return np.column_stack(_proj_trans_points(points, M))
|
||||
|
||||
|
||||
def _proj_trans_points(points, M):
|
||||
points = np.asanyarray(points)
|
||||
xs, ys, zs = points[:, 0], points[:, 1], points[:, 2]
|
||||
return proj_transform(xs, ys, zs, M)
|
|
@ -0,0 +1,10 @@
|
|||
from pathlib import Path
|
||||
|
||||
|
||||
# Check that the test directories exist
|
||||
if not (Path(__file__).parent / "baseline_images").exists():
|
||||
raise OSError(
|
||||
'The baseline image directory does not exist. '
|
||||
'This is most likely because the test data is not installed. '
|
||||
'You may need to install matplotlib from source to get the '
|
||||
'test data.')
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,2 @@
|
|||
from matplotlib.testing.conftest import (mpl_test_settings, # noqa
|
||||
pytest_configure, pytest_unconfigure)
|
|
@ -0,0 +1,102 @@
|
|||
import numpy as np
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from matplotlib.backend_bases import MouseEvent
|
||||
from mpl_toolkits.mplot3d.art3d import (
|
||||
Line3DCollection,
|
||||
Poly3DCollection,
|
||||
_all_points_on_plane,
|
||||
)
|
||||
|
||||
|
||||
def test_scatter_3d_projection_conservation():
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(projection='3d')
|
||||
# fix axes3d projection
|
||||
ax.roll = 0
|
||||
ax.elev = 0
|
||||
ax.azim = -45
|
||||
ax.stale = True
|
||||
|
||||
x = [0, 1, 2, 3, 4]
|
||||
scatter_collection = ax.scatter(x, x, x)
|
||||
fig.canvas.draw_idle()
|
||||
|
||||
# Get scatter location on canvas and freeze the data
|
||||
scatter_offset = scatter_collection.get_offsets()
|
||||
scatter_location = ax.transData.transform(scatter_offset)
|
||||
|
||||
# Yaw -44 and -46 are enough to produce two set of scatter
|
||||
# with opposite z-order without moving points too far
|
||||
for azim in (-44, -46):
|
||||
ax.azim = azim
|
||||
ax.stale = True
|
||||
fig.canvas.draw_idle()
|
||||
|
||||
for i in range(5):
|
||||
# Create a mouse event used to locate and to get index
|
||||
# from each dots
|
||||
event = MouseEvent("button_press_event", fig.canvas,
|
||||
*scatter_location[i, :])
|
||||
contains, ind = scatter_collection.contains(event)
|
||||
assert contains is True
|
||||
assert len(ind["ind"]) == 1
|
||||
assert ind["ind"][0] == i
|
||||
|
||||
|
||||
def test_zordered_error():
|
||||
# Smoke test for https://github.com/matplotlib/matplotlib/issues/26497
|
||||
lc = [(np.fromiter([0.0, 0.0, 0.0], dtype="float"),
|
||||
np.fromiter([1.0, 1.0, 1.0], dtype="float"))]
|
||||
pc = [np.fromiter([0.0, 0.0], dtype="float"),
|
||||
np.fromiter([0.0, 1.0], dtype="float"),
|
||||
np.fromiter([1.0, 1.0], dtype="float")]
|
||||
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(projection="3d")
|
||||
ax.add_collection(Line3DCollection(lc))
|
||||
ax.scatter(*pc, visible=False)
|
||||
plt.draw()
|
||||
|
||||
|
||||
def test_all_points_on_plane():
|
||||
# Non-coplanar points
|
||||
points = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]])
|
||||
assert not _all_points_on_plane(*points.T)
|
||||
|
||||
# Duplicate points
|
||||
points = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 0]])
|
||||
assert _all_points_on_plane(*points.T)
|
||||
|
||||
# NaN values
|
||||
points = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, np.nan]])
|
||||
assert _all_points_on_plane(*points.T)
|
||||
|
||||
# Less than 3 unique points
|
||||
points = np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0]])
|
||||
assert _all_points_on_plane(*points.T)
|
||||
|
||||
# All points lie on a line
|
||||
points = np.array([[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0]])
|
||||
assert _all_points_on_plane(*points.T)
|
||||
|
||||
# All points lie on two lines, with antiparallel vectors
|
||||
points = np.array([[-2, 2, 0], [-1, 1, 0], [1, -1, 0],
|
||||
[0, 0, 0], [2, 0, 0], [1, 0, 0]])
|
||||
assert _all_points_on_plane(*points.T)
|
||||
|
||||
# All points lie on a plane
|
||||
points = np.array([[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0], [1, 2, 0]])
|
||||
assert _all_points_on_plane(*points.T)
|
||||
|
||||
|
||||
def test_generate_normals():
|
||||
# Smoke test for https://github.com/matplotlib/matplotlib/issues/29156
|
||||
vertices = ((0, 0, 0), (0, 5, 0), (5, 5, 0), (5, 0, 0))
|
||||
shape = Poly3DCollection([vertices], edgecolors='r', shade=True)
|
||||
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(projection='3d')
|
||||
ax.add_collection3d(shape)
|
||||
plt.draw()
|
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,117 @@
|
|||
import platform
|
||||
|
||||
import numpy as np
|
||||
|
||||
import matplotlib as mpl
|
||||
from matplotlib.colors import same_color
|
||||
from matplotlib.testing.decorators import image_comparison
|
||||
import matplotlib.pyplot as plt
|
||||
from mpl_toolkits.mplot3d import art3d
|
||||
|
||||
|
||||
@image_comparison(['legend_plot.png'], remove_text=True, style='mpl20')
|
||||
def test_legend_plot():
|
||||
fig, ax = plt.subplots(subplot_kw=dict(projection='3d'))
|
||||
x = np.arange(10)
|
||||
ax.plot(x, 5 - x, 'o', zdir='y', label='z=1')
|
||||
ax.plot(x, x - 5, 'o', zdir='y', label='z=-1')
|
||||
ax.legend()
|
||||
|
||||
|
||||
@image_comparison(['legend_bar.png'], remove_text=True, style='mpl20')
|
||||
def test_legend_bar():
|
||||
fig, ax = plt.subplots(subplot_kw=dict(projection='3d'))
|
||||
x = np.arange(10)
|
||||
b1 = ax.bar(x, x, zdir='y', align='edge', color='m')
|
||||
b2 = ax.bar(x, x[::-1], zdir='x', align='edge', color='g')
|
||||
ax.legend([b1[0], b2[0]], ['up', 'down'])
|
||||
|
||||
|
||||
@image_comparison(['fancy.png'], remove_text=True, style='mpl20',
|
||||
tol=0 if platform.machine() == 'x86_64' else 0.011)
|
||||
def test_fancy():
|
||||
fig, ax = plt.subplots(subplot_kw=dict(projection='3d'))
|
||||
ax.plot(np.arange(10), np.full(10, 5), np.full(10, 5), 'o--', label='line')
|
||||
ax.scatter(np.arange(10), np.arange(10, 0, -1), label='scatter')
|
||||
ax.errorbar(np.full(10, 5), np.arange(10), np.full(10, 10),
|
||||
xerr=0.5, zerr=0.5, label='errorbar')
|
||||
ax.legend(loc='lower left', ncols=2, title='My legend', numpoints=1)
|
||||
|
||||
|
||||
def test_linecollection_scaled_dashes():
|
||||
lines1 = [[(0, .5), (.5, 1)], [(.3, .6), (.2, .2)]]
|
||||
lines2 = [[[0.7, .2], [.8, .4]], [[.5, .7], [.6, .1]]]
|
||||
lines3 = [[[0.6, .2], [.8, .4]], [[.5, .7], [.1, .1]]]
|
||||
lc1 = art3d.Line3DCollection(lines1, linestyles="--", lw=3)
|
||||
lc2 = art3d.Line3DCollection(lines2, linestyles="-.")
|
||||
lc3 = art3d.Line3DCollection(lines3, linestyles=":", lw=.5)
|
||||
|
||||
fig, ax = plt.subplots(subplot_kw=dict(projection='3d'))
|
||||
ax.add_collection(lc1)
|
||||
ax.add_collection(lc2)
|
||||
ax.add_collection(lc3)
|
||||
|
||||
leg = ax.legend([lc1, lc2, lc3], ['line1', 'line2', 'line 3'])
|
||||
h1, h2, h3 = leg.legend_handles
|
||||
|
||||
for oh, lh in zip((lc1, lc2, lc3), (h1, h2, h3)):
|
||||
assert oh.get_linestyles()[0] == lh._dash_pattern
|
||||
|
||||
|
||||
def test_handlerline3d():
|
||||
# Test marker consistency for monolithic Line3D legend handler.
|
||||
fig, ax = plt.subplots(subplot_kw=dict(projection='3d'))
|
||||
ax.scatter([0, 1], [0, 1], marker="v")
|
||||
handles = [art3d.Line3D([0], [0], [0], marker="v")]
|
||||
leg = ax.legend(handles, ["Aardvark"], numpoints=1)
|
||||
assert handles[0].get_marker() == leg.legend_handles[0].get_marker()
|
||||
|
||||
|
||||
def test_contour_legend_elements():
|
||||
x, y = np.mgrid[1:10, 1:10]
|
||||
h = x * y
|
||||
colors = ['blue', '#00FF00', 'red']
|
||||
|
||||
fig, ax = plt.subplots(subplot_kw=dict(projection='3d'))
|
||||
cs = ax.contour(x, y, h, levels=[10, 30, 50], colors=colors, extend='both')
|
||||
|
||||
artists, labels = cs.legend_elements()
|
||||
assert labels == ['$x = 10.0$', '$x = 30.0$', '$x = 50.0$']
|
||||
assert all(isinstance(a, mpl.lines.Line2D) for a in artists)
|
||||
assert all(same_color(a.get_color(), c)
|
||||
for a, c in zip(artists, colors))
|
||||
|
||||
|
||||
def test_contourf_legend_elements():
|
||||
x, y = np.mgrid[1:10, 1:10]
|
||||
h = x * y
|
||||
|
||||
fig, ax = plt.subplots(subplot_kw=dict(projection='3d'))
|
||||
cs = ax.contourf(x, y, h, levels=[10, 30, 50],
|
||||
colors=['#FFFF00', '#FF00FF', '#00FFFF'],
|
||||
extend='both')
|
||||
cs.cmap.set_over('red')
|
||||
cs.cmap.set_under('blue')
|
||||
cs.changed()
|
||||
artists, labels = cs.legend_elements()
|
||||
assert labels == ['$x \\leq -1e+250s$',
|
||||
'$10.0 < x \\leq 30.0$',
|
||||
'$30.0 < x \\leq 50.0$',
|
||||
'$x > 1e+250s$']
|
||||
expected_colors = ('blue', '#FFFF00', '#FF00FF', 'red')
|
||||
assert all(isinstance(a, mpl.patches.Rectangle) for a in artists)
|
||||
assert all(same_color(a.get_facecolor(), c)
|
||||
for a, c in zip(artists, expected_colors))
|
||||
|
||||
|
||||
def test_legend_Poly3dCollection():
|
||||
|
||||
verts = np.asarray([[0, 0, 0], [0, 1, 1], [1, 0, 1]])
|
||||
mesh = art3d.Poly3DCollection([verts], label="surface")
|
||||
|
||||
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
|
||||
mesh.set_edgecolor('k')
|
||||
handle = ax.add_collection3d(mesh)
|
||||
leg = ax.legend()
|
||||
assert (leg.legend_handles[0].get_facecolor()
|
||||
== handle.get_facecolor()).all()
|
Loading…
Add table
Add a link
Reference in a new issue