# -*- coding: utf-8 -*-
# Copyright (c) 2019, PyRETIS Development Team.
# Distributed under the LGPLv2.1+ License. See LICENSE for more info.
"""This file contains classes to represent order parameters.

CD: adding coordination number thingy.  Following instructions at 
http://www.pyretis.org/current/user/orderparameters.html

v2: Now adding classes for a) min. OH distance and b) c.num. computed
from min. OH distance.

v3: adding in complicated scheme to track smallest covalent COO-H distance.

v31: changing covalent thingy a bit.  First assign all H's to an O,
then based on that find the shortest distance from O to H, or from COO- to H3O+
depending.

v32: Also add Lambda2 c.v., ie. longest OH for H's 
assigned to water O closest to carboxylate.  Still not sure how to
use multiple OP's though... untested

v4: Need to add some stuff for so3 etc. reactions.
1) add boolean "negative" to choose if the OP should be -ive
(needed if the OP gets smaller for higher interfaces)
2) add SumDist OP, sum of n different distances.

Important classes defined here
---------------------------------

CoordNum (:py:class:'.CoordNum')

"""
import mdtraj
import numpy as np
from pyretis.orderparameter import OrderParameter

class CoordNum(OrderParameter):
    """A coordination order parameter.

    This class defines an order parameter made up of a sum of 
    functions of pairwise distances, designed to mimic the coordination number.

    Because PyRETIS expects the order parameter to increase from reactant to
    product I will use 1 - cnum.

    Attributes
    ----------
    indexi : list of integers
        These are the indices used for "from" in colvars parlance.
    indexj : list of integers
        indices for outer loop, "to"
        loop over ni, nj
        `system.particles.pos[0..ni]` and
        `system.particles.pos[0..nj]` will be used.
    periodic : boolean
        This determines if periodic boundaries should be applied to
        the distance or not.
    negative : boolean
        Determines if the whole thing should be the negative.

    """

    def __init__(self, indexi, indexj, expn,expm,rcut,periodic=True,negative=False):
        """Initialise order parameter.

        Parameters
        ----------
        indexi : list of ints
            These are the indices of the atoms we will use as the coordnum center(s).
        indexj : list of ints
            These are the indices of the atoms we will use as the "to" for coordnum.
        expn : numerator exponent (usually 6)
        expm : denominator exponent (usually 12 or 18)
        rcut : cutoff value
        periodic : boolean, optional
            This determines if periodic boundary conditions should be
            applied to the distances.
        negative : boolean, optional
            Determines if we want to multiply by -1.

        """

        txt = '''Setting up CoordNum order parameter,n={},m={},rcut={},
i={},j={}'''.format(
            expn,
            expm,
            rcut,
            indexi,
            indexj
            )
        super().__init__(description=txt,velocity=False)

        self.periodic = periodic
        self.negative = negative
        self.expn = expn
        self.expm = expm
        self.rcut = rcut

        self.indexi = []
        self.indexj = []

        self.indexi = [int(i) for i in indexi]
        self.indexj = [int(i) for i in indexj]

        print(self.indexi)
        print(self.indexj)

    def calculate(self, system):
        """Calculate the order parameter.

        Here, the order parameter is the sum of a function of the pairwise
        distances between atoms listed in indexi and indexj

        Parameters
        ----------
        system : object like :py:class:`.System`
            The object containing the positions and box used for the
            calculation.

        Returns
        -------
        out : one float.
            The coordination number order parameter.

        """

        particles = system.particles
        cnum = 0.0
        for i in self.indexi:
            for j in self.indexj:
#                print(i,j)
                delta = (particles.pos[i] 
                         - particles.pos[j])
                if self.periodic:
                    delta = system.box.pbc_dist_coordinate(delta)

                dist = np.sqrt(np.dot(delta,delta))

                cnum += ( (1.0-(dist/self.rcut)**self.expn)
                         / (1.0-(dist/self.rcut)**self.expm) )

        cnum2 = 1 - cnum

        if self.negative:
            cnum2 = -1*cnum2
        
        print(cnum2)
        return [cnum2]

""" alternate method using mdtraj to import the config for some reason

    def calculate(self,system):
        cfg = system.particles.config[0]
        trj = mdtraj.load(cfg)
        frame = trj[0]

        cnum = 0.0
        for i in self.indexi:
            for j in self.indexj:
                delta = ( frame.xyz[i] 
                         - frame.xyz[j] )
                if self.periodic:
                    delta = system.box.pbc_dist_coordinate(delta)

                dist = np.sqrt(np.dot(delta,delta))

                cnum += ( (1.0-(dist/self.rcut)**self.expn)
                         / (1.0-(dist/self.rcut)**self.expm) )

        cnum2 = 1 - cnum
        return [cnum2]
"""

"""
MinDist (:py:class:'.MinDist')
"""

class MinDist(OrderParameter):
    """An order parameter based on the minimum separation between
       two lists of atoms i and j.

    Attributes
    ----------
    indexi : list of integers
        These are the indices used for "from" in colvars parlance.
    indexj : list of integers
        indices for outer loop, "to"
        loop over ni, nj
        `system.particles.pos[0..ni]` and
        `system.particles.pos[0..nj]` will be used.
    periodic : boolean
        This determines if periodic boundaries should be applied to
        the distance or not.
    negative : boolean
        Mult. by -1? y/n

    """

    def __init__(self, indexi, indexj,periodic=True,negative=False):
        """Initialise order parameter.

        Parameters
        ----------
        indexi : list of ints
            These are the indices of the atoms we will use as the center(s) i.
        indexj : list of ints
            These are the indices of the atoms we will use as the atoms j.
        periodic : boolean, optional
            This determines if periodic boundary conditions should be
            applied to the distances.
        negative : boolean, optional
            Mult. by -1?

        """

        txt = 'Setting up MinDist order parameter,i={},j={}'.format(
            indexi,
            indexj
            )
        super().__init__(description=txt,velocity=False)

        self.periodic = periodic
        self.negative = negative

        self.indexi = []
        self.indexj = []

        self.indexi = [int(i) for i in indexi]
        self.indexj = [int(i) for i in indexj]

        print(self.indexi)
        print(self.indexj)

    def calculate(self, system):
        """Calculate the order parameter.

        Here, the order parameter is the minimum distance
        between two groups of atoms.

        Parameters
        ----------
        system : object like :py:class:`.System`
            The object containing the positions and box used for the
            calculation.

        Returns
        -------
        out : one float.
            The coordination number order parameter.

        """

        particles = system.particles
        itr = 0
        for i in self.indexi:
            for j in self.indexj:
#                print(i,j)
                itr += 1
                delta = (particles.pos[i] 
                         - particles.pos[j])
                if self.periodic:
                    delta = system.box.pbc_dist_coordinate(delta)

                dist = np.sqrt(np.dot(delta,delta))
# do I know for sure that dist and mindist are not just the same variable?
# seems ok...
#                print(dist)
                if (itr == 1):
                    mindist = dist
                elif dist < mindist:
                    mindist = dist
#                print(dist,mindist)

        if self.negative:
            mindist = -1*mindist

        print(mindist)
        return [mindist]

"""
SumDist (:py:class:'.MinDist')
"""

class SumDist(OrderParameter):
    """An order parameter based on the sum of some number of atom-atom distances.

    Attributes
    ----------
    indexi : list of integers
        indices for first set of atoms 1 to npairs
    indexj : list of integers
        indices for second set of atoms 1 to npairs
        will compute distances from i1 to j1, etc.
        indexi and indexj should be the same length.
    periodic : boolean
        This determines if periodic boundaries should be applied to
        the distance or not.
    negative : boolean
        Mult. by -1? y/n

    """

    def __init__(self, indexi, indexj,periodic=True,negative=False):
        """Initialise order parameter.

        Parameters
        ----------
        indexi : list of ints
            First set of atoms
        indexj : list of ints
            Second set of atoms
        periodic : boolean, optional
            This determines if periodic boundary conditions should be
            applied to the distances.
        negative : boolean, optional
            Mult. by -1?

        """
        txt = 'Setting up SumDist order parameter,i={},j={}'.format(
            indexi,
            indexj,
            )
        super().__init__(description=txt,velocity=False)

        self.periodic = periodic
        self.negative = negative

        self.indexi = []
        self.indexj = []

        self.indexi = [int(i) for i in indexi]
        self.indexj = [int(i) for i in indexj]

        print(self.indexi)
        print(self.indexj)

    def calculate(self, system):
        """Calculate the order parameter.

        Here, the order parameter is the minimum distance
        between two groups of atoms.

        Parameters
        ----------
        system : object like :py:class:`.System`
            The object containing the positions and box used for the
            calculation.

        Returns
        -------
        out : one float.
            The SumDist order parameter.

        """

        particles = system.particles
        sumdist = 0
        for i,j in zip(self.indexi,self.indexj):
            print(i,j)
            delta = (particles.pos[i] 
                - particles.pos[j])
            if self.periodic:
                delta = system.box.pbc_dist_coordinate(delta)
            dist = np.sqrt(np.dot(delta,delta))
            sumdist = sumdist + dist
            print(sumdist)

        if self.negative:
            sumdist = -1*sumdist

        print(sumdist)
        return [sumdist]

"""
MinCovDist (:py:class:'.MinCovDist')
"""

class MinCovDist(OrderParameter):
    """An order parameter based on the minimum separation between
       two lists of atoms i and j.

    Attributes
    ----------
    indexi : list of integers
        These are the indices used for "from" in colvars parlance.
        COO oxygens in this case.
    indexow: list of integers defining water oxygens.
    indexj : list of integers defining all reactive hydrogens.
        indices for outer loop, "to"
        loop over ni, nj
        `system.particles.pos[0..ni]` and
        `system.particles.pos[0..nj]` will be used.
    periodic : boolean
        This determines if periodic boundaries should be applied to
        the distance or not.
    negative : boolean
        Mult. by -1? y/n

    """

    def __init__(self,indexi,indexow,indexj,periodic=True,negative=False):
        """Initialise order parameter.

        Parameters
        ----------
        indexi : list of ints
            These are the indices of the atoms we will use as the center(s) i.
	    indexow : list of ints
            indices of water oxygen atoms.
        indexj : list of ints
            These are the indices of the atoms we will use as the atoms j.
        periodic : boolean, optional
            This determines if periodic boundary conditions should be
            applied to the distances.
        negative : boolean, optional
            Mult by -1?

        """

        txt = '''Setting up MinCovDist order parameter,
i={},O={},j={}'''.format(
            indexi,
            indexow,
            indexj,
            )
        super().__init__(description=txt,velocity=False)

        self.periodic = periodic
        self.negative = negative

        self.indexi = []
        self.indexow = []
        self.indexj = []

        self.indexi = [int(i) for i in indexi]
        self.indexow = [int(i) for i in indexow]
        self.indexj = [int(i) for i in indexj]

        print(self.indexi)
        print(self.indexow)
        print(self.indexj)

# Combined list of all oxygens is useful, but just do it in the intialization once!

        self.indexox = []
        print(self.indexox)
        self.indexox = self.indexi[:]
        print(self.indexox)
        self.indexox.extend(self.indexow)
        print(self.indexox)
        print(self.indexi)

    def calculate(self, system):
        """Calculate the order parameter.

        Here, the order parameter is the minimum distance
        between two groups of atoms.

        Parameters
        ----------
        system : object like :py:class:`.System`
            The object containing the positions and box used for the
            calculation.

        Returns
        -------
        out : one float.
            The coordination number order parameter.

        """

        particles = system.particles

# initialize indexOH as nested list to hold linked O and H indices.

        indexOH = [[0 for i in self.indexox] for _ in range(5)]
#        print(indexOH)

        indexOH[0] = self.indexox
#        print(indexOH)

        index3H = []

        n3h = 0
        for j in self.indexj:
            itr = 0
            for i in self.indexox:
#                print(i,j)
                itr += 1
                delta = (particles.pos[i] 
                         - particles.pos[j])
                if self.periodic:
                    delta = system.box.pbc_dist_coordinate(delta)

                dist = np.sqrt(np.dot(delta,delta))
                if (itr == 1):
                    mindist = dist
#                    tempidx = i
                    itrmin = itr
                elif (dist < mindist):
                    mindist = dist
#                    tempidx = i
                    itrmin = itr
#            print(itr,itrmin,tempidx)
            indexOH[1][itrmin-1] += 1
            if (indexOH[1][itrmin-1] > 2):
                n3h += 1
#                index3H.append(tempidx)
            indexOH[1+indexOH[1][itrmin-1]][itrmin-1] = j

        print(indexOH)
        print(n3h,index3H)

# if no 3-coord, find the minimum distance COO - H, as before.
# this one can probably stay the same?
        
        if (n3h == 0):
            itr = 0
            for i in self.indexi:
                for j in self.indexj:
#                    print(n3h,i,j)
                    itr += 1
                    delta = (particles.pos[i] 
                             - particles.pos[j])
                    if self.periodic:
                        delta = system.box.pbc_dist_coordinate(delta)

                    dist = np.sqrt(np.dot(delta,delta))
                    if (itr == 1):
                        mindist = dist
                    elif dist < mindist:
                        mindist = dist

# if there is one or more 3-coord. water, only look for min.
# distance to those hydrogens.
# same as before but different indexing.

        else:

            transOH = list(map(list,zip(*indexOH)))
            print(transOH)

            nOX = len(indexOH[0])
            print(nOX)

            index3H = []

            itr = 1
            for i in range(nOX-2):
                itr += 1
                if transOH[itr][1] == 3:
                    itrj = 0
                    for j in transOH[itr]:
                        itrj += 1
                        if itrj > 2:
                            index3H.append(j)

# OK now I have a list of the hydrogens bound to triply-coordinated oxygens!

            itr = 0
            for i in self.indexi:
                for j in index3H:
                    itr += 1
                    delta = (particles.pos[i]
                             - particles.pos[j])
                    if self.periodic:
                        delta = system.box.pbc_dist_coordinate(delta)

                    dist = np.sqrt(np.dot(delta,delta))
                    if (itr == 1):
                        mindist = dist
                    elif dist < mindist:
                        mindist = dist
                    print(i,j,dist,mindist)


        if self.negative:
            mindist = -1*mindist

        print(mindist)
        return [mindist]

"""
CNMinDist (:py:class:'.CNMinDist')
"""

class CNMinDist(OrderParameter):
    """An order parameter based on the coordination number function,
       but only for the one pair of atoms with minimum dist. between
       two lists of atoms i and j.

    Attributes
    ----------
    indexi : list of integers
        These are the indices used for "from" in colvars parlance.
    indexj : list of integers
        indices for outer loop, "to"
        loop over ni, nj
        `system.particles.pos[0..ni]` and
        `system.particles.pos[0..nj]` will be used.
    periodic : boolean
        This determines if periodic boundaries should be applied to
        the distance or not.
    negative : boolean
        Mult. by -1? y/n

    """

    def __init__(self, indexi, indexj, expn,expm,rcut,periodic=True,negative=False):
        """Initialise order parameter.

        Parameters
        ----------
        indexi : list of ints
            These are the indices of the atoms we will use as the coordnum center(s).
        indexj : list of ints
            These are the indices of the atoms we will use as the "to" for coordnum.
        expn : numerator exponent (usually 6)
        expm : denominator exponent (usually 12 or 18)
        rcut : cutoff value
        periodic : boolean, optional
            This determines if periodic boundary conditions should be
            applied to the distances.
        negative : boolean, optional
            Mult. by -1?

        """

        txt = '''Setting up CNMinDist order parameter,n={},m={},rcut={},
i={},j={}'''.format(
            expn,
            expm,
            rcut,
            indexi,
            indexj
            )
        super().__init__(description=txt,velocity=False)

        self.periodic = periodic
        self.negative = negative
        self.expn = expn
        self.expm = expm
        self.rcut = rcut

        self.indexi = []
        self.indexj = []

        self.indexi = [int(i) for i in indexi]
        self.indexj = [int(i) for i in indexj]

        print(self.indexi)
        print(self.indexj)

    def calculate(self, system):
        """Calculate the order parameter.

        Here, the order parameter is the coord.num based on the
        minimum distance between two groups of atoms.

        Parameters
        ----------
        system : object like :py:class:`.System`
            The object containing the positions and box used for the
            calculation.

        Returns
        -------
        out : one float.
            The coordination number order parameter.

        """

        particles = system.particles
        itr = 0
        for i in self.indexi:
            for j in self.indexj:
#                print(i,j)
                itr += 1
                delta = (particles.pos[i] 
                         - particles.pos[j])
                if self.periodic:
                    delta = system.box.pbc_dist_coordinate(delta)

                dist = np.sqrt(np.dot(delta,delta))
                if (itr == 1):
                    mindist = dist
                elif dist < mindist:
                    mindist = dist

        cnum = ( (1.0-(mindist/self.rcut)**self.expn)
                / (1.0-(mindist/self.rcut)**self.expm) )

        cnum2 = 1 - cnum

        if self.negative:
            cnum2 = -1*cnum2

        return [cnum2]
