/*
=============================================================================
  sign_glove_inference.ino
  Arabic Sign Language Glove — ESP32 TFLite Micro Inference
=============================================================================
*/
#include "esp_wifi.h"
#include <TensorFlowLite_ESP32.h>
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/schema/schema_generated.h"

#include "sign_model.h"
#include "preprocessing.h"

#include <MPU6050_tockn.h>
#include <Wire.h>

MPU6050 mpu(Wire);

// =============================================================================
// HARDWARE CONFIG
// =============================================================================
const int FLEX_PINS[5] = {39, 33, 35, 34, 32};

#define FLEX_MIN  200
#define FLEX_MAX  1350

#define SAMPLE_INTERVAL_MS  20
#define DOWNSAMPLE_STEP     2

#define CONFIDENCE_THRESHOLD  0.55f

const char* CLASS_NAMES[28] = {
  "AAIN",  "ALEF",  "BEH",   "DAAD",  "DAAL",
  "GEEM",  "GHEEN", "HAHH",  "HEEHH", "KAAF",
  "KHAA",  "LAAM",  "MEEM",  "NOON",  "QAAF",
  "RAAH",  "SAAD",  "SEEN",  "SHEEN", "TAH",
  "TEH",   "THEH",  "WOW",   "YEH",   "ZAAL",
  "ZAH",   "ZEEN",  "FEH"
};

// =============================================================================
// TFLite Micro — globals
// Arena is heap-allocated dynamically at runtime based on available memory.
// Size is determined by querying largest free block — never hardcoded.
// =============================================================================
static uint8_t* tensor_arena      = nullptr;
static size_t   tensor_arena_size = 0;

static tflite::MicroErrorReporter micro_error_reporter;
static tflite::ErrorReporter*     error_reporter = &micro_error_reporter;

static tflite::MicroMutableOpResolver<20> resolver;

static const tflite::Model*      tfl_model     = nullptr;
static TfLiteTensor*             input_tensor  = nullptr;
static TfLiteTensor*             output_tensor = nullptr;

// Interpreter lives in a global aligned buffer — prevents stack corruption
// and null-pointer crashes during Invoke().
static uint8_t interp_buf[sizeof(tflite::MicroInterpreter)] __attribute__((aligned(16)));
static tflite::MicroInterpreter* interpreter = nullptr;

// =============================================================================
// FEATURE ENGINEERING
// =============================================================================
#define DEBUG_FEATURES false

float normFlex(int adc) {
  float v = (float)(adc - FLEX_MIN) / (float)(FLEX_MAX - FLEX_MIN);
  return fmaxf(0.0f, fminf(1.0f, v));
}

void computeFeatures(int* flex_raw,
                     float ax, float ay, float az,
                     float gx, float gy, float gz,
                     float acc_mag, float gyro_mag,
                     float* f)
{
  float fn[5];
  for (int i = 0; i < 5; i++) { fn[i] = normFlex(flex_raw[i]); f[i] = fn[i]; }
  for (int i = 0; i < 5; i++) f[5+i] = flex_raw[i] / 4095.0f;

  f[10] = fn[1]-fn[0]; f[11] = fn[2]-fn[1];
  f[12] = fn[3]-fn[2]; f[13] = fn[4]-fn[3];

  float fm = (fn[0]+fn[1]+fn[2]+fn[3]+fn[4]) / 5.0f;
  f[14] = fm;

  float fv = 0;
  for (int i = 0; i < 5; i++) fv += (fn[i]-fm)*(fn[i]-fm);
  f[15] = sqrtf(fv / 5.0f);

  f[16]=ax; f[17]=ay; f[18]=az;
  f[19]=gx; f[20]=gy; f[21]=gz;
  f[22]=acc_mag; f[23]=gyro_mag;

  f[24]=ax*fm; f[25]=ay*fm; f[26]=az*fm;
  f[27]=gx*fm; f[28]=gy*fm; f[29]=gz*fm;

  for (int i = 0; i < 5; i++) f[30+i] = fn[i]*fn[i];

  f[35] = acc_mag*acc_mag;
  f[36] = gyro_mag*gyro_mag;

  float fmax=fn[0], fmin_v=fn[0];
  for (int i=1;i<5;i++){
    if(fn[i]>fmax) fmax=fn[i];
    if(fn[i]<fmin_v) fmin_v=fn[i];
  }
  f[37] = fmax - fmin_v;
  f[38] = fn[0]+fn[1]+fn[2]+fn[3]+fn[4];

  float d = acc_mag*acc_mag;
  f[39] = (d > 1e-6f) ? (ax*ax + ay*ay) / d : 0.0f;

  f[40] = ax*ax;
  f[41] = ay*ay;
  f[42] = az*az;
}

// =============================================================================
// PREPROCESSING
// =============================================================================
void preprocessStep(float* raw43, int8_t* out40) {
  for (int fi = 0; fi < N_FEATURES; fi++) {
    float val = raw43[FEATURES_KEPT[fi]];
    float xc  = fmaxf(CLIP_LOW[fi], fminf(CLIP_HIGH[fi], val));
    float xs  = 2.0f*(xc - CLIP_LOW[fi]) / (CLIP_HIGH[fi] - CLIP_LOW[fi]) - 1.0f;
    float xq  = xs / INPUT_SCALE + (float)INPUT_ZP;
    out40[fi] = (int8_t)fmaxf(-128.0f, fminf(127.0f, roundf(xq)));
  }
}

// =============================================================================
// SEQUENCE BUFFER
// =============================================================================
static int8_t        seq_buf[SEQ_LEN][N_FEATURES];
static int           step_count  = 0;
static int           raw_counter = 0;
static unsigned long last_tick   = 0;

// =============================================================================
// INFERENCE
// =============================================================================
void runInference() {
  if (!input_tensor || !output_tensor) {
    Serial.println("ERROR: null tensor pointer");
    return;
  }

  for (int t = 0; t < SEQ_LEN; t++)
    for (int f = 0; f < N_FEATURES; f++)
      input_tensor->data.int8[t * N_FEATURES + f] = seq_buf[t][f];

  unsigned long t0 = micros();
  if (interpreter->Invoke() != kTfLiteOk) {
    Serial.println("ERROR: Invoke() failed");
    return;
  }
  unsigned long lat_ms = (micros() - t0) / 1000;

  float logits[N_CLASSES];
  for (int c = 0; c < N_CLASSES; c++)
    logits[c] = (output_tensor->data.int8[c] - OUTPUT_ZP) * OUTPUT_SCALE;

  float mx = logits[0];
  for (int c = 1; c < N_CLASSES; c++) if (logits[c] > mx) mx = logits[c];
  float se = 0;
  for (int c = 0; c < N_CLASSES; c++) se += expf(logits[c] - mx);

  int   top[3]  = {0, 1, 2};
  float topp[3] = {-1.0f, -1.0f, -1.0f};
  for (int c = 0; c < N_CLASSES; c++) {
    float p = expf(logits[c] - mx) / se;
    if (p > topp[0]) {
      topp[2]=topp[1]; top[2]=top[1];
      topp[1]=topp[0]; top[1]=top[0];
      topp[0]=p;       top[0]=c;
    } else if (p > topp[1]) {
      topp[2]=topp[1]; top[2]=top[1];
      topp[1]=p;       top[1]=c;
    } else if (p > topp[2]) {
      topp[2]=p; top[2]=c;
    }
  }

  Serial.println("──────────────────────────────");
  if (topp[0] >= CONFIDENCE_THRESHOLD) {
    Serial.printf("  SIGN : %-8s  %.1f%%\n", CLASS_NAMES[top[0]], topp[0]*100);
  } else {
    Serial.printf("  SIGN : uncertain (best: %s %.1f%%)\n",
                  CLASS_NAMES[top[0]], topp[0]*100);
  }
  Serial.printf("  Top3 : %s(%.0f%%)  %s(%.0f%%)  %s(%.0f%%)\n",
    CLASS_NAMES[top[0]], topp[0]*100,
    CLASS_NAMES[top[1]], topp[1]*100,
    CLASS_NAMES[top[2]], topp[2]*100);
  Serial.printf("  Time : %lu ms\n", lat_ms);
}

// =============================================================================
// SETUP
// =============================================================================
void setup() {

  // ═══════════════════════════════════════════════════════════════════════════
  esp_wifi_stop();
  esp_wifi_deinit();
  // STEP 1 — Allocate arena BEFORE anything else touches the heap.
  // We measure the largest free block dynamically and use that minus a small
  // overhead margin. This is the only reliable approach on a non-PSRAM ESP32.
  // ═══════════════════════════════════════════════════════════════════════════
  // REPLACE the entire arena allocation block with this:
{
  // Try PSRAM first
  size_t psram_block = heap_caps_get_largest_free_block(MALLOC_CAP_SPIRAM);
  if (psram_block > 1024) {
    tensor_arena_size = psram_block - 8;
    tensor_arena = (uint8_t*)heap_caps_malloc(
        tensor_arena_size, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
  }

  // Fall back to internal RAM with manual alignment
  if (!tensor_arena) {
    size_t largest = heap_caps_get_largest_free_block(MALLOC_CAP_8BIT);

    // Allocate largest-8 (only subtract the actual malloc header, nothing more)
    size_t alloc_size = (largest > 8) ? (largest - 8) : 0;
    static uint8_t* raw_ptr = nullptr;   // keep raw pointer so we never lose it
    raw_ptr = (uint8_t*)heap_caps_malloc(alloc_size, MALLOC_CAP_8BIT);

    if (raw_ptr) {
      // Manually align to 16 bytes inside the buffer
      uintptr_t addr         = (uintptr_t)raw_ptr;
      uintptr_t aligned_addr = (addr + 15u) & ~15u;   // round up to 16
      size_t    wasted       = aligned_addr - addr;    // 0-15 bytes

      tensor_arena      = (uint8_t*)aligned_addr;
      tensor_arena_size = alloc_size - wasted;         // remaining usable bytes
    }
  }
}

  // ═══════════════════════════════════════════════════════════════════════════
  // STEP 2 — Start Serial and report memory state
  // ═══════════════════════════════════════════════════════════════════════════
  Serial.begin(115200);
  delay(1000);

  Serial.println("\n=== Arabic Sign Language Glove ===");
  Serial.printf("  Model      : %u bytes\n", sign_model_len);
  Serial.printf("  Classes    : %d\n", N_CLASSES);
  Serial.println("\n--- Memory report after arena alloc ---");
  Serial.printf("  Arena ptr  : %s\n", tensor_arena ? "OK" : "NULL");
  Serial.printf("  Arena size : %u bytes (%.1f KB)\n",
                tensor_arena_size, tensor_arena_size / 1024.0f);
  Serial.printf("  Free heap  : %u bytes\n", esp_get_free_heap_size());
  Serial.printf("  Lgst block : %u bytes\n",
                heap_caps_get_largest_free_block(MALLOC_CAP_8BIT));
  Serial.println("---------------------------------------");

  if (!tensor_arena || tensor_arena_size < 90000) {
    Serial.println("FATAL: Arena allocation failed or too small.");
    Serial.println("  Your ESP32 does not have enough contiguous RAM.");
    Serial.println("  Solution: Use an ESP32-WROVER (has 4MB PSRAM).");
    while (1) delay(100);
  }

  // ═══════════════════════════════════════════════════════════════════════════
  // STEP 3 — Register ops
  // ═══════════════════════════════════════════════════════════════════════════
  resolver.AddFullyConnected();
  resolver.AddReshape();
  resolver.AddSoftmax();
  resolver.AddQuantize();
  resolver.AddDequantize();
  resolver.AddUnidirectionalSequenceLSTM();
  resolver.AddConv2D();
  resolver.AddDepthwiseConv2D();
  resolver.AddMean();
  resolver.AddPack();
  resolver.AddUnpack();
  resolver.AddAdd();
  resolver.AddMul();
  resolver.AddRelu();
  resolver.AddStridedSlice();
  resolver.AddConcatenation();
  resolver.AddTranspose();
  resolver.AddSplit();
  resolver.AddLogistic();
  resolver.AddTanh();

  // ═══════════════════════════════════════════════════════════════════════════
  // STEP 4 — Hardware init
  // ═══════════════════════════════════════════════════════════════════════════
  analogReadResolution(12);
  for (int i = 0; i < 5; i++)
    analogSetPinAttenuation(FLEX_PINS[i], ADC_11db);

  Wire.begin(21, 22);
  Wire.setClock(100000);
  Wire.setTimeOut(50);
  mpu.begin();
  Serial.print("  Calibrating gyro (keep glove still)...");
  mpu.calcGyroOffsets(true);
  Serial.println(" done");

  Serial.println("\n--- Memory report after MPU init ---");
  Serial.printf("  Free heap  : %u bytes\n", esp_get_free_heap_size());
  Serial.printf("  Lgst block : %u bytes\n",
                heap_caps_get_largest_free_block(MALLOC_CAP_8BIT));
  Serial.println("-------------------------------------");

  // ═══════════════════════════════════════════════════════════════════════════
  // STEP 5 — Load model and create interpreter
  // ═══════════════════════════════════════════════════════════════════════════
  tfl_model = tflite::GetModel(sign_model);
  if (tfl_model->version() != TFLITE_SCHEMA_VERSION) {
    Serial.println("ERROR: schema version mismatch");
    while (1) delay(100);
  }

  Serial.printf("  Creating interpreter (arena = %u bytes)...\n",
                tensor_arena_size);

  interpreter = new(interp_buf) tflite::MicroInterpreter(
    tfl_model, resolver, tensor_arena, tensor_arena_size, error_reporter
  );

  Serial.println("  Running AllocateTensors()...");
  TfLiteStatus alloc_status = interpreter->AllocateTensors();

  Serial.println("\n--- Memory report after AllocateTensors ---");
  Serial.printf("  Free heap  : %u bytes\n", esp_get_free_heap_size());
  Serial.printf("  Arena used : %u bytes\n", interpreter->arena_used_bytes());
  Serial.printf("  Arena size : %u bytes\n", tensor_arena_size);
  Serial.printf("  Margin     : %d bytes\n",
                (int)tensor_arena_size - (int)interpreter->arena_used_bytes());
  Serial.println("-------------------------------------------");

  if (alloc_status != kTfLiteOk) {
    Serial.println("FATAL: AllocateTensors() failed.");
    Serial.println("  The arena is too small for this model.");
    Serial.println("  You need an ESP32-WROVER with PSRAM.");
    while (1) delay(100);
  }

  input_tensor  = interpreter->input(0);
  output_tensor = interpreter->output(0);

  Serial.printf("  Input  : [%d,%d,%d] type=%d\n",
    input_tensor->dims->data[0],
    input_tensor->dims->data[1],
    input_tensor->dims->data[2],
    input_tensor->type);
  Serial.printf("  Output : [%d,%d] type=%d\n",
    output_tensor->dims->data[0],
    output_tensor->dims->data[1],
    output_tensor->type);

  Serial.println("\n  Ready — hold a sign for ~5 seconds\n");
  last_tick = millis();
}

// =============================================================================
// LOOP
// =============================================================================
void loop() {
  unsigned long now = millis();
  if (now - last_tick < SAMPLE_INTERVAL_MS) return;
  last_tick = now;

  int flex_raw[5];
  for (int i = 0; i < 5; i++)
    flex_raw[i] = analogRead(FLEX_PINS[i]);

  mpu.update();
  float ax = mpu.getAccX(), ay = mpu.getAccY(), az = mpu.getAccZ();
  float gx = mpu.getGyroX(), gy = mpu.getGyroY(), gz = mpu.getGyroZ();
  float acc_mag  = sqrtf(ax*ax + ay*ay + az*az);
  float gyro_mag = sqrtf(gx*gx + gy*gy + gz*gz);

  raw_counter++;
  if (raw_counter % DOWNSAMPLE_STEP != 0) return;

  float feat43[N_FEATURES_RAW];
  computeFeatures(flex_raw, ax, ay, az, gx, gy, gz, acc_mag, gyro_mag, feat43);

  #if DEBUG_FEATURES
  if (step_count < 3) {
    Serial.printf("[step %d] ", step_count);
    for (int i = 0; i < N_FEATURES_RAW; i++) Serial.printf("%.3f ", feat43[i]);
    Serial.println();
  }
  #endif

  int8_t step40[N_FEATURES];
  preprocessStep(feat43, step40);
  memcpy(seq_buf[step_count], step40, N_FEATURES);
  step_count++;

  if (step_count >= SEQ_LEN) {
    runInference();
    step_count  = 0;
    raw_counter = 0;
  }
}