import * as tf from "@tensorflow/tfjs";
import latestModelVersion from "../../latestModelVersion";
import { run } from "js-coroutines";
const model_url = process.env.PUBLIC_URL + "/models/model.json";
const laundryClasses = [
  "hand_wash",
  "no_machine_wash",
  "machine_wash",
  "machine_wash_cold",
  "machine_wash_warm",
  "machine_wash_hot",
  "perm_press",
  "delicate",
  "bleach",
  "bleach_non_chlorine",
  "no_bleach",
  "dryer",
  "dryer_low",
  "dryer_warm",
  "dryer_hot",
  "no_dryer",
  "iron",
  "iron_low",
  "iron_warm",
  "iron_hot",
  "no_iron",
  "line_dry",
  "flat_dry",
  "drip_dry",
  "no_dry_clean",
];

export async function loadModel() {
  return await tf
    .loadGraphModel("indexeddb://laundry-ai")
    .then(async (model) => {
      console.log(`Current TF Backend: ${tf.getBackend()}`);
      return _checkIfUpdateNeeded(model) ? await _getNewModel() : model;
    })
    .catch(async (e) => {
      // Throwing if model not found in indexDB
      return await _getNewModel();
    });
}
const { maxObjects, minScoreTh, iouTh, classes } = {
  maxObjects: 30,
  minScoreTh: 0.25,
  iouTh: 0.45,
  classes: laundryClasses,
  numClasses: 25,
};

// rework this logic
function _checkIfVersionExists(model) {
  if (model.metadata?.suds) {
    const version = model.metadata.suds?.version;
    if (version !== typeof "undefined" && version !== null) {
      return true;
    }
  }

  return false;
}

function _checkIfUpdateNeeded(model) {
  if (
    _checkIfVersionExists(model) &&
    parseFloat(model.metadata.suds?.version) >= parseFloat(latestModelVersion.version)
  ) {
    return false;
  }

  return true;
}

async function _getNewModel() {
  return await tf
    .loadGraphModel(model_url, { onProgress: (percent) => console.log(`Loaded Model: ${percent}`) })
    .then((model) => {
      model.save("indexeddb://laundry-ai");
      return model;
    });
}

/**
 * Forward Pass the input and returns the box of detected objects
 * @param {input} image, canvas, or video element
 * @param {flipHorizontal} flip the input image tensor for webcam
 */

async function detectAndBox(input, model) {
  return await run(function* () {
    if (process.env.NODE_ENV === "development") {
      const mem = tf.memory();
      console.log(`Num of tensors: ${mem.numTensors} , Memory: ${mem.numBytes / 1e6}`);
    }

    const imageTensor = tf.tidy(() =>
      tf.browser.fromPixels(input, 3).expandDims(0).cast("float32").div(tf.scalar(255))
    );
    const [height, width] = imageTensor.shape.slice(1, 3);
    const features = model.predict(imageTensor);
    imageTensor.dispose();

    // shape = [batch, totalGrids, 5 + numClasses]
    // x,y,w,h,conf,...classes

    const { boxes, scores: boxScores } = boxesAndScores(features, height, width);
    features.dispose();
    let boxCoord = [];
    let scores = [];
    let labelIdx = [];

    const yPred = tf.argMax(boxScores, -1);
    const boxPred = tf.max(boxScores, -1);

    const nmsIndex = yield tf.image.nonMaxSuppressionAsync(
      boxes,
      boxPred,
      maxObjects,
      iouTh,
      minScoreTh
    );

    if (nmsIndex.size) {
      const classBoxes = tf.gather(boxes, nmsIndex);
      const classBoxScores = tf.gather(boxPred, nmsIndex);
      const classLabels = yPred.gather(nmsIndex);

      yield Promise.all([
        Promise.all(
          classBoxes.split(nmsIndex.size).map(async (box) => {
            const val = await box.data();
            box.dispose();
            return val;
          })
        ),
        classBoxScores.data(),
        classLabels.data(),
      ]).then(([boxVals, scoreVals, labelVals]) => {
        boxCoord = boxVals;
        scores = scoreVals;
        labelIdx = labelVals;
      });

      classBoxScores.dispose();
      classBoxes.dispose();
      classLabels.dispose();
    }

    nmsIndex.dispose();

    boxPred.dispose();
    yPred.dispose();

    boxes.dispose();
    boxScores.dispose();

    return boxCoord.map((box, i) => {
      const top = box[0];
      const left = box[1];
      const bottom = box[2];
      const right = box[3];
      const height = bottom - top;
      const width = right - left;
      return {
        top,
        left,
        bottom,
        right,
        height,
        width,
        score: scores[i],
        label: classes[labelIdx[i]],
      };
    });
  });
}

function xywh2yxyx(xywh) {
  return tf.tidy(() => {
    const [x, y, w, h] = tf.split(xywh, [1, 1, 1, 1], -1);
    const halfW = w.div(tf.scalar(2));
    const halfH = h.div(tf.scalar(2));
    const yx1 = tf.concat([y.sub(halfH), x.sub(halfW)], -1);
    const yx2 = tf.concat([y.add(halfH), x.add(halfW)], -1);
    return tf.concat([yx1, yx2], -1);
  });
}

function boxesAndScores(features, height, width) {
  return tf.tidy(() => {
    const [xywhNormalized, conf, classes] = tf.split(features, [4, 1, 25], -1);

    const xywh = tf.tensor1d([width, height, width, height]).mul(xywhNormalized);
    const boxes = xywh2yxyx(xywh).squeeze(0);
    const scores = tf.mul(conf, classes).squeeze(0);
    return { boxes, scores };
  });
}

export { detectAndBox };
