Source code for gana.utils.draw

"""Plotting Utilities"""

from __future__ import annotations

from typing import TYPE_CHECKING

from matplotlib import pyplot as plt
from matplotlib import rc

if TYPE_CHECKING:
    from ..sets.parameter import P
    from ..sets.variable import V


[docs] def draw( element: V | P, data: list[float] | None = None, kind: str = "line", font_size: float = 16, fig_size: tuple[float, float] = (12, 6), linewidth: float = 0.7, color: str = "blue", grid_alpha: float = 0.3, usetex: bool = True, str_idx_lim: int = 10, ): """ Plot the variable set :param kind: Type of plot ['line', 'bar']. Defaults to 'line'. :type kind: str, optional :param font_size: Font size for the plot. Defaults to 16. :type font_size: float, optional :param fig_size: Size of the figure. Defaults to (12, 6). :type fig_size: tuple[float, float], optional :param linewidth: Width of the line in the plot. Defaults to 0.7. :type linewidth: float, optional :param color: Color of the line in the plot. Defaults to 'blue'. :type color: str, optional :param grid_alpha: Transparency of the grid lines. Defaults to 0.3. :type grid_alpha: float, optional :param usetex: Use LaTeX for text rendering. Defaults to True. :type usetex: bool, optional :param str_idx_lim: Limit for string indices display. Defaults to 10. :type str_idx_lim: int, optional """ ax = plt.subplots(figsize=fig_size)[1] # the values are the y-axis y = data _len = len(y) # the indices are the x-axis if _len <= str_idx_lim: x = [str(idx) for idx in element.map] else: x = list(range(len(y))) if usetex: rc( "font", **{"family": "serif", "serif": ["Computer Modern"], "size": font_size}, ) rc("text", usetex=usetex) else: rc("font", **{"size": font_size}) if kind == "line": ax.plot(x, y, linewidth=linewidth, color=color) elif kind == "bar": ax.bar(x, y, linewidth=linewidth, color=color) ax.set_title( rf"${element.latex()}$", ) ax.set_ylabel(r"Values") ax.set_xlabel(r"Indices") ax.grid(alpha=grid_alpha) if _len <= str_idx_lim: ax.set_xticks(x) ax.set_xticklabels( [ rf"${tuple([idx.ltx for idx in index])}$".replace("'", "").replace( "\\", "" ) for index in element.map ] ) plt.rcdefaults() return plt