#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Literal,
Optional,
Sized,
Tuple,
TypeAlias,
Union,
)
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import torch
from aepsych.strategy import Strategy
from aepsych.transforms import ParameterTransforms
from aepsych.transforms.ops import NormalizeScale
from aepsych.utils import dim_grid, get_lse_contour, get_lse_interval, make_scaled_sobol
from matplotlib.axes import Axes
from matplotlib.image import AxesImage
from scipy.stats import norm
linestyle_str = Literal["solid", "dashsed", "dashdot", "dotted"]
ColorType: TypeAlias = str
[docs]def plot_predict_1d(
x: Iterable[float],
prediction: Iterable[float],
ax: Optional[Axes] = None,
pred_lower: Optional[Iterable[float]] = None,
pred_upper: Optional[Iterable[float]] = None,
shaded_kwargs: Optional[Dict[Any, Any]] = None,
**kwargs,
) -> Axes:
"""Return the ax with the model predictions plotted in place as a 1D line plot.
Usually plots the predictions in the posterior space or the probability space.
Args:
x (Iterable[float]): The values where the model was evaluated, assumed to be
ordered from lb to ub.
prediction (Iterable[float]): The values of the predictions at every point it was
evaluated, assumed to be the same order as x.
ax (Axes, optional): The Matplotlib axes to plot onto. If not set, an axes is
made and returned.
post_lower (Iterable[float], optional): The lower part of the shaded region around
the prediction line, both post_lower/post_upper must be set to plot the band.
post_upper (Iterable[float], optional): The upper part of the shaded region around
the prediction line, both post_lower/post_upper must be set to plot the band.
shaded_kwargs (Dict[Any, Any], optional): Kwargs to pass to the
`ax.fill_between()` call to modify the shaded regions, defaults to None.
**kwargs: Extra kwargs passed to the ax.plot() call, not passed to the plotting
functions in charge of shaded regions.
Returns:
Axes: The input axes with the prediction plotted onto it. Note that plotting is
done in-place.
"""
x, prediction, pred_lower, pred_upper = _tensor_cast(
x, prediction, pred_lower, pred_upper
)
if ax is None:
_, ax = plt.subplots()
ax.plot(x.squeeze(), prediction.squeeze(), **kwargs)
if pred_lower is not None and pred_upper is not None:
shaded_kwargs = shaded_kwargs or {}
ax.fill_between(
x.squeeze(),
pred_lower.squeeze(),
pred_upper.squeeze(),
alpha=0.3,
hatch="///",
edgecolor="gray",
**shaded_kwargs,
)
return ax
[docs]def plot_points_1d(
x: Iterable[Union[float, Iterable[float]]],
y: Iterable[float],
ax: Optional[Axes] = None,
pred_x: Optional[Iterable[float]] = None,
pred_y: Optional[Iterable[float]] = None,
point_size: float = 5.0,
cmap_colors: List[ColorType] = ["r", "b"],
label_points: bool = True,
legend_loc: str = "best",
**kwargs,
) -> Axes:
r"""Return the ax with the points plotted based on x and y in a 1D plot. If
pred_x/pred_y is not set, these are plotted as marks at the bottom of the plot,
otherwise, each point is plotted as close as possible to the line defined by
the prediction values (pred_x/pred_y). Usually use alongside `plot_predict_1d()`.
Args:
x (Iterable[Union[float, Iterable[float]]]): The `(n, 1)` or `(n, d, 2)` points to plot. The 3D case will
be considered a pairwise plot.
y (Iterable[float]): The `(n, 1)` responses to plot.
ax (Axes, optional): The Matplotlib axes to plot onto. If not set, an axes is
made and returned.
pred_x (Iterable[float], optional): The points where the model was evaluated, used
to position each point as close as possible to the line. If not set, the
points are plotted as marks at the bottom of the plot.
pred_y (Iterable[float], optional): The model outputs at each point in pred_x, used to
position each point as close as possible to the line. If not set, the points
are plotted as marks at the bottom of the plot.
point_size (float): The size of each plotted point, defaults to 5.0.
cmap_colors (List[ColorType]): A list of colors to map the point colors to from
min to max of y. At least 2 colors are needed, but more colors will allow
customizing intermediate colors. Defaults to ["r", "b"].
label_points (bool): Add a way to identify the value of the points, whether as a
legend for cases there are 6 or less unique responses or a colorbar for
7 or more unique responses. Defaults to True.
legend_loc (str): If a legend is added, where should it be placed.
**kwargs: Extra kwargs passed to the ax.plot() call, note that every point is
plotted with an individual call so these kwargs must be applicable to single
points.
Returns:
Axes: The input axes with the points plotted onto it. Note that plotting is done
in-place.
"""
x, y, pred_x, pred_y = _tensor_cast(x, y, pred_x, pred_y)
if len(cmap_colors) < 2:
raise ValueError("cmap_colors must be at least 2 colors.")
cmap = matplotlib.colors.LinearSegmentedColormap.from_list("cmap", cmap_colors)
norm = matplotlib.colors.Normalize(y.min().item(), y.max().item())
if ax is None:
_, ax = plt.subplots()
if len(x.shape) == 3: # Multi dim case
if pred_x is not None and pred_y is not None:
for pair, response in zip(x, y):
x1 = pair[:, 0]
x2 = pair[:, 1]
y1 = pred_y[torch.argmin(torch.abs(x1 - pred_x))]
y2 = pred_y[torch.argmin(torch.abs(x2 - pred_x))]
ax.plot(
[np.array(x1), np.array(x2)],
[np.array(y1), np.array(y2)],
"-",
c="gray",
alpha=0.5,
)
ax.plot(
x1,
y1,
marker="o",
color=cmap.reversed()(norm(response)),
markersize=point_size,
alpha=0.5,
**kwargs,
)
ax.plot(
x2,
y2,
marker="o",
color=cmap(norm(response)),
markersize=point_size,
alpha=0.5,
**kwargs,
)
else:
def curve(start, end, mid):
x_coords = (
start[0].item() if isinstance(start[0], torch.Tensor) else start[0],
end[0].item() if isinstance(end[0], torch.Tensor) else end[0],
mid[0].item() if isinstance(mid[0], torch.Tensor) else mid[0],
)
y_coords = (
start[1].item() if isinstance(start[1], torch.Tensor) else start[1],
end[1].item() if isinstance(end[1], torch.Tensor) else end[1],
mid[1].item() if isinstance(mid[1], torch.Tensor) else mid[1],
)
poly_x = np.array(x_coords).squeeze()
poly_y = np.array(y_coords).squeeze()
f = np.poly1d(np.polyfit(poly_x, poly_y, 2))
x = np.linspace(start[0], end[0], 100)
return x, f(x)
# Get where the hatches should be
hatch_y, y_max = ax.get_ylim()
mid_y = hatch_y + ((y_max - hatch_y) * 0.05)
for pair, response in zip(x, y):
x1 = pair[:, 0]
x2 = pair[:, 1]
# Create a curvey line between the hatches
mid_x = torch.min(pair) + (torch.abs(x1 - x2) / 2)
line_x, line_y = curve([x1, hatch_y], [x2, hatch_y], [mid_x, mid_y])
ax.plot(line_x, line_y, "-", c="gray", alpha=0.5)
ax.plot(
x1,
hatch_y,
marker=3,
color=cmap.reversed()(norm(response)),
**kwargs,
)
ax.plot(
x2,
hatch_y,
marker=3,
color=cmap(norm(response)),
**kwargs,
)
else:
if pred_x is not None and pred_y is not None:
for x_, y_ in zip(x, y):
plot_y = pred_y[torch.argmin(torch.abs(x_ - pred_x))]
ax.plot(
x_,
plot_y,
marker="o",
color=cmap(norm(y_)),
markersize=point_size,
alpha=0.5,
**kwargs,
)
else:
# Get where the hatches should be
hatch_y, y_max = ax.get_ylim()
mid_y = hatch_y + ((y_max - hatch_y) * 0.05)
for x_, y_ in zip(x, y):
ax.plot(
x_,
hatch_y,
marker=3,
color=cmap(norm(y_)),
**kwargs,
)
if label_points:
_point_labeler(ax, y, cmap, norm, legend_loc)
return ax
[docs]def plot_predict_2d(
prediction: Iterable[Iterable[float]],
lb: Iterable[float],
ub: Iterable[float],
ax: Optional[Axes] = None,
vmin: Optional[float] = None,
vmax: Optional[float] = None,
edge_multiplier: float = 0.0,
colorbar: bool = True,
**kwargs,
) -> Axes:
"""Return the ax with the model predictions plotted in-place as a 2D heatmap.
Usually used to plot the model outputs in posterior space or probability space.
Args:
prediction (Iterable[Iterable[float]]): A 2D array of the predictions, assumes it was a square
parameter grid where each cell was evaluated by the model.
lb (Iterable[float]): The lower bound of the two parameters being plotted.
ub (Iterable[float]): The upper bound of the two parameters being plotted.
ax (Axes, optional): The Matplotlib axes to plot onto. If not set, an axes is
made and returned.
vmin (float, optional): The minimum value of the predictions, if not set, it will
be the minimum of prediction.
vmax (float, optional): The maximum of the predictions, if not set, it will be the
maximum of prediction.
edge_multiplier (float): How much to extend the plot extents beyond the parameter
bounds (lb, ub), the plot extends beyond the bounds by the absolute difference
multiplied by the edge_multiplier. Setting this to 0 will not extend the
edge, any postive value will plot beyond the bounds, negative will plot less
than the bounds. Used when you do not want the edges of the heatmap to be
right at the parameter boundaries, especially useful if many points are at
parameter boundaries. Defaults to 0.
colorbar (bool): Whether to add a colorbar (with bounds of [vmin, vmax]) to the
parent figure of the Axes, defaults to True.
**kwargs: Extra kwargs passed to the ax.imshow() call that creates the heatmap.
Returns:
Axes: The input axes with the predictions plotted onto it. Note that the plotting
is done in-place.
"""
prediction, lb, ub = _tensor_cast(prediction, lb, ub)
if ax is None:
_, ax = plt.subplots()
# Make sure bounds are floats
lb = lb.double()
ub = ub.double()
diff = np.abs(ub - lb)
edge_bumps = (diff - (diff * (1 - edge_multiplier))) / 2
lb -= edge_bumps
ub += edge_bumps
extent = (float(lb[0]), float(ub[0]), float(lb[1]), float(ub[1]))
prediction = prediction.T
prediction = torch.flip(prediction, dims=[0])
mappable = ax.imshow(
prediction,
origin="upper",
aspect=((ub[0] - lb[0]) / (ub[1] - lb[1])).item(),
extent=extent,
alpha=0.5,
vmin=vmin,
vmax=vmax,
**kwargs,
)
if colorbar and ax.figure is not None:
# Orphaned axes will ignore colorbar, but it's rare
ax.figure.colorbar(mappable)
return ax
[docs]def plot_points_2d(
x: Iterable[Iterable[float]],
y: Iterable[float],
point_size: float = 5.0,
ax: Optional[Axes] = None,
axis: Optional[List[int]] = None,
slice_vals: Optional[Iterable[float]] = None,
slice_gradient: float = 1.0,
cmap_colors: List[ColorType] = ["r", "b"],
label_points: bool = True,
legend_loc: str = "best",
**kwargs,
) -> Axes:
r"""Return the axes with the points defined by the parameters (x) and the outcomes
(y) plotted in-place. If the inputs are from an experiment with stimuli per trial, a
line is drawn between both. Usually used alongside `plot_predict_2d()`.
Args:
x (Iterable[Iterable[float]]): The `(n, d, 2)` or `(n, d)` data points to plot. Each point is
a different trial.
y (Iterable[float]): The `(n, 1)` responses given each set of datapoints. Each value
is the response to a given trial.
ax (Axes, optional): The Matplotlib axes to plot onto. If not set, an axes is
made and returned.
point_size (float): The size of the points, defaults to 5.
axis (List[int], optional): If the dimensionality `d` is higher than 2, which
two dimensions should the points be positioned with.
slice_vals (Iterable[float]): If the dimensionality `d` is higher than 2, where was
the other dimensions sliced at. This is used to determine the size/alpha
gradient of the points based on how close each point is to the slice. If not
set, the points will not change size/alpha based on slice distance. The
Euclidean distance (s) of each point is calculated and converted to a
multipler by `exp(-c * s)`, where c is the slice_gradient. The multipler is
applied to the point_size and the alpha (transparency) of the points.
slice_gradient (float): The rate at which the multiplier decreases as a function
of the distance between a point and the slice. Defaults to 1.
cmap_colors (List[ColorType]): A list of colors to map the point colors to from
min to max of y. At least 2 colors are needed, but more colors will allow
customizing intermediate colors. Defaults to ["r", "b"].
label_points (bool): Add a way to identify the value of the points, whether as a
legend for cases there are 6 or less unique responses or a colorbar for
7 or more unique responses. Defaults to True.
legend_loc (str): If a legend is added, where should it be placed.
**kwargs: Extra kwargs passed to the ax.plot() call, note that every point is
plotted with an individual call so these kwargs must be applicable to single
points.
Returns:
Axes: The axes with the points plotted in-place.
"""
x, y, slice_vals = _tensor_cast(x, y, slice_vals)
if len(cmap_colors) < 2:
raise ValueError("cmap_colors must be at least 2 colors.")
cmap = matplotlib.colors.LinearSegmentedColormap.from_list("cmap", cmap_colors)
norm = matplotlib.colors.Normalize(y.min().item(), y.max().item())
if ax is None:
_, ax = plt.subplots()
if x.shape[1] > 2:
if axis is None:
raise ValueError(
"x has more than 2 dimensions and no axis has been defined."
)
else:
xcoords = x[:, axis[0], ...]
ycoords = x[:, axis[1], ...]
if slice_vals is not None:
slice_vals = slice_vals.double()
not_axis = [i for i in range(x.shape[1]) if i not in axis]
normalize = ParameterTransforms(
normalize=NormalizeScale(d=x.shape[1], indices=not_axis)
)
not_x = normalize.transform(x)[:, not_axis, ...]
# Freezes the inferred bounds
normalize.eval()
transformed_vals = torch.zeros(x.shape[1])
transformed_vals[not_axis] = slice_vals
transformed_vals = normalize.transform(transformed_vals)
transformed_vals = transformed_vals[not_axis]
slice_dist = not_x.sub(transformed_vals).pow(2).sum(dim=1).sqrt()
# Calculate point alphas (which is also used as a size multiplier)
point_alphas = torch.exp(-slice_dist * slice_gradient)
else:
point_alphas = torch.ones_like(xcoords)
else:
xcoords = x[:, 0]
ycoords = x[:, 1]
point_alphas = torch.ones_like(xcoords)
if len(x.shape) == 3: # n-wise experiment
# Plot everything in pairs and add lines
for xcoord, ycoord, y_, point_alpha in zip(xcoords, ycoords, y, point_alphas):
ax.plot(xcoord, ycoord, "-", c="gray", alpha=0.5)
ax.plot(
xcoord[0],
ycoord[0],
marker="o",
color=cmap(norm(y_)),
alpha=point_alpha[0].item(),
markersize=point_alpha[0].item() * point_size,
**kwargs,
)
ax.plot(
xcoord[1],
ycoord[1],
marker="o",
color=cmap.reversed()(norm(y_)),
alpha=point_alpha[1].item(),
markersize=point_alpha[1].item() * point_size,
**kwargs,
)
else:
for xcoord, ycoord, y_, point_alpha in zip(xcoords, ycoords, y, point_alphas):
ax.plot(
xcoord,
ycoord,
marker="o",
color=cmap(norm(y_)),
alpha=point_alpha.item(),
markersize=point_alpha.item() * point_size,
**kwargs,
)
if label_points:
_point_labeler(ax, y, cmap, norm, legend_loc)
return ax
[docs]def plot_contours(
prediction: Iterable[Iterable[float]],
lb: Iterable[float],
ub: Iterable[float],
ax: Optional[Axes] = None,
levels: Optional[Iterable[float]] = None,
edge_multiplier: float = 0,
color: Optional[ColorType] = "white",
labels: bool = False,
linestyles: Optional[linestyle_str] = "solid",
**kwargs,
) -> Axes:
"""Plot contour lines at the levels onto the axes based on the model predictions
with extents defined by lb and ub. Assumes that you're plotting ontop of a heatmap
of those predictions given the same extents.
Args:
prediction (Iterable[Iterable[float]]): A 2D array of the predictions, assumes it was a square
parameter grid where each cell was evaluated by the model.
lb (Iterable[float]): The lower bound of the two parameters being plotted.
ub (Iterable[float]): The upper bound of the two parameters being plotted.
ax (Axes, optional): The Matplotlib axes to plot onto. If not set, an axes is
made and returned.
levels (Iterable[float], optional): A sequence of values to plot the contours given
the predictions. If not set, a contour will be plotted at each integer.
edge_multiplier (float): How much to extend the plot extents beyond the parameter
bounds (lb, ub), the plot extends beyond the bounds by the absolute difference
multiplied by the edge_multiplier. Setting this to 0 will not extend the
edge, any postive value will plot beyond the bounds, negative will plot less
than the bounds. Used when you do not want the edges of the heatmap to be
right at the parameter boundaries, especially useful if many points are at
parameter boundaries. Defaults to 0.
color (ColorType): What colors the contours should be, defaults to white.
labels (bool): Whether or not to label the contours.
linestyles (linestyle_str, optional): How should the contour lines be styled,
defaults to "solid". Options are "solid", "dashsed", "dashdot", "dotted", can
be set to None to default to the Matplotlib default.
**kwargs: Extra keyword arguments to pass to the ax.contour() call that plots
the contours.
Returns:
Axes: The axes with the points plotted in-place.
"""
prediction, lb, ub, levels = _tensor_cast(prediction, lb, ub, levels)
if ax is None:
_, ax = plt.subplots()
# Make sure bounds are floats
lb = lb.double()
ub = ub.double()
diff = np.abs(ub - lb)
edge_bumps = (diff - (diff * (1 - edge_multiplier))) / 2
lb -= edge_bumps
ub += edge_bumps
extent = (float(lb[0]), float(ub[0]), float(lb[1]), float(ub[1]))
prediction = prediction.T
prediction = torch.flip(prediction, dims=[0])
if levels is None:
levels = torch.arange(
prediction.min().floor().item(), prediction.max().ceil().item() + 1
)
contours = ax.contour(
prediction,
levels=levels,
extent=extent,
origin="upper",
colors=color,
linestyles=linestyles,
**kwargs,
)
if labels:
ax.clabel(contours, fontsize=10)
return ax
[docs]def facet_slices(
prediction: Union[torch.Tensor, np.ndarray],
plotted_axes: List[int],
lb: Iterable[float],
ub: Iterable[float],
nrows: int,
ncols: int,
plot_size: float,
**kwargs,
) -> Tuple[matplotlib.figure.Figure, np.ndarray, np.ndarray, np.ndarray]:
"""Sets up a set of subplots to plot either a 3D or a 4D space where two dimensions
are plotted and the other dimensions are sliced over the subplots.
Args:
prediction(Union[torch.Tensor, np.ndarray]): The model predictions cube to
slice, in a 3D parameter space, it would be a 3D array, in a 4D parameter
space it would be a 4D array.
plotted_axes (List[int]): The two parameter indices that will be plotted, the
other parameters will be sliced over subplots.
lb (Iterable[Float]): The lower bound of the parameter space.
ub (Iterable[Float]): The upper bound of the parameter space.
nrows (int): How many rows to plot, which will also be how many slices there are
of the first sliced dimension.
ncols (int): How many columns to plot, which will also be how many slices there
are of the second sliced dimesion.
plot_size (float): The width of each individual square plot in inches.
**kwargs: Kwargs passed to the plt.subplots() call.
Returns:
Figure: A Matplotlib figure of all of the subplots.
np.ndarray[Axes]: 2D object array of each subplot.
np.ndarray[torch.Tensor]: 2D object array of tensors representing the values of the sliced
dimensions for each subplot.
np.ndarray[torch.Tensor]: 2D object array of tensors representing the sliced predictions for
each subplot.
"""
prediction, lb, ub = _tensor_cast(prediction, lb, ub)
not_axes = [i for i in range(len(lb)) if i not in plotted_axes]
if len(not_axes) not in [1, 2]:
raise ValueError("Only 3 and 4 dimensional spaces can use this function.")
if "layout" in kwargs:
layout = kwargs.pop("layout")
warnings.warn(
"The layout arg for subplots is defaulted to 'constrained', be careful when changing this"
)
else:
layout = "constrained"
fig, axes = plt.subplots(
nrows=nrows,
ncols=ncols,
figsize=(plot_size * ncols, plot_size * nrows),
layout=layout,
**kwargs,
)
slice_template = [
None if idx in not_axes else slice(0, prediction.shape[idx])
for idx in range(len(lb))
]
slice_vals = np.empty_like(axes)
slice_predictions = np.empty_like(axes)
if len(axes.shape) > 1:
row_slices = np.linspace(lb[not_axes[0]], ub[not_axes[0]], nrows)
col_slices = np.linspace(lb[not_axes[1]], ub[not_axes[1]], ncols)
row_idxs = np.linspace(0, prediction.shape[0] - 1, nrows, dtype=int)
col_idxs = np.linspace(0, prediction.shape[0] - 1, ncols, dtype=int)
for idx in np.ndindex(axes.shape):
slice_vals[idx] = [row_slices[idx[0]], col_slices[idx[1]]]
tmp_slice = slice_template[:]
tmp_slice[tmp_slice.index(None)] = row_idxs[idx[0]]
tmp_slice[tmp_slice.index(None)] = col_idxs[idx[1]]
slice_predictions[idx] = prediction[tmp_slice]
else:
nPlots = nrows if nrows != 1 else ncols
slices = np.linspace(lb[not_axes[0]], ub[not_axes[0]], nPlots)
idxs = np.linspace(0, prediction.shape[0] - 1, nPlots, dtype=int)
for idx in np.ndindex(axes.shape):
slice_vals[idx] = slices[idx[0]]
tmp_slice = slice_template[:]
tmp_slice[tmp_slice.index(None)] = idxs[idx[0]]
slice_predictions[idx] = prediction[tmp_slice]
return fig, axes, slice_vals, slice_predictions
def _point_labeler(
ax: Axes,
responses: torch.Tensor,
cmap: matplotlib.colors.Colormap,
norm: matplotlib.colors.Normalize,
legend_loc: str,
):
# Given responses, create some way to indicate what point color means
unique_responses = responses.unique()
if len(unique_responses) <= 6: # Make a legend, probably categorical
handles = []
for res in unique_responses:
handle = matplotlib.lines.Line2D(
[0],
[0],
label=np.around(res.item(), decimals=1),
marker="o",
linestyle="",
color=cmap(norm(res)),
)
handles.append(handle)
ax.legend(handles=handles, loc=legend_loc)
else: # Make a colorbar, probably continuous
assert ax.figure is not None # for mypy, unlikely to actually happen
mappable = matplotlib.cm.ScalarMappable(cmap=cmap, norm=norm)
ax.figure.colorbar(mappable, ax=ax)
def _tensor_cast(*objs: Any) -> Tuple[torch.Tensor, ...]:
# Turns objects into tensors if possible
casted_objs: List[Any] = []
for obj in objs:
try:
if not isinstance(obj, torch.Tensor) and hasattr(
obj, "__iter__"
): # Checks if iterable
casted_objs.append(torch.tensor(obj))
else:
casted_objs.append(obj)
except (ValueError, TypeError):
casted_objs.append(obj)
return tuple(casted_objs)
[docs]def plot_strat(
strat: Strategy,
ax: Optional[plt.Axes] = None,
true_testfun: Optional[Callable] = None,
cred_level: float = 0.95,
target_level: Optional[float] = 0.75,
xlabel: Optional[str] = None,
ylabel: Optional[str] = None,
yes_label: str = "Yes trial",
no_label: str = "No trial",
flipx: bool = False,
logx: bool = False,
gridsize: int = 30,
title: str = "",
save_path: Optional[str] = None,
show: bool = True,
include_legend: bool = True,
include_colorbar: bool = True,
) -> None:
"""Creates a plot of a strategy, showing participants responses on each trial, the estimated response function and
threshold, and optionally a ground truth response threshold.
Args:
strat (Strategy): Strategy object to be plotted. Must have a dimensionality of 2 or less.
ax (plt.Axes, optional): Matplotlib axis to plot on (if None, creates a new axis). Default: None.
true_testfun (Callable, optional): Ground truth response function. Should take a n_samples x n_parameters tensor
as input and produce the response probability at each sample as output. Default: None.
cred_level (float): Percentage of posterior mass around the mean to be shaded. Default: 0.95.
target_level (float, optional): Response probability to estimate the threshold of. Default: 0.75.
xlabel (str, optional): Label of the x-axis. Default: "Context (abstract)".
ylabel (str, optional): Label of the y-axis (if None, defaults to "Response Probability" for 1-d plots or
"Intensity (Abstract)" for 2-d plots). Default: None.
yes_label (str): Label of trials with response of 1. Default: "Yes trial".
no_label (str): Label of trials with response of 0. Default: "No trial".
flipx (bool): Whether the values of the x-axis should be flipped such that the min becomes the max and vice
versa.
(Only valid for 2-d plots.) Default: False.
logx (bool): Whether the x-axis should be log-transformed. (Only valid for 2-d plots.) Default: False.
gridsize (int): The number of points to sample each dimension at. Default: 30.
title (str): Title of the plot. Default: ''.
save_path (str, optional): File name to save the plot to. Default: None.
show (bool): Whether the plot should be shown in an interactive window. Default: True.
include_legend (bool): Whether to include the legend in the figure. Default: True.
include_colorbar (bool): Whether to include the colorbar indicating the probability of "Yes" trials.
Default: True.
"""
warnings.warn(
"Plotting directly from strategy is deprecated, plots should be composed manually using the Matplotlib API, AEPsych specific helper functions are available in the plotting submodule.",
DeprecationWarning,
)
assert (
"binary" in strat.outcome_types
), f"Plotting not supported for outcome_type {strat.outcome_types[0]}"
if target_level is not None and not hasattr(strat.model, "monotonic_idxs"):
warnings.warn(
"Threshold estimation may not be accurate for non-monotonic models."
)
if ax is None:
_, ax = plt.subplots()
if xlabel is None:
xlabel = "Context (abstract)"
dim = strat.dim
if dim == 1:
if ylabel is None:
ylabel = "Response Probability"
_plot_strat_1d(
strat,
ax,
true_testfun,
cred_level,
target_level,
xlabel,
ylabel,
yes_label,
no_label,
gridsize,
)
elif dim == 2:
if ylabel is None:
ylabel = "Intensity (abstract)"
_plot_strat_2d(
strat,
ax,
true_testfun,
cred_level,
target_level,
xlabel,
ylabel,
yes_label,
no_label,
flipx,
logx,
gridsize,
include_colorbar,
)
elif dim == 3:
raise RuntimeError("Use plot_strat_3d for 3d plots!")
else:
raise NotImplementedError("No plots for >3d!")
ax.set_title(title)
if include_legend:
anchor = (1.4, 0.5) if include_colorbar and dim > 1 else (1, 0.5)
plt.legend(loc="center left", bbox_to_anchor=anchor)
if save_path is not None:
plt.savefig(save_path, bbox_inches="tight")
if show:
plt.tight_layout()
if include_legend or (include_colorbar and dim > 1):
plt.subplots_adjust(left=0.1, bottom=0.25, top=0.75)
plt.show()
def _plot_strat_1d(
strat: Strategy,
ax: plt.Axes,
true_testfun: Optional[Callable],
cred_level: float,
target_level: Optional[float],
xlabel: str,
ylabel: str,
yes_label: str,
no_label: str,
gridsize: int,
) -> plt.Axes:
"""Helper function for creating 1-d plots. See plot_strat for an explanation of the arguments.
Args:
strat (Strategy): Strategy object to be plotted. Must have a dimensionality of 1.
ax (plt.Axes): Matplotlib axis to plot on
true_testfun (Callable, optional): Ground truth response function. Should take a n_samples x n_parameters tensor
as input and produce the response probability at each sample as output. Default: None.
cred_level (float): Percentage of posterior mass around the mean to be shaded. Default: 0.95.
target_level (float, optional): Response probability to estimate the threshold of. Default: 0.75.
xlabel (str): Label of the x-axis. Default: "Context (abstract)".
ylabel (str): Label of the y-axis (if None, defaults to "Response Probability" for 1-d plots or
"Intensity (Abstract)" for 2-d plots). Default: None.
yes_label (str): Label of trials with response of 1. Default: "Yes trial".
no_label (str): Label of trials with response of 0. Default: "No trial".
gridsize (int): The number of points to sample each dimension at. Default: 30.
Returns:
plt.Axes: The axis object with the plot.
"""
x, y = strat.x, strat.y
assert x is not None and y is not None, "No data to plot!"
if strat.model is not None:
grid = dim_grid(lower=strat.lb, upper=strat.ub, gridsize=gridsize).cpu()
samps = norm.cdf(strat.model.sample(grid, num_samples=10000).detach())
phimean = samps.mean(0)
else:
raise RuntimeError("Cannot plot without a model!")
ax.plot(np.squeeze(grid), phimean)
if cred_level is not None:
upper = np.quantile(samps, cred_level, axis=0)
lower = np.quantile(samps, 1 - cred_level, axis=0)
ax.fill_between(
np.squeeze(grid),
lower,
upper,
alpha=0.3,
hatch="///",
edgecolor="gray",
label=f"{cred_level*100:.0f}% posterior mass",
)
if target_level is not None:
from aepsych.utils import interpolate_monotonic
threshold_samps = [
interpolate_monotonic(grid, s, target_level, strat.lb[0], strat.ub[0])
for s in samps
]
thresh_med = np.mean(threshold_samps)
thresh_lower = np.quantile(threshold_samps, q=1 - cred_level)
thresh_upper = np.quantile(threshold_samps, q=cred_level)
ax.errorbar(
thresh_med,
target_level,
xerr=np.r_[thresh_med - thresh_lower, thresh_upper - thresh_med][:, None],
capsize=5,
elinewidth=1,
label=f"Est. {target_level*100:.0f}% threshold \n(with {cred_level*100:.0f}% posterior \nmass marked)",
)
if true_testfun is not None:
true_f = true_testfun(grid)
ax.plot(grid, true_f.squeeze(), label="True function")
if target_level is not None:
true_thresh = interpolate_monotonic(
grid,
true_f.squeeze(),
target_level,
strat.lb[0],
strat.ub[0],
)
ax.plot(
true_thresh,
target_level,
"o",
label=f"True {target_level*100:.0f}% threshold",
)
ax.scatter(
x[y == 0, 0],
np.zeros_like(x[y == 0, 0]),
marker="3",
color="r",
label=no_label,
)
ax.scatter(
x[y == 1, 0],
np.zeros_like(x[y == 1, 0]),
marker="3",
color="b",
label=yes_label,
)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
return ax
def _plot_strat_2d(
strat: Strategy,
ax: plt.Axes,
true_testfun: Optional[Callable],
cred_level: float,
target_level: Optional[float],
xlabel: str,
ylabel: str,
yes_label: str,
no_label: str,
flipx: bool,
logx: bool,
gridsize: int,
include_colorbar: bool,
):
"""Helper function for creating 2-d plots. See plot_strat for an explanation of the arguments.
Args:
strat (Strategy): Strategy object to be plotted. Must have a dimensionality of 2.
ax (plt.Axes): Matplotlib axis to plot on
true_testfun (Callable, optional): Ground truth response function. Should take a n_samples x n_parameters tensor
as input and produce the response probability at each sample as output. Default: None.
cred_level (float): Percentage of posterior mass around the mean to be shaded. Default: 0.95.
target_level (float, optional): Response probability to estimate the threshold of. Default: 0.75.
xlabel (str): Label of the x-axis. Default: "Context (abstract)".
ylabel (str): Label of the y-axis (if None, defaults to "Response Probability" for 1-d plots or
"Intensity (Abstract)" for 2-d plots). Default: None.
yes_label (str): Label of trials with response of 1. Default: "Yes trial".
no_label (str): Label of trials with response of 0. Default: "No trial".
flipx (bool): Whether the values of the x-axis should be flipped such that the min becomes the max and vice
versa.
(Only valid for 2-d plots.) Default: False.
logx (bool): Whether the x-axis should be log-transformed. (Only valid for 2-d plots.) Default: False.
gridsize (int): The number of points to sample each dimension at. Default: 30.
include_colorbar (bool): Whether to include the colorbar indicating the probability of "Yes" trials.
Default: True.
"""
x, y = strat.x, strat.y
assert x is not None and y is not None, "No data to plot!"
# make sure the model is fit well if we've been limiting fit time
if strat.model is not None:
strat.model.fit(train_x=x, train_y=y, max_fit_time=None)
grid = dim_grid(lower=strat.lb, upper=strat.ub, gridsize=gridsize).cpu()
fmean, _ = strat.model.predict(grid)
phimean = norm.cdf(fmean.reshape(gridsize, gridsize).detach().numpy()).T
else:
raise RuntimeError("Cannot plot without a model!")
lb = strat.transforms.untransform(strat.lb)
ub = strat.transforms.untransform(strat.ub)
extent = np.r_[lb[0], ub[0], lb[1], ub[1]]
colormap = ax.imshow(
phimean, aspect="auto", origin="lower", extent=extent, alpha=0.5
)
if flipx:
extent = np.r_[lb[0], ub[0], ub[1], lb[1]]
colormap = ax.imshow(
phimean, aspect="auto", origin="upper", extent=extent, alpha=0.5
)
else:
extent = np.r_[lb[0], ub[0], lb[1], ub[1]]
colormap = ax.imshow(
phimean, aspect="auto", origin="lower", extent=extent, alpha=0.5
)
# hacky relabel to be in logspace
if logx:
locs: np.ndarray = np.arange(lb[0], ub[0])
ax.set_xticks(ticks=locs)
ax.set_xticklabels(2.0**locs)
ax.plot(x[y == 0, 0], x[y == 0, 1], "ro", alpha=0.7, label=no_label)
ax.plot(x[y == 1, 0], x[y == 1, 1], "bo", alpha=0.7, label=yes_label)
if target_level is not None: # plot threshold
mono_grid = np.linspace(lb[1], ub[1], num=gridsize)
context_grid = np.linspace(lb[0], ub[0], num=gridsize)
thresh_75, lower, upper = get_lse_interval(
model=strat.model,
mono_grid=mono_grid,
target_level=target_level,
grid_lb=strat.lb,
grid_ub=strat.ub,
cred_level=cred_level,
mono_dim=1,
lb=mono_grid.min(),
ub=mono_grid.max(),
gridsize=gridsize,
)
ax.plot(
context_grid,
thresh_75.cpu().numpy(),
label=f"Est. {target_level*100:.0f}% threshold \n(with {cred_level*100:.0f}% posterior \nmass shaded)",
)
ax.fill_between(
context_grid,
lower.cpu().numpy(),
upper.cpu().numpy(),
alpha=0.3,
hatch="///",
edgecolor="gray",
)
if true_testfun is not None:
true_f = true_testfun(grid).reshape(gridsize, gridsize)
true_thresh = (
get_lse_contour(
true_f,
mono_grid,
level=target_level,
lb=strat.lb[-1],
ub=strat.ub[-1],
)
.cpu()
.numpy()
)
ax.plot(context_grid, true_thresh, label="Ground truth threshold")
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
if include_colorbar:
colorbar = plt.colorbar(colormap, ax=ax)
colorbar.set_label(f"Probability of {yes_label}")
[docs]def plot_strat_3d(
strat: Strategy,
parnames: Optional[List[str]] = None,
outcome_label: str = "Yes Trial",
slice_dim: int = 0,
slice_vals: Union[List[float], int] = 5,
contour_levels: Optional[Union[Iterable[float], bool]] = None,
probability_space: bool = False,
gridsize: int = 30,
extent_multiplier: Optional[List[float]] = None,
save_path: Optional[str] = None,
show: bool = True,
) -> None:
"""Creates a plot of a 2d slice of a 3D strategy, showing the estimated model or probability response and contours
Args:
strat (Strategy): Strategy object to be plotted. Must have a dimensionality of 3.
parnames (List[str], optional): list of the parameter names. If None, defaults to ["x1", "x2", "x3"].
outcome_label (str): The label of the outcome variable
slice_dim (int): dimension to slice on. Default: 0.
slice_vals (Union[List[float], int]): values to take slices; OR number of values to take even slices from. Default: 5.
contour_levels (Union[Iterable[float], bool], optional): List contour values to plot. Default: None. If true, all integer levels.
probability_space (bool): Whether to plot probability. Default: False
gridsize (int): The number of points to sample each dimension at. Default: 30.
extent_multiplier (List[float], optional): multipliers for each of the dimensions when plotting. If None, defaults to [1, 1, 1].
save_path (str, optional): File name to save the plot to. Default: None.
show (bool): Whether the plot should be shown in an interactive window. Default: True.
"""
warnings.warn(
"Plotting directly from strategy is deprecated, plots should be composed manually using the Matplotlib API, AEPsych specific helper functions are available in the plotting submodule.",
DeprecationWarning,
)
assert strat.model is not None, "Cannot plot without a model!"
contour_levels_list: List[float] = []
if parnames is None:
parnames = ["x1", "x2", "x3"]
# Get global min/max for all slices
if probability_space:
vmax = 1
vmin = 0
if contour_levels is True:
contour_levels_list = [0.75]
else:
d = make_scaled_sobol(strat.lb, strat.ub, 2000)
post = strat.model.posterior(d)
fmean = post.mean.squeeze().detach().numpy()
vmax = np.max(fmean)
vmin = np.min(fmean)
if contour_levels is True:
contour_levels_list = list(np.arange(np.ceil(vmin), vmax + 1))
if not isinstance(contour_levels_list, Sized):
raise TypeError("contour_levels_list must be Sized (e.g., a list or an array).")
# slice_vals is either a list of values or an integer number of values to slice on
if isinstance(slice_vals, int):
slices = np.linspace(strat.lb[slice_dim], strat.ub[slice_dim], slice_vals)
slices = np.around(slices, 4)
elif not isinstance(slice_vals, list):
raise TypeError("slice_vals must be either an integer or a list of values")
else:
slices = np.array(slice_vals)
# make mypy happy, note that this can't be more specific
# because of https://github.com/numpy/numpy/issues/24738
axs: np.ndarray[Any, Any]
_, axs = plt.subplots(1, len(slices), constrained_layout=True, figsize=(20, 3)) # type: ignore
assert len(slices) > 1, "Must have at least 2 slices"
for _i, dim_val in enumerate(slices):
img = plot_slice(
axs[_i],
strat,
parnames,
slice_dim,
dim_val,
vmin,
vmax,
gridsize,
contour_levels_list,
probability_space,
extent_multiplier,
)
plt_parnames = np.delete(parnames, slice_dim)
axs[0].set_ylabel(plt_parnames[1])
cbar = plt.colorbar(img, ax=axs[-1])
if probability_space:
cbar.ax.set_ylabel(f"Probability of {outcome_label}")
else:
cbar.ax.set_ylabel(outcome_label)
for clevel in contour_levels_list: # type: ignore
cbar.ax.axhline(y=clevel, c="w")
if save_path is not None:
plt.savefig(save_path)
if show:
plt.show()
[docs]def plot_slice(
ax: Axes,
strat: Strategy,
parnames: List[str],
slice_dim: int,
slice_val: int,
vmin: float,
vmax: float,
gridsize: int = 30,
contour_levels: Optional[Sized] = None,
lse: bool = False,
extent_multiplier: Optional[List] = None,
) -> AxesImage:
"""Creates a plot of a 2d slice of a 3D strategy, showing the estimated model or probability response and contours
Args:
ax (plt.Axes): Matplotlib axis to plot on
start (Strategy): Strategy object to be plotted. Must have a dimensionality of 3.
parnames (List[str]): list of the parameter names.
slice_dim (int): dimension to slice on.
slice_val (int): value to take the slice along that dimension.
vmin (float): global model minimum to use for plotting.
vmax (float): global model maximum to use for plotting.
gridsize (int): The number of points to sample each dimension at. Default: 30.
contour_levels (Sized, optional): Contours to plot. Default: None
lse (bool): Whether to plot probability. Default: False
extent_multiplier (List, optional): multipliers for each of the dimensions when plotting. Default:None
Returns:
AxesImage: The axis object with the plot.
"""
extent = np.c_[strat.lb, strat.ub].reshape(-1)
if strat.model is not None:
x = dim_grid(
lower=strat.lb,
upper=strat.ub,
gridsize=gridsize,
slice_dims={slice_dim: slice_val},
).cpu()
else:
raise RuntimeError("Cannot plot without a model!")
if lse:
fmean, fvar = strat.predict(x)
fmean = fmean.detach().numpy().reshape(gridsize, gridsize)
fmean = norm.cdf(fmean)
else:
post = strat.model.posterior(x)
fmean = post.mean.squeeze().detach().numpy().reshape(gridsize, gridsize)
# optionally rescale extents to correct values
if extent_multiplier is not None:
extent_scaled = extent * np.repeat(extent_multiplier, 2)
dim_val_scaled = slice_val * extent_multiplier[slice_dim]
else:
extent_scaled = extent
dim_val_scaled = slice_val
plt_extents = np.delete(extent_scaled, [slice_dim * 2, slice_dim * 2 + 1])
plt_parnames = np.delete(parnames, slice_dim)
img = ax.imshow(
fmean.T,
extent=tuple(plt_extents),
origin="lower",
aspect="auto",
vmin=vmin,
vmax=vmax,
)
ax.set_title(parnames[slice_dim] + "=" + str(dim_val_scaled))
ax.set_xlabel(plt_parnames[0])
if contour_levels is not None:
if len(contour_levels) > 0:
ax.contour(
fmean.T,
contour_levels,
colors="w",
extent=plt_extents,
origin="lower",
aspect="auto",
)
else:
raise (ValueError("Countour Levels should not be None!"))
return img