1 from logic import Formula, Relation, Conjunction, Disjunction, Negation, Equivalence 2 from subprocess import Popen, PIPE 3 4 __all__ = ['UNARY', 'BINARY', 'SAT_variable', 'SAT'] 5 6 def get_process_cputime(pid): 7 line = open('/proc/'+str(pid)+'/stat', 'r').read() 8 line = line[line.rindex(')') + 2:] 9 return int(line.split(' ')[11]) 10 11 def satsolver_write(file, satvars, clauses): 12 file.write('p cnf '+str(satvars)+' '+str(len(clauses))+'\n') 13 for clause in clauses: 14 file.write((' '.join(map(str, clause))) + ' 0\n') 15 file.close() 16 17 def picosat_solver(satvars, clauses): 18 p = Popen(['/usr/local/bin/picosat'], stdin=PIPE, stdout=PIPE, universal_newlines=True) 19 satsolver_write(p.stdin, satvars, clauses) 20 values = False 21 if p.stdout.readline() == 's SATISFIABLE\n': 22 values = [None] 23 line = p.stdout.read() 24 line = line[2:line.index(' 0')].replace('\nv', '') 25 for i in line.split(' '): 26 values.append(1 if int(i) > 0 else 0) 27 cputime = get_process_cputime(p.pid) 28 p.wait() 29 return (values, cputime) 30 31 UNARY = 1 32 BINARY = 2 33 34 class SAT_variable(object): 35 def __init__(self, satvar, name, lower_bound=0, upper_bound=1, encoding=BINARY): 36 # osetrit vstupne parametre 37 self.name = name 38 self.lower_bound = lower_bound 39 self.upper_bound = upper_bound 40 self.encoding = encoding 41 42 satvar += 1 43 if encoding == BINARY: 44 bitlen = (upper_bound - lower_bound).bit_length() 45 self.satvar = range(satvar, satvar + bitlen) 46 self.bounded = self.different_values() == (1 << bitlen) 47 else: 48 self.satvar = range(satvar, satvar + upper_bound - lower_bound + 1) 49 50 def set_value(self, values): 51 if self.encoding == BINARY: 52 value = 0 53 for literal in reversed(self.satvar): 54 value <<= 1 55 value |= values[literal] 56 else: 57 value = 0 58 for literal in self.satvar: 59 if values[literal] == 1: 60 break 61 value += 1 62 self.value = self.lower_bound + value 63 64 def different_values(self): 65 return self.upper_bound - self.lower_bound + 1 66 67 def is_boolean(self): 68 return self.lower_bound == 0 and self.upper_bound == 1 and self.encoding == BINARY 69 70 def __eq__(self, other): 71 return self.name == other.name 72 73 def __ne__(self, other): 74 return self.name != other.name 75 76 def __hash__(self): 77 return hash(self.name) 78 79 def __str__(self): 80 return self.name 81 82 def satvar_and_int(a, b): 83 return isinstance(a, SAT_variable) and isinstance(b, int) 84 85 class SAT(object): 86 satvars = 0 87 vars = {} 88 clauses = [] 89 solved = False 90 91 def __init__(self, solver=picosat_solver): 92 self.solver = solver 93 94 def new_boolean_variable(self, name): 95 if name in self.vars: 96 raise AttributeError('variable \'%s\' already exists' % (name,)) 97 self.vars[name] = SAT_variable(self.satvars, name) 98 self.satvars += 1 99 100 def new_integer_variable(self, name, lower_bound, upper_bound, 101 encoding=BINARY): 102 if name in self.vars: 103 raise AttributeError('variable \'%s\' already exists' % (name,)) 104 var = SAT_variable(self.satvars, name, lower_bound, upper_bound, encoding) 105 self.vars[name] = var 106 length = len(var.satvar) 107 self.satvars += length 108 if encoding == BINARY: 109 if not var.bounded: 110 self.add_constraint(Relation('<', var, var.different_values(), infix=True)) 111 var.bounded = True 112 else: 113 for i in range(length): 114 for j in range(i+1, length): 115 self.clauses.append([-var.satvar[i], -var.satvar[j]]) 116 self.clauses.append(list(var.satvar)) 117 118 def __getitem__(self, var): 119 if var not in self.vars: 120 raise AttributeError('unknown variable \'%s\'' % (name,)) 121 return self.vars[var].value 122 123 @staticmethod 124 def __rel_unsupported(relation): 125 return AttributeError('unsupported relation: %s' % (relation,)) 126 127 @staticmethod 128 def __arg_unsupported(relation): 129 return AttributeError('unsupported arguments for relation: %s' % (relation,)) 130 131 @staticmethod 132 def __rel_always_holds(relation): 133 return AttributeError('relation always holds: %s' % (relation,)) 134 135 @staticmethod 136 def __rel_never_holds(relation): 137 return AttributeError('relation never holds: %s' % (relation,)) 138 139 def __eq(self, relation): 140 a, b = relation.args 141 if isinstance(a, int): 142 a, b = b, a 143 if isinstance(a, SAT_variable) and isinstance(b, SAT_variable): 144 if a.encoding != b.encoding: 145 raise AttributeError('relation between variables of different encoding: %s' % (relation,)) 146 if a.lower_bound != b.lower_bound: 147 raise AttributeError('relation between variables of different lower bound: %s' % (relation,)) 148 if len(a.satvar) > len(b.satvar): 149 a, b = b, a 150 con = [] 151 for i in range(len(a.satvar)): 152 con.append(Equivalence(Relation(a.satvar[i]), Relation(b.satvar[i]))) 153 for i in range(len(a.satvar), len(b.satvar)): 154 con.append(Negation(Relation(b.satvar[i]))) 155 return Conjunction(*con) 156 elif satvar_and_int(a, b): 157 if a.lower_bound > b or a.upper_bound < b: 158 raise AttributeError('relation out of bounds: ' % (relation,)) 159 b -= a.lower_bound 160 if a.encoding == BINARY: 161 con = [] 162 for sv in a.satvar: 163 rel = Relation(sv) 164 if not (b & 1): 165 rel = Negation(rel) 166 con.append(rel) 167 b >>= 1 168 return Conjunction(*con) 169 else: 170 return Relation(a.satvar[b]) 171 else: 172 raise self.__arg_unsupported(relation) 173 174 def __ne(self, relation): 175 return Negation(self.__eq(relation)) 176 177 def __lt(self, relation, args=None): 178 if args: 179 a, b = args 180 else: 181 a, b = relation.args 182 if satvar_and_int(b, a): 183 if relation.rel == '<': 184 return self.__gt(relation) 185 else: 186 a, b = b, a 187 if satvar_and_int(a, b): 188 if a.lower_bound >= b: 189 raise self.__rel_never_holds(relation) 190 elif a.upper_bound < b and (a.encoding != BINARY or a.bounded): 191 raise self.__rel_always_holds(relation) 192 b -= a.lower_bound 193 if a.encoding == BINARY: 194 # Example 195 # X < 153 196 # abcdefgh < 10011001 197 # a → (¬b ∧ ¬c ∧ (d → (e → (¬f ∧ ¬g ∧ ¬h))) 198 # ¬a ∨ (¬b ∧ ¬c ∧ (¬d ∨ (¬e ∨ (¬f ∧ ¬g ∧ ¬h)))) 199 # ¬a ∨ (¬b ∧ ¬c ∧ (¬d ∨ ¬e ∨ (¬f ∧ ¬g ∧ ¬h))) 200 # 201 # X < 152 202 # abcdefgh < 10011000 203 # a → (¬b ∧ ¬c ∧ (d → ¬e)) 204 # ¬a ∨ (¬b ∧ ¬c ∧ (¬d ∨ ¬e)) 205 206 # Starting from lowest 1 (if it isn't zero in the most inner 207 # implication, the relation does not hold). 208 result = [] 209 lastbit = None 210 for literal in a.satvar: 211 bit = b & 1 212 b >>= 1 213 rel = Negation(Relation(literal)) 214 if not result: 215 if bit == 1: 216 result = [rel] 217 continue 218 elif bit == 1 and lastbit == 0: 219 result = [Conjunction(*result)] 220 elif bit == 0 and lastbit == 1: 221 result = [Disjunction(*result)] 222 result.append(rel) 223 lastbit = bit 224 if lastbit == 1: 225 result = Disjunction(*result) 226 else: 227 result = Conjunction(*result) 228 return result 229 else: 230 downs = b 231 ups = a.different_values() - downs 232 if downs < ups: 233 return Conjunction(*map(Relation, a.satvar[:downs])) 234 else: 235 return Conjunction(*map(lambda x: Negation(Relation(x)), a.satvar[downs:])) 236 else: 237 raise self.__arg_unsupported(relation) 238 239 def __gt(self, relation, args=None): 240 if args: 241 a, b = args 242 else: 243 a, b = relation.args 244 if satvar_and_int(b, a): 245 if relation.rel == '>': 246 return self.__lt(relation) 247 else: 248 a, b = b, a 249 if satvar_and_int(a, b): 250 if a.upper_bound <= b: 251 raise self.__rel_never_holds(relation) 252 elif a.lower_bound > b: 253 raise self.__rel_always_holds(relation) 254 b -= a.lower_bound 255 if a.encoding == BINARY: 256 # Example 257 # X > 153 258 # abcdefgh > 10011001 259 # a ∧ (¬b → (¬c → (d ∧ e ∧ (¬f → g)))) 260 # a ∧ (b ∨ (c ∨ (d ∧ e ∧ (f ∨ g)))) 261 # a ∧ (b ∨ c ∨ (d ∧ e ∧ (f ∨ g))) 262 263 # Starting from lowest 0 (if it isn't one in the most inner 264 # implication, the relation does not hold). 265 result = [] 266 lastbit = None 267 for literal in a.satvar: 268 bit = b & 1 269 b >>= 1 270 if not result: 271 if bit == 0: 272 result = [Relation(literal)] 273 continue 274 elif bit == 1 and lastbit == 0: 275 result = [Disjunction(*result)] 276 elif bit == 0 and lastbit == 1: 277 result = [Conjunction(*result)] 278 result.append(Relation(literal)) 279 lastbit = bit 280 if lastbit == 1: 281 result = Conjunction(*result) 282 else: 283 result = Disjunction(*result) 284 return result 285 else: 286 downs = b + 1 287 ups = a.different_values() - downs 288 if downs < ups: 289 return Conjunction(*map(lambda x: Negation(Relation(x)), a.satvar[downs:])) 290 else: 291 return Conjunction(*map(Relation, a.satvar[:downs])) 292 else: 293 raise self.__arg_unsupported(relation) 294 295 def __le(self, relation): 296 a, b = relation.args 297 if satvar_and_int(b, a): 298 # a ≤ b → b > a-1 299 return self.__gt(relation, (b, a-1)) 300 elif satvar_and_int(a, b): 301 # a ≤ b → a < b+1 302 return self.__lt(relation, (a, b+1)) 303 else: 304 raise self.__arg_unsupported(relation) 305 306 def __ge(self, relation): 307 a, b = relation.args 308 if satvar_and_int(b, a): 309 # a ≥ b → b < a+1 310 return self.__lt(relation, (b, a+1)) 311 elif satvar_and_int(a, b): 312 # a ≥ b → a > b-1 313 return self.__gt(relation, (a, b-1)) 314 else: 315 raise self.__arg_unsupported(relation) 316 317 def __relation_transformer(self, relation): 318 if relation.arity() == 0: 319 name = relation.name() 320 var = self[name] 321 if not var.is_boolean(): 322 raise AttributeError('variable \'%s\' not boolean' % (name,)) 323 return Relation(var.satvar[0]) 324 elif relation.arity() == 2: 325 transformers = { 326 '=': self.__eq, 327 '≠': self.__ne, 328 '<': self.__lt, 329 '>': self.__gt, 330 '≤': self.__le, 331 '≥': self.__ge, 332 } 333 if relation.name() not in transformers: 334 raise self.__rel_unsupported(relation) 335 transformer = transformers[relation.name()] 336 a, b = relation.args 337 if isinstance(a, str): 338 a = self.vars[a] 339 if isinstance(b, str): 340 b = self.vars[b] 341 return transformer(Relation(relation.name(), a, b, infix=True)) 342 else: 343 raise self.__rel_unsupported(relation) 344 345 @staticmethod 346 def __map2lit(x): 347 if isinstance(x, Negation): 348 return -x.a.name() 349 else: 350 return x.name() 351 352 def add_constraint(self, constraint): 353 if isinstance(constraint, Formula): 354 formula = constraint 355 else: 356 formula = Formula.parse(constraint) 357 formula = formula.transform_relations(self.__relation_transformer) 358 formula = formula.cnf() 359 if not isinstance(formula, Conjunction): 360 formula = [formula] 361 for dis in formula: 362 if isinstance(dis, Disjunction): 363 self.clauses.append(map(self.__map2lit, dis)) 364 elif isinstance(dis, Negation): 365 self.clauses.append((-dis.a.name(),)) 366 else: 367 self.clauses.append((dis.name(),)) 368 369 def write_to_file(file): 370 if isinstance(file, str): 371 file = open(file, 'w') 372 satsolver_write(file, self.satvars, self.clauses) 373 374 def solve(self): 375 values, cputime = self.solver(self.satvars, self.clauses) 376 self.cputime = cputime 377 if values: 378 for name in self.vars: 379 self.vars[name].set_value(values) 380 self.solved = True 381 return self.solved 382 383 if __name__ == '__main__': 384 sat = SAT() 385 for i in 'abcd': 386 sat.new_integer_variable(i, 5, 21) 387 sat.add_constraint('a < 10 -> b > 9') 388 sat.add_constraint('b > 15 -> c < 8') 389 sat.add_constraint('c < 6 -> d = 17') 390 sat.add_constraint('a < 10 -> (a = d ∨ b = c)') 391 sat.add_constraint('a >= 10 -> c < 7') 392 sat.add_constraint('¬(c = 5) ∧ ¬(d = 5)') 393 sat.solve() 394 for i in 'abcd': 395 print(i+' = '+str(sat[i])) 396