import * as poseDetection from "@tensorflow-models/pose-detection";
import Vector2 from "./vector2";

const DEFAULT_DRAW_PARAMS = {
  threshold: 0.3,
  lineWidth: 2,
  circleRadius: 4,
  scale: new Vector2(1, 1),
  leftKpColor: "Green",
  midKpColor: "Red",
  rightKpColor: "Orange",
  skeletonColor: "White"
};

export function drawPose(
  ctx: CanvasRenderingContext2D,
  model: poseDetection.SupportedModels,
  pose: poseDetection.Pose,
  params?: Partial<Omit<typeof DEFAULT_DRAW_PARAMS, "scale"> & { scale: Vector2 | number }>
): void
export function drawPose(
  ctx: CanvasRenderingContext2D,
  model: poseDetection.SupportedModels,
  pose: poseDetection.Pose[],
  params?: Partial<Omit<typeof DEFAULT_DRAW_PARAMS, "scale"> & { scale: Vector2 | number }>
): void
export default function drawPose(
  ctx: CanvasRenderingContext2D,
  model: poseDetection.SupportedModels,
  pose: poseDetection.Pose | poseDetection.Pose[],
  params?: Partial<Omit<typeof DEFAULT_DRAW_PARAMS, "scale"> & { scale: Vector2 | number }>
) {
  if (!pose) {
    console.log("No pose to draw");
    return;
  }

  const filledParams = {
    ...DEFAULT_DRAW_PARAMS,
    ...params,
    scale: params?.scale && typeof params.scale === "number"
      ? new Vector2(params.scale, params.scale) : (params?.scale ?? DEFAULT_DRAW_PARAMS.scale) as Vector2
  };
  ctx.clearRect(0, 0, ctx.canvas.width, ctx.canvas.height);

  if (Array.isArray(pose)) {
    for (const p of pose) {
      if (p.keypoints.length) {
        drawSkeleton(ctx, model, p.keypoints, filledParams);
        drawKeypoints(ctx, model, p.keypoints, filledParams);
      }
    }
  } else {
    drawSkeleton(ctx, model, pose.keypoints, filledParams);
    drawKeypoints(ctx, model, pose.keypoints, filledParams);
  }
}

function drawKeypoints(
  ctx: CanvasRenderingContext2D,
  model: poseDetection.SupportedModels,
  keypoints: poseDetection.Keypoint[],
  params = DEFAULT_DRAW_PARAMS
) {
  const keypointInd = poseDetection.util.getKeypointIndexBySide(model);
  ctx.fillStyle = params.midKpColor;
  ctx.strokeStyle = params.skeletonColor;
  ctx.lineWidth = 3;

  for (const i of keypointInd.middle) {
    drawKeypoint(ctx, keypoints[i], params);
  }

  ctx.fillStyle = params.leftKpColor;
  for (const i of keypointInd.left) {
    drawKeypoint(ctx, keypoints[i], params);
  }

  ctx.fillStyle = params.rightKpColor;
  for (const i of keypointInd.right) {
    drawKeypoint(ctx, keypoints[i], params);
  }
}

function drawKeypoint(ctx: CanvasRenderingContext2D, keypoint: poseDetection.Keypoint, params = DEFAULT_DRAW_PARAMS) {
  if ((keypoint.score ?? 1) >= params.threshold) {
    const circle = new Path2D();
    circle.arc(keypoint.x * params.scale.x, keypoint.y * params.scale.y, params.circleRadius, 0, 2 * Math.PI);
    ctx.fill(circle);
    ctx.stroke(circle);
  }
}

function drawSkeleton(ctx: CanvasRenderingContext2D, model: poseDetection.SupportedModels, keypoints: poseDetection.Keypoint[], params = DEFAULT_DRAW_PARAMS) {
  ctx.fillStyle = params.skeletonColor;
  ctx.strokeStyle = params.skeletonColor;
  ctx.lineWidth = params.lineWidth;

  poseDetection.util.getAdjacentPairs(model).forEach(([i, j]) => {
    const kp1 = keypoints[i];
    const kp2 = keypoints[j];

    // If score is null, just show the keypoint.
    const score1 = kp1.score ?? 1;
    const score2 = kp2.score ?? 1;

    if (score1 >= params.threshold && score2 >= params.threshold) {
      ctx.beginPath();
      ctx.moveTo(kp1.x * params.scale.x, kp1.y * params.scale.y);
      ctx.lineTo(kp2.x * params.scale.x, kp2.y * params.scale.y);
      ctx.stroke();
    }
  });
}