Source code for eqc_models.algorithms.alm
- from dataclasses import dataclass
- from typing import Callable, Dict, List, Tuple, Optional, Sequence, Union
- import numpy as np
- from eqc_models.base.polynomial import PolynomialModel
- Array = np.ndarray
- PolyTerm = Tuple[Tuple[int, ...], float]
- [docs]
- @dataclass
- class ALMConstraint:
- """One constraint family; fun returns a vector; jac returns its Jacobian."""
- kind: str
- fun: Callable[[Array], Array]
- jac: Optional[Callable[[Array], Array]] = None
- name: str = ""
- [docs]
- @dataclass
- class ALMBlock:
- """Lifted discrete variable block (optional)."""
- idx: Sequence[int]
- levels: Array
- enforce_sum_to_one: bool = True
- enforce_one_hot: bool = True
- [docs]
- @dataclass
- class ALMConfig:
-
- rho_h: float = 50.0
- rho_g: float = 50.0
- rho_min: float = 1e-3
- rho_max: float = 1e3
-
- adapt: bool = True
- tau_up_h: float = 0.90
- tau_down_h: float = 0.50
- tau_up_g: float = 0.90
- tau_down_g: float = 0.50
- gamma_up: float = 2.0
- gamma_down: float = 1.0
-
- tol_h: float = 1e-6
- tol_g: float = 1e-6
- max_outer: int = 100
-
- use_stagnation_bump: bool = True
- patience_h: int = 10
- patience_g: int = 10
- stagnation_factor: float = 1e-3
-
- ema_alpha: float = 0.3
-
- fd_eps: float = 1e-6
-
- act_tol: float = 1e-10
- [docs]
- class ConstraintRegistry:
- """
- Holds constraints and block metadata; keeps ALMAlgorithm stateless. Register constraints and
- (optional) lifted-discrete blocks here.
- """
- def __init__(self):
- self.constraints: List[ALMConstraint] = []
- self.blocks: List[ALMBlock] = []
- [docs]
- def add_equality(self, fun, jac=None, name=""):
- self.constraints.append(ALMConstraint("eq", fun, jac, name))
- [docs]
- def add_inequality(self, fun, jac=None, name=""):
- self.constraints.append(ALMConstraint("ineq", fun, jac, name))
- [docs]
- def add_block(self, idx: Sequence[int], levels: Array, sum_to_one=True, one_hot=True):
- self.blocks.append(ALMBlock(list(idx), np.asarray(levels, float), sum_to_one, one_hot))
- [docs]
- class ALMAlgorithm:
- """Stateless ALM outer loop. Call `run(model, registry, core, cfg, **core_kwargs)`."""
-
- @staticmethod
- def _finite_diff_jac(fun: Callable[[Array], Array], x: Array, eps: float) -> Array:
- y0 = fun(x)
- m = int(np.prod(y0.shape))
- y0 = y0.reshape(-1)
- n = x.size
- J = np.zeros((m, n), dtype=float)
- for j in range(n):
- xp = x.copy()
- xp[j] += eps
- J[:, j] = (fun(xp).reshape(-1) - y0) / eps
- return J
- @staticmethod
- def _pairwise_M(k: int) -> Array:
- return np.ones((k, k), dtype=float) - np.eye(k, dtype=float)
- @staticmethod
- def _sum_to_one_selector(n: int, idx: Sequence[int]) -> Array:
- S = np.zeros((1, n), dtype=float)
- S[0, np.array(list(idx), int)] = 1.0
- return S
- @staticmethod
- def _make_sum1_fun(S):
- return lambda x: S @ x - np.array([1.0])
- @staticmethod
- def _make_sum1_jac(S):
- return lambda x: S
- @staticmethod
- def _make_onehot_fun(sl, M):
- sl = np.array(sl, int)
- def _f(x):
- s = x[sl]
- return np.array([float(s @ (M @ s))])
- return _f
- @staticmethod
- def _make_onehot_jac(sl, M, n):
- sl = np.array(sl, int)
- def _J(x):
- s = x[sl]
- grad_blk = 2.0 * (M @ s)
- J = np.zeros((1, n), dtype=float)
- J[0, sl] = grad_blk
- return J
- return _J
- @staticmethod
- def _poly_value(poly_terms: List[PolyTerm], x: Array) -> float:
- val = 0.0
- for inds, coeff in poly_terms:
- prod = 1.0
- for j in inds:
- if j == 0:
- continue
- else:
- prod *= x[j - 1]
- val += coeff * prod
- return float(val)
- @staticmethod
- def _merge_poly(poly_terms: Optional[List[PolyTerm]], Q_aug: Optional[Array],
- c_aug: Optional[Array]) -> List[PolyTerm]:
- """
- Merge ALM's quadratic/linear increments (Q_aug, c_aug) into the base polynomial term list `poly_terms`.
- If 'poly_terms' is None, then turn x^T Q_aug x + c_aug^T x into polynomial monomials.
- Terms are of the form:
- ((0, i), w) for linear, ((i, j), w) for quadratic.
- """
- merged = list(poly_terms) if poly_terms is not None else []
- if Q_aug is not None:
- Qs = 0.5 * (Q_aug + Q_aug.T)
- n = Qs.shape[0]
- for i in range(n):
-
- if Qs[i, i] != 0.0:
- merged.append(((i + 1, i + 1), float(Qs[i, i])))
- for j in range(i + 1, n):
- q = 2.0 * Qs[i, j]
- if q != 0.0:
- merged.append(((i + 1, j + 1), float(q)))
- if c_aug is not None:
- for i, ci in enumerate(c_aug):
- if ci != 0.0:
- merged.append(((0, i + 1), float(ci)))
- return merged
-
- [docs]
- @staticmethod
- def run(
- model: PolynomialModel,
- registry: ConstraintRegistry,
- solver,
- cfg: ALMConfig = ALMConfig(),
- x0: Optional[Array] = None,
- *,
- parse_output=None,
- verbose: bool = True,
- **solver_kwargs,
- ) -> Dict[str, Union[Array, Dict[int, float], Dict]]:
- """
- Solve with ALM. Keep all ALM state local to this call (no global side-effects).
- Returns:
- {
- "x": final iterate,
- "decoded": {start_idx_of_block: level_value, ...} for lifted blocks,
- "hist": { "eq_inf": [...], "ineq_inf": [...], "obj": [...], "x": [...] }
- }
- """
- n = int(getattr(model, "n", len(getattr(model, "upper_bound", [])) or 0))
- x = (np.asarray(x0, float).copy() if x0 is not None else
- np.zeros(n, float))
- lb = getattr(model, "lower_bound", None)
- ub = getattr(model, "upper_bound", None)
-
- problem_eqs = [c for c in registry.constraints if c.kind == "eq"]
- problem_ineqs = [c for c in registry.constraints if c.kind == "ineq"]
-
-
- def _install_block_equalities() -> List[ALMConstraint]:
- eqs: List[ALMConstraint] = []
- for blk in registry.blocks:
- if blk.enforce_sum_to_one:
- S = ALMAlgorithm._sum_to_one_selector(n, blk.idx)
- eqs.append(ALMConstraint(
- "eq",
- fun=ALMAlgorithm._make_sum1_fun(S),
- jac=ALMAlgorithm._make_sum1_jac(S),
- name=f"sum_to_one_block_{blk.idx[0]}",
- ))
- if blk.enforce_one_hot:
- k = len(blk.idx)
- M = ALMAlgorithm._pairwise_M(k)
- eqs.append(ALMConstraint(
- "eq",
- fun=ALMAlgorithm._make_onehot_fun(blk.idx, M),
- jac=ALMAlgorithm._make_onehot_jac(blk.idx, M, n),
- name=f"onehot_block_{blk.idx[0]}",
- ))
- return eqs
- block_eqs = _install_block_equalities()
-
- full_eqs = problem_eqs + block_eqs
-
- lam_eq = []
- for csp in full_eqs:
- r0 = csp.fun(x).reshape(-1)
- lam_eq.append(np.zeros_like(r0, dtype=float))
-
- mu_ineq = []
- for csp in problem_ineqs:
- r0 = csp.fun(x).reshape(-1)
- mu_ineq.append(np.zeros_like(r0, dtype=float))
-
- rho_h, rho_g = cfg.rho_h, cfg.rho_g
- best_eq, best_ineq = np.inf, np.inf
- no_imp_eq = no_imp_ineq = 0
- prev_eq_inf, prev_ineq_inf = np.inf, np.inf
- eps = 1e-12
- hist = {"eq_inf": [], "ineq_inf": [], "obj": [], "x": [],
-
- "rho_h": [], "rho_g": [],
- }
- for k_idx, csp in enumerate(full_eqs):
- if csp.kind != "eq":
- continue
- hist[f"lam_eq_max_idx{k_idx}"] = []
- hist[f"lam_eq_min_idx{k_idx}"] = []
- for k_idx, csp in enumerate(problem_ineqs):
- if csp.kind != "ineq":
- continue
- hist[f"mu_ineq_max_idx{k_idx}"] = []
- hist[f"mu_ineq_min_idx{k_idx}"] = []
- for it in range(cfg.max_outer):
-
-
- base_terms: List[PolyTerm] = list(zip(model.polynomial.indices, model.polynomial.coefficients))
-
- Q_aug = np.zeros((n, n), dtype=float)
- c_aug = np.zeros(n, dtype=float)
- have_aug = False
-
- for k_idx, csp in enumerate(full_eqs):
- if csp.kind != "eq":
- continue
- h = csp.fun(x).reshape(-1)
- A = csp.jac(x) if csp.jac is not None else ALMAlgorithm._finite_diff_jac(csp.fun, x, cfg.fd_eps)
- A = np.atleast_2d(A)
- assert A.shape[1] == n, f"A has {A.shape[1]} cols, expected {n}"
-
- b = A @ x - h
- Qk = 0.5 * rho_h * (A.T @ A)
- ck = (A.T @ lam_eq[k_idx]) - rho_h * (A.T @ b)
- Q_aug += Qk
- c_aug += ck
- have_aug = True
-
- for k_idx, csp in enumerate(problem_ineqs):
- if csp.kind != "ineq":
- continue
- g = csp.fun(x).reshape(-1)
- G = csp.jac(x) if csp.jac is not None else ALMAlgorithm._finite_diff_jac(csp.fun, x, cfg.fd_eps)
- G = np.atleast_2d(G)
- assert G.shape[1] == n, f"G has {G.shape[1]} cols, expected {n}"
- d = G @ x - g
-
-
- y = G @ x - d + mu_ineq[k_idx] / rho_g
- active = (y > cfg.act_tol)
- if np.any(active):
- GA = G[active, :]
- muA = mu_ineq[k_idx][active]
- gA = g[active]
-
- Qk = 0.5 * rho_g * (GA.T @ GA)
-
- ck = (GA.T @ muA) - rho_g * (GA.T @ (GA @ x - gA))
- Q_aug += Qk
- c_aug += ck
- have_aug = True
-
- all_terms = ALMAlgorithm._merge_poly(base_terms, Q_aug if have_aug else None,
- c_aug if have_aug else None)
- idxs, coeffs = zip(*[(inds, w) for (inds, w) in all_terms]) if all_terms else ([], [])
- poly_model = PolynomialModel(list(coeffs), list(idxs))
- if lb is not None and hasattr(poly_model, "lower_bound"):
- poly_model.lower_bound = np.asarray(lb, float)
- if ub is not None and hasattr(poly_model, "upper_bound"):
- poly_model.upper_bound = np.asarray(ub, float)
- x_ws = x.copy()
-
-
- setattr(poly_model, "initial_guess", x_ws)
- setattr(poly_model, "warm_start", x_ws)
- setattr(poly_model, "x0", x_ws)
-
- out = solver.solve(poly_model, **solver_kwargs)
-
- if parse_output:
- x = parse_output(out)
- else:
-
- if isinstance(out, tuple) and len(out) == 2:
- _, x = out
- elif isinstance(out, dict) and "results" in out and "solutions" in out["results"]:
- x = out["results"]["solutions"][0]
- elif isinstance(out, dict) and "x" in out:
- x = out["x"]
- else:
- x = getattr(out, "x", out)
- x = np.asarray(x, float)
-
- eq_infs = []
- for k_idx, csp in enumerate(full_eqs):
- if csp.kind != "eq": continue
- r = csp.fun(x).reshape(-1)
- lam_eq[k_idx] = lam_eq[k_idx] + rho_h * r
- if r.size:
- eq_infs.append(np.max(np.abs(r)))
- eq_inf = float(np.max(eq_infs)) if eq_infs else 0.0
- ineq_infs = []
- for k_idx, csp in enumerate(problem_ineqs):
- if csp.kind != "ineq": continue
- r = csp.fun(x).reshape(-1)
- mu_ineq[k_idx] = np.maximum(0.0, mu_ineq[k_idx] + rho_g * r)
- if r.size:
- ineq_infs.append(np.max(np.maximum(0.0, r)))
- ineq_inf = float(np.max(ineq_infs)) if ineq_infs else 0.0
- assert len(lam_eq) == len(full_eqs)
- assert len(mu_ineq) == len(problem_ineqs)
-
- f_val = ALMAlgorithm._poly_value(base_terms, x)
- hist["eq_inf"].append(eq_inf); hist["ineq_inf"].append(ineq_inf)
- hist["obj"].append(float(f_val)); hist["x"].append(x.copy())
-
- hist["rho_h"].append(float(rho_h)); hist["rho_g"].append(float(rho_g))
- for k_idx, csp in enumerate(full_eqs):
- if csp.kind != "eq": continue
- hist[f"lam_eq_max_idx{k_idx}"].append(float(np.max(lam_eq[k_idx])))
- hist[f"lam_eq_min_idx{k_idx}"].append(float(np.min(lam_eq[k_idx])))
- for k_idx, csp in enumerate(problem_ineqs):
- if csp.kind != "ineq": continue
- hist[f"mu_ineq_max_idx{k_idx}"].append(float(np.max(mu_ineq[k_idx])))
- hist[f"mu_ineq_min_idx{k_idx}"].append(float(np.min(mu_ineq[k_idx])))
- if verbose:
- print(f"[ALM {it:02d}] f={f_val:.6g} | eq_inf={eq_inf:.2e} | ineq_inf={ineq_inf:.2e} "
- f"| rho_h={rho_h:.2e} | rho_g={rho_g:.2e}")
-
- if eq_inf <= cfg.tol_h and ineq_inf <= cfg.tol_g:
- if verbose:
- print(f"[ALM] converged at iter {it}")
- break
-
- if it == 0:
- eq_inf_smooth = eq_inf
- ineq_inf_smooth = ineq_inf
- else:
- eq_inf_smooth = cfg.ema_alpha * eq_inf + (1 - cfg.ema_alpha) * eq_inf_smooth
- ineq_inf_smooth = cfg.ema_alpha * ineq_inf + (1 - cfg.ema_alpha) * ineq_inf_smooth
-
- if cfg.adapt and it > 0:
-
- if eq_inf_smooth > cfg.tau_up_h * max(prev_eq_inf, eps):
- rho_h = min(cfg.gamma_up * rho_h, cfg.rho_max)
- elif eq_inf_smooth < cfg.tau_down_h * max(prev_eq_inf, eps):
- rho_h = max(cfg.gamma_down * rho_h, cfg.rho_min)
-
- if ineq_inf_smooth > cfg.tau_up_g * max(prev_ineq_inf, eps):
- rho_g = min(cfg.gamma_up * rho_g, cfg.rho_max)
- elif ineq_inf_smooth < cfg.tau_down_g * max(prev_ineq_inf, eps):
- rho_g = max(cfg.gamma_down * rho_g, cfg.rho_min)
-
- if cfg.use_stagnation_bump:
-
- if eq_inf <= best_eq * (1 - cfg.stagnation_factor):
- best_eq = eq_inf; no_imp_eq = 0
- else:
- no_imp_eq += 1
- if no_imp_eq >= cfg.patience_h:
- rho_h = min(2.0 * rho_h, cfg.rho_max); no_imp_eq = 0
-
- if ineq_inf <= best_ineq * (1 - cfg.stagnation_factor):
- best_ineq = ineq_inf; no_imp_ineq = 0
- else:
- no_imp_ineq += 1
- if no_imp_ineq >= cfg.patience_g:
- rho_g = min(2.0 * rho_g, cfg.rho_max); no_imp_ineq = 0
-
- prev_eq_inf = max(eq_inf_smooth, eps)
- prev_ineq_inf = max(ineq_inf_smooth, eps)
-
- decoded: Dict[int, Union[int, float]] = {}
- for blk in registry.blocks:
- sl = np.array(blk.idx, int)
- if len(sl) == 0:
- continue
- s = x[sl]
- j = int(np.argmax(s))
- decoded[sl[0]] = float(blk.levels[j])
- return {"x": x, "decoded": decoded, "hist": hist}