/**
 * router.cpp
 *
 * for Vivado HLS
 */

#ifdef SOFTWARE
#include "ap_int.h"
#else
#include <ap_int.h>
#endif

#include "./router.hpp"

// Set weight
ap_uint<8> new_weight(ap_uint<16> x) {
#pragma HLS INLINE
    // K. Terada: y = 1~32 (8bit)
    ap_uint<8> y;
    y = ((x & 255) >> 3) + 1;
    return y;
}


// Global values
static ap_uint<7> size_x;    // X
static ap_uint<7> size_y;    // Y
static ap_uint<4> size_z;    // Z

static ap_uint<LINE_BIT> line_num = 0; // #Lines

#ifdef DEBUG_PRINT
int max_queue_length;    // Max length of priority queue
int max_search_count;    // Max count of queue pop
int max_buffer_length;   // Max length of line buffer
#endif


bool pynqrouter(char boardstr[BOARDSTR_SIZE], ap_uint<32> seed, ap_int<32> *status) {
#pragma HLS INTERFACE s_axilite port=boardstr bundle=AXI4LS
#pragma HLS INTERFACE s_axilite port=seed bundle=AXI4LS
#pragma HLS INTERFACE s_axilite port=status bundle=AXI4LS
#pragma HLS INTERFACE s_axilite port=return bundle=AXI4LS

    // status(0:Solved, 1:Not solved)
    *status = -1;

    // For all lines
    ap_uint<CELL_BIT> paths[MAX_BUFFER];    // Line buffer

    // For each line
    // Note: Should not partition completely
    bool adjacents[MAX_LINES];              // Line has adjacent terminals?
    ap_uint<CELL_BIT> starts[MAX_LINES];    // Start list
    ap_uint<CELL_BIT> goals[MAX_LINES];     // Goal list
    ap_uint<BUFF_BIT> s_idx[MAX_LINES];     // Start point on line buffer

    ap_uint<8> weights[MAX_CELLS];          // Weight of each cell
    // Note: Should not partition weight array
    // since each element will be accessed in "random" order


    // ================================
    // (Step.0) Initialization (BEGIN)
    // ================================

    // Note: Loop counter -> need an extra bit (for condition determination)

    INIT_WEIGHTS:
    for (ap_uint<CELL_BIT> i = 0; i < (ap_uint<CELL_BIT>)(MAX_CELLS); i++) {
#pragma HLS UNROLL factor=2
        weights[i] = 1;
    }

    /// Parse ///
    size_x = (boardstr[1] - '0') * 10 + (boardstr[2] - '0');
    size_y = (boardstr[4] - '0') * 10 + (boardstr[5] - '0');
    size_z = (boardstr[7] - '0');

    INIT_BOARDS:
    for (ap_uint<CELL_BIT> idx = 8; idx < (ap_uint<CELL_BIT>)(BOARDSTR_SIZE); idx+=11) {

    	// NULL-terminated
        if (boardstr[idx] == 0) break;

        // Start & Goal of each line
        ap_uint<7> s_x = (boardstr[idx+1] - '0') * 10 + (boardstr[idx+2] - '0');
        ap_uint<7> s_y = (boardstr[idx+3] - '0') * 10 + (boardstr[idx+4] - '0');
        ap_uint<3> s_z = (boardstr[idx+5] - '0') - 1;
        ap_uint<7> g_x = (boardstr[idx+6] - '0') * 10 + (boardstr[idx+7] - '0');
        ap_uint<7> g_y = (boardstr[idx+8] - '0') * 10 + (boardstr[idx+9] - '0');
        ap_uint<3> g_z = (boardstr[idx+10] - '0') - 1;

        ap_uint<CELL_BIT> start_id = (((ap_uint<CELL_BIT>)s_x * MAX_WIDTH + (ap_uint<CELL_BIT>)s_y) << BITWIDTH_Z) | (ap_uint<CELL_BIT>)s_z;
        ap_uint<CELL_BIT> goal_id  = (((ap_uint<CELL_BIT>)g_x * MAX_WIDTH + (ap_uint<CELL_BIT>)g_y) << BITWIDTH_Z) | (ap_uint<CELL_BIT>)g_z;
        starts[line_num] = start_id;
        goals[line_num] = goal_id;
        weights[start_id] = MAX_WEIGHT;
        weights[goal_id] = MAX_WEIGHT;

        // Line has adjacent terminals?
        adjacents[line_num] = false;
        ap_int<8> dx = (ap_int<8>)g_x - (ap_int<8>)s_x; // Min: -71, Max: 71 (Signed 8bit)
        ap_int<8> dy = (ap_int<8>)g_y - (ap_int<8>)s_y; // Min: -71, Max: 71 (Signed 8bit)
        ap_int<4> dz = (ap_int<4>)g_z - (ap_int<4>)s_z; // Min:  -7, Max:  7 (Signed 4bit)
        if ((dx == 0 && dy == 0 && (dz == 1 || dz == -1)) || (dx == 0 && (dy == 1 || dy == -1) && dz == 0) || ((dx == 1 || dx == -1) && dy == 0 && dz == 0)) {
            adjacents[line_num] = true;
        }

        line_num++;
    }

    // ================================
    // (Step.0) Initialization (END)
    // ================================

#ifdef DEBUG_PRINT
    max_queue_length = 0;
    max_search_count = 0;
    max_buffer_length = 0;
#endif

    ap_uint<BUFF_BIT> pointer = 0; // Pointer for line buffer

    // ================================
    // (Step.1) Initial Routing (BEGIN)
    // ================================

#ifdef DEBUG_PRINT
    cout << "Initial Routing ..." << endl;
#endif

    FIRST_ROUTING:
    for (ap_uint<LINE_BIT> i = 0; i < (ap_uint<LINE_BIT>)(line_num); i++) {
#pragma HLS LOOP_TRIPCOUNT min=2 max=999

        s_idx[i] = pointer;

        if (adjacents[i] == true) continue; // Skip routing

#ifdef DEBUG_PRINT
        //cout << "LINE #" << (int)(i + 1) << endl;
#endif
        // Routing
        pointer = search(s_idx[i], paths, starts[i], goals[i], weights);
    }

    // ================================
    // (Step.1) Initial Routing (END)
    // ================================


    // Memories for Overlap Check
    ap_uint<1> overlap_checks[MAX_CELLS];
#pragma HLS ARRAY_PARTITION variable=overlap_checks cyclic factor=16 dim=1
    bool has_overlap = false;

    // ================================
    // (Step.2) Rip-up Routing (BEGIN)
    // ================================

#ifdef DEBUG_PRINT
    cout << "Rip-up Routing ..." << endl;
#endif

    ROUTING:
    for (ap_uint<16> round = 0; round < 32768 /* = (2048 * 16) */; round++) {
#pragma HLS LOOP_TRIPCOUNT min=1 max=32768

        // Target line
        ap_uint<LINE_BIT> target = round % line_num;
        ap_uint<LINE_BIT> next_target = target + 1;
        if (next_target == line_num) next_target = 0;

#ifdef DEBUG_PRINT
        //cout << "(round " << round << ") LINE #" << (int)(target + 1);
        //cout << " -> " << pointer << endl;
#endif
#ifdef DEBUG_PRINT
        int buffer_length = pointer - s_idx[target];
        if (max_buffer_length < buffer_length) { max_buffer_length = buffer_length; }
#endif

        // Skip routing
        if (adjacents[target] == true) {
            s_idx[target] = pointer;  continue;
        }


        // (Step.2-1) Reset weights of target line
        WEIGHT_RESET:
        for (ap_uint<BUFF_BIT> j = s_idx[target]; j != s_idx[next_target]; j++) {
#pragma HLS UNROLL factor=2
#pragma HLS LOOP_TRIPCOUNT min=1 max=256
            weights[paths[j]] = 1;
        }

        // (Step.2-2) Set weights of non-target lines and terminals
        ap_uint<8> current_round_weight = new_weight(round);
        WEIGHT_PATH:
        for (ap_uint<BUFF_BIT> j = s_idx[next_target]; j != pointer; j++) {
#pragma HLS UNROLL factor=2
#pragma HLS LOOP_TRIPCOUNT min=1 max=8192
            weights[paths[j]] = current_round_weight;
        }
        WEIGHT_TERMINAL:
        for (ap_uint<LINE_BIT> i = 0; i < (ap_uint<LINE_BIT>)(line_num); i++) {
#pragma HLS UNROLL factor=2
#pragma HLS LOOP_TRIPCOUNT min=2 max=999
            weights[starts[i]] = MAX_WEIGHT;
            weights[goals[i]] = MAX_WEIGHT;
        }
        // Reset weight of start terminal of target line (bug avoiding)
        // Restore original settings in (*)
        weights[starts[target]] = 1;

        // (Step.2-3) Routing
        s_idx[target] = pointer;
        pointer = search(s_idx[target], paths, starts[target], goals[target], weights);

        // (*)
        weights[starts[target]] = MAX_WEIGHT;

#ifdef DEBUG_PRINT
    bool ng = false;
    for (ap_uint<LINE_BIT> i = 0; i < (ap_uint<LINE_BIT>)(line_num); i++) {
    	if (weights[starts[i]] != 255 || weights[goals[i]] != 255) {
            cout << i << " "; ng = true;
        }
    }
    if(ng) { cout << endl; }
#endif

        // (Step.2-4) Overlap check
        has_overlap = false;
        OVERLAP_RESET:
        for (ap_uint<CELL_BIT> i = 0; i < (ap_uint<CELL_BIT>)(MAX_CELLS); i++) {
#pragma HLS UNROLL factor=32
            overlap_checks[i] = 0;
        }
        OVERLAP_CHECK_LINE:
        for (ap_uint<LINE_BIT> i = 0; i < (ap_uint<LINE_BIT>)(line_num); i++) {
#pragma HLS UNROLL factor=2
#pragma HLS LOOP_TRIPCOUNT min=2 max=999
            overlap_checks[starts[i]] = 1;
            overlap_checks[goals[i]] = 1;
        }
        OVERLAP_CHECK_PATH:
        for (ap_uint<BUFF_BIT> j = s_idx[next_target]; j != pointer; j++) {
#pragma HLS UNROLL factor=2
#pragma HLS LOOP_TRIPCOUNT min=1 max=8192
            ap_uint<CELL_BIT> cell_id = paths[j];
            if (overlap_checks[cell_id]) {
                has_overlap = true;  break;
            }
            overlap_checks[cell_id] = 1;
        }
#ifdef DEBUG_PRINT
    if(!has_overlap){ cout << "ROUND: " << round << endl; }
#endif
        if (!has_overlap) break; // Finish routing?
    }

#ifdef DEBUG_PRINT
    cout << "MAX PQ LENGTH: " << max_queue_length << endl;
    cout << "MAX SEARCH COUNT: " << max_search_count << endl;
    cout << "MAX BUFFER: " << max_buffer_length << endl;
#endif

    // Not solved
    if (has_overlap) {
        *status = 1;  return false;
    }

    // ================================
    // (Step.2) Rip-up Routing (END)
    // ================================


    // ================================
    // (Step.3) Output (BEGIN)
    // ================================

#ifdef DEBUG_PRINT
    cout << "Output ..." << endl;
#endif

    // Init: Blank = 0
    OUTPUT_INIT:
    for (ap_uint<CELL_BIT> i = 0; i < (ap_uint<CELL_BIT>)(MAX_CELLS); i++) {
        boardstr[i] = 0;
    }
    // Line
    OUTPUT_LINE:
    for (ap_uint<LINE_BIT> i = 0; i < (ap_uint<LINE_BIT>)(line_num); i++) {
#pragma HLS LOOP_TRIPCOUNT min=2 max=999
        boardstr[starts[i]] = (i + 1);
        boardstr[goals[i]] = (i + 1);

        ap_uint<BUFF_BIT> p1; // p1: s_idx of target
        ap_uint<BUFF_BIT> p2; // p2: s_idx of next target
        p1 = s_idx[i];
        if (i == (ap_uint<LINE_BIT>)(line_num-1)) {
            p2 = s_idx[0];
        }
        else {
            p2 = s_idx[i+1];
        }
        if ((ap_uint<BUFF_BIT>)(p2 - p1) > 8192){
            p2 = pointer;
        }
        OUTPUT_LINE_PATH:
        for (ap_uint<BUFF_BIT> j = p1; j != p2; j++) {
#pragma HLS LOOP_TRIPCOUNT min=1 max=256
            boardstr[paths[j]] = (i + 1);
        }
    }

    // ================================
    // (Step.3) Output (END)
    // ================================

    *status = 0;  return true;
}


// ================================ //
// For Routing
// ================================ //

// Max: 71, Min: 0 (7bit)
ap_uint<7> abs_uint7(ap_uint<7> a, ap_uint<7> b) {
#pragma HLS INLINE
    if (a < b) { return b - a; }
    else  { return a - b; }
}
// Max: 7, Min: 0 (3bit)
ap_uint<3> abs_uint3(ap_uint<3> a, ap_uint<3> b) {
#pragma HLS INLINE
    if (a < b) { return b - a; }
    else  { return a - b; }
}

// Reference codes:
// http://lethe2211.hatenablog.com/entry/2014/12/30/011030
// http://www.redblobgames.com/pathfinding/a-star/implementation.html
// Need to modify "array partition factor"
ap_uint<BUFF_BIT> search(ap_uint<BUFF_BIT> idx, ap_uint<CELL_BIT> paths[MAX_BUFFER], ap_uint<CELL_BIT> start, ap_uint<CELL_BIT> goal, ap_uint<8> w[MAX_CELLS]) {

    ap_uint<CELL_BIT> dist[MAX_CELLS];
#pragma HLS ARRAY_PARTITION variable=dist cyclic factor=64 dim=1
    ap_uint<CELL_BIT> prev[MAX_CELLS];

    SEARCH_INIT_DIST:
    for (ap_uint<CELL_BIT> i = 0; i < (ap_uint<CELL_BIT>)(MAX_CELLS); i++) {
#pragma HLS UNROLL factor=128
        dist[i] = 65535; // = (2^16 - 1)
    }

    // Priority queue (Circular list)
    ap_uint<PQ_BIT> top = 1, bottom = 0;
    bool is_empty = true;
    ap_uint<32> pq_nodes[MAX_PQ];

#ifdef DEBUG_PRINT
    int queue_length = 0;
    int search_count = 0;
#endif

    // Point of goal terminal
    ap_uint<13> goal_xy = (ap_uint<13>)(goal >> BITWIDTH_Z);
    ap_uint<7> goal_x = (ap_uint<7>)(goal_xy / MAX_WIDTH);
    ap_uint<7> goal_y = (ap_uint<7>)(goal_xy - goal_x * MAX_WIDTH);
    ap_uint<3> goal_z = (ap_uint<3>)(goal & BITMASK_Z);

    dist[start] = 0;
    pq_push(pq_nodes, 0, start, &top, &bottom, &is_empty); // push start terminal

    SEARCH_PQ:
    while (!is_empty) {
#pragma HLS LOOP_TRIPCOUNT min=1 max=1000
#pragma HLS LOOP_FLATTEN off

        ap_uint<16> prev_cost;
        ap_uint<16> src; // target cell
        pq_pop(pq_nodes, &prev_cost, &src, &top, &bottom, &is_empty);
#ifdef DEBUG_PRINT
        search_count++;
#endif


        // End routing
        if (src == goal) break;


        // Target cell
        ap_uint<16> dist_src = dist[src];
        ap_uint<8> cost = w[src];
        // Point of target cell
        ap_uint<13> src_xy = (ap_uint<13>)(src >> BITWIDTH_Z);
        ap_uint<7> src_x = (ap_uint<7>)(src_xy / MAX_WIDTH);
        ap_uint<7> src_y = (ap_uint<7>)(src_xy - src_x * MAX_WIDTH);
        ap_uint<3> src_z = (ap_uint<3>)(src & BITMASK_Z);

        // Search adjacent cells
        SEARCH_ADJACENTS:
        for (ap_uint<3> a = 0; a < 6; a++) {
            ap_int<8> dest_x = (ap_int<8>)src_x; // Min: -1, Max: 72 (Signed 8bit)
            ap_int<8> dest_y = (ap_int<8>)src_y; // Min: -1, Max: 72 (Signed 8bit)
            ap_int<5> dest_z = (ap_int<5>)src_z; // Min: -1, Max:  8 (Signed 5bit)
            if (a == 0) { dest_x -= 1; }
            if (a == 1) { dest_x += 1; }
            if (a == 2) { dest_y -= 1; }
            if (a == 3) { dest_y += 1; }
            if (a == 4) { dest_z -= 1; }
            if (a == 5) { dest_z += 1; }

            // Inside the board ? //
            if (0 <= dest_x && dest_x < (ap_int<8>)size_x && 0 <= dest_y && dest_y < (ap_int<8>)size_y && 0 <= dest_z && dest_z < (ap_int<5>)size_z) {
            	// Adjacent cell
                ap_uint<16> dest = (((ap_uint<16>)dest_x * MAX_WIDTH + (ap_uint<16>)dest_y) << BITWIDTH_Z) | (ap_uint<16>)dest_z;
                ap_uint<16> dist_new = dist_src + cost;

                if (dist[dest] > dist_new) {
                    dist[dest] = dist_new;	// Update dist
                    prev[dest] = src;		// Recode previous cell
                    dist_new += abs_uint7(dest_x, goal_x) + abs_uint7(dest_y, goal_y) + abs_uint3(dest_z, goal_z); // A* heuristic
                    pq_push(pq_nodes, dist_new, dest, &top, &bottom, &is_empty); // push adjacent cell
                }
            }
        }
#ifdef DEBUG_PRINT
        if (queue_length < (bottom-top+1)) { queue_length = (bottom-top+1); }
#endif
    }

    // Output target path
    // Note: Do not include start & goal terminals
    ap_uint<16> t = prev[goal];

    // Backtracking
    ap_uint<BUFF_BIT> p = idx; // buffer-idx
    SEARCH_BACKTRACK:
    while (t != start) {
#pragma HLS LOOP_TRIPCOUNT min=1 max=256
        paths[p] = t;
        p++;
        t = prev[t];
    }

#ifdef DEBUG_PRINT
    if (max_queue_length < queue_length) { max_queue_length = queue_length; }
    if (max_search_count < search_count) { max_search_count = search_count; }
#endif

    return p;
}

// Queue push (Enqueue)
// Need to modify "trip count" (2)
void pq_push(ap_uint<32> pq_nodes[MAX_PQ], ap_uint<16> priority, ap_uint<16> data, ap_uint<PQ_BIT> *top, ap_uint<PQ_BIT> *bottom, bool *is_empty) {
#pragma HLS INLINE

    (*bottom)++;
    if ((*bottom) == (*top) && !(*is_empty)) { (*bottom)--; } // Queue is full -> Last element is automatically removed

    // Binary search for circular list
    ap_uint<PQ_BIT> t = (*top);
    ap_uint<PQ_BIT> b = (*bottom);
    ap_uint<PQ_BIT> h = ((ap_uint<PQ_BIT>)(b - t) / 2) + t;
    // Note: "h = (t + b) / 2" causes a bug!
    PQ_PUSH_BINARY:
    while (t != b) {
#pragma HLS LOOP_TRIPCOUNT min=0 max=13
/** Set!: min=0 max=PQ_BIT **/
    	if ((ap_uint<16>)(pq_nodes[h] & PQ_PRIORITY_MASK) >= priority) {
    		b = h;
    	}
    	else {
    		t = h + 1;
    	}
    	h = ((ap_uint<PQ_BIT>)(b - t) / 2) + t;
    }

    // Shifting
    ap_uint<PQ_BIT> shift_count = (*bottom) - t; // # of shifting

    ap_uint<PQ_BIT> p0 = (*bottom), p1;
    PQ_PUSH_SHIFT:
    for (ap_uint<PQ_BIT> j = 0; j < (ap_uint<PQ_BIT>)(shift_count); j++) {
#pragma HLS LOOP_TRIPCOUNT min=0 max=8191
/** Set!: min=0 max=MAX_PQ-1 **/
        p0 = p0 - 1;
        p1 = p0 + 1;
    	pq_nodes[p1] = pq_nodes[p0];
    }
    pq_nodes[p0] = ((ap_uint<32>)data << PQ_PRIORITY_WIDTH) | (ap_uint<32>)priority;
    *is_empty = false;
}

// Queue pop (Dequeue)
void pq_pop(ap_uint<32> pq_nodes[MAX_PQ], ap_uint<16> *ret_priority, ap_uint<16> *ret_data, ap_uint<PQ_BIT> *top, ap_uint<PQ_BIT> *bottom, bool *is_empty) {
#pragma HLS INLINE

    *ret_priority = (ap_uint<16>)(pq_nodes[(*top)] & PQ_PRIORITY_MASK);
    *ret_data     = (ap_uint<16>)(pq_nodes[(*top)] >> PQ_PRIORITY_WIDTH);
    (*top)++;
    if (((*bottom)-(*top)+1) == 0) { *is_empty = true; }
}
