#!/usr/bin/env python3
'''Implementation of entrainment-limited bimolecular reaction kinetics

Includes exact and approximate entrainment-limited bimolecular kinetics,
thin-cloud approximation, cloud partitioning method, and an
reference two-box model of cloud chemistry with explicit entrainment.

When run as a script, this program reproduces Figure 1, which compares
numerical predictions among all of these methods, and the timing tests.

Christopher D. Holmes
cdholmes@fsu.edu
''' 
#%%
import numpy as np
import scipy.optimize as opt 
from scipy.integrate import solve_ivp
from timeit import timeit

def k2eff(T,patm,Lc,pH):
    '''Effective rate coefficient for aqueous HSO3- + H2O2 + H+ = SO4-- + 2H+ + H2O, expressed as reaction between gaseous SO2 and H2O2

    Parameters 
    ----------
    T : float
        Temperature, K
    patm : float
        atmospheric pressure, atm
    Lc : float
        cloud liquid water content, m3/m3
    pH : float 
        cloud water pH

    Returns
    -------
    KaqH2O2 : floats
        rate for SO2 + H2O2, m3/molec/s

    '''
    # Cloud [H+], mol/L
    Hplus = 10**(-pH)

    # Ideal gas constant, L atm / K / mol
    R = 0.08205 

    # Reaction rate for HSO3- + H2O2 + H+ => SO4-- + 2H+ + H2O [Jacob, 1986], L/mol/s
    KH2O2 = 6.31e+14 * np.exp( -4.76e+3 / T )

    ## SO2-H2O equilibria
    # SO2 + H2O = SO2.H2O, Henry's law, M/atm
    HSO2  = 1.22e+0 * np.exp( 10.55e+0 * ( 298.15e+0 / T - 1.e+0) )
    # SO2(aq)  = HSO3- + H+, M
    Ks1    = 1.30e-2 * np.exp( 6.75e+0 * ( 298.15e+0 / T - 1.e+0 ) )
    #  HSO3- = SO3= + H+, M
    Ks2    = 6.31e-8 * np.exp( 5.05e+0 * ( 298.15e+0 / T - 1.e+0 ) )
    # Effective Henry's law coefficient, M/atm
    HeffSO2 = HSO2 * (1.e+0 + (Ks1/Hplus) + (Ks1*Ks2 / (Hplus*Hplus)))
    # Fraction of SO2 in gas phase
    xSO2g = 1.e+0 / ( 1.e+0 + ( HeffSO2 * R * T * Lc ) )

    ## H2O2 equilibria
    # Henry's law for H2O2, M/atm
    HH2O2 = 7.45e+4 * np.exp( 22.21e+0 * (298.15e+0 / T - 1.e+0) )
    # H2O2 = HO2- + H+, M
    Kh1 = 2.20e-12 * np.exp( -12.52e+0 * ( 298.15e+0 / T - 1.e+0 ) )
    # Effective Henry's law coefficient, M/atm
    HeffH2O2 = HH2O2 * (1.e+0 + (Kh1 / Hplus))
    # Fraction of H2O2 in gas phase
    xH2O2g = 1.e+0 / ( 1.e+0 + ( HeffH2O2 * R * T * Lc ) )

    # Effective rate constant for SO2 + H2O2, with concentrations expressed in mol/mol
    KaqH2O2 = KH2O2 * Ks1 * HeffH2O2 * HSO2 * xH2O2g * xSO2g * patm * Lc * R * T

    # Convert -> m3/molec/s
    KaqH2O2 /= patm*101300/8.31/T*6.02e23

    return KaqH2O2

# Entrainment limited bimolecular
def dmdt_el(t,x,fc,kc,kab,V=1,rtol=1e-4,inputIsConc=False):
    ''' Mass-balance equation for entrainment-limited bimolecular reactions in cloud
    
    Parameters
    ----------
    t    : float
        time, not used
    x    : ndarray
        mass or concentration at grid scale (see note below on units)
    fc   : float
        cloud fraction, in range 0-1 
    kc   : float
        cloud air detrainment loss frequency, 1/s
            (1/kc = mean residence time of air in cloud)
    kab  : float
        bimolecular reaction rate coefficient (see note below on units)
    V    : float
        Grid volume (see note below on units)
    rtol : float
        error tolerance (relative for 1st order loss frequencies)
    inputIsConc : bool
        specifies whether the units are intensive concentration-like (True) 
        or extensive burden-like (False) (see note below on units)
     
    Returns
    -------
    dmdt : ndarray
        dm/dt

    Note on units
    The values of x, kab, and V can be specified in several unit systems, as long as they are mutually consistent 
    1. x in molecules, V in m3,  kab in m3/molecule/s, inputIsConc=False
    2. x in molecules, V in cm3, kab in cm3/molecule/s, inputIsConc=False
    3. x in kg, V in m3, kab in m3/kg/s, inputIsConc=False
    4. x in molecules/m3, V (not used), kab in m3/molecule/s, inputIsConc=True
    5. x in molecules/cm3, V (not used), kab in cm3/molecule/s, inputIsConc=True
    6. x in kg/m3, V (not used), kab in m3/kg/s, inputIsConc=True
    The returned value will always be in units of x per second
    '''

    if inputIsConc:
        # Input is grid-scale concentration molecule/m3
        ca = x[0]
        cb = x[1]
    else:
        # Input is grid-scale burden, molecules
        ma = x[0]
        mb = x[1]
        # Grid scale concentration, molecules
        ca = ma/V
        cb = mb/V

    # Users may choose to ignore cloud chemistry for computational speed, 
    # when cloud fraction is below a very small threshold (e.g. 1e-4).
    # This parameter should be carefully chosen by the user.
    # Set fcmin = 0, to always compute cloud chemistry regardless of cloud fraction. 
    fcmin = 1e-4
    if fc < fcmin:
        # Assume zero reaction for very small cloud fraction
        return [0.,0.]

    # For very large cloud fractions, the thin-cloud approximation has acceptable low error.
    # Use it in these conditions because it's faster.
    # The error is under ~0.1% for fc > 0.97 and under ~0.01% for fc > 0.99.
    fcmax = 0.99 
    if fc > fcmax:
        if inputIsConc:
            loss = fc * kab * ca * cb
        else:
            loss = fc * kab * ma * cb
        return [-loss,-loss]

    # Useful numberss
    ff = fc / (1-fc)

    def rcalc(rother,cother):
        # Calculate the ratio of reactant concentration in cloud to grid average, given the ratio for other reactant
        # cother = grid-average concentration of other reactant
        # rother = ratio ci/cgrid for other reactant 
        ki = kab * rother * cother 
        kk = ki / kc
        x = 0.5*(ff-kk-1) + 0.5*np.sqrt(1+kk**2+ff**2+2*kk+2*ff-2*ff*kk)
        r = x / (fc*(1+x))
        return r
    if cb > ca:
        # Swap reactants A and B so that A is always more abundant
        ca, cb = cb, ca
        swapAB = True

    if ca / cb > 100:
        # if ca >> cb, then assume that ra = 1
        ra = 1
    else:
        # Solve for ra with root finding
        def xroot2(r,ca,cb):
            # Update ra = cia / ca then subtract input value
            return rcalc(rcalc(r,ca),cb) - r
        # Solve for ra
        s = opt.root_scalar( xroot2, (ca,cb), 'brentq', bracket=[1-cb/ca,1] )
        ra = s.root
        
    # Solve for rb
    rb = rcalc(ra,ca)
    # loss, burden/volume/s
    loss = kab * ca * cb * ra * rb * fc

    if not inputIsConc:
        # Convert burden/volume/s -> burden/s
        loss = loss * V

    # We could swap A and B back, if necessary, but the loss is the same either way
    # if swapAB:

    # Solution from simpler iteration
    return [-loss, -loss]

# Approximate entrainment-limited kinetics, requires no iteration
def dmdt_elapprox(t,x,fc,kc,kab,V=1,inputIsConc=False):
    '''Mass-balance equation for approximate entrainment-limited bimolecular reactions in cloud

    Parameters
    ----------
    t    : float
        time, not used
    x    : ndarray
        mass or concentration at grid scale (see note below on units)
    fc   : float
        cloud fraction, in range 0-1 
    kc   : float
        cloud air detrainment loss frequency, 1/s
            (1/kc = mean residence time of air in cloud)
    kab  : float
        bimolecular reaction rate coefficient (see note below on units)
    V    : float
        Grid volume (see note below on units)
    rtol : float
        error tolerance (relative for 1st order loss frequencies)
    inputIsConc : bool
        specifies whether the units are intensive concentration-like (True) 
        or extensive burden-like (False) (see note below on units)
     
    Returns
    -------
    dmdt : ndarray
        dm/dt

    Note on units: See dmdt_el documentation
    '''

    if inputIsConc:
        # Input is grid-scale concentration molecule/m3
        ca = x[0]
        cb = x[1]
    else:
        # Input is grid-scale burden, molecules
        ma = x[0]
        mb = x[1]
        # Grid scale concentration, molecules
        ca = ma/V
        cb = mb/V
    
    ff = fc / (1-fc)

    # Concentration of limiting reactant
    cmin = np.minimum(ca,cb)

    # Loss frequency in limit of slow cloud loss, 1/s
    ka = 1 / ( 1/(ff*kc*cmin/ca) + 1/(fc*kab*cb) )
    kb = 1 / ( 1/(ff*kc*cmin/cb) + 1/(fc*kab*ca) )
    
    if inputIsConc:
        return [-ka*ca, -kb*cb]
    else:
        return [-ka*ma, -kb*mb]

def dmdt_thincloud(t,m,kthincloud,V):
    '''Mass balance equation for thin-cloud method

    Parameters
    ----------
    t    : float
        time, not used
    x    : ndarray
        mass or concentration at grid scale (see note below on units)
    V    : float
        Grid volume (see note below on units)
     
    Returns
    -------
    dmdt : ndarray
        dm/dt
    '''
    # mass, molec
    ma = m[0]
    mb = m[1]

    rate = kthincloud * ma * mb / V 
    return [-rate, -rate]

def dmdt_partition(t,m,ma0,mb0,fc,kc,kab,V):
    '''Mass balance equation for partitioning method

    Parameters
    ----------
    t    : float
        time, not used
    m    : ndarray
        mass at grid scale 
    ma : float
        initial mass of A in cell
    mb : float
        initial mass of B in cell
    fc   : float
        cloud fraction, in range 0-1 
    kc   : float
        cloud air detrainment loss frequency, 1/s
            (1/kc = mean residence time of air in cloud)
    kab  : float
        bimolecular reaction rate coefficient 
    V    : float
        Grid volume 
     
    Returns
    -------
    dmdt : ndarray
        dm/dt
    '''
    # The in-cloud mass is
    mai = m[0] - (1-fc) * ma0
    mbi = m[1] - (1-fc) * mb0

    rate = kab * mai * mbi / (V*fc)     
    return [-rate,-rate]

def solve_partition_stepping( trange, tstep, m0, fc,kc,kab,V ):
    '''Integrate partitioning method

    Parameters
    ----------
    trange : ndarray
        start and end times for integration
    tstep : float
        time step for re-partitioning reactants
    m0 : ndarray
        initial reactant concentrations
    fc   : float
        cloud fraction, in range 0-1 
    kc   : float
        cloud air detrainment loss frequency, 1/s
            (1/kc = mean residence time of air in cloud)
    kab  : float
        bimolecular reaction rate coefficient 
    V    : float
        Grid volume 

    Returns
    -------
    sol_step : ODE solution structure
    '''

    # Initialize
    tstart = trange[0]
    mstart = m0
    solt = np.array([])
    soly = np.array([[],[]])
    # Continue stepping until we reach end of trange
    while tstart < trange[1]:
        # End of current time step
        tend = np.minimum( tstart+tstep, trange[1] )
        # Time points for this step
        ttstep = np.arange(tstart,tend,60)
        # Do time step
        sol_step = solve_ivp( dmdt_partition, [tstart,tend], mstart, args=[mstart[0],mstart[1],fc,kc,kab,V], method='Radau', t_eval=ttstep )
        # Concatenate result
        solt = np.hstack((solt,sol_step.t))
        soly = np.hstack((soly,sol_step.y))
        # 
        tstart = tstart+tstep
        mstart = soly[:,-1]
    sol_step.t = solt
    sol_step.y = soly
    return sol_step

def dmdt_2box(t,m,fc,kc,kab,V):
    '''Mass balance equation for two-box model with entrainment, bimolecular reaction in cloudy fraction
    
    Parameters
    ----------
    t : float
        time, not used
    m : ndarray
        mass [A outside cloud, A in cloud, B outside cloud, B in cloud]
    fc : float
        cloud fraction
    kab : float
        rate coefficient, volume/mass/time
    V : float
        volume

    Returns
    -------
    dmdt : ndarray
        dm/dt
    '''
    moa = m[0] # mass of a outside cloud
    mia = m[1] # mass of a in cloud
    mob = m[2] # mass of b outside cloud
    mib = m[3] # mass of b in cloud

    ff = fc / (1-fc)
    Vc = fc * V
    if fc == 0 or fc == 1:
        raise ValueError('Cloudy/Clear box model not applicable to cloud fraction {:f}'.format(fc))

    return [-ff*kc*moa + kc*mia,
             ff*kc*moa - kc*mia - kab*mia*mib/Vc,
            -ff*kc*mob + kc*mib,
             ff*kc*mob - kc*mib - kab*mia*mib/Vc ]


if __name__ == '__main__':
    # Create Figure 1

    import matplotlib.pyplot as plt
    import copy

    # Cloud fraction, m3/m3
    fc = 0.2
    # Cloud residence time, s
    tauc = 3600
    kc = 1/tauc
    # Temperature, K
    T    = 284   
    # Pressure, atm
    patm = 0.80  
    # In-cloud liquid water content, m3/m3
    Lc   = 1e-6 
    # Grid average liquid water content, m3/m3
    L    = Lc*fc
    # Cloud pH, 
    pH    = 5
    # Volume of grid cell, m3
    V = 1

    # Cloud volume, m3
    Vc = fc * V

    # Rate coefficient for in-cloud reaction, m3/molec/s
    kab = k2eff(T,patm,Lc,pH)
    # Rate coefficient for thin-cloud, m3/molec/s
    kthincloud = k2eff(T,patm,L,pH)

    # Initial Conditions, molec/m3
    # Typical value for H2O2 = 1 ppb, SO2 = 1 ppb; 1ppb = 2e16 molec/m3 @800 hPa
    # m0 = [SO2 clear, SO2 cloud, H2O2 clear, H2O2 cloud]
    m0 = np.array( [1,0,1,0] ) * 1e-9 * (patm*101300/8.31/T*6.02e23)

    # Convert all units to m3 -> cm3
    fac = 1e6
    m0  /= fac
    kab *= fac
    kthincloud *= fac
    iscm3 = True

    # Simulation time step, end time
    tmax = 4 # end time, hr
    dt = 60  # time step, s
    tt = np.arange(0,tmax*3600,dt)

    # Spin up to get steady ratio of reactants in cloud to out of cloud
    if True:
        # Spinup function that keeps mass constant
        def dmdt_2box_spin(t,m,fc,kc,kab,V):
            dmdt = dmdt_2box(t,m,fc,kc,kab,V)
            da = dmdt[0]+dmdt[1]
            db = dmdt[2]+dmdt[3]
            # Add this mass increment back in proportiion to current mass
            dmdt[0] -= da * m[0] / (m[0]+m[1])
            dmdt[1] -= da * m[1] / (m[0]+m[1])
            dmdt[2] -= db * m[2] / (m[2]+m[3])
            dmdt[3] -= db * m[3] / (m[2]+m[3])
            return dmdt
        # Paritioning at end of integration, mass inside cloud/ outside cloud
        sol_2box = solve_ivp( dmdt_2box_spin, [0,10*3600], m0, method='Radau', args=(fc,kc,kab,V) )
        m0 = sol_2box.y[:,-1]
    # Total initial concentrations of A, B
    ma0 = np.sum(m0[:2])
    mb0 = np.sum(m0[2:])

    # Solve initial value problem
    sol_2boxfull  = solve_ivp( dmdt_2box,      [0,tmax*3600], m0, method='Radau', t_eval=tt, args=(fc,kc,kab,V) )
    sol_el        = solve_ivp( dmdt_el,        [0,tmax*3600], [ma0,mb0], method='Radau', t_eval=tt, args=(fc,kc,kab,V,1e-6,False) )
    sol_elapprox  = solve_ivp( dmdt_elapprox,  [0,tmax*3600], [ma0,mb0], method='Radau', t_eval=tt, args=(fc,kc,kab,V) )
    sol_thincloud = solve_ivp( dmdt_thincloud, [0,tmax*3600], [ma0,mb0], method='Radau', t_eval=tt, args=(kthincloud,V) )
    sol_partition = solve_ivp( dmdt_partition, [0,tmax*3600], [ma0,mb0], method='Radau', t_eval=tt, args=(ma0,mb0,fc,kc,kab,V) )
    sol_partition_step60 = solve_partition_stepping( [0,tmax*3600], 3600, [ma0,mb0], fc,kc,kab,V )
    sol_partition_step30 = solve_partition_stepping( [0,tmax*3600], 1800, [ma0,mb0], fc,kc,kab,V )
    sol_partition_step10 = solve_partition_stepping( [0,tmax*3600], 600,  [ma0,mb0], fc,kc,kab,V )

    # Condense 2-box model to cell total
    sol_2box = copy.deepcopy( sol_2boxfull )
    sol_2box.y = np.stack( [np.sum( sol_2boxfull.y[:2,:], axis=0), np.sum( sol_2boxfull.y[2:,:], axis=0)] )

    #%%
    if iscm3:
        pow = 10
        unit = 'molec cm$^{-3}$'
        strfac = '10$^{10}$'
    else:
        pow = 16
        unit='molec m$^{-3}$'
        strfac = '10$^{16}$'

    # Show numerical box solution
    fig = plt.figure(figsize=(5,4))
    ax1 = plt.subplot()
    ax1.plot(sol_2box.t/3600,      sol_2box.y[0,:]/(10**pow), 'k', label='Reference (two-box) solution', linewidth=4)
    ax1.plot(sol_el.t/3600,        sol_el.y[0,:]/(10**pow),        'C3', label='Entrainment-limited')
    ax1.plot(sol_elapprox.t/3600,  sol_elapprox.y[0,:]/(10**pow),  'C3--',label='Entrainment-limited approx.')
    ax1.plot(sol_thincloud.t/3600, sol_thincloud.y[0,:]/(10**pow), 'C1', label='Thin cloud')
    ax1.plot(sol_partition_step60.t/3600, sol_partition_step60.y[0,:]/(10**pow), 'C0-.',   label='Partition (60 min)')
    ax1.plot(sol_partition_step30.t/3600, sol_partition_step30.y[0,:]/(10**pow), 'C0--',  label='Partition (30 min)')
    ax1.plot(sol_partition_step10.t/3600, sol_partition_step10.y[0,:]/(10**pow), 'C0-', label='Partition (10 min)')

    ax1.legend()
    ax1.set_xlabel('Time, hours')
    ax1.set_ylabel(r'SO$_2$, {:s} {:s}'.format(strfac,unit))

    # Label right axis with S(VI) production
    ax1b = ax1.twinx()
    mn,mx = ax1.get_ylim()
    ax1b.set_ylim(ma0/(10**pow)-mn, ma0/(10**pow)-mx)
    ax1b.set_ylabel(r'S(VI), {:s} {:s}'.format(strfac,unit))

    fig.tight_layout()
    fig.savefig('Figure1.png')

    ## Timing test for the calculations in Figure 1 ##
    print('\n\nPerformance comparison of solution methods.\n')

    # Number of repetitions
    N = 1000

    # Mass balance derivative speed test
    print('Execution time (seconds) to compute derivative n={:d} times'.format(N))
    print('{:15s}{:15s}{:15s}{:15s}'.format('Method','Time(s)','Time(norm)','1/Time(norm)'))
    for i, (name, stmt) in enumerate( zip( ['Exact','Approximate','Thin cloud'],
                [ "dmdt_el(0,[ma0,mb0],fc,kc,kab,V,1e-6,False) ",
                  "dmdt_elapprox(0,[ma0,mb0],fc,kc,kab,V) ",
                  "dmdt_thincloud(0,[ma0,mb0],kthincloud,V) " ] ) ):
        time = timeit( stmt=stmt, number=N, globals=locals() )
        if i==0:
            norm_time = time

        print('{:15s}{:<15f}{:<15f}{:<15f}'.format( name, time, time/norm_time, norm_time/time ))

    # ODE integration speed
    print('\nExecution time to complete ODE integration n={:d} times'.format(N))    
    print('{:15s}{:15s}{:15s}{:15s}'.format('Method','Time(s)','Time(norm)','1/Time(norm)'))
    for i, (name, stmt) in enumerate( zip( ['Exact','Approximate','Thin cloud'],
                [ "solve_ivp( dmdt_el,        [0,tmax*3600], [ma0,mb0], method='Radau', t_eval=tt, args=(fc,kc,kab,V,1e-6,False) )",
                  "solve_ivp( dmdt_elapprox,  [0,tmax*3600], [ma0,mb0], method='Radau', t_eval=tt, args=(fc,kc,kab,V) )",
                  "solve_ivp( dmdt_thincloud, [0,tmax*3600], [ma0,mb0], method='Radau', t_eval=tt, args=(kthincloud,V) )" ] ) ):
        time = timeit( stmt=stmt, number=N, globals=locals() )
        if i==0:
            norm_time = time

        print('{:15s}{:<15f}{:<15f}{:<15f}'.format( name, time, time/norm_time, norm_time/time ))
    
    print("""
    Timing results on the author's computer
    Execution time (seconds) to compute derivative n=1000 times
    Method         Time(s)        Time(norm)     1/Time(norm)   
    Exact          0.066079       1.000000       1.000000       
    Approximate    0.004000       0.060537       16.518806      
    Thin cloud     0.000567       0.008585       116.475834     

    Execution time to complete ODE integration n=1000 times
    Method         Time(s)        Time(norm)     1/Time(norm)   
    Exact          6.150929       1.000000       1.000000       
    Approximate    2.621767       0.426239       2.346100       
    Thin cloud     4.995044       0.812080       1.231406    
    """ )

# %%
