Commit b63ef575 authored by  tawada's avatar tawada

add solver

parent 202fd317
# パッケージ
%matplotlib inline
import numpy as np
import cv2 as cv # opencv
import matplotlib
import matplotlib.pyplot as plt # matplotlibの描画系
from matplotlib import cm
from collections import OrderedDict
import os
import math
import re
import glob
import time
def split(string):
re_list = re.split('[ X#@,()\t\n]',string)
while '' in re_list:
re_list.remove('')
return re_list
def readQ(filename):
Q = {'BLOCK':{}}
max_l = 0
with open(filename, 'r') as file:
while True:
str_line = file.readline()
if not str_line:
break
result = split(str_line)
if result == []:
continue
if result[0] == 'SIZE':
Q['w'] = int(result[1])
Q['h'] = int(result[2])
elif result[0] == 'BLOCK_NUM':
Q['num_b'] = int(result[1])
Q['n'] = Q['num_b']
elif result[0] == 'BLOCK':
index = int(result[1])
w = int(result[2])
h = int(result[3])
bs = Q['BLOCK']
cells = {}
bs[index] = {'index':index,'w':w,'h':h,'cells':cells}
for i in range(h):
subline = file.readline()
subr = split(subline)
for j in range(w):
if subr[j] != '0':
cells[(j,i)] = subr[j]
if subr[j] != '+':
num = int(subr[j])
if num > max_l:
max_l = num
Q['num_l'] = max_l
return Q
def readA(filename):
A = {'BLOCK':{}}
with open(filename, 'r') as file:
i = 0
while True:
str_line = file.readline()
if not str_line:
break
result = split(str_line)
if result == []:
continue
if result[0] == 'SIZE':
A['w'] = int(result[1])
A['h'] = int(result[2])
A['map'] = np.zeros((A['h'],A['w']), dtype=int)
elif result[0] == 'BLOCK':
index = int(result[1])
x = int(result[2])
y = int(result[3])
bs = A['BLOCK']
cells = {}
bs[index] = {'index':index,'x':x,'y':y}
else:
for j in range(A['w']):
A['map'][i,j] = int(result[j])
i += 1
return A
def index2color(index, mode='cell'):
if mode == 'cell':
if index == -1:
R,G,B,Alpha = 1,1,1,1
else:
R,G,B,Alpha = cm.tab20((index*2)%20)
else:
R,G,B,Alpha = cm.tab20((index*2+1)%20)
return np.array([R,G,B])
def QA2img(Q,A):
w_grid = 5
w_line = 25
len_cell = 50
margin = (len_cell + w_grid - w_line)//2
white = [1,1,1]
black = [0,0,0]
red = [1,0,0]
w_max, h_max = Q['w'],Q['h']
w,h,n = A['w'],A['h'],Q['n']
map_A = A['map']
bs_q = Q['BLOCK']
bs_a = A['BLOCK']
img = np.zeros((h_max*len_cell+w_grid,w_max*len_cell+w_grid,3))
img[:,:] = white
img[[i for i in range(h_max*len_cell+w_grid) if (i%len_cell)<w_grid],:] = black
img[:,[j for j in range(w_max*len_cell+w_grid) if (j%len_cell)<w_grid]] = black
img[0:h*len_cell,0:w_grid] = red
img[0:h*len_cell,w*len_cell:w*len_cell+w_grid] = red
img[0:w_grid,0:w*len_cell] = red
img[h*len_cell:h*len_cell+w_grid,0:w*len_cell+w_grid] = red
for index in range(1, n+1):
x, y = bs_a[index]['x'], bs_a[index]['y']
cells = bs_q[index]['cells']
for dx,dy in cells.keys():
X = (x+dx)*len_cell
Y = (y+dy)*len_cell
img[Y+w_grid:Y+len_cell,X+w_grid:X+len_cell] = index2color(index-1, 'cell')
for i in range(h):
for j in range(w):
index_line = map_A[i][j]
if index_line == 0:
continue
X = j*len_cell
Y = i*len_cell
Xp = X+len_cell
Yp = Y+len_cell
img[Y+margin:Y+margin+w_line,X+margin:X+margin+w_line] = index2color(index_line, 'line')
if j+1<w and map_A[i][j+1] == index_line:
img[Y+margin:Y+margin+w_line,X+margin+w_line:X+margin+len_cell] = index2color(index_line, 'line')
if i+1<h and map_A[i+1][j] == index_line:
img[Y+margin+w_line:Y+margin+len_cell,X+margin:X+margin+w_line] = index2color(index_line, 'line')
cv.putText(img, f'{index_line}', (X+margin,Y+margin+w_line-3), cv.FONT_HERSHEY_PLAIN, 1.5, (.1,.1,.1), 2)
return img
def show(index, read=None):
paths = sorted(glob.glob(f"../Q*.txt"))
Q = readQ(paths[index])
if read is None:
A = readA(f'../A{paths[index][4:7]}.txt')
else:
A = read
print(f'A{paths[index][4:7]}.txt')
img = QA2img(Q,A)
wQ = Q['w']
hQ = Q['h']
wA = A['w']
hA = A['h']
print(f'問題サイズ{wQ}×{hQ},解答サイズ{wA}×{hA}')
print(f'矩形領域の割合{(wA*hA)/(wQ*hQ)}')
img_uint8 = (img*255).astype(int)
R = np.array(img_uint8[:,:,2])
B = np.array(img_uint8[:,:,0])
img_uint8[:,:,0] = R
img_uint8[:,:,2] = B
cv.imwrite(f'A{index+1:03}.png', img_uint8)
plt.imshow(img)
def add(e, _dict):
if e not in _dict.keys():
_dict[e] = len(_dict) + 1
def print_cnf(nodes, cnfs):
s = []
filename = 'p.txt'
with open(filename, 'w') as file:
s.append(f'p cnf {len(nodes)} {len(cnfs)}\n')
file.write(''.join(s))
for cnf in cnfs:
s_sub = ''
for e in cnf:
s_sub += ('' if cnf[e] == '1' else '-')+f'{nodes[e]} '
s_sub += '0\n'
file.write(s_sub)
#s.append(s_sub)
#print(''.join(s))
def read_satA(filename, Q, nodes, TF=None):
with open(filename, 'r') as file:
str_line = file.readline()
if str_line.startswith('UNSAT'):
print('UNSAT')
else:
print('SAT')
str_line = file.readline()
ss = str_line.split()
if TF is None:
TF = [False] * (len(nodes)+1)
for num_s in ss:
num = int(num_s)
if num>0:
TF[num] = True
else:
TF[-num] = False
w = Q['w']
h = Q['h']
A = {'BLOCK':{}}
A['w'] = w
A['h'] = h
A['map'] = np.zeros((h,w), dtype=int)
num_b = len(Q['BLOCK'])
num_l = Q['num_l']
for i in range(1, num_b+1):
A['BLOCK'][i] = {'index':i,'x':0,'y':0}
for i in range(h):
for j in range(w):
for b in range(1,num_b+1):
key = f'b{b}_{j}_{i}'
if key not in nodes.keys():
break
index = nodes[key]
if TF[index]:
A['BLOCK'][b]['x'] = j
A['BLOCK'][b]['y'] = i
for l in range(1,num_l+1):
key = f'l{l}_{j}_{i}'
if key not in nodes.keys():
break
index = nodes[key]
if TF[index]:
A['map'][i][j] = l
return A
def debug_satA(filename, Q, nodes, TF=None):
TF = [False] * (len(nodes)+1)
A = read_satA(filename, Q, nodes, TF)
w = Q['w']
h = Q['h']
num_b = len(Q['BLOCK'])
num_l = Q['num_l']
for b in range(1,num_b+1):
b_board = np.zeros((w,h),dtype=bool)
for i in range(h):
for j in range(w):
key = f'b{b}_{j}_{i}'
index = nodes[key]
b_board[j,i] = TF[index]
print(f'b{b}=')
print(b_board)
for l in range(1,num_l+1):
l_board = np.zeros((w,h),dtype=bool)
for i in range(h):
for j in range(w):
key = f'l{l}_{j}_{i}'
index = nodes[key]
l_board[j,i] = TF[index]
print(f'l{l}=')
print(l_board)
return A
def generate_sat(Q):
nodes = {}
cnfs = []
num_b = len(Q['BLOCK'])
num_l = Q['num_l']
w = Q['w'] = 20
h = Q['h'] = 20
# ラインからブロック番号, 位置を逆引き
l2b = [[] for _ in range(num_l+1)]
for b in range(1, num_b+1):
cells = Q['BLOCK'][b]['cells']
for dx,dy in cells.keys():
v = cells[(dx,dy)]
if v == '+' or v == '0':
l2b[0].append((b,-dx,-dy))
else:
index = int(v)
l2b[index].append((b,-dx,-dy))
for b in range(1,num_b+1):
for i in range(h):
for j in range(w):
nodes[f'b{b}_{j}_{i}'] = len(nodes) + 1
for l in range(1, num_l+1):
for i in range(h):
for j in range(w):
nodes[f'l{l}_{j}_{i}'] = len(nodes) + 1
# b one-hot
for b in range(1,num_b+1):
cnf = {}
cnfs.append(cnf)
for i in range(h):
for j in range(w):
cnf[f'b{b}_{j}_{i}'] = '1'
for i in range(h):
for j in range(w):
for i2 in range(i, h):
for j2 in range(w):
if i < i2 or j < j2:
cnf = {}
cnfs.append(cnf)
cnf[f'b{b}_{j}_{i}'] = '-1'
cnf[f'b{b}_{j2}_{i2}'] = '-1'
# line 0 or 1
for i in range(h):
for j in range(w):
for l in range(1, num_l):
for l2 in range(l+1, num_l+1):
cnf = {}
cnfs.append(cnf)
cnf[f'l{l}_{j}_{i}'] = '-1'
cnf[f'l{l2}_{j}_{i}'] = '-1'
# line孤立禁止
for i in range(h):
for j in range(w):
for l in range(1, num_l+1):
around = [(1,0),(-1,0),(0,1),(0,-1)]
cnf = {}
cnfs.append(cnf)
cnf[f'l{l}_{j}_{i}'] = '-1'
for dx,dy in around:
if j+dx < 0 or j+dx >= w or i+dy < 0 or i+dy >= h:
continue
cnf[f'l{l}_{j+dx}_{i+dy}'] = '1'
# line三叉禁止
for i in range(h):
for j in range(w):
for l in range(1, num_l+1):
around = [(1,0),(-1,0),(0,1),(0,-1)]
new_around = []
for dx,dy in around:
if j+dx < 0 or j+dx >= w or i+dy < 0 or i+dy >= h:
continue
new_around.append((dx,dy))
if len(new_around) <= 2:
continue
elif len(new_around) <= 3:
cnf = {}
cnfs.append(cnf)
cnf[f'l{l}_{j}_{i}'] = '-1'
for dx,dy in new_around:
cnf[f'l{l}_{j+dx}_{i+dy}'] = '-1'
elif len(new_around) <= 4:
for dx,dy in new_around:
for dx2,dy2 in new_around:
if dx*3+dy >= dx2*3+dy2:
continue
for dx3,dy3 in new_around:
if dx2*3+dy2 >= dx3*3+dy3:
continue
cnf = {}
cnfs.append(cnf)
cnf[f'l{l}_{j}_{i}'] = '-1'
cnf[f'l{l}_{j+dx}_{i+dy}'] = '-1'
cnf[f'l{l}_{j+dx2}_{i+dy2}'] = '-1'
cnf[f'l{l}_{j+dx3}_{i+dy3}'] = '-1'
# blockがはみ出る位置には置けない
for i in range(h):
for j in range(w):
for b in range(1, num_b+1):
cells = Q['BLOCK'][b]['cells']
for dx,dy in cells:
if j+dx < 0 or j+dx >= w or i+dy < 0 or i+dy >= h:
cnf = {}
cnfs.append(cnf)
cnf[f'b{b}_{j}_{i}'] = '-1'
break
# block位置に対してライン位置が決定
for i in range(h):
for j in range(w):
for b in range(1, num_b+1):
cells = Q['BLOCK'][b]['cells']
flag = False
for dx,dy in cells:
if j+dx < 0 or j+dx >= w or i+dy < 0 or i+dy >= h:
break
l_str = cells[(dx,dy)]
if l_str == '+':
for l in range(1,num_l+1):
cnf = {}
cnfs.append(cnf)
cnf[f'b{b}_{j}_{i}'] = '-1'
cnf[f'l{l}_{j+dx}_{i+dy}'] = '-1'
elif l_str != '0':
for l in range(1,num_l+1):
cnf = {}
cnfs.append(cnf)
cnf[f'b{b}_{j}_{i}'] = '-1'
if l == int(l_str):
cnf[f'l{l}_{j+dx}_{i+dy}'] = '1'
else:
cnf[f'l{l}_{j+dx}_{i+dy}'] = '-1'
# block同士の衝突禁止
for i in range(h):
for j in range(w):
for b in range(1, num_b+1):
cells = Q['BLOCK'][b]['cells']
for i2 in range(i-4,i+5):
if i2 < 0 or i2 >= h:
continue
for j2 in range(j-4,j+5):
if j2 < 0 or j2 >= w:
continue
for b2 in range(b+1, num_b+1):
cells2 = Q['BLOCK'][b2]['cells']
flag = False
for dx,dy in cells.keys():
if flag:
break
if j+dx < 0 or j+dx >= w or i+dy < 0 or i+dy >= h:
continue
for dx2,dy2 in cells2.keys():
if j2+dx2 < 0 or j2+dx2 >= w or i2+dy2 < 0 or i2+dy2 >= h:
continue
if i+dy == i2+dy2 and j+dx == j2+dx2:
if not flag:
flag = True
cnf = {}
cnfs.append(cnf)
cnf[f'b{b}_{j}_{i}'] = '-1'
cnf[f'b{b2}_{j2}_{i2}'] = '-1'
# block上にないlineは周囲2マスに接続
for i in range(h):
for j in range(w):
for l in range(1, num_l+1):
around = [(1,0),(-1,0),(0,1),(0,-1)]
new_around = []
for dx,dy in around:
if j+dx < 0 or j+dx >= w or i+dy < 0 or i+dy >= h:
continue
new_around.append((dx,dy))
bdxdylist = l2b[l]
if len(new_around) <= 2:
for dx,dy in new_around:
cnf = {}
cnfs.append(cnf)
for b,mdx,mdy in bdxdylist:
if i+mdy < 0 or i+mdy >= h or j+mdx < 0 or j+mdx >= w:
continue
cnf[f'b{b}_{j+mdx}_{i+mdy}'] = '1'
cnf[f'l{l}_{j}_{i}'] = '-1'
cnf[f'l{l}_{j+dx}_{i+dy}'] = '1'
elif len(new_around) <= 3:
for dx,dy in new_around:
for dx2,dy2 in new_around:
if dx*3+dy <= dx2*3+dy2:
continue
cnf = {}
cnfs.append(cnf)
for b,mdx,mdy in bdxdylist:
if i+mdy < 0 or i+mdy >= h or j+mdx < 0 or j+mdx >= w:
continue
cnf[f'b{b}_{j+mdx}_{i+mdy}'] = '1'
cnf[f'l{l}_{j}_{i}'] = '-1'
cnf[f'l{l}_{j+dx}_{i+dy}'] = '1'
cnf[f'l{l}_{j+dx2}_{i+dy2}'] = '1'
elif len(new_around) == 4:
for dx,dy in new_around:
for dx2,dy2 in new_around:
if dx*3+dy >= dx2*3+dy2:
continue
for dx3,dy3 in new_around:
if dx2*3+dy2 >= dx3*3+dy3:
continue
cnf = {}
cnfs.append(cnf)
for b,mdx,mdy in bdxdylist:
if i+mdy < 0 or i+mdy >= h or j+mdx < 0 or j+mdx >= w:
continue
cnf[f'b{b}_{j+mdx}_{i+mdy}'] = '1'
cnf[f'l{l}_{j}_{i}'] = '-1'
cnf[f'l{l}_{j+dx}_{i+dy}'] = '1'
cnf[f'l{l}_{j+dx2}_{i+dy2}'] = '1'
cnf[f'l{l}_{j+dx3}_{i+dy3}'] = '1'
# block上にあるlineは周囲2マスに接続禁止
for i in range(h):
for j in range(w):
for l in range(1, num_l+1):
around = [(1,0),(-1,0),(0,1),(0,-1)]
new_around = []
for dx,dy in around:
if j+dx < 0 or j+dx >= w or i+dy < 0 or i+dy >= h:
continue
new_around.append((dx,dy))
bdxdylist = l2b[l]
if len(new_around) <= 2:
dx,dy = new_around[0]
dx2,dy2 = new_around[1]
for b,mdx,mdy in bdxdylist:
if i+mdy < 0 or i+mdy >= h or j+mdx < 0 or j+mdx >= w:
continue
cnf = {}
cnfs.append(cnf)
cnf[f'b{b}_{j+mdx}_{i+mdy}'] = '-1'
cnf[f'l{l}_{j}_{i}'] = '-1'
cnf[f'l{l}_{j+dx}_{i+dy}'] = '-1'
cnf[f'l{l}_{j+dx2}_{i+dy2}'] = '-1'
elif len(new_around) <= 3:
for dx,dy in new_around:
for dx2,dy2 in new_around:
if dx*3+dy >= dx2*3+dy2:
continue
for b,mdx,mdy in bdxdylist:
if i+mdy < 0 or i+mdy >= h or j+mdx < 0 or j+mdx >= w:
continue
cnf = {}
cnfs.append(cnf)
cnf[f'b{b}_{j+mdx}_{i+mdy}'] = '-1'
cnf[f'l{l}_{j}_{i}'] = '-1'
cnf[f'l{l}_{j+dx}_{i+dy}'] = '-1'
cnf[f'l{l}_{j+dx2}_{i+dy2}'] = '-1'
elif len(new_around) == 4:
for dx,dy in new_around:
for dx2,dy2 in new_around:
if dx*3+dy >= dx2*3+dy2:
continue
for b,mdx,mdy in bdxdylist:
if i+mdy < 0 or i+mdy >= h or j+mdx < 0 or j+mdx >= w:
continue
cnf = {}
cnfs.append(cnf)
cnf[f'b{b}_{j+mdx}_{i+mdy}'] = '-1'
cnf[f'l{l}_{j}_{i}'] = '-1'
cnf[f'l{l}_{j+dx}_{i+dy}'] = '-1'
cnf[f'l{l}_{j+dx2}_{i+dy2}'] = '-1'
return nodes,cnfs
def main():
Q = readQ('./adc2019problem/Q001_10X10_b8_l11.txt')
#Q = readQ('./adc2019problem/Q002_10X10_b8_l9.txt')
#Q = readQ('./adc2019problem/Q003_10X10_b5_l5.txt')
#Q = readQ('./adc2019problem/Q004_10X10_b8_l9.txt')
#Q = readQ('./adc2019problem/Q005_10X10_b7_l6.txt')
#Q = readQ('./adc2019problem/Q006_10X10_b8_l9.txt')
#Q = readQ('./adc2019problem/Q007_10X10_b7_l8.txt')
#Q = readQ('./adc2019problem/Q008_10X10_b10_l0.txt')
#Q = readQ('./adc2019problem/Q009_10X10_b7_l9.txt')
#Q = readQ('./adc2019problem/Q010_10X10_b20_l24.txt')
#Q = readQ('./adc2019problem/Q013_10X10_b8_l9.txt')
#Q = readQ('./adc2019problem/Q014_10X10_b9_l9.txt')
#Q = readQ('./adc2019problem/Q015_10X10_b8_l9.txt')
start_time = time.time()
nodes,cnfs = generate_sat(Q)
print_cnf(nodes, cnfs)
os.system('minisat p.txt a.txt')
A = read_satA('a.txt', Q, nodes)
#A = debug_satA('a.txt', Q, nodes)
end_time= time.time()
print(f"経過時間:{end_time - start_time}")
img = QA2img(Q,A)
plt.imshow(img)
main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment