import numpy as np
import pyHyperMesh.CartesianGrid as CG
from pyHyperMesh.EltMesh import EltMesh,LabelBaseName
from pyHyperMesh.colors import select_colors
from pyHyperMesh.tools import set_axes_equal

class OrthMesh:
  def __init__(self,d,N,**kwargs):
    ctype=kwargs.get('type', 'simplicial' )
    box=kwargs.get('box',np.ones((d,1))*np.array([0,1]))
    N=np.atleast_1d(N)
    assert(len(N)==1 or len(N)==d)
    if len(N)==1:
      N=N[0]*np.ones(d)
    if (ctype=='orthotope'):
      funMesh = lambda N: CG.TessHyp(N)
      funFaces = lambda N,m: CG.TessFaces(N,m)
    else: # default 'simplicial'
      funMesh = lambda N: CG.Triangulation(N)
      funFaces = lambda N,m: CG.TriFaces(N,m)
    self.d=d
    self.box=np.array(box)
    self.type=ctype
    [q,me]=funMesh(N)
    #mapping=lambda Q: np.dot(np.diag(1/N),Q) # mapping to unit hypercube
    mapping=lambda Q: MappingBox(Q,N,box)
    #labName=lambda i: LabelBaseName(d,d)+'_{%d}'%i
    self.Mesh=EltMesh(d,d,mapping(q),me,None,label=1)
    self.Faces=[]
    for m in np.arange(d-1,-1,-1):
      sTh=funFaces(N,m)
      nsTh=len(sTh)
      colors = select_colors(nsTh)
      Fh=[]
      for j in np.arange(nsTh):
        Fh.append(EltMesh(d,m,mapping(sTh[j].q),sTh[j].me,sTh[j].toGlobal,label=j,color=colors[j]))
      self.Faces.append(Fh)
   
  def getFacesIndex(self,m): # must be improved
    A=np.array(np.arange(self.d-1,-1,-1))
    return np.where(A==m)[0][0]
   
  def plot(self,**kwargs):
    m=kwargs.get('m', self.d );kwargs.pop('m',None)
    legend=kwargs.get('legend', False);kwargs.pop('legend',None)
    #print(kwargs)
    import matplotlib.pyplot as plt
    fig=plt.gcf()
    #ax=plt.gca()
    #ax.patch.set_facecolor('None')
    #fig.patch.set_visible(False) # No background
    if m==self.d:
      Legend_handle,Labels=self.Mesh.plot(**kwargs)
      Legend_handle=[Legend_handle]
      Labels=[Labels]
    else:
      Legend_handle=[];Labels=[]
      idx=self.getFacesIndex(m)
      F=self.Faces[idx]
      nF=len(F);
      for i in range(nF):
        Leg,Lab=F[i].plot(**kwargs)
        Legend_handle.append(Leg)
        Labels.append(Lab)
    fig = plt.gcf()
    ax=fig.axes[0]
    if self.d==3:
      ax.set_zlim(self.box[2,0],self.box[2,1])
    ax.set_xlim(self.box[0,0],self.box[0,1])
    ax.set_ylim(self.box[1,0],self.box[1,1]) 
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    if self.d==3:
      set_axes_equal(ax)
      ax.set_zlabel('z')
    else:  
      ax.set_aspect('equal') 
    #-> No background
    ax.patch.set_facecolor('None')
    fig.patch.set_visible(False) 
    #<-
    #print(Labels)
    if legend:
      plt.legend(Legend_handle,Labels,loc='best', ncol=int(len(Legend_handle)/10)+1).draggable()
      
def MappingBox(q,N,box):
  box=np.array(box)
  d=len(N)
  for i in range(d):
    q[i]=box[i,0]+(box[i,1]-box[i,0])/N[i]*q[i]
  return q