diff --git a/solver.py b/solver.py index 2cbbbdfd51648f5dbf5cb906f52bdc51cb674355..a36f0524af2ad5ec4143c6270d60a9c900393606 100644 --- a/solver.py +++ b/solver.py @@ -291,7 +291,7 @@ def read_satA(filename, Q, nodes, TF=None): check_flag[c] = 1 # validation_mapが0 => 端点と接続していない => 閉路 A['map'] = A['map'] * validation_map - + ## 最小矩形に更新 map_l = A['map'] h, w = map_l.shape @@ -306,8 +306,8 @@ def read_satA(filename, Q, nodes, TF=None): if i >= h: break if (map_l[i,:] == 0).all() and (map_b[i,:] == 0).all(): - np.delete(map_l, i, axis=0) - np.delete(map_b, i, axis=0) + map_l = np.delete(map_l, i, axis=0) + map_b = np.delete(map_b, i, axis=0) for b in range(1, num_b+1): if A['BLOCK'][b]['y'] > i: A['BLOCK'][b]['y'] -= 1 @@ -319,8 +319,8 @@ def read_satA(filename, Q, nodes, TF=None): if j >= w: break if (map_l[:,j] == 0).all() and (map_b[:,j] == 0).all(): - np.delete(map_l, j, axis=1) - np.delete(map_b, j, axis=1) + map_l = np.delete(map_l, j, axis=1) + map_b = np.delete(map_b, j, axis=1) for b in range(1, num_b+1): if A['BLOCK'][b]['x'] > j: A['BLOCK'][b]['x'] -= 1 @@ -331,7 +331,7 @@ def read_satA(filename, Q, nodes, TF=None): A['h'] = h A['w'] = w return A - + def debug_satA(filename, Q, nodes, TF=None): TF = [False] * (len(nodes)+1) A = read_satA(filename, Q, nodes, TF) @@ -340,25 +340,24 @@ def debug_satA(filename, Q, nodes, TF=None): num_b = len(Q['BLOCK']) num_l = Q['num_l'] for b in range(1,num_b+1): - b_board = np.zeros((w,h),dtype=bool) + b_board = np.zeros((h,w),dtype=bool) for i in range(h): for j in range(w): key = f'b{b}_{j}_{i}' index = nodes[key] - b_board[j,i] = TF[index] + b_board[i, j] = TF[index] print(f'b{b}=') print(b_board) for l in range(1,num_l+1): - l_board = np.zeros((w,h),dtype=bool) + l_board = np.zeros((h,w),dtype=bool) for i in range(h): for j in range(w): key = f'l{l}_{j}_{i}' index = nodes[key] - l_board[j,i] = TF[index] + l_board[i, j] = TF[index] print(f'l{l}=') print(l_board) return A - def generate_sat(Q, WH=None, seed=0): nodes = {} @@ -478,9 +477,6 @@ def generate_sat(Q, WH=None, seed=0): cnf[f'l{l}_{j+dx2}_{i+dy2}'] = '-1' cnf[f'l{l}_{j+dx3}_{i+dy3}'] = '-1' - - - # blockがはみ出る位置には置けない for i in range(h): for j in range(w): @@ -520,27 +516,27 @@ def generate_sat(Q, WH=None, seed=0): cnf[f'l{l}_{j+dx}_{i+dy}'] = '-1' # block同士の衝突禁止 - for i in range(h): - for j in range(w): - for b in range(1, num_b+1): - cells = Q['BLOCK'][b]['cells'] - for i2 in range(i-4,i+5): - if i2 < 0 or i2 >= h: - continue - for j2 in range(j-4,j+5): - if j2 < 0 or j2 >= w: + for b in range(1, num_b+1): + cells = Q['BLOCK'][b]['cells'] + wb = Q['BLOCK'][b]['w'] + hb = Q['BLOCK'][b]['h'] + for b2 in range(b+1, num_b+1): + cells2 = Q['BLOCK'][b2]['cells'] + wb2 = Q['BLOCK'][b2]['w'] + hb2 = Q['BLOCK'][b2]['h'] + for i in range(h-hb+1): + for j in range(w-wb+1): + for i2 in range(i-hb2+1,i+hb): + if i2 < 0 or i2 >= h: continue - for b2 in range(b+1, num_b+1): - cells2 = Q['BLOCK'][b2]['cells'] + for j2 in range(j-wb2+1,j+wb): + if j2 < 0 or j2 >= w: + continue flag = False for dx,dy in cells.keys(): if flag: break - if j+dx < 0 or j+dx >= w or i+dy < 0 or i+dy >= h: - continue for dx2,dy2 in cells2.keys(): - if j2+dx2 < 0 or j2+dx2 >= w or i2+dy2 < 0 or i2+dy2 >= h: - continue if i+dy == i2+dy2 and j+dx == j2+dx2: if not flag: flag = True @@ -548,6 +544,7 @@ def generate_sat(Q, WH=None, seed=0): cnfs.append(cnf) cnf[f'b{b}_{j}_{i}'] = '-1' cnf[f'b{b2}_{j2}_{i2}'] = '-1' + break # block上にないlineは周囲2マスに接続 for i in range(h):