#!/usr/bin/env python3
import warnings
from typing import Callable, Iterable, List, Optional, Union

import matplotlib.pyplot as plt
import numpy as np

from aepsych.strategy import Strategy
from aepsych.utils import get_lse_contour, get_lse_interval, make_scaled_sobol
from scipy.stats import norm

[docs]def plot_strat( strat: Strategy, ax: Optional[plt.Axes] = None, true_testfun: Optional[Callable] = None, cred_level: float = 0.95, target_level: Optional[float] = 0.75, xlabel: Optional[str] = None, ylabel: Optional[str] = None, yes_label: str = "Yes trial", no_label: str = "No trial", flipx: bool = False, logx: bool = False, gridsize: int = 30, title: str = "", save_path: Optional[str] = None, show: bool = True, include_legend: bool = True, include_colorbar: bool = True, ) -> None: """Creates a plot of a strategy, showing participants responses on each trial, the estimated response function and threshold, and optionally a ground truth response threshold. Args: strat (Strategy): Strategy object to be plotted. Must have a dimensionality of 2 or less. ax (plt.Axes, optional): Matplotlib axis to plot on (if None, creates a new axis). Default: None. true_testfun (Callable, optional): Ground truth response function. Should take a n_samples x n_parameters tensor as input and produce the response probability at each sample as output. Default: None. cred_level (float): Percentage of posterior mass around the mean to be shaded. Default: 0.95. target_level (float): Response probability to estimate the threshold of. Default: 0.75. xlabel (str): Label of the x-axis. Default: "Context (abstract)". ylabel (str): Label of the y-axis (if None, defaults to "Response Probability" for 1-d plots or "Intensity (Abstract)" for 2-d plots). Default: None. yes_label (str): Label of trials with response of 1. Default: "Yes trial". no_label (str): Label of trials with response of 0. Default: "No trial". flipx (bool): Whether the values of the x-axis should be flipped such that the min becomes the max and vice versa. (Only valid for 2-d plots.) Default: False. logx (bool): Whether the x-axis should be log-transformed. (Only valid for 2-d plots.) Default: False. gridsize (int): The number of points to sample each dimension at. Default: 30. title (str): Title of the plot. Default: ''. save_path (str, optional): File name to save the plot to. Default: None. show (bool): Whether the plot should be shown in an interactive window. Default: True. include_legend (bool): Whether to include the legend in the figure. Default: True. include_colorbar (bool): Whether to include the colorbar indicating the probability of "Yes" trials. Default: True. """ assert ( "binary" in strat.outcome_types ), f"Plotting not supported for outcome_type {strat.outcome_types[0]}" if target_level is not None and not hasattr(strat.model, "monotonic_idxs"): warnings.warn( "Threshold estimation may not be accurate for non-monotonic models." ) if ax is None: _, ax = plt.subplots() if xlabel is None: xlabel = "Context (abstract)" dim = strat.dim if dim == 1: if ylabel is None: ylabel = "Response Probability" _plot_strat_1d( strat, ax, true_testfun, cred_level, target_level, xlabel, ylabel, yes_label, no_label, gridsize, ) elif dim == 2: if ylabel is None: ylabel = "Intensity (abstract)" _plot_strat_2d( strat, ax, true_testfun, cred_level, target_level, xlabel, ylabel, yes_label, no_label, flipx, logx, gridsize, include_colorbar, ) elif dim == 3: raise RuntimeError("Use plot_strat_3d for 3d plots!") else: raise NotImplementedError("No plots for >3d!") ax.set_title(title) if include_legend: anchor = (1.4, 0.5) if include_colorbar and dim > 1 else (1, 0.5) plt.legend(loc="center left", bbox_to_anchor=anchor) if save_path is not None: plt.savefig(save_path, bbox_inches="tight") if show: plt.tight_layout() if include_legend or (include_colorbar and dim > 1): plt.subplots_adjust(left=0.1, bottom=0.25, top=0.75)
def _plot_strat_1d( strat: Strategy, ax: plt.Axes, true_testfun: Optional[Callable], cred_level: float, target_level: Optional[float], xlabel: str, ylabel: str, yes_label: str, no_label: str, gridsize: int, ): """Helper function for creating 1-d plots. See plot_strat for an explanation of the arguments.""" x, y = strat.x, strat.y assert x is not None and y is not None, "No data to plot!" grid = strat.model.dim_grid(gridsize=gridsize) samps = norm.cdf(strat.model.sample(grid, num_samples=10000).detach()) phimean = samps.mean(0) ax.plot(np.squeeze(grid), phimean) if cred_level is not None: upper = np.quantile(samps, cred_level, axis=0) lower = np.quantile(samps, 1 - cred_level, axis=0) ax.fill_between( np.squeeze(grid), lower, upper, alpha=0.3, hatch="///", edgecolor="gray", label=f"{cred_level*100:.0f}% posterior mass", ) if target_level is not None: from aepsych.utils import interpolate_monotonic threshold_samps = [ interpolate_monotonic( grid.squeeze().numpy(), s, target_level,[0], strat.ub[0] ) for s in samps ] thresh_med = np.mean(threshold_samps) thresh_lower = np.quantile(threshold_samps, q=1 - cred_level) thresh_upper = np.quantile(threshold_samps, q=cred_level) ax.errorbar( thresh_med, target_level, xerr=np.r_[thresh_med - thresh_lower, thresh_upper - thresh_med][:, None], capsize=5, elinewidth=1, label=f"Est. {target_level*100:.0f}% threshold \n(with {cred_level*100:.0f}% posterior \nmass marked)", ) if true_testfun is not None: true_f = true_testfun(grid) ax.plot(grid, true_f.squeeze(), label="True function") if target_level is not None: true_thresh = interpolate_monotonic( grid.squeeze().numpy(), true_f.squeeze(), target_level,[0], strat.ub[0], ) ax.plot( true_thresh, target_level, "o", label=f"True {target_level*100:.0f}% threshold", ) ax.scatter( x[y == 0, 0], np.zeros_like(x[y == 0, 0]), marker=3, color="r", label=no_label, ) ax.scatter( x[y == 1, 0], np.zeros_like(x[y == 1, 0]), marker=3, color="b", label=yes_label, ) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) return ax def _plot_strat_2d( strat: Strategy, ax: plt.Axes, true_testfun: Optional[Callable], cred_level: float, target_level: Optional[float], xlabel: str, ylabel: str, yes_label: str, no_label: str, flipx: bool, logx: bool, gridsize: int, include_colorbar: bool, ): """Helper function for creating 2-d plots. See plot_strat for an explanation of the arguments.""" x, y = strat.x, strat.y assert x is not None and y is not None, "No data to plot!" # make sure the model is fit well if we've been limiting fit time, train_y=y, max_fit_time=None) grid = strat.model.dim_grid(gridsize=gridsize) fmean, _ = strat.model.predict(grid) phimean = norm.cdf(fmean.reshape(gridsize, gridsize).detach().numpy()).T extent = np.r_[[0], strat.ub[0],[1], strat.ub[1]] colormap = ax.imshow( phimean, aspect="auto", origin="lower", extent=extent, alpha=0.5 ) if flipx: extent = np.r_[[0], strat.ub[0], strat.ub[1],[1]] colormap = ax.imshow( phimean, aspect="auto", origin="upper", extent=extent, alpha=0.5 ) else: extent = np.r_[[0], strat.ub[0],[1], strat.ub[1]] colormap = ax.imshow( phimean, aspect="auto", origin="lower", extent=extent, alpha=0.5 ) # hacky relabel to be in logspace if logx: locs = np.arange([0], strat.ub[0]) ax.set_xticks(ticks=locs) ax.set_xticklabels(2.0**locs) ax.plot(x[y == 0, 0], x[y == 0, 1], "ro", alpha=0.7, label=no_label) ax.plot(x[y == 1, 0], x[y == 1, 1], "bo", alpha=0.7, label=yes_label) if target_level is not None: # plot threshold mono_grid = np.linspace([1], strat.ub[1], num=gridsize) context_grid = np.linspace([0], strat.ub[0], num=gridsize) thresh_75, lower, upper = get_lse_interval( model=strat.model, mono_grid=mono_grid, target_level=target_level, cred_level=cred_level, mono_dim=1, lb=mono_grid.min(), ub=mono_grid.max(), gridsize=gridsize, ) ax.plot( context_grid, thresh_75, label=f"Est. {target_level*100:.0f}% threshold \n(with {cred_level*100:.0f}% posterior \nmass shaded)", ) ax.fill_between( context_grid, lower, upper, alpha=0.3, hatch="///", edgecolor="gray" ) if true_testfun is not None: true_f = true_testfun(grid).reshape(gridsize, gridsize) true_thresh = get_lse_contour( true_f, mono_grid, level=target_level,[-1], ub=strat.ub[-1] ) ax.plot(context_grid, true_thresh, label="Ground truth threshold") ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) if include_colorbar: colorbar = plt.colorbar(colormap, ax=ax) colorbar.set_label(f"Probability of {yes_label}")
[docs]def plot_strat_3d( strat: Strategy, parnames: Optional[List[str]] = None, outcome_label: str = "Yes Trial", slice_dim: int = 0, slice_vals: Union[List[float], int] = 5, contour_levels: Optional[Union[Iterable[float], bool]] = None, probability_space: bool = False, gridsize: int = 30, extent_multiplier: Optional[List[float]] = None, save_path: Optional[str] = None, show: bool = True, ): """Creates a plot of a 2d slice of a 3D strategy, showing the estimated model or probability response and contours Args: strat (Strategy): Strategy object to be plotted. Must have a dimensionality of 3. parnames (str list): list of the parameter names outcome_label (str): The label of the outcome variable slice_dim (int): dimension to slice on dim_vals (list of floats or int): values to take slices; OR number of values to take even slices from contour_levels (iterable of floats or bool, optional): List contour values to plot. Default: None. If true, all integer levels. probability_space (bool): Whether to plot probability. Default: False gridsize (int): The number of points to sample each dimension at. Default: 30. extent_multiplier (list, optional): multipliers for each of the dimensions when plotting. Default:None save_path (str, optional): File name to save the plot to. Default: None. show (bool): Whether the plot should be shown in an interactive window. Default: True. """ assert strat.model is not None, "Cannot plot without a model!" contour_levels_list = contour_levels or [] if parnames is None: parnames = ["x1", "x2", "x3"] # Get global min/max for all slices if probability_space: vmax = 1 vmin = 0 if contour_levels is True: contour_levels_list = [0.75] else: d = make_scaled_sobol(, strat.ub, 2000) post = strat.model.posterior(d) fmean = post.mean.squeeze().detach().numpy() vmax = np.max(fmean) vmin = np.min(fmean) if contour_levels is True: contour_levels_list = np.arange(np.ceil(vmin), vmax + 1) # slice_vals is either a list of values or an integer number of values to slice on if type(slice_vals) is int: slices = np.linspace([slice_dim], strat.ub[slice_dim], slice_vals) slices = np.around(slices, 4) elif type(slice_vals) is not list: raise TypeError("slice_vals must be either an integer or a list of values") else: slices = np.array(slice_vals) _, axs = plt.subplots(1, len(slices), constrained_layout=True, figsize=(20, 3)) for _i, dim_val in enumerate(slices): img = plot_slice( axs[_i], strat, parnames, slice_dim, dim_val, vmin, vmax, gridsize, contour_levels_list, probability_space, extent_multiplier, ) plt_parnames = np.delete(parnames, slice_dim) axs[0].set_ylabel(plt_parnames[1]) cbar = plt.colorbar(img, ax=axs[-1]) if probability_space:"Probability of {outcome_label}") else: for clevel in contour_levels_list: # type: ignore, c="w") if save_path is not None: plt.savefig(save_path) if show:
[docs]def plot_slice( ax, strat, parnames, slice_dim, slice_val, vmin, vmax, gridsize=30, contour_levels=None, lse=False, extent_multiplier=None, ): """Creates a plot of a 2d slice of a 3D strategy, showing the estimated model or probability response and contours Args: strat (Strategy): Strategy object to be plotted. Must have a dimensionality of 3. ax (plt.Axes): Matplotlib axis to plot on parnames (str list): list of the parameter names slice_dim (int): dimension to slice on slice_vals (float): value to take the slice along that dimension vmin (float): global model minimum to use for plotting vmax (float): global model maximum to use for plotting gridsize (int): The number of points to sample each dimension at. Default: 30. contour_levels (int list): Contours to plot. Default: None lse (bool): Whether to plot probability. Default: False extent_multiplier (list, optional): multipliers for each of the dimensions when plotting. Default:None """ extent = np.c_[, strat.ub].reshape(-1) x = strat.model.dim_grid(gridsize=gridsize, slice_dims={slice_dim: slice_val}) if lse: fmean, fvar = strat.predict(x) fmean = fmean.detach().numpy().reshape(gridsize, gridsize) fmean = norm.cdf(fmean) else: post = strat.model.posterior(x) fmean = post.mean.squeeze().detach().numpy().reshape(gridsize, gridsize) # optionally rescale extents to correct values if extent_multiplier is not None: extent_scaled = extent * np.repeat(extent_multiplier, 2) dim_val_scaled = slice_val * extent_multiplier[slice_dim] else: extent_scaled = extent dim_val_scaled = slice_val plt_extents = np.delete(extent_scaled, [slice_dim * 2, slice_dim * 2 + 1]) plt_parnames = np.delete(parnames, slice_dim) img = ax.imshow( fmean.T, extent=plt_extents, origin="lower", aspect="auto", vmin=vmin, vmax=vmax ) ax.set_title(parnames[slice_dim] + "=" + str(dim_val_scaled)) ax.set_xlabel(plt_parnames[0]) if len(contour_levels) > 0: ax.contour( fmean.T, contour_levels, colors="w", extent=plt_extents, origin="lower", aspect="auto", ) return img