import numpy as np
import time
from pyOptFEMP1.FEM import *
from pyOptFEMP1.mesh import HyperCube
from pyOptFEMP1.FEMtools import NormInf,genericFunc,setFdata

import matplotlib.pyplot as plt

def benchStiff(d,Versions,LN,**kwargs):
  """  bench for assembly of the Stiffness Matrix by :math:`P_1`-Lagrange finite elements using 
  various versions (see report). Only on HyperCube mesh.
  
  :param d: space dimension
  :param Versions: a list of versions
  :param LN: HyperCube mesh refinements  
  """
  nbruns=kwargs.get('nbruns',3)
  isplot=kwargs.get('isplot',False)
  
  nLN=len(LN)
  nV=len(Versions)
  Tcpu=np.zeros((nLN,nV))
  Lndof=np.zeros((nLN,1))
  T=np.zeros((nbruns,1))
  for i in range(nLN):
    N=LN[i]
    print(' -> (%d/%d) Creation of the mesh - HyperCube(%d,%d) ...'%(i+1,nLN,d,N))
    Th=HyperCube(d,N)
    print('    %dd-mesh : nq=%d, nme=%d'%(Th.d,Th.nq,Th.nme));
    for v in range(nV):
      for l in range(nbruns):
        tstart=time.time()
        M=globals()['AssemblyStiffP1'+Versions[v]](Th);
        T[l]=time.time()-tstart
        print('         run (%d/%d) : %.4f(s)'%(l+1,nbruns,T[l]))
      if (v==0):
        M1=M  
      Tcpu[i,v]=T.mean()
      print('       AssemblyStiffP1%6s (%dd) : %d-by-%d matrix in %.4f(s)'%(Versions[v],d,M.shape[0],M.shape[1],Tcpu[i,v]))
      if v>0: 
        print('       Error with %6s matrix   : %.5e'%(Versions[0],NormInf(M-M1)))
        print('       %6s speedup             : x%.3f'%(Versions[0],Tcpu[i,0]/Tcpu[i,v]))
      Lndof[i]=M.shape[0]
  if isplot:
    plt.ion()
    plotBench(Versions,Lndof,Tcpu)
  return Lndof,Tcpu
    
def benchMass(d,Versions,LN,**kwargs):
  nbruns=kwargs.get('nbruns',3)
  isplot=kwargs.get('isplot',False)
  
  nLN=len(LN)
  nV=len(Versions)
  Tcpu=np.zeros((nLN,nV))
  Lndof=np.zeros((nLN,1))
  T=np.zeros((nbruns,1))
  for i in range(nLN):
    N=LN[i]
    print(' -> (%d/%d) Creation of the mesh - HyperCube(%d,%d) ...'%(i+1,nLN,d,N))
    Th=HyperCube(d,N)
    print('    %dd-mesh : nq=%d, nme=%d'%(Th.d,Th.nq,Th.nme));
    for v in range(nV):
      for l in range(nbruns):
        tstart=time.time()
        M=globals()['AssemblyMassP1'+Versions[v]](Th);
        T[l]=time.time()-tstart
        print('         run (%d/%d) : %.4f(s)'%(l+1,nbruns,T[l]))
      if (v==0):
        M1=M  
      Tcpu[i,v]=T.mean()
      print('       AssemblyMassP1%6s (%dd) : %d-by-%d matrix in %.4f(s)'%(Versions[v],d,M.shape[0],M.shape[1],Tcpu[i,v]))
      if v>0: 
        print('       Error with %6s matrix   : %.5e'%(Versions[0],NormInf(M-M1)))
        print('       %6s speedup             : x%.3f'%(Versions[0],Tcpu[i,0]/Tcpu[i,v]))
      Lndof[i]=M.shape[0]
  if isplot:
    plt.ion()
    plotBench(Versions,Lndof,Tcpu)
  return Lndof,Tcpu
    
def benchMassW(d,Versions,LN,**kwargs):
  nbruns=kwargs.get('nbruns',3)
  isplot=kwargs.get('isplot',False)
  w=kwargs.get('w',genericFunc(d,'cos(1'+''.join('+x%d'% i for i in range(1,d+1))+')'))
  
  nLN=len(LN)
  nV=len(Versions)
  Tcpu=np.zeros((nLN,nV))
  Lndof=np.zeros((nLN,1))
  T=np.zeros((nbruns,1))
  for i in range(nLN):
    N=LN[i]
    print(' -> (%d/%d) Creation of the mesh - HyperCube(%d,%d) ...'%(i+1,nLN,d,N))
    Th=HyperCube(d,N)
    print('    %dd-mesh : nq=%d, nme=%d'%(Th.d,Th.nq,Th.nme));
    W=setFdata(w,Th)
    for v in range(nV):
      for l in range(nbruns):
        tstart=time.time()
        M=globals()['AssemblyMassWP1'+Versions[v]](Th,W);
        T[l]=time.time()-tstart
        print('         run (%d/%d) : %.4f(s)'%(l+1,nbruns,T[l]))
      if (v==0):
        M1=M  
      Tcpu[i,v]=T.mean()
      print('       AssemblyMassWP1%6s (%dd) : %d-by-%d matrix in %.4f(s)'%(Versions[v],d,M.shape[0],M.shape[1],Tcpu[i,v]))
      if v>0: 
        print('       Error with %6s matrix   : %.5e'%(Versions[0],NormInf(M-M1)))
        print('       %6s speedup             : x%.3f'%(Versions[0],Tcpu[i,0]/Tcpu[i,v]))
      Lndof[i]=M.shape[0]
  if isplot:
    plt.ion()
    plotBench(Versions,Lndof,Tcpu)
  return Lndof,Tcpu

def benchStiffElas(d,Versions,LN,**kwargs):
  assert (d==2) or (d==3)
  # Poisson and Lame's parameters for rubber
  E = 21e5; nu = 0.45
  mu= E/(2*(1+nu)); lam = E*nu/((1+nu)*(1-2*nu))
  nbruns=kwargs.get('nbruns',3)
  isplot=kwargs.get('isplot',False)
  lam=kwargs.get('lam',lam)
  mu=kwargs.get('mu',mu)
  
  nLN=len(LN)
  nV=len(Versions)
  Tcpu=np.zeros((nLN,nV))
  Lndof=np.zeros((nLN,1))
  T=np.zeros((nbruns,1))
  for i in range(nLN):
    N=LN[i]
    print(' -> (%d/%d) Creation of the mesh - HyperCube(%d,%d) ...'%(i+1,nLN,d,N))
    Th=HyperCube(d,N)
    print('    %dd-mesh : nq=%d, nme=%d'%(Th.d,Th.nq,Th.nme));
    Lam=setFdata(lam,Th)
    Mu=setFdata(mu,Th)
    for v in range(nV):
      for l in range(nbruns):
        tstart=time.time()
        M=globals()['AssemblyStiffElasP1'+Versions[v]](Th,Lam,Mu);
        T[l]=time.time()-tstart
        print('         run (%d/%d) : %.4f(s)'%(l+1,nbruns,T[l]))
      if (v==0):
        M1=M  
      Tcpu[i,v]=T.mean()
      print('       AssemblyStiffElasP1%6s (%dd) : %d-by-%d matrix in %.4f(s)'%(Versions[v],d,M.shape[0],M.shape[1],Tcpu[i,v]))
      if v>0: 
        print('       Error with %6s matrix   : %.5e'%(Versions[0],NormInf(M-M1)))
        print('       %6s speedup             : x%.3f'%(Versions[0],Tcpu[i,0]/Tcpu[i,v]))
      Lndof[i]=M.shape[0]
  if isplot:
    plt.ion()
    plotBench(Versions,Lndof,Tcpu)
  return Lndof,Tcpu
 
def plotBench(versions,Lndof,T):
  nV=len(versions)
  for i in range(0,nV):
    plt.loglog(Lndof,T[:,i],label=versions[i])
  plt.loglog(Lndof,1.2*max(T[0,:])*Lndof/Lndof[0],'k--',label="$O(n_{dof})$")
  plt.loglog(Lndof,np.mean(T[0,:])*Lndof**2/(Lndof[0]**2),'k-.',label="$O(n_{dof}^2)$")
  #plt.legend(loc='lower right')
  #plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
  plt.grid()
  plt.xlabel('$n_{dof}$')
  plt.ylabel('cputime(s)')
  plt.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3,
       ncol=nV+2, mode="expand", borderaxespad=0.)
  plt.show()