Source code for gasfir.retrieval

"""Systematic nonlinear regression pipeline for GASFIR parameter retrieval.

Three-phase pipeline
--------------------
1. **Global search** — :func:`run_global_search`
   CMA-ES (default, ``cma`` package) or Differential Evolution (scipy,
   ``workers=-1``).  Finds the basin of attraction robustly without a good
   initial guess.

2. **Local polish** — :func:`run_local_polish`
   Levenberg–Marquardt least squares (lmfit).  Takes the DE/CMA-ES output and
   converges to the precise minimum quickly.

3. **Uncertainty quantification** — :func:`run_mcmc`
   ``emcee`` MCMC.  Samples the posterior and reports credible intervals.
   Supports an optional ``multiprocessing.Pool`` for walker-level parallelism.

Parallelisation notes
---------------------
The GASFIR kernel (``get_diabatic_ionization_probability_batch``) already
parallelises over pulses via Numba ``prange``, using all available CPU cores.
This is the dominant speedup for typical datasets.

For emcee, walker-level parallelism via ``multiprocessing.Pool`` is available
through the ``pool`` parameter of :func:`run_mcmc`.  Each worker process runs
the Numba kernel serially (``NUMBA_NUM_THREADS=1`` is set automatically in
workers); the overall CPU utilisation is comparable.  The pool approach
helps when *nwalkers ≫ ncores* and each individual kernel call is fast.

For differential evolution, ``scipy.optimize.differential_evolution`` is used
with ``workers=1`` by default (Numba handles inner-loop parallelism).  Passing
``workers=-1`` enables scipy's process-level parallelism, which can help when
individual residual evaluations are slow.

Typical usage
-------------
::

    from gasfir import create_pulse, ret_pulse_from_pandas_table
    from gasfir.retrieval import RetrievalConfig, retrieve

    data = pd.read_csv("my_data.csv")
    data["pulses"] = ret_pulse_from_pandas_table(data)

    cfg = RetrievalConfig(medium_name="H_SFA")         # saves to ./H_SFA/
    cfg = RetrievalConfig(medium_name="H_SFA", output_dir="results/H_SFA")  # explicit path
    result = retrieve(data_NA=data, config=cfg)
"""

from __future__ import annotations

import json
import logging
import multiprocessing
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union

if TYPE_CHECKING:
    import lmfit
    from lmfit.minimizer import MinimizerResult

import numpy as np
import numpy.typing as npt
import pandas as pd
from scipy.optimize import differential_evolution

try:
    import lmfit
    from lmfit.minimizer import MinimizerResult

    _HAS_LMFIT = True
except ImportError:
    lmfit = None  # type: ignore[assignment]
    _HAS_LMFIT = False
    # MinimizerResult is intentionally left undefined when lmfit is absent.
    # Accessing it raises NameError rather than the silent TypeError that
    # `isinstance(x, None)` would produce.

try:
    import emcee

    _HAS_EMCEE = True
except ImportError:
    emcee = None  # type: ignore[assignment]
    _HAS_EMCEE = False

try:
    import matplotlib as mpl
    import matplotlib.pyplot as plt

    _HAS_MPL = True
except ImportError:
    mpl = None  # type: ignore[assignment]
    plt = None  # type: ignore[assignment]
    _HAS_MPL = False

try:
    import corner as _corner_lib

    _HAS_CORNER = True
except ImportError:
    _corner_lib = None  # type: ignore[assignment]
    _HAS_CORNER = False

try:
    from tqdm.auto import tqdm

    _HAS_TQDM = True
except ImportError:

    def tqdm(iterable, **kwargs):  # type: ignore[misc]
        return iterable

    _HAS_TQDM = False

from ._atomic_data import get_ionization_potential
from ._parameter_store import get_parameters
from .fitting import ret_residual_function

_log = logging.getLogger(__name__)

_RETRIEVAL_INSTALL_HINT = (
    "Install the retrieval extras with:  pip install gasfir[retrieval]"
)


def _require_lmfit() -> None:
    """Raise a helpful ImportError if lmfit is not installed."""
    if not _HAS_LMFIT:
        raise ImportError(
            "lmfit is required for parameter fitting and retrieval. "
            + _RETRIEVAL_INSTALL_HINT
        )


# ---------------------------------------------------------------------------
# Optional dependency: CMA-ES
# ---------------------------------------------------------------------------
try:
    import cma  # type: ignore[import]

    _HAS_CMA = True
except ImportError:
    _HAS_CMA = False


# ---------------------------------------------------------------------------
# Pretty-print helpers
# ---------------------------------------------------------------------------


def _section(title: str) -> None:
    """Print a bold section header to stdout."""
    bar = "═" * 62
    print(f"\n{bar}", flush=True)
    print(f"  {title}", flush=True)
    print(bar, flush=True)


def _build_emcee_minimizer_result(
    params: lmfit.Parameters,
    var_names: List[str],
    flat_samples: npt.NDArray[np.float64],
    ndata: int,
    n_walkers: int,
    n_steps: int,
    chisqr: float,
) -> MinimizerResult:
    """Wrap emcee output in an lmfit MinimizerResult for full fit_report() support.

    Sets ``stderr`` on each varied parameter to the posterior standard deviation
    and populates ``flatchain`` as a DataFrame so lmfit corner/trace plotting works.

    Args:
        params: Parameters at the emcee median (stderr already set by run_mcmc).
        var_names: Ordered list of varied parameter names.
        flat_samples: Flattened MCMC chain, shape (n_samples, ndim).
        ndata: Number of data points used in the fit.
        n_walkers: Number of emcee walkers.
        n_steps: Total MCMC steps per walker.
        chisqr: Sum of squared residuals at the median parameter values.

    Returns:
        A :class:`lmfit.minimizer.MinimizerResult` ready for ``lmfit.fit_report()``.
    """
    nvarys = len(var_names)

    covar = (
        np.cov(flat_samples, rowvar=False)
        if flat_samples.ndim == 2 and flat_samples.shape[1] > 1
        else None
    )

    # Build correlation dicts on each parameter (used by fit_report)
    p = params.copy()
    if covar is not None:
        stds = np.sqrt(np.diag(covar))
        for i, name in enumerate(var_names):
            p[name].correl = {
                var_names[j]: float(covar[i, j] / (stds[i] * stds[j]))
                for j in range(nvarys)
                if j != i
            }

    redchi = chisqr / max(ndata - nvarys, 1)
    aic = chisqr + 2 * nvarys
    bic = chisqr + nvarys * float(np.log(ndata))

    result = MinimizerResult()
    result.method = "emcee"
    result.var_names = var_names
    result.params = p
    result.covar = covar
    result.nvarys = nvarys
    result.ndata = ndata
    result.nfev = n_walkers * n_steps
    result.success = True
    result.errorbars = True
    result.chisqr = chisqr
    result.redchi = redchi
    result.aic = aic
    result.bic = bic
    result.residual = np.array([])  # individual residuals not retained

    try:
        result.flatchain = pd.DataFrame(flat_samples, columns=var_names)
    except Exception:
        pass

    return result


# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------


[docs] @dataclass class RetrievalConfig: """All settings for the three-phase retrieval pipeline. Args: medium_name: Name used to look up initial parameters via :func:`~gasfir.get_parameters` (if ``initial_params`` is ``None``) and as the stem for output file names. output_dir: Directory where JSON results, HDF5 chains, and LaTeX/corner outputs are written. initial_params: Explicit initial parameter dict. If ``None``, :func:`~gasfir.get_parameters(medium_name)` is used. fixed_params: Names of parameters to hold fixed (not varied). ``E_g`` and ``m_eff`` are fixed by default for non-crystal media. relative_bound: Half-width of parameter bounds as a multiple of the absolute initial value. Default 5.0 means the search range is ``[val - 5*abs(val), val + 5*abs(val)]``. is_crystal: If ``True``, ``m_eff`` is also varied and ``a0_per_au3`` is converted to ``a0`` using the lattice constant. simultaneous: If ``True``, load and concatenate quasi-static residuals. dt: Coarse time step for pulse pre-computation (a.u.). dT: Fine time step for pulse pre-computation (a.u.). ret_electron_density: Pass ``True`` to fit electron density instead of ionization probability. global_method: ``"differential_evolution"`` or ``"cma"``. de_workers: Number of workers for differential evolution. ``-1`` uses all cores. ``1`` (default) lets Numba handle intra-call parallelism. de_maxiter: Maximum iterations for differential evolution. de_tol: Convergence tolerance for differential evolution. de_popsize: Population size multiplier for differential evolution. cma_sigma0: Initial step size for CMA-ES. cma_maxiter: Maximum function evaluations for CMA-ES. ls_method: lmfit method for local polish (e.g. ``"least_squares"``, ``"leastsq"``). emcee_nwalkers: Number of MCMC walkers (must be even and ≥ 2×ndim). ``None`` → ``max(32, 4×ndim)``. emcee_nsteps: Total MCMC steps per walker. emcee_burn_fraction: Fraction of chain discarded as burn-in if autocorrelation time estimation fails. emcee_thin: Thinning factor applied when drawing flat samples. emcee_max_autocorr_loops: Maximum auto-correction loops (re-centre bounds and restart MCMC if parameters drift). emcee_pool: Optional ``multiprocessing.Pool``-compatible object passed to :class:`emcee.EnsembleSampler`. ``None`` (default) uses Numba thread-level parallelism instead. trial_run: If ``True``, run DE and MCMC with minimal iterations for a quick sanity check. """ medium_name: str = "unknown" output_dir: Optional[Union[str, Path]] = None """Output directory for JSON, HDF5 chain, plots, and LaTeX files. Defaults to a folder named after *medium_name* in the current directory.""" initial_params: Optional[Dict[str, float]] = None fixed_params: Optional[List[str]] = None relative_bound: float = 5.0 is_crystal: bool = False simultaneous: bool = False dt: float = 4.0 dT: float = 0.25 ret_electron_density: bool = False # --- global search --- global_method: str = "cma" de_workers: int = 1 de_maxiter: int = 1000 de_tol: float = 1e-7 de_popsize: int = 15 cma_sigma0: float = 0.3 cma_maxiter: int = 10_000 # --- local polish --- ls_method: str = "least_squares" # --- emcee --- emcee_nwalkers: Optional[int] = None emcee_nsteps: int = 3000 emcee_burn_fraction: float = 0.3 emcee_thin: int = 10 emcee_max_autocorr_loops: int = 5 emcee_pool: Optional[object] = field(default=None, repr=False) trial_run: bool = False # --- reporting --- latex_mapping: Optional[Dict[str, str]] = field(default=None, repr=False) """LaTeX label overrides for parameter names used in plots and .tex output. Example:: {"E_g": r"$E_g$ [a.u.]", "m_eff": r"$m_{\\mathrm{eff}}$"} """
# --------------------------------------------------------------------------- # Result container # ---------------------------------------------------------------------------
[docs] @dataclass class RetrievalResult: """Full results from a retrieval run. Attributes: params: Final lmfit Parameters with best-fit values and (after MCMC) ``stderr`` attributes set to posterior standard deviations. var_names: Names of the varied parameters. chisqr: Sum of squared residuals at the best-fit point. redchi: Reduced chi-squared. ndata: Number of data points used in the fit. nvarys: Number of varied parameters. covar: Covariance matrix from the MCMC posterior. flat_samples: Flattened MCMC chain (shape: ``(n_samples, ndim)``). lnprob: Log-probability for each flat sample. val_stats: Out-of-sample validation statistics keyed by subset name. de_result: Raw scipy DE result (or ``None``). ls_result: Raw lmfit least-squares result (or ``None``). method: Comma-separated list of phases that were run. """ params: lmfit.Parameters var_names: List[str] chisqr: float = 0.0 redchi: float = 0.0 ndata: int = 0 nvarys: int = 0 covar: Optional[npt.NDArray[np.float64]] = None flat_samples: Optional[npt.NDArray[np.float64]] = None lnprob: Optional[npt.NDArray[np.float64]] = None val_stats: Dict = field(default_factory=dict) de_result: object = field(default=None, repr=False) ls_result: object = field(default=None, repr=False) method: str = "emcee" chain_path: Optional[Path] = field(default=None, repr=False) """Path to the HDF5 emcee chain file, if one was written."""
# --------------------------------------------------------------------------- # I/O helpers # ---------------------------------------------------------------------------
[docs] def save_result(result: RetrievalResult, path: Union[str, Path]) -> None: """Serialise a :class:`RetrievalResult` to a JSON file. Args: result: The retrieval result to save. path: Output file path (will be created/overwritten). """ path = Path(path) covar_list = result.covar.tolist() if result.covar is not None else None payload = { "params": json.loads(result.params.dumps()), "var_names": result.var_names, "chisqr": result.chisqr, "redchi": result.redchi, "ndata": result.ndata, "nvarys": result.nvarys, "covar": covar_list, "val_stats": result.val_stats, "method": result.method, } path.parent.mkdir(parents=True, exist_ok=True) with open(path, "w") as fh: json.dump(payload, fh, indent=4) _log.info("Result saved to %s", path)
[docs] def load_result(path: Union[str, Path]) -> RetrievalResult: """Load a :class:`RetrievalResult` from a JSON file written by :func:`save_result`. Args: path: Path to the JSON file. Returns: A :class:`RetrievalResult` with ``flat_samples`` and ``lnprob`` set to ``None`` (they are not persisted in JSON; reload the HDF5 chain for those). """ with open(path) as fh: data = json.load(fh) params = lmfit.Parameters() params.loads(json.dumps(data["params"])) result = RetrievalResult( params=params, var_names=data.get("var_names", []), chisqr=data.get("chisqr", 0.0), redchi=data.get("redchi", 0.0), ndata=data.get("ndata", 0), nvarys=data.get("nvarys", 0), covar=np.array(data["covar"]) if data.get("covar") is not None else None, val_stats=data.get("val_stats", {}), method=data.get("method", "unknown"), ) return result
# --------------------------------------------------------------------------- # Parameter setup # --------------------------------------------------------------------------- def setup_parameters( params_dict: Dict[str, float], relative_bound: float = 5.0, fixed_params: Optional[List[str]] = None, is_crystal: bool = False, param_bounds: Optional[Dict[str, tuple]] = None, ) -> lmfit.Parameters: """Build a bounded :class:`lmfit.Parameters` object from a raw dict. Args: params_dict: Initial parameter values (e.g. from :func:`~gasfir.get_parameters`). relative_bound: Half-width of the search range as a multiple of the absolute value. A value of 5.0 means the range spans ``[val - 5*abs(val), val + 5*abs(val)]``. fixed_params: Names of parameters to hold fixed. Defaults to ``["E_g"]`` for atomic/molecular media. For crystals (``is_crystal=True``) only explicitly listed names are fixed. is_crystal: When ``True``, ``m_eff`` is treated as a free parameter and ``a0_per_au3`` is converted to ``a0`` via the lattice constant. param_bounds: Optional explicit ``{name: (min, max)}`` overrides. Any parameter listed here uses the given bounds instead of the ``relative_bound`` formula. Used for cold-start runs where physical rather than relative bounds make more sense. Returns: A :class:`lmfit.Parameters` instance with bounds set. """ _require_lmfit() if param_bounds is None: param_bounds = {} if fixed_params is None: fixed_params = [] if is_crystal else ["E_g", "m_eff"] params = lmfit.Parameters() for name, val in params_dict.items(): if name == "lattice_constant_au": continue # derived constant, not a fit parameter if name in fixed_params: params.add(name, value=val, vary=False) continue # Crystal: convert a0_per_au3 → a0 in atomic units if name == "a0_per_au3" and is_crystal: lat = params_dict.get("lattice_constant_au", 1.0) a0 = val * lat**3 offset = relative_bound * abs(a0) params.add( "a0", value=a0, vary=True, min=max(1e-8, a0 - offset), max=a0 + offset, ) continue # Explicit bounds override the relative-bound formula if name in param_bounds: lb, ub = param_bounds[name] params.add(name, value=val, vary=True, min=lb, max=ub) continue offset = relative_bound * abs(val) if offset == 0.0: val = 1e-6 offset = 4.0 # Parameters that are physically positive-definite _positive = {"E_g", "a0", "a1"} if name == "a4": if float(val) == 0.0: params.add(name, value=val, vary=False) else: lb = max(-5.0, val - offset) ub = min(5.0, val + offset) if is_crystal: lb = max(-1.0, val - offset) params.add(name, value=val, vary=True, min=lb, max=ub) elif name in _positive: params.add( name, value=val, vary=True, min=max(1e-8, val - offset), max=val + offset, ) else: params.add(name, value=val, vary=True, min=val - offset, max=val + offset) return params # --------------------------------------------------------------------------- # Statistical regularisation # --------------------------------------------------------------------------- def get_cleanest_value( mcmc_median: float, mcmc_err: float, initial_val: Optional[float] = None, ) -> float: """Return the statistically cleanest representative value in the 1-σ interval. Priority: 1. ``initial_val`` if it falls within ``[median - err, median + err]``. 2. ``0.0`` if zero is contained in the interval. 3. The number with the fewest significant figures inside the interval. 4. ``mcmc_median`` as a fallback. Args: mcmc_median: Median of the posterior sample. mcmc_err: Standard deviation of the posterior sample (1-σ). initial_val: Prior best estimate (e.g. the stored parameter value). Pass ``None`` when there is no prior. Returns: The cleanest representative value within the 1-σ interval. """ low = mcmc_median - mcmc_err high = mcmc_median + mcmc_err if initial_val is not None and low <= initial_val <= high: return initial_val if low <= 0.0 <= high: return 0.0 for sig in range(1, 15): fmt = "{:." + str(sig) + "g}" for candidate in ( float(fmt.format(mcmc_median)), float(fmt.format(low)), float(fmt.format(high)), ): if low <= candidate <= high: return candidate return mcmc_median # --------------------------------------------------------------------------- # Safe residual wrapper # --------------------------------------------------------------------------- def _make_safe_residual( base_fn: Callable[[lmfit.Parameters], npt.NDArray[np.float64]], ) -> Callable[[lmfit.Parameters], npt.NDArray[np.float64]]: """Wrap a residual function to replace NaN/Inf with large finite values. Args: base_fn: The original residual function. Returns: A wrapped callable with the same signature that never returns NaN or infinite values, making it safe for all optimisers. """ def safe_fn(params: lmfit.Parameters) -> npt.NDArray[np.float64]: res = np.asarray(base_fn(params)) valid = np.isfinite(res) penalty = ( min(np.max(np.abs(res[valid])) * 10, 1e100) if np.any(valid) else 200.0 ) return np.clip( np.nan_to_num(res, nan=penalty, posinf=penalty, neginf=-penalty), -1e100, 1e100, ) return safe_fn # --------------------------------------------------------------------------- # Picklable log-probability for emcee # --------------------------------------------------------------------------- class _LogProbability: """Picklable callable wrapping a GASFIR residual function for emcee. Using a class (rather than a closure) ensures the object is picklable across ``multiprocessing.Pool`` workers, which require all callables and their captured state to be picklable. Attributes: var_names: Ordered list of varied parameter names. bounds: Dict ``{name: (min, max)}`` for each varied parameter. params_template: A :class:`lmfit.Parameters` template that is cloned on each call to avoid race conditions. """ def __init__( self, residual_fn: Callable, var_names: List[str], params: lmfit.Parameters, ) -> None: self._residual_fn = residual_fn self.var_names = var_names self.bounds = {name: (params[name].min, params[name].max) for name in var_names} self._params_template = params def __call__(self, theta: npt.NDArray[np.float64]) -> float: """Evaluate log-probability at walker position *theta*. Args: theta: 1-D array of parameter values in the order of ``self.var_names``. Returns: Log-probability value (``-inf`` for out-of-bounds walkers). """ p = self._params_template.copy() for name, val in zip(self.var_names, theta): lo, hi = self.bounds[name] if not (lo <= val <= hi): return -np.inf p[name].value = val resid = self._residual_fn(p) if np.any(np.abs(resid) >= 1e99): return -np.inf return float(-0.5 * np.sum(resid**2)) # --------------------------------------------------------------------------- # Phase 1 — global search # --------------------------------------------------------------------------- # --------------------------------------------------------------------------- # Phase 2 — local polish # ---------------------------------------------------------------------------
[docs] def run_local_polish( residual_fn: Callable, params: lmfit.Parameters, method: str = "least_squares", ) -> Tuple[lmfit.Parameters, MinimizerResult]: """Locally polish the global-search result with Levenberg–Marquardt. Args: residual_fn: Callable ``f(params) -> residual_array``. params: :class:`lmfit.Parameters` at the global search optimum. method: lmfit minimisation method. ``"least_squares"`` is recommended as it handles bounds natively. Returns: Tuple of (updated params, :class:`lmfit.MinimizerResult`). """ _require_lmfit() _log.info("Phase 2: Local polish with %s", method) print(f" [LS] Starting {method} …", flush=True) _ls_iter = [0] def _iter_cb(params, iter_num, resid, **kw): _ls_iter[0] = iter_num if iter_num % 10 == 0: print( f" [LS] iter {iter_num:4d} | " f"‖resid‖ = {float(np.sum(resid**2)):.6g}", flush=True, ) mini = lmfit.Minimizer(residual_fn, params, nan_policy="omit", iter_cb=_iter_cb) ls_result = mini.minimize(method=method) print( f" [LS] done — {_ls_iter[0]} iters | " f"success={ls_result.success} | redchi={ls_result.redchi:.4g}", flush=True, ) _log.info( "LS finished: success=%s redchi=%.4g", ls_result.success, ls_result.redchi, ) return ls_result.params, ls_result
# --------------------------------------------------------------------------- # Phase 3 — MCMC # ---------------------------------------------------------------------------
[docs] def run_mcmc( residual_fn: Callable, params: lmfit.Parameters, n_steps: int = 3000, n_walkers: Optional[int] = None, burn_fraction: float = 0.3, thin: int = 10, backend_path: Optional[Union[str, Path]] = None, pool: Optional[object] = None, max_autocorr_loops: int = 5, initial_stored_params: Optional[Dict[str, float]] = None, trial_run: bool = False, ) -> Tuple[lmfit.Parameters, npt.NDArray, npt.NDArray, object]: """Run emcee MCMC and return posterior summary. Wraps the residual function in a log-probability callable, runs emcee with optional HDF5 chain persistence and auto-correction for parameter drift. Args: residual_fn: Safe residual callable ``f(params) -> array``. params: :class:`lmfit.Parameters` at the local-polish optimum. n_steps: Total MCMC steps per walker. n_walkers: Number of walkers. ``None`` → ``max(32, 4 × ndim)``. burn_fraction: Fraction of chain used as burn-in when autocorrelation estimation fails. thin: Thinning factor for flat samples. backend_path: HDF5 file for chain persistence. Pass ``None`` to skip persistence. pool: Optional ``multiprocessing.Pool`` for walker-level parallelism. Pass ``None`` (default) to rely on Numba intra-call parallelism. max_autocorr_loops: Number of auto-correction cycles (re-centre bounds around drifted parameters and restart). initial_stored_params: Original parameter dict used as the reference for drift detection (passed to :func:`get_cleanest_value`). trial_run: Use a very small step count for quick testing. Returns: ``(final_params, flat_samples, lnprob, sampler)`` where *final_params* has ``stderr`` set to posterior standard deviations and values snapped to clean estimates via :func:`get_cleanest_value`. """ _require_lmfit() if not _HAS_EMCEE: raise ImportError( "emcee is required for MCMC sampling. " + _RETRIEVAL_INSTALL_HINT ) var_names = sorted([n for n in params if params[n].vary]) ndim = len(var_names) if n_walkers is None: n_walkers = max(32, 4 * ndim) # emcee requires n_walkers > 2 × ndim; also enforce even n_walkers = max(n_walkers, 2 * ndim + 2) n_walkers = n_walkers + (n_walkers % 2) steps = 20 if trial_run else n_steps stored = initial_stored_params or {} log_prob = _LogProbability(residual_fn, var_names, params) _log.info( "Phase 3: emcee ndim=%d nwalkers=%d steps=%d pool=%s", ndim, n_walkers, steps, "yes" if pool else "no", ) # ----- auto-correction loop ----- final_params = params.copy() flat_samples: npt.NDArray = np.empty(0) lnprob_arr: npt.NDArray = np.empty(0) sampler: Optional[object] = None for loop in range(max_autocorr_loops): _log.info("MCMC auto-correction loop %d/%d", loop + 1, max_autocorr_loops) # Refresh log_prob with current bounds log_prob = _LogProbability(residual_fn, var_names, final_params) # HDF5 backend backend: Optional[emcee.backends.HDFBackend] = None if backend_path is not None: bp = Path(backend_path) bp.parent.mkdir(parents=True, exist_ok=True) backend = emcee.backends.HDFBackend(str(bp)) # Determine initial state if ( loop == 0 and backend is not None and bp.exists() and backend.iteration > 0 and backend.shape == (n_walkers, ndim) ): _log.info("Resuming from existing chain at step %d", backend.iteration) initial_state = backend.get_last_sample() remaining = max(0, steps - backend.iteration) else: if backend is not None: backend.reset(n_walkers, ndim) x0 = np.array([final_params[n].value for n in var_names]) initial_state = x0 + 1e-4 * np.random.default_rng( 42 + loop ).standard_normal((n_walkers, ndim)) remaining = steps sampler = emcee.EnsembleSampler( n_walkers, ndim, log_prob, backend=backend, pool=pool ) if remaining > 0: with tqdm( total=remaining, desc=f"emcee loop {loop+1}", unit="step" ) as pbar: for _ in sampler.sample(initial_state, iterations=remaining): pbar.update(1) # Autocorrelation-based burn-in try: tau = sampler.get_autocorr_time(quiet=True) burnin = int(2 * np.max(tau)) thin_auto = max(1, int(0.5 * np.min(tau))) except Exception: # emcee.autocorr.AutocorrError or chain too short n_stored = sampler.get_chain().shape[0] burnin = max(0, int(n_stored * burn_fraction)) thin_auto = thin n_stored = sampler.get_chain().shape[0] burnin = min(burnin, n_stored // 2) flat_samples = sampler.get_chain(discard=burnin, thin=thin_auto, flat=True) lnprob_arr = sampler.get_log_prob(discard=burnin, thin=thin_auto, flat=True) if len(flat_samples) == 0: _log.warning( "Empty flat samples — chain too short. Proceeding without thinning." ) flat_samples = sampler.get_chain(flat=True) lnprob_arr = sampler.get_log_prob(flat=True) # Drift detection and snapping all_stable = True next_params = final_params.copy() _log.info("Drift analysis:") for i, name in enumerate(var_names): median = float(np.median(flat_samples[:, i])) sigma = float(np.std(flat_samples[:, i])) init = stored.get(name) clean = get_cleanest_value(median, sigma, init) next_params[name].value = clean next_params[name].stderr = sigma if init is not None and clean != init: _log.info( " DRIFT %s: %.4g%.4g (σ=%.2g)", name, init, clean, sigma ) all_stable = False else: _log.info(" STABLE %s: %.4g (σ=%.2g)", name, clean, sigma) final_params = next_params if all_stable or loop == max_autocorr_loops - 1: if not all_stable: _log.warning("Max autocorrection loops reached; using current best.") else: _log.info("Convergence confirmed — all parameters stable.") break # Copy fixed params for name in params: if not params[name].vary: final_params[name].value = params[name].value assert sampler is not None return final_params, flat_samples, lnprob_arr, sampler
# --------------------------------------------------------------------------- # Validation metrics # --------------------------------------------------------------------------- def compute_validation_metrics( result_params: lmfit.Parameters, validation_df: pd.DataFrame, uncertainty: float = 5e-2, dt: float = 4.0, dT: float = 0.25, ret_electron_density: bool = False, ) -> Dict[str, float]: """Compute chi-squared and RMSE for an out-of-sample validation set. Args: result_params: Best-fit parameters. validation_df: DataFrame with ``"pulses"`` and ``"Y"`` columns. uncertainty: Fractional uncertainty (applied as ``uncertainty × Y``). dt: Coarse time step for the kernel. dT: Fine time step for the kernel. ret_electron_density: Pass ``True`` to use electron density. Returns: Dict with keys ``"N"``, ``"chisqr"``, ``"redchi"``, ``"rmse"``. """ uncert_arr = uncertainty * validation_df["Y"].to_numpy() val_fn = ret_residual_function( validation_df, uncert_arr, None, None, dt=dt, dT=dT, ret_electron_density=ret_electron_density, ) resid = np.asarray(val_fn(result_params)) chisqr = float(np.sum(resid**2)) n = len(validation_df) rmse = float(np.sqrt(np.mean((resid * uncert_arr) ** 2))) return {"N": n, "chisqr": chisqr, "redchi": chisqr / n, "rmse": rmse} # --------------------------------------------------------------------------- # Stored-parameter comparison # ---------------------------------------------------------------------------
[docs] def compare_to_stored( ls_result: MinimizerResult, stored_params: Dict[str, float], sigma: float = 2.0, ) -> None: """Print a table comparing LS-fitted values to stored reference parameters. For each varied parameter, reports the stored value, fitted value, standard error, number of standard deviations separating them, and whether the stored value falls within *sigma*-σ of the fitted value. Args: ls_result: :class:`lmfit.minimizer.MinimizerResult` from the LS phase. stored_params: Reference parameter dict (e.g. from :func:`~gasfir.get_parameters`). Only keys that appear in ``ls_result.params`` and were varied are compared. sigma: Confidence threshold in units of stderr (default 2 → 95 % CI). """ _require_lmfit() var_names = [n for n in ls_result.params if ls_result.params[n].vary] # filter to only those present in stored_params names_to_check = [n for n in var_names if n in stored_params] if not names_to_check: print( " (no overlap between fitted and stored parameters — skipping comparison)", flush=True, ) return _section(f"Stored vs LS-fitted parameters ({sigma:.0f}σ tolerance)") print( f" {'Param':10s} {'Stored':>12s} {'Fitted':>12s} {'±stderr':>10s} " f"{'Nσ away':>9s} {f'In {sigma:.0f}σ?':>8s}", flush=True, ) print(f" {'─'*10} {'─'*12} {'─'*12} {'─'*10} {'─'*9} {'─'*8}", flush=True) all_pass = True for name in names_to_check: p = ls_result.params[name] s = stored_params[name] f = p.value err = p.stderr if p.stderr is not None else float("nan") nsigma = abs(s - f) / err if err > 0 else float("inf") ok = nsigma <= sigma if not ok: all_pass = False icon = "✅" if ok else "❌" print( f" {name:10s} {s:12.4g} {f:12.4g} {err:10.4g} " f"{nsigma:9.2f} {icon:>8s}", flush=True, ) if all_pass: print( f"\n All stored values are within {sigma:.0f}σ of the LS fit. ✅", flush=True, ) else: print( f"\n Some stored values lie outside {sigma:.0f}σ — the stored parameters " "may be rough starting guesses rather than a prior fit to this dataset.", flush=True, )
# --------------------------------------------------------------------------- # Visualisation & reporting # --------------------------------------------------------------------------- def _latex_label(name: str, latex_mapping: Dict[str, str]) -> str: """Return the LaTeX label for a parameter, falling back to auto-escaped name.""" escaped = name.replace("_", r"\_") return latex_mapping.get(name, f"${escaped}$")
[docs] def generate_publication_corner( flat_samples: npt.NDArray[np.float64], var_names: List[str], output_dir: Union[str, Path], medium_name: str, truths: Optional[List[float]] = None, latex_mapping: Optional[Dict[str, str]] = None, constants: Optional[Dict[str, float]] = None, ) -> Optional[Path]: """Save a publication-quality corner plot as a PDF. The corner plot shows 1-D marginal distributions on the diagonal and 2-D contour projections off-diagonal. When *truths* are provided, two inset tables are added: one with adopted parameter values and one with the correlation matrix. Args: flat_samples: Flattened MCMC chain, shape ``(n_samples, ndim)``. var_names: Ordered parameter names matching the chain columns. output_dir: Directory where the PDF is saved. medium_name: Used as the file-name stem. truths: Best-fit / snapped values to mark on the plot (one per param). latex_mapping: ``{name: latex_string}`` overrides for axis labels. constants: Fixed parameter values to list in the adopted-values table. Returns: Path to the saved PDF, or ``None`` if matplotlib / corner are absent. """ if not _HAS_MPL or not _HAS_CORNER: _log.warning("corner / matplotlib not installed — skipping corner plot.") return None if latex_mapping is None: latex_mapping = {} if constants is None: constants = {} display_labels = [_latex_label(n, latex_mapping) for n in var_names] output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) with mpl.rc_context({"font.size": 14, "font.family": "serif"}): fig = _corner_lib.corner( flat_samples, labels=display_labels, quantiles=[0.16, 0.5, 0.84], show_titles=True, truths=truths, truth_color="firebrick", title_kwargs={"fontsize": 14}, label_kwargs={"fontsize": 16}, levels=(1 - np.exp(-0.5), 1 - np.exp(-2), 1 - np.exp(-4.5)), plot_datapoints=False, fill_contours=True, smooth=1.0, color="#005b96", ) if truths is not None: # Inset table 1: adopted values + fixed constants ax_params = fig.add_axes([0.60, 0.75, 0.35, 0.25]) ax_params.axis("off") table_data = [["Parameter", "Adopted Value"]] for i, name in enumerate(var_names): table_data.append([display_labels[i], f"{truths[i]:.4g}"]) for c_name, c_val in constants.items(): table_data.append( [_latex_label(c_name, latex_mapping), f"{c_val:.4g} (Fixed)"] ) tbl = ax_params.table( cellText=table_data, loc="center", cellLoc="center", edges="horizontal" ) tbl.auto_set_font_size(False) tbl.set_fontsize(14) tbl.scale(1, 1.8) num_fitted = len(var_names) for (row, col), cell in tbl.get_celld().items(): if row == 0: cell.get_text().set_weight("bold") if col == 1: cell.get_text().set_color("firebrick") elif 1 <= row <= num_fitted and col == 1: cell.get_text().set_color("firebrick") cell.get_text().set_weight("bold") # Inset table 2: correlation matrix ax_corr = fig.add_axes([0.53, 0.6, 0.42, 0.15]) ax_corr.axis("off") ax_corr.set_title( "Correlation Matrix", fontsize=14, fontweight="bold", pad=10 ) corr = np.corrcoef(flat_samples, rowvar=False) corr_data = [[""] + display_labels] for i in range(len(var_names)): row_data = [display_labels[i]] for j in range(len(var_names)): row_data.append("1" if i == j else f"{corr[i, j]:.2g}") corr_data.append(row_data) tbl_c = ax_corr.table( cellText=corr_data, loc="center", cellLoc="center", edges="open" ) tbl_c.auto_set_font_size(False) tbl_c.set_fontsize(12) tbl_c.scale(1, 1.8) for (r, c), cell in tbl_c.get_celld().items(): if r == 0 and c == 0: cell.visible_edges = "BR" elif r == 0: cell.visible_edges = "B" elif c == 0: cell.visible_edges = "R" else: cell.visible_edges = "" cell.set_linewidth(1.5) if r == 0 or c == 0: cell.get_text().set_weight("bold") if r > 0 and c > 0 and r == c: cell.get_text().set_weight("bold") out_path = output_dir / f"{medium_name}_publication_corner.pdf" fig.savefig(out_path, dpi=300, bbox_inches="tight") plt.close(fig) print(f" 📊 Corner plot saved to: {out_path}", flush=True) return out_path
[docs] def generate_trace_plot( chain_path: Union[str, Path], var_names: List[str], output_dir: Union[str, Path], medium_name: str, latex_mapping: Optional[Dict[str, str]] = None, discard: int = 0, ) -> Optional[Path]: """Save a trace plot (walker trajectories) as a PDF. Loads the chain from an HDF5 backend file so no live sampler object is required — the plot can be regenerated from persisted runs. Args: chain_path: Path to the ``*.h5`` emcee HDF5 backend file. var_names: Parameter names in chain column order. output_dir: Directory where the PDF is saved. medium_name: Used as the file-name stem. latex_mapping: ``{name: latex_string}`` label overrides. discard: Number of initial steps to discard from the trace. Returns: Path to the saved PDF, or ``None`` if matplotlib or emcee absent. """ if not _HAS_MPL or not _HAS_EMCEE: _log.warning("matplotlib / emcee not installed — skipping trace plot.") return None chain_path = Path(chain_path) if not chain_path.exists(): _log.warning("Chain file not found: %s — skipping trace plot.", chain_path) return None if latex_mapping is None: latex_mapping = {} backend = emcee.backends.HDFBackend(str(chain_path), read_only=True) samples = backend.get_chain(discard=discard) # shape: (steps, walkers, ndim) ndim = len(var_names) output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) with mpl.rc_context({"font.size": 12, "font.family": "serif"}): fig, axes = plt.subplots(ndim, figsize=(10, 2 * ndim), sharex=True) if ndim == 1: axes = [axes] for i in range(ndim): ax = axes[i] ax.plot(samples[:, :, i], "k", alpha=0.3) ax.set_xlim(0, len(samples)) ax.set_ylabel(_latex_label(var_names[i], latex_mapping), fontsize=14) ax.yaxis.set_label_coords(-0.1, 0.5) axes[-1].set_xlabel("Step Number", fontsize=14) fig.tight_layout() out_path = output_dir / f"{medium_name}_trace.pdf" fig.savefig(out_path, dpi=300, bbox_inches="tight") plt.close(fig) print(f" 📈 Trace plot saved to: {out_path}", flush=True) return out_path
[docs] def generate_latex_summary( result: RetrievalResult, output_dir: Union[str, Path], latex_mapping: Optional[Dict[str, str]] = None, ) -> Path: """Write a self-contained LaTeX section summarising the fit. Produces a ``.tex`` file with: * Parameter table (initial value, best fit, ±stderr). * Out-of-sample validation table (if ``result.val_stats`` is populated). * Correlation matrix (if ``result.covar`` is available). * ``\\includegraphics`` link to the corner plot. Args: result: :class:`RetrievalResult` from :func:`retrieve`. output_dir: Directory where the ``.tex`` file is saved. latex_mapping: ``{name: latex_string}`` label overrides. Returns: Path to the written ``.tex`` file. """ if latex_mapping is None: latex_mapping = {} medium_name = result.method.split("+")[0] if result.method else "fit" # Use the output_dir stem as medium name for file/label purposes output_dir = Path(output_dir) stem = output_dir.name clean = stem.replace("_", r"\_") latex_str = f"\\section*{{Fit Summary for {clean}}}\n\n" # --- 1. Parameter table --- latex_str += "\\subsection*{Parameters and Uncertainties}\n" latex_str += "\\begin{table}[htbp]\n\\centering\n" latex_str += "\\begin{tabular}{l c c c}\n\\toprule\n" latex_str += ( "Parameter & Initial Value & Best Fit & Uncertainty ($\\pm$) \\\\\n\\midrule\n" ) for name, param in result.params.items(): label = _latex_label(name, latex_mapping) if not param.vary: init = f"{param.init_value:.4g}" if param.init_value is not None else "---" latex_str += f"{label} & {init} & {param.value:.4g} & (Fixed) \\\\\n" else: err = param.stderr if param.stderr is not None else float("nan") init = f"{param.init_value:.4g}" if param.init_value is not None else "---" latex_str += f"{label} & {init} & {param.value:.4g} & {err:.4g} \\\\\n" latex_str += "\\bottomrule\n\\end{tabular}\n" latex_str += f"\\caption{{Best fit parameters for {clean}.}}\n" latex_str += f"\\label{{tab:params_{stem}}}\n\\end{{table}}\n\n" # --- 2. Validation table (once, not duplicated) --- if result.val_stats: latex_str += "\\subsection*{Out-of-Sample Validation}\n" latex_str += "\\begin{table}[htbp]\n\\centering\n" latex_str += "\\begin{tabular}{l c c c c}\n\\toprule\n" latex_str += "Data Subset & $N$ & Reduced $\\chi^2$ & RMSE & Avg.\\,Error \\\\\n\\midrule\n" train_avg_err = np.sqrt(max(result.redchi, 0)) * 5.0 latex_str += ( f"Training (baseline) & {result.ndata} & {result.redchi:.3f} & --- & " f"{train_avg_err:.1f}\\% \\\\\n\\midrule\n" ) for subset_name, stats in result.val_stats.items(): cs = subset_name.replace("_", r"\_") val_err = np.sqrt(max(stats["redchi"], 0)) * 5.0 latex_str += ( f"{cs} & {stats['N']} & {stats['redchi']:.3f} & " f"{stats['rmse']:.3e} & {val_err:.1f}\\% \\\\\n" ) latex_str += "\\bottomrule\n\\end{tabular}\n" latex_str += ( f"\\caption{{Out-of-sample validation for {clean}. " r"Avg.\ Error = $\sqrt{\tilde{\chi}^2}\times5\%$.}}" + "\n" ) latex_str += f"\\label{{tab:val_{stem}}}\n\\end{{table}}\n\n" # --- 3. Correlation matrix --- if result.covar is not None: vn = result.var_names stdevs = np.sqrt(np.diag(result.covar)) corr = result.covar / np.outer(stdevs, stdevs) col_fmt = "c|" + "c" * len(vn) header_names = [_latex_label(n, latex_mapping) for n in vn] latex_str += "\\subsection*{Parameter Correlation Matrix}\n" latex_str += "\\begin{equation}\n" latex_str += f"\\begin{{array}}{{{col_fmt}}}\n" latex_str += " & " + " & ".join(header_names) + " \\\\\n\\hline\n" for i, ni in enumerate(vn): row_label = _latex_label(ni, latex_mapping) row = row_label for j in range(len(vn)): row += " & \\mathbf{1}" if i == j else f" & {corr[i, j]:.4f}" latex_str += row + " \\\\\n" latex_str += "\\end{array}\n\\end{equation}\n\n" # --- 4. Figure link --- latex_str += "\\subsection*{Parameter Correlations}\n" latex_str += "\\begin{figure}[htbp]\n\\centering\n" latex_str += ( f"\\includegraphics[width=0.9\\textwidth]{{{stem}_publication_corner.pdf}}\n" ) latex_str += ( f"\\caption{{Corner plot for {clean}. " "Solid lines mark the adopted (snapped) parameter values.}}\n" ) latex_str += f"\\label{{fig:corner_{stem}}}\n\\end{{figure}}\n" out_path = output_dir / f"{stem}_fit_summary.tex" with open(out_path, "w") as fh: fh.write(latex_str) print(f" 📝 LaTeX summary saved to: {out_path}", flush=True) return out_path
# --------------------------------------------------------------------------- # Post-processing: chain manipulation after a completed run # ---------------------------------------------------------------------------
[docs] def post_process_mcmc( medium_name: str, output_dir: Union[str, Path], original_stored_params: Dict[str, float], drop_vars: Optional[List[str]] = None, derived_exprs: Optional[Dict[str, str]] = None, latex_mapping: Optional[Dict[str, str]] = None, is_crystal: bool = False, ) -> RetrievalResult: """Re-analyse an existing MCMC chain: drop/derive variables, snap, and re-plot. Loads the HDF5 chain written by a previous :func:`retrieve` call, applies transformations (drop parameters, add derived quantities via pandas ``eval``), snaps each parameter to its cleanest 1-σ value, and writes updated corner plot, trace plot, and LaTeX summary. This mirrors the logic in ``uncertainty_emcee_post.py`` but as a proper library function. Args: medium_name: Stem used for file names (matches the original run). output_dir: Directory containing the ``.h5`` chain and ``.json`` result. original_stored_params: The reference parameter dict (e.g. from :func:`~gasfir.get_parameters`) used for drift detection. drop_vars: Parameter names to remove from the chain before re-analysis. derived_exprs: ``{new_name: pandas_eval_expression}`` pairs, e.g. ``{"a0_per_au3": "a0 / 6.74**3"}``. latex_mapping: ``{name: latex_string}`` label overrides. is_crystal: If ``True``, ``lattice_constant_au`` is kept as a fixed constant in the reported tables. Returns: A new :class:`RetrievalResult` with post-processed parameters. Raises: FileNotFoundError: If the HDF5 chain or JSON result is missing. ImportError: If emcee or h5py are not installed. """ if not _HAS_EMCEE: raise ImportError( "emcee is required for post-processing. " + _RETRIEVAL_INSTALL_HINT ) if not _HAS_LMFIT: raise ImportError( "lmfit is required for post-processing. " + _RETRIEVAL_INSTALL_HINT ) output_dir = Path(output_dir) h5_path = output_dir / f"{medium_name}_emcee_chain.h5" json_path = output_dir / f"{medium_name}_retrieval_result.json" if not h5_path.exists(): raise FileNotFoundError(f"HDF5 chain not found: {h5_path}") if not json_path.exists(): raise FileNotFoundError(f"JSON result not found: {json_path}") # ---- load original result ---- original = load_result(json_path) original_var_names: List[str] = original.var_names # ---- load chain ---- backend = emcee.backends.HDFBackend(str(h5_path), read_only=True) try: tau = backend.get_autocorr_time(quiet=True) burnin = int(2 * np.max(tau)) thin_val = max(1, int(0.5 * np.min(tau))) except Exception: total = backend.get_chain().shape[0] burnin = total // 3 thin_val = 1 total_steps = backend.get_chain().shape[0] burnin = min(burnin, total_steps // 2) flat_samples = backend.get_chain(discard=burnin, thin=thin_val, flat=True) df_chain = pd.DataFrame(flat_samples, columns=original_var_names) _section(f"Post-processing MCMC chain — {medium_name}") print( f" Loaded {len(df_chain)} samples (burn={burnin}, thin={thin_val})", flush=True ) # ---- apply transformations ---- if derived_exprs: for new_var, expr in derived_exprs.items(): print(f" ➕ Derived: {new_var} = {expr}", flush=True) df_chain[new_var] = df_chain.eval(expr) if drop_vars: for dv in drop_vars: if dv in df_chain.columns: print(f" ➖ Dropped: {dv}", flush=True) df_chain = df_chain.drop(columns=[dv]) # ---- snap parameters ---- new_var_names = list(df_chain.columns) rebuilt = lmfit.Parameters() print("\n Snapping posteriors to cleanest values:", flush=True) for name in new_var_names: median = float(np.median(df_chain[name])) sigma = float(np.std(df_chain[name])) init = original_stored_params.get(name) clean = get_cleanest_value(median, sigma, init) tag = "STABLE" if init is not None and clean == init else "SNAPPED" print(f" [{tag}] {name}: {clean:.4g} (σ={sigma:.3g})", flush=True) rebuilt.add(name, value=clean, vary=True) rebuilt[name].stderr = sigma # carry fixed constants from original for name in original.params: if name not in rebuilt and not original.params[name].vary: rebuilt.add(name, value=original.params[name].value, vary=False) # add lattice_constant_au if crystal if is_crystal and "lattice_constant_au" in original_stored_params: lat_val = original_stored_params["lattice_constant_au"] rebuilt.add("lattice_constant_au", value=lat_val, vary=False) # ---- assemble RetrievalResult ---- new_flat = df_chain.values new_covar = np.cov(new_flat, rowvar=False) if new_flat.shape[1] > 1 else None stds = np.sqrt(np.diag(new_covar)) if new_covar is not None else None if new_covar is not None and stds is not None: for i, name in enumerate(new_var_names): rebuilt[name].correl = { new_var_names[j]: float(new_covar[i, j] / (stds[i] * stds[j])) for j in range(len(new_var_names)) if j != i } post_result = RetrievalResult( params=rebuilt, var_names=new_var_names, chisqr=original.chisqr, redchi=original.redchi, ndata=original.ndata, nvarys=len(new_var_names), covar=new_covar, flat_samples=new_flat, val_stats=original.val_stats, method="emcee_postprocessed", chain_path=h5_path, ) # ---- outputs ---- post_stem = f"{medium_name}_postprocessed" truths = [rebuilt[n].value for n in new_var_names] fixed_consts: Dict[str, float] = {} if is_crystal and "lattice_constant_au" in original_stored_params: fixed_consts["lattice_constant_au"] = original_stored_params[ "lattice_constant_au" ] generate_publication_corner( new_flat, new_var_names, output_dir, post_stem, truths=truths, latex_mapping=latex_mapping or {}, constants=fixed_consts, ) generate_trace_plot( h5_path, original_var_names, output_dir, post_stem, latex_mapping=latex_mapping or {}, discard=burnin, ) generate_latex_summary(post_result, output_dir, latex_mapping=latex_mapping or {}) # save updated JSON post_json = output_dir / f"{post_stem}_result.json" save_result(post_result, post_json) print(f" 💾 Post-processed result saved to: {post_json}", flush=True) return post_result
# --------------------------------------------------------------------------- # Cold-start defaults # --------------------------------------------------------------------------- def _cold_start_params( E_g: float, is_crystal: bool, ) -> tuple: """Return ``(initial_values, absolute_bounds)`` for a cold-start run. Used when no stored parameters and no user-supplied initial guess exist. Values and bounds are physically motivated broad ranges that cover all known GASFIR parameter sets across atoms, molecules, and solids. Args: E_g: Ionization potential / band gap in atomic units. This is the one parameter the user **must** supply — it sets the energy scale and cannot be guessed from the data alone. is_crystal: If ``True``, use wider bounds appropriate for solids (larger a0, a1; m_eff free with negative values allowed). Returns: Tuple of ``(initial_params_dict, param_bounds_dict)`` where ``param_bounds`` contains ``{name: (min, max)}`` entries that are passed directly to :func:`setup_parameters`. """ if is_crystal: initial = { "E_g": E_g, "m_eff": 0.3, # typical solid; can be negative "a0": 50.0, # broad: Diamond≈107, SiO2≈5 "a1": 5.0, # broad: Diamond≈14, SiO2≈1.6 "a2": 2.0, "a3": -2.0, # often negative for crystals } bounds = { "E_g": (max(1e-4, E_g / 5), E_g * 5), "m_eff": (-2.0, 5.0), "a0": (1e-6, 1e5), "a1": (0.1, 200.0), "a2": (-15.0, 20.0), "a3": (-30.0, 15.0), } else: initial = { "E_g": E_g, "a0": 5.0, # broad: He≈0.82, Ar≈14, H≈3.4 "a1": 3.5, # broad: He≈2, Ar≈4, H≈3.5 "a2": 2.0, "a3": 1.0, } bounds = { "E_g": (max(1e-4, E_g / 5), E_g * 5), "a0": (1e-6, 1e3), "a1": (0.1, 20.0), "a2": (0.0, 15.0), "a3": (-10.0, 10.0), } return initial, bounds # --------------------------------------------------------------------------- # Main entry point # ---------------------------------------------------------------------------
[docs] def retrieve( data_NA: pd.DataFrame, config: RetrievalConfig, uncertainty_NA: Optional[npt.NDArray[np.float64]] = None, data_QS: Optional[pd.DataFrame] = None, uncertainty_QS: Optional[Union[npt.NDArray[np.float64], float]] = None, data_validation: Optional[pd.DataFrame] = None, run_phases: Tuple[bool, bool, bool] = (True, True, True), ) -> RetrievalResult: """Run the full DE/CMA-ES → least_squares → emcee retrieval pipeline. Args: data_NA: Training DataFrame with ``"pulses"`` and ``"Y"`` columns. config: :class:`RetrievalConfig` controlling all pipeline settings. uncertainty_NA: Per-point uncertainty array. If ``None``, defaults to ``0.05 × Y`` (5% relative uncertainty). data_QS: Optional quasi-static data with ``"field"`` and ``"Y"`` columns for simultaneous fitting. uncertainty_QS: Uncertainty for QS data. Defaults to 5%. data_validation: Optional out-of-sample DataFrame for validation metrics (same columns as ``data_NA``). run_phases: Three-tuple ``(global, polish, mcmc)`` controlling which pipeline phases to execute. Default ``(True, True, True)`` runs all three. Returns: A :class:`RetrievalResult` with the best-fit parameters, posterior statistics, and optional validation metrics. """ _require_lmfit() # Validate optional dependencies before any expensive computation. run_global, run_polish, run_mcmc_flag = run_phases if run_global and config.global_method == "cma" and not _HAS_CMA: raise ImportError( "CMA-ES is the default global search method but the 'cma' package " "is not installed. Either install it with " " pip install gasfir[retrieval]\n" "or switch to differential evolution:\n" " config.global_method = 'differential_evolution'" ) if run_mcmc_flag and not _HAS_EMCEE: raise ImportError( "emcee is required for MCMC sampling. " + _RETRIEVAL_INSTALL_HINT ) output_dir = ( Path(config.output_dir) if config.output_dir is not None else Path(config.medium_name) ) output_dir.mkdir(parents=True, exist_ok=True) print(f" Output directory: {output_dir.resolve()}", flush=True) # --- uncertainty defaults --- if uncertainty_NA is None: uncertainty_NA = 0.05 * data_NA["Y"].to_numpy() if uncertainty_QS is None and data_QS is not None: uncertainty_QS = 0.05 * data_QS["Y"].to_numpy() # --- initial parameters --- _cold_start = False _explicit_bounds: Dict[str, tuple] = {} # Try to get stored params; fall back to cold start if unavailable. _stored_ref: Optional[Dict[str, float]] = None try: _stored_ref = get_parameters(config.medium_name) except (ValueError, KeyError): pass # medium not in store → cold start unless full initial_params given if config.initial_params is not None and _stored_ref is not None: # Full explicit initial params provided OR store available — use directly. stored = config.initial_params.copy() elif _stored_ref is not None: # Store hit, no override → use stored params. stored = _stored_ref.copy() else: # ── Cold start ─────────────────────────────────────────────────── # Medium not in store. E_g must be provided — it sets the energy # scale and cannot be inferred from data alone. hints = config.initial_params or {} # Try to auto-resolve E_g from the atomic database if medium_name # looks like an element symbol or name (atoms/molecules only). _auto_E_g: Optional[float] = None if not config.is_crystal and "E_g" not in hints: _auto_E_g = get_ionization_potential(config.medium_name) if "E_g" not in hints and _auto_E_g is None: raise ValueError( f"No stored parameters found for '{config.medium_name}'.\n" "For a cold start, supply E_g (ionization potential in a.u.):\n" f" config.initial_params = {{'E_g': <value_in_au>}}\n" "For atoms/molecules the ionization potential can be looked up:\n" " from gasfir import get_ionization_potential\n" " E_g = get_ionization_potential('Ar') # example\n" "For crystals, E_g must always be provided explicitly.\n" "All other parameters are initialised from broad physical defaults." ) E_g_init = hints.get("E_g", _auto_E_g) stored, _explicit_bounds = _cold_start_params(E_g_init, config.is_crystal) stored.update(hints) # user hints override defaults _cold_start = True _src = "NIST atomic database" if _auto_E_g is not None else "user-supplied" _E_g_eV = E_g_init * 27.21138383 print( f"\n ⚠ Cold start — no stored parameters for '{config.medium_name}'.\n" f" E_g = {E_g_init:.4g} a.u. ({_E_g_eV:.4g} eV, {_src})\n" f" Physical defaults used for all other parameters.\n" f" Consider increasing cma_maxiter for a thorough global search.", flush=True, ) # Crystal: pre-expand a0_per_au3 → a0 if config.is_crystal and "a0_per_au3" in stored and "lattice_constant_au" in stored: stored["a0"] = stored["a0_per_au3"] * stored["lattice_constant_au"] ** 3 if "a0" not in _explicit_bounds and _cold_start: _explicit_bounds["a0"] = _explicit_bounds.pop("a0", (1e-6, 1e5)) params = setup_parameters( stored, relative_bound=config.relative_bound, fixed_params=config.fixed_params, is_crystal=config.is_crystal, param_bounds=_explicit_bounds if _cold_start else None, ) _section(f"Initial Parameters — {config.medium_name}") params.pretty_print() # --- residual function --- base_residual = ret_residual_function( data_NA, uncertainty_NA, data_QS, uncertainty_QS, dt=config.dt, dT=config.dT, ret_electron_density=config.ret_electron_density, ) residual_fn = _make_safe_residual(base_residual) # --- track which phases ran --- phases_run: List[str] = [] de_result = None ls_result = None # === Phase 1: Global search === if run_global: params, de_result = run_global_search( residual_fn, params, method=config.global_method, de_workers=config.de_workers, de_maxiter=config.de_maxiter, de_tol=config.de_tol, de_popsize=config.de_popsize, cma_sigma0=config.cma_sigma0, cma_maxiter=config.cma_maxiter, trial_run=config.trial_run, ) phases_run.append(config.global_method) _section(f"Phase 1 — {config.global_method.upper()} best parameters") params.pretty_print() # === Phase 2: Local polish === if run_polish: params, ls_result = run_local_polish( residual_fn, params, method=config.ls_method ) phases_run.append("least_squares") _section("Phase 2 — Least Squares Fit Report") print(lmfit.fit_report(ls_result), flush=True) # Compare to stored reference parameters if they exist if config.medium_name and config.medium_name != "unknown": try: stored_ref = get_parameters(config.medium_name) # For crystals: convert a0_per_au3 → a0 for the comparison if "a0_per_au3" in stored_ref and "lattice_constant_au" in stored_ref: stored_ref = stored_ref.copy() stored_ref["a0"] = ( stored_ref["a0_per_au3"] * stored_ref["lattice_constant_au"] ** 3 ) compare_to_stored(ls_result, stored_ref) except (ValueError, KeyError): pass # medium not in store — skip silently # === Phase 3: MCMC === flat_samples: Optional[npt.NDArray] = None lnprob_arr: Optional[npt.NDArray] = None covar: Optional[npt.NDArray] = None if run_mcmc_flag: backend_path = output_dir / f"{config.medium_name}_emcee_chain.h5" params, flat_samples, lnprob_arr, sampler = run_mcmc( residual_fn, params, n_steps=config.emcee_nsteps, n_walkers=config.emcee_nwalkers, burn_fraction=config.emcee_burn_fraction, thin=config.emcee_thin, backend_path=backend_path, pool=config.emcee_pool, max_autocorr_loops=config.emcee_max_autocorr_loops, initial_stored_params=stored, trial_run=config.trial_run, ) phases_run.append("emcee") covar = np.cov(flat_samples, rowvar=False) if flat_samples.ndim == 2 else None # === Goodness of fit at best-fit point === best_resid = residual_fn(params) var_names = sorted([n for n in params if params[n].vary]) ndata = len(data_NA) + (len(data_QS) if data_QS is not None else 0) nvarys = len(var_names) chisqr = float(np.sum(best_resid**2)) redchi = chisqr / max(ndata - nvarys, 1) # === MCMC fit report (needs chisqr computed above) === # HDF5 chain path (written during run_mcmc if emcee ran) _chain_path: Optional[Path] = None if run_mcmc_flag: _chain_path = output_dir / f"{config.medium_name}_emcee_chain.h5" if run_mcmc_flag and flat_samples is not None and len(flat_samples) > 0: emcee_mr = _build_emcee_minimizer_result( params, var_names, flat_samples, ndata=ndata, n_walkers=max(32, 4 * nvarys), n_steps=config.emcee_nsteps, chisqr=chisqr, ) _section("Phase 3 — MCMC Posterior Report") print(lmfit.fit_report(emcee_mr), flush=True) # Corner plot, trace plot, and LaTeX summary truths = [params[n].value for n in var_names] fixed_consts = {n: params[n].value for n in params if not params[n].vary} generate_publication_corner( flat_samples, var_names, output_dir, config.medium_name, truths=truths, latex_mapping=config.latex_mapping or {}, constants=fixed_consts, ) if _chain_path and _chain_path.exists(): generate_trace_plot( _chain_path, var_names, output_dir, config.medium_name, latex_mapping=config.latex_mapping or {}, ) # === Out-of-sample validation === val_stats: Dict = {} if data_validation is not None and len(data_validation) > 0: _log.info("Computing out-of-sample validation metrics...") subsets: Dict[str, pd.DataFrame] = {"All_Validation": data_validation} if "intens" in data_validation.columns: subsets["High_Intensity"] = data_validation[ data_validation["intens"] >= 1e14 ] subsets["Low_Intensity"] = data_validation[data_validation["intens"] < 1e14] if "wavel" in data_validation.columns: subsets["Short_wl"] = data_validation[data_validation["wavel"] <= 800] subsets["Long_wl"] = data_validation[data_validation["wavel"] > 800] for subset_name, df_sub in subsets.items(): if len(df_sub) == 0: continue val_stats[subset_name] = compute_validation_metrics( params, df_sub, dt=config.dt, dT=config.dT, ret_electron_density=config.ret_electron_density, ) v = val_stats[subset_name] _log.info( " [%s] N=%d redchi=%.3f RMSE=%.2e", subset_name, v["N"], v["redchi"], v["rmse"], ) # === Assemble result === result = RetrievalResult( params=params, var_names=var_names, chisqr=chisqr, redchi=redchi, ndata=ndata, nvarys=nvarys, covar=covar, flat_samples=flat_samples, lnprob=lnprob_arr, val_stats=val_stats, de_result=de_result, ls_result=ls_result, method="+".join(phases_run), chain_path=_chain_path, ) # LaTeX summary (after val_stats are populated) if run_mcmc_flag and flat_samples is not None and len(flat_samples) > 0: generate_latex_summary( result, output_dir, latex_mapping=config.latex_mapping or {} ) # === Final summary === _section(f"Pipeline Complete — {config.medium_name} [{'+'.join(phases_run)}]") print(f" Phases run : {' → '.join(phases_run)}", flush=True) print(f" N data points : {ndata}", flush=True) print(f" N free params : {nvarys}", flush=True) print(f" χ² : {chisqr:.4g}", flush=True) print(f" red-χ² : {redchi:.4g}", flush=True) if val_stats: print(f" Validation:", flush=True) for name, v in val_stats.items(): print( f" {name:30s} N={v['N']:4d} red-χ²={v['redchi']:.3f} RMSE={v['rmse']:.2e}", flush=True, ) print(f" Output dir : {output_dir}", flush=True) # === Persist === json_path = output_dir / f"{config.medium_name}_retrieval_result.json" save_result(result, json_path) print(f" JSON saved : {json_path}", flush=True) return result
# --------------------------------------------------------------------------- # Multiprocessing helper # --------------------------------------------------------------------------- def make_emcee_pool(n_workers: Optional[int] = None) -> multiprocessing.Pool: """Create a ``multiprocessing.Pool`` for use as ``config.emcee_pool``. Each worker sets ``NUMBA_NUM_THREADS=1`` so Numba does not over-subscribe the CPU when combined with process-level parallelism. Args: n_workers: Number of worker processes. ``None`` → CPU count. Returns: An initialised :class:`multiprocessing.Pool`. **Remember to close it after the retrieval run:** ``pool.close(); pool.join()``. Example:: pool = make_emcee_pool() config.emcee_pool = pool result = retrieve(data_NA, config) pool.close() pool.join() """ def _worker_init() -> None: os.environ["NUMBA_NUM_THREADS"] = "1" n = n_workers or multiprocessing.cpu_count() return multiprocessing.Pool(processes=n, initializer=_worker_init)