/*******************************************************************
 **                                                               **
 **  Copyright(C) 2023 Ouster Inc. All Rights Reserved.           **
 **  Contact: https://ouster.io                                   **
 **                                                               **
 *******************************************************************/

import {
  DEFAULT_POINT_COLOR,
  DEFAULT_POINT_SIZE,
} from '@ouster/webviz/build/src/utils/pointcloud';
import {
  Points,
  BufferGeometry,
  BufferAttribute,
  Float32BufferAttribute,
  Uint8BufferAttribute,
  DynamicDrawUsage,
  Color,
  Matrix4,
  ShaderMaterial,
  UniformsUtils,
} from 'three';
import { colormaps } from '@ouster/webviz';
import { Preset } from '../app/components/pane/cloudsDisplay/CloudColorPresets';
import {
  HEIGHT_RANGE,
  MAX_CLOUDS_IN_NODE,
  MAX_NODE_RESOLUTION,
  MAX_SENSOR_RESOLUTION,
} from '../constants';
import { CloudDisplayMode, CloudDisplayModes, MinMax } from '../types';

/**
 * @brief Default settings for uniforms to use for the shader.
 * @note If you use this for other constants, be sure to deep copy.
 */
const uniforms = {
  colorMin: { value: new Color(DEFAULT_POINT_COLOR) },
  colorMax: { value: new Color(DEFAULT_POINT_COLOR) },
  uMin: { value: 0 },
  uMax: { value: 1 },
  pointSize: { value: DEFAULT_POINT_SIZE },
  colorMode: { value: CloudDisplayModes.indexOf('Fixed') },
  transform: { value: new Matrix4() },
  heightMin: { value: HEIGHT_RANGE.default.min },
  heightMax: { value: HEIGHT_RANGE.default.max },
  calRef: {
    value: new Float32Array([...colormaps.calRefColorMap.palette.flat()]),
  },
};

/**
 * @brief Deep copy of uniforms plus extrinsics, to support nodes.
 */
const nodeUniforms = {
  colorMin: { value: new Color(DEFAULT_POINT_COLOR) },
  colorMax: { value: new Color(DEFAULT_POINT_COLOR) },
  uMin: { value: 0 },
  uMax: { value: 1 },
  pointSize: { value: DEFAULT_POINT_SIZE },
  colorMode: { value: CloudDisplayModes.indexOf('Fixed') },
  transform: { value: new Matrix4() },
  heightMin: { value: HEIGHT_RANGE.default.min },
  heightMax: { value: HEIGHT_RANGE.default.max },
  calRef: {
    value: new Float32Array([...colormaps.calRefColorMap.palette.flat()]),
  },
  extrinsics: { value: new Float32Array() },
};

type Uniform = typeof uniforms;
type NodeUniform = typeof nodeUniforms;

// Note: These conditionals are ok since they resolve
// to the same thing for each pixel. Making this code conditional
// free is not worth the readability hit.
// More here: https://stackoverflow.com/a/37837060
const shaderGetFragmentColourHelper = `
vec3 getFragmentColor(vec3 positionWorld, float positionWorldMagnitude) {
  float colorInterpolation = 0.;
  float scalingRange = uMax - uMin;

  if (colorMode == 0) // HEIGHT
  {
    colorInterpolation = (positionWorld.z - uMin) / scalingRange;
  }
  else if (colorMode == 1) // RANGE
  {
    colorInterpolation = ( positionWorldMagnitude - uMin) / scalingRange;
  }
  else if (colorMode == 2) // SIGNAL
  {
    colorInterpolation = (signal - uMin) / scalingRange;
  }
  else if (colorMode == 3) // REFLECTIVITY
  {
    colorInterpolation = (reflectivity - uMin) / scalingRange;
  }
  // (colorMode == 4) => CALIBRATED REFLECTIVITY
  else if (colorMode == 5) // NEARIR
  {
    colorInterpolation = (nearIR - uMin) / scalingRange;
  }
  else if (colorMode == 6) // FIXED
  {
    colorInterpolation = 0.;
  }

  colorInterpolation = clamp(colorInterpolation, 0., 1.);
  return (colorMode == 4) ? calRef[int(reflectivity)] : mix(colorMin, colorMax, colorInterpolation);
}`;

/**
 * @brief This vertex shader is for basic clouds that don't require
 * extra extrinsics like Node Clouds do.
 * It is an attribute vertix shader containing common attributes like:
 * - positionWorld
 * - reflectivityVarying
 * - nearIRVarying
 * - isignalVarying
 */
const vertexShaderCloud = `
  uniform float pointSize;
  uniform float heightMin;
  uniform float heightMax;

  // For fragment shader colour computation
  uniform float uMin;
  uniform float uMax;
  uniform vec3 colorMin;
  uniform vec3 colorMax;
  uniform int colorMode;
  uniform vec3 calRef[256];

  attribute float signal;
  attribute float reflectivity;
  attribute float nearIR;

  varying vec4 fragmentColorVarying;

  ${shaderGetFragmentColourHelper}

  void main() {
    vec3 positionWorld = (modelMatrix *  vec4( position, 1.0 )).xyz;
    float positionWorldMagnitude = length( positionWorld );

    // Don't render the point if below min or above max.
    gl_PointSize = pointSize * float((positionWorld.z > heightMin) && (positionWorld.z < heightMax));

    gl_Position = projectionMatrix * modelViewMatrix * vec4( position, 1.0 );

    fragmentColorVarying = vec4(
      getFragmentColor(positionWorld, positionWorldMagnitude), 
      1.0
    );
  }
`;

/**
 * @brief This vertex shader is for compound clouds that require extra
 * extrinsics, like Node clouds.
 * It is an attribute vertix shader containing common attributes like:
 * - positionWorld
 * - reflectivityVarying
 * - nearIRVarying
 * - isignalVarying
 */
const vertexShaderNode = `
  uniform float pointSize;
  uniform float heightMin;
  uniform float heightMax;
  uniform float extrinsics[7 * ${MAX_CLOUDS_IN_NODE}];

  // For fragment shader colour computation
  uniform float uMin;
  uniform float uMax;
  uniform vec3 colorMin;
  uniform vec3 colorMax;
  uniform int colorMode;
  uniform vec3 calRef[256];

  attribute float signal;
  attribute float reflectivity;
  attribute float nearIR;
  attribute float transformIndex;

  varying vec4 fragmentColorVarying;

  ${shaderGetFragmentColourHelper}

  mat4 fromTranslationRotation(vec3 translation, vec4 rotation) {
    float x2 = rotation.x + rotation.x;
    float y2 = rotation.y + rotation.y;
    float z2 = rotation.z + rotation.z;
    float xx = rotation.x * x2;
    float xy = rotation.x * y2;
    float xz = rotation.x * z2;
    float yy = rotation.y * y2;
    float yz = rotation.y * z2;
    float zz = rotation.z * z2;
    float wx = -rotation.w * x2;
    float wy = -rotation.w * y2;
    float wz = -rotation.w * z2;

    return mat4(
        vec4(1.0 - (yy + zz), xy - wz, xz + wy, 0.0),
        vec4(xy + wz, 1.0 - (xx + zz), yz - wx, 0.0),
        vec4(xz - wy, yz + wx, 1.0 - (xx + yy), 0.0),
        vec4(translation, 1.0)
    );
  }

  void main() {
    int exIndex = int(transformIndex);

    // Unflatten the extrinsics corresponding to this point
    vec3 translation = vec3(
      extrinsics[exIndex * 7 + 0],
      extrinsics[exIndex * 7 + 1],
      extrinsics[exIndex * 7 + 2]
    );
    vec4 rotation = vec4(
      extrinsics[exIndex * 7 + 3],
      extrinsics[exIndex * 7 + 4],
      extrinsics[exIndex * 7 + 5],
      extrinsics[exIndex * 7 + 6]
    );

    mat4 extrinsicsMatrix = fromTranslationRotation(translation, rotation);

    // Apply those extrinsics
    vec4 updatedPositionWorld = extrinsicsMatrix * vec4(position, 1.0);
    vec3 positionWorld = updatedPositionWorld.xyz;
    float positionWorldMagnitude = length(positionWorld);

    // Don't render the point if below min or above max.
    gl_PointSize = pointSize * float((positionWorld.z > heightMin) && (positionWorld.z < heightMax));

    gl_Position = projectionMatrix * modelViewMatrix * updatedPositionWorld;

    fragmentColorVarying = vec4(
      getFragmentColor(positionWorld, positionWorldMagnitude), 
      1.0
    );
  }
`;

const fragmentShader = `
  varying vec4 fragmentColorVarying;

  void main( void ) {
    gl_FragColor = fragmentColorVarying;
  }
`;

class CloudMaterial extends ShaderMaterial {
  constructor() {
    super({
      vertexShader: vertexShaderCloud,
      fragmentShader,
      uniforms: UniformsUtils.clone(uniforms),
    });
  }
}
class NodeMaterial extends ShaderMaterial {
  constructor() {
    super({
      vertexShader: vertexShaderNode,
      fragmentShader,
      uniforms: UniformsUtils.clone(nodeUniforms),
    });
  }
}

type Material = CloudMaterial | NodeMaterial;

type Attributes = {
  position: BufferAttribute;
  reflectivity: BufferAttribute;
  signal: BufferAttribute;
  nearIR: BufferAttribute;
  transformIndex?: BufferAttribute;
};

export class Pointcloud3JS extends Points {
  public updateBoundingVolumes = true;
  private maxPoints: number;

  constructor(isSensor = false) {
    super(
      new BufferGeometry(),
      isSensor ? new CloudMaterial() : new NodeMaterial(),
    );

    const attributes = this.geometry.attributes as Attributes;
    this.maxPoints = isSensor ? MAX_SENSOR_RESOLUTION : MAX_NODE_RESOLUTION;

    attributes.position = new Float32BufferAttribute(this.maxPoints * 3, 3);
    attributes.position.setUsage(DynamicDrawUsage);

    attributes.reflectivity = new Uint8BufferAttribute(this.maxPoints, 1);
    attributes.reflectivity.setUsage(DynamicDrawUsage);

    attributes.signal = new Uint8BufferAttribute(this.maxPoints, 1);
    attributes.signal.setUsage(DynamicDrawUsage);

    attributes.nearIR = new Uint8BufferAttribute(this.maxPoints, 1);
    attributes.nearIR.setUsage(DynamicDrawUsage);

    if (!isSensor) {
      attributes.transformIndex = new Uint8BufferAttribute(this.maxPoints, 1);
      attributes.transformIndex.setUsage(DynamicDrawUsage);
    }

    (this.material as Material).transparent = true;
  }

  setPoints = (position: Float32Array): void => {
    if (position.length > this.maxPoints) {
      console.warn('More values have been supplied than allocated');
      position = position.slice(0, this.maxPoints);
    }

    const attributes = this.geometry.attributes as Attributes;
    attributes.position.set(position);
    this.geometry.setDrawRange(0, position.length / 3 - 1);
    // set dirty flags
    attributes.position.needsUpdate = true;

    if (this.updateBoundingVolumes) {
      this.geometry.computeBoundingBox();
      this.geometry.computeBoundingSphere();
    }
  };

  setAttributes = (
    position: Float32Array,
    signal: Uint8Array,
    reflectivity: Uint8Array,
    nearIR: Uint8Array,
    transformIndex?: Uint8Array,
  ): void => {
    if (
      position.length > this.maxPoints * 3 ||
      signal.length > this.maxPoints ||
      reflectivity.length > this.maxPoints ||
      nearIR.length > this.maxPoints
    ) {
      console.error('Too many values for cloud attributes have been supplied.');
    }

    const attributes = this.geometry.attributes as Attributes;

    attributes.position.set(position);
    attributes.reflectivity.set(reflectivity);
    attributes.signal.set(signal);
    attributes.nearIR.set(nearIR);
    if (transformIndex) {
      attributes.transformIndex?.set(transformIndex);
    }

    this.geometry.setDrawRange(0, position.length / 3 - 1);

    if (this.updateBoundingVolumes) {
      this.geometry.computeBoundingBox();
      this.geometry.computeBoundingSphere();
    }
    // set dirty flags
    attributes.position.needsUpdate = true;
    attributes.reflectivity.needsUpdate = reflectivity.length > 0;
    attributes.signal.needsUpdate = signal.length > 0;
    attributes.nearIR.needsUpdate = nearIR.length > 0;
    if (attributes.transformIndex) {
      attributes.transformIndex.needsUpdate = true;
    }
  };

  setMin = (min: number): void => {
    const m = (this.material as Material).uniforms as Uniform;
    m.uMin.value = min;
  };

  setMax = (max: number): void => {
    const m = (this.material as Material).uniforms as Uniform;
    m.uMax.value = max;
  };

  setColorMin = (colorMin: string): void => {
    const m = (this.material as Material).uniforms as Uniform;
    m.colorMin.value.setStyle(colorMin);
  };

  setColorMax = (colorMax: string): void => {
    const m = (this.material as Material).uniforms as Uniform;
    m.colorMax.value.setStyle(colorMax);
  };

  setColorMode = (colorMode: CloudDisplayMode): void => {
    const m = (this.material as Material).uniforms as Uniform;
    m.colorMode.value = CloudDisplayModes.indexOf(colorMode);
  };

  setHeightRange = (range: MinMax): void => {
    const m = (this.material as Material).uniforms as Uniform;
    m.heightMin.value = range.min;
    m.heightMax.value = range.max;
    (this.material as Material).needsUpdate = true;
  };

  setPointSize = (pointSize: number): void => {
    const m = (this.material as Material).uniforms as Uniform;
    m.pointSize.value = pointSize;
  };

  setPreset = (preset: Preset): void => {
    const m = (this.material as Material).uniforms as Uniform;
    m.uMin.value = preset.min;
    m.uMax.value = preset.max;
    m.colorMin.value.setStyle(preset.colorMin);
    m.colorMax.value.setStyle(preset.colorMax);
  };

  setExtrinsics = (extrinsics: Float32Array): void => {
    const m = (this.material as NodeMaterial).uniforms as NodeUniform;
    m.extrinsics.value = extrinsics;
  };
}
