#!/usr/bin/env python
# -*- coding: utf-8 -*-

##############################################################################

# import os
# import subprocess
# import datetime
import sys
# import numpy as np
import re
import graph_tool.all as gt
import matplotlib.pyplot as plt

HLINE =  "-" * 78
DEBUG = False

##############################################################################

def create_graphviz_draw(g):
    gt.graphviz_draw(g, # http://www.graphviz.org/doc/info/attrs.html
                     #pos=pos,
                     gprops={#'concentrate':'true',
                         #'overlap':'prism0',
                         #'overlap':'scalexy',
                         'overlap':'voronoi', 
                         #'rankdir':'LR',
                         'splines':'curved',
                         #'splines':'ortho',
                         #'splines':'polyline',
                         #'splines':'spline',
                         'size':'40,20'},
                     vprops={'shape':'oval',
                             'fontsize':20,#gt.prop_to_size(x, mi=1, ma=30),
                             #'color':'Black', # circle around vertex
                             'colorscheme':'blues9', # doesn't work?
                             #'style':'filled',
                             'label':g.vp.name},
                     eprops={'label':g.ep.eqntag,
                             'dir':'forward', # forward arrows
                             'arrowhead':'normal',
                             'arrowsize':1,
                             'fontsize':10,
                             #'style':'dashed',
                             #'style':'dotted',
                             #'color':'black;0.25:red;0.5:yellow;0.25',
                             'color':'blue',
                             'penwidth':1},
                     overlap='compress',
                     #vcolor=s_dist,
                     vcolor=fraction,
                     #vcolor='yellow', # vertex fill color
                     output='mecca_graph.pdf')

##############################################################################

def create_max_flow(g):
    
    g.ep.cap = g.new_edge_property("double")
    rxnrates = read_rxn_rates('caaba_mecca_rr.nc')
    for e in g.edges():
        eqntag = g.ep.eqntag[e]
        rxnrate = rxnrates[g.ep.eqntag[e]]
        prodstoic = g.ep.prodstoic[e] # stoic factor of product
        print '%9s: %7g %g' % (eqntag, prodstoic, rxnrate)
        g.ep.cap[e] = prodstoic*rxnrate

    # choose one:
    res = gt.edmonds_karp_max_flow(g, src, tgt, g.ep.cap)
    #res = gt.push_relabel_max_flow(g, src, tgt, g.ep.cap)
    #res = gt.boykov_kolmogorov_max_flow(g, src, tgt, g.ep.cap)
    res.a = g.ep.cap.a - res.a  # the actual flow
    #print res.a
    max_flow = sum(res[e] for e in tgt.in_edges())
    print 'max flow: %g' % (max_flow)
    edgewidth = gt.prop_to_size(res, mi=0.1, ma=5, power=0.3)
    #-------------------------------------------------------------------------
    print ; print ; print 'Sorted listing of important reactions from %s to %s:' % (
        g.vp.name[src], g.vp.name[tgt])
    # put all info into mylist:
    mylist = []
    for e in g.edges():
        if (res[e]>0):
            mylist.append(
                [g.ep.cap[e], res[e], edgewidth[e], g.ep.eqntag[e],
                 g.vp.name[e.source()], g.vp.name[e.target()]])
    # sort and print mylist:
    print '            rxn rate                 flow            edgewidth     eqntag'
    #for myitem in sorted(mylist, reverse=True, key=lambda tup: tup[0]):
    for myitem in sorted(mylist, reverse=True):
        print '%20s %20s %20s %10s %s -> %s' % (
            myitem[0], myitem[1], myitem[2], myitem[3], myitem[4], myitem[5])
    #-------------------------------------------------------------------------
    # # filter out all edges that contribute < 1 %
    # for e in g.edges():
    #     if (res[e] < max_flow/100.):
    #         g.ep.myfilter[e] = False
    #     else:
    #         g.ep.myfilter[e] = True
    #     print g.ep.myfilter[e], res[e], g.ep.eqntag[e], g.ep.reaction[e]
    # g.set_edge_filter(g.ep.myfilter) # keep if True
    #-------------------------------------------------------------------------
    # filter out all vertices that contribute < 1 %
    keep = []
    delete = []
    for v in g.vertices():
        rxnrate = (sum(res[e] for e in v.in_edges()) +
                    sum(res[e] for e in v.out_edges())) / 2.
        if (rxnrate < max_flow/50.):
            g.vp.myfilter[v] = False
            delete.append(g.vp.name[v])
        else:
            g.vp.myfilter[v] = True
            keep.append(g.vp.name[v])
    print 'Keep: ', sorted(keep)
    print 'Delete: ', sorted(delete)
    g.set_vertex_filter(g.vp.myfilter) # keep if True
    #-------------------------------------------------------------------------
    # vertex colors:
    myvcolor = g.new_vertex_property('string')
    for v in g.vertices():
        myvcolor[v] = 'yellow'
    myvcolor[src] = 'red'
    myvcolor[tgt] = 'green'
    #-------------------------------------------------------------------------

    gt.graphviz_draw(g, # http://www.graphviz.org/doc/info/attrs.html
                     #pos=pos,
                     gprops={#'concentrate':'true',
                         #'overlap':'prism0',
                         'overlap':'voronoi', 
                         #'rankdir':'LR',
                         'splines':'curved',
                         #'splines':'ortho',
                         #'splines':'polyline',
                         #'splines':'spline',
                         'size':'40,20'},
                     vprops={'shape':'oval',
                             'fontsize':20,#gt.prop_to_size(x, mi=1, ma=30),
                             #'color':'Black', # circle around vertex
                             'style':'filled',
                             'label':g.vp.name},
                     eprops={'label':g.ep.eqntag,
                             'dir':'forward', # forward arrows
                             'arrowhead':'normal',
                             'arrowsize':1,
                             'fontsize':20,
                             #'style':'dashed',
                             #'style':'dotted',
                             #'color':'black;0.25:red;0.5:yellow;0.25',
                             'color':'blue',
                             'penwidth':edgewidth},
                     overlap='compress',
                     vcolor=myvcolor,
                     #vcolor='yellow', # vertex fill color
                     output='mecca_edmonds-karp.pdf')

def read_rxn_rates(ncfile):
    from netCDF4 import Dataset, num2date
    rxnfilename = 'caaba_mecca_rr.nc'
    ncid = Dataset(rxnfilename)
    time = ncid.variables['time']
    mytime = len(time)-36 # last day at noon if delta_t = 20 min
    print 'selected time: ', num2date(time[mytime],time.units)
    if DEBUG: print ncid.variables['RRJ41000'][:,0,0,0] # show rxn rates
    rxnrates = {}
    for rxn in ncid.variables:
        if (rxn[0:2]=='RR'):
            mydata = ncid.variables[rxn][mytime,0,0,0]
            rxnrates[rxn[2:]] = mydata
    ncid.close()
    return rxnrates # dict with eqntags -> rxnrates

##############################################################################

def create_interactive_window(g):
    # middle mouse button: move
    # mouse wheel: zoom
    # shift + mouse wheel: zoom (including vertex + edge sizes)
    # control + mouse wheel: rotate
    # 'a': autozoom
    # 'r': resize and center
    # 's': spring-block layout for all non-selected vertices
    # 'z': zoom to selected vertices
    # left mouse button: select vertex
    # right mouse button: unselect all
    # shift + left mouse button drag: select several vertices
    
    # https://graph-tool.skewed.de/static/doc/draw.html#graph_tool.draw.interactive_window
    # for options, see "List of vertex properties" at:
    # https://graph-tool.skewed.de/static/doc/draw.html
    gt.interactive_window(g,
                          geometry=(1000, 800), # initial window size
                          edge_text=g.ep.eqntag,
                          vertex_text=g.vp.name,
                          #layout_callback=layoutchanged,
                          key_press_callback=keypressed,
                          vertex_fill_color='yellow',
                          #vertex_aspect=2, # very slow!!!
                          #display_props=g.vp.name,
                          display_props=[g.vp.name,g.vp.atoms],
                          display_props_size=20
    )

def keypressed(gtk, g, keyval, picked, pos, vprops, eprops):
    print 'keypressed'
    #print 'gtk='    ,gtk    # gtk_draw.GraphWidget
    #print 'g='      ,g      # graph being drawn
    print 'keyval=' ,keyval # key id (a=97, z=122)
    if picked:
        print 'picked=' ,g.vp.name[picked] # vertex or boolean vertex property map for selected vertices
    #print 'pos='    ,pos    # vertex positions
    print 'vprops='         # vertex property dictionary
    print vprops['text'][picked]
    for vprop in vprops:
        print '%s = %s' % (vprop, vprops[vprop])
    print 'eprops='         # edge property dictionary
    for eprop in eprops:
        print '%s = %s' % (eprop, eprops[eprop])
    # for eqntag in eprops['text']:
    #     print eqntag

def layoutchanged(gtk, g, picked, pos, vprops, eprops):
    print 'layoutchanged'
    print 'gtk='    ,gtk    # gtk_draw.GraphWidget
    print 'g='      ,g      # graph being drawn
    print 'picked=' ,picked # vertex or boolean vertex property map for selected vertices
    print 'pos='    ,pos    # 
    print 'vprops=' ,vprops # vertex property dictionary 
    print 'eprops=' ,eprops # edge property dictionary
            
##############################################################################

def n2v(species_name): # n2v = name to vertex
    vertex = gt.find_vertex(g, g.vp.name, species_name)
    if (len(vertex)==1):
        return vertex[0] # return a vertex
    else:
        print 'ERROR: %s is duplicate or missing' % (species_name)

##############################################################################

def list_vertices(g):
    for v in g.vertices():
        print g.vp.name[v],
    print '--> %d species' % (g.num_vertices())

##############################################################################

def list_edges(g):
    for e in g.edges():
        print g.ep.eqntag[e], g.ep.reaction[e]
        print '%s (%d) -> %s (%d)' % (
            g.vp.name[e.source()], elem_count[g.vp.name[e.source()]]['C'],
            g.vp.name[e.target()], elem_count[g.vp.name[e.target()]]['C'])
    print '--> %d reactions' % (g.num_edges())

##############################################################################

def elemental_composition(g):
    # define elemental composition elem_count:
    elem_count = {}
    for v in g.vertices():
        # create temporary dictionary with element count set to zero:
        tmp_dict = {'O':0,'H':0,'N':0,'C':0,'F':0,'Cl':0,'Br':0,'I':0,'S':0,'Hg':0}
        # loop over all elements of current species:
        for element in g.vp.atoms[v].split('+'):
            # search for count and element symbol:
            search_result = re.search('([0-9]*)([A-Za-z]+)', element)
            if (search_result.group(1)==''):
                count = 1
            else:
                count = int(search_result.group(1))
            # update current element in temporary dictionary:
            tmp_dict[search_result.group(2)] = count
        # add elemental composition of current species to dictionary:
        elem_count[g.vp.name[v]] = tmp_dict
        if DEBUG:
            print '%-15s %-15s' % (g.vp.name[v],g.vp.atoms[v]),
            print tmp_dict
            print 'example usage: ',
            print 'HCHO has %d H atoms.' % (elem_count['HCHO']['H'])
    return elem_count

##############################################################################

if __name__ == '__main__':

    #-------------------------------------------------------------------------

    # load graph produced by define_graph.py:
    g = gt.load_graph("mecca_graph.xml.gz")
    elem_count = elemental_composition(g) # define elemental composition
    g.vp.myfilter = g.new_vertex_property('bool') # vertex filter (internal property map)
    g.ep.myfilter = g.new_edge_property('bool')   # edge filter   (internal property map)
    N_v = g.num_vertices()
    #src = n2v('MACR') # define target species
    #src = n2v('APINENE') # define target species
    #src = n2v('CH3COCH3') # define target species
    #src = n2v('HCHO') # define target species
    src = n2v('C5H8') # define target species
    #src = n2v('CH4')   # define source species
    #tgt = n2v('HCOOH') # define target species
    tgt = n2v('CO2') # define target species
    #tgt = n2v('HCHO') # define target species
    #-------------------------------------------------------------------------

    # filters:

    # filter option 1 = only organic species (C>=1):
    # g.vp.myfilter.a = [elem_count[g.vp.name[v]]['C']>0 for v in g.vertices()]
    # g.set_vertex_filter(g.vp.myfilter) # keep if True

    # filter option 2 = remove unreachable species:
    # s_dist = gt.shortest_distance(g, source=src)
    # g.set_reversed(True)
    # t_dist = gt.shortest_distance(g, source=tgt)
    # g.set_reversed(False)
    # g.vp.myfilter.a = [s_dist[v]<N_v and t_dist[v]<N_v for v in g.vertices()]
    # g.set_vertex_filter(g.vp.myfilter) # keep if True

    # for v in gC.vertices():
    #     print '%-15s %d %d %d %d' % (
    #         gC.vp.name[v], int(gC.vp.myfilter[v]),
    #         elem_count[gC.vp.name[v]]['C'], s_dist[v], t_dist[v])

    # g.clear_filters()

    #-------------------------------------------------------------------------

    # graph views:

    # create graph view gC with only organic species (C>=1):
    gC = gt.GraphView(g, vfilt=lambda v: elem_count[g.vp.name[v]]['C']>0)
    
    # remove edges where number of C atoms increases:
    #print HLINE ; list_edges(gC)
    gC = gt.GraphView(gC, efilt=lambda e:
                      elem_count[gC.vp.name[e.source()]]['C'] >=
                      elem_count[gC.vp.name[e.target()]]['C'])
    #print HLINE ; list_edges(gC)

    s_dist = gt.shortest_distance(gC, source=src)
    gC.set_reversed(True)
    t_dist = gt.shortest_distance(gC, source=tgt)
    gC.set_reversed(False)
    fraction = g.new_vertex_property('float')
    list_vertices(gC)
    # remove unreachable species from graph:
    gC = gt.GraphView(gC, vfilt=lambda v: s_dist[v]<N_v and t_dist[v]<N_v)
    list_vertices(gC)
    print 'name, #C, %s->, ->%s, fraction' % (gC.vp.name[src], gC.vp.name[tgt])
    for v in gC.vertices():
        fraction[v] = float(s_dist[v])/(s_dist[v]+t_dist[v])
        print '%-15s %d %d %d %s' % (
            gC.vp.name[v], elem_count[gC.vp.name[v]]['C'],
            s_dist[v], t_dist[v], fraction[v])

    #-------------------------------------------------------------------------

    # show shortest path from A to B:
    A = src
    B = tgt # n2v('CO2')
    print 'from %s to %s' % (gC.vp.name[A], gC.vp.name[B])
    vlist, elist = gt.shortest_path(gC, A, B)
    for v in vlist:
        print gC.vp.name[v], ' ',
    print
    for e in elist:
        print gC.ep.eqntag[e], gC.ep.reaction[e]

    #-------------------------------------------------------------------------

    # centrality:
    # x = gt.katz(g)
    # print x.a

    #-------------------------------------------------------------------------

    # correlations:
    # https://graph-tool.skewed.de/static/doc/correlations.html
    # h = gt.corr_hist(g, "out", "out")
    # plt.clf()
    # plt.xlabel("Source out-degree")
    # plt.ylabel("Target out-degree")
    # plt.imshow(h[0].T, interpolation="nearest", origin="lower")
    # plt.colorbar()
    # plt.savefig("corr.pdf")

    #-------------------------------------------------------------------------

    #create_interactive_window(gC)
    create_graphviz_draw(gC)
    create_max_flow(gC)
    #list_edges(gC)

    #-------------------------------------------------------------------------

    #sys.exit('END') #qqq
