#include "router.hpp"

// DATA_BIT: bit length of cell idx
// PRIO_BIT: bit length of cost(priority)
int router(short int size_x, short int size_y, short int line_num, short int board_str[], ap_uint<32> seed){
	
	ap_uint<DATA_BIT> board_str_len = size_x * size_y;
	ap_uint<DATA_BIT> j;
	
	ap_uint<DATA_BIT> terminals[MAX_LINES][2];
	short int paths_size[MAX_LINES];
	ap_uint<DATA_BIT> paths[MAX_LINES][MAX_PATH];
	bool adjacents[MAX_LINES];
	
	ap_uint<PRIO_BIT> weights[MAX_CELLS];
	
	short int i, x, y, p;
	
	// Line init.
	for(i = 0; i < line_num; i++) {
		terminals[i][0] = DATA_MAX;
		terminals[i][1] = DATA_MAX;
		paths_size[i] = 0;
		adjacents[i] = false;
	}
	for(j = 0; j < board_str_len; j++) {
		weights[j] = 1; // init. value is "1"
		i = board_str[j];
		if(i > 0) {
			if(terminals[i-1][0] == DATA_MAX) { terminals[i-1][0] = j; }
			else if(terminals[i-1][1] == DATA_MAX) { terminals[i-1][1] = j; }
			else {
				cout << "Error: line#" << i << endl;
				exit(1);
			}
			weights[j] = PRIO_MAX;
		}
		else if(i < 0) {
			weights[j] = PRIO_MAX;
		}
	}
	
#ifdef PRINT_BOARD
	cout << "== Print board ==" << endl;
	for(y = 0; y < size_y; y++) {
		for(x = 0; x < size_x; x++) {
			j = y * size_x + x;
			i = board_str[j];
			if(i == 0) { cout << "-"; }
			else if(i == -1) { cout << "X"; }
			else{ cout << i; }
			if(x != size_x-1) { cout << " "; }
		}
		cout << endl;
	}
#endif
	for(i = 0; i < line_num; i++) {
		ap_uint<DATA_BIT> t1 = terminals[i][0];
		ap_uint<DATA_BIT> t2 = terminals[i][1];
		short int t1_x = t1 % size_x;
		short int t1_y = t1 / size_x;
		short int t2_x = t2 % size_x;
		short int t2_y = t2 / size_x;
		short int dist = abs(t1_x - t2_x) + abs(t1_y - t2_y); // Manhattan dist.
		if(dist == 1) {
			adjacents[i] = true;
#ifdef PRINT_BOARD
			cout << "Line #" << i + 1 << " needs no routing" << endl;
#endif
		}
	}
	
	
	lfsr_random_init(seed);
	
	// Step 1
	cout << "1st routing ..." << endl;
	for(i = 0; i < line_num; i++) {
		if(adjacents[i]) continue;
#ifdef PRINT_SEARCH
		cout << "Line #" << i + 1 << endl;
#endif
		ap_uint<DATA_BIT> t1 = terminals[i][0];
		ap_uint<DATA_BIT> t2 = terminals[i][1];
		weights[t1] = 1;
		if(search(size_x, size_y, &paths_size[i], paths[i], t1, t2, weights) < 0) { return 0; }
		weights[t1] = PRIO_MAX;
	}
	
	bool has_overlap;
	ap_uint<1> overlap_checks[MAX_CELLS];
	
	// Step 2
	cout << "rip-up routing ..." << endl;
	short int last_target = -1;
	ap_uint<16> round;
	for(round = 1; round <= ROUND_LIMIT; round++) {
		short int target = lfsr_random() % line_num;
		if(adjacents[target]) continue;
		if(target == last_target) continue;
		last_target = target;
		ap_uint<PRIO_BIT> round_weight = new_weight(round);

#ifdef PRINT_SEARCH		
		cout << "Line #" << target + 1 << "(round: " << round << ", weight: " << round_weight << ")" << endl;
#endif
		
		ap_uint<DATA_BIT> t1 = terminals[target][0];
		ap_uint<DATA_BIT> t2 = terminals[target][1];
		
		for(p = 0; p < paths_size[target]; p++) {
			weights[paths[target][p]] = 1;
		}
		for(i = 0; i < line_num; i++) {
			if(i == target) continue;
			for(p = 0; p < paths_size[i]; p++) {
				weights[paths[i][p]] = round_weight;
			}
		}
		weights[t1] = 1;
		search(size_x, size_y, &paths_size[target], paths[target], t1, t2, weights);
		weights[t1] = PRIO_MAX;
		
		has_overlap = false;
		for(j = 0; j < board_str_len; j++) {
			i = board_str[j];
			if(i > 0 || i < 0) { overlap_checks[j] = 1; }
			else { overlap_checks[j] = 0; }
		}
		for(i = 0; i < line_num; i++) {
			for(p = 0; p < paths_size[i]; p++) {
				ap_uint<DATA_BIT> cell_id = paths[i][p];
				if(overlap_checks[cell_id]) {
					has_overlap = true;
					break;
				}
				overlap_checks[cell_id] = 1;
			}
		}
		if(!has_overlap) break;
	}
	
	/** //debug
	for(j = 0; j < board_str_len; j++) {
		i = board_str[j];
		cout << i << ": " << weights[j] << endl;
	}
	**/
	
	if(has_overlap) { // Cannot solve
		return 0;
	}
	
	int total_wire_length = 0;
	for(i = 0; i < line_num; i++) {
		total_wire_length += paths_size[i];
		for(p = 0; p < paths_size[i]; p++) {
			ap_uint<DATA_BIT> cell_id = paths[i][p];
			board_str[cell_id] = i + 1;
		}
	}
	
#ifdef PRINT_BOARD
	cout << "== Print answer ==" << endl;
	for(y = 0; y < size_y; y++) {
		for(x = 0; x < size_x; x++) {
			j = y * size_x + x;
			i = board_str[j];
			if(i == 0) { cout << "-"; }
			else if(i == -1) { cout << "X"; }
			else{ cout << i; }
			if(x != size_x-1) { cout << " "; }
		}
		cout << endl;
	}
#endif
	
	return total_wire_length;
}

void lfsr_random_init(ap_uint<32> seed) {
	lfsr = seed;
}

ap_uint<32> lfsr_random() {
	bool b_32 = lfsr.get_bit(32-32);
	bool b_22 = lfsr.get_bit(32-22);
	bool b_2 = lfsr.get_bit(32-2);
	bool b_1 = lfsr.get_bit(32-1);
	bool new_bit = b_32 ^ b_22 ^ b_2 ^ b_1;
	lfsr = lfsr >> 1;
	lfsr.set_bit(31, new_bit);
	return lfsr.to_uint();
}

ap_uint<PRIO_BIT> new_weight(ap_uint<16> x) {
	ap_uint<16> y = (x & 0x00FF) + 1;
	return (ap_uint<PRIO_BIT>)(y);
}

int search(short int size_x, short int size_y, short int *path_size, ap_uint<DATA_BIT> path[MAX_PATH], ap_uint<DATA_BIT> start, ap_uint<DATA_BIT> goal, ap_uint<PRIO_BIT> w[MAX_CELLS]){
	
	ap_uint<PRIO_BIT> dist[MAX_CELLS];
	ap_uint<DATA_BIT> prev[MAX_CELLS];
	
	int j;
	for(j = 0; j < MAX_CELLS; j++) {
		dist[j] = PRIO_MAX;
	}
	
	ap_uint<PQ_BIT> pq_len = 0;
    bool is_empty = true;
    ap_uint<ELEM_BIT> pq_nodes[MAX_PQ];
	
	short int goal_x = goal % size_x;
	short int goal_y = goal / size_x;
	
	dist[start] = 0;
	enqueue(pq_nodes, 0, start, &pq_len, &is_empty);
	
	bool find_path = false; // No path exists
	while(!is_empty) {
		
		ap_uint<PRIO_BIT> prev_cost;
		ap_uint<DATA_BIT> s;
		dequeue(pq_nodes, &prev_cost, &s, &pq_len, &is_empty);
		
		ap_uint<PRIO_BIT> dist_s = dist[s];
		
		if(s == goal) {
			find_path = true;
			break;
		}
		
		ap_uint<PRIO_BIT> cost = w[s];
		short int s_x = s % size_x;
		short int s_y = s / size_x;
		
		int a;
		for(a = 0; a < 4; a++) {
			short int d_x = s_x;
			short int d_y = s_y;
			switch(a) {
				case 0: d_x -= 1; break;
				case 1: d_x += 1; break;
				case 2: d_y -= 1; break;
				case 3: d_y += 1; break;
			}
			if(d_x >= 0 && d_x < size_x && d_y >= 0 && d_y < size_y) {
				ap_uint<DATA_BIT> d = d_y * size_x + d_x;
				if(w[d] == PRIO_MAX && d != goal) continue;
				
				ap_uint<PRIO_BIT> dist_d = dist_s + cost;
				if(dist_d < dist[d]) {
					dist[d] = dist_d;
					prev[d] = s;
					dist_d += abs(d_x - goal_x) + abs(d_y - goal_y);
					enqueue(pq_nodes, dist_d, d, &pq_len, &is_empty);
				}
			}
		}
	}
	if(!find_path){ return -1; }
	
	ap_uint<DATA_BIT> t = prev[goal];
	short int p = 0;
#ifdef PRINT_SEARCH
	cout << "Path: ";
#endif
	while(t != start) {
#ifdef PRINT_SEARCH
		cout << "(" << (short int)t % size_x << ", " << (short int)t / size_x << ")";
#endif
		path[p] = t;
		p++;
		t = prev[t];
	}
#ifdef PRINT_SEARCH
	cout << endl;
#endif
	*path_size = p;
	
	return 0;
}

// Enqueue (Insert an element)
void enqueue(ap_uint<ELEM_BIT> pq_nodes[MAX_PQ], ap_uint<PRIO_BIT> priority, ap_uint<DATA_BIT> data, ap_uint<PQ_BIT> *pq_len, bool *is_empty){

	(*pq_len)++;
	if ((*pq_len) == 0) { (*pq_len)--; } // Queue is full -> Last element is automatically removed
	// Note that last element is not always the lowest priority one. //

	ap_uint<PQ_BIT> i = (*pq_len);
	ap_uint<PQ_BIT> p = (*pq_len) >> 1; // parent node
	while (i > 1 && (ap_uint<PRIO_BIT>)(pq_nodes[p] >> DATA_BIT) >= priority) {
		pq_nodes[i] = pq_nodes[p];
		i = p;
		p = p >> 1; // parent node
	}
	pq_nodes[i] = ((ap_uint<ELEM_BIT>)priority << DATA_BIT) | (ap_uint<ELEM_BIT>)data;
	*is_empty = false;
}

// Dequeue (Extract and remove the top element)
void dequeue(ap_uint<ELEM_BIT> pq_nodes[MAX_PQ], ap_uint<PRIO_BIT> *ret_priority, ap_uint<DATA_BIT> *ret_data, ap_uint<PQ_BIT> *pq_len, bool *is_empty){

	*ret_priority = (ap_uint<PRIO_BIT>)(pq_nodes[1] >> DATA_BIT);
	*ret_data     = (ap_uint<DATA_BIT>)(pq_nodes[1] & DATA_MASK);
	
	ap_uint<PQ_BIT> i = 1; // root node
	ap_uint<PRIO_BIT> last_priority = (ap_uint<PRIO_BIT>)(pq_nodes[*pq_len] >> DATA_BIT); // Priority of last element

	while (!(i >> (PQ_BIT-1))) {
		ap_uint<PQ_BIT> c1 = i << 1; // child node (L)
		ap_uint<PQ_BIT> c2 = c1 + 1; // child node (R)
		if (c1 < *pq_len && (ap_uint<PRIO_BIT>)(pq_nodes[c1] >> DATA_BIT) <= last_priority) {
			if (c2 < *pq_len && (ap_uint<PRIO_BIT>)(pq_nodes[c2] >> DATA_BIT) <= (ap_uint<PRIO_BIT>)(pq_nodes[c1] >> DATA_BIT)) {
				pq_nodes[i] = pq_nodes[c2];
				i = c2;
			}
			else {
				pq_nodes[i] = pq_nodes[c1];
				i = c1;
			}
		}
		else {
			if (c2 < *pq_len && (ap_uint<PRIO_BIT>)(pq_nodes[c2] >> DATA_BIT) <= last_priority) {
				pq_nodes[i] = pq_nodes[c2];
				i = c2;
			}
			else {
				break;
			}
		}
	}
	pq_nodes[i] = pq_nodes[*pq_len];
	(*pq_len)--;
	if ((*pq_len) == 0) { *is_empty = true; }
}
