import { Quaternion, Matrix3, Matrix4 } from 'three';

import { CartesianPose } from '@sb/geometry';

import type { ArmJointPositions } from './ArmJointPositions';
import type { DHParams, Offsets } from './DHParams';
import { DEFAULT_DH_PARAMS, DEFAULT_OFFSETS } from './DHParams';

const { sin, cos } = Math;

export function forwardKinematics(
  jointAngles: ArmJointPositions,
  dhParams: DHParams = DEFAULT_DH_PARAMS,
  offsets: Offsets = DEFAULT_OFFSETS,
): CartesianPose {
  // eslint-disable-next-line no-param-reassign
  jointAngles = jointAngles.map(
    (angle, index) => angle - offsets[index],
  ) as ArmJointPositions;

  const s1 = sin(jointAngles[0]);
  const c1 = cos(jointAngles[0]);

  let q23 = jointAngles[1];
  let q234 = jointAngles[1];
  const s2 = sin(jointAngles[1]);
  const c2 = cos(jointAngles[1]);

  q23 += jointAngles[2];
  q234 += jointAngles[2];

  const s4 = sin(jointAngles[3]);
  const c4 = cos(jointAngles[3]);
  q234 += jointAngles[3];

  const s5 = sin(jointAngles[4]);
  const c5 = cos(jointAngles[4]);
  const s6 = sin(jointAngles[5]);
  const c6 = cos(jointAngles[5]);

  const s23 = sin(q23);
  const c23 = cos(q23);

  const s234 = sin(q234);
  const c234 = cos(q234);

  const r11 = c234 * c1 * s5 - c5 * s1;
  const r12 = c6 * (s1 * s5 + c234 * c1 * c5) - s234 * c1 * s6;
  const r13 = -s6 * (s1 * s5 + c234 * c1 * c5) - s234 * c1 * c6;
  const r21 = c1 * c5 + c234 * s1 * s5;
  const r22 = -c6 * (c1 * s5 - c234 * c5 * s1) - s234 * s1 * s6;
  const r23 = s6 * (c1 * s5 - c234 * c5 * s1) - s234 * c6 * s1;
  const r31 = -s234 * s5;
  const r32 = -c234 * s6 - s234 * c5 * c6;
  const r33 = s234 * c5 * s6 - c234 * c6;

  const { d1, a2, a3, d4, d5, d6 } = dhParams;

  const xx =
    d6 * c234 * c1 * s5 -
    a3 * c23 * c1 -
    a2 * c1 * c2 -
    d6 * c5 * s1 -
    d5 * s234 * c1 -
    d4 * s1;

  const yy =
    d6 * (c1 * c5 + c234 * s1 * s5) +
    d4 * c1 -
    a3 * c23 * s1 -
    a2 * c2 * s1 -
    d5 * s234 * s1;

  const zz =
    d1 +
    a3 * s23 +
    a2 * s2 -
    d5 * (c23 * c4 - s23 * s4) -
    d6 * s5 * (c23 * s4 + s23 * c4);

  // ThreeJS requires you to go from Matrix3 -> Matrix4 -> Quaternion
  const quaternion = new Quaternion();

  {
    const rotationMatrix = new Matrix3();
    rotationMatrix.set(r11, r12, r13, r21, r22, r23, r31, r32, r33);

    const mat4 = new Matrix4();
    mat4.setFromMatrix3(rotationMatrix);

    quaternion.setFromRotationMatrix(mat4);
  }

  return CartesianPose.parse({
    x: xx,
    y: yy,
    z: zz,
    w: quaternion.w,
    i: quaternion.x,
    j: quaternion.y,
    k: quaternion.z,
  });
}
