Source code for mlresearch.utils._visualization
"""
Functions for visualization formatting or producing pre-formatted
visualizations.
"""
from distutils.spawn import find_executable
import warnings
import types
import numpy as np
def _optional_import(module: str) -> types.ModuleType:
"""
Import an optional dependency.
Parameters
----------
module : str
The identifier for the backend. Either an entrypoint item registered
with importlib.metadata, "matplotlib", or a module name.
Returns
-------
types.ModuleType
The imported backend.
"""
# This function was adapted from the _load_backend function from the pandas.plotting
# source code.
import importlib
# Attempt an import of an optional dependency here and raise an ImportError if
# needed.
try:
module_ = importlib.import_module(module)
except ImportError:
mod = module.split(".")[0]
raise ImportError(f"{mod} is required to use this functionality.") from None
return module_
[docs]def set_matplotlib_style(font_size=8, **rcparams):
"""
Load LaTeX-style configurations for Matplotlib Visualizations.
"""
plt = _optional_import("matplotlib.pyplot")
# Replicates the rcParams of seaborn's "whitegrid" style and a few extra
# configurations I like
plt.style.use("seaborn-v0_8-whitegrid")
base_style = {
# "patch.edgecolor": "w",
# "patch.force_edgecolor": True,
# "xtick.bottom": False,
# "ytick.left": False,
"font.family": "serif",
# Use 10pt font in plots, to match 10pt font in document
"axes.labelsize": (10 / 8) * font_size,
"font.size": (10 / 8) * font_size,
# Make the legend/label fonts a little smaller
"legend.fontsize": font_size,
"xtick.labelsize": font_size,
"ytick.labelsize": font_size,
# Subplots size/shape
"figure.subplot.left": 0.098,
"figure.subplot.right": 0.938,
"figure.subplot.bottom": 0.12,
"figure.subplot.top": 0.944,
"figure.subplot.wspace": 0.071,
}
plt.rcParams.update(base_style)
if find_executable("latex"):
tex_fonts = {
# Use LaTeX to write all text
"text.usetex": True,
}
plt.rcParams.update(tex_fonts)
else:
warn_msg = (
"Could not find a LaTeX installation. ``text.usetex`` will be set to False."
)
warnings.warn(warn_msg)
# Used to pass any additional custom configurations
plt.rcParams.update(rcparams)
[docs]def feature_to_color(col, cmap="RdYlBu_r"):
"""
Converts a column of values to hex-type colors.
Parameters
----------
col : {list, array-like} of shape (n_samples,)
Values to convert to hex-type color code
cmap : str or `~matplotlib.colors.Colormap`
The colormap used to map normalized data values to RGBA colors
Returns
-------
colors : array-like of shape (n_samples,)
Array with hex values as string type.
"""
colors = _optional_import("matplotlib.colors")
cm = _optional_import("matplotlib.cm")
if type(col) == list:
col = np.array(col)
norm = colors.Normalize(vmin=col.min(), vmax=col.max(), clip=True)
mapper = cm.ScalarMappable(norm=norm, cmap=cmap)
rgba = mapper.to_rgba(col)
return np.apply_along_axis(colors.rgb2hex, 1, rgba)