"""
MUD Plotting Module
Plotting utility functions for visualizing data-sets and distributions related
to algorithm implemented within the MUD library.
Functions
---------
"""
from pathlib import Path
import numpy as np
from matplotlib import pyplot as plt # type: ignore
from scipy.stats.contingency import margins # type: ignore
from mud.util import null_space
# Matplotlib plotting options
mud_plot_params = {
"mathtext.fontset": "stix",
"font.family": "STIXGeneral",
"axes.labelsize": 14,
"axes.titlesize": 26,
"legend.fontsize": 16,
"xtick.labelsize": 14,
"ytick.labelsize": 14,
"axes.titlepad": 1,
"axes.labelpad": 1,
"font.size": 16,
"savefig.facecolor": "white",
"text.usetex": True,
"text.latex.preamble": " ".join(
[r"\usepackage{bm}", r"\usepackage{amsfonts}", r"\usepackage{amsmath}"]
),
}
plt.rcParams.update(mud_plot_params)
def _check_latex():
"""check latex installation"""
global mud_plot_params
path = Path.cwd() / ".test_fig.png"
plt.plot([0], [1], label=r"$a_\text{foo} = \lambda$")
try:
plt.legend()
plt.savefig(str(path))
path.unlink(missing_ok=True)
except RuntimeError:
print("NOT USING TEX")
mud_plot_params["text.usetex"] = False
mud_plot_params["text.latex.preamble"] = ""
plt.rcParams.update(mud_plot_params)
[docs]def plotChain(mud_chain, ref_param, color="k", s=100):
num_steps = len(mud_chain)
current_point = mud_chain[0]
plt.scatter(current_point[0], current_point[1], c="b", s=s)
for i in range(0, num_steps):
next_point = mud_chain[i]
points = np.hstack([current_point, next_point])
plt.plot(points[0, :], points[1, :], c=color)
current_point = next_point
plt.ylim([0, 1])
plt.xlim([0, 1])
# plt.axis('off')
plt.scatter(ref_param[0], ref_param[1], c="r", s=s)
[docs]def plot_contours(
A, ref_param, subset=None, color="k", ls=":", lw=1, fs=20, w=1, s=100, **kwds
):
if subset is None:
subset = np.arange(A.shape[0])
A = A[np.array(subset), :]
numQoI = A.shape[0]
AA = np.hstack([null_space(A[i, :].reshape(1, -1)) for i in range(numQoI)]).T
for i, contour in enumerate(subset):
xloc = [ref_param[0] - w * AA[i, 0], ref_param[1] + w * AA[i, 0]]
yloc = [ref_param[0] - w * AA[i, 1], ref_param[1] + w * AA[i, 1]]
plt.plot(xloc, yloc, c=color, ls=ls, lw=lw, **kwds)
plt.annotate("%d" % (contour + 1), (xloc[0], yloc[0]), fontsize=fs)
[docs]def make_2d_normal_mesh(N=50, window=1):
X = np.linspace(-window, window, N)
Y = np.linspace(-window, window, N)
X, Y = np.meshgrid(X, Y)
XX = np.vstack([X.ravel(), Y.ravel()]).T
return (X, Y, XX)
[docs]def make_2d_unit_mesh(N=50, window=1):
X = np.linspace(0, window, N)
Y = np.linspace(0, window, N)
X, Y = np.meshgrid(X, Y)
XX = np.vstack([X.ravel(), Y.ravel()]).T
return (X, Y, XX)
[docs]def plot_1D_vecs(vecs, markers=None, ax=None, label=True, **kwargs):
"""Plot components of 1D vectors"""
# Plot density of trained data
if ax is None:
fig = plt.figure(figsize=(10, 3))
ax = fig.add_subplot(1, 1, 1)
for i, v in enumerate(vecs):
if markers is not None:
kwargs["marker"] = markers[i]
plt.scatter(np.arange(len(v)), v, label=f"$p^{{({i+1})}}$", **kwargs)
if label:
ax.set_xlabel("i")
ax.set_ylabel("$p^{{(l)}}_i$")
ax.legend()
return ax
# TODO: replace make_2d_normal_Grid with this
[docs]def build_nd_mesh_grid(domain, aff=100):
"""
Build n-dimensional mesh grid
Parameters
----------
domain: List[List[int]]
List of `[min, max]` ranges for each parameter to construct grid over.
aff : int, default=100
Number of points to use in each dimension.
Returns
-------
(mesh, grid_points): Tuple[numpy.ndarray, numpy.ndarray]
Tuple consistent of n-dimensional grid of points `mesh` and the list of
coordinates $(x_1, x_2, ..., x_n)$ for each grid point along.
Examples
--------
Building 1d ``mesh"
>>> mesh_1d, pts_1d = build_nd_mesh_grid([[0,1]], aff=5)
>>> mesh_1d
[array([0. , 0.25, 0.5 , 0.75])]
>>> pts_1d
array([[0. , 0.25, 0.5 , 0.75]])
"""
mesh = np.meshgrid(*[np.linspace(i, j, aff)[:-1] for i, j in domain])
grid_points = np.vstack([x.ravel() for x in mesh])
return (mesh, grid_points)
[docs]def plot_dist(dist, domain, ax=None, idx=0, source="kde", aff=100, **kwargs):
"""
Plot a probability distribution over a given domain.
"""
# Create plot if one isn't passed in
_, ax = plt.subplots(1, 1) if ax is None else (None, ax)
# Mesh is a list of grids over n-dim space the distribution is defined over
mesh, grid = build_nd_mesh_grid(domain, aff=aff)
plot_x = np.linspace(domain[idx, 0], domain[idx, 1], aff)[: aff - 1]
# TODO: Add check that domain size matches dimension of dist
if source == "pdf":
# Compute observed distribution using stored pdf
plot_y = margins(np.reshape(dist.pdf(grid.T).prod(axis=1), mesh[idx].shape))[
idx
].reshape(-1)
elif source == "kde":
# Compute observed distribution using kernel density estimate
plot_y = margins(np.reshape(dist(grid), mesh[idx].shape).T)[idx].reshape(-1)
else:
raise ValueError("Source must be one of pdf or kde.")
# Scale distribution
total = np.sum(plot_y * np.abs(plot_x[1] - plot_x[0]))
plot_y = plot_y / total
# Plot updated distribution over parameter space
ax.plot(plot_x, plot_y, **kwargs)
return ax
[docs]def plot_vert_line(ax, x_loc, ylim=None, **kwargs):
"""Plot a vertical line on an existing axis at `x_loc`"""
ylims = list(ax.get_ylim())
ylims[1] = ylim if ylim is not None else ylims[1]
ax.plot([x_loc, x_loc], [ylims[0], ylims[1]], **kwargs)
ax.set_ylim(ylims)
_check_latex()