Source code for gana.sets.constraint

"""General Constraint Class"""

from __future__ import annotations

from copy import deepcopy
from typing import TYPE_CHECKING, Self

from IPython.display import Math, display

from .cases import FCase

if TYPE_CHECKING:
    from .function import F
    from .index import I
    from .parameter import P
    from .theta import T
    from .variable import V


[docs] class C: """ Represents a relationship between Parameters, Variables, or Expressions. This class is not intended to be used directly. It is constructed based on relationships between parameter sets, variable sets, or function sets. :param function: Function set :type function: F :param leq: If the constraint is less than or equal to. Defaults to False. :type leq: bool, optional :param parent: Parent constraint set. Defaults to None. :type parent: C, optional :param pos: Position of the constraint in the set. Defaults to None. :type pos: int, optional :param nn: If the constraint is non-negative. Defaults to False. :type nn: bool, optional :ivar _: List of constraints :vartype _: list[Cons] :ivar function: Function set :vartype function: F :ivar leq: If the constraint is less than or equal to :vartype leq: bool :ivar binding: If the constraint is binding :vartype binding: bool :ivar nn: If the constraint is non-negative :vartype nn: bool :ivar index: Index of the constraint set (product of all indices) :vartype index: P :ivar eq: If the constraint is an equality constraint :vartype eq: bool :ivar one: Element one in the function :vartype one: V | P :ivar two: Element two in the function :vartype two: V | P :ivar name: Name of the constraint (shows the operation) :vartype name: str :ivar n: Number of the set in the program :vartype n: int :ivar pname: Name given by user in program :vartype pname: str :raises ValueError: Adding constraints of different types (leq and eq) :raises ValueError: Subtracting constraints of different types (leq and eq) :raises ValueError: Cannot multiply constraints :raises ValueError: Cannot divide constraints """ def __init__( self, function: F | V, leq: bool = False, parent: C = None, pos: int = None, nn: bool = False, category: str = "General", ): if function.case == FCase.VAR: # if the function is a variable, the index needs to be made consistent # with what a function index looks lik function = function.make_function() self.function = function(*function.index) self.index = function.index # variables in the constraint self.variables = function.variables # index is the same as the function # whether the constraint is less than or equal to self.leq = leq # the map of indices and constraints self.map = function.map # and the structure self.struct = function.struct # if part of a constraint set self.parent = parent # position in the parent set self.pos = pos # if its a non-negativity constraint for a variable self.nn = nn # arguments to pass self.args = {"leq": self.leq, "nn": self.nn} # since indices should match, take any # whether the constraint is binding self.binding = False # position of the constraint in the cons_by of its variables self.cons_by_pos = {} if not self.nn: if self.function.case == FCase.NEGVAR and self.leq: self.nn = True else: self.nn = False if self.parent is None: # if this is a constraint set, birth constraints self._ = [ C(function=f, leq=self.leq, parent=self, pos=n, nn=self.nn) for n, f in enumerate(self.function) if f ] else: # single constraint of a constraint set self._ = [self] # number of the set in the program self.n: int = None # name given by user in program self.pname: str = None # category of the constraint # constraints can be printed by category self.category: str = category @property def name(self) -> str: if self.leq: return self.function.name + r"<=0" else: return self.function.name + r"=0" # ----------------------------------------------------- # Helpers # -----------------------------------------------------
[docs] def categorize(self, category: str): """Categorizes the constraint :param category: Category name :type category: str """ self.category = category for c in self._: c.category = category
[docs] def update_variables(self): """Update variables in the constraint set""" for cons in self._: for v in cons.variables: if v is not None: # update cons_by for variables of children in constraint cons.cons_by_pos[v] = len(v.cons_by) v.cons_by.append(cons)
# for v in self.variables: # if v is not None: # v.cons_by.append(self)
[docs] def copy(self) -> Self: """Copy the constraint set""" return deepcopy(self)
# ----------------------------------------------------- # Matrices # ----------------------------------------------------- @property def A(self) -> list[float | None]: """Variable Coefficients""" return self.function.A @property def P(self) -> list[None | int]: """Variables""" return self.function.P @property def B(self) -> float | None: """Constant""" return self.function.B @property def F(self) -> float | None: return self.function.F @property def Z(self) -> float | None: return self.function.Z @property def matrix(self) -> dict: """Matrix as dict""" return self.function.matrix # ----------------------------------------------------- # Form # ----------------------------------------------------- @property def eq(self): """Equality Constraint""" return not self.leq @property def one(self): """element one in function""" return self.function.one @property def two(self): """element two in function""" return self.function.two # ----------------------------------------------------- # Printing # -----------------------------------------------------
[docs] def mps(self): """Name in MPS file""" return f"C{self.n}"
[docs] def latex(self) -> str: """Latex representation""" if self.leq: rel = r"\leq" else: rel = r"=" return rf"[{self.n}]" + r"\text{ }" + rf"{self.function.latex()} {rel} 0"
[docs] def show(self, descriptive: bool = False): """Display the function""" if descriptive: for c in self._: display(Math(c.latex())) else: display(Math(self.latex()))
@property def longname(self) -> str: """Long name""" if self.leq: return f"{self.function.longname} <= 0" return f"{self.function.longname} == 0" # ----------------------------------------------------- # Solution # -----------------------------------------------------
[docs] def output(self, n_sol: int = 0, compare=False): """Solution""" if self.leq: if compare: for c in self._: display( Math( c.function.latex() + r"=" + ", ".join(str(val) for val in c.function.X.values()) ) ) else: for c in self._: display(Math(c.function.latex() + r"=" + rf"{c.function.X[n_sol]}"))
# ----------------------------------------------------- # Operators # ----------------------------------------------------- def __add__(self, other: V | P | T | F | int | float) -> Self: if isinstance(other, C): if self.leq != other.leq: raise ValueError( f"Cannot add constraints with different types: {self.leq} and {other.leq}" ) return C( function=self.function + other.function, leq=self.leq or other.leq, category=self.category, ) return C(function=self.function + other, leq=self.leq, category=self.category) def __radd__(self, other: V | P | T | F | int | float) -> Self: _ = self + other def __sub__(self, other: V | P | T | F | int | float) -> Self: if isinstance(other, C): if self.leq != other.leq: raise ValueError( f"Cannot subtract constraints with different types: {self.leq} and {other.leq}" ) return C( function=self.function - other.function, leq=self.leq or other.leq, category=self.category, ) return C(function=self.function - other, leq=self.leq, category=self.category) def __rsub__(self, other: V | P | T | F | int | float) -> Self: _ = self - other def __mul__(self, other: V | P | T | F | int | float) -> Self: if isinstance(other, C): raise ValueError("Cannot multiply constraints") return C(function=self.function * other, leq=self.leq) def __rmul__(self, other: V | P | T | F | int | float) -> Self: return C(function=self.function * other, leq=self.leq) def __truediv__(self, other: V | P | T | F | int | float) -> Self: if isinstance(other, C): raise ValueError("Cannot divide constraints") return C(function=self.function / other, leq=self.leq) # ----------------------------------------------------- # Vector # ----------------------------------------------------- def __call__(self, *key: list[I]) -> Self: if not key or (key == self.index): # if the index is an exact match # or no key is passed return self if self.function.case == FCase.VAR: return C(function=self.function(*key), **self.args) return C(function=self.function(key), **self.args) def __getitem__(self, pos: int) -> Self: return self._[pos] def __iter__(self) -> Self: return iter(self._)
[docs] def order(self) -> list: """order""" return len(self.index)
def __len__(self): return len(self._) # ----------------------------------------------------- # Hashing # ----------------------------------------------------- def __str__(self): return self.name def __repr__(self): return self.name def __hash__(self): try: return hash(self.name) except AttributeError: # Fallback for uninitialized state during unpickling return id(self)