Source code for pimpmyplot.legend



import matplotlib.pyplot as plt
import matplotlib

from pimpmyplot.utils import setupax

from typing import Tuple




EXT_LOC_MAP = {
    'ext lower center': ('upper center', (.5, -0.1)),
    'ext lower right': ('upper right', (1., -0.1)),
    'ext lower left': ('upper left', (-0.01, -0.1)),
    'ext upper center': ('lower center', (.5, 1.03)), 
    'ext upper right': ('lower right', (1, 1.03)),
    'ext upper left': ('lower left', (-0.01, 1.03)),

    'ext side lower left': ('lower right', (None, -0.02)),
    'ext side upper left': ('upper right', (None, 1.02)),
    'ext side lower right': ('lower left', (1., -0.02)),
    'ext side upper right': ('upper left', (1., 1.02))
}


DEFAULT_OFFSET_ANNOTATION_LEGEND =  {
    'left': (5, 0),
    'center': (0, 5),
    'right': (-5, 0)
}


def extract_line_info(line):
    """
    Extract coordinate and label information from a matplotlib Line2D object.

    Parameters
    ----------
    line : matplotlib.lines.Line2D
        The line object to extract information from.

    Returns
    -------
    Tuple[float, float, str]
        A tuple containing (last_x_coord, last_y_coord, label).
    """
    x_coord = line.get_xdata()[-1]
    y_coord = line.get_ydata()[-1]
    label = line.get_label() 
    return x_coord, y_coord, label


def get_number_labels(ax) -> int:
    """
    Count the number of labeled handles in the axis.

    Parameters
    ----------
    ax : matplotlib.axes.Axes
        The axis to inspect.

    Returns
    -------
    int
        The number of legend labels found.
    """
    _, labels = ax.get_legend_handles_labels()
    return len(labels)


def infer_ncol(ncol: int, ax, loc: str):
    """
    Determine the appropriate number of columns for the legend if not explicitly provided.

    Parameters
    ----------
    ncol : int
        The requested number of columns.
    ax : matplotlib.axes.Axes
        The axis object.
    loc : str
        The location string for the legend.

    Returns
    -------
    int
        The inferred number of columns.
    """
    if ncol == -1:
        ncol = get_number_labels(ax=ax)
    elif ncol is None:
        ncol = 1
        if not loc.startswith('ext side'):
            ncol = get_number_labels(ax=ax)
    return ncol


def get_tick_labels_shape(ax) -> Tuple:
    """
    Calculate the maximum width and height of x and y tick labels.

    Parameters
    ----------
    ax : matplotlib.axes.Axes
        The axis object.

    Returns
    -------
    Tuple[Tuple[float, float], Tuple[float, float]]
        ((max_x_width, max_x_height), (max_y_width, max_y_height)) in points.
    """
    x_width = max(*[l.get_window_extent().width for l in ax.get_xticklabels()])
    x_height = max(*[l.get_window_extent().height for l in ax.get_xticklabels()])

    y_width = max(*[l.get_window_extent().width for l in ax.get_yticklabels()])
    y_height = max(*[l.get_window_extent().height for l in ax.get_yticklabels()])

    return (x_width, x_height), (y_width, y_height)
    


def _build_ext_bbox_to_anchor(loc: str, ax) -> Tuple:
    """
    Calculate the bounding box to anchor (bbox_to_anchor) for external legend locations.

    Parameters
    ----------
    loc : str
        The external location string (starts with 'ext').
    ax : matplotlib.axes.Axes
        The axis object.

    Returns
    -------
    Tuple[str, Tuple[float, float]]
        A tuple of (standardized_loc, bbox_to_anchor_coordinates).
    """

    _cases_dynamic_mapping = (loc.endswith('left') and loc.startswith('ext side'))
    loc, bbox_to_anchor = EXT_LOC_MAP[loc]
    if not _cases_dynamic_mapping:
        return loc, bbox_to_anchor

    # TODO Make MARGIN dynamic. Percentage of (max_width_labels) / (ax_width)?
    MARGIN = 0.05
    y_labels = ax.get_yticklabels()
    max_width_labels = max(*[l.get_window_extent().width for l in y_labels])
    ax_width = ax.get_position().transformed(plt.gca().transAxes.get_affine()).width
    L = (max_width_labels) / (ax_width) + MARGIN
    bbox_to_anchor = (-L, bbox_to_anchor[1])

    return loc, bbox_to_anchor




[docs] @setupax def legend(*args, shadow: bool = True, frameon: bool = True, loc: str = 'ext lower center', bbox_to_anchor: Tuple = None, edgecolor: str = 'k', ax: matplotlib.axes.Axes = None, ncol: int = None, **kwargs) -> matplotlib.legend.Legend: """ Enhanced matplotlib legend with support for external positioning and custom styling. Parameters ---------- *args : Arguments passed to the standard `ax.legend`. shadow : bool, optional Whether to draw a shadow behind the legend, by default True. frameon : bool, optional Whether to draw a frame around the legend, by default True. loc : str, optional The legend location. Supports standard matplotlib locations and 'ext' prefixed locations (e.g., 'ext lower center' or 'ext side upper left'), by default 'ext lower center'. bbox_to_anchor : Tuple, optional The bounding box to anchor the legend to, by default None. edgecolor : str, optional The color of the legend frame edge, by default 'k' (black). ax : matplotlib.axes.Axes, optional The matplotlib axis to add the legend to. If None, uses the current axis (via @setupax). ncol : int, optional Number of columns. If -1, sets to the number of labels. If None, it is inferred. **kwargs : Additional keyword arguments passed to `ax.legend`. Returns ------- matplotlib.legend.Legend The created legend object. """ original_loc = loc if loc.startswith('ext'): loc, bbox_to_anchor = _build_ext_bbox_to_anchor(loc=loc, ax=ax) ncol = infer_ncol(ncol=ncol, ax=ax, loc=original_loc) l = ax.legend(*args, shadow=shadow, frameon=frameon, loc=loc, bbox_to_anchor=bbox_to_anchor, edgecolor=edgecolor, ncol=ncol, **kwargs) # Shadow style l._shadow_props = {'ox': 3, 'oy': -3, 'color': 'black', 'shade': 1., 'alpha': 1.} return l
@setupax def annotation_legend(ax: matplotlib.axes.Axes = None, offset: Tuple = None, fontweight: str = None, size: str = None, ha: str = 'left', color: str = None): """ Create an annotation-style legend where labels are printed directly next to the lines. Labels are placed at the end of each line on the right side of the plot. Parameters ---------- ax : matplotlib.axes.Axes, optional The matplotlib axis to apply annotations to. If None, uses the current axis (via @setupax). offset : Tuple, optional Offset in points for the labels (dx, dy), by default derived from `ha`. fontweight : str, optional The font weight for the labels (e.g., 'bold'). size : str or float, optional The font size for the labels. ha : str, optional Horizontal alignment of the labels, by default 'left'. color : str, optional Color for all labels. If None, each label uses the color of its corresponding line. """ if offset is None: offset = DEFAULT_OFFSET_ANNOTATION_LEGEND[ha] for line in ax.get_lines(): x_coord, y_coord, label = extract_line_info(line=line) c = line.get_color() if color is None else color ax.annotate( text=label, xy=(x_coord, y_coord), textcoords='offset points', xytext=offset, color=c, ha=ha, size=size, fontweight=fontweight )