import { GLShader } from '@zeainc/zea-engine'

// import './GLSL/stack-gl/inverse.js'
// import './GLSL/stack-gl/transpose.js'
// import './GLSL/materialparams.js'

class FEAShader extends GLShader {
  /**
   * Create a GL shader.
   * @param {WebGLRenderingContext} gl - The webgl rendering context.
   */
  constructor(gl) {
    super(gl)
    this.setShaderStage(
      'VERTEX_SHADER',
      `
precision highp float;
precision highp int;

uniform int elementIndexOffset;
uniform int elementNodeCount;
uniform mat4 modelMatrix;
uniform mat4 viewMatrix;
uniform mat4 projectionMatrix;

uniform vec3 cutPlaneNormal;
uniform float cutPlaneDist;

uniform sampler2D elementNodeIndicesTexture;
uniform sampler2D nodePositionsTexture;
uniform sampler2D nodeValuesTexture;

bool isCutaway(in vec3 worldPos) {
  return dot(worldPos, -cutPlaneNormal) > cutPlaneDist;
}

float cutFract(in vec3 pos0, in vec3 pos1) {
  vec3 vec = pos1 - pos0;
  float v0DotN = dot(pos0, -cutPlaneNormal);
  float v1DotN = dot(pos1, -cutPlaneNormal);
  return (cutPlaneDist - v0DotN) / (v1DotN - v0DotN);
}

float signedAngle(in vec3 vec0, in vec3 vec1) {
  float angle = acos(dot(vec0, vec1));
  if (dot(cross(vec0, vec1), -cutPlaneNormal) < 0.0)
    return -angle;
  else return angle;
}

int findCut(in bool positionCuts[4], bool used[4], inout int i) {
  for (; i<4; i++) {
    if (positionCuts[i] && !used[i]) return i;
  }
  return -1;
}

vec4 fetchNodePosition(ivec2 nodeDataCoords) {
  return vec4(texelFetch(nodePositionsTexture, nodeDataCoords, 0).xyz, 1.0);
}
float fetchNodeValue(ivec2 nodeDataCoords) {
  return texelFetch(nodeValuesTexture, nodeDataCoords, 0).r;
}

vec4 fetchNodePosition(int nodeId) {
  ivec2 nodeDataTexSize = textureSize(nodePositionsTexture, 0);
  ivec2 nodeDataCoords = ivec2(nodeId % nodeDataTexSize.y, nodeId / nodeDataTexSize.y);
  return fetchNodePosition(nodeDataCoords);
}
float fetchNodeValue(int nodeId) {
  ivec2 nodeDataTexSize = textureSize(nodePositionsTexture, 0);
  ivec2 nodeDataCoords = ivec2(nodeId % nodeDataTexSize.y, nodeId / nodeDataTexSize.y);
  return fetchNodeValue(nodeDataCoords);
}

int fetchNodeIndex(int elementId, int vertexId) {
  int offset = elementIndexOffset + (int(gl_InstanceID) * elementNodeCount) + vertexId;
  ivec2 elementDataTexSize = textureSize(elementNodeIndicesTexture, 0);
  ivec2 elementDataCoords = ivec2(offset % elementDataTexSize.y, offset / elementDataTexSize.y);
  return int(texelFetch(elementNodeIndicesTexture, elementDataCoords, 0).r);
}

/* VS Outputs */
varying float v_elementId;
varying vec3 v_viewPos;
varying vec3 v_worldPos;
varying float v_feaValue;
varying float v_isCut;


void main(void) {
  int elementId = elementIndexOffset + int(gl_InstanceID);
  int vertexId = int(gl_VertexID);

  // ivec2 elementDataTexSize = textureSize(elementNodeIndicesTexture, 0);
  // ivec2 elementDataCoords = ivec2(elementId % elementDataTexSize.y, elementId / elementDataTexSize.y);
  // ivec4 nodeIds = ivec4(texelFetch(elementNodeIndicesTexture, elementDataCoords, 0));
  

  int nodeId = fetchNodeIndex(0, vertexId);
  ivec2 nodeDataTexSize = textureSize(nodePositionsTexture, 0);
  ivec2 nodeDataCoords = ivec2(nodeId % nodeDataTexSize.y, nodeId / nodeDataTexSize.y);

  vec4 position;
  if (gl_VertexID < 4) {
    position = fetchNodePosition(nodeDataCoords);
    // position = vec4(vec3(0.0), 1.0);
    v_feaValue = fetchNodeValue(nodeDataCoords);
    v_isCut = 1.0;
    
    mat4 modelViewMatrix = viewMatrix * modelMatrix;
    vec4 viewPos    = modelViewMatrix * position;
    gl_Position     = projectionMatrix * viewPos;

    v_elementId     = float(elementId);
    v_viewPos       = -viewPos.xyz;
    v_worldPos      = (modelMatrix * position).xyz;

  } else {
    ivec4 nodeIds;
    nodeIds[0] = fetchNodeIndex(0, 0);
    nodeIds[1] = fetchNodeIndex(0, 1);
    nodeIds[2] = fetchNodeIndex(0, 2);
    nodeIds[3] = fetchNodeIndex(0, 3);
    
    vec3 positions[4];
    positions[0] = (modelMatrix * fetchNodePosition(nodeIds[0])).xyz;
    positions[1] = (modelMatrix * fetchNodePosition(nodeIds[1])).xyz;
    positions[2] = (modelMatrix * fetchNodePosition(nodeIds[2])).xyz;
    positions[3] = (modelMatrix * fetchNodePosition(nodeIds[3])).xyz;
   
    bool positionCuts[4];
    positionCuts[0] = isCutaway(positions[0]);
    positionCuts[1] = isCutaway(positions[1]);
    positionCuts[2] = isCutaway(positions[2]);
    positionCuts[3] = isCutaway(positions[3]);
    
    ivec2 edges[6];
    edges[0] = ivec2(0, 1);
    edges[1] = ivec2(1, 2);
    edges[2] = ivec2(0, 2);
    edges[3] = ivec2(0, 3);
    edges[4] = ivec2(1, 3);
    edges[5] = ivec2(2, 3);

    int cutEdges[4];
    int found = 0;
    for (int i=0; i<6; i++) {
      if (positionCuts[edges[i].x] && !positionCuts[edges[i].y] ||
          !positionCuts[edges[i].x] && positionCuts[edges[i].y]) {
        cutEdges[found] = i;
        found++;
      }
    }
    if (found > 0) {
      float cutEdgeFracts[4];
      vec3 cutEdgePos[4];
      float cutEdgeAngles[4];
      cutEdgeAngles[0] = 0.0;
      cutEdgeAngles[1] = 0.0;
      vec3 baseVertexPos;
      vec3 baseVec;
      for (int i=0; i<found; i++) {
        int cutEdgeId = cutEdges[i];
        vec3 pos0 = positions[edges[cutEdgeId].x];
        vec3 pos1 = positions[edges[cutEdgeId].y];
        cutEdgeFracts[i] = cutFract(pos0, pos1);
        vec3 pos = mix(pos0, pos1, cutEdgeFracts[i]);
        cutEdgePos[i] = pos;

        if (i==0) {
          baseVertexPos = pos;
        } else if (i==1) {
          baseVec = normalize(pos - baseVertexPos);
        } else {
          cutEdgeAngles[i] = signedAngle(baseVec, normalize(pos - baseVertexPos));
        }
      }
      
      int order[4];
      order[0] = 0; // Note: vertex 0 is not sorted.
      order[1] = 1;
      order[2] = 2;
      order[3] = 3;

      if (found > 3) 
      {
        if (cutEdgeAngles[order[1]] < cutEdgeAngles[order[2]]) {
          int tmp = order[1];
          order[1] = order[2];
          order[2] = tmp;
        }
        if (cutEdgeAngles[order[2]] < cutEdgeAngles[order[3]]) {
          int tmp = order[2];
          order[2] = order[3];
          order[3] = tmp;
        }
        if (cutEdgeAngles[order[1]] < cutEdgeAngles[order[2]]) {
          int tmp = order[1];
          order[1] = order[2];
          order[2] = tmp;
        }
      }

      // These fans can be 3-4 vertices. If 3, then the 4th vertex just is places on top of the 3rd.
      int vertexId = (gl_VertexID-4) < found ? (gl_VertexID-4) : found-1;
      int cutEdgeId = cutEdges[order[vertexId]];
      ivec2 cutEdge = edges[cutEdgeId];
      float cutEdgeFract = cutEdgeFracts[order[vertexId]];
      
      position =  vec4(cutEdgePos[order[vertexId]], 1.0);
      v_feaValue = mix(fetchNodeValue(nodeIds[cutEdge.x]), fetchNodeValue(nodeIds[cutEdge.y]), cutEdgeFract);
      v_isCut = 0.0;

      // Note: position is already in world space
      vec4 viewPos    = viewMatrix * position;
      gl_Position     = projectionMatrix * viewPos;

      v_elementId     = float(elementId);
      v_viewPos       = -viewPos.xyz;
      v_worldPos      = position.xyz;
    }
  }


}
`
    )

    this.setShaderStage(
      'FRAGMENT_SHADER',
      `
precision highp float;
precision highp int;

uniform int elementNodeCount;
uniform vec3 gradientRange;
uniform int bands;
uniform vec3 cutPlaneNormal;
uniform float cutPlaneDist;

uniform mat4 cameraMatrix;
uniform int nodeValuesTextureConnected;

<%include file="stack-gl/gamma.glsl"/>
<%include file="GLSLUtils.glsl"/>
<%include file="PBRSurfaceRadiance.glsl"/>

#if defined(DRAW_COLOR)
uniform sampler2D nodeValueColorGradient;
uniform float exposure;

#elif defined(DRAW_GEOMDATA)
uniform int passId;
uniform int itemId;
#elif defined(DRAW_HIGHLIGHT)
uniform vec4 highlightColor;
#endif

/* VS Outputs */
varying float v_elementId;
varying vec3 v_viewPos;
varying vec3 v_worldPos;
varying float v_feaValue;
varying float v_isCut;

out vec4 fragColor;
void main(void) {
  
  // /////////////////////
  // Cutaways
  // if (elementNodeCount >= 4 && v_isCut > 0.5) {
  if (v_isCut > 0.5) {
    if(dot(v_worldPos, -cutPlaneNormal) > cutPlaneDist){
      discard;
      return;
    }
  }

  float f = (v_feaValue - gradientRange.x) / (gradientRange.z - gradientRange.x);

  if (bands > 0) {
    f = round(f * float(bands)) / float(bands);
  }

  
  // the shader can be compiled in 3 different modes. 
  // In 'DRAW_COLOR' mode the final pixel color is output to the fragment.
#if defined(DRAW_COLOR)

  // /////////////////////
  // Material
  MaterialParams material;

  if (nodeValuesTextureConnected > 0) {
    vec2 texCoord = vec2(0.5, 1.0 - f);
    material.baseColor     = toLinear(texture2D(nodeValueColorGradient, texCoord).rgb);
  } else {
    material.baseColor     = toLinear(vec3(0.7, 0.7, 0.7));
  }
  material.metallic      = 0.0;
  material.roughness     = 0.75;
  material.reflectance   = 0.0;
  
  material.emission         = 0.0;
  material.opacity          = 1.0;
  material.ambientOcclusion = 1.0;

  // /////////////////////
  // PBR Lighting
  vec3 normal = normalize(cross(dFdx(v_worldPos), dFdy(v_worldPos)));
  vec3 viewVector = normalize(mat3(cameraMatrix) * normalize(v_viewPos));
  fragColor = pbrSurfaceRadiance(material, normal, viewVector);

  // Only make the tetraderons (volumetric elements) and above transparent. 
  if (elementNodeCount >= 4 && nodeValuesTextureConnected > 0)
  {
    // fragColor.a = f - 0.89;
    // if (fragColor.a < 0.0)  fragColor.a = -fragColor.a;
    
    if (v_feaValue < gradientRange.y) {
      fragColor.a = (gradientRange.y - v_feaValue)  / (gradientRange.y - gradientRange.x);
    } else {
      fragColor.a = (v_feaValue - gradientRange.y) / (gradientRange.z - gradientRange.y);
    }
  } else {
    fragColor.a = 1.0;
  }


  // /////////////////////
  // Gamma
  fragColor.rgb = toGamma(fragColor.rgb * exposure);

  // fragColor = vec4(material.baseColor, 1.0);
  
#elif defined(DRAW_GEOMDATA)
  vec3 camPos = vec3(cameraMatrix[3][0], cameraMatrix[3][1], cameraMatrix[3][2]);

  if (elementNodeCount >= 4)
  {
    float a;
    if (v_feaValue < gradientRange.y) {
      a = (gradientRange.y - v_feaValue)  / (gradientRange.y - gradientRange.x);
    } else {
      a = (v_feaValue - gradientRange.y) / (gradientRange.z - gradientRange.y);
    }
    if (a < 0.001) {
      discard;
      return;
    }
  }

  // The Geom Data buffer is an offscreen buffer that stores identifying 
  // information in each pixel for the geometry rasterized to that pixel.
  // The Viewport queries this buffer to detect pointer interactions. 
  // The red channel stores the pass id, which tells the Viewport which 
  // pass to call  getGeomItemAndDist.
  // The pass can decide what to pack in the other 3 pixels to identify 
  // the drawn geometry, and also the distance to the rendered fragment.
  fragColor.r = float(passId); 
  fragColor.g = float(itemId);
  fragColor.b = (f * (gradientRange.z - gradientRange.x)) + gradientRange.x;
  fragColor.a = length(camPos - v_worldPos);
#elif defined(DRAW_HIGHLIGHT)
  fragColor = highlightColor;
#endif
}
`
    )
  }

  /**
   * The bind method.
   * @param {object} renderstate - The object tracking the current state of the renderer
   * @param {string} key - The key value.
   * @return {any} - The return value.
   */
  bind(renderstate, key) {
    super.bind(renderstate, key)

    const gl = this.__gl
    if (renderstate.envMap) {
      renderstate.envMap.bind(renderstate)
    }

    const { exposure } = renderstate.unifs
    if (exposure) {
      gl.uniform1f(exposure.location, renderstate.exposure)
    }
    return true
  }
}

export { FEAShader }
