"""Systematic nonlinear regression pipeline for GASFIR parameter retrieval.
Three-phase pipeline
--------------------
1. **Global search** — :func:`run_global_search`
CMA-ES (default, ``cma`` package) or Differential Evolution (scipy,
``workers=-1``). Finds the basin of attraction robustly without a good
initial guess.
2. **Local polish** — :func:`run_local_polish`
Levenberg–Marquardt least squares (lmfit). Takes the DE/CMA-ES output and
converges to the precise minimum quickly.
3. **Uncertainty quantification** — :func:`run_mcmc`
``emcee`` MCMC. Samples the posterior and reports credible intervals.
Supports an optional ``multiprocessing.Pool`` for walker-level parallelism.
Parallelisation notes
---------------------
The GASFIR kernel (``get_diabatic_ionization_probability_batch``) already
parallelises over pulses via Numba ``prange``, using all available CPU cores.
This is the dominant speedup for typical datasets.
For emcee, walker-level parallelism via ``multiprocessing.Pool`` is available
through the ``pool`` parameter of :func:`run_mcmc`. Each worker process runs
the Numba kernel serially (``NUMBA_NUM_THREADS=1`` is set automatically in
workers); the overall CPU utilisation is comparable. The pool approach
helps when *nwalkers ≫ ncores* and each individual kernel call is fast.
For differential evolution, ``scipy.optimize.differential_evolution`` is used
with ``workers=1`` by default (Numba handles inner-loop parallelism). Passing
``workers=-1`` enables scipy's process-level parallelism, which can help when
individual residual evaluations are slow.
Typical usage
-------------
::
from gasfir import create_pulse, ret_pulse_from_pandas_table
from gasfir.retrieval import RetrievalConfig, retrieve
data = pd.read_csv("my_data.csv")
data["pulses"] = ret_pulse_from_pandas_table(data)
cfg = RetrievalConfig(medium_name="H_SFA") # saves to ./H_SFA/
cfg = RetrievalConfig(medium_name="H_SFA", output_dir="results/H_SFA") # explicit path
result = retrieve(data_NA=data, config=cfg)
"""
from __future__ import annotations
import json
import logging
import multiprocessing
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
if TYPE_CHECKING:
import lmfit
from lmfit.minimizer import MinimizerResult
import numpy as np
import numpy.typing as npt
import pandas as pd
from scipy.optimize import differential_evolution
try:
import lmfit
from lmfit.minimizer import MinimizerResult
_HAS_LMFIT = True
except ImportError:
lmfit = None # type: ignore[assignment]
_HAS_LMFIT = False
# MinimizerResult is intentionally left undefined when lmfit is absent.
# Accessing it raises NameError rather than the silent TypeError that
# `isinstance(x, None)` would produce.
try:
import emcee
_HAS_EMCEE = True
except ImportError:
emcee = None # type: ignore[assignment]
_HAS_EMCEE = False
try:
import matplotlib as mpl
import matplotlib.pyplot as plt
_HAS_MPL = True
except ImportError:
mpl = None # type: ignore[assignment]
plt = None # type: ignore[assignment]
_HAS_MPL = False
try:
import corner as _corner_lib
_HAS_CORNER = True
except ImportError:
_corner_lib = None # type: ignore[assignment]
_HAS_CORNER = False
try:
from tqdm.auto import tqdm
_HAS_TQDM = True
except ImportError:
def tqdm(iterable, **kwargs): # type: ignore[misc]
return iterable
_HAS_TQDM = False
from ._atomic_data import get_ionization_potential
from ._parameter_store import get_parameters
from .fitting import ret_residual_function
_log = logging.getLogger(__name__)
_RETRIEVAL_INSTALL_HINT = (
"Install the retrieval extras with: pip install gasfir[retrieval]"
)
def _require_lmfit() -> None:
"""Raise a helpful ImportError if lmfit is not installed."""
if not _HAS_LMFIT:
raise ImportError(
"lmfit is required for parameter fitting and retrieval. "
+ _RETRIEVAL_INSTALL_HINT
)
# ---------------------------------------------------------------------------
# Optional dependency: CMA-ES
# ---------------------------------------------------------------------------
try:
import cma # type: ignore[import]
_HAS_CMA = True
except ImportError:
_HAS_CMA = False
# ---------------------------------------------------------------------------
# Pretty-print helpers
# ---------------------------------------------------------------------------
def _section(title: str) -> None:
"""Print a bold section header to stdout."""
bar = "═" * 62
print(f"\n{bar}", flush=True)
print(f" {title}", flush=True)
print(bar, flush=True)
def _build_emcee_minimizer_result(
params: lmfit.Parameters,
var_names: List[str],
flat_samples: npt.NDArray[np.float64],
ndata: int,
n_walkers: int,
n_steps: int,
chisqr: float,
) -> MinimizerResult:
"""Wrap emcee output in an lmfit MinimizerResult for full fit_report() support.
Sets ``stderr`` on each varied parameter to the posterior standard deviation
and populates ``flatchain`` as a DataFrame so lmfit corner/trace plotting works.
Args:
params: Parameters at the emcee median (stderr already set by run_mcmc).
var_names: Ordered list of varied parameter names.
flat_samples: Flattened MCMC chain, shape (n_samples, ndim).
ndata: Number of data points used in the fit.
n_walkers: Number of emcee walkers.
n_steps: Total MCMC steps per walker.
chisqr: Sum of squared residuals at the median parameter values.
Returns:
A :class:`lmfit.minimizer.MinimizerResult` ready for ``lmfit.fit_report()``.
"""
nvarys = len(var_names)
covar = (
np.cov(flat_samples, rowvar=False)
if flat_samples.ndim == 2 and flat_samples.shape[1] > 1
else None
)
# Build correlation dicts on each parameter (used by fit_report)
p = params.copy()
if covar is not None:
stds = np.sqrt(np.diag(covar))
for i, name in enumerate(var_names):
p[name].correl = {
var_names[j]: float(covar[i, j] / (stds[i] * stds[j]))
for j in range(nvarys)
if j != i
}
redchi = chisqr / max(ndata - nvarys, 1)
aic = chisqr + 2 * nvarys
bic = chisqr + nvarys * float(np.log(ndata))
result = MinimizerResult()
result.method = "emcee"
result.var_names = var_names
result.params = p
result.covar = covar
result.nvarys = nvarys
result.ndata = ndata
result.nfev = n_walkers * n_steps
result.success = True
result.errorbars = True
result.chisqr = chisqr
result.redchi = redchi
result.aic = aic
result.bic = bic
result.residual = np.array([]) # individual residuals not retained
try:
result.flatchain = pd.DataFrame(flat_samples, columns=var_names)
except Exception:
pass
return result
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
[docs]
@dataclass
class RetrievalConfig:
"""All settings for the three-phase retrieval pipeline.
Args:
medium_name: Name used to look up initial parameters via
:func:`~gasfir.get_parameters` (if ``initial_params`` is ``None``)
and as the stem for output file names.
output_dir: Directory where JSON results, HDF5 chains, and
LaTeX/corner outputs are written.
initial_params: Explicit initial parameter dict. If ``None``,
:func:`~gasfir.get_parameters(medium_name)` is used.
fixed_params: Names of parameters to hold fixed (not varied).
``E_g`` and ``m_eff`` are fixed by default for non-crystal media.
relative_bound: Half-width of parameter bounds as a multiple of
the absolute initial value. Default 5.0 means the search range is
``[val - 5*abs(val), val + 5*abs(val)]``.
is_crystal: If ``True``, ``m_eff`` is also varied and ``a0_per_au3``
is converted to ``a0`` using the lattice constant.
simultaneous: If ``True``, load and concatenate quasi-static residuals.
dt: Coarse time step for pulse pre-computation (a.u.).
dT: Fine time step for pulse pre-computation (a.u.).
ret_electron_density: Pass ``True`` to fit electron density instead of
ionization probability.
global_method: ``"differential_evolution"`` or ``"cma"``.
de_workers: Number of workers for differential evolution. ``-1``
uses all cores. ``1`` (default) lets Numba handle intra-call
parallelism.
de_maxiter: Maximum iterations for differential evolution.
de_tol: Convergence tolerance for differential evolution.
de_popsize: Population size multiplier for differential evolution.
cma_sigma0: Initial step size for CMA-ES.
cma_maxiter: Maximum function evaluations for CMA-ES.
ls_method: lmfit method for local polish (e.g. ``"least_squares"``,
``"leastsq"``).
emcee_nwalkers: Number of MCMC walkers (must be even and ≥ 2×ndim).
``None`` → ``max(32, 4×ndim)``.
emcee_nsteps: Total MCMC steps per walker.
emcee_burn_fraction: Fraction of chain discarded as burn-in if
autocorrelation time estimation fails.
emcee_thin: Thinning factor applied when drawing flat samples.
emcee_max_autocorr_loops: Maximum auto-correction loops (re-centre
bounds and restart MCMC if parameters drift).
emcee_pool: Optional ``multiprocessing.Pool``-compatible object
passed to :class:`emcee.EnsembleSampler`. ``None`` (default)
uses Numba thread-level parallelism instead.
trial_run: If ``True``, run DE and MCMC with minimal iterations for
a quick sanity check.
"""
medium_name: str = "unknown"
output_dir: Optional[Union[str, Path]] = None
"""Output directory for JSON, HDF5 chain, plots, and LaTeX files.
Defaults to a folder named after *medium_name* in the current directory."""
initial_params: Optional[Dict[str, float]] = None
fixed_params: Optional[List[str]] = None
relative_bound: float = 5.0
is_crystal: bool = False
simultaneous: bool = False
dt: float = 4.0
dT: float = 0.25
ret_electron_density: bool = False
# --- global search ---
global_method: str = "cma"
de_workers: int = 1
de_maxiter: int = 1000
de_tol: float = 1e-7
de_popsize: int = 15
cma_sigma0: float = 0.3
cma_maxiter: int = 10_000
# --- local polish ---
ls_method: str = "least_squares"
# --- emcee ---
emcee_nwalkers: Optional[int] = None
emcee_nsteps: int = 3000
emcee_burn_fraction: float = 0.3
emcee_thin: int = 10
emcee_max_autocorr_loops: int = 5
emcee_pool: Optional[object] = field(default=None, repr=False)
trial_run: bool = False
# --- reporting ---
latex_mapping: Optional[Dict[str, str]] = field(default=None, repr=False)
"""LaTeX label overrides for parameter names used in plots and .tex output.
Example::
{"E_g": r"$E_g$ [a.u.]", "m_eff": r"$m_{\\mathrm{eff}}$"}
"""
# ---------------------------------------------------------------------------
# Result container
# ---------------------------------------------------------------------------
[docs]
@dataclass
class RetrievalResult:
"""Full results from a retrieval run.
Attributes:
params: Final lmfit Parameters with best-fit values and (after MCMC)
``stderr`` attributes set to posterior standard deviations.
var_names: Names of the varied parameters.
chisqr: Sum of squared residuals at the best-fit point.
redchi: Reduced chi-squared.
ndata: Number of data points used in the fit.
nvarys: Number of varied parameters.
covar: Covariance matrix from the MCMC posterior.
flat_samples: Flattened MCMC chain (shape: ``(n_samples, ndim)``).
lnprob: Log-probability for each flat sample.
val_stats: Out-of-sample validation statistics keyed by subset name.
de_result: Raw scipy DE result (or ``None``).
ls_result: Raw lmfit least-squares result (or ``None``).
method: Comma-separated list of phases that were run.
"""
params: lmfit.Parameters
var_names: List[str]
chisqr: float = 0.0
redchi: float = 0.0
ndata: int = 0
nvarys: int = 0
covar: Optional[npt.NDArray[np.float64]] = None
flat_samples: Optional[npt.NDArray[np.float64]] = None
lnprob: Optional[npt.NDArray[np.float64]] = None
val_stats: Dict = field(default_factory=dict)
de_result: object = field(default=None, repr=False)
ls_result: object = field(default=None, repr=False)
method: str = "emcee"
chain_path: Optional[Path] = field(default=None, repr=False)
"""Path to the HDF5 emcee chain file, if one was written."""
# ---------------------------------------------------------------------------
# I/O helpers
# ---------------------------------------------------------------------------
[docs]
def save_result(result: RetrievalResult, path: Union[str, Path]) -> None:
"""Serialise a :class:`RetrievalResult` to a JSON file.
Args:
result: The retrieval result to save.
path: Output file path (will be created/overwritten).
"""
path = Path(path)
covar_list = result.covar.tolist() if result.covar is not None else None
payload = {
"params": json.loads(result.params.dumps()),
"var_names": result.var_names,
"chisqr": result.chisqr,
"redchi": result.redchi,
"ndata": result.ndata,
"nvarys": result.nvarys,
"covar": covar_list,
"val_stats": result.val_stats,
"method": result.method,
}
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as fh:
json.dump(payload, fh, indent=4)
_log.info("Result saved to %s", path)
[docs]
def load_result(path: Union[str, Path]) -> RetrievalResult:
"""Load a :class:`RetrievalResult` from a JSON file written by :func:`save_result`.
Args:
path: Path to the JSON file.
Returns:
A :class:`RetrievalResult` with ``flat_samples`` and ``lnprob`` set to
``None`` (they are not persisted in JSON; reload the HDF5 chain for
those).
"""
with open(path) as fh:
data = json.load(fh)
params = lmfit.Parameters()
params.loads(json.dumps(data["params"]))
result = RetrievalResult(
params=params,
var_names=data.get("var_names", []),
chisqr=data.get("chisqr", 0.0),
redchi=data.get("redchi", 0.0),
ndata=data.get("ndata", 0),
nvarys=data.get("nvarys", 0),
covar=np.array(data["covar"]) if data.get("covar") is not None else None,
val_stats=data.get("val_stats", {}),
method=data.get("method", "unknown"),
)
return result
# ---------------------------------------------------------------------------
# Parameter setup
# ---------------------------------------------------------------------------
def setup_parameters(
params_dict: Dict[str, float],
relative_bound: float = 5.0,
fixed_params: Optional[List[str]] = None,
is_crystal: bool = False,
param_bounds: Optional[Dict[str, tuple]] = None,
) -> lmfit.Parameters:
"""Build a bounded :class:`lmfit.Parameters` object from a raw dict.
Args:
params_dict: Initial parameter values (e.g. from
:func:`~gasfir.get_parameters`).
relative_bound: Half-width of the search range as a multiple of
the absolute value. A value of 5.0 means the range spans
``[val - 5*abs(val), val + 5*abs(val)]``.
fixed_params: Names of parameters to hold fixed. Defaults to
``["E_g"]`` for atomic/molecular media. For crystals (``is_crystal=True``)
only explicitly listed names are fixed.
is_crystal: When ``True``, ``m_eff`` is treated as a free parameter
and ``a0_per_au3`` is converted to ``a0`` via the lattice constant.
param_bounds: Optional explicit ``{name: (min, max)}`` overrides.
Any parameter listed here uses the given bounds instead of the
``relative_bound`` formula. Used for cold-start runs where
physical rather than relative bounds make more sense.
Returns:
A :class:`lmfit.Parameters` instance with bounds set.
"""
_require_lmfit()
if param_bounds is None:
param_bounds = {}
if fixed_params is None:
fixed_params = [] if is_crystal else ["E_g", "m_eff"]
params = lmfit.Parameters()
for name, val in params_dict.items():
if name == "lattice_constant_au":
continue # derived constant, not a fit parameter
if name in fixed_params:
params.add(name, value=val, vary=False)
continue
# Crystal: convert a0_per_au3 → a0 in atomic units
if name == "a0_per_au3" and is_crystal:
lat = params_dict.get("lattice_constant_au", 1.0)
a0 = val * lat**3
offset = relative_bound * abs(a0)
params.add(
"a0",
value=a0,
vary=True,
min=max(1e-8, a0 - offset),
max=a0 + offset,
)
continue
# Explicit bounds override the relative-bound formula
if name in param_bounds:
lb, ub = param_bounds[name]
params.add(name, value=val, vary=True, min=lb, max=ub)
continue
offset = relative_bound * abs(val)
if offset == 0.0:
val = 1e-6
offset = 4.0
# Parameters that are physically positive-definite
_positive = {"E_g", "a0", "a1"}
if name == "a4":
if float(val) == 0.0:
params.add(name, value=val, vary=False)
else:
lb = max(-5.0, val - offset)
ub = min(5.0, val + offset)
if is_crystal:
lb = max(-1.0, val - offset)
params.add(name, value=val, vary=True, min=lb, max=ub)
elif name in _positive:
params.add(
name,
value=val,
vary=True,
min=max(1e-8, val - offset),
max=val + offset,
)
else:
params.add(name, value=val, vary=True, min=val - offset, max=val + offset)
return params
# ---------------------------------------------------------------------------
# Statistical regularisation
# ---------------------------------------------------------------------------
def get_cleanest_value(
mcmc_median: float,
mcmc_err: float,
initial_val: Optional[float] = None,
) -> float:
"""Return the statistically cleanest representative value in the 1-σ interval.
Priority:
1. ``initial_val`` if it falls within ``[median - err, median + err]``.
2. ``0.0`` if zero is contained in the interval.
3. The number with the fewest significant figures inside the interval.
4. ``mcmc_median`` as a fallback.
Args:
mcmc_median: Median of the posterior sample.
mcmc_err: Standard deviation of the posterior sample (1-σ).
initial_val: Prior best estimate (e.g. the stored parameter value).
Pass ``None`` when there is no prior.
Returns:
The cleanest representative value within the 1-σ interval.
"""
low = mcmc_median - mcmc_err
high = mcmc_median + mcmc_err
if initial_val is not None and low <= initial_val <= high:
return initial_val
if low <= 0.0 <= high:
return 0.0
for sig in range(1, 15):
fmt = "{:." + str(sig) + "g}"
for candidate in (
float(fmt.format(mcmc_median)),
float(fmt.format(low)),
float(fmt.format(high)),
):
if low <= candidate <= high:
return candidate
return mcmc_median
# ---------------------------------------------------------------------------
# Safe residual wrapper
# ---------------------------------------------------------------------------
def _make_safe_residual(
base_fn: Callable[[lmfit.Parameters], npt.NDArray[np.float64]],
) -> Callable[[lmfit.Parameters], npt.NDArray[np.float64]]:
"""Wrap a residual function to replace NaN/Inf with large finite values.
Args:
base_fn: The original residual function.
Returns:
A wrapped callable with the same signature that never returns NaN or
infinite values, making it safe for all optimisers.
"""
def safe_fn(params: lmfit.Parameters) -> npt.NDArray[np.float64]:
res = np.asarray(base_fn(params))
valid = np.isfinite(res)
penalty = (
min(np.max(np.abs(res[valid])) * 10, 1e100) if np.any(valid) else 200.0
)
return np.clip(
np.nan_to_num(res, nan=penalty, posinf=penalty, neginf=-penalty),
-1e100,
1e100,
)
return safe_fn
# ---------------------------------------------------------------------------
# Picklable log-probability for emcee
# ---------------------------------------------------------------------------
class _LogProbability:
"""Picklable callable wrapping a GASFIR residual function for emcee.
Using a class (rather than a closure) ensures the object is picklable
across ``multiprocessing.Pool`` workers, which require all callables and
their captured state to be picklable.
Attributes:
var_names: Ordered list of varied parameter names.
bounds: Dict ``{name: (min, max)}`` for each varied parameter.
params_template: A :class:`lmfit.Parameters` template that is cloned
on each call to avoid race conditions.
"""
def __init__(
self,
residual_fn: Callable,
var_names: List[str],
params: lmfit.Parameters,
) -> None:
self._residual_fn = residual_fn
self.var_names = var_names
self.bounds = {name: (params[name].min, params[name].max) for name in var_names}
self._params_template = params
def __call__(self, theta: npt.NDArray[np.float64]) -> float:
"""Evaluate log-probability at walker position *theta*.
Args:
theta: 1-D array of parameter values in the order of
``self.var_names``.
Returns:
Log-probability value (``-inf`` for out-of-bounds walkers).
"""
p = self._params_template.copy()
for name, val in zip(self.var_names, theta):
lo, hi = self.bounds[name]
if not (lo <= val <= hi):
return -np.inf
p[name].value = val
resid = self._residual_fn(p)
if np.any(np.abs(resid) >= 1e99):
return -np.inf
return float(-0.5 * np.sum(resid**2))
# ---------------------------------------------------------------------------
# Phase 1 — global search
# ---------------------------------------------------------------------------
[docs]
def run_global_search(
residual_fn: Callable,
params: lmfit.Parameters,
method: str = "differential_evolution",
de_workers: int = 1,
de_maxiter: int = 1000,
de_tol: float = 1e-7,
de_popsize: int = 15,
cma_sigma0: float = 0.3,
cma_maxiter: int = 10_000,
trial_run: bool = False,
) -> lmfit.Parameters:
"""Global parameter search using Differential Evolution or CMA-ES.
The result is used as the starting point for :func:`run_local_polish`.
Args:
residual_fn: Callable ``f(params) -> residual_array``.
params: Starting :class:`lmfit.Parameters` with bounds set.
method: ``"differential_evolution"`` or ``"cma"``.
de_workers: Workers for DE. ``1`` → Numba parallelism handles
the inner loop. ``-1`` → all cores for process-level DE.
de_maxiter: Maximum DE iterations.
de_tol: DE convergence tolerance.
de_popsize: DE population multiplier.
cma_sigma0: Initial step size for CMA-ES (relative to bounds width).
cma_maxiter: Maximum CMA-ES function evaluations.
trial_run: Use a very small budget for quick testing.
Returns:
Updated :class:`lmfit.Parameters` with values set to the global
search optimum.
Raises:
ImportError: If ``method="cma"`` but the ``cma`` package is not
installed.
ValueError: If ``method`` is not recognised.
"""
var_names = [n for n in params if params[n].vary]
bounds = [(params[n].min, params[n].max) for n in var_names]
# Print parameter bounds table before search begins
print(
f"\n {'Parameter':12s} {'Initial':>14s} {'Lower':>14s} {'Upper':>14s}",
flush=True,
)
print(f" {'─'*12} {'─'*14} {'─'*14} {'─'*14}", flush=True)
for name in var_names:
p = params[name]
lo = f"{p.min:.4g}" if p.min is not None and not np.isinf(p.min) else "−∞"
hi = f"{p.max:.4g}" if p.max is not None and not np.isinf(p.max) else "+∞"
print(f" {name:12s} {p.value:14.4g} {lo:>14s} {hi:>14s}", flush=True)
print(flush=True)
# ── parameter normalisation for CMA-ES ────────────────────────────────────
# CMA-ES starts with a spherical (isotropic) search. When bounds widths
# differ by orders of magnitude (e.g. a0 range 428 vs m_eff range 0.69)
# the initial sigma is shrunk to fit the tightest dimension, leaving
# the widest dimensions essentially frozen. Mapping each parameter to
# [0, 1] removes the scale disparity so CMA-ES explores all dimensions
# equally on the first generation.
lo_arr = np.array([params[n].min for n in var_names])
hi_arr = np.array([params[n].max for n in var_names])
width = hi_arr - lo_arr
def _normalise(theta: npt.NDArray) -> npt.NDArray:
return (theta - lo_arr) / width # physical → [0, 1]
def _denormalise(u: npt.NDArray) -> npt.NDArray:
return lo_arr + u * width # [0, 1] → physical
# initial point and bounds in normalised space
x0_norm = _normalise(np.array([params[n].value for n in var_names]))
bounds_norm = ([0.0] * len(var_names), [1.0] * len(var_names))
_ndata = [1] # captured from first residual call
def _scalar_objective(u: npt.NDArray) -> float:
theta = _denormalise(u)
p = params.copy()
for name, val in zip(var_names, theta):
p[name].value = float(val)
resid = residual_fn(p)
if _ndata[0] == 1:
_ndata[0] = max(len(resid) - len(var_names), 1)
return float(np.sum(resid**2))
if method == "differential_evolution":
_log.info("Phase 1: Differential Evolution (workers=%d)", de_workers)
maxiter = 10 if trial_run else de_maxiter
de_res = differential_evolution(
_scalar_objective,
bounds=bounds,
maxiter=maxiter,
tol=de_tol,
popsize=de_popsize,
workers=de_workers,
seed=42,
polish=False,
)
_log.info("DE finished: success=%s f=%g", de_res.success, de_res.fun)
for name, val in zip(var_names, de_res.x):
params[name].value = val
return params, de_res
if method == "cma":
if not _HAS_CMA:
raise ImportError(
"The 'cma' package is required for CMA-ES. "
"Install it with: pip install cma"
)
_log.info("Phase 1: CMA-ES (sigma0=%.3g, normalised)", cma_sigma0)
# Work in normalised [0,1]^n space — removes scale disparity between params
scaled_sigma = cma_sigma0 # sigma0 is now a fraction of [0,1]
opts = cma.CMAOptions()
opts["bounds"] = list(bounds_norm)
opts["maxfevals"] = 100 if trial_run else cma_maxiter
opts["verbose"] = -9 # suppress CMA's own output; we print selectively
opts["verb_log"] = 0 # disable log files
# ── disable all early-stopping criteria ──────────────────────────────
# CMA-ES has several heuristic stopping rules that fire long before
# maxfevals when the landscape is flat or the optimum has been found.
# We disable them all so that the run always reaches the requested
# budget — the user can then judge convergence from the f_best trace.
# Explicitly set options that CMA stores as lazy string formulas —
# leaving them as strings causes TypeError in newer cma versions
# when CMA tries arithmetic on them during __init__.
ndim_cma = len(var_names)
opts["popsize"] = 4 + int(3 * np.log(ndim_cma)) # default formula evaluated
opts["CMA_diagonal"] = 0 # disable diagonal-only mode
opts["tolx"] = 0 # disable position-change tolerance
opts["tolfun"] = 0 # disable function-value tolerance
opts["tolstagnation"] = int(1e9) # disable stagnation criterion
opts["tolconditioncov"] = 1e30 # disable ill-conditioning stop
opts["tolflatfitness"] = int(1e9) # disable flat-fitness stop
opts["maxiter"] = int(1e9) # disable default generation cap
# (default = 100 + 50*ndim/popsize ≈ 133 for ndim=6, fires before maxfevals)
# ─────────────────────────────────────────────────────────────────────
budget = 100 if trial_run else cma_maxiter
print(
f"\n {'Gen':>5s} {'Evals':>6s} {'Budget':>7s} {'red-χ² (best)':>15s} {'Improvement':>12s}",
flush=True,
)
print(f" {'─'*5} {'─'*6} {'─'*7} {'─'*15} {'─'*12}", flush=True)
_last_rchi2 = [np.inf]
_print_every = max(1, budget // 50)
_plateau_win = max(20, budget // 100)
_plateau_thr = 0.005
_history: List[float] = []
_plateau_fired = [False] # fire once only
def _cma_callback(es: object) -> None:
chi2 = es.result.fbest
rchi2 = chi2 / _ndata[0]
gen = es.result.iterations
ev = es.result.evaluations
pct = 100.0 * ev / budget
improved = rchi2 < _last_rchi2[0] * 0.99 # >1 % drop
if gen == 1 or gen % _print_every == 0 or improved:
delta = (
f"↓{100*(1 - rchi2/_last_rchi2[0]):5.1f}%"
if improved and _last_rchi2[0] < np.inf
else "—"
)
print(
f" {gen:5d} {ev:6d} {pct:6.1f}% {rchi2:15.4g} {delta:>12s}",
flush=True,
)
if improved:
_last_rchi2[0] = rchi2
# ── plateau detection ──────────────────────────────────────────
if _plateau_fired[0]:
return False # keep returning False until cma fully stops
_history.append(rchi2)
if len(_history) >= _plateau_win:
oldest = _history[-_plateau_win]
rel_improvement = (oldest - rchi2) / max(oldest, 1e-12)
if rel_improvement < _plateau_thr:
_plateau_fired[0] = True
print(
f"\n [CMA-ES] Plateau — red-χ² improved < {_plateau_thr*100:.1f}% "
f"over {_plateau_win} gens. Stopping at {ev} evals "
f"({pct:.1f}% of budget).",
flush=True,
)
return False
es = cma.CMAEvolutionStrategy(x0_norm, scaled_sigma, opts)
while not es.stop():
X = es.ask()
fit = [_scalar_objective(x) for x in X]
es.tell(X, fit)
_cma_callback(es)
if _plateau_fired[0]:
break
best_x = _denormalise(es.result.xbest)
rchi2_final = es.result.fbest / _ndata[0]
print(
f"\n CMA-ES finished — {es.result.evaluations} evals | "
f"best red-χ² = {rchi2_final:.4g}",
flush=True,
)
for name, val in zip(var_names, best_x):
params[name].value = val
return params, es.result
raise ValueError(
f"Unknown global_method '{method}'. Choose 'differential_evolution' or 'cma'."
)
# ---------------------------------------------------------------------------
# Phase 2 — local polish
# ---------------------------------------------------------------------------
[docs]
def run_local_polish(
residual_fn: Callable,
params: lmfit.Parameters,
method: str = "least_squares",
) -> Tuple[lmfit.Parameters, MinimizerResult]:
"""Locally polish the global-search result with Levenberg–Marquardt.
Args:
residual_fn: Callable ``f(params) -> residual_array``.
params: :class:`lmfit.Parameters` at the global search optimum.
method: lmfit minimisation method. ``"least_squares"`` is recommended
as it handles bounds natively.
Returns:
Tuple of (updated params, :class:`lmfit.MinimizerResult`).
"""
_require_lmfit()
_log.info("Phase 2: Local polish with %s", method)
print(f" [LS] Starting {method} …", flush=True)
_ls_iter = [0]
def _iter_cb(params, iter_num, resid, **kw):
_ls_iter[0] = iter_num
if iter_num % 10 == 0:
print(
f" [LS] iter {iter_num:4d} | "
f"‖resid‖ = {float(np.sum(resid**2)):.6g}",
flush=True,
)
mini = lmfit.Minimizer(residual_fn, params, nan_policy="omit", iter_cb=_iter_cb)
ls_result = mini.minimize(method=method)
print(
f" [LS] done — {_ls_iter[0]} iters | "
f"success={ls_result.success} | redchi={ls_result.redchi:.4g}",
flush=True,
)
_log.info(
"LS finished: success=%s redchi=%.4g",
ls_result.success,
ls_result.redchi,
)
return ls_result.params, ls_result
# ---------------------------------------------------------------------------
# Phase 3 — MCMC
# ---------------------------------------------------------------------------
[docs]
def run_mcmc(
residual_fn: Callable,
params: lmfit.Parameters,
n_steps: int = 3000,
n_walkers: Optional[int] = None,
burn_fraction: float = 0.3,
thin: int = 10,
backend_path: Optional[Union[str, Path]] = None,
pool: Optional[object] = None,
max_autocorr_loops: int = 5,
initial_stored_params: Optional[Dict[str, float]] = None,
trial_run: bool = False,
) -> Tuple[lmfit.Parameters, npt.NDArray, npt.NDArray, object]:
"""Run emcee MCMC and return posterior summary.
Wraps the residual function in a log-probability callable, runs emcee
with optional HDF5 chain persistence and auto-correction for parameter
drift.
Args:
residual_fn: Safe residual callable ``f(params) -> array``.
params: :class:`lmfit.Parameters` at the local-polish optimum.
n_steps: Total MCMC steps per walker.
n_walkers: Number of walkers. ``None`` → ``max(32, 4 × ndim)``.
burn_fraction: Fraction of chain used as burn-in when autocorrelation
estimation fails.
thin: Thinning factor for flat samples.
backend_path: HDF5 file for chain persistence. Pass ``None`` to
skip persistence.
pool: Optional ``multiprocessing.Pool`` for walker-level parallelism.
Pass ``None`` (default) to rely on Numba intra-call parallelism.
max_autocorr_loops: Number of auto-correction cycles (re-centre
bounds around drifted parameters and restart).
initial_stored_params: Original parameter dict used as the reference
for drift detection (passed to :func:`get_cleanest_value`).
trial_run: Use a very small step count for quick testing.
Returns:
``(final_params, flat_samples, lnprob, sampler)`` where
*final_params* has ``stderr`` set to posterior standard deviations and
values snapped to clean estimates via :func:`get_cleanest_value`.
"""
_require_lmfit()
if not _HAS_EMCEE:
raise ImportError(
"emcee is required for MCMC sampling. " + _RETRIEVAL_INSTALL_HINT
)
var_names = sorted([n for n in params if params[n].vary])
ndim = len(var_names)
if n_walkers is None:
n_walkers = max(32, 4 * ndim)
# emcee requires n_walkers > 2 × ndim; also enforce even
n_walkers = max(n_walkers, 2 * ndim + 2)
n_walkers = n_walkers + (n_walkers % 2)
steps = 20 if trial_run else n_steps
stored = initial_stored_params or {}
log_prob = _LogProbability(residual_fn, var_names, params)
_log.info(
"Phase 3: emcee ndim=%d nwalkers=%d steps=%d pool=%s",
ndim,
n_walkers,
steps,
"yes" if pool else "no",
)
# ----- auto-correction loop -----
final_params = params.copy()
flat_samples: npt.NDArray = np.empty(0)
lnprob_arr: npt.NDArray = np.empty(0)
sampler: Optional[object] = None
for loop in range(max_autocorr_loops):
_log.info("MCMC auto-correction loop %d/%d", loop + 1, max_autocorr_loops)
# Refresh log_prob with current bounds
log_prob = _LogProbability(residual_fn, var_names, final_params)
# HDF5 backend
backend: Optional[emcee.backends.HDFBackend] = None
if backend_path is not None:
bp = Path(backend_path)
bp.parent.mkdir(parents=True, exist_ok=True)
backend = emcee.backends.HDFBackend(str(bp))
# Determine initial state
if (
loop == 0
and backend is not None
and bp.exists()
and backend.iteration > 0
and backend.shape == (n_walkers, ndim)
):
_log.info("Resuming from existing chain at step %d", backend.iteration)
initial_state = backend.get_last_sample()
remaining = max(0, steps - backend.iteration)
else:
if backend is not None:
backend.reset(n_walkers, ndim)
x0 = np.array([final_params[n].value for n in var_names])
initial_state = x0 + 1e-4 * np.random.default_rng(
42 + loop
).standard_normal((n_walkers, ndim))
remaining = steps
sampler = emcee.EnsembleSampler(
n_walkers, ndim, log_prob, backend=backend, pool=pool
)
if remaining > 0:
with tqdm(
total=remaining, desc=f"emcee loop {loop+1}", unit="step"
) as pbar:
for _ in sampler.sample(initial_state, iterations=remaining):
pbar.update(1)
# Autocorrelation-based burn-in
try:
tau = sampler.get_autocorr_time(quiet=True)
burnin = int(2 * np.max(tau))
thin_auto = max(1, int(0.5 * np.min(tau)))
except Exception: # emcee.autocorr.AutocorrError or chain too short
n_stored = sampler.get_chain().shape[0]
burnin = max(0, int(n_stored * burn_fraction))
thin_auto = thin
n_stored = sampler.get_chain().shape[0]
burnin = min(burnin, n_stored // 2)
flat_samples = sampler.get_chain(discard=burnin, thin=thin_auto, flat=True)
lnprob_arr = sampler.get_log_prob(discard=burnin, thin=thin_auto, flat=True)
if len(flat_samples) == 0:
_log.warning(
"Empty flat samples — chain too short. Proceeding without thinning."
)
flat_samples = sampler.get_chain(flat=True)
lnprob_arr = sampler.get_log_prob(flat=True)
# Drift detection and snapping
all_stable = True
next_params = final_params.copy()
_log.info("Drift analysis:")
for i, name in enumerate(var_names):
median = float(np.median(flat_samples[:, i]))
sigma = float(np.std(flat_samples[:, i]))
init = stored.get(name)
clean = get_cleanest_value(median, sigma, init)
next_params[name].value = clean
next_params[name].stderr = sigma
if init is not None and clean != init:
_log.info(
" DRIFT %s: %.4g → %.4g (σ=%.2g)", name, init, clean, sigma
)
all_stable = False
else:
_log.info(" STABLE %s: %.4g (σ=%.2g)", name, clean, sigma)
final_params = next_params
if all_stable or loop == max_autocorr_loops - 1:
if not all_stable:
_log.warning("Max autocorrection loops reached; using current best.")
else:
_log.info("Convergence confirmed — all parameters stable.")
break
# Copy fixed params
for name in params:
if not params[name].vary:
final_params[name].value = params[name].value
assert sampler is not None
return final_params, flat_samples, lnprob_arr, sampler
# ---------------------------------------------------------------------------
# Validation metrics
# ---------------------------------------------------------------------------
def compute_validation_metrics(
result_params: lmfit.Parameters,
validation_df: pd.DataFrame,
uncertainty: float = 5e-2,
dt: float = 4.0,
dT: float = 0.25,
ret_electron_density: bool = False,
) -> Dict[str, float]:
"""Compute chi-squared and RMSE for an out-of-sample validation set.
Args:
result_params: Best-fit parameters.
validation_df: DataFrame with ``"pulses"`` and ``"Y"`` columns.
uncertainty: Fractional uncertainty (applied as ``uncertainty × Y``).
dt: Coarse time step for the kernel.
dT: Fine time step for the kernel.
ret_electron_density: Pass ``True`` to use electron density.
Returns:
Dict with keys ``"N"``, ``"chisqr"``, ``"redchi"``, ``"rmse"``.
"""
uncert_arr = uncertainty * validation_df["Y"].to_numpy()
val_fn = ret_residual_function(
validation_df,
uncert_arr,
None,
None,
dt=dt,
dT=dT,
ret_electron_density=ret_electron_density,
)
resid = np.asarray(val_fn(result_params))
chisqr = float(np.sum(resid**2))
n = len(validation_df)
rmse = float(np.sqrt(np.mean((resid * uncert_arr) ** 2)))
return {"N": n, "chisqr": chisqr, "redchi": chisqr / n, "rmse": rmse}
# ---------------------------------------------------------------------------
# Stored-parameter comparison
# ---------------------------------------------------------------------------
[docs]
def compare_to_stored(
ls_result: MinimizerResult,
stored_params: Dict[str, float],
sigma: float = 2.0,
) -> None:
"""Print a table comparing LS-fitted values to stored reference parameters.
For each varied parameter, reports the stored value, fitted value,
standard error, number of standard deviations separating them, and
whether the stored value falls within *sigma*-σ of the fitted value.
Args:
ls_result: :class:`lmfit.minimizer.MinimizerResult` from the LS phase.
stored_params: Reference parameter dict (e.g. from
:func:`~gasfir.get_parameters`). Only keys that appear in
``ls_result.params`` and were varied are compared.
sigma: Confidence threshold in units of stderr (default 2 → 95 % CI).
"""
_require_lmfit()
var_names = [n for n in ls_result.params if ls_result.params[n].vary]
# filter to only those present in stored_params
names_to_check = [n for n in var_names if n in stored_params]
if not names_to_check:
print(
" (no overlap between fitted and stored parameters — skipping comparison)",
flush=True,
)
return
_section(f"Stored vs LS-fitted parameters ({sigma:.0f}σ tolerance)")
print(
f" {'Param':10s} {'Stored':>12s} {'Fitted':>12s} {'±stderr':>10s} "
f"{'Nσ away':>9s} {f'In {sigma:.0f}σ?':>8s}",
flush=True,
)
print(f" {'─'*10} {'─'*12} {'─'*12} {'─'*10} {'─'*9} {'─'*8}", flush=True)
all_pass = True
for name in names_to_check:
p = ls_result.params[name]
s = stored_params[name]
f = p.value
err = p.stderr if p.stderr is not None else float("nan")
nsigma = abs(s - f) / err if err > 0 else float("inf")
ok = nsigma <= sigma
if not ok:
all_pass = False
icon = "✅" if ok else "❌"
print(
f" {name:10s} {s:12.4g} {f:12.4g} {err:10.4g} "
f"{nsigma:9.2f} {icon:>8s}",
flush=True,
)
if all_pass:
print(
f"\n All stored values are within {sigma:.0f}σ of the LS fit. ✅",
flush=True,
)
else:
print(
f"\n Some stored values lie outside {sigma:.0f}σ — the stored parameters "
"may be rough starting guesses rather than a prior fit to this dataset.",
flush=True,
)
# ---------------------------------------------------------------------------
# Visualisation & reporting
# ---------------------------------------------------------------------------
def _latex_label(name: str, latex_mapping: Dict[str, str]) -> str:
"""Return the LaTeX label for a parameter, falling back to auto-escaped name."""
escaped = name.replace("_", r"\_")
return latex_mapping.get(name, f"${escaped}$")
[docs]
def generate_publication_corner(
flat_samples: npt.NDArray[np.float64],
var_names: List[str],
output_dir: Union[str, Path],
medium_name: str,
truths: Optional[List[float]] = None,
latex_mapping: Optional[Dict[str, str]] = None,
constants: Optional[Dict[str, float]] = None,
) -> Optional[Path]:
"""Save a publication-quality corner plot as a PDF.
The corner plot shows 1-D marginal distributions on the diagonal and
2-D contour projections off-diagonal. When *truths* are provided, two
inset tables are added: one with adopted parameter values and one with
the correlation matrix.
Args:
flat_samples: Flattened MCMC chain, shape ``(n_samples, ndim)``.
var_names: Ordered parameter names matching the chain columns.
output_dir: Directory where the PDF is saved.
medium_name: Used as the file-name stem.
truths: Best-fit / snapped values to mark on the plot (one per param).
latex_mapping: ``{name: latex_string}`` overrides for axis labels.
constants: Fixed parameter values to list in the adopted-values table.
Returns:
Path to the saved PDF, or ``None`` if matplotlib / corner are absent.
"""
if not _HAS_MPL or not _HAS_CORNER:
_log.warning("corner / matplotlib not installed — skipping corner plot.")
return None
if latex_mapping is None:
latex_mapping = {}
if constants is None:
constants = {}
display_labels = [_latex_label(n, latex_mapping) for n in var_names]
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
with mpl.rc_context({"font.size": 14, "font.family": "serif"}):
fig = _corner_lib.corner(
flat_samples,
labels=display_labels,
quantiles=[0.16, 0.5, 0.84],
show_titles=True,
truths=truths,
truth_color="firebrick",
title_kwargs={"fontsize": 14},
label_kwargs={"fontsize": 16},
levels=(1 - np.exp(-0.5), 1 - np.exp(-2), 1 - np.exp(-4.5)),
plot_datapoints=False,
fill_contours=True,
smooth=1.0,
color="#005b96",
)
if truths is not None:
# Inset table 1: adopted values + fixed constants
ax_params = fig.add_axes([0.60, 0.75, 0.35, 0.25])
ax_params.axis("off")
table_data = [["Parameter", "Adopted Value"]]
for i, name in enumerate(var_names):
table_data.append([display_labels[i], f"{truths[i]:.4g}"])
for c_name, c_val in constants.items():
table_data.append(
[_latex_label(c_name, latex_mapping), f"{c_val:.4g} (Fixed)"]
)
tbl = ax_params.table(
cellText=table_data, loc="center", cellLoc="center", edges="horizontal"
)
tbl.auto_set_font_size(False)
tbl.set_fontsize(14)
tbl.scale(1, 1.8)
num_fitted = len(var_names)
for (row, col), cell in tbl.get_celld().items():
if row == 0:
cell.get_text().set_weight("bold")
if col == 1:
cell.get_text().set_color("firebrick")
elif 1 <= row <= num_fitted and col == 1:
cell.get_text().set_color("firebrick")
cell.get_text().set_weight("bold")
# Inset table 2: correlation matrix
ax_corr = fig.add_axes([0.53, 0.6, 0.42, 0.15])
ax_corr.axis("off")
ax_corr.set_title(
"Correlation Matrix", fontsize=14, fontweight="bold", pad=10
)
corr = np.corrcoef(flat_samples, rowvar=False)
corr_data = [[""] + display_labels]
for i in range(len(var_names)):
row_data = [display_labels[i]]
for j in range(len(var_names)):
row_data.append("1" if i == j else f"{corr[i, j]:.2g}")
corr_data.append(row_data)
tbl_c = ax_corr.table(
cellText=corr_data, loc="center", cellLoc="center", edges="open"
)
tbl_c.auto_set_font_size(False)
tbl_c.set_fontsize(12)
tbl_c.scale(1, 1.8)
for (r, c), cell in tbl_c.get_celld().items():
if r == 0 and c == 0:
cell.visible_edges = "BR"
elif r == 0:
cell.visible_edges = "B"
elif c == 0:
cell.visible_edges = "R"
else:
cell.visible_edges = ""
cell.set_linewidth(1.5)
if r == 0 or c == 0:
cell.get_text().set_weight("bold")
if r > 0 and c > 0 and r == c:
cell.get_text().set_weight("bold")
out_path = output_dir / f"{medium_name}_publication_corner.pdf"
fig.savefig(out_path, dpi=300, bbox_inches="tight")
plt.close(fig)
print(f" 📊 Corner plot saved to: {out_path}", flush=True)
return out_path
[docs]
def generate_trace_plot(
chain_path: Union[str, Path],
var_names: List[str],
output_dir: Union[str, Path],
medium_name: str,
latex_mapping: Optional[Dict[str, str]] = None,
discard: int = 0,
) -> Optional[Path]:
"""Save a trace plot (walker trajectories) as a PDF.
Loads the chain from an HDF5 backend file so no live sampler object is
required — the plot can be regenerated from persisted runs.
Args:
chain_path: Path to the ``*.h5`` emcee HDF5 backend file.
var_names: Parameter names in chain column order.
output_dir: Directory where the PDF is saved.
medium_name: Used as the file-name stem.
latex_mapping: ``{name: latex_string}`` label overrides.
discard: Number of initial steps to discard from the trace.
Returns:
Path to the saved PDF, or ``None`` if matplotlib or emcee absent.
"""
if not _HAS_MPL or not _HAS_EMCEE:
_log.warning("matplotlib / emcee not installed — skipping trace plot.")
return None
chain_path = Path(chain_path)
if not chain_path.exists():
_log.warning("Chain file not found: %s — skipping trace plot.", chain_path)
return None
if latex_mapping is None:
latex_mapping = {}
backend = emcee.backends.HDFBackend(str(chain_path), read_only=True)
samples = backend.get_chain(discard=discard) # shape: (steps, walkers, ndim)
ndim = len(var_names)
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
with mpl.rc_context({"font.size": 12, "font.family": "serif"}):
fig, axes = plt.subplots(ndim, figsize=(10, 2 * ndim), sharex=True)
if ndim == 1:
axes = [axes]
for i in range(ndim):
ax = axes[i]
ax.plot(samples[:, :, i], "k", alpha=0.3)
ax.set_xlim(0, len(samples))
ax.set_ylabel(_latex_label(var_names[i], latex_mapping), fontsize=14)
ax.yaxis.set_label_coords(-0.1, 0.5)
axes[-1].set_xlabel("Step Number", fontsize=14)
fig.tight_layout()
out_path = output_dir / f"{medium_name}_trace.pdf"
fig.savefig(out_path, dpi=300, bbox_inches="tight")
plt.close(fig)
print(f" 📈 Trace plot saved to: {out_path}", flush=True)
return out_path
[docs]
def generate_latex_summary(
result: RetrievalResult,
output_dir: Union[str, Path],
latex_mapping: Optional[Dict[str, str]] = None,
) -> Path:
"""Write a self-contained LaTeX section summarising the fit.
Produces a ``.tex`` file with:
* Parameter table (initial value, best fit, ±stderr).
* Out-of-sample validation table (if ``result.val_stats`` is populated).
* Correlation matrix (if ``result.covar`` is available).
* ``\\includegraphics`` link to the corner plot.
Args:
result: :class:`RetrievalResult` from :func:`retrieve`.
output_dir: Directory where the ``.tex`` file is saved.
latex_mapping: ``{name: latex_string}`` label overrides.
Returns:
Path to the written ``.tex`` file.
"""
if latex_mapping is None:
latex_mapping = {}
medium_name = result.method.split("+")[0] if result.method else "fit"
# Use the output_dir stem as medium name for file/label purposes
output_dir = Path(output_dir)
stem = output_dir.name
clean = stem.replace("_", r"\_")
latex_str = f"\\section*{{Fit Summary for {clean}}}\n\n"
# --- 1. Parameter table ---
latex_str += "\\subsection*{Parameters and Uncertainties}\n"
latex_str += "\\begin{table}[htbp]\n\\centering\n"
latex_str += "\\begin{tabular}{l c c c}\n\\toprule\n"
latex_str += (
"Parameter & Initial Value & Best Fit & Uncertainty ($\\pm$) \\\\\n\\midrule\n"
)
for name, param in result.params.items():
label = _latex_label(name, latex_mapping)
if not param.vary:
init = f"{param.init_value:.4g}" if param.init_value is not None else "---"
latex_str += f"{label} & {init} & {param.value:.4g} & (Fixed) \\\\\n"
else:
err = param.stderr if param.stderr is not None else float("nan")
init = f"{param.init_value:.4g}" if param.init_value is not None else "---"
latex_str += f"{label} & {init} & {param.value:.4g} & {err:.4g} \\\\\n"
latex_str += "\\bottomrule\n\\end{tabular}\n"
latex_str += f"\\caption{{Best fit parameters for {clean}.}}\n"
latex_str += f"\\label{{tab:params_{stem}}}\n\\end{{table}}\n\n"
# --- 2. Validation table (once, not duplicated) ---
if result.val_stats:
latex_str += "\\subsection*{Out-of-Sample Validation}\n"
latex_str += "\\begin{table}[htbp]\n\\centering\n"
latex_str += "\\begin{tabular}{l c c c c}\n\\toprule\n"
latex_str += "Data Subset & $N$ & Reduced $\\chi^2$ & RMSE & Avg.\\,Error \\\\\n\\midrule\n"
train_avg_err = np.sqrt(max(result.redchi, 0)) * 5.0
latex_str += (
f"Training (baseline) & {result.ndata} & {result.redchi:.3f} & --- & "
f"{train_avg_err:.1f}\\% \\\\\n\\midrule\n"
)
for subset_name, stats in result.val_stats.items():
cs = subset_name.replace("_", r"\_")
val_err = np.sqrt(max(stats["redchi"], 0)) * 5.0
latex_str += (
f"{cs} & {stats['N']} & {stats['redchi']:.3f} & "
f"{stats['rmse']:.3e} & {val_err:.1f}\\% \\\\\n"
)
latex_str += "\\bottomrule\n\\end{tabular}\n"
latex_str += (
f"\\caption{{Out-of-sample validation for {clean}. "
r"Avg.\ Error = $\sqrt{\tilde{\chi}^2}\times5\%$.}}" + "\n"
)
latex_str += f"\\label{{tab:val_{stem}}}\n\\end{{table}}\n\n"
# --- 3. Correlation matrix ---
if result.covar is not None:
vn = result.var_names
stdevs = np.sqrt(np.diag(result.covar))
corr = result.covar / np.outer(stdevs, stdevs)
col_fmt = "c|" + "c" * len(vn)
header_names = [_latex_label(n, latex_mapping) for n in vn]
latex_str += "\\subsection*{Parameter Correlation Matrix}\n"
latex_str += "\\begin{equation}\n"
latex_str += f"\\begin{{array}}{{{col_fmt}}}\n"
latex_str += " & " + " & ".join(header_names) + " \\\\\n\\hline\n"
for i, ni in enumerate(vn):
row_label = _latex_label(ni, latex_mapping)
row = row_label
for j in range(len(vn)):
row += " & \\mathbf{1}" if i == j else f" & {corr[i, j]:.4f}"
latex_str += row + " \\\\\n"
latex_str += "\\end{array}\n\\end{equation}\n\n"
# --- 4. Figure link ---
latex_str += "\\subsection*{Parameter Correlations}\n"
latex_str += "\\begin{figure}[htbp]\n\\centering\n"
latex_str += (
f"\\includegraphics[width=0.9\\textwidth]{{{stem}_publication_corner.pdf}}\n"
)
latex_str += (
f"\\caption{{Corner plot for {clean}. "
"Solid lines mark the adopted (snapped) parameter values.}}\n"
)
latex_str += f"\\label{{fig:corner_{stem}}}\n\\end{{figure}}\n"
out_path = output_dir / f"{stem}_fit_summary.tex"
with open(out_path, "w") as fh:
fh.write(latex_str)
print(f" 📝 LaTeX summary saved to: {out_path}", flush=True)
return out_path
# ---------------------------------------------------------------------------
# Post-processing: chain manipulation after a completed run
# ---------------------------------------------------------------------------
[docs]
def post_process_mcmc(
medium_name: str,
output_dir: Union[str, Path],
original_stored_params: Dict[str, float],
drop_vars: Optional[List[str]] = None,
derived_exprs: Optional[Dict[str, str]] = None,
latex_mapping: Optional[Dict[str, str]] = None,
is_crystal: bool = False,
) -> RetrievalResult:
"""Re-analyse an existing MCMC chain: drop/derive variables, snap, and re-plot.
Loads the HDF5 chain written by a previous :func:`retrieve` call, applies
transformations (drop parameters, add derived quantities via pandas
``eval``), snaps each parameter to its cleanest 1-σ value, and writes
updated corner plot, trace plot, and LaTeX summary.
This mirrors the logic in ``uncertainty_emcee_post.py`` but as a proper
library function.
Args:
medium_name: Stem used for file names (matches the original run).
output_dir: Directory containing the ``.h5`` chain and ``.json`` result.
original_stored_params: The reference parameter dict (e.g. from
:func:`~gasfir.get_parameters`) used for drift detection.
drop_vars: Parameter names to remove from the chain before re-analysis.
derived_exprs: ``{new_name: pandas_eval_expression}`` pairs, e.g.
``{"a0_per_au3": "a0 / 6.74**3"}``.
latex_mapping: ``{name: latex_string}`` label overrides.
is_crystal: If ``True``, ``lattice_constant_au`` is kept as a fixed
constant in the reported tables.
Returns:
A new :class:`RetrievalResult` with post-processed parameters.
Raises:
FileNotFoundError: If the HDF5 chain or JSON result is missing.
ImportError: If emcee or h5py are not installed.
"""
if not _HAS_EMCEE:
raise ImportError(
"emcee is required for post-processing. " + _RETRIEVAL_INSTALL_HINT
)
if not _HAS_LMFIT:
raise ImportError(
"lmfit is required for post-processing. " + _RETRIEVAL_INSTALL_HINT
)
output_dir = Path(output_dir)
h5_path = output_dir / f"{medium_name}_emcee_chain.h5"
json_path = output_dir / f"{medium_name}_retrieval_result.json"
if not h5_path.exists():
raise FileNotFoundError(f"HDF5 chain not found: {h5_path}")
if not json_path.exists():
raise FileNotFoundError(f"JSON result not found: {json_path}")
# ---- load original result ----
original = load_result(json_path)
original_var_names: List[str] = original.var_names
# ---- load chain ----
backend = emcee.backends.HDFBackend(str(h5_path), read_only=True)
try:
tau = backend.get_autocorr_time(quiet=True)
burnin = int(2 * np.max(tau))
thin_val = max(1, int(0.5 * np.min(tau)))
except Exception:
total = backend.get_chain().shape[0]
burnin = total // 3
thin_val = 1
total_steps = backend.get_chain().shape[0]
burnin = min(burnin, total_steps // 2)
flat_samples = backend.get_chain(discard=burnin, thin=thin_val, flat=True)
df_chain = pd.DataFrame(flat_samples, columns=original_var_names)
_section(f"Post-processing MCMC chain — {medium_name}")
print(
f" Loaded {len(df_chain)} samples (burn={burnin}, thin={thin_val})", flush=True
)
# ---- apply transformations ----
if derived_exprs:
for new_var, expr in derived_exprs.items():
print(f" ➕ Derived: {new_var} = {expr}", flush=True)
df_chain[new_var] = df_chain.eval(expr)
if drop_vars:
for dv in drop_vars:
if dv in df_chain.columns:
print(f" ➖ Dropped: {dv}", flush=True)
df_chain = df_chain.drop(columns=[dv])
# ---- snap parameters ----
new_var_names = list(df_chain.columns)
rebuilt = lmfit.Parameters()
print("\n Snapping posteriors to cleanest values:", flush=True)
for name in new_var_names:
median = float(np.median(df_chain[name]))
sigma = float(np.std(df_chain[name]))
init = original_stored_params.get(name)
clean = get_cleanest_value(median, sigma, init)
tag = "STABLE" if init is not None and clean == init else "SNAPPED"
print(f" [{tag}] {name}: {clean:.4g} (σ={sigma:.3g})", flush=True)
rebuilt.add(name, value=clean, vary=True)
rebuilt[name].stderr = sigma
# carry fixed constants from original
for name in original.params:
if name not in rebuilt and not original.params[name].vary:
rebuilt.add(name, value=original.params[name].value, vary=False)
# add lattice_constant_au if crystal
if is_crystal and "lattice_constant_au" in original_stored_params:
lat_val = original_stored_params["lattice_constant_au"]
rebuilt.add("lattice_constant_au", value=lat_val, vary=False)
# ---- assemble RetrievalResult ----
new_flat = df_chain.values
new_covar = np.cov(new_flat, rowvar=False) if new_flat.shape[1] > 1 else None
stds = np.sqrt(np.diag(new_covar)) if new_covar is not None else None
if new_covar is not None and stds is not None:
for i, name in enumerate(new_var_names):
rebuilt[name].correl = {
new_var_names[j]: float(new_covar[i, j] / (stds[i] * stds[j]))
for j in range(len(new_var_names))
if j != i
}
post_result = RetrievalResult(
params=rebuilt,
var_names=new_var_names,
chisqr=original.chisqr,
redchi=original.redchi,
ndata=original.ndata,
nvarys=len(new_var_names),
covar=new_covar,
flat_samples=new_flat,
val_stats=original.val_stats,
method="emcee_postprocessed",
chain_path=h5_path,
)
# ---- outputs ----
post_stem = f"{medium_name}_postprocessed"
truths = [rebuilt[n].value for n in new_var_names]
fixed_consts: Dict[str, float] = {}
if is_crystal and "lattice_constant_au" in original_stored_params:
fixed_consts["lattice_constant_au"] = original_stored_params[
"lattice_constant_au"
]
generate_publication_corner(
new_flat,
new_var_names,
output_dir,
post_stem,
truths=truths,
latex_mapping=latex_mapping or {},
constants=fixed_consts,
)
generate_trace_plot(
h5_path,
original_var_names,
output_dir,
post_stem,
latex_mapping=latex_mapping or {},
discard=burnin,
)
generate_latex_summary(post_result, output_dir, latex_mapping=latex_mapping or {})
# save updated JSON
post_json = output_dir / f"{post_stem}_result.json"
save_result(post_result, post_json)
print(f" 💾 Post-processed result saved to: {post_json}", flush=True)
return post_result
# ---------------------------------------------------------------------------
# Cold-start defaults
# ---------------------------------------------------------------------------
def _cold_start_params(
E_g: float,
is_crystal: bool,
) -> tuple:
"""Return ``(initial_values, absolute_bounds)`` for a cold-start run.
Used when no stored parameters and no user-supplied initial guess exist.
Values and bounds are physically motivated broad ranges that cover all
known GASFIR parameter sets across atoms, molecules, and solids.
Args:
E_g: Ionization potential / band gap in atomic units. This is the
one parameter the user **must** supply — it sets the energy scale
and cannot be guessed from the data alone.
is_crystal: If ``True``, use wider bounds appropriate for solids
(larger a0, a1; m_eff free with negative values allowed).
Returns:
Tuple of ``(initial_params_dict, param_bounds_dict)`` where
``param_bounds`` contains ``{name: (min, max)}`` entries that are
passed directly to :func:`setup_parameters`.
"""
if is_crystal:
initial = {
"E_g": E_g,
"m_eff": 0.3, # typical solid; can be negative
"a0": 50.0, # broad: Diamond≈107, SiO2≈5
"a1": 5.0, # broad: Diamond≈14, SiO2≈1.6
"a2": 2.0,
"a3": -2.0, # often negative for crystals
}
bounds = {
"E_g": (max(1e-4, E_g / 5), E_g * 5),
"m_eff": (-2.0, 5.0),
"a0": (1e-6, 1e5),
"a1": (0.1, 200.0),
"a2": (-15.0, 20.0),
"a3": (-30.0, 15.0),
}
else:
initial = {
"E_g": E_g,
"a0": 5.0, # broad: He≈0.82, Ar≈14, H≈3.4
"a1": 3.5, # broad: He≈2, Ar≈4, H≈3.5
"a2": 2.0,
"a3": 1.0,
}
bounds = {
"E_g": (max(1e-4, E_g / 5), E_g * 5),
"a0": (1e-6, 1e3),
"a1": (0.1, 20.0),
"a2": (0.0, 15.0),
"a3": (-10.0, 10.0),
}
return initial, bounds
# ---------------------------------------------------------------------------
# Main entry point
# ---------------------------------------------------------------------------
[docs]
def retrieve(
data_NA: pd.DataFrame,
config: RetrievalConfig,
uncertainty_NA: Optional[npt.NDArray[np.float64]] = None,
data_QS: Optional[pd.DataFrame] = None,
uncertainty_QS: Optional[Union[npt.NDArray[np.float64], float]] = None,
data_validation: Optional[pd.DataFrame] = None,
run_phases: Tuple[bool, bool, bool] = (True, True, True),
) -> RetrievalResult:
"""Run the full DE/CMA-ES → least_squares → emcee retrieval pipeline.
Args:
data_NA: Training DataFrame with ``"pulses"`` and ``"Y"`` columns.
config: :class:`RetrievalConfig` controlling all pipeline settings.
uncertainty_NA: Per-point uncertainty array. If ``None``, defaults
to ``0.05 × Y`` (5% relative uncertainty).
data_QS: Optional quasi-static data with ``"field"`` and ``"Y"``
columns for simultaneous fitting.
uncertainty_QS: Uncertainty for QS data. Defaults to 5%.
data_validation: Optional out-of-sample DataFrame for validation
metrics (same columns as ``data_NA``).
run_phases: Three-tuple ``(global, polish, mcmc)`` controlling which
pipeline phases to execute. Default ``(True, True, True)``
runs all three.
Returns:
A :class:`RetrievalResult` with the best-fit parameters, posterior
statistics, and optional validation metrics.
"""
_require_lmfit()
# Validate optional dependencies before any expensive computation.
run_global, run_polish, run_mcmc_flag = run_phases
if run_global and config.global_method == "cma" and not _HAS_CMA:
raise ImportError(
"CMA-ES is the default global search method but the 'cma' package "
"is not installed. Either install it with "
" pip install gasfir[retrieval]\n"
"or switch to differential evolution:\n"
" config.global_method = 'differential_evolution'"
)
if run_mcmc_flag and not _HAS_EMCEE:
raise ImportError(
"emcee is required for MCMC sampling. " + _RETRIEVAL_INSTALL_HINT
)
output_dir = (
Path(config.output_dir)
if config.output_dir is not None
else Path(config.medium_name)
)
output_dir.mkdir(parents=True, exist_ok=True)
print(f" Output directory: {output_dir.resolve()}", flush=True)
# --- uncertainty defaults ---
if uncertainty_NA is None:
uncertainty_NA = 0.05 * data_NA["Y"].to_numpy()
if uncertainty_QS is None and data_QS is not None:
uncertainty_QS = 0.05 * data_QS["Y"].to_numpy()
# --- initial parameters ---
_cold_start = False
_explicit_bounds: Dict[str, tuple] = {}
# Try to get stored params; fall back to cold start if unavailable.
_stored_ref: Optional[Dict[str, float]] = None
try:
_stored_ref = get_parameters(config.medium_name)
except (ValueError, KeyError):
pass # medium not in store → cold start unless full initial_params given
if config.initial_params is not None and _stored_ref is not None:
# Full explicit initial params provided OR store available — use directly.
stored = config.initial_params.copy()
elif _stored_ref is not None:
# Store hit, no override → use stored params.
stored = _stored_ref.copy()
else:
# ── Cold start ───────────────────────────────────────────────────
# Medium not in store. E_g must be provided — it sets the energy
# scale and cannot be inferred from data alone.
hints = config.initial_params or {}
# Try to auto-resolve E_g from the atomic database if medium_name
# looks like an element symbol or name (atoms/molecules only).
_auto_E_g: Optional[float] = None
if not config.is_crystal and "E_g" not in hints:
_auto_E_g = get_ionization_potential(config.medium_name)
if "E_g" not in hints and _auto_E_g is None:
raise ValueError(
f"No stored parameters found for '{config.medium_name}'.\n"
"For a cold start, supply E_g (ionization potential in a.u.):\n"
f" config.initial_params = {{'E_g': <value_in_au>}}\n"
"For atoms/molecules the ionization potential can be looked up:\n"
" from gasfir import get_ionization_potential\n"
" E_g = get_ionization_potential('Ar') # example\n"
"For crystals, E_g must always be provided explicitly.\n"
"All other parameters are initialised from broad physical defaults."
)
E_g_init = hints.get("E_g", _auto_E_g)
stored, _explicit_bounds = _cold_start_params(E_g_init, config.is_crystal)
stored.update(hints) # user hints override defaults
_cold_start = True
_src = "NIST atomic database" if _auto_E_g is not None else "user-supplied"
_E_g_eV = E_g_init * 27.21138383
print(
f"\n ⚠ Cold start — no stored parameters for '{config.medium_name}'.\n"
f" E_g = {E_g_init:.4g} a.u. ({_E_g_eV:.4g} eV, {_src})\n"
f" Physical defaults used for all other parameters.\n"
f" Consider increasing cma_maxiter for a thorough global search.",
flush=True,
)
# Crystal: pre-expand a0_per_au3 → a0
if config.is_crystal and "a0_per_au3" in stored and "lattice_constant_au" in stored:
stored["a0"] = stored["a0_per_au3"] * stored["lattice_constant_au"] ** 3
if "a0" not in _explicit_bounds and _cold_start:
_explicit_bounds["a0"] = _explicit_bounds.pop("a0", (1e-6, 1e5))
params = setup_parameters(
stored,
relative_bound=config.relative_bound,
fixed_params=config.fixed_params,
is_crystal=config.is_crystal,
param_bounds=_explicit_bounds if _cold_start else None,
)
_section(f"Initial Parameters — {config.medium_name}")
params.pretty_print()
# --- residual function ---
base_residual = ret_residual_function(
data_NA,
uncertainty_NA,
data_QS,
uncertainty_QS,
dt=config.dt,
dT=config.dT,
ret_electron_density=config.ret_electron_density,
)
residual_fn = _make_safe_residual(base_residual)
# --- track which phases ran ---
phases_run: List[str] = []
de_result = None
ls_result = None
# === Phase 1: Global search ===
if run_global:
params, de_result = run_global_search(
residual_fn,
params,
method=config.global_method,
de_workers=config.de_workers,
de_maxiter=config.de_maxiter,
de_tol=config.de_tol,
de_popsize=config.de_popsize,
cma_sigma0=config.cma_sigma0,
cma_maxiter=config.cma_maxiter,
trial_run=config.trial_run,
)
phases_run.append(config.global_method)
_section(f"Phase 1 — {config.global_method.upper()} best parameters")
params.pretty_print()
# === Phase 2: Local polish ===
if run_polish:
params, ls_result = run_local_polish(
residual_fn, params, method=config.ls_method
)
phases_run.append("least_squares")
_section("Phase 2 — Least Squares Fit Report")
print(lmfit.fit_report(ls_result), flush=True)
# Compare to stored reference parameters if they exist
if config.medium_name and config.medium_name != "unknown":
try:
stored_ref = get_parameters(config.medium_name)
# For crystals: convert a0_per_au3 → a0 for the comparison
if "a0_per_au3" in stored_ref and "lattice_constant_au" in stored_ref:
stored_ref = stored_ref.copy()
stored_ref["a0"] = (
stored_ref["a0_per_au3"]
* stored_ref["lattice_constant_au"] ** 3
)
compare_to_stored(ls_result, stored_ref)
except (ValueError, KeyError):
pass # medium not in store — skip silently
# === Phase 3: MCMC ===
flat_samples: Optional[npt.NDArray] = None
lnprob_arr: Optional[npt.NDArray] = None
covar: Optional[npt.NDArray] = None
if run_mcmc_flag:
backend_path = output_dir / f"{config.medium_name}_emcee_chain.h5"
params, flat_samples, lnprob_arr, sampler = run_mcmc(
residual_fn,
params,
n_steps=config.emcee_nsteps,
n_walkers=config.emcee_nwalkers,
burn_fraction=config.emcee_burn_fraction,
thin=config.emcee_thin,
backend_path=backend_path,
pool=config.emcee_pool,
max_autocorr_loops=config.emcee_max_autocorr_loops,
initial_stored_params=stored,
trial_run=config.trial_run,
)
phases_run.append("emcee")
covar = np.cov(flat_samples, rowvar=False) if flat_samples.ndim == 2 else None
# === Goodness of fit at best-fit point ===
best_resid = residual_fn(params)
var_names = sorted([n for n in params if params[n].vary])
ndata = len(data_NA) + (len(data_QS) if data_QS is not None else 0)
nvarys = len(var_names)
chisqr = float(np.sum(best_resid**2))
redchi = chisqr / max(ndata - nvarys, 1)
# === MCMC fit report (needs chisqr computed above) ===
# HDF5 chain path (written during run_mcmc if emcee ran)
_chain_path: Optional[Path] = None
if run_mcmc_flag:
_chain_path = output_dir / f"{config.medium_name}_emcee_chain.h5"
if run_mcmc_flag and flat_samples is not None and len(flat_samples) > 0:
emcee_mr = _build_emcee_minimizer_result(
params,
var_names,
flat_samples,
ndata=ndata,
n_walkers=max(32, 4 * nvarys),
n_steps=config.emcee_nsteps,
chisqr=chisqr,
)
_section("Phase 3 — MCMC Posterior Report")
print(lmfit.fit_report(emcee_mr), flush=True)
# Corner plot, trace plot, and LaTeX summary
truths = [params[n].value for n in var_names]
fixed_consts = {n: params[n].value for n in params if not params[n].vary}
generate_publication_corner(
flat_samples,
var_names,
output_dir,
config.medium_name,
truths=truths,
latex_mapping=config.latex_mapping or {},
constants=fixed_consts,
)
if _chain_path and _chain_path.exists():
generate_trace_plot(
_chain_path,
var_names,
output_dir,
config.medium_name,
latex_mapping=config.latex_mapping or {},
)
# === Out-of-sample validation ===
val_stats: Dict = {}
if data_validation is not None and len(data_validation) > 0:
_log.info("Computing out-of-sample validation metrics...")
subsets: Dict[str, pd.DataFrame] = {"All_Validation": data_validation}
if "intens" in data_validation.columns:
subsets["High_Intensity"] = data_validation[
data_validation["intens"] >= 1e14
]
subsets["Low_Intensity"] = data_validation[data_validation["intens"] < 1e14]
if "wavel" in data_validation.columns:
subsets["Short_wl"] = data_validation[data_validation["wavel"] <= 800]
subsets["Long_wl"] = data_validation[data_validation["wavel"] > 800]
for subset_name, df_sub in subsets.items():
if len(df_sub) == 0:
continue
val_stats[subset_name] = compute_validation_metrics(
params,
df_sub,
dt=config.dt,
dT=config.dT,
ret_electron_density=config.ret_electron_density,
)
v = val_stats[subset_name]
_log.info(
" [%s] N=%d redchi=%.3f RMSE=%.2e",
subset_name,
v["N"],
v["redchi"],
v["rmse"],
)
# === Assemble result ===
result = RetrievalResult(
params=params,
var_names=var_names,
chisqr=chisqr,
redchi=redchi,
ndata=ndata,
nvarys=nvarys,
covar=covar,
flat_samples=flat_samples,
lnprob=lnprob_arr,
val_stats=val_stats,
de_result=de_result,
ls_result=ls_result,
method="+".join(phases_run),
chain_path=_chain_path,
)
# LaTeX summary (after val_stats are populated)
if run_mcmc_flag and flat_samples is not None and len(flat_samples) > 0:
generate_latex_summary(
result, output_dir, latex_mapping=config.latex_mapping or {}
)
# === Final summary ===
_section(f"Pipeline Complete — {config.medium_name} [{'+'.join(phases_run)}]")
print(f" Phases run : {' → '.join(phases_run)}", flush=True)
print(f" N data points : {ndata}", flush=True)
print(f" N free params : {nvarys}", flush=True)
print(f" χ² : {chisqr:.4g}", flush=True)
print(f" red-χ² : {redchi:.4g}", flush=True)
if val_stats:
print(f" Validation:", flush=True)
for name, v in val_stats.items():
print(
f" {name:30s} N={v['N']:4d} red-χ²={v['redchi']:.3f} RMSE={v['rmse']:.2e}",
flush=True,
)
print(f" Output dir : {output_dir}", flush=True)
# === Persist ===
json_path = output_dir / f"{config.medium_name}_retrieval_result.json"
save_result(result, json_path)
print(f" JSON saved : {json_path}", flush=True)
return result
# ---------------------------------------------------------------------------
# Multiprocessing helper
# ---------------------------------------------------------------------------
def make_emcee_pool(n_workers: Optional[int] = None) -> multiprocessing.Pool:
"""Create a ``multiprocessing.Pool`` for use as ``config.emcee_pool``.
Each worker sets ``NUMBA_NUM_THREADS=1`` so Numba does not over-subscribe
the CPU when combined with process-level parallelism.
Args:
n_workers: Number of worker processes. ``None`` → CPU count.
Returns:
An initialised :class:`multiprocessing.Pool`. **Remember to close it
after the retrieval run:** ``pool.close(); pool.join()``.
Example::
pool = make_emcee_pool()
config.emcee_pool = pool
result = retrieve(data_NA, config)
pool.close()
pool.join()
"""
def _worker_init() -> None:
os.environ["NUMBA_NUM_THREADS"] = "1"
n = n_workers or multiprocessing.cpu_count()
return multiprocessing.Pool(processes=n, initializer=_worker_init)