#!/usr/bin/env python
# -*- coding: utf-8 -*- Time-stamp: <2017-06-30 19:25:27 sander>*-

# xskeleton: execute mechanism reduction to obtain a skeletal mechanism
# Rolf Sander, 2016-2017

##############################################################################

# RUN CONTROL:

full_calc       = 0 # 0=only plots, 1=full calculation
plot_delta_skel = 1 # 0=no plots,   1=all
plot_targets    = 1 # 0=no plots,   1=all
plot_scenarios  = 2 # 0=no plots,   1=all,              2=some

eps = 5E-4          # start value for epsilon_ep
eps_increase = 1.2  # factor for increasing epsilon_ep

# select a file with the scenarios (w/o suffix '.nc'):
#scenariofile = 'skeleton_scenarios_small'
scenariofile = 'skeleton_scenarios_30'

##############################################################################

# TODO:

# - see also: ~/messy/mecca/mechanism_reduction
# - user manual:
#   e manual.tex ; pdflatex manual.tex ; acro manual.pdf
# - target precision:
#   - use fun_split to calc lifetime (can it be called from outside kpp?)
#   - allow e.g. 10% change after 1 tau (put "10%" or 0.1 into
#     targets.txt instead of reltol)
# - xskeleton.py:
#   - add more targets?
# - is relhum calculated from H2O in scenario file?

##############################################################################

from netCDF4 import Dataset
import matplotlib.pyplot as plt
#plt.rcParams.update({'figure.max_open_warning': 0}) # 'More than 20 figures'
from matplotlib.backends.backend_pdf import PdfPages
import os
import subprocess
import shutil # rmtree = rm -fr
import datetime
import sys
import numpy as np
from glob import glob
from mecca import mecca
from caabaplot import caabaplot
from pyteetime import tee

KPPMODE  = '// -*- kpp -*- kpp mode for emacs'
DONTEDIT = '// This file was created by xskeleton, DO NOT EDIT!'
HLINE =  '-' * 78
skeletondir = os.path.abspath('.')
caabadir = os.path.abspath('..')
LOGFILE = tee.stdout_start('xskeleton.log') # stdout
# Read info about targets from file:
targetdata = np.genfromtxt('targets.txt', dtype=None, comments='#')
global epslist

##############################################################################

def runcmd(cmd, comment, logfile):
    print >> logfile, HLINE
    print >> logfile, cmd
    print >> logfile, HLINE
    print '%-12s' % (comment),
    exitstatus = subprocess.call('time -p '+cmd, stdout=logfile, stderr=logfile, shell=True)
    if (exitstatus == 0):
        print '  ...done'
    else:
        sys.exit('ERROR: nonzero exit status')

##############################################################################

def cleanup():
    if (os.getenv('TRASH')):
        # $TRASH exists, move old data to trash directory:
        os.rename('output', os.getenv('TRASH') + '/skeleton-output-' +
                  datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S'))
    else:
        # Completely delete output directory:
        shutil.rmtree('output')
    os.mkdir('output')

##############################################################################

def get_workdir(skelnum):
    if (skelnum):
        # workdir contains skelnum with 3 digits:
        workdir = '%s/output/skeleton_%3.3d' % (skeletondir, skelnum)
    else:
        workdir = skeletondir+'/output/fullmech'
    if (full_calc):
        os.mkdir(workdir)
    return workdir

##############################################################################

def show_reaction(rxn):
    reagents = ''
    products = ''
    for spc_num, spc_stoic in enumerate(StoichNum[rxn]): # loop over species
        if (spc_stoic < 0):
            reagents += ' + %g %s' % (-spc_stoic, oicdata[spc_num][1])
        if (spc_stoic > 0):
            products += ' + %g %s' % (spc_stoic, oicdata[spc_num][1])
    # return 'Rxn %4d: %-10s %s -> %s' \
    #   % (rxn+1, '<'+EQN_TAGS[rxn]+'>', reagents, products)
    return '%-10s %s -> %s' % ('<'+EQN_TAGS[rxn]+'>', reagents, products)

##############################################################################

def create_skeletal_mechanism(eps):
    global del_rxn
    MECHLOGFILE = open(workdir+'/mechanism.log','w+', 1) # 1=line-buffered
    print >> MECHLOGFILE, '%s\neps = %g\n%s' % (HLINE, eps, HLINE)
    # Create empty list of reactions to delete:
    del_rxn = [False] * len(StoichNum)
    N_var_skel = 0
    # Create mechanism including all species with OIC>eps:
    for num, (oic, name) in enumerate(oicdata): # loop over species
        if (oic > eps): # keep!
            print >> MECHLOGFILE, '\nKEEP   %4d %s %15g' % (num+1, name, oic)
            N_var_skel += 1
        else: # delete!
            print >> MECHLOGFILE, '\nDELETE %4d %s %15g' % (num+1, name, oic)
            # Find rxns that contain this species:
            for rxn in range(N_rxns): # loop over reactions
                if (StoichNum[rxn][num] != 0):
                    if (StoichNum[rxn][num] < 0):
                        del_rxn[rxn] = True # mark for deletion
                        print >> MECHLOGFILE, '  '+show_reaction(rxn)
                        print >> MECHLOGFILE, \
                            '    delete because reagent %10g %15s' % \
                            (StoichNum[rxn][num], name)
                    if (StoichNum[rxn][num] > 0):
                        del_rxn[rxn] = True # mark for deletion
                        print >> MECHLOGFILE, '  '+show_reaction(rxn)
                        print >> MECHLOGFILE, \
                            '    delete because product %10g %15s' % \
                            (StoichNum[rxn][num], name)
    print >> MECHLOGFILE, '\n%s\nSUMMARY OF DELETED REACTIONS\n%s:' % (
        HLINE, HLINE)
    for rxn in range(N_rxns): # loop over reactions
        if (del_rxn[rxn]):
            print >> MECHLOGFILE, 'DELETE: '+show_reaction(rxn)
    print >> MECHLOGFILE, '\n%s\nSUMMARY OF KEPT REACTIONS\n%s:' % (HLINE, HLINE)
    for rxn in range(N_rxns): # loop over reactions
        if (not del_rxn[rxn]):
            print >> MECHLOGFILE, 'KEEP: '+show_reaction(rxn)
    MECHLOGFILE.close()
    print 'nvar = %d/%d, nreact = %d/%d, eps = %g' % \
      (N_var_skel, N_var_full, del_rxn.count(False), N_rxns, eps)
    # Create rpl file:
    create_skeleton_rpl(del_rxn)
    epslist.append(eps)
    del_rxnlist.append(del_rxn)

##############################################################################

def create_skeleton_rpl(delrxn):
    # A rpl file that deletes several reactions is used to create the
    # skeletal mechanism. Initially, an empty rpl file creates the full
    # mechanism.
    rplfilename = skeletondir+'/skeleton.rpl'
    RPLFILE = open(rplfilename,'w+')
    print >> RPLFILE, KPPMODE + '\n' + DONTEDIT
    if (delrxn):
        for rxn, delete in enumerate(delrxn): # loop over reactions
            if (delete):
                print >> RPLFILE, '#REPLACE %-10s' % ('<'+EQN_TAGS[rxn]+'>')
                print >> RPLFILE, '#ENDREPLACE'
    RPLFILE.close()
    # Save *.rpl file in output directory:
    os.system('cp -p ' + rplfilename + ' ' + workdir)

##############################################################################

def caaba_multirun():
    if (full_calc):
        olddir = os.getcwd()
        os.chdir(caabadir+'/mecca') # cd to MECCA directory
        # Create mechanism:
        runcmd('xmecca skeleton', 'xmecca', LOGFILE)
        os.chdir(caabadir) # cd to CAABA base directory
        # Compile caaba:
        runcmd('gmake', 'gmake', LOGFILE)
        # Run CAABA for all model scenarios ('sample points')
        runcmd('./multirun/multirun.tcsh skeleton/scenarios/' +
               scenariofile + '.nc', 'multirun', LOGFILE)
        os.rename('output/multirun/' + scenariofile, workdir+'/multirun')
        os.chdir(olddir)
    else:
        print 'Skipping CAABA multirun because full_calc=0'

##############################################################################

def get_scenarionames():
    fullscenarionames = sorted(glob('output/fullmech/multirun/runs/*'))
    scenarionames = map(os.path.basename, fullscenarionames)
    N_scenarios = len(scenarionames)
    #print 'scenarios   = ', scenarionames
    print 'N_scenarios = ', N_scenarios
    return scenarionames, N_scenarios

##############################################################################

def calc_oic():
    global oicdata, N_var_full, StoichNum, N_rxns, EQN_TAGS, EQN_NAMES
    if (full_calc):
        print '\nCalculate DICs and OICs for full mechanism.'
        # Compile skeleton:
        runcmd('gmake', 'gmake', LOGFILE)
        # Run skeleton:
        runcmd('./skeleton.exe', 'skeleton.exe', LOGFILE)
    else:
        print 'Skipping calculation of OIC with skeleton.exe because full_calc=0'
    # Load OIC data from skeleton.exe into a numpy structured array:
    # http://docs.scipy.org/doc/numpy/user/basics.rec.html
    oicdata = np.genfromtxt('OIC.dat', dtype=None)
    N_var_full = len(oicdata) # number of variable species
    # Load StoichNum from skeleton.exe:
    StoichNum = np.genfromtxt('StoichNum.dat')
    N_rxns  = len(StoichNum) # number of reactions
    # Load EQN_TAGS from skeleton.exe:
    EQN_TAGS = np.genfromtxt('EQN_TAGS.dat', dtype=None)
    # Load EQN_NAMES from skeleton.exe:
    EQN_NAMES = [line.rstrip() for line in open('EQN_NAMES.dat')]

##############################################################################

def save_rates(delrxn):
    rates = np.zeros((N_scenarios,len(delrxn)))
    keep_rxn = np.invert(np.array(delrxn))
    # scenario loop:
    for scenarionum, scenario in enumerate(scenarionames):
        ratesfile = workdir+'/multirun/runs/'+scenario+'/caaba_mecca_a_end.dat'
        rates0 = np.genfromtxt(ratesfile, dtype=None)
        rates[scenarionum,keep_rxn] = rates0
    return rates

##############################################################################

def calc_error():
    olddir = os.getcwd()
    os.chdir(workdir)
    ERRORFILE = open('skel_error.dat','w+', 1) # 1=line-buffered
    # delta_scenarios is a 2D array (list of lists) containing the
    # errors of the current skeletal mechanism compared to the full
    # mechanism for all targets and for all scenarios:
    delta_scenarios = []
    for scenario in scenarionames: # scenario loop
        print >> ERRORFILE, 'SCENARIO: %s' % (scenario)
        filename_part2 = 'multirun/runs/'+scenario+'/caaba_mecca_c_end.nc'
        # ncfile_full = NetCDFFile('../fullmech/'+filename_part2)
        # ncfile_skel = NetCDFFile(filename_part2)
        ncfile_full = Dataset('../fullmech/'+filename_part2)
        ncfile_skel = Dataset(filename_part2)
        # delta_targets is a 1D array (list) containing the errors of
        # the current skeletal mechanism compared to the full mechanism
        # for all targets:
        delta_targets = [None] * len(targetdata)
        print >> ERRORFILE, 'target           abstol      reltol    ' + \
          'mixrat_skel    mixrat_full     err/reltol'
        # target loop:
        for num, (target, abstol, reltol) in enumerate(targetdata):
            # read one number in a 4D array:
            mixrat_full = ncfile_full.variables[target][0][0][0][0]
            mixrat_skel = ncfile_skel.variables[target][0][0][0][0]
            delta_targets[num] = abs(max(mixrat_skel,abstol)/ \
              max(mixrat_full,abstol)-1) / reltol
            print >> ERRORFILE, '%-15s %10G %8G %14G %14G %14G' % \
              (target, abstol, reltol, mixrat_skel, mixrat_full, \
              delta_targets[num])
        delta_scenarios.append(delta_targets)
        ncfile_full.close()
        ncfile_skel.close()
    ERRORFILE.close()
    os.chdir(olddir)
    return delta_scenarios

##############################################################################

def analyze_results():
    global all_rates
    delta_scenarios = calc_error() # calculate error relative to full mechanism
    delta_skel = np.amax(delta_scenarios) # max error over all targets and scenarios
    # add rates in current skeletal mechanism to all_rates:
    all_rates = np.dstack((all_rates, save_rates(del_rxn)))
    if (delta_skel > 1.):
        RATESFILE = open(skeletondir+'/rates.dat','w+', 1) # 1=line-buffered
        # difference of current to previous skeletal mechanism:
        diff = all_rates[:,:,-1] - all_rates[:,:,-2]
        print 'List of scenarios with delta_skel>1:'
        for scenarionum in range(N_scenarios): # scenario loop
            idx = np.argsort(abs(diff[scenarionum,:]))
            # check if this is a problem scenario:
            if (max(delta_scenarios[scenarionum]) > 1.):
                print >> RATESFILE, HLINE+'\n\n*** Scenario: %04d\n' % (scenarionum+1)
                for targetnum, delta_target in enumerate(delta_scenarios[scenarionum]):
                    if (delta_target > 1.):
                        print 'Scenario: %04d, delta_skel: %10G, target: %s' % (
                            scenarionum+1, delta_target, targetdata[targetnum][0])
                    print >> RATESFILE, 'delta_skel: %10G, target: %s' % (
                        delta_target, targetdata[targetnum][0])
                print >> RATESFILE, '\nColumn 1: Reaction rate in previous '+ \
                    '(s%03d) skeletal mechanism [cm-3 s-1]' % (skelnum-1)
                print >> RATESFILE, 'Column 2: Is reaction also included in current ' \
                    '(s%03d) skeletal mechanism?' % (skelnum)
                print >> RATESFILE, 'Column 3: Difference (s%03d-s%03d) of current ' \
                    'minus previous reaction rate\n' % (skelnum, skelnum-1)
                print >> RATESFILE, '        s%03d  s%03d         diff reaction' % (
                    skelnum-1, skelnum)
                for i in range(idx.shape[0]):
                    x = idx[-i-1]
                    if (not del_rxnlist[-2][x]): # if rxn was in previous skeleton
                        print >> RATESFILE, '%12G %5s %12G %-10s %s' % (
                            all_rates[scenarionum,x,-2], not del_rxnlist[-1][x],
                            diff[scenarionum,x], '<'+EQN_TAGS[x]+'>', EQN_NAMES[x])
                print >> RATESFILE
        print >> RATESFILE, HLINE
        RATESFILE.close()
    return delta_skel, delta_scenarios

##############################################################################

def list_species(oicdata, epslist):
    SPECFILE = open('species.dat','w+', 1) # 1=line-buffered
    print >> SPECFILE, HLINE
    for epsnum,eps in enumerate(epslist):
        if (epsnum == 0):
            print >> SPECFILE, 'full',
        else:
            print >> SPECFILE, 's%03d' % (epsnum),
    print >> SPECFILE, 'OIC          species'
    print >> SPECFILE, HLINE
    for oic, species in np.sort(oicdata):
        for epsnum,eps in enumerate(epslist):
            if (oic < eps):
                print >> SPECFILE, '    ',
            else:
                if (epsnum == 0):
                    print >> SPECFILE, 'full',
                else:
                    print >> SPECFILE, 's%03d' % (epsnum),
        print >> SPECFILE, '%12E %s' % (oic, species)
    print >> SPECFILE, HLINE
    SPECFILE.close()

##############################################################################

def list_reactions(del_rxnlist):
    RXNSFILE = open('reactions.dat','w+', 1) # 1=line-buffered
    print >> RXNSFILE, HLINE
    for rxn in range(N_rxns): # loop over reactions
        for mechnum, delrxn in enumerate(del_rxnlist):
            if (delrxn[rxn]):
                print >> RXNSFILE, '    ',
            else:
                if (mechnum == 0):
                    print >> RXNSFILE, 'full',
                else:
                    print >> RXNSFILE, 's%03d' % (mechnum),
        # use either EQN_TAGS+EQN_NAMES or show_reaction:
        print >> RXNSFILE, '%-10s %s' % ('<'+EQN_TAGS[rxn]+'>', EQN_NAMES[rxn])
        #print >> RXNSFILE, show_reaction(rxn)
    print >> RXNSFILE, HLINE
    RXNSFILE.close()

##############################################################################

def finalize(info):
    print HLINE
    print 'Summary of results: %d targets, %d scenarios, %d skeletal mechanisms' % (
        info[2], info[1], info[0])
    print HLINE
    print 'Scenarios                      acro scenarios/'+scenariofile+'.pdf'
    print 'Logfile:                       e xskeleton.log'
    print 'List of species:               e species.dat'
    print 'List of reactions:             e reactions.dat'
    print 'List of rates:                 e rates.dat'
    print 'Rates in skeletal mechanisms:  e output/skeleton_*/rates.dat'
    print 'Errors of skeletal mechanisms: e output/skeleton_*/skel_error.dat'
    print 'Plots of errors:               acro output/delta_skel.pdf'
    print 'Plots of target species:       acro output/targets.pdf'
    print 'Plots of scenarios:            acro output/scenario_*.pdf'
    print HLINE
    tee.stdout_stop()

##############################################################################

def make_target_plots(plot_targets):
    from viewport import viewport
    if (not plot_targets): return
    print 'Plotting these skeletal mechanisms:\n', all_skel, '\n'
    viewport.init(4, 4, 'output/targets.pdf', 17, 8) # open pdf
    print HLINE
    for num, (target, abstol, reltol) in enumerate(targetdata): # target loop
        print 'Plotting target %-15s' % (target)
        for scenario in scenarionames: # scenario loop
            caabaplot.plot_0d(
                modelruns = [['output/'+skel+'/multirun/runs/'+scenario, skel]
                               for skel in all_skel],
                species   = target,
                pagetitle = 'Target species: '+target,
                plottitle = 'scenario: '+scenario,
                timeformat  = '%-Hh')
    viewport.exit() # close pdf

##############################################################################

def make_scenario_plots(plot_scenarios):
    if (not plot_scenarios): return
    # delete old plots:
    map(os.remove, glob('output/scenario_*.pdf'))
    # define verbosity to select species for plotting:
    verbosity = 2
    plotspecies = mecca.set_species(verbosity)
    print HLINE, '\n'
    # define plotscenarios:
    if (plot_scenarios == 1):
        plotscenarios = scenarionames
    else:
        plotscenarios = ['0003', '0027'] # selected scenarios
    for scenario in plotscenarios: # scenario loop
        print 'Plotting scenario %-15s' % (scenario)
        caabaplot.xxxg(
            modelruns   = [['output/'+skel+'/multirun/runs/'+scenario, skel]
                           for skel in all_skel],
            plotspecies = plotspecies,
            pdffile     = 'output/scenario_'+scenario,
            pagetitle   = 'Scenario: '+scenario,
            timeformat  = '%-Hh')
        print

##############################################################################

def make_delta_skel_plots(plot_delta_skel):
    from viewport import viewport
    from cycler import cycler
    if (not plot_delta_skel): return
    linecolors = ['k', 'r', 'g', 'b', 'm', 'y', 'c']
    print HLINE, '\nPlotting delta_skel errors'
    viewport.init(4, 4, 'output/delta_skel.pdf', 17, 8) # open pdf
    for scenarionum, scenario in enumerate(scenarionames): # scenario loop
        ax = viewport.next()
        if (viewport.current == 1):
            # on new page, start with legend on a dummy plot:
            for targetnum, (target, abstol, reltol) in enumerate(targetdata): # target loop
                lines = plt.plot([0,0], linewidth=3, label='%s (abstol=%G, reltol=%G)' % (
                    target, abstol, reltol))
            plt.axis('off')
            legend = plt.legend(loc='center',
                                mode='expand',
                                fontsize = 'small',
                                title='delta_skel errors',
                                fancybox=True,
                                shadow=True,
                                borderaxespad=0.)
            plt.setp(legend.get_title(),fontsize='large')
            ax = viewport.next()
            ax.set_prop_cycle(cycler('color', linecolors))
        for targetnum, (target, abstol, reltol) in enumerate(targetdata): # target loop
            mydata = delta_skel_all[:,scenarionum,targetnum]
            xval = np.arange(1, len(mydata)+1, 1)
            plt.plot(xval, mydata[:], '*', linestyle='solid')
        plt.xlim(0,len(mydata)+1)
        plt.ylim(0.,1.)
        # ax.set_yscale("log", nonposy='clip') # qqq doesn't work
        plt.title('Scenario: '+scenario)
        plt.xlabel('skeletal mechanism number')
        plt.ylabel('delta_skel')
    viewport.exit() # close pdf

##############################################################################

if __name__ == '__main__':

    if (full_calc):
        cleanup()
        print '\nRun caaba with full mechanism, all scenarios.'
    workdir = get_workdir(None)
    create_skeleton_rpl(None)
    caaba_multirun()
    scenarionames, N_scenarios = get_scenarionames()
    calc_oic()
    del_rxn = [False] * len(StoichNum)
    all_rates = save_rates(del_rxn)
    skelnum = 1
    # list of all mechanisms (full mechanism means del_rxnlist[:]=True):
    del_rxnlist = [del_rxn]
    # list of all eps (full mech means eps=0):
    epslist = [0]
    N_var = N_var_full
    # delta_skel_all_list is a list of lists of lists containing the errors
    # of all current skeletal mechanisms compared to the full mechanism
    # for all scenarios and for all targets:
    delta_skel_all_list = []
    while True: # loop over skeletal mechanisms for skelnum = 1, 2, 3, ...
        eps_last = eps
        workdir = get_workdir(skelnum)
        print '\n***** Skeletal mechanism %d *****' % (skelnum)
        # Increase epsilon_ep to include less reactions:
        eps_too_small = True
        while eps_too_small:
            eps *= eps_increase
            # Confirm that new mechanism contains less species:
            N_var_new = np.sum(oicdata['f0']>eps) # f0=oic values, f1=species
            if (N_var_new < N_var):
                print 'Number of species reduced from %d to %d' % (
                    N_var, N_var_new)
                N_var = N_var_new
                eps_too_small = False
            else:
                print 'Still %d species in mechanism with eps = %g.' % (
                    N_var, eps)
        if (N_var_new == 0):
            sys.exit('ERROR: N_var=0')
        # Create skeletal mechanism excluding species with OIC < eps:
        create_skeletal_mechanism(eps)
        caaba_multirun()
        delta_skel, delta_scenarios = analyze_results()
        delta_skel_all_list.append(delta_scenarios)
        # Exit loop when error is too big:
        if (delta_skel > 1.):
            break
        skelnum += 1
    print
    eps = eps_last
    # convert list of lists of lists to 3D numpy array:
    delta_skel_all = np.asarray(delta_skel_all_list)
    list_species(oicdata, epslist)
    list_reactions(del_rxnlist)
    # define directories of full and all skeletal mechanisms:
    all_skel = ['fullmech']
    for skeldir in sorted(glob('output/skeleton_*')):
        all_skel.append(os.path.basename(skeldir))
    # if there are too many skeletal mechanisms, show only the
    # last 4 and the full mechanism:
    if (len(all_skel)>5):
        all_skel = [all_skel[i] for i in [0]+range(-4,0)]        
    make_target_plots(plot_targets)
    make_scenario_plots(plot_scenarios)
    make_delta_skel_plots(plot_delta_skel)
    finalize(delta_skel_all.shape)

##############################################################################

# code snippets for debugging:
#qqq+
#sys.exit('END') #qqq
#print os.getcwd() # print current directory
#qqq-
