diff --git a/solver.py b/solver.py index 2cbbbdfd51648f5dbf5cb906f52bdc51cb674355..288bc4e9a78de9c03179b44fe68cb2cf8d65718f 100644 --- a/solver.py +++ b/solver.py @@ -9,6 +9,7 @@ import re import glob import time import random +import itertools if __name__ == '__main__': import cv2 as cv # opencv @@ -204,6 +205,15 @@ def print_cnf(nodes, cnfs, filename='p.txt'): #s.append(s_sub) #print(''.join(s)) +def print_cnf2(num, cnfs, filename='p.txt'): + s = [] + s.append(f'p cnf {num} {len(cnfs)}\n') + for cnf in cnfs: + cnf = [str(i) for i in cnf] + s.append((' '.join(cnf))+' 0\n') + with open(filename, 'w') as file: + file.write(''.join(s)) + def read_satA(filename, Q, nodes, TF=None): with open(filename, 'r') as file: str_line = file.readline() @@ -332,6 +342,135 @@ def read_satA(filename, Q, nodes, TF=None): A['w'] = w return A +def read_satA2(filename, Q, nodes, TF=None): + with open(filename, 'r') as file: + str_line = file.readline() + if str_line.startswith('UNSAT'): + print('UNSAT') + return None + else: + print('SAT') + str_line = file.readline() + ss = str_line.split() + if TF is None: + TF = [False] * (len(nodes)+1) + for num_s in ss: + num = int(num_s) + if num>0: + TF[num] = True + else: + TF[-num] = False + w = Q['w'] + h = Q['h'] + A = {'BLOCK':{}} + A['w'] = w + A['h'] = h + A['map'] = np.zeros((h,w), dtype=int) + num_b = len(Q['BLOCK']) + num_l = Q['num_l'] + A['num_b'] = num_b + A['num_l'] = num_l + for i in range(1, num_b+1): + A['BLOCK'][i] = {'index':i,'x':0,'y':0} + for i in range(h): + for j in range(w): + for b in range(1,num_b+1): + key = f'b{b}_{j}_{i}' + if key not in nodes.keys(): + break + index = nodes[key] + if TF[index]: + A['BLOCK'][b]['x'] = j + A['BLOCK'][b]['y'] = i + for l in range(1,num_l+1): + key = f'l{l}_{j}_{i}' + if key not in nodes.keys(): + break + index = nodes[key] + if TF[index]: + A['map'][i][j] = l + return A + ## 閉路の削除 + validation_map = np.zeros((h, w), dtype=int) + check_flag = np.zeros(num_l+1) + map = A['map'] + + # 端点から探索し + for b in range(1, num_b): + for dx,dy in Q['BLOCK'][b]['cells'].keys(): + c = Q['BLOCK'][b]['cells'][(dx,dy)] + + # 配線でないならcontinue + if c == '+' or c == '0': + continue + c = int(c) + + # 探索済みならcontinue + if check_flag[c] == 1: + continue + + x, y = A['BLOCK'][b]['x']+dx, A['BLOCK'][b]['y']+dy + while True: + validation_map[y, x] = 1 + x -= 1 + if x >= 0 and map[y, x] == c and validation_map[y, x] == 0: + continue + x += 2 + if x <= w-1 and map[y, x] == c and validation_map[y, x] == 0: + continue + x -= 1 + y -= 1 + if y >= 0 and map[y, x] == c and validation_map[y, x] == 0: + continue + y += 2 + if y <= h-1 and map[y, x] == c and validation_map[y, x] == 0: + continue + y -= 1 + break + check_flag[c] = 1 + # validation_mapが0 => 端点と接続していない => 閉路 + A['map'] = A['map'] * validation_map + + ## 最小矩形に更新 + map_l = A['map'] + h, w = map_l.shape + map_b = np.zeros((h, w), dtype=int) + for b in range(1, num_b+1): + x = A['BLOCK'][b]['x'] + y = A['BLOCK'][b]['y'] + for dx,dy in Q['BLOCK'][b]['cells'].keys(): + map_b[y+dy,x+dx] = b + i = 0 + while True: + if i >= h: + break + if (map_l[i,:] == 0).all() and (map_b[i,:] == 0).all(): + np.delete(map_l, i, axis=0) + np.delete(map_b, i, axis=0) + for b in range(1, num_b+1): + if A['BLOCK'][b]['y'] > i: + A['BLOCK'][b]['y'] -= 1 + h -= 1 + i -= 1 + i += 1 + j = 0 + while True: + if j >= w: + break + if (map_l[:,j] == 0).all() and (map_b[:,j] == 0).all(): + np.delete(map_l, j, axis=1) + np.delete(map_b, j, axis=1) + for b in range(1, num_b+1): + if A['BLOCK'][b]['x'] > j: + A['BLOCK'][b]['x'] -= 1 + w -= 1 + j -= 1 + j += 1 + A['map'] = map_l + A['h'] = h + A['w'] = w + return A + def debug_satA(filename, Q, nodes, TF=None): TF = [False] * (len(nodes)+1) A = read_satA(filename, Q, nodes, TF) @@ -359,6 +498,52 @@ def debug_satA(filename, Q, nodes, TF=None): print(l_board) return A +def debug_satA2(filename, Q, nodes, TF=None): + TF = [False] * (len(nodes)+1) + A = read_satA(filename, Q, nodes, TF) + w = Q['w'] + h = Q['h'] + num_b = len(Q['BLOCK']) + num_l = Q['num_l'] + v_board = np.zeros((h-1, w),dtype=bool) + for i in range(h-1): + for j in range(w): + key = f'lv_{j}_{i}' + index = nodes[key] + v_board[i, j] = TF[index] + print(f'lv=') + print(v_board) + h_board = np.zeros((h, w-1),dtype=bool) + for i in range(h): + for j in range(w-1): + key = f'lh_{j}_{i}' + index = nodes[key] + h_board[i, j] = TF[index] + print(f'lh=') + print(h_board) + + +# for b in range(1,num_b+1): +# b_board = np.zeros((h,w),dtype=bool) +# for i in range(h): +# for j in range(w): +# key = f'b{b}_{j}_{i}' +# index = nodes[key] +# b_board[i, j] = TF[index] +# print(f'b{b}=') +# print(b_board) + for l in range(1,num_l+1): + l_board = np.zeros((h, w),dtype=bool) + for i in range(h): + for j in range(w): + key = f'l{l}_{j}_{i}' + index = nodes[key] + l_board[i, j] = TF[index] + print(f'l{l}=') + print(l_board) + + + return A def generate_sat(Q, WH=None, seed=0): nodes = {} @@ -657,11 +842,337 @@ def generate_sat(Q, WH=None, seed=0): cnf[f'l{l}_{j+dx2}_{i+dy2}'] = '-1' return nodes,cnfs + +def generate_sat2(Q, WH=None, seed=0): + nodes = {} + cnfs = [] + + num_b = len(Q['BLOCK']) + num_l = Q['num_l'] + w = Q['w'] + h = Q['h'] + + if not WH is None: + w = WH[0] + h = WH[1] + Q['w'] = WH[0] + Q['h'] = WH[1] + + # 端点制約に使用 + bdxdylist = [] + for b in range(1, num_b+1): + cells = Q['BLOCK'][b]['cells'] + for dx,dy in cells.keys(): + v = cells[(dx,dy)] + if v == '+' or v == '0': + continue + else: + bdxdylist.append((b,-dx,-dy)) + + for b in range(1,num_b+1): + for i in range(h): + for j in range(w): + nodes[f'b{b}_{j}_{i}'] = len(nodes) + 1 + for l in range(1, num_l+1): + for i in range(h): + for j in range(w): + nodes[f'l{l}_{j}_{i}'] = len(nodes) + 1 + for i in range(h-1): + for j in range(w): + nodes[f'lv_{j}_{i}'] = len(nodes) + 1 + for i in range(h): + for j in range(w-1): + nodes[f'lh_{j}_{i}'] = len(nodes) + 1 + + # minisatが決定的に動く?ので変数の番号をランダムに割り当て + keys = list(nodes.keys()) + values = list(nodes.values()) + random.seed(seed) + np.random.seed(seed) + random.shuffle(values) + nodes = dict(zip(keys, values)) + + # b one-hot + for b in range(1,num_b+1): + cnf = [] + cnfs.append(cnf) + for i in range(h): + for j in range(w): + cnf.append(nodes[f'b{b}_{j}_{i}']) + for i in range(h): + for j in range(w): + for i2 in range(i, h): + for j2 in range(w): + if i < i2 or j < j2: + cnf = [] + cnfs.append(cnf) + cnf.append(-nodes[f'b{b}_{j}_{i}']) + cnf.append(-nodes[f'b{b}_{j2}_{i2}']) + + # line 0 or 1 + for i in range(h): + for j in range(w): + for l in range(1, num_l): + for l2 in range(l+1, num_l+1): + cnf = [] + cnfs.append(cnf) + cnf.append(-nodes[f'l{l}_{j}_{i}']) + cnf.append(-nodes[f'l{l2}_{j}_{i}']) + + # block位置に対してライン位置が決定 + for i in range(h): + for j in range(w): + for b in range(1, num_b+1): + cells = Q['BLOCK'][b]['cells'] + flag = False + for dx,dy in cells: + if j+dx < 0 or j+dx >= w or i+dy < 0 or i+dy >= h: + break + l_str = cells[(dx,dy)] + if l_str == '+': + for l in range(1,num_l+1): + cnf = [] + cnfs.append(cnf) + cnf.append(-nodes[f'b{b}_{j}_{i}']) + cnf.append(-nodes[f'l{l}_{j+dx}_{i+dy}']) + elif l_str != '0': + cnf = [] + cnfs.append(cnf) + cnf.append(-nodes[f'b{b}_{j}_{i}']) + cnf.append(nodes[f'l{int(l_str)}_{j+dx}_{i+dy}']) + + # blockがはみ出る位置には置けない + for i in range(h): + for j in range(w): + for b in range(1, num_b+1): + cells = Q['BLOCK'][b]['cells'] + for dx,dy in cells: + if j+dx < 0 or j+dx >= w or i+dy < 0 or i+dy >= h: + cnf = [] + cnfs.append(cnf) + cnf.append(-nodes[f'b{b}_{j}_{i}']) + break + + # block同士の衝突禁止 + for i in range(h): + for j in range(w): + for b in range(1, num_b): + cells = Q['BLOCK'][b]['cells'] + for i2 in range(i-3,i+4): + if i2 < 0 or i2 >= h: + continue + for j2 in range(j-3,j+4): + if j2 < 0 or j2 >= w: + continue + for b2 in range(b+1, num_b+1): + cells2 = Q['BLOCK'][b2]['cells'] + flag = False + for dx,dy in cells.keys(): + if flag: + break + if j+dx < 0 or j+dx >= w or i+dy < 0 or i+dy >= h: + flag = False + break + for dx2,dy2 in cells2.keys(): + if j2+dx2 < 0 or j2+dx2 >= w or i2+dy2 < 0 or i2+dy2 >= h: + break + if i+dy == i2+dy2 and j+dx == j2+dx2: + if not flag: + flag = True + break + if flag: + cnf = [] + cnfs.append(cnf) + cnf.append(-nodes[f'b{b}_{j}_{i}']) + cnf.append(-nodes[f'b{b2}_{j2}_{i2}']) + + + + + # 接続セルの配線は等しい + for i in range(h-1): + for j in range(w): + for l in range(1, num_l+1): + cnf = [] + cnfs.append(cnf) + cnf.append(-nodes[f'lv_{j}_{i}']) + cnf.append(-nodes[f'l{l}_{j}_{i}']) + cnf.append(nodes[f'l{l}_{j}_{i+1}']) + cnf = [] + cnfs.append(cnf) + cnf.append(-nodes[f'lv_{j}_{i}']) + cnf.append(nodes[f'l{l}_{j}_{i}']) + cnf.append(-nodes[f'l{l}_{j}_{i+1}']) + for i in range(h): + for j in range(w-1): + for l in range(1, num_l+1): + cnf = [] + cnfs.append(cnf) + cnf.append(-nodes[f'lh_{j}_{i}']) + cnf.append(-nodes[f'l{l}_{j}_{i}']) + cnf.append(nodes[f'l{l}_{j+1}_{i}']) + cnf = [] + cnfs.append(cnf) + cnf.append(-nodes[f'lh_{j}_{i}']) + cnf.append(nodes[f'l{l}_{j}_{i}']) + cnf.append(-nodes[f'l{l}_{j+1}_{i}']) + + # line孤立禁止 + for i in range(h): + for j in range(w): + for l in range(1, num_l+1): + cnf = [] + cnfs.append(cnf) + cnf.append(-nodes[f'l{l}_{j}_{i}']) + if j+1 <= w-1: + cnf.append(nodes[f'lh_{j}_{i}']) + if j-1 >= 0: + cnf.append(nodes[f'lh_{j-1}_{i}']) + if i+1 <= h-1: + cnf.append(nodes[f'lv_{j}_{i}']) + if i-1 >= 0: + cnf.append(nodes[f'lv_{j}_{i-1}']) + + # line三叉禁止 + for i in range(h): + for j in range(w): + around = [] + if j+1 <= w-1: + around.append(nodes[f'lh_{j}_{i}']) + if j-1 >= 0: + around.append(nodes[f'lh_{j-1}_{i}']) + if i+1 <= h-1: + around.append(nodes[f'lv_{j}_{i}']) + if i-1 >= 0: + around.append(nodes[f'lv_{j}_{i-1}']) + if len(around) == 2: + continue + if len(around) == 3: + cnf = [] + cnfs.append(cnf) + cnf.append(-around[0]) + cnf.append(-around[1]) + cnf.append(-around[2]) + if len(around) == 4: + for ar in itertools.combinations(around, 3): + cnf = [] + cnfs.append(cnf) + cnf.append(-ar[0]) + cnf.append(-ar[1]) + cnf.append(-ar[2]) + + # 非端点は周囲2マス以上に接続 + for i in range(h): + for j in range(w): + around = [] + if j+1 <= w-1: + around.append(nodes[f'lh_{j}_{i}']) + if j-1 >= 0: + around.append(nodes[f'lh_{j-1}_{i}']) + if i+1 <= h-1: + around.append(nodes[f'lv_{j}_{i}']) + if i-1 >= 0: + around.append(nodes[f'lv_{j}_{i-1}']) + if len(around) == 2: + for ar in around: + cnf = [] + cnfs.append(cnf) + for b,mdx,mdy in bdxdylist: + if i+mdy < 0 or i+mdy >= h or j+mdx < 0 or j+mdx >= w: + continue + cnf.append(nodes[f'b{b}_{j+mdx}_{i+mdy}']) + cnf.append(ar) + elif len(around) == 3: + for ar in itertools.combinations(around, 2): + cnf = [] + cnfs.append(cnf) + for b,mdx,mdy in bdxdylist: + if i+mdy < 0 or i+mdy >= h or j+mdx < 0 or j+mdx >= w: + continue + cnf.append(nodes[f'b{b}_{j+mdx}_{i+mdy}']) + cnf.append(ar[0]) + cnf.append(ar[1]) + elif len(around) == 4: + for ar in itertools.combinations(around, 3): + cnf = [] + cnfs.append(cnf) + for b,mdx,mdy in bdxdylist: + if i+mdy < 0 or i+mdy >= h or j+mdx < 0 or j+mdx >= w: + continue + cnf.append(nodes[f'b{b}_{j+mdx}_{i+mdy}']) + cnf.append(ar[0]) + cnf.append(ar[1]) + cnf.append(ar[2]) + # 端点は周囲2マス以上に接続禁止 + for i in range(h): + for j in range(w): + around = [] + if j+1 <= w-1: + around.append(nodes[f'lh_{j}_{i}']) + if j-1 >= 0: + around.append(nodes[f'lh_{j-1}_{i}']) + if i+1 <= h-1: + around.append(nodes[f'lv_{j}_{i}']) + if i-1 >= 0: + around.append(nodes[f'lv_{j}_{i-1}']) + if len(around) == 2: + for b,mdx,mdy in bdxdylist: + if i+mdy < 0 or i+mdy >= h or j+mdx < 0 or j+mdx >= w: + continue + cnf = [] + cnfs.append(cnf) + cnf.append(-nodes[f'b{b}_{j+mdx}_{i+mdy}']) + cnf.append(-around[0]) + cnf.append(-around[1]) + elif len(around) == 3: + for ar in itertools.combinations(around, 2): + for b,mdx,mdy in bdxdylist: + if i+mdy < 0 or i+mdy >= h or j+mdx < 0 or j+mdx >= w: + continue + cnf = [] + cnfs.append(cnf) + cnf.append(-nodes[f'b{b}_{j+mdx}_{i+mdy}']) + cnf.append(-ar[0]) + cnf.append(-ar[1]) + elif len(around) == 4: + for ar in itertools.combinations(around, 2): + for b,mdx,mdy in bdxdylist: + if i+mdy < 0 or i+mdy >= h or j+mdx < 0 or j+mdx >= w: + continue + cnf = [] + cnfs.append(cnf) + cnf.append(-nodes[f'b{b}_{j+mdx}_{i+mdy}']) + cnf.append(-ar[0]) + cnf.append(-ar[1]) + return nodes,cnfs + ## 不要な制約 + # 配線していないセルには線を接続しない + for i in range(h): + for j in range(w): + around = [] + if j+1 <= w-1: + around.append(nodes[f'lh_{j}_{i}']) + if j-1 >= 0: + around.append(nodes[f'lh_{j-1}_{i}']) + if i+1 <= h-1: + around.append(nodes[f'lv_{j}_{i}']) + if i-1 >= 0: + around.append(nodes[f'lv_{j}_{i-1}']) + for index in around: + cnf = [] + cnfs.append(cnf) + for l in range(1, num_l+1): + cnf.append(nodes[f'l{l}_{j}_{i}']) + cnf.append(-index) + + return nodes,cnfs + + def main(): - Q = readQ('./adc2019problem/Q001_10X10_b8_l11.txt') + #Q = readQ('./adc2019problem/Q001_10X10_b8_l11.txt') #Q = readQ('./adc2019problem/Q002_10X10_b8_l9.txt') - #Q = readQ('./adc2019problem/Q003_10X10_b5_l5.txt') + Q = readQ('./adc2019problem/Q003_10X10_b5_l5.txt') #Q = readQ('./adc2019problem/Q004_10X10_b8_l9.txt') #Q = readQ('./adc2019problem/Q005_10X10_b7_l6.txt') #Q = readQ('./adc2019problem/Q006_10X10_b8_l9.txt') @@ -674,16 +1185,17 @@ def main(): #Q = readQ('./adc2019problem/Q015_10X10_b8_l9.txt') #Q = readQ('./adc2019problem/QRAND327_20X20_b20_l15.txt') #Q = readQ('./adc2019problem/QRAND368_10X10_b10_l15.txt') + #Q = readQ('./adc2019problem/random_problem/QRAND136_30X30_b30_l30.txt') #start_time = time.time() - nodes,cnfs = generate_sat(Q, WH=None, seed=0) - print_cnf(nodes, cnfs) + nodes,cnfs = generate_sat2(Q, WH=(8,5), seed=0) + print_cnf2(len(nodes), cnfs) os.system('minisat p.txt a.txt') - A = read_satA('a.txt', Q, nodes) + A = read_satA2('a.txt', Q, nodes) if A is None: return print(strA(A)) - #A = debug_satA('a.txt', Q, nodes) + debug_satA2('a.txt', Q, nodes) #end_time= time.time() #print(f"経過時間:{end_time - start_time}") #img = QA2img(Q,A)