import matplotlib.pyplot as plt
import numpy as np
import sys,os
sep=os.path.sep
ToolBoxDir=os.path.abspath(os.path.dirname(os.path.abspath(__file__))+sep+"..")
sys.path.append(ToolBoxDir)

from pyVecFEMP1Light.mesh import readFreeFEM
from pyVecFEMP1Light.pde import initPDE,setBC_PDE,buildPDEsystem,splitPDEsol
from pyVecFEMP1Light.operators import Hoperator,Loperator,StiffElasHoperators
from pyVecFEMP1Light.common import  mkdir_p,run_from_ipython
from pyVecFEMP1Light.graphics import PlotVal,PlotMesh,PlotBounds,showSparsity
from scipy.sparse.linalg import spsolve

MeshDir=ToolBoxDir+sep+'meshes'+sep+'2D'
MeshFile=MeshDir+sep+'bar4-15.msh'

plt.close("all")
d=2;m=2

E = 21e5; nu = 0.45; 
mu= E/(2*(1+nu))
lam = E*nu/((1+nu)*(1-2*nu))

Num=1 # 0 :alternate basis, 1: block basis

print('1. Get mesh %s using pyVecFEMP1Light'%MeshFile)
Th=readFreeFEM(MeshFile)
print('  -> Mesh sizes : nq=%d, nme=%d, nbe=%d'%(Th.nq,Th.nme,Th.nbe))

print('2. Set elasticity 2D problem')
Hop=Hoperator(d=d,m=m)
gam=lam+2*mu
Hop.H[0][0]=Loperator(d=d,A=[[gam,None],[None,mu]])
Hop.H[0][1]=Loperator(d=d,A=[[None,lam],[mu,None]]) 
Hop.H[1][0]=Loperator(d=d,A=[[None,mu],[lam,None]]) 
Hop.H[1][1]=Loperator(d=d,A=[[mu,None],[None,gam]])
Hop1=StiffElasHoperators(Th.d,lam,mu)
pde=initPDE(Hop1,Th)
pde.f=[0,-1]
pde=setBC_PDE(pde,4,[0,1],'Dirichlet',[0,0],None)

print('3. Build 2D elasticity linear system')
num=1 # num=1, block basis - num=0, alternate basis
A,b,ID,IDc,gD=buildPDEsystem(pde,Num=num)
plt.ion() # interactive mode
plt.figure(1)
showSparsity(A)
if num==1:
  plt.title(r"2D elasticity linear system : sparsity of the matrix (block basis)")
else:
  plt.title(r"2D elasticity linear system : sparsity of the matrix (alternate basis)")


print('4. Solve linear system')
X=np.zeros((pde.m*Th.nq,))
X[ID]=gD[ID]
bb=b[IDc]-A[IDc,::]*gD;
X[IDc]=spsolve((A[IDc])[::,IDc],bb)
U=splitPDEsol(pde,X,num)

plt.figure(2)
plt.clf()
PlotVal(Th,np.sqrt(U[0]**2+U[1]**2))
plt.title(r"2D elasticity - norm of the solution")
plt.axis('off')

q=Th.q
Th.q=Th.q+50*np.c_[U[0],U[1]]
plt.figure(3)
plt.clf()
plt.hold(True)
PlotMesh(Th)
PlotBounds(Th)
plt.axis('off')
plt.hold(False)
plt.title(r"2D elasticity - mesh deformation scaled by 50")
Th.q=q