multipitch-architectures / libfmp / b / b_plot.py
b_plot.py
Raw
"""
Module: libfmp.b.plot
Author: Frank Zalkow, Meinard Mueller
License: The MIT license, https://opensource.org/licenses/MIT

This file is part of the FMP Notebooks (https://www.audiolabs-erlangen.de/FMP)
"""

import numpy as np
from matplotlib import pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.gridspec as gridspec
import matplotlib.patches as mpatch


FMP_COLORMAPS = {
    'FMP_1': np.array([[1.0, 0.5, 0.0], [0.33, 0.75, 0.96], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0], [0.0, 1.0, 1.0],
                       [1.0, 0.0, 1.0],  [0.99, 0.51, 0.71], [0.53, 0.0, 0.46], [0.56, 0.93, 0.72], [0, 0, 0.9]])
}


def plot_signal(x, Fs=1, T_coef=None, ax=None, figsize=(6, 2), xlabel='Time (seconds)', ylabel='', title='', dpi=72,
                ylim=True, **kwargs):
    """Plot a signal, e.g. a waveform or a novelty function

    Notebook: B/B_PythonVisualization.ipynb

    Args:
        x: Input signal
        Fs: Sample rate
        T_coef: Time coeffients. If None, will be computed, based on Fs.
        ax: The Axes instance to plot on. If None, will create a figure and axes.
        figsize: Width, height in inches
        xlabel: Label for x axis
        ylabel: Label for y axis
        title: Title for plot
        dpi: Dots per inch
        ylim: True or False (auto adjust ylim or nnot) or tuple with actual ylim
        **kwargs: Keyword arguments for matplotlib.pyplot.plot

    Returns:
        fig: The created matplotlib figure or None if ax was given.
        ax: The used axes.
        line: The line plot
    """
    fig = None
    if ax is None:
        fig = plt.figure(figsize=figsize, dpi=dpi)
        ax = plt.subplot(1, 1, 1)
    if T_coef is None:
        T_coef = np.arange(x.shape[0]) / Fs

    if 'color' not in kwargs:
        kwargs['color'] = 'gray'

    line = ax.plot(T_coef, x, **kwargs)

    ax.set_xlim([T_coef[0], T_coef[-1]])
    if ylim is True:
        ylim_x = x[np.isfinite(x)]
        x_min, x_max = ylim_x.min(), ylim_x.max()
        if x_max == x_min:
            x_max = x_max + 1
        ax.set_ylim([min(1.1 * x_min, 0.9 * x_min), max(1.1 * x_max, 0.9 * x_max)])
    elif ylim not in [True, False, None]:
        ax.set_ylim(ylim)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    if fig is not None:
        plt.tight_layout()

    return fig, ax, line


def plot_matrix(X, Fs=1, Fs_F=1, T_coef=None, F_coef=None, xlabel='Time (seconds)', ylabel='Frequency (Hz)',
                xlim=None, ylim=None, clim=None, title='', dpi=72,
                colorbar=True, colorbar_aspect=20.0, cbar_label='', ax=None, figsize=(6, 3), **kwargs):
    """Plot a matrix, e.g. a spectrogram or a tempogram

    Notebook: B/B_PythonVisualization.ipynb

    Args:
        X: The matrix
        Fs: Sample rate for axis 1
        Fs_F: Sample rate for axis 0
        T_coef: Time coeffients. If None, will be computed, based on Fs.
        F_coef: Frequency coeffients. If None, will be computed, based on Fs_F.
        xlabel, ylabel: Label for x axis and y axis
        xlim, ylim, clim: Limits for x axis, yaxis, and image values
        title: Title for plot
        dpi: Dots per inch
        colorbar: Create a colorbar.
        colorbar_aspect: Aspect used for colorbar, in case only a single axes is used.
        cbar_label: Label for colorbar
        ax: Either (1.) a list of two axes (first used for matrix, second for colorbar), or (2.) a list with a single
            axes (used for matrix), or (3.) None (an axes will be created).
        figsize: Width, height in inches
        **kwargs: Keyword arguments for matplotlib.pyplot.imshow

    Returns:
        fig: The created matplotlib figure or None if ax was given.
        ax: The used axes.
        im: The image plot
    """
    fig = None
    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
        ax = [ax]
    if T_coef is None:
        T_coef = np.arange(X.shape[1]) / Fs
    if F_coef is None:
        F_coef = np.arange(X.shape[0]) / Fs_F

    if 'extent' not in kwargs:
        x_ext1 = (T_coef[1] - T_coef[0]) / 2
        x_ext2 = (T_coef[-1] - T_coef[-2]) / 2
        y_ext1 = (F_coef[1] - F_coef[0]) / 2
        y_ext2 = (F_coef[-1] - F_coef[-2]) / 2
        kwargs['extent'] = [T_coef[0] - x_ext1, T_coef[-1] + x_ext2, F_coef[0] - y_ext1, F_coef[-1] + y_ext2]
    if 'cmap' not in kwargs:
        kwargs['cmap'] = 'gray_r'
    if 'aspect' not in kwargs:
        kwargs['aspect'] = 'auto'
    if 'origin' not in kwargs:
        kwargs['origin'] = 'lower'

    im = ax[0].imshow(X, **kwargs)

    if len(ax) == 2 and colorbar:
        cbar = plt.colorbar(im, cax=ax[1])
        cbar.set_label(cbar_label)
    elif len(ax) == 2 and not colorbar:
        ax[1].set_axis_off()
    elif len(ax) == 1 and colorbar:
        plt.sca(ax[0])
        cbar = plt.colorbar(im, aspect=colorbar_aspect)
        cbar.set_label(cbar_label)

    ax[0].set_xlabel(xlabel)
    ax[0].set_ylabel(ylabel)
    ax[0].set_title(title)
    if xlim is not None:
        ax[0].set_xlim(xlim)
    if ylim is not None:
        ax[0].set_ylim(ylim)
    if clim is not None:
        im.set_clim(clim)

    if fig is not None:
        plt.tight_layout()

    return fig, ax, im


def plot_chromagram(*args, chroma_yticks=np.arange(12), **kwargs):
    """Calls libfmp.b.plot_matrix and sets chroma labels

    Notebook: B/B_PythonVisualization.ipynb

    Args:
        see libfmp.b.plot_matrix

    Returns:
        fig: The created matplotlib figure or None if ax was given.
        ax: The used axes.
        im: The image plot
    """
    if 'ylabel' not in kwargs:
        kwargs['ylabel'] = 'Chroma'
    fig, ax, im = plot_matrix(*args, **kwargs)

    chroma_names = 'C C# D D# E F F# G G# A A# B'.split()
    ax[0].set_yticks(np.array(chroma_yticks))
    ax[0].set_yticklabels([chroma_names[i] for i in chroma_yticks])

    return fig, ax, im


def compressed_gray_cmap(alpha=5, N=256, reverse=False):
    """Creates a logarithmically or exponentially compressed grayscale colormap

    Notebook: B/B_PythonVisualization.ipynb

    Args:
        alpha: The compression factor. If alpha > 0, it performs log compression (enhancing black colors).
            If alpha < 0, it performs exp compression (enhancing white colors).
            Raises an error if alpha = 0.
        N: The number of rgb quantization levels (usually 256 in matplotlib)
        reverse: If False then "white to black", if True then "black to white"

    Returns:
        color_wb: The colormap
    """
    assert alpha != 0

    gray_values = np.log(1 + abs(alpha) * np.linspace(0, 1, N))
    gray_values /= gray_values.max()

    if alpha > 0:
        gray_values = 1 - gray_values
    else:
        gray_values = gray_values[::-1]

    if reverse:
        gray_values = gray_values[::-1]

    gray_values_rgb = np.repeat(gray_values.reshape(N, 1), 3, axis=1)
    color_wb = LinearSegmentedColormap.from_list('color_wb', gray_values_rgb, N=N)
    return color_wb


class MultiplePlotsWithColorbar():
    """
    Two-column layout plot, where the first column is for user-given plots and the second column
    is for colorbars if the corresponding row needs a colorbar.

    Notebook: B/B_PythonVisualization.ipynb

    Attributes:
        axes: A list of axes for the first column.
        cbar_axes: A list of axes for the second column.
        num_plots: Number of rows, as given to init method.
    """

    def __init__(self, num_plots, figsize=(8, 4), dpi=72, cbar_ratio=0.1, height_ratios=None):
        """Creates an instance of the MultiplePlotsWithColorbar class

        Args:
            num_plots: Number of plots (also number of rows)
            figsize: Figure size in dpi
            dpi: Dots per inch
            cbar_ratio: Width ratio of color bar
            height_ratios: Height ratio for rows
        """
        if height_ratios is None:
            height_ratios = [1] * num_plots

        plt.figure(figsize=figsize, dpi=dpi)
        gs = gridspec.GridSpec(num_plots, 2, width_ratios=[1, cbar_ratio], height_ratios=height_ratios)

        self.num_plots = num_plots
        self.axes = []
        self.cbar_axes = []

        for i in range(self.num_plots):
            self.axes.append(plt.subplot(gs[i, 0]))
            self.cbar_axes.append(plt.subplot(gs[i, 1]))

    def make_colorbars(self):
        """
        Creates colorbars if the corresponding row needs a colorbar, or hides the axis in other cases.
        """
        for i in range(self.num_plots):
            ax_img = self.axes[i].get_images()

            if len(ax_img) == 0:
                self.cbar_axes[i].set_axis_off()
            else:
                plt.colorbar(ax_img[0], cax=self.cbar_axes[i])

        plt.tight_layout()


def color_argument_to_dict(colors, labels_set, default='gray'):
    """Creates a color dictionary

    Args:
        colors: Several options: 1. string of FMP_COLORMAPS, 2. string of matplotlib colormap, 3. list or np.ndarray of
            matplotlib color specifications, 4. dict that assigns labels  to colors
        labels_set: List of all labels
        default: Default color, used for labels that are in labels_set, but not in colors

    Returns:
        color_dict: Dictionary that maps labels to colors
    """

    if isinstance(colors, str):
        # FMP colormap
        if colors in FMP_COLORMAPS:
            color_dict = {l: c for l, c in zip(labels_set, FMP_COLORMAPS[colors])}
        # matplotlib colormap
        else:
            cm = plt.get_cmap(colors)
            num_labels = len(labels_set)
            colors = [cm(i / (num_labels + 1)) for i in range(num_labels)]
            color_dict = {l: c for l, c in zip(labels_set, colors)}

    # list/np.ndarray of colors
    elif isinstance(colors, (list, np.ndarray, tuple)):
        color_dict = {l: c for l, c in zip(labels_set, colors)}

    # is already a dict, nothing to do
    elif isinstance(colors, dict):
        color_dict = colors

    else:
        raise ValueError('`colors` must be str, list, np.ndarray, or dict')

    for key in labels_set:
        if key not in color_dict:
            color_dict[key] = default

    return color_dict


def check_line_annotations(annot, default_label=''):
    """Checks line annotation. If label is missing, adds an default label.

    Args:
        annot: A List of the form of [(time_position, label), ...], or [(time_position, ), ...], or [time_position, ...]
        default_label: The default label used if label is missing

    Returns:
        annot: A List of tuples in the form of [(time_position, label), ...]
    """
    if isinstance(annot[0], (list, np.ndarray, tuple)):
        len_annot = len(annot[0])
        assert all(len(a) == len_annot for a in annot)
        if len_annot == 1:
            annot = [(a[0], default_label) for a in annot]

    else:
        assert isinstance(annot[0], (int, float, complex)) or np.isscalar(annot[0])
        annot = [(a, default_label) for a in annot]

    return annot


def check_segment_annotations(annot, default_label=''):
    """Checks segment annotation. If label is missing, adds an default label.

    Args:
        annot: A List of the form of [(start_position, end_position, label), ...], or
            [(start_position, end_position), ...]
        default_label: The default label used if label is missing

    Returns:
        annot: A List of tuples in the form of [(start_position, end_position, label), ...]
    """
    assert isinstance(annot[0], (list, np.ndarray, tuple))
    len_annot = len(annot[0])
    assert all(len(a) == len_annot for a in annot)
    if len_annot == 2:
        annot = [(a[0], a[1], default_label) for a in annot]

    return annot


def plot_annotation_line(annot, ax=None, label_keys={}, colors='FMP_1', figsize=(6, 1), direction='horizontal',
                         time_min=None, time_max=None, time_axis=True, nontime_axis=False, swap_time_ticks=False,
                         axis_off=False, dpi=72):
    """Creates a line plot for annotation data

    Args:
        annot: A List of tuples in the form of [(time_position, label), ...]
        ax: The Axes instance to plot on. If None, will create a figure and axes.
        label_keys: A dict, where the keys are the labels used in `annot`. The values are dicts, which are used as
            keyword arguments for matplotlib.pyplot.axvline or matplotlib.pyplot.axhline.
        colors: Several options: 1. string of FMP_COLORMAPS, 2. string of matplotlib colormap, 3. list or np.ndarray of
            matplotlib color specifications, 4. dict that assigns labels  to colors
        figsize: Width, height in inches
        direction: 'vertical' or 'horizontal'
        time_min: Minimal limit for time axis. If None, will be min annotation.
        time_max: Maximal limit for time axis. If None, will be max from annotation.
        time_axis: Display time axis ticks or not
        nontime_axis: Display non-time axis ticks or not
        swap_time_ticks: For horizontal: xticks up; for vertical: yticks left
        axis_off: Calls ax.axis('off')
        dpi: Dots per inch

    Returns:
        fig: The created matplotlib figure or None if ax was given.
        ax: The used axes.
    """

    assert direction in ['vertical', 'horizontal']
    annot = check_line_annotations(annot)

    fig = None
    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)

    labels_set = sorted(set([label for pos, label in annot]))
    colors = color_argument_to_dict(colors, labels_set)

    for key, value in colors.items():
        if key not in label_keys:
            label_keys[key] = {}
        if 'color' not in label_keys[key]:
            label_keys[key]['color'] = value

    for pos, label in annot:
        if direction == 'horizontal':
            ax.axvline(pos, **label_keys[label])
        else:
            ax.axhline(pos, **label_keys[label])

    if time_min is None:
        time_min = min(pos for pos, label in annot)
    if time_max is None:
        time_max = max(pos for pos, label in annot)

    if direction == 'horizontal':
        ax.set_xlim([time_min, time_max])
        if not time_axis:
            ax.set_xticks([])
        if not nontime_axis:
            ax.set_yticks([])
        if swap_time_ticks:
            ax.xaxis.tick_top()
    else:
        ax.set_ylim([time_min, time_max])
        if not time_axis:
            ax.set_yticks([])
        if not nontime_axis:
            ax.set_xticks([])
        if swap_time_ticks:
            ax.yaxis.tick_right()

    if axis_off:
        ax.axis('off')

    if fig is not None:
        plt.tight_layout()

    return fig, ax


def plot_annotation_line_overlay(*args, **kwargs):
    """Plot segment annotations as overlay

    Args:
        see libfmp.b.plot_annotation_line

    Returns:
        see libfmp.b.plot_annotation_line
    """
    assert 'nontime_axis' not in kwargs
    kwargs['nontime_axis'] = True
    return plot_annotation_line(*args, **kwargs)


def plot_annotation_multiline(annot, ax=None, label_keys={}, colors='FMP_1', figsize=(6, 1.5), direction='horizontal',
                              sort_labels=None, time_min=None, time_max=None, time_axis=True, swap_time_ticks=False,
                              axis_off=False, dpi=72):
    """Creates a multi-line plot for annotation data

    Args:
        annot: A List of tuples in the form of [(time_position, label), ...]
        ax: The Axes instance to plot on. If None, will create a figure and axes.
        label_keys: A dict, where the keys are the labels used in `annot`. The values are dicts, which are used as
            keyword arguments for matplotlib.pyplot.axvline or matplotlib.pyplot.axhline.
        colors: Several options: 1. string of FMP_COLORMAPS, 2. string of matplotlib colormap, 3. list or np.ndarray of
            matplotlib color specifications, 4. dict that assigns labels  to colors
        figsize: Width, height in inches
        direction: 'vertical' or 'horizontal'
        sort_labels: List of labels used for sorting the line plots
        time_min: Minimal limit for time axis. If None, will be min annotation.
        time_max: Maximal limit for time axis. If None, will be max from annotation.
        time_axis: Display time axis ticks or not
        swap_time_ticks: For horizontal: xticks up; for vertical: yticks left
        axis_off: Calls ax.axis('off')
        dpi: Dots per inch

    Returns:
        fig: The created matplotlib figure or None if ax was given.
        ax: The used axes.
    """

    assert direction in ['vertical', 'horizontal']
    annot = check_line_annotations(annot)

    fig = None
    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)

    labels_set = sorted(set([label for pos, label in annot]))
    colors = color_argument_to_dict(colors, labels_set)

    for key, value in colors.items():
        if key not in label_keys:
            label_keys[key] = {}
        if 'color' not in label_keys[key]:
            label_keys[key]['color'] = value

    if sort_labels:
        sort_func = lambda x: sort_labels.index(x) if x in sort_labels else 0
    else:
        sort_func = None

    all_labels = sorted(set(label for pos, label in annot), key=sort_func)

    for i, cur_label in enumerate(all_labels):
        cur_pos = [pos for pos, label in annot if cur_label == label]

        if direction == 'horizontal':
            ax.vlines(cur_pos, i, i+1, **label_keys[cur_label])
        else:
            ax.hlines(cur_pos, i, i+1, **label_keys[cur_label])

    if time_min is None:
        time_min = min(pos for pos, label in annot)
    if time_max is None:
        time_max = max(pos for pos, label in annot)

    if direction == 'horizontal':
        ax.set_ylim([0, len(all_labels)])
        ax.set_xlim([time_min, time_max])
        for seperator in range(1, len(all_labels)):
            ax.axhline(seperator, color='k')
        ax.set_yticks(np.arange(len(all_labels)) + 0.5)
        ax.set_yticklabels(all_labels)
        if not time_axis:
            ax.set_xticks([])
        if swap_time_ticks:
            ax.xaxis.tick_top()

    else:
        ax.set_xlim([0, len(all_labels)])
        ax.set_ylim([time_min, time_max])
        for seperator in range(1, len(all_labels)):
            ax.axvline(seperator, color='k')
        ax.set_xticks(np.arange(len(all_labels)) + 0.5)
        ax.set_xticklabels(all_labels, rotation=90)
        if not time_axis:
            ax.set_yticks([])
        if swap_time_ticks:
            ax.yaxis.tick_right()

    if axis_off:
        ax.axis('off')

    if fig is not None:
        plt.tight_layout()

    return fig, ax


def plot_segments(annot, ax=None, figsize=(6, 1), direction='horizontal', colors='FMP_1', time_min=None,
                  time_max=None, nontime_min=0, nontime_max=1, time_axis=True, nontime_axis=False, time_label=None,
                  swap_time_ticks=False, edgecolor='k', axis_off=False, dpi=72, adjust_time_axislim=True,
                  adjust_nontime_axislim=True, alpha=None, print_labels=True, label_ticks=False, **kwargs):
    """Creates a multi-line plot for annotation data

    Args:
        annot: A List of tuples in the form of [(start_position, end_position, label), ...]
        ax: The Axes instance to plot on. If None, will create a figure and axes.
        figsize: Width, height in inches
        direction: 'vertical' or 'horizontal'
        colors: Several options: 1. string of FMP_COLORMAPS, 2. string of matplotlib colormap, 3. list or np.ndarray of
            matplotlib color specifications, 4. dict that assigns labels  to colors
        time_min: Minimal limit for time axis. If None, will be min annotation.
        time_max: Maximal limit for time axis. If None, will be max from annotation.
        nontime_min: Minimal limit for non-time axis.
        nontime_max: Maximal limit for non-time axis.
        time_axis: Display time axis ticks or not
        nontime_axis: Display non-time axis ticks or not
        swap_time_ticks: For horizontal: xticks up; for vertical: yticks left
        edgecolor: Color for edgelines of segment box
        axis_off: Calls ax.axis('off')
        dpi: Dots per inch
        adjust_time_axislim: Adjust time-axis. Usually True for plotting on standalone axes and False for
            overlay plotting
        adjust_nontime_axislim: Adjust non-time-axis. Usually True for plotting on standalone axes and False for
            overlay plotting
        alpha: Alpha value for rectangle
        print_labels: Print labels inside Rectangles
        label_ticks: Print labels as ticks
        kwargs: Keyword arguments for matplotlib.pyplot.annotate

    Returns:
        fig: The created matplotlib figure or None if ax was given.
        ax: The used axes.
    """
    assert direction in ['vertical', 'horizontal']
    annot = check_segment_annotations(annot)

    if 'color' not in kwargs:
        kwargs['color'] = 'k'
    if 'weight' not in kwargs:
        kwargs['weight'] = 'bold'
        # kwargs['weight'] = 'normal'
    if 'fontsize' not in kwargs:
        kwargs['fontsize'] = 12
    if 'ha' not in kwargs:
        kwargs['ha'] = 'center'
    if 'va' not in kwargs:
        kwargs['va'] = 'center'

    if colors is None:
        colors = 'FMP_1'

    fig = None
    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)

    labels_set = sorted(set([label for start, end, label in annot]))
    colors = color_argument_to_dict(colors, labels_set)

    nontime_width = nontime_max - nontime_min
    nontime_middle = nontime_min + nontime_width / 2
    all_time_middles = []

    for start, end, label in annot:
        time_width = end - start
        time_middle = start + time_width / 2
        all_time_middles.append(time_middle)

        if direction == 'horizontal':
            rect = mpatch.Rectangle((start, nontime_min), time_width, nontime_width,
                                    facecolor=colors[label], edgecolor=edgecolor, alpha=alpha)
            ax.add_patch(rect)
            if print_labels:
                ax.annotate(label, (time_middle, nontime_middle), **kwargs)
        else:
            rect = mpatch.Rectangle((nontime_min, start), nontime_width, time_width,
                                    facecolor=colors[label], edgecolor=edgecolor, alpha=alpha)
            ax.add_patch(rect)
            if print_labels:
                ax.annotate(label, (nontime_middle, time_middle), **kwargs)

    if time_min is None:
        time_min = min(start for start, end, label in annot)
    if time_max is None:
        time_max = max(end for start, end, label in annot)

    if direction == 'horizontal':
        if adjust_time_axislim:
            ax.set_xlim([time_min, time_max])
        if adjust_nontime_axislim:
            ax.set_ylim([nontime_min, nontime_max])
        if not nontime_axis:
            ax.set_yticks([])
        if not time_axis:
            ax.set_xticks([])
        if swap_time_ticks:
            ax.xaxis.tick_top()
        if time_label:
            ax.set_xlabel(time_label)
        if label_ticks:
            ax.set_xticks(all_time_middles)
            ax.set_xticklabels([label for start, end, label in annot])

    else:
        if adjust_time_axislim:
            ax.set_ylim([time_min, time_max])
        if adjust_nontime_axislim:
            ax.set_xlim([nontime_min, nontime_max])
        if not nontime_axis:
            ax.set_xticks([])
        if not time_axis:
            ax.set_yticks([])
        if swap_time_ticks:
            ax.yaxis.tick_right()
        if time_label:
            ax.set_ylabel(time_label)
        if label_ticks:
            ax.set_yticks(all_time_middles)
            ax.set_yticklabels([label for start, end, label in annot])

    if axis_off:
        ax.axis('off')

    if fig is not None:
        plt.tight_layout()

    return fig, ax


def plot_segments_overlay(*args, **kwargs):
    """Plot segment annotations as overlay

    Args:
        see libfmp.b.plot_segments

    Returns:
        see libfmp.b.plot_segments
    """
    assert 'ax' in kwargs
    ax = kwargs['ax']

    if 'adjust_time_axislim' not in kwargs:
        kwargs['adjust_time_axislim'] = False
    if 'adjust_nontime_axislim' not in kwargs:
        kwargs['adjust_nontime_axislim'] = False
    if 'alpha' not in kwargs:
        kwargs['alpha'] = 0.3
    if 'edgecolor' not in kwargs:
        kwargs['edgecolor'] = None
    if 'nontime_axis' not in kwargs:
        kwargs['nontime_axis'] = True

    if 'direction' in kwargs and kwargs['direction'] == 'vertical':
        kwargs['nontime_min'], kwargs['nontime_max'] = ax.get_xlim()
    else:
        kwargs['nontime_min'], kwargs['nontime_max'] = ax.get_ylim()

    fig, ax = plot_segments(*args, **kwargs)

    return fig, ax