Source code for gasfir.kernels

# Copyright (c) 2024 Manoram Agarwal
"""Kernel functions for computing ionization rates and probabilities.

This module provides functions to compute ionization rates and probabilities using
different kernel methods. The kernels are used to model the ionization process in
atoms and molecules under strong laser fields.

The module implements three different methods:

1. GASFIR (General Approximator for Strong Field Ionization Rates)
2. Exact SFA (Strong Field Approximation)
3. QS (Quasi-Static)

Parameters can be obtained either by using pre-defined values for common gases
through :func:`~gasfir.get_parameters` or by specifying custom values.

Example:
    >>> from gasfir import create_pulse, get_parameters
    >>> laser = create_pulse(800, 1e14, 0, 30)
    >>> params = get_parameters("Hydrogen_SFA")
    >>> rates = get_diabatic_ionization_rate(t_grid, laser, params)
"""

import warnings
from typing import Dict, NamedTuple, Tuple, Union

import numpy as np
import numpy.typing as npt
import pandas
from numba import njit, prange, types
from numba.typed import List
from scipy.integrate import simpson

from ._parameter_store import get_parameters
from .pulse import Pulse, create_pulse
from .utils import (
    find_extrema_positions,
    find_zero_crossings,
    get_momentum_grid,
)
from .utils import integrate_oscillating_function_jit as IOF
from .utils import (
    meshgrid,
    safe_diff_linear,
    safe_diff_squared,
    trapz_jit,
)

safe_fastmath = {
    "nnan",
    "ninf",
    "contract",
}  # explicitely define which fastmath algorithms can be applied safely
aggressive_safe_fastmath = {
    "contract",  # Hardware-level Fused Multiply-Add
    "reassoc",  # Reorder math for speed (Safe now due to Hybrid Sum)
    "nnan",  # Assume no NaNs (Safe since a1 > 0 and meff != 0)
    "ninf",  # Assume no Infinities
    "nsz",  # No signed zeros (treats -0.0 exactly like +0.0)
    "arcp",  # Allow reciprocal math (Safe due to analytic division)
    # Note: 'afn' is ecluded to protect phase angles!
}


########################
class KernelParams(NamedTuple):
    """named tuple with default values for all fit parameters available in the current kernel
    not every parameter needs to be defined by the user. They are available for future compatibility
    none of the parameters are actually complex numbers. They have been defined in this manner for compatibility with the rest of the math
    a0: Normalization factor
    a1: Time constant related to smearing of Gaussian /tunneling time
    a2: correction for Stark shift
    a3: correction for dynamic stark shift
    a4: Power of denominator (Coulomb effects)
    E_g: band gap or ionization potential
    m_eff: effective mass to account for dispersion relation in solids
    T_thresh:  threshold for maximum diabatic time (au) beyond which the kernel is replaced with its analytic continuation
    a5: Multiplier of xi2**2 in numerator
    a6: Multiplier of xi1**2 in numerator
    div_theta: division factor for theta grid in exact SFA (default: 1.0, smaller values increase resolution)
    div_p: division factor for p grid in exact SFA (default: 1.0, smaller values increase resolution)
    """

    # internal_param_pxi1: Power of xi1 term in numerator (currently commented out)
    # internal_param_pxi2: Power of xi2 term in numerator (currently commented out)
    # required parameters with no defaults
    a0: float
    a1: float
    a2: float
    E_g: float
    a3: float = 0.0
    a4: float = 0.0
    m_eff: float = 1.0  # only necessary for solids
    # mostly for internal testing
    T_thresh: float = (
        200.0  # 100 atomic units (2.5 fs) threshold before the explicit evaluation of the kernel is replaced with analyctical continuation
    )
    a5: float = 1.0
    a6: float = 1.0
    αPol: float = 0.0
    # internal_param_pxi1 : float = 1.0
    # internal_param_pxi2 : float= 1.0
    div_theta: float = 1.0
    div_p: float = 1.0


from scipy.special import factorial, gamma


def C_n_star_sq_asymptotic(n_star):
    """
    Calculates the standard asymptotic ADK coefficient.
    Bypasses the Gamma function pole for atoms where n* < l.
    """
    e = np.e
    return (1.0 / (2 * np.pi * n_star)) * ((2 * e / n_star) ** (2 * n_star))


def f_lm(l, m):
    """Calculates the magnetic quantum number geometric factor"""
    m_abs = abs(m)
    numerator = (2 * l + 1) * factorial(l + m_abs)
    denominator = (2**m_abs) * factorial(m_abs) * factorial(l - m_abs)
    return numerator / denominator


def adk_rate(F, Ip, Z=1, l=0, m=None):
    """
    Standard ADK tunneling ionization rate using the asymptotic coefficient.
    """
    n_star = Z / np.sqrt(2 * Ip)
    kappa = np.sqrt(2 * Ip)

    # Calculate constants
    C2 = C_n_star_sq_asymptotic(n_star)

    rate = np.zeros_like(F)
    mask = np.abs(F) > 0
    F_m = np.abs(F[mask])

    exponent = -(2 * (kappa**3)) / (3 * F_m)

    # Full prefactor including the f(l,m) factor
    if l == 0 and m is None:  # s-orbital case, m=0
        m = 0
        prefactor = C2 * f_lm(l, m) * Ip * ((2 * (kappa**3)) / F_m) ** (2 * n_star - 1)
    elif m is None:
        prefactor = 0
        for m in range(-l, l + 1):
            flm = f_lm(l, m)
            prefactor += (
                C2 * flm * Ip * ((2 * (kappa**3)) / F_m) ** (2 * n_star - abs(m) - 1)
            )
    else:
        prefactor = (
            C2 * f_lm(l, m) * Ip * ((2 * (kappa**3)) / F_m) ** (2 * n_star - abs(m) - 1)
        )

    rate[mask] = prefactor * np.exp(exponent)
    return rate


@njit(parallel=True, fastmath=aggressive_safe_fastmath, cache=True, inline="always")
def Kernel_f_term(
    EFp: float | npt.NDArray[np.float64],
    EFm: float | npt.NDArray[np.float64],
    term0: complex | npt.NDArray[np.complex128],
    term1: float | npt.NDArray[np.float64],
    term2: float | npt.NDArray[np.float64],
    Ti: float | npt.NDArray[np.float64],
    params: NamedTuple,
) -> complex | npt.NDArray[np.complex128]:
    """Compute the pre-exponential factor of the kernel.

    Args:
        EFp: Electric field at t+Ti
        EFm: Electric field at t-Ti
        term0: Complex term = 1j/(T+1j*a1)
        term1: xi1=(A(t+Ti)-A(t-Ti))**2/4
        term2: xi2=(A(t+Ti)/2 + A(t-Ti)/2-A_bar)**2
        Ti: Time before and after moment of ionization
        params: Named tuple converted from user defined dictionary
    Returns:
        Complex value representing the pre-exponential factor
    """
    a1 = params.a1
    a0 = params.a0
    a4 = params.a4
    a5 = params.a5
    a6 = params.a6
    # m_eff = params.m_eff
    return (
        (1j * np.pi) ** 1.5
        * a0
        * EFp
        * EFm
        * (term0) ** (1.5 + a4)
        * (2 * term0 - 4 * a6 * term1 - 4 * a5 * term2 * term0**2 * Ti**2)
    )  # /m_eff**2
    # pxi1 = params.internal_param_pxi1
    # pxi2 = params.internal_param_pxi2

    # return (1j*np.pi)**1.5 * a0/m_eff**2 * EFp * EFm * term0**(1.5+a4) * (
    #     2*term0*m_eff - 4*a6*term1**pxi1 - 4*a5*(term2*term0**2*Ti**2)**pxi2
    # )


@njit(parallel=True, fastmath=aggressive_safe_fastmath, cache=True, inline="always")
def Kernel_phase_term(
    term0: complex | npt.NDArray[np.complex128],
    term1: float | npt.NDArray[np.float64],
    term2: float | npt.NDArray[np.float64],
    DelAbar: float | npt.NDArray[np.float64],
    DelA2bar: float | npt.NDArray[np.float64],
    E2diff: float | npt.NDArray[np.float64],
    Ti: float | npt.NDArray[np.float64],
    params: NamedTuple,
) -> complex | npt.NDArray[np.complex128]:
    """Compute the complex argument of the exponential function of the kernel.

    Args:
        term0: Complex term = 1j/(T+1j*a1)
        term1: xi1=(A(t+Ti)-A(t-Ti))**2/4
        term2: xi2=(A(t+Ti)/2 + A(t-Ti)/2-A_bar)**2
        DelAbar: (A(t+Ti)-A(t-Ti))/2
        DelA2bar: (A(t+Ti)**2-A(t-Ti)**2)/4
        E2diff: E^2(t+Ti)-E^2(t-Ti)
        Ti: Time before and after moment of ionization
        params: Dictionary containing the kernel parameters:
            E_g: Band gap or ionization potential
            αPol: Polarization for Stark shift
            a1: Time constant for Gaussian smearing
            a2: Exponential decay factor for term1
            a3: Exponential decay factor for term2

    Returns:
        Complex value representing the phase term
    """
    E_g = params.E_g
    αPol = params.αPol
    a1 = params.a1
    a2 = params.a2
    a3 = params.a3
    m_eff = params.m_eff
    # for QS function, nothing needs to be updated as long as
    # 1j*(Ti*(2*E_g+DelA2bar-DelAbar**2) + 3*np.pi/4 + αPol * E2diff/2 ) -  a1*(a2*term1) is UNTOUCHED
    # term2 is zero in QS case so changes there are irrelevant

    return (
        1j
        * (
            Ti * (2 * E_g + DelA2bar / m_eff - DelAbar**2 / m_eff)
            + αPol * E2diff / 2 / m_eff
        )
        - a1 * (a2 * term1 + a3 * term2 * Ti / (Ti + 1j * a1))
        - 3j * np.pi / 4
    )


@njit(parallel=False, fastmath=aggressive_safe_fastmath, cache=True, inline="always")
def Kernel_jit_helper(
    tar, Tar, params, EF, EF2, VP, intA, intA2, dT, N, n, nmin, Ti_ar, EF_max
):
    """return the kernel_GASFIR(t_grid,T_grid) for a given laser field computed with provided parameters using a jit optimized implementation
    Args:
        tar (np.ndarray): the grid of moment of ionization
        Tar (np.ndarray): array time T before and after time t that affects that moment of ionization
        EF (np.ndarray): Electric field
        EF2 (np.ndarray): Cummulative Electric field squared
        VP (np.ndarray): Vector potential
        A (np.ndarray): cummulative of the vector potential
        A2 (np.ndarray): cummulative squared of the vector potential
        dT (float64): step size of dense arrays used for EF, EF2 etc
        n (int64): number of steps of dT needed to reach next t[i+1]
        Ti_ar (np.ndarray): indices of time array T for which the values of kernel will actually be stored.
        params: named tupple
    Returns:
        f0 (np.ndarray, shape=(T_grid.size, t_grid.size)): 2d grid to store pre-exponential
        phase0 (np.ndarray, shape=(T_grid.size, t_grid.size)):2d grid storing complex agument of the exponential function
    """
    f0 = np.zeros((Tar.size, tar.size), dtype=np.cdouble)
    phase0 = np.zeros((Tar.size, tar.size), dtype=np.cdouble)
    a1 = params.a1
    E2diff = 0.0
    calc_E2 = True
    # a4_pos= (params.a4>=0)
    if params.αPol == 0.0:
        calc_E2 = False
    for i in prange(Tar.size):
        Ti = Ti_ar[i]
        T = Ti * dT
        if T == 0.0:
            E2diff = 0.0
            term1 = 0.0
            DelAbar = 0.0
            term2 = 0.0
            DelA2bar = 0.0
            term0 = 1 / a1
            phase_val = Kernel_phase_term(
                term0, term1, term2, DelAbar, DelA2bar, E2diff, T, params
            )
            for j in range(tar.size):
                tj = N + nmin + j * n
                # Check bounds since Ti=0, (tp == tm == tj)
                if tj >= 0 and tj < EF.size:
                    phase0[i, j] = phase_val
                    f0[i, j] = Kernel_f_term(
                        EF[tj], EF[tj], term0, term1, term2, T, params
                    )
                    phase0[i, j] = Kernel_phase_term(
                        term0, term1, term2, DelAbar, DelA2bar, E2diff, T, params
                    )
        # elif T > params.T_thresh and a4_pos:
        #     continue
        else:
            inv_denom = 1.0 / (T * T + a1 * a1)
            term0 = complex(a1 * inv_denom, T * inv_denom)
            for j in range(tar.size):
                tj = N + nmin + j * n
                tp = tj + Ti
                tm = tj - Ti
                EFp = EF[tp]
                EFm = EF[tm]
                # if abs(EFp * EFm) < 1e-7*EF_max: continue
                if tp >= 0 and tp < EF.size and tm >= 0 and tm < EF.size:
                    VPt = VP[tj]
                    term1 = (VP[tp] - VP[tm]) ** 2 / 4
                    # --- THE HYBRID SUM SWAPS ---
                    # upto T=threshold*dT, computes integral explitely
                    # without relying on the cummulant
                    # default threshold = 10
                    # if calc_E2: E2diff = safe_diff_squared(EF, EF2, tm, tp, dT)
                    # DelAbar = safe_diff_linear(VP, intA, tm, tp, dT) / (2.0 * T)
                    # DelA2bar = safe_diff_squared(VP, intA2, tm, tp, dT) / (2.0 * T)
                    # -----------------------------
                    DelAbar = (intA[tp] - intA[tm]) / (2.0 * T)
                    DelA2bar = (intA2[tp] - intA2[tm]) / (2.0 * T)
                    if calc_E2:
                        E2diff = (EF2[tp] - EF2[tm]) / (2.0 * T)

                    term2 = (VP[tp] / 2 + VP[tm] / 2 - DelAbar) ** 2
                    phase0[i, j] = Kernel_phase_term(
                        term0, term1, term2, DelAbar, DelA2bar, E2diff, T, params
                    )
                    f0[i, j] = Kernel_f_term(EFp, EFm, term0, term1, term2, T, params)
                    # DelAbar = DelAbar - VPt

    return f0, phase0


@njit(parallel=False, fastmath=False, cache=True)
def exact_SFA_jit_helper(
    tar,
    Tar,
    params,
    EF,
    EF2,
    VP,
    intA,
    intA2,
    dT,
    N,
    n,
    nmin,
    Ti_ar,
    p_grid,
    Theta_grid,
    window,
    p,
    theta,
):
    """return the kernel_SFA(t_grid,T_grid) for a given laser field computed with provided parameters using a jit optimized implementation
    Args:
        tar (np.ndarray): the grid of moment of ionization
        Tar (np.ndarray): array time T before and after time t that affects that moment of ionization
        EF (np.ndarray): Electric field
        EF2 (np.ndarray): Cummulative Electric field squared
        VP (np.ndarray): Vector potential
        A (np.ndarray): cummulative of the vector potential
        A2 (np.ndarray): cummulative squared of the vector potential
        dT (float64): step size of dense arrays used for EF, EF2 etc
        n (int64): number of steps of dT needed to reach next t[i+1]
        Ti_ar (np.ndarray): indices of time array T for which the values of kernel will actually be stored.

    Returns:
        f0 (np.ndarray, shape=(T_grid.size, t_grid.size)): 2d grid to store pre-exponential
        phase0 (np.ndarray, shape=(T_grid.size, t_grid.size)):2d grid storing complex agument of the exponential function
    """
    f0 = np.zeros((Tar.size, tar.size), dtype=np.cdouble)
    phase0 = np.zeros((Tar.size, tar.size), dtype=np.cdouble)
    E_g = params.E_g
    pz = p * np.cos(theta)
    for i in prange(Tar.size):
        Ti = Ti_ar[i]
        for j in range(tar.size):
            tj = N + nmin + j * n
            tp = tj + Ti
            tm = tj - Ti
            if tp >= 0 and tp < EF.size and tm >= 0 and tm < EF.size:
                VPt = VP[tj]
                T = Ti * dT
                DelA = (intA[tp] - intA[tm]) - 2 * VPt * T
                VP_p = VP[tp] - VPt
                VP_m = VP[tm] - VPt
                f_t_1 = (
                    (pz + VP_p)
                    * (pz + VP_m)
                    / (p**2 + VP_p**2 + 2 * pz * VP_p + 2 * E_g) ** 3
                    / (p**2 + VP_m**2 + 2 * pz * VP_m + 2 * E_g) ** 3
                )
                G1_T_p = np.trapezoid(
                    f_t_1 * np.exp(1j * pz * DelA) * np.sin(theta), Theta_grid
                )
                G1_T = np.trapezoid(
                    G1_T_p * window * p_grid**2 * np.exp(1j * p_grid**2 * T), p_grid
                )
                # G1_T_p = IOF(p_grid,f_t_1,phase_t)
                # G1_T=IOF(p_grid,G1_T_p*window*p_grid**2,p_grid**2*T)
                DelA = DelA + 2 * VPt * T
                phase0[i, j] = (
                    (intA2[tp] - intA2[tm]) / 2 + 2 * E_g * T + T * VPt**2 - VPt * DelA
                )
                f0[i, j] = EF[tp] * EF[tm] * 2**9 * (2 * E_g) ** 2.5 / np.pi * G1_T
    return f0, phase0 * 1j


def Kernel_jit(
    t_grid: npt.NDArray[np.float64],
    T_grid: npt.NDArray[np.float64],
    pulse: Pulse,
    param_dict: Dict[str, float],
    kernel_type: str = "GASFIR",
) -> Tuple[npt.NDArray[np.complex128], npt.NDArray[np.complex128]]:
    """Compute the kernel for ionization rate calculation.

    Args:
        t_grid: Time grid for the calculation
        T_grid: Integration time grid
        pulse: Pulse object defining the laser field
        param_dict: Dictionary containing the kernel parameters
        kernel_type: Type of kernel to use ("GASFIR" or "exact_SFA")

    Returns:
        Tuple of (kernel_real, kernel_imag) arrays
    """
    t = t_grid
    T = T_grid
    if len(t) > 1:
        dt = min(
            np.round(np.diff(t), 4)
        )  # rounding just to remove any numerical artifact
    else:
        dt = 1.0
    dT = min(np.diff(T))  # /2
    t_min, t_max = pulse.get_time_interval()
    a1_injection = int(max(abs(t_min), abs(t_max))) + 1
    # print(dt, dT)
    assert np.isclose(
        dt % dT, 0, atol=1e-8
    ), f"dt should be approximately a multiple of dT, got dt={dt}, dT={dT}, dt%dT={dt%dT}"
    n = int(dt // dT)
    N = int(a1_injection // dT) + 1
    # nmax=int(t[-1]//dT)
    nmin = int(t[0] // dT)
    tAr = np.arange(-N, N + 1, 1.0) * dT
    VP = pulse.get_vector_potential(tAr)
    EF = pulse.get_electric_field(tAr)
    EF_max = np.max(np.abs(EF))
    intA = pulse.get_cummulative_vector_potential(tAr)  # np.cumsum(VP*dT)
    intA2 = pulse.get_cummulative_vector_potential_squared(tAr)  # np.cumsum(VP**2*dT)
    EF2 = pulse.get_cummulative_electric_field_squared(tAr)  # np.cumsum(EF**2*dT)
    Ti_ar = (T // dT).astype(np.int64)
    # f0 = np.zeros((T.size, t.size), dtype=np.cdouble)
    # phase0 = np.zeros((T.size, t.size), dtype=np.cdouble)
    if kernel_type == "exact_SFA":
        valid_keys = KernelParams._fields
        filtered_dict = {k: v for k, v in param_dict.items() if k in valid_keys}
        params = KernelParams(**filtered_dict)  # NamedTuple
        # div_theta=param_dict["div_theta"]
        # div_p=param_dict["div_p"]
        p_grid, Theta_grid, window = get_momentum_grid(
            params.div_p, params.div_theta, pulse, Ip=param_dict["E_g"]
        )
        # print(p_grid.size, Theta_grid.size)
        p, theta = meshgrid(p_grid, Theta_grid)
        f0, phase0 = exact_SFA_jit_helper(
            t,
            T,
            params,
            EF,
            EF2,
            VP,
            intA,
            intA2,
            dT,
            N,
            n,
            nmin,
            Ti_ar,
            p_grid,
            Theta_grid,
            window,
            p,
            theta,
        )
    elif kernel_type == "GASFIR":
        valid_keys = KernelParams._fields
        filtered_dict = {k: v for k, v in param_dict.items() if k in valid_keys}
        params = KernelParams(**filtered_dict)  # NamedTuple
        f0, phase0 = Kernel_jit_helper(
            t, T, params, EF, EF2, VP, intA, intA2, dT, N, n, nmin, Ti_ar, EF_max
        )
    else:
        raise NotImplementedError
    return f0, phase0


[docs] def get_diabatic_ionization_rate( t_grid: npt.NDArray[np.float64], pulse: Pulse, param_dict: Dict[str, float], dT: float = 0.25, kernel_type: str = "GASFIR", ret_tgrid: bool = False, ) -> Union[ npt.NDArray[np.float64], Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]] ]: """Compute ionization rates for a defined pulse. Args: t_grid: Time grid for ionization moments pulse: Pulse object defining the laser field param_dict: Dictionary containing kernel parameters dT: Time step for integration (default: 0.25) kernel_type: Type of kernel to use ("GASFIR" or "exact_SFA") ret_tgrid: Whether to return the time grid (default: False) Returns: If ``ret_tgrid`` is False, an array of ionization rates at each time point. If ``ret_tgrid`` is True, a tuple ``(time_grid, rates)``. Notes: The function handles various input formats for t_grid: - None: Uses laser field's default time grid - pandas Series: Converts to numpy array - list: Converts to numpy array - numpy array: Used as is """ t_min, t_max = pulse.get_time_interval() a1_injection = int(max(abs(t_min), abs(t_max))) + 1 T_grid = np.ascontiguousarray(np.arange(0, a1_injection + dT, dT, dtype=np.float64)) if type(t_grid) is str: t_grid = np.array(json.loads(t_grid), dtype=np.float64) elif t_grid is None: t_grid = pulse.get_tgrid(dt=1.0) # get_tGrid_dt(dt=1) ret_tgrid = True else: t_grid = np.array(t_grid, dtype=np.float64) if not np.all(np.diff(t_grid) % dT == 0): user_grid = t_grid t_grid = np.arange(user_grid[0], user_grid[-1] + dT * 4, dT * 4) t_grid = np.unique(t_grid) f, phase = Kernel_jit( t_grid, T_grid, pulse, param_dict, kernel_type=kernel_type ) rate = 2 * np.real(IOF(T_grid, f, phase)) return np.interp(user_grid, t_grid, rate) else: f, phase = Kernel_jit( t_grid, T_grid, pulse, param_dict, kernel_type=kernel_type ) rate = 2 * np.real(IOF(T_grid, f, phase)) # rate=2*np.real(np.trapezoid(f*np.exp(phase), x=T_grid, axis=0)) # the factor two is to account for the fact that the kernel # is symmetric in T and we integrate from 0 to inf if ret_tgrid: return t_grid, np.array(rate) else: return np.array(rate)
# @njit(parallel=True,fastmath = False, cache=True)
[docs] def get_quasi_static_rate_for_field( field: float | npt.NDArray[np.float64], param_dict: Dict[str, float], type="GASFIR", Z=1, l=0, m=None, ) -> npt.NDArray[np.float64]: """return the ionization rate for a define pulse computed with provided parameters Args: field (float64/np.ndarray): the grid of electric field strengths param_dict: dictionary defining the medium's parameters type: GASFIR or ADK or Tong-Lin (default: GASFIR) Z: charge of the residual ion (default: 1 for neutral atoms) l: orbital angular momentum quantum number (default: 0) m: magnetic quantum number (default: None, if None, sums over all m for given l) Returns: np.ndarray, shape=(field.size): the ionization rates for given array of electric field strengths """ valid_keys = KernelParams._fields filtered_dict = {k: v for k, v in param_dict.items() if k in valid_keys} params = KernelParams(**filtered_dict) # NamedTuple E_g = params.E_g if type == "GASFIR": αPol = params.αPol a1 = params.a1 a2 = params.a2 m_eff = params.m_eff # Ensure field is always a numpy array if np.isscalar(field): field = np.array([field], dtype=np.float64) scalar_input = True else: field = np.asarray(field, dtype=np.float64) scalar_input = False tmp = np.zeros(field.size, dtype=np.complex128) field = np.abs(field) cond = field > 0 field = field[cond] T_saddle = ( 1j * ( -a2 * a1 + np.sqrt(αPol * m_eff + a2**2 * a1**2 + 2 * E_g * m_eff**2 / field**2) ) / m_eff ) E2diff = 2 * field**2 * T_saddle term1 = (-2 * field * T_saddle) ** 2 / 4 DelAbar = 0 term2 = 0 DelA2bar = (field * T_saddle) ** 2 / 3 term0 = 1j / (T_saddle + 1j * a1) # saddle_contribution=np.sqrt(np.pi/(a2*a1-1j*T_saddle))/field saddle_contribution = ( np.sqrt(np.pi * m_eff / (a1 * a2 - 1j * T_saddle * m_eff)) / field ) # (np.sqrt(2 * E_g*m_eff**2/field**2 + (a2**2 * a1**2 + αPol*m_eff))) tmp[cond] = ( np.exp( Kernel_phase_term( term0, term1, term2, DelAbar, DelA2bar, E2diff, T_saddle, params ) ) * Kernel_f_term(field, field, term0, term1, term2, T_saddle, params) * saddle_contribution ) rate = np.real(tmp) assert np.all(abs(np.imag(rate)) <= 1e-10) return rate elif type == "ADK": return adk_rate(field, E_g, Z=Z, l=l, m=m) elif type == "Tong-Lin": # map tong-lin alpha values for Hydrogen, Helium, Argon, Neon alpha_dic = { 0.5: 6.0, # Hydrogen 0.90: 7.0, # Helium 0.58: 9.0, # Argon 0.79: 9.0, } # Neon kappa = np.sqrt(2 * E_g) # print(f"Using Tong-Lin correction with alpha={alpha_dic.get(np.round(E_g, 2), 'unknown')} for E_g={E_g}") correction = np.exp(-alpha_dic.get(np.round(E_g, 2)) * field / (kappa**3) / E_g) return adk_rate(field, E_g, Z=Z, l=l, m=m) * correction else: raise NotImplementedError(f"Unsupported type {type}. Use 'GASFIR' or 'ADK'.")
# @njit(parallel=True,fastmath = False, cache=True)
[docs] def get_rate_quasi_static_limit( t_grid: npt.NDArray[np.float64], pulse: Pulse, param_dict: Dict[str, float], type="GASFIR", Z=1, l=0, m=None, ) -> npt.NDArray[np.float64]: """Calculate the quasi-static ionization rate. Args: t_grid: Time grid for the calculation pulse: Pulse object defining the laser field param_dict: Dictionary containing the parameters type: Type of rate to compute ("GASFIR" or "ADK" or "Tong-Lin") default is "GASFIR" Z: Charge of the residual ion (default: 1 for neutral atoms, relevant for ADK) l: Orbital angular momentum quantum number (default: 0, relevant for ADK) m: Magnetic quantum number (default: None, if None, sums over all m for given l, relevant for ADK) Returns: Array of ionization rates at each time point """ field = np.abs(pulse.get_electric_field(t_grid)) return get_quasi_static_rate_for_field(field, param_dict, type=type, Z=Z, l=l, m=m)
# @njit(parallel=True,fastmath = False, cache=True)
[docs] def get_probability_quasi_static_limit( pulse: Pulse, param_dict: Dict[str, float], dt: float = 2.0 ) -> float: """Calculate the quasi-static ionization probability. Args: pulse: Pulse object defining the laser field param_dict: Dictionary containing the parameters dt: Time step for integration Returns: Ionization probability as a float """ t_min, t_max = pulse.get_time_interval() t_grid = np.arange(t_min, t_max + dt, dt) tmp = get_rate_quasi_static_limit(t_grid, pulse, param_dict) return float(simpson(tmp, x=t_grid, axis=-1))
[docs] def get_diabatic_ionization_probability( pulse: Pulse, param_dict: Dict[str, float], dt: float = 2.0, dT: float = 0.25, filterTreshold: float = 0.0, kernel_type: str = "GASFIR", ret_Rate: bool = False, ret_electron_density: bool = False, ) -> float | Tuple[float, npt.NDArray[np.float64], npt.NDArray[np.float64]]: """Compute ionization probability for a defined pulse. Args: pulse: Pulse object defining the pulse param_dict: Dictionary containing kernel parameters dt: Time step for rate calculation (default: 2.0) dT: Time step for kernel integration (default: 0.25) filterTreshold: Threshold for filtering rates (default: 0.0) kernel_type: Type of kernel to use ("GASFIR" or "exact_SFA") ret_Rate: Whether to return rates array (default: False) Returns: If ``ret_Rate`` is False, the total ionization probability (float). If ``ret_Rate`` is True, a tuple ``(probability, t_grid, rates)``. """ t_min, t_max = pulse.get_time_interval() a1_injection = int(max(abs(t_min), abs(t_max))) + 1 if np.any(np.array(pulse.get_central_wavelength()) < 140): dt = 0.5 dT = 0.25 t_grid = np.ascontiguousarray(np.arange(int(t_min), int(t_max) + dt, dt)) if filterTreshold > 0: ### filter out the t_grid such that |E(t_grid)|>=1% of max(E(t_grid)) ### ElecField = lambda t: pulse.get_electric_field(t) extr = find_extrema_positions(t_grid, ElecField(t_grid)) Fextr = ElecField(extr) extr = extr[np.abs(Fextr) >= max(np.abs(Fextr)) * filterTreshold] if extr[0] > 0 and extr[-1] < 0: extr = extr[::-1] ### smartly take the t_grid uptill the correspondin zero crossing rather than abruptly ending a a sub-cycle peak zeroCr = find_zero_crossings(t_grid, ElecField(t_grid)) if zeroCr[0] > 0 and zeroCr[-1] < 0: zeroCr = zeroCr[::-1] t_grid = np.arange( np.floor(zeroCr[zeroCr < extr[0]][-1]), np.ceil(zeroCr[zeroCr > extr[-1]][0]) + dt, dt, dtype=np.float64, ) rate_result = get_diabatic_ionization_rate( t_grid, pulse, param_dict, dT, kernel_type=kernel_type ) if isinstance(rate_result, tuple): t_grid_out, rate = rate_result else: t_grid_out, rate = t_grid, rate_result integral = np.double(simpson(rate, x=t_grid_out, axis=-1)) if not ret_electron_density: ionization_probability = 1 - np.exp(-1 * integral) else: ionization_probability = integral # if ionization_probability < 0: # old_prediction= ionization_probability # warnings.warn(f"Ionization probability, {ionization_probability} is negative at {pulse.get_central_wavelength()} nm, I={pulse.get_peak_intensity()}. {param_dict}" if ret_Rate: return ionization_probability, t_grid_out, rate else: return ionization_probability
[docs] def get_diabatic_ionization_probability_vec( pulses: List[Pulse], param_dict: Dict[str, float], dt: float = 2.0, dT: float = 0.25, kernel_type: str = "GASFIR", ret_electron_density: bool = False, ) -> List[float | Tuple[float, npt.NDArray[np.float64], npt.NDArray[np.float64]]]: precomputed_tuple = precompute_pulse_batch(pulses, dt_val=dt, dT_val=dT) valid_keys = KernelParams._fields filtered_dict = {k: v for k, v in param_dict.items() if k in valid_keys} kernel_params = KernelParams(**filtered_dict) return get_diabatic_ionization_probability_batch( kernel_params, precomputed_tuple, ret_electron_density )
@njit(parallel=True, fastmath=aggressive_safe_fastmath, cache=True) def get_diabatic_ionization_probability_batch( params: NamedTuple, precomputed_tuple: Tuple[List], ret_electron_density: bool = False, ): """ Takes the massive precomputed tuple and processes all pulses simultaneously across all available CPU cores. """ # Unpack the tuple ( EF_list, EF2_list, VP_list, intA_list, intA2_list, t_grid_list, T_grid_list, Ti_ar_list, dT_arr, N_arr, n_arr, nmin_arr, EF_max_list, ) = precomputed_tuple num_pulses = np.int64(len(EF_list)) probabilities = np.zeros(num_pulses, dtype=np.float64) # DISTRIBUTE THE PULSES ACROSS THE CPU CORES for p in prange(num_pulses): # 1. Run the inner matrix generation sequentially for this specific pulse f0, phase0 = Kernel_jit_helper( t_grid_list[p], T_grid_list[p], params, EF_list[p], EF2_list[p], VP_list[p], intA_list[p], intA2_list[p], dT_arr[p], N_arr[p], n_arr[p], nmin_arr[p], Ti_ar_list[p], EF_max_list[p], ) # 2. Integrate the complex matrix sequentially rate_complex = IOF(T_grid_list[p], f0, phase0) rate = 2.0 * np.real(rate_complex) # 3. Final time integration using the JIT trapezoidal rule integral = trapz_jit(rate, t_grid_list[p]) if ( integral is np.inf or integral < 0 or np.isnan(integral) or integral is -np.inf ): probabilities[p] = 0.0 # 4. Store the final probability else: if ret_electron_density: probabilities[p] = integral else: probabilities[p] = 1.0 - np.exp(-1.0 * integral) return probabilities
[docs] def precompute_pulse_batch(pulse_list, dt_val=4.0, dT_val=0.25): """ Strips the raw NumPy arrays out of Python Pulse objects and packs them into Numba-compatible typed Lists for zero-overhead C-level processing. """ # print(f"Pre-computing arrays for {len(pulse_list)} pulses...") # Initialize empty Numba-typed lists EF_list, EF2_list, VP_list = List(), List(), List() intA_list, intA2_list = List(), List() t_grid_list, T_grid_list, Ti_ar_list = List(), List(), List() EF_max_list = List() # Standard lists for scalars (we convert to numpy arrays at the end) dT_vals, N_vals, n_vals, nmin_vals = [], [], [], [] try: from tqdm.auto import tqdm as _tqdm _pulse_iter = _tqdm(pulse_list, desc="Precomputing pulses", unit="pulse") except ImportError: _pulse_iter = pulse_list for pulse in _pulse_iter: t_min, t_max = pulse.get_time_interval() if pulse.get_central_wavelength() < 140: dt_val = min(1.0, dt_val) dT_val = min(0.25, dT_val) elif pulse.get_central_wavelength() < 500: dt_val = min(2.0, dt_val) dT_val = min(0.5, dT_val) else: dt_val = min(4.0, dt_val) dT_val = min(0.5, dT_val) a1_injection = int(max(abs(t_min), abs(t_max))) + 1 t_grid = np.ascontiguousarray( np.arange(int(t_min), int(t_max) + dt_val, dt_val, dtype=np.float64) ) T_grid = np.ascontiguousarray( np.arange(0, a1_injection + dT_val, dT_val, dtype=np.float64) ) n = int(dt_val // dT_val) N = int(a1_injection // dT_val) + 1 nmin = int(t_grid[0] // dT_val) tAr = np.arange(-N, N + 1, 1.0) * dT_val Ti_ar = (T_grid // dT_val).astype(np.int64) # Extract and append the heavy arrays EF_list.append(pulse.get_electric_field(tAr)) EF_max_list.append(np.max(np.abs(pulse.get_electric_field(tAr)))) EF2_list.append(pulse.get_cummulative_electric_field_squared(tAr)) VP_list.append(pulse.get_vector_potential(tAr)) intA_list.append(pulse.get_cummulative_vector_potential(tAr)) intA2_list.append(pulse.get_cummulative_vector_potential_squared(tAr)) t_grid_list.append(t_grid) T_grid_list.append(T_grid) Ti_ar_list.append(Ti_ar) # Store the grid scalars dT_vals.append(dT_val) N_vals.append(N) n_vals.append(n) nmin_vals.append(nmin) # print("Pre-computation complete.") # Return a massive tuple containing all the pre-processed data return ( EF_list, EF2_list, VP_list, intA_list, intA2_list, t_grid_list, T_grid_list, Ti_ar_list, np.array(dT_vals, dtype=np.float64), np.array(N_vals, dtype=np.int64), np.array(n_vals, dtype=np.int64), np.array(nmin_vals, dtype=np.int64), EF_max_list, )
def ret_gasfir_P_for_dataFrame( data_Nadiabatic: pandas.DataFrame, dt: float = 2, dT: float = 0.25, ret_electron_density: bool = False, ) -> np.ndarray: """ This function runs once for given dataset It pre-computes the heavy math, locks it in memory, and returns the actual diabatic probability function that only needs parametrs. """ print("Initializing Optimization Engine...") # 1. Run the heavy pre-computation ONCE pulse_list = data_Nadiabatic.pulses.tolist() precomputed_tuple = precompute_pulse_batch(pulse_list, dt, dT) def diabatic_P(params): # Unpack params safely valid_keys = KernelParams._fields filtered_dict = {k: v for k, v in params.items() if k in valid_keys} kernel_params = KernelParams(**filtered_dict) probs = get_diabatic_ionization_probability_batch( kernel_params, precomputed_tuple, ret_electron_density ) # Return normalized residuals return probs return diabatic_P if __name__ == "__main__": import time # ========================================== # 1. SETUP YOUR LASER AND PARAMS HERE # ========================================== # Replace these with your actual pulse creation and parameter fetching print("Initializing Pulse and Parameters...") # pulse = create_pulse(800, 1e14, 0, 30) # param_dict = get_parameters("Hydrogen_SFA") pulse = create_pulse(wavel=800, intens=1e14, cep=0, fwhmCyc=3) param_dict = get_parameters("H_SFA") # ========================================== # ========================================== # 2. THE WARM-UP RUN (Triggers LLVM Compiler) # ========================================== print("\n[1/3] Warming up Numba compiler (This will be slow)...") _ = get_diabatic_ionization_probability(pulse, param_dict) print("Compilation complete.") # ========================================== # 3. GRANULAR BENCHMARKING # ========================================== print("\n[2/3] Running Granular Benchmark...") # Manually recreate the steps of Kernel_jit to time them t_min, t_max = pulse.get_time_interval() dt, dT = 2.0, 0.25 t_grid = np.arange(t_min, t_max + dt, dt) a1_injection = int(max(abs(t_min), abs(t_max))) + 1 T_grid = np.ascontiguousarray(np.arange(0, a1_injection + dT, dT, dtype=np.float64)) n = int(dt // dT) N = int(a1_injection // dT) + 1 nmin = int(t_grid[0] // dT) tAr = np.arange(-N, N + 1, 1.0) * dT Ti_ar = (T_grid // dT).astype(np.int64) # --- Time Step A: Array Generation --- t0 = time.perf_counter() VP = pulse.get_vector_potential(tAr) EF = pulse.get_electric_field(tAr) EF_max = np.max(np.abs(EF)) intA = pulse.get_cummulative_vector_potential(tAr) intA2 = pulse.get_cummulative_vector_potential_squared(tAr) EF2 = pulse.get_cummulative_electric_field_squared(tAr) t1 = time.perf_counter() time_arrays = (t1 - t0) * 1000 # check if the values are being cached VP = pulse.get_vector_potential(tAr) EF = pulse.get_electric_field(tAr) EF_max = np.max(np.abs(EF)) intA = pulse.get_cummulative_vector_potential(tAr) intA2 = pulse.get_cummulative_vector_potential_squared(tAr) EF2 = pulse.get_cummulative_electric_field_squared(tAr) t2 = time.perf_counter() time_arrays2 = (t2 - t1) * 1000 # --- Time Step B: Dictionary to NamedTuple --- t0 = time.perf_counter() valid_keys = KernelParams._fields filtered_dict = {k: v for k, v in param_dict.items() if k in valid_keys} params = KernelParams(**filtered_dict) t1 = time.perf_counter() time_tuple = (t1 - t0) * 1000 # --- Time Step C: The JIT Helper Loop --- t0 = time.perf_counter() f0, phase0 = Kernel_jit_helper( t_grid, T_grid, params, EF, EF2, VP, intA, intA2, dT, N, n, nmin, Ti_ar, EF_max ) t1 = time.perf_counter() time_jit = (t1 - t0) * 1000 # --- Time Step D: Integration (IOF) --- t0 = time.perf_counter() rate = 2 * np.real(IOF(T_grid, f0, phase0)) t1 = time.perf_counter() time_iof = (t1 - t0) * 1000 # ========================================== # 4. OVERALL FUNCTION BENCHMARK # ========================================== print("\n[3/3] Running Overall Function Benchmark...") t0 = time.perf_counter() prob = get_diabatic_ionization_probability( pulse, param_dict, dt=2.0, dT=0.25, filterTreshold=0.0, kernel_type="GASFIR" ) t1 = time.perf_counter() time_total = (t1 - t0) * 1000 # ========================================== # RESULTS # ========================================== print("\n" + "=" * 40) print(" PERFORMANCE BREAKDOWN (in milliseconds)") print("=" * 40) print(f"1. Pulse Array Generation: {time_arrays:.3f} ms") print(f"1. Cached Pulse Array Generation: {time_arrays2:.3f} ms") print(f"2. NamedTuple Conversion: {time_tuple:.3f} ms") print(f"3. Kernel JIT Loop: {time_jit:.3f} ms") print(f"4. Kernel Integration: {time_iof:.3f} ms") print("-" * 40) print(f"TOTAL Top-Level Runtime: {time_total:.3f} ms") print("=" * 40) # Calculate percentage sum_parts = time_arrays + time_tuple + time_jit + time_iof print(f"JIT Loop accounts for {(time_jit/sum_parts)*100:.1f}% of the compute time.")