/* eslint-disable camelcase */
import { Color, Registry } from '@zeainc/zea-engine'
import './GLSLCADConstants.js'
import './GLSLMath.js'
import './GLSLCADSurfaceDrawing.js'

const GLDrawCADSurfaceShader_VERTEX_SHADER = `
precision highp float;

<%include file="GLSLUtils.glsl"/>
<%include file="GLSLCADConstants.glsl"/>
<%include file="stack-gl/transpose.glsl"/>
<%include file="stack-gl/inverse.glsl"/>

attribute vec3 positions;
instancedattribute vec4 drawCoords;  // body ID, Surface index in Body, Surface Id, TrimSet Id
// instancedattribute vec2 drawItemTexAddr;  // Address of the data in the draw item texture. (mat4)

uniform mat4 viewMatrix;
uniform mat4 cameraMatrix;
uniform mat4 projectionMatrix;
uniform ivec2 quadDetail;
uniform vec3 assetCentroid;

// #define DEBUG_SURFACES
uniform int numSurfacesInLibrary;


<%include file="GLSLCADSurfaceDrawing.vertexShader.glsl"/>

varying vec4 v_drawCoords;
varying vec3 v_viewPos;
varying vec3 v_worldPos;
varying vec3 v_viewNormal;
varying vec2 v_textureCoord;
varying vec3 v_bodyDescAddr;
varying float v_surfaceType;
varying vec2 v_quadDetail;

void main(void) {
    int cadBodyId = ftoi(drawCoords.r);
    int drawItemIndexInBody = ftoi(drawCoords.g);
    int surfaceId = ftoi(drawCoords.b);
    int trimSetId = ftoi(drawCoords.a);

    vec2 texCoords = positions.xy + 0.5;
    
    v_drawCoords = drawCoords;

    vec4 cadBodyPixel0 = getCADBodyPixel(cadBodyId, 0);
    vec4 cadBodyPixel1 = getCADBodyPixel(cadBodyId, 1);

    // int bodyDescId = ftoi(cadBodyPixel0.r);
    int cadBodyFlags = ftoi(cadBodyPixel0.g);
    
    //////////////////////////////////////////////
    // Visibility
    if(testFlag(cadBodyFlags, BODY_FLAG_INVISIBLE)) {
        gl_Position = vec4(-3.0, -3.0, -3.0, 1.0);;
        return;
    }

    //////////////////////////////////////////////
    // Transforms
#ifdef DEBUG_SURFACES
    mat4 modelMatrix = mat4(1.0);
    // if(v_surfaceType == SURFACE_TYPE_NURBS_SURFACE) {
    //     // int drawItemIndexInBody = int(metadata.b+0.5);
    //     int sideLen = int(ceil(sqrt(float(numSurfacesInLibrary))));
    //     int x = drawItemIndexInBody % sideLen;
    //     int y = drawItemIndexInBody / sideLen;
    //     modelMatrix = mat4(1.0, 0.0, 0.0, 0.0, 
    //                     0.0, 1.0, 0.0, 0.0, 
    //                     0.0, 0.0, 1.0, 0.0,  
    //                     float(x), float(y), 0.0, 1.0);
    // }
#else

#ifdef CALC_GLOBAL_XFO_DURING_DRAW
    mat4 bodyMat = getCADBodyMatrix(cadBodyId);
    ivec2 bodyDescAddr = ftoi(cadBodyPixel0.ba);
    v_bodyDescAddr.xy = vec2(float(bodyDescAddr.x), float(bodyDescAddr.y));
    v_bodyDescAddr.z = float(drawItemIndexInBody);
    mat4 surfaceMat = getDrawItemMatrix(bodyDescAddr, drawItemIndexInBody);
    mat4 modelMatrix = bodyMat * surfaceMat;
#else
    mat4 modelMatrix = getModelMatrix();
    // Note: on mobile GPUs, we get only FP16 math in the
    // fragment shader, causing inaccuracies in modelMatrix
    // calculation. By offsetting the data to the origin
    // we calculate a modelMatrix in the asset space, and
    //  then add it back on during final drawing.
    // modelMatrix[3][0] += assetCentroid.x;
    // modelMatrix[3][1] += assetCentroid.y;
    // modelMatrix[3][2] += assetCentroid.z;
#endif
#endif
    // modelMatrix = mat4(1.0);
    mat4 modelViewMatrix = viewMatrix * modelMatrix;
    mat3 normalMatrix = mat3(transpose(inverse(modelViewMatrix)));

    //////////////////////////////////////////////
    // Vertex Attributes
    
    GLSLBinReader surfaceLayoutDataReader;
    GLSLBinReader_init(surfaceLayoutDataReader, surfaceAtlasLayoutTextureSize, 16);
    vec4 surfaceDataAddr = GLSLBinReader_readVec4(surfaceLayoutDataReader, surfaceAtlasLayoutTexture, surfaceId * 8);
    int surfaceFlags = GLSLBinReader_readInt(surfaceLayoutDataReader, surfaceAtlasLayoutTexture, surfaceId * 8 + 6);

    bool isFan = int(quadDetail.y) == 0;
    vec2 vertexCoords = texCoords * (isFan ? vec2(quadDetail) + vec2(1.0, 1.0) : vec2(quadDetail));
    vec4 surfaceVertex = getSurfaceVertex(surfaceDataAddr.xy, vertexCoords);
    v_surfaceType = surfaceVertex.a;
    vec3 normal  = getSurfaceNormal(surfaceDataAddr.xy, vertexCoords);
    vec4 pos     = vec4(surfaceVertex.rgb, 1.0);
    
    bool flippedNormal = testFlag(surfaceFlags, SURFACE_FLAG_FLIPPED_NORMAL);
    if(flippedNormal)
        normal = -normal;

    vec4 viewPos = modelViewMatrix * pos;
    v_viewPos    = viewPos.xyz;
    v_worldPos   = (modelMatrix * pos).xyz;
    gl_Position  = projectionMatrix * viewPos;
    v_viewNormal = normalMatrix * normal;

    v_quadDetail = vec2(quadDetail);

    {
        // Pull back facing vertices towards us ever so slightly...
        // This is to avoid z-fighting that occurs wehn we see the inside
        // of a surface that is resting on another surface.
        vec3 worldNormal = normalize(mat3(cameraMatrix) * v_viewNormal);

        vec3 viewVector = normalize(mat3(cameraMatrix) * normalize(-v_viewPos));
        float ndotv = dot(worldNormal, viewVector);
        bool backFacing = ndotv <= 0.0;
        if (backFacing) {
            // Pull backfacing vertices towards us ever so slightly...
            gl_Position.z *= 0.99999;
        }
    }

    if(isFan) {
        // We are drawing a Fan surface, so the uv coords
        // simply come from the vertex positions.
        v_textureCoord = positions.xy;
    }
    else {
        v_textureCoord = texCoords;
        if(testFlag(surfaceFlags, SURFACE_FLAG_FLIPPED_UV)) {
            v_textureCoord = vec2(v_textureCoord.y, v_textureCoord.x);
            v_quadDetail = vec2(v_quadDetail.y, v_quadDetail.x);
        }

        // v_textureCoord.y = 1.0 - v_textureCoord.y; // Flip y
    }
}`

const GLDrawCADSurfaceShader_FRAGMENT_SHADER = `
precision highp float;

<%include file="math/constants.glsl"/>
<%include file="GLSLUtils.glsl"/>
<%include file="stack-gl/gamma.glsl"/>
<%include file="materialparams.glsl"/>
<%include file="GGX_Specular.glsl"/>
<%include file="PBRSurfaceRadiance.glsl"/>

<%include file="GLSLCADConstants.glsl"/>
<%include file="GLSLBinReader.glsl"/>

uniform mat4 cameraMatrix;

uniform bool headLighting;
uniform bool displayWireframes;
uniform bool displayEdges;


#ifdef ENABLE_INLINE_GAMMACORRECTION
uniform float exposure;
#endif

varying vec4 v_drawCoords;
varying vec3 v_viewPos;
varying vec3 v_worldPos;
varying vec3 v_viewNormal;
varying vec2 v_textureCoord;
varying vec3 v_bodyDescAddr;
varying float v_surfaceType;
varying vec2 v_quadDetail;

vec3 getDebugColor(int id){
    
    int sel = int(round(mod(float(id), 14.0)));
    
    if(sel==0)
        return vec3(0.0, 1.0, 1.0);
    else if (sel==1)
        return vec3(0.0, 1.0, 0.0);
    else if (sel==2)
        return vec3(1.0, 0.0, 1.0);
    else if (sel==3)
        return vec3(0.75, 0.75, 0.0);
    else if (sel==4)
        return vec3(0.0, 0.75, 0.75);
    else if (sel==5)
        return vec3(0.75, 0.0, 0.75);
    else if (sel==6)
        return vec3(0.45, 0.95, 0.0);
    else if (sel==7)
        return vec3(0.0, 0.45, 0.95);
    else if (sel==8)
        return vec3(0.95, 0.0, 0.45);
    else if (sel==9)
        return vec3(0.95, 0.45, 0.0);
    else if (sel==10)
        return vec3(0.0, 0.95, 0.45);
    else if (sel==11)
        return vec3(0.45, 0.0, 0.95);
    else if (sel==12)
        return vec3(0.45, 0.45, 0.95);
    else if (sel==13)
        return vec3(0.0, 0.0, 0.45);
    else if (sel==14)
        return vec3(0.0, 0.45, 0.45);
    else if (sel==15)
        return vec3(0.45, 0.0, 0.45);
    else return vec3(0.2, 0.2, 0.2);
}

<%include file="GLSLCADSurfaceDrawing.fragmentShader.glsl"/>

// const float gridSize = 0.02;
const float gridSize = 0.2;

#ifdef ENABLE_ES3
out vec4 fragColor;
#endif

void main(void) {
    
    int cadBodyId = int(floor(v_drawCoords.r + 0.5));
    int drawItemIndexInBody = int(floor(v_drawCoords.g + 0.5));
    int surfaceId = int(floor(v_drawCoords.b + 0.5));
    int trimSetId = int(floor(v_drawCoords.a + 0.5));


    // TODO: pass as varying from pixel shader.
    vec4 cadBodyPixel0 = getCADBodyPixel(cadBodyId, 0);
    vec4 cadBodyPixel1 = getCADBodyPixel(cadBodyId, 1);

    int flags = int(floor(cadBodyPixel0.g + 0.5));
    vec2 materialCoords = cadBodyPixel1.xy;
    //////////////////////////////////////////////
    // Trimming
    vec4 trimPatchQuad;
    vec3 trimCoords;
    if(trimSetId >= 0) {
        GLSLBinReader trimsetLayoutDataReader;
        GLSLBinReader_init(trimsetLayoutDataReader, trimSetsAtlasLayoutTextureSize, 16);
        trimPatchQuad = GLSLBinReader_readVec4(trimsetLayoutDataReader, trimSetsAtlasLayoutTexture, trimSetId*4);

        if(applyTrim(trimPatchQuad, trimCoords, flags)){
            discard;
            return;
        }
    }

    ///////////////////////////////////////////
    // Normal

    vec3 normal = normalize(mat3(cameraMatrix) * v_viewNormal);
    vec3 viewNormal = normalize(v_viewNormal);

    vec3 viewVector = normalize(mat3(cameraMatrix) * normalize(-v_viewPos));
    bool backFacing = dot(normal, viewVector) <= 0.0;
    if(backFacing){
        normal = -normal;
        viewNormal = -viewNormal;
    }

    //////////////////////////////////////////////
    // Material

    vec4 matValue0 = getMaterialValue(materialCoords, 0);

    MaterialParams material;

    /////////////////
    bool clayRendering = false;
    
    material.baseColor             = matValue0.rgb;
    material.opacity               = matValue0.a;
    
    /////////////////
    // Face color
#ifdef ENABLE_PER_FACE_COLORS
    vec4 faceColor = getDrawItemColor(ftoi(v_bodyDescAddr.xy), ftoi(v_bodyDescAddr.z));
    material.baseColor = mix(material.baseColor, faceColor.rgb, faceColor.a);
#endif

    if(clayRendering) {
        material.baseColor          = vec3(0.45, 0.26, 0.13);
        material.opacity            = 1.0;
    } 

    //////////////////////////////////////////////
    // Cutaways
    // if (applyCutaway(cadBodyId, flags)) {
    //     discard;
    //     return;
    // }
    if (testFlag(flags, BODY_FLAG_CUTAWAY)) {
        vec4 cadBodyPixel6 = getCADBodyPixel(cadBodyId, 6);
        vec3 cutNormal = normalize(cadBodyPixel6.xyz);
        float cutPlaneDist = cadBodyPixel6.w;
        if (cutaway(v_worldPos, cutNormal, cutPlaneDist)) {
            discard;
            return;
        }
        // If we are not cutaway, but we can see a back facing face
        // then set the normal to the cut plane do the lighting is flat.
        if (backFacing){
            normal = cutNormal;
        }
    }

    /////////////////
    // Debug backFacing
    // if(backFacing) {
    //     material.baseColor = mix(material.baseColor, vec3(1.0, 0.0, 0.0), 0.5);
    // }

    /////////////////
    // Debug materialId
#ifdef DEBUG_MATERIALID
    {
        material.baseColor = vec3(float(int(materialCoords.x) % 5)/5.0, float(int(materialCoords.y) % 5)/5.0, 0.0);
    }
#endif

    /////////////////
    // Debug bodyId
#ifdef DEBUG_BODYID
    {
        material.baseColor       = getDebugColor(cadBodyId);
    }
#endif

    /////////////////
    // Debug drawItemIndexInBody
#ifdef DEBUG_SURFACEID
    {
        material.baseColor       = getDebugColor(drawItemIndexInBody);
    }
#endif

    /////////////////
    // Debug surface Type
#ifdef DEBUG_SURFACETYPE
    {
        material.baseColor       = getDebugColor(v_surfaceType);
    }
#endif

    /////////////////
    // bool flippedNormal = testFlag(flags, SURFACE_FLAG_FLIPPED_NORMAL);
    // if(flippedNormal) {
    //    material.baseColor = mix(material.baseColor, vec3(1,0,0), 0.75);
    // }

    // if (backFacing) {
    //     material.baseColor = mix(material.baseColor, vec3(1,0,0), 0.75);
    // }

    /////////////////
    // Debug UV layout.
    // {
    //     material.baseColor = vec3(v_textureCoord.x);
    //     // material.baseColor.r = mix(0.0, 1.0, v_textureCoord.x);
    //     // material.baseColor.g = mix(0.0, 1.0, v_textureCoord.y);
    // }

    /////////////////
    // if(testFlag(flags, SURFACE_FLAG_FLIPPED_UV)){
    //     material.baseColor = mix(material.baseColor, vec3(1,1,1), 0.5);
    // }

    /////////////////
    // if(v_quadDetail.x > 512.0 || v_quadDetail.y > 512.0){
    //     material.baseColor = mix(material.baseColor, vec3(1,0,0), 0.75);
    // } else {
    //     // discard;
    // }
    
    /////////////////
    // Debug trim texture.
#ifdef DEBUG_TRIMTEXELS
    if(trimCoords.x >= 0.0) {
        // trimCoords = (trimPatchQuad.xy + 0.5) + ((trimPatchQuad.zw - 0.5) * v_textureCoord);
        trimCoords.xy = trimPatchQuad.xy + (trimPatchQuad.zw * v_textureCoord);
        vec2 trimUv = (trimCoords.xy) / vec2(trimSetAtlasTextureSize);
        vec4 trimTexel = texture2D(trimSetAtlasTexture, trimUv);

        vec2 texelOffset = trimCoords.xy - (floor(trimCoords.xy) + 0.5);
        float texelDist = length(texelOffset);
        
        material.baseColor = trimTexel.rgb * texelDist;

        // if (trimTexel.r > 0.5 && trimTexel.g > 0.5) {
        //     material.baseColor = vec3(0,0,0);
        // }

        // material.baseColor = mix(material.baseColor, vec3(0,0,0), texelDist);
        // material.baseColor = mix(material.baseColor, vec3(0,0,0), trimCoords.z);
        // material.baseColor = mix(material.baseColor, vec3(0,0,0), (trimCoords.z < 0.5) ? 1.0 : 0.0);

        // if(trimCoords.z < 0.5) {
        //     material.baseColor = mix(material.baseColor, vec3(0,0,0), 0.1);
        // }
        // else{
        //     float total = floor(trimCoords.x) +
        //                   floor(trimCoords.y);
        //     if(mod(total,2.0)==0.0)
        //         material.baseColor = mix(material.baseColor, vec3(0,0,0), 0.25);
        //     else
        //         material.baseColor = mix(material.baseColor, vec3(1,1,1), 0.25);
        // }
    }
#endif


    
    //////////////////////////////////////////////
    // Transparency
    // Simple screen door transparency.
    // float threshold = gridSize * opacity * (1.0 - (v_viewPos.z / 300.0));
    // // if(mod(v_viewPos.x / v_viewPos.z, gridSize) > threshold || mod(v_viewPos.y/v_viewPos.z, gridSize) > threshold)// || mod(v_viewPos.z, gridSize) > threshold)
    // if(mod(abs(v_worldPos.x), gridSize) > threshold || mod(abs(v_worldPos.y), gridSize) > threshold || mod(abs(v_worldPos.z), gridSize) > threshold)
    //     discard;


    ///////////////////////////////////////////
    // Lighting
    vec3 radiance;

    vec4 matValue1;
    if(clayRendering)
        matValue1          = vec4(0.0, 0.9, 0.1, 0.0);
    else
        matValue1          = getMaterialValue(materialCoords, 1);

    material.metallic       = matValue1.r;
    material.roughness      = matValue1.g;
    material.reflectance    = matValue1.b;
    material.emission       = matValue1.a;

#ifndef ENABLE_ES3
    vec4 fragColor;
#endif
    fragColor = pbrSurfaceRadiance(material, normal, viewVector);

    /////////////////////////////
    // fragColor = vec4(material.baseColor, 1.0);
    // fragColor = vec4( normalize(viewNormal), 1.0);
    // fragColor = vec4( normalize(normal), 1.0);

    // fragColor = vec4(sampleEnvMap(viewNormal, material.roughness), 1.0);;
    
    ////////////////////
    {
        // vec4 wireColor = vec4(0.1, 0.1, 0.1, 1.0);
        //vec4 wireColor = vec4(0.6, 0.6, 0.6, 1.0);
        vec4 wireColor = vec4(0.0, 0.0, 0.0, 1.0);
        
        vec2 vertexCoords = v_textureCoord * v_quadDetail;
        vec2 vcD = fwidth(vertexCoords);
        vec2 vcW = fract(vertexCoords);

        bool isFan = v_quadDetail.y < 0.5;
        if(displayWireframes) {
            if (isFan) {

            } else {
        
                float lerpVal = smoothstep(0.0, vcD.x, vcW.x) * smoothstep(1.0, 1.0 - vcD.x, vcW.x) * smoothstep(0.0, vcD.y, vcW.y) * smoothstep(1.0, 1.0 - vcD.y, vcW.y);
                
                // Display a thin line at 50% opacity.
                fragColor = mix(fragColor, wireColor, (1.0-smoothstep(0.0, 0.5, lerpVal)) * 0.5 );
        
                //fragColor = mix(fragColor, wireColor, (mod(vertexCoords.x, 2.0) < 1.0) ? 0.5 : 0.0 );
            }
        }
    }

#ifdef ENABLE_INLINE_GAMMACORRECTION
    fragColor.rgb = toGamma(fragColor.rgb * exposure);
#endif

#ifndef ENABLE_ES3
    gl_FragColor = fragColor;
#endif
}
`

import { GLCADShader } from './GLCADShader.js'

/** Class representing a GL draw CAD surface shader.
 * @extends GLCADShader
 * @ignore
 */
class GLDrawCADSurfaceShader extends GLCADShader {
  /*
   * Create a GL draw CAD surface shader.
   * @param {any} gl - The gl value.
   */
  constructor(gl) {
    super(gl)
    this.setShaderStage('VERTEX_SHADER', GLDrawCADSurfaceShader_VERTEX_SHADER)
    this.setShaderStage('FRAGMENT_SHADER', GLDrawCADSurfaceShader_FRAGMENT_SHADER)
  }

  /**
   * The bind method.
   * @param {object} renderstate - The object tracking the current state of the renderer
   * @param {string} key - The key value.
   * @return {any} - The return value.
   */
  bind(renderstate, key) {
    super.bind(renderstate, key)

    const gl = this.__gl
    if (renderstate.envMap) {
      renderstate.envMap.bind(renderstate)
    }

    const { exposure } = renderstate.unifs
    if (exposure) {
      gl.uniform1f(exposure.location, renderstate.exposure)
    }
    return true
  }

  /**
   * The getParamDeclarations method.
   * @return {any} - The return value.
   */
  static getParamDeclarations() {
    const paramDescs = super.getParamDeclarations()
    paramDescs.push({
      name: 'BaseColor',
      defaultValue: new Color(1.0, 1.0, 0.5),
    })
    paramDescs.push({
      name: 'EmissiveStrength',
      defaultValue: 0.0,
    })
    paramDescs.push({
      name: 'Metallic',
      defaultValue: 0.0,
    })
    paramDescs.push({
      name: 'Roughness',
      defaultValue: 0.25,
    })
    paramDescs.push({
      name: 'Normal',
      defaultValue: new Color(0.0, 0.0, 0.0),
    })
    paramDescs.push({
      name: 'TexCoordScale',
      defaultValue: 1.0,
      texturable: false,
    })
    // F0 = reflectance and is a physical property of materials
    // It also has direct relation to IOR so we need to dial one or the other
    // For simplicity sake, we don't need to touch this value as metalic can dictate it
    // such that non metallic is mostly around (0.01-0.025) and metallic around (0.7-0.85)
    paramDescs.push({
      name: 'Reflectance',
      defaultValue: 0.025,
    })
    return paramDescs
  }

  /**
   * The getPackedMaterialData method.
   * @param {any} material - The material param.
   * @return {any} - The return value.
   */
  static getPackedMaterialData(material) {
    const matData = new Float32Array(8)
    const baseColor = material.getParameter('BaseColor').getValue()
    matData[0] = baseColor.r
    matData[1] = baseColor.g
    matData[2] = baseColor.b
    matData[3] = baseColor.a
    if (material.getParameter('EmissiveStrength')) {
      matData[4] = material.getParameter('Metallic').getValue()
      matData[5] = material.getParameter('Roughness').getValue()
      matData[6] = material.getParameter('Reflectance').getValue()
      matData[7] = material.getParameter('EmissiveStrength').getValue()
    } else {
      matData[5] = 1.0
    }
    return matData
  }
}

Registry.register('GLDrawCADSurfaceShader', GLDrawCADSurfaceShader)

export { GLDrawCADSurfaceShader_VERTEX_SHADER, GLDrawCADSurfaceShader_FRAGMENT_SHADER, GLDrawCADSurfaceShader }
