Source code for mlresearch.utils._visualization
"""
Functions for visualization formatting or producing pre-formatted
visualizations.
"""
from setuptools import distutils
import warnings
import numpy as np
from ._utils import _optional_import
[docs]
def set_matplotlib_style(font_size=8, use_latex=True, **rcparams):
"""
Load LaTeX-style configurations for Matplotlib Visualizations. You may pass
additional parameters to the rcParams as keyworded arguments.
Parameters
----------
font_size : int, default=8
Desired default font size displayed in visualizations. ``axes.labelsize`` and
``font.size`` will take 1.25x the size passed in ``font_size``, whereas
``legend.fontsize``, ``xtick.labelsize`` and ``ytick.labelsize`` will take the
value passed in this parameter.
use_latex : bool, default=True
Whether to use Latex to render visualizations. If ``True`` and a Latex
installation is found in the system, the text will be rendered using Latex and
math mode can be used. If ``True`` and no Latex installation is found,
``text.usetex`` will be set to ``False`` and an issue is raised. If ``False``,
``text.usetex`` will be set to ``False``.
Returns
-------
None : NoneType
"""
plt = _optional_import("matplotlib.pyplot")
# Replicates the rcParams of seaborn's "whitegrid" style and a few extra
# configurations I like
plt.style.use("seaborn-v0_8-whitegrid")
base_style = {
# "patch.edgecolor": "w",
# "patch.force_edgecolor": True,
# "xtick.bottom": False,
# "ytick.left": False,
"font.family": "Times",
# Use 10pt font in plots, to match 10pt font in document
"axes.labelsize": (10 / 8) * font_size,
"font.size": (10 / 8) * font_size,
# Make the legend/label fonts a little smaller
"legend.fontsize": font_size,
"xtick.labelsize": font_size,
"ytick.labelsize": font_size,
# Subplots size/shape
"figure.subplot.left": 0.098,
"figure.subplot.right": 0.938,
"figure.subplot.bottom": 0.12,
"figure.subplot.top": 0.944,
"figure.subplot.wspace": 0.071,
}
plt.rcParams.update(base_style)
if distutils.spawn.find_executable("latex") and use_latex:
tex_fonts = {
# Use LaTeX to write all text
"text.usetex": True,
}
plt.rcParams.update(tex_fonts)
elif use_latex:
warn_msg = (
"Could not find a LaTeX installation. ``text.usetex`` will be set to False."
)
warnings.warn(warn_msg)
# Used to pass any additional custom configurations
plt.rcParams.update(rcparams)
[docs]
def feature_to_color(col, cmap="RdYlBu_r"):
"""
Converts a column of values to hex-type colors.
Parameters
----------
col : {list, array-like} of shape (n_samples,)
Values to convert to hex-type color code
cmap : str or `~matplotlib.colors.Colormap`
The colormap used to map normalized data values to RGBA colors
Returns
-------
colors : array-like of shape (n_samples,)
Array with hex values as string type.
"""
colors = _optional_import("matplotlib.colors")
cm = _optional_import("matplotlib.cm")
if type(col) is list:
col = np.array(col)
norm = colors.Normalize(vmin=col.min(), vmax=col.max(), clip=True)
mapper = cm.ScalarMappable(norm=norm, cmap=cmap)
rgba = mapper.to_rgba(col)
return np.apply_along_axis(colors.rgb2hex, 1, rgba)