Source code for research.utils._visualization

"""
Functions for visualization formatting or producing pre-formatted
visualizations.
"""

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import rgb2hex, Normalize
from matplotlib.cm import ScalarMappable


[docs]def load_plt_sns_configs(font_size=8): """ Load LaTeX style configurations for Matplotlib/Seaborn Visualizations. """ sns.set_style("whitegrid") tex_fonts = { # Use LaTeX to write all text "text.usetex": True, "font.family": "serif", # Use 10pt font in plots, to match 10pt font in document "axes.labelsize": (10 / 8) * font_size, "font.size": (10 / 8) * font_size, # Make the legend/label fonts a little smaller "legend.fontsize": font_size, "xtick.labelsize": font_size, "ytick.labelsize": font_size, # Subplots size/shape "figure.subplot.left": 0.098, "figure.subplot.right": 0.938, "figure.subplot.bottom": 0.12, "figure.subplot.top": 0.944, "figure.subplot.wspace": 0.071, "figure.subplot.hspace": 0.2, } plt.rcParams.update(tex_fonts)
[docs]def val_to_color(col, cmap="RdYlBu_r"): """ Converts a column of values to hex-type colors. Parameters ---------- col : array-like of shape (n_samples,) Values to convert to hex-type color code cmap : str or `~matplotlib.colors.Colormap` The colormap used to map normalized data values to RGBA colors Returns ------- colors : array-like of shape (n_samples,) Array with hex values as string type. """ norm = Normalize(vmin=col.min(), vmax=col.max(), clip=True) mapper = ScalarMappable(norm=norm, cmap=cmap) rgba = mapper.to_rgba(col) return np.apply_along_axis(rgb2hex, 1, rgba)