#include <functional>
#include <memory>
#include "cpuidGet.h"
#include "lap.h"

#include<iostream>
#include<chrono>
#include<vector>
#include <cmath>

typedef int row;
typedef int col;
typedef double cost;

static SIMDFlags simd_flags = SIMDFlags();

template <typename F>
static always_inline double call_lap(int dim, const void *restrict cost_matrix, bool verbose,
                                     int *restrict row_ind, int *restrict col_ind,
                                     void *restrict u, void *restrict v) {
    double lapcost;
    bool hasAVX2 = simd_flags.hasAVX2();
    if (verbose) {
        printf("AVX2: %s\n", hasAVX2? "enabled" : "disabled");
    }
    auto cost_matrix_typed = reinterpret_cast<const F*>(cost_matrix);
    auto u_typed = reinterpret_cast<F*>(u);
    auto v_typed = reinterpret_cast<F*>(v);
    if (hasAVX2) {
        lapcost = lap<true>(dim, cost_matrix_typed, verbose, row_ind, col_ind, u_typed, v_typed);
    } else {
        lapcost = lap<false>(dim, cost_matrix_typed, verbose, row_ind, col_ind, u_typed, v_typed);
    }
    return lapcost;
}

int main(int argc, char** argv) {

    // squared matrix
    std::vector<std::vector<double>> costs = {
      {9.0, 7.6, 7.5,7.0},    // track 0
      {3.5, 8.5, 5.5,6.5},    // track 1
      {12.5, 9.5, 9.0,10.5},  // track 2
      {4.5, 11.0, 9.5,11.5},  // track 3
    };

    //more row than cols: More tracks than det 4x3
    // std::vector<std::vector<double>> costs = {
    //   {9.0, 7.6, 7.5},    // track 0
    //   {3.5, 8.5, 5.5},    // track 1
    //   {12.5, 9.5, 9.0},  // track 2
    //   {4.5, 11.0, 9.5},  // track 3
    // };

    //less row than cols: less tracks than det 3x4
    // std::vector<std::vector<double>> costs = {
    //   {9.0 , 7.6, 7.5, 7.0},   // track 0
    //   {3.5 , 8.5, 5.5, 6.5},   // track 1
    //   {12.5, 9.5, 9.0, 10.5},  // track 2
    // };
    // std::vector<std::vector<double>> costs = {
    //   {90, 76, 75, 70},    // track 0
    //   {35, 85, 55, 65},    // track 1
    //   {125, 95, 90, 105},  // track 2
    //   {45, 110, 95, 115},  // track 3
    // };
    bool PRINTCOST = true;
    int dim = std::max(costs.size(),costs[0].size());
    double missing_cost = 1000;
    cost **assigncost, *u, *v,lapcost;
    row i, *colsol;
    col j, *rowsol;
    assigncost = new cost*[dim];
    for (i = 0; i < dim; i++)
        assigncost[i] = new cost[dim];

    rowsol = new col[dim];
    colsol = new row[dim];
    u = new cost[dim];
    v = new cost[dim];

    for (i = 0; i < dim; i++){
      for (j = 0; j < dim; j++){
        if (i < costs.size() & j < costs[0].size()){
            assigncost[i][j] = costs[i][j];
        }else {
            assigncost[i][j] = missing_cost;
        }
      }
    }

    if (PRINTCOST){
        for (i = 0; i < dim; i++)
        {
        printf("\n");
        for (j = 0; j < dim; j++)
            printf("%f ", assigncost[i][j]);
        }
}

    printf("\nstart\n");
    std::cout << "Beginning lapjv method."<<std::endl;
    auto start = std::chrono::high_resolution_clock::now();
    bool verbose = true;
    lapcost = call_lap<double>(dim,
    static_cast<const void*>(assigncost ),
                            verbose,
                            rowsol ,
                            colsol ,
            static_cast<void*>( u ),
            static_cast<void*>( v ));
    auto end = std::chrono::high_resolution_clock::now();
    std::chrono::duration<double, std::milli> elapsed = end - start;
    double cost_tot = 0;
    for (int k = 0; k < dim; k++){
        i = rowsol[k];
        j = colsol[k];
        if(i < costs.size() &  j<costs[0].size())
        {
            std::cout << "Track " << i << " assigned to detect " << j
                                << ".  Cost: " << costs[i][j] <<std::endl;
            cost_tot += costs[i][j];
        }
    }
    std::cout << "Total cost = "<<cost_tot<<std::endl;

    printf("\n\ndim  %4d - lap cost %6.3f - runtime %6.3f ms\n", dim, lapcost,elapsed.count());

    delete[] assigncost;
    delete[] rowsol;
    delete[] colsol;
    delete[] u;
    delete[] v;

    return EXIT_SUCCESS;
    }



