Source code for mplutils.utils

import typing as tp
import inspect
import os
import pathlib
import math

import cycler
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.axes as maxes
import matplotlib.text as mtext
import matplotlib.figure as mfig
import matplotlib.colors as mcolors
import matplotlib.colorbar as mcbar
import matplotlib.cm as cm
import matplotlib.patches as mpatches
import numpy as np
import numpy.typing as npt

from .constants import PTS_PER_INCH
from .colors import OkabeItoPalette
from ._core import FontsizeLike, convert_to_inches
from ._fixed_layout import get_axes_for_layout, get_axes_grid, get_bboxes_inch_grid
from ._layout import (
    set_colorbar_thickness_inch,
    set_colorbar_pad_inch,
    set_axes_width_inch,
    set_axes_height_inch,
)
from ._utils import normalize_lw_fs_lh


[docs] def centers_to_edges( centers: npt.ArrayLike, lower: None | float = None, upper: None | float = None, ) -> npt.NDArray[np.float64]: """ Work out bin edges from bin centers. If the bins don't have constant size, at least one limit has to be provided, from which the edges can be determined. .. attention:: If `centers` are not the centers of *all* bins, or if `lower` or `upper` are not indeed the lower or upper edge, `centers_to_edges` will silently produce nonsense. Parameters ---------- centers : array_like, shape(n) centers of the bins lower, uppper : float, optional Lower/upper limits of the range. At least one limit must be provided if bins don't have a constant size. If both lower and upper limits are provided, the lower one will be prioritized. Returns ------- edges : ndarray, shape(n+1) Edges of the bins. """ # if bins don't have a constant size, determine xbinedges differently centers = np.asarray(centers, copy=True).astype(np.float64) edges = np.empty(len(centers) + 1) binsizes = np.diff(centers) if not np.allclose(binsizes, binsizes[0], atol=0.0): if lower is not None: # take lower edge and work out binsize forward edges[0] = lower for i in range(len(centers)): edges[i + 1] = 2.0 * centers[i] - edges[i] elif upper is not None: # take upper edge and work out binsize backward edges[-1] = upper for i in reversed(range(len(centers))): edges[i] = 2.0 * centers[i] - edges[i + 1] else: # cannot determine binsize, throw exception raise ValueError( "cannot determine binsizes without 'upper' or 'lower' bounds" ) else: # bins have equal size edges[:-1] = centers - 0.5 * binsizes[0] edges[-1] = centers[-1] + 0.5 * binsizes[0] return edges
[docs] def savefig( fname: None | str = None, ftype: None | str | tp.Sequence[str] = None, fig: None | mfig.Figure = None, **savefig_kwargs, ) -> None: r""" Save a :class:`matplotlib.figure.Figure` to a file. Wraps :func:`matplotlib.pyplot.savefig`. Parameters ---------- fig : :class:`matplotlib.figure.Figure`, optional If None, use last active figure. fname : str, optional File name (and path). If None, uses the filename of the programs entry point. If a file name without a file-type extension is provided, uses `rcParams["savefig.format"] <https://matplotlib.org/stable/users/explain/customizing.html#matplotlibrc-sample>`__ unless `ftype` is provided. If `fname` ends in ``/`` or ``\``, it is assumed as a directory in which the output will be saved using the file name of the programs entry point. If ``*`` is in `fname`, replace it with the file name of the programs entry point. ftype : str or Sequence[str], optional The file type(s). If provided, the appropriate extension is appended to `fname` and the file is saved as that file type. If a sequence is provided, saves one file for each provided type. If nothing is provided, infers the file type from `fname`. Other Parameters ---------------- **savefig_kwargs Keyword arguments of :func:`matplotlib.pyplot.savefig`. See also -------- matplotlib.pyplot.savefig Examples -------- Assume a script named ``my_plot.py`` with the contents .. code-block:: python :caption: ``my_plot.py`` import mplutils as mplu import matplotlib.pyplot as plt plt.plot(some_data) mplu.savefig() mplu.savefig(ftype="pdf") mplu.savefig("output/") mplu.savefig("output/*_extra") mplu.savefig("a_plot") mplu.savefig("a_plot", ftype=("pdf", "png")) mplu.savefig("a_plot.pdf", ftype=("pdf", "png")) Executing ``my_plot.py`` will save: - ``my_plot.png`` (assuming `plt.rcParams["savefig.format"] <https://matplotlib.org/stable/users/explain/customizing.html#the-default-matplotlibrc-file>`__ = "png") - ``my_plot.pdf`` - ``output/my_plot.png`` - ``output/my_plot_extra.png`` - ``a_plot.png`` - ``a_plot.pdf`` and ``a_plot.png`` - ``a_plot.pdf.pdf`` and ``a_plot.pdf.png`` """ fig = fig or plt.gcf() main_caller_fname = inspect.stack()[-1].filename main_caller_fname_base, _ = os.path.splitext(main_caller_fname) if fname is None: fname = main_caller_fname_base elif fname.endswith("\\") or fname.endswith("/"): if fname.startswith("\\") or fname.startswith("/"): fname = fname[1:] path = pathlib.Path(main_caller_fname_base) fname = path.parent / fname / path.name # type: ignore elif "*" in fname: path = pathlib.Path(main_caller_fname_base) fname = fname.replace("*", path.name) assert fname is not None # create directories if not present pathlib.Path(fname).parent.mkdir(parents=True, exist_ok=True) # type: ignore if ftype is None: fig.savefig(fname, **savefig_kwargs) else: ftypes = (ftype,) if isinstance(ftype, str) else tuple(ftype) for type in ftypes: fig.savefig(f"{fname}.{type}", **savefig_kwargs)
[docs] def for_pcolormesh( xcenters: npt.ArrayLike, ycenters: npt.ArrayLike, z: npt.ArrayLike, *, iteration_order: tp.Literal["x_first", "y_first"] = "x_first", xmin: float | None = None, xmax: float | None = None, ymin: float | None = None, ymax: float | None = None, ) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]]: """ Convert `xcenters, ycenters, z` such that they can be plotted by :func:`matplotlib.pyplot.pcolormesh`. See Examples for an example how the arrays are structured. Parameters ---------- xcenters : array_like A flat sequence of x-coordinates, representing the horizontal position of each element in the grid. If the bins are not all equal in size, `xmin` or `xmax` needs to be specified. ycenters : array_like A flat sequence of y-coordinates, representing the vertical position of each element in the grid. If the bins are not all equal in size, `ymin` or `ymax` needs to be specified. z : array_like A flat sequence of data values corresponding to each (xcenters, ycenters) coordinate The x, y, and z sequences must all be the same length, where each z[i] corresponds to the position (xcenters[i], ycenters[i]) in a 2D grid. iteration_order : {"x_first", "y_first"}, default "x_first" Specify if the outer iteration is along x or y. xmin, xmax, ymin, ymax : float, optional If x (y) bins do not have constant size, at least one corresponding limit has to be provided. .. note:: This does not refer to the limits of the bin centers, but the limits of the bin edges! Returns ------- X, Y, C : ndarray Output formatted to work with :func:`matplotlib.pyplot.pcolormesh`. """ z_ = np.asarray(z) if len(np.asarray(xcenters)) != len(np.asarray(ycenters)) != len(z_): raise ValueError("xcenters, ycenters, and z must have the same length") x_ = np.unique(xcenters) y_ = np.unique(ycenters) xedges = centers_to_edges(x_, xmin, xmax) yedges = centers_to_edges(y_, ymin, ymax) if iteration_order == "x_first": z_ = z_.reshape(y_.size, x_.size) elif iteration_order == "y_first": z_ = z_.reshape(x_.size, y_.size).T else: msg = f"'{iteration_order=}', but it should be 'x_first' or 'y_first'" raise ValueError(msg) return xedges, yedges, z_
[docs] def convert_to_steps( x: npt.ArrayLike, y: npt.ArrayLike, start_at: float | tp.Literal["auto"] = 0.0, xlim_lower: float | None = None, xlim_upper: float | None = None, ) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: """ Convert (x,y) data to steps. Parameters ---------- x : array_like *x* data, corresponding to the centers of the resulting steps. y : array_like corresponding *y* data. start_at : float or "auto", default 0.0 Start/end steps at this value. If "auto", start (end) at first (last) *y* value. xlim_lower, ylim_uppper : float, optional Lower/upper limits of the range. At least one limit must be provided if *x* data doesn't have constant spacing. If both lower and upper limits are provided, the lower one will be prioritized. Returns ------- x, y : ndarray Data corresponding to steps. Examples -------- .. plot:: _examples/convert_to_steps.py :include-source: """ y = np.asarray(y) edges = centers_to_edges(x, xlim_lower, xlim_upper) x_ = np.repeat(edges, 2) y_ = np.empty_like(x_, dtype=np.float64) y_[0] = start_at if start_at != "auto" else y[0] y_[1:-1] = np.repeat(y, 2) y_[-1] = start_at if start_at != "auto" else y[-1] return x_, y_
[docs] def set_color_cycle( *colors: str, nsteps: int = 7, fig: mfig.Figure | None = None, ) -> None: """ Set the color cycle in plots. Modifies ``rcParams`` of matplotlib (see `here <https://matplotlib.org/stable/users/explain/customizing.html#the-default-matplotlibrc-file>`__). Parameters ---------- *colors : str, optional Colors seperated by comma and given in HEX-codes. If no colors are provided, defaults to :class:`.OkabeItoPalette` (see `here <https://jfly.uni-koeln.de/color/>`__ for a motivation). Alternatively, the name of a colormap can be specified and the color cycle picks ``nsteps`` colors from that colormap. See `matplolib's colormaps documenation <https://matplotlib.org/stable/users/explain/colors/colormaps.html>`__ for available colormaps. nsteps : int, default 7 Define how many different colors will be set in the color cycler. Irrelevant if specific colors are passed in `colors`. fig : :class:`matplotlib.figure.Figure`, optional Optionally, provide a figure. The color cycle of all axes of that figure will be updated. If None, check if a figure already exists. If so, update the color cycle of all axes of the last active figure. Examples -------- .. code:: python # Set color cycle to red, green, blue set_color_cycle("#FF0000", "#00FF00", "#0000FF") # Set color cycle to Okabe and Ito palette set_color_cycle() # Set color cycle to a continuous colormap that repeats after 8 colors set_color_cylce(cmap="viridis", nsteps=8) # Set color cycle to a diverging colormap that passes through zero # with the second color set_color_cylce("RdBu", nsteps=3) See also -------- OkabeItoPalette OkabeItoMutedPalette OkabeItoAccentPalette """ # format colors appropriately if not colors: if nsteps > 7: msg = ( "For Okabe and Ito palette, a maximum of 7 distinct colors can " f"be used, but {nsteps=}" ) raise ValueError(msg) colors = OkabeItoPalette()[:nsteps] if len(colors) == 1 and colors[0] in plt.colormaps(): cmap_ = plt.get_cmap(colors[0]) colors = tuple([mcolors.to_hex(cmap_(i / (nsteps - 1))) for i in range(nsteps)]) # no figure passed, but a figure exists if fig is None and plt.get_fignums(): fig = plt.gcf() # some axes exist already, update the cycler for them if fig is not None and fig.get_axes(): axs = fig.get_axes() color_cycler = cycler.cycler(color=colors) for ax in axs: ax.set_prop_cycle(color_cycler) # update rcParams so future axes will use this one cycler_str = "cycler('color', [" for color in colors: if color[0] == "#": color = color[1:] cycler_str += f"'{color}', " cycler_str += "])" plt.rcParams["axes.prop_cycle"] = cycler_str
[docs] def set_latex_backend(font: tp.Literal["FiraSans", "Times", "ScholaX"]) -> None: """ Enable a latex backend for rendering figures. Must be called before figure creation. Parameters ---------- font : {``"FiraSans"``, ``"Times"``, ``"ScholaX"``} Choose a font. ``"FiraSans"`` uses ``lualatex`` (which is very slow) as backend, the other fonts use ``pdflatex`` (not as slow). """ plt.rcParams["pgf.rcfonts"] = False plt.rcParams["backend"] = "pgf" if font == "FiraSans": plt.rcParams["font.family"] = "sans-serif" plt.rcParams["pgf.texsystem"] = "lualatex" plt.rcParams["pgf.preamble"] = ( r"\usepackage[mathrm=sym]{unicode-math}\setmathfont{Fira Math}[Scale=MatchUppercase,Numbers=Tabular]\setsansfont{Fira Sans}[Scale=MatchUppercase,Numbers=Lining]\usepackage{picture,xcolor}\usepackage{nicefrac}" ) elif font == "Times": plt.rcParams["font.family"] = "serif" plt.rcParams["pgf.texsystem"] = "pdflatex" plt.rcParams["pgf.preamble"] = ( r"\usepackage[T1]{fontenc}\usepackage{newtxtext,newtxmath}\usepackage{picture,xcolor}\usepackage{nicefrac}" ) elif font == "ScholaX": plt.rcParams["font.family"] = "serif" plt.rcParams["pgf.texsystem"] = "pdflatex" plt.rcParams["pgf.preamble"] = ( r"\usepackage{scholax}\usepackage{amsmath,amsthm}\usepackage[scaled=1.075,ncf,vvarbb]{newtxmath}\usepackage{picture,xcolor}\usepackage{nicefrac}" )
[docs] def crop_colormap( cmap, x: float = 0.0, y: float = 1.0, new_cmap_name: str | None = None, register: bool = False, n_lut: int = 256, ) -> mcolors.LinearSegmentedColormap: """ Return a colormap within (x,y) range Parameters ---------- cmap A matplotlib colormap, as returned by ``matplotlib.colormaps["colormap_name"]``. x : float, default = 0.0 Lower limit of the colorbar (in relative units). y : float, default = 1.0 Upper limit of the colorbar (in relative units). new_cmap_name : str, optional Optionally, give the colormap a name. If ``None``, colormap will be called ``"_new_colormap"``. register : bool, default ``False`` Register the colormap. If it is registered, one can call it simply by it's name (i.e., ``new_cmap_name``). n_lut : int, default 256 The number of segments of the old colormap. Returns ------- new_cmap : :class:`matplotlib.colors.LinearSegmentedColormap`. The new colormap. Examples -------- .. plot:: _examples/crop_colormap.py :include-source: """ if x < 0.0 or x > 1.0: raise ValueError(f"{x=}, but it must be within [0.0, 1.0]") if y < 0.0 or y > 1.0: raise ValueError(f"{y=}, but it must be within [0.0, 1.0]") # Ensure we have a list of colors (works for both Listed and LinearSegmented) if hasattr(cmap, "colors"): # ListedColormap colors = cmap.colors else: # LinearSegmentedColormap → sample N colors colors = [cmap(i / (n_lut - 1)) for i in range(n_lut)] length = len(colors) index_low = int(x * length) index_high = int(y * length) new_cmap_name = new_cmap_name or "_new_colormap" new_cmap = mcolors.LinearSegmentedColormap.from_list( new_cmap_name, colors[index_low:index_high] ) if register: matplotlib.colormaps.register(new_cmap) return new_cmap
[docs] def textwithbox( axes: maxes.Axes, x: float, y: float, text: str, pad: float = 1.0, boxbackground: str | None = "white", boxedgecolor: str = "black", boxedgewidth: float = 0.5, **text_kwargs, ) -> mtext.Text: """ Plot text with matplotlib surrounded by a box using LaTeX commands. Parameters ---------- ax : :class:`matplotlib.axes.Axes` the axes x : float x-position y : float y-position text : str The text to be surrounded by the box pad : float, default: :code:`1.0` (in pts) padding between boxedge and text boxbackground : ``None``, ``False``, or str, default: ``"white"`` background of box - ``None`` or ``False``: No background color - str: latex xcolor named color boxedgecolor : str, optional, default: :code:`"black"` edge color using named color from latex package *xcolor* only used if boxbackground != None boxedgewidth : float, default :code:`0.5` (in pts) Linewidth of the box' edges. **text_kwargs Additional :class:`matpotlib.text.Text` keyword arguments. Returns ------- text : :class:`matplotlib.text.Text` The text artist. See also -------- set_latex_backend """ sep = r"\setlength{\fboxsep}{%lfpt}" % pad rule = r"\setlength{\fboxrule}{%lfpt}" % boxedgewidth if boxbackground is not None: text = r"%s\fcolorbox{%s}{%s}{%s}" % ( sep + rule, boxedgecolor, boxbackground, text, ) else: text = r"%s\fbox{%s}" % (sep + rule, text) return axes.text(x, y, text, **text_kwargs)
[docs] def dotted( linewidth: float | None = None, fontsize: FontsizeLike = None, legend_handlelength: float | None = None, **aliases, ) -> tuple[float, tuple[float, float]]: """ Return a tuple to create a dotted line that fits perfectly into a legend. For that to work properly you may need to provide the linewidth of the graph and the fontsize of the legend. Parameters ---------- linewidth (or lw) : float, optional, default: ``rcParams["lines.linewidth"]`` fontsize : float or str, Optional, default: ``rcParams["legend.fontsize"]`` The fontsize used in the legend - float: fontsize in pts - str: :code:`"xx-small"`, :code:`"x-small"`, :code:`"small"`, :code:`"medium"`, :code:`"large"`, :code:`"x-large"`, :code:`"xx-large"`, :code:`"larger"`, or :code:`"smaller"` legend_handlelength (or lh) : float, default ``rcParams["legend.handlelength"]`` Length of the legend handles (the dotted line, in this case) in font units Returns ------- tuple : (float, (float, float)) tuple to be used as linetype in plotting See Also -------- :func:`.dash_dotted` : Create dotted-line linestyle. :func:`.dashed` : Create dash-dotted-line linestyle. Examples -------- .. code-block:: python import matplotlib.pyplot as plt import mplutils as mplu plt.plot([0., 1.], linestyle=mplu.dotted()) plt.legend() # if one changes the linewidth, the fontsize of the legend, or the # handlelength of the legend from the default, this needs to be passed # to dotted(). plt.plot([0., 1.], linewidth=2., linestyle=(mplu.dotted(linewidth=2., legend_handlelength=3., fontsize="x-small"))) plt.legend(fontsize="x-small", handlelength=3.) # alternatively, use rcParams to set these values plt.rcParams["lines.linewidth"] = 2. plt.rcParams["legend.handlelength"] = 3. plt.rcParams["legend.fontsize"] = "x-small" plt.plot([0., 1.], linestyle=mplu.dotted()) plt.legend() .. plot:: _examples/legend_dotted.py :include-source: """ lw_, fs_, lh_ = normalize_lw_fs_lh( linewidth, fontsize, legend_handlelength, **aliases ) total_points = fs_ * lh_ / lw_ n_dots = math.ceil(total_points / 2.0) spacewidth = (total_points - n_dots) / (n_dots - 1) return 0.0, (1.0, spacewidth)
[docs] def dash_dotted( ratio: float = 3.0, n_dashes: int = 2, linewidth: float | None = None, fontsize: FontsizeLike = None, legend_handlelength: float | None = None, **aliases, ) -> tuple[float, tuple[float, float, float, float]]: """ Return a tuple to create a dash-dotted line that fits perfectly into a legend. For that to work properly you may need to provide the linewidth of the graph and the fontsize of the legend. Parameters ---------- ratio : float, default: 3.0 Ratio between dash-length and gap-length n_dashes : int, default: 2 Number of dashes drawn linewidth (or lw) : float, optional, default: ``rcParams["lines.linewidth"]`` fontsize : float or str, Optional, default: ``rcParams["legend.fontsize"]`` The fontsize used in the legend - float: fontsize in pts - str: :code:`"xx-small"`, :code:`"x-small"`, :code:`"small"`, :code:`"medium"`, :code:`"large"`, :code:`"x-large"`, :code:`"xx-large"`, :code:`"larger"`, or :code:`"smaller"` legend_handlelength (or 'lh') : float, default :code:`rcParams["legend.handlelength"]` Length of the legend handles (the dotted line, in this case) in font units Returns ------- tuple : (float, (float, float, float, float)) tuple to be used as linetype in plotting See Also -------- :func:`.dotted` : Create dotted-line linestyle. :func:`.dashed` : Create dash-dotted-line linestyle. Examples -------- .. plot:: _examples/legend_dash_dotted.py :include-source: """ lw_, fs_, lh_ = normalize_lw_fs_lh( linewidth, fontsize, legend_handlelength, **aliases ) total_points = fs_ * lh_ / lw_ spacewidth = (total_points - n_dashes) / (2.0 * n_dashes - 1 + n_dashes * ratio) dashwidth = ratio * spacewidth return 0.0, (dashwidth, spacewidth, 1.0, spacewidth)
[docs] def dashed( ratio: float = 1.5, n_dashes: int = 3, linewidth: float | None = None, fontsize: FontsizeLike = None, legend_handlelength: float | None = None, **aliases, ) -> tuple[float, tuple[float, float]]: """ Return a tuple to create a dashed line that fits perfectly into a legend. For that to work properly you may need to provide the linewidth of the graph and the fontsize of the legend. Parameters ---------- ratio : float, default: 1.5 Ratio between dash-length and gap-length n_dashes : int, default: 3 Number of dashes drawn linewidth (or lw) : float, optional, default: rcParams["lines.linewidth"] fontsize : float or str, Optional, default: rcParams["legend.fontsize"] The fontsize used in the legend - float: fontsize in pts - str: :code:`"xx-small"`, :code:`"x-small"`, :code:`"small"`, :code:`"medium"`, :code:`"large"`, :code:`"x-large"`, :code:`"xx-large"`, :code:`"larger"`, or :code:`"smaller"` legend_handlelength (or lh) : float, default ``rcParams["legend.handlelength"]`` Length of the legend handles (the dotted line, in this case) in font units Returns ------- (float, (float, float, float, float)) tuple to be used as linetype in plotting See Also -------- :func:`.dotted` : Create dotted-line linestyle. :func:`.dash_dotted` : Create dash-dotted-line linestyle. Examples -------- .. plot:: _examples/legend_dashed.py :include-source: """ lw_, fs_, lh_ = normalize_lw_fs_lh( linewidth, fontsize, legend_handlelength, **aliases ) total_points = fs_ * lh_ / lw_ n_gaps = n_dashes - 1 spacewidth = total_points / (n_gaps + n_dashes * ratio) dashwidth = ratio * spacewidth return 0.0, (dashwidth, spacewidth)
[docs] def add_abc( fig: mfig.Figure | None = None, xoffset_pts: npt.ArrayLike = 2.0, yoffset_pts: npt.ArrayLike = -12.0, anchor: tp.Literal[ "top left", "top right", "bottom left", "bottom right" ] = "top left", labels: str = "a;b;c;d;e;f;g;h;i;j;k;l;m;n;o;p;q;r;s;t;u;v;w;x;y;z", labels_sep: str = ";", pre: str = "(", post: str = ")", start_at: int = 0, rowsfirst: bool = True, **text_kwargs, ) -> dict[maxes.Axes, mtext.Text]: """ Add labels to all suplots in `fig`. By default, adds '(a)', '(b)', ... to each subplot in the upper-right corner. Parameters ---------- fig : :class:`matplotlib.figure.Figure`, optional If None, use last active figure. xoffset_pts : float, default: ``2.0`` Offset in pts from `anchor`. Positive moves right. yoffset_pts : float, default: ``-12.0`` Offset in pts from `anchor`. Positive moves up. anchor : {``"top left"``, ``"top right"``, ``"bottom left"``, ``"bottom right"``} Specify anchor point of the labels (offsets are relative to this). Refers to the corner of the graph-area of the axes. labels : str, optional, default ``"a;b;c;d; ... ;y;z",`` A string of labels, where each label is seperated by `labels_sep`. If ``None``, use label of the respective axes (i.e., :meth:`ax.get_label() <matplotlib.artist.Artist.get_label>`). labels_sep : str, default ``";"`` Separator used for `labels`. pre : str, default: ``"("`` String in front of `labels`. post : str, default: ``")"`` String after `labels`. start_at : int, default: 0 Skip `start_at` entries in `labels`. Only applies if `labels` is not ``None``. rowsfirst : bool, default: ``True`` Label rows first, e.g., "a b c / d e f" instead of "a c e / b d f". Only applies if `labels` is not ``None``. text_kwargs Additional keyword arguments of :class:`matplotlib.text.Text`. For a list thereof, see `here <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.text.html>`_. Returns ------- text_dict : dict[:class:`matplotlib.axes.Axes`, :class:`matplotlib.text.Text`] A dictionary with the axes of `fig` as keys and the corresponding text instances added by ``add_abc`` as values. Can be used to manipulate the text later (to, e.g., change the color of the text only for certain subplots). Notes ----- Cannot handle a fancy :class:`matplotlib.gridspec.GridSpec`, e.g., where one subplot spans multiple other subplots. If you need one of those, you're on your own. Examples -------- .. plot:: _examples/add_abc0.py :include-source: .. plot:: _examples/add_abc1.py :include-source: .. plot:: _examples/add_abc2.py :include-source: """ fig = fig or plt.gcf() axs = get_axes_grid(get_axes_for_layout(fig.axes)) nrows, ncols = axs.shape valid_anchors = ["top left", "top right", "bottom left", "bottom right"] if anchor not in valid_anchors: err_msg = f"{anchor=}, but it needs to be one of {valid_anchors}" raise ValueError(err_msg) topbottom, leftright = anchor.split(" ") # process offsets xoffsets_inch = np.asarray(xoffset_pts) / PTS_PER_INCH if xoffsets_inch.ndim == 0: xoffsets_inch = np.full(axs.shape, xoffsets_inch) elif xoffsets_inch.ndim == 1: xoffsets_inch = np.repeat(xoffsets_inch[np.newaxis, :], nrows, axis=0) if xoffsets_inch.shape != (nrows, ncols): msg = f"{xoffset_pts=} but must be either scalar or of shapes ({ncols},) or ({nrows}, {ncols})" raise ValueError(msg) yoffsets_inch = np.asarray(yoffset_pts) / PTS_PER_INCH if yoffsets_inch.ndim == 0: yoffsets_inch = np.full(axs.shape, yoffsets_inch) elif yoffsets_inch.ndim == 1: yoffsets_inch = np.repeat(yoffsets_inch[:, np.newaxis], ncols, axis=1) if yoffsets_inch.shape != (nrows, ncols): msg = f"{yoffset_pts=} but must be either scalar or of shapes ({nrows},) or ({nrows}, {ncols})" raise ValueError(msg) if labels is not None: labels_ = labels.split(labels_sep) out: dict[maxes.Axes, mtext.Text] = {} bboxes_inch = get_bboxes_inch_grid(fig, axs) for (row, col), bbox in np.ndenumerate(bboxes_inch): xloc = 0.0 if leftright == "left" else 1.0 yloc = 1.0 if topbottom == "top" else 0.0 x = xloc + xoffsets_inch[row, col] / bbox.width y = yloc + yoffsets_inch[row, col] / bbox.height if labels is None: text = pre + str(axs[row, col].get_label()) + post else: idx = start_at + (ncols * row + col if rowsfirst else nrows * col + row) text = pre + labels_[idx] + post out[axs[row, col]] = axs[row, col].text( x, y, text, transform=axs[row, col].transAxes, **text_kwargs ) return out
[docs] def square_polar_axes(ax: maxes.Axes | None = None, zorder: float = 0.98) -> None: """ Format a polar axes to have a squared outline. Parameters ---------- ax : :class:`matplotlib.axes.Axes`, optional If None, use last active axes. Examples -------- .. plot:: _examples/square_polar_axes.py :include-source: """ ax = ax or plt.gca() ec = plt.rcParams["axes.edgecolor"] def circle_to_square_distance(theta_rad): h = 1.0 max_cos_sin = np.maximum(np.abs(np.cos(theta_rad)), np.abs(np.sin(theta_rad))) return h / max_cos_sin - h ax.add_patch( mpatches.Rectangle( (-0.0, -0.0), 1.0, 1.0, transform=ax.transAxes, fill=False, edgecolor=ec, lw=plt.rcParams["axes.linewidth"], clip_on=False, ) ) kwargs = dict( clip_on=False, solid_capstyle="butt", lw=plt.rcParams["axes.linewidth"], c=ec, zorder=zorder, ) # crosshair for n in (0, 1, 2, 3): ax.axvline(np.pi / 2.0 * n, 0.0, 1.0, **kwargs) for angle in (30, 60, 120, 150, 210, 240, 300, 330): angle_rad = np.deg2rad(angle) dist = circle_to_square_distance(angle_rad) ax.axvline(angle_rad, 1.0 + dist / 3.0, 1.0 + dist, **kwargs) ax.set_axis_off()
def normalize_colorbar_metrics( fig: mfig.Figure, ax: maxes.Axes, thickness: float | str | None, pad: float | str | None, unit: tp.Literal["mm", "inch", "pts"], location: tp.Literal["left", "right", "top", "bottom"] = "right", ) -> tuple[float | None, float | None]: fw, fh = fig.get_size_inches() pos = ax.get_position() if location in "left right": ax_size = pos.width fig_size = fw else: # top bottom ax_size = pos.height fig_size = fh metrics: list[float | None] = [None, None] for i, metric in enumerate((thickness, pad)): if metric is None: continue elif isinstance(metric, str): if metric[-1] != "%": raise ValueError("invalid string, must end with a %") metrics[i] = (float(metric[:-1]) * ax_size / 100.0) * fig_size else: metrics[i] = convert_to_inches(metric, unit) return tuple(metrics) # type: ignore
[docs] def add_colorbar( mappable: cm.ScalarMappable, ax: maxes.Axes | None = None, location: tp.Literal["left", "right", "top", "bottom"] = "right", thickness: float | str | None = "5%", pad: float | str | None = "3.5%", unit: tp.Literal["inch", "pts", "mm"] = "pts", label: str | None = None, **text_kwargs, ) -> mcbar.Colorbar: """ Add a colorbar to `ax`. It uses :meth:`matplotlib.figure.Figure.colorbar` to create a colorbar and adjusts tick positions and labels appropriately depending on `location`. .. warning:: This method is intended to be used with :class:`.FixedLayoutEngine`. Matplotlib's own layouts (e.g., `constrained layout <https://matplotlib.org/stable/users/explain/axes/constrainedlayout_guide.html>`__) may not work as intended. .. warning:: If you use this method with :func:`~matplotlib.pyplot.imshow`, set the use the keyword ``aspect="auto"``, otherwise the colorbar may not be aligned properly. Parameters ---------- mappable : :class:`matplotlib.cm.ScalarMappable` The colormap described by this colorbar. For more information, see :func:`matplotlib.pyplot.colorbar`. ax : :class:`matplotlib.axes.Axes`, optional The axes to which the colorbar is added. If ``None``, use currently active axes. location : {"left", "right", "top", "bottom"}, default: ``"right"`` Location of the colorbar relative to `ax`. thickness : float or str or None, default "5%" The thickness of the colorbar in pts. If ``None``, use the default values of :meth:`matplotlib.figure.Figure.colorbar`. If str, must be of form "x%" and will set the thickness relative to the axes size. .. warning:: If *thickness* is of form "x%", it is relative to the initial size of the axes. If you resize the axes afterward (e.g., with :func:`.set_axes_size`), the thickness will not be updated. pad : float or str or None, default "3.5%" The pad between the colorbar and `axes` in pts. Analog to *thickness*. label : str, optional Label of the colorbar. By default, the label of a right-sided colorbar will read from top to bottom. unit : {"mm", "pts", "inch"}, default "pts" Unit of *thickness* and *pad*. Other parameters ---------------- **text_kwargs Additional keyword arguments passed to :meth:`matplotlib.colorbar.Colorbar.set_label`. Returns ------- colorbar : :class:`matplotlib.colorbar.Colorbar` Examples -------- .. plot:: _examples/add_colorbar.py :include-source: """ ax = ax or plt.gca() fig = ax.figure if not isinstance(fig, mfig.Figure): raise ValueError("ax must be in a top-level Figure") old_pos = ax.get_position().frozen() cbar = fig.colorbar(mappable, ax=ax, use_gridspec=False, location=location) axis = cbar.ax.yaxis if location in ("left", "right") else cbar.ax.xaxis axis.set_ticks_position(location) # type: ignore axis.set_label_position(location) # type: ignore thickness_inch, pad_inch = normalize_colorbar_metrics(fig, ax, thickness, pad, unit) if thickness_inch is not None: set_colorbar_thickness_inch(fig, cbar.ax, thickness_inch) if pad_inch is not None: set_colorbar_pad_inch(fig, cbar.ax, pad_inch) if label is not None: kwargs = text_kwargs.copy() if location == "right": if (not "verticalalignment" in kwargs) and (not "va" in kwargs): kwargs["verticalalignment"] = "baseline" kwargs.setdefault("rotation", 270.0) cbar.set_label(label, **kwargs) fw, fh = fig.get_size_inches() if location in "left right": anchor = (0, 0.5) if location == "right" else (1, 0.5) set_axes_width_inch(fig, ax, old_pos.width * fw, anchor) else: # top bottom anchor = (0.5, 1.0) if location == "top" else (0.5, 0.0) set_axes_height_inch(fig, ax, old_pos.height * fh, anchor) return cbar