import * as poseDetection from "@tensorflow-models/pose-detection";
import DynamicTimeWarping from "dynamic-time-warping-ts";
import _ from "lodash";
import Vector2 from "./vector2";

export function getPosesFromVideo(
  video: HTMLVideoElement, detector: poseDetection.PoseDetector,
  fps = 10, detectionConfig: poseDetection.MoveNetEstimationConfig = {},
  onProgress?: (progress: number, done: boolean) => void,
) {
  const originalTime = video.currentTime;
  const interval = 1 / fps;
  const stepAmount = Math.ceil(video.duration / interval) + 1;    // +1 as last step should be resolving the promise
  const poses = Array<poseDetection.Pose>();

  return new Promise<poseDetection.Pose[]>((res, rej) => {
    video.pause();

    video.onseeked = async () => {
      try {
        let pose = (await detector.estimatePoses(video, detectionConfig))[0];
  
        if (!poses.length) {
          // Run the detector twice to make sure we get the correct pose, unaffected by smoothing
          pose = (await detector.estimatePoses(video, detectionConfig))[0];
        }

        if (!pose) {
          if (!poses.length) {
            rej(new Error("No pose detected on first frame."));
          } else {
            poses.push(_.last(poses)!);
          }
        } else {
          pose.keypoints = poseDetection.calculators.keypointsToNormalizedKeypoints(
            pose.keypoints, { width: video.videoWidth, height: video.videoHeight }
          );
          poses.push(pose);
        }
        
        onProgress?.(poses.length / stepAmount, false);
  
        if (video.currentTime + interval < video.duration) {
          video.currentTime += interval;
        } else {
          video.onseeked = null;
          video.currentTime = originalTime;
          res(poses);
          onProgress?.(1, true);
        }
      } catch (e) {
        rej(e);
      }
    };

    video.currentTime = 0;
  });
}

export function euclidianDistance(a: number[], b: number[]) {
  if (a.length !== b.length) {
    throw new Error("Euclidian distance requires same length vectors");
  } else {
    return Math.hypot(...a.map((c, i) => c - b[i]));
  }
}

export function flipPose(
  pose: poseDetection.Pose, bboxWidth: number, modelType?: poseDetection.SupportedModels
): poseDetection.Pose;
export function flipPose(
  pose: poseDetection.Pose[], bboxWidth: number, modelType?: poseDetection.SupportedModels
): poseDetection.Pose[];
export function flipPose(
  pose: poseDetection.Pose | poseDetection.Pose[],
  bboxWidth: number,
  modelType = poseDetection.SupportedModels.MoveNet
): poseDetection.Pose | poseDetection.Pose[] {
  if (Array.isArray(pose)) {
    return pose.map(p => flipPose(p, bboxWidth, modelType));
  } else {
    const midPoint = bboxWidth / 2;
    const LOOKUP = poseDetection.util.getKeypointIndexByName(modelType);
    return {
      ...pose,
      keypoints: pose.keypoints.map(kp => ({
        ...kp,
        x: midPoint + (midPoint - kp.x),
        name: kp.name?.includes("left") ? kp.name?.replace("left", "right") :
          kp.name?.includes("right") ? kp.name?.replace("right", "left") : kp.name
      })).sort((a, b) => LOOKUP[a.name!] - LOOKUP[b.name!])
    };
  }
}

export function getTorsoCenter(keypoints: poseDetection.Keypoint[]) {
  const torsoKps = keypoints.filter(kp => ["shoulder", "hip"].some(part => kp.name!.endsWith(part)));
  return Vector2.add(...Vector2.fromObjectArray(torsoKps)).div(torsoKps.length);
}

(window as any).getTorsoCenter = getTorsoCenter;

const MOVENET_KP_NAMES = [
  "nose", "left_eye", "right_eye", "left_ear", "right_ear", "left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
  "left_wrist", "right_wrist", "left_hip", "right_hip", "left_knee", "right_knee", "left_ankle", "right_ankle"
];

const MOVENET_HEAD_KP_NAMES = MOVENET_KP_NAMES.slice(0, 5);

export const MOVENET_NAMED_KP_PAIRS = Object.freeze({
  clavicle: ["left_shoulder", "right_shoulder"],
  left_torso: ["left_shoulder", "left_hip"],
  right_torso: ["right_shoulder", "right_hip"],
  pelvis: ["left_hip", "right_hip"],
  left_arm: ["left_shoulder", "left_elbow"],
  right_arm: ["right_shoulder", "right_elbow"],
  left_forearm: ["left_elbow", "left_wrist"],
  right_forearm: ["right_elbow", "right_wrist"],
  left_thigh: ["left_hip", "left_knee"],
  right_thigh: ["right_hip", "right_knee"],
  left_leg: ["left_knee", "left_ankle"],
  right_leg: ["right_knee", "right_ankle"],
});

/**
 * Gets the difference between two poses.
 * @returns A score between 0 and 1 indicating how different the poses are.
 */
export function computeStaticPoseDiff(
  a: poseDetection.Keypoint[],
  b: poseDetection.Keypoint[],
  params?: Partial<typeof DEFAULT_POSE_COMP_PARAMS>
) {
  // Spread callee params since we don't want to mutate the original
  const filledParams = _.defaults({ ...params }, DEFAULT_POSE_COMP_PARAMS);
  // Ignore head keypoints using filter, probably only works for MoveNet so if we change the model it might break
  let KP_IDX_PAIR = poseDetection.util
    .getAdjacentPairs(poseDetection.SupportedModels.MoveNet)
    .filter(pair => pair.every(idx => idx >= 5));

  if (filledParams.whitelist.length) {
    const whitelistedPairs = filledParams.whitelist.map(bone =>
      MOVENET_NAMED_KP_PAIRS[bone].map(kp => MOVENET_KP_NAMES.indexOf(kp))
    );
    KP_IDX_PAIR = KP_IDX_PAIR.filter(pair => whitelistedPairs.some(wPair => wPair.every(idx => pair.includes(idx))));
  }

  // Total deviation from the static pose in degrees
  let delta = 0;
  // This is for knowing what the max possible deviation is
  let compared = 0;

  for (const pair of KP_IDX_PAIR) {
    compared++;

    // If the keypoint is not detected, skip it
    if (
      pair.some(idx => a[idx].score! < filledParams.minScore)
      || pair.some(idx => b[idx].score! < filledParams.minScore)
    ) {
      delta += 180;
      continue;
    } else {
      delta += Vector2.angle(
        Vector2.fromObject(a[pair[0]]).sub(Vector2.fromObject(a[pair[1]])),
        Vector2.fromObject(b[pair[0]]).sub(Vector2.fromObject(b[pair[1]]))
      );
    }
  }

  return delta / (compared * 180);
}

const DEFAULT_POSE_COMP_PARAMS = {
  /** The minimum score for a keypoint to be considered visible. */
  minScore: 0.3,
  /** The minimum percentage of visible keypoints to consider a keypoint motion valid. */
  minVisible: 0.6,
  /** Only consider keypoint motions that has its name included in the whitelist. */
  whitelist: Array<keyof typeof MOVENET_NAMED_KP_PAIRS>()
};

export type PoseComparisonParams = Partial<typeof DEFAULT_POSE_COMP_PARAMS>;

export function computeDynamicPoseDiff(
  reference: { keypoints: poseDetection.Keypoint[], timestamp: number }[],
  comparison: { keypoints: poseDetection.Keypoint[], timestamp: number }[],
  params?: PoseComparisonParams
) {
  // Spread callee params since we don't want to mutate the original
  const filledParams = _.defaults({ ...params }, DEFAULT_POSE_COMP_PARAMS);
  const KP_IDX_MAP = poseDetection.util.getKeypointIndexByName(poseDetection.SupportedModels.MoveNet);
  const distArr = Array<number>();
  let start: number;

  if ((start = reference[0].timestamp) !== 0) {
    reference.forEach(frame => frame.timestamp -= start);
  }

  if ((start = comparison[0].timestamp) !== 0) {
    comparison.forEach(frame => frame.timestamp -= start);
  }

  const whitelistedPairs = filledParams.whitelist.map(bone => MOVENET_NAMED_KP_PAIRS[bone]);

  for (const key in KP_IDX_MAP) {
    if (
      !MOVENET_HEAD_KP_NAMES.includes(key)
      && (!whitelistedPairs.length || whitelistedPairs.some(wPair => wPair.includes(key)))
    ) {
      const cCurveVisible = comparison.filter(frame => frame.keypoints[KP_IDX_MAP[key]].score! >= filledParams.minScore);

      if (cCurveVisible.length / comparison.length >= filledParams.minVisible) {
        const rCurve = reference.map(pose =>
          [pose.keypoints[KP_IDX_MAP[key]].x, pose.keypoints[KP_IDX_MAP[key]].y, pose.timestamp]
        );
        const cCurve = cCurveVisible.map(pose =>
          [pose.keypoints[KP_IDX_MAP[key]].x, pose.keypoints[KP_IDX_MAP[key]].y, pose.timestamp]
        );
        distArr.push(new DynamicTimeWarping(rCurve, cCurve, euclidianDistance).getDistance());
      } else {
        distArr.push(comparison.length);
      }
    }
  }

  return _.mean(distArr.map(v => v / comparison.length));
}