source of highlighter
plain | download
    1 from logic import Formula, Relation, Conjunction, Disjunction, Negation, Equivalence
    2 from subprocess import Popen, PIPE
    3 
    4 __all__ = ['UNARY', 'BINARY', 'SAT_variable', 'SAT']
    5 
    6 def get_process_cputime(pid):
    7     line = open('/proc/'+str(pid)+'/stat', 'r').read()
    8     line = line[line.rindex(')') + 2:]
    9     return int(line.split(' ')[11])
   10 
   11 def satsolver_write(file, satvars, clauses):
   12     file.write('p cnf '+str(satvars)+' '+str(len(clauses))+'\n')
   13     for clause in clauses:
   14         file.write((' '.join(map(str, clause))) + ' 0\n')
   15     file.close()
   16 
   17 def picosat_solver(satvars, clauses):
   18     p = Popen(['/usr/local/bin/picosat'], stdin=PIPE, stdout=PIPE, universal_newlines=True)
   19     satsolver_write(p.stdin, satvars, clauses)
   20     values = False
   21     if p.stdout.readline() == 's SATISFIABLE\n':
   22         values = [None]
   23         line = p.stdout.read()
   24         line = line[2:line.index(' 0')].replace('\nv', '')
   25         for i in line.split(' '):
   26             values.append(1 if int(i) > 0 else 0)
   27     cputime = get_process_cputime(p.pid)
   28     p.wait()
   29     return (values, cputime)
   30 
   31 UNARY = 1
   32 BINARY = 2
   33 
   34 class SAT_variable(object):
   35     def __init__(self, satvar, name, lower_bound=0, upper_bound=1, encoding=BINARY):
   36         # osetrit vstupne parametre
   37         self.name = name
   38         self.lower_bound = lower_bound
   39         self.upper_bound = upper_bound
   40         self.encoding = encoding
   41 
   42         satvar += 1
   43         if encoding == BINARY:
   44             bitlen = (upper_bound - lower_bound).bit_length()
   45             self.satvar = range(satvar, satvar + bitlen)
   46             self.bounded = self.different_values() == (1 << bitlen)
   47         else:
   48             self.satvar = range(satvar, satvar + upper_bound - lower_bound + 1)
   49 
   50     def set_value(self, values):
   51         if self.encoding == BINARY:
   52             value = 0
   53             for literal in reversed(self.satvar):
   54                 value <<= 1
   55                 value |= values[literal]
   56         else:
   57             value = 0
   58             for literal in self.satvar:
   59                 if values[literal] == 1:
   60                     break
   61                 value += 1
   62         self.value = self.lower_bound + value
   63 
   64     def different_values(self):
   65         return self.upper_bound - self.lower_bound + 1
   66 
   67     def is_boolean(self):
   68         return self.lower_bound == 0 and self.upper_bound == 1 and self.encoding == BINARY
   69 
   70     def __eq__(self, other):
   71         return self.name == other.name
   72 
   73     def __ne__(self, other):
   74         return self.name != other.name
   75 
   76     def __hash__(self):
   77         return hash(self.name)
   78 
   79     def __str__(self):
   80         return self.name
   81 
   82 def satvar_and_int(a, b):
   83     return isinstance(a, SAT_variable) and isinstance(b, int)
   84 
   85 class SAT(object):
   86     satvars = 0
   87     vars = {}
   88     clauses = []
   89     solved = False
   90 
   91     def __init__(self, solver=picosat_solver):
   92         self.solver = solver
   93 
   94     def new_boolean_variable(self, name):
   95         if name in self.vars:
   96             raise AttributeError('variable \'%s\' already exists' % (name,))
   97         self.vars[name] = SAT_variable(self.satvars, name)
   98         self.satvars += 1
   99 
  100     def new_integer_variable(self, name, lower_bound, upper_bound,
  101                              encoding=BINARY):
  102         if name in self.vars:
  103             raise AttributeError('variable \'%s\' already exists' % (name,))
  104         var = SAT_variable(self.satvars, name, lower_bound, upper_bound, encoding)
  105         self.vars[name] = var
  106         length = len(var.satvar)
  107         self.satvars += length
  108         if encoding == BINARY:
  109             if not var.bounded:
  110                 self.add_constraint(Relation('<', var, var.different_values(), infix=True))
  111                 var.bounded = True
  112         else:
  113             for i in range(length):
  114                 for j in range(i+1, length):
  115                     self.clauses.append([-var.satvar[i], -var.satvar[j]])
  116             self.clauses.append(list(var.satvar))
  117 
  118     def __getitem__(self, var):
  119         if var not in self.vars:
  120             raise AttributeError('unknown variable \'%s\'' % (name,))
  121         return self.vars[var].value
  122 
  123     @staticmethod
  124     def __rel_unsupported(relation):
  125         return AttributeError('unsupported relation: %s' % (relation,))
  126 
  127     @staticmethod
  128     def __arg_unsupported(relation):
  129         return AttributeError('unsupported arguments for relation: %s' % (relation,))
  130 
  131     @staticmethod
  132     def __rel_always_holds(relation):
  133         return AttributeError('relation always holds: %s' % (relation,))
  134 
  135     @staticmethod
  136     def __rel_never_holds(relation):
  137         return AttributeError('relation never holds: %s' % (relation,))
  138 
  139     def __eq(self, relation):
  140         a, b = relation.args
  141         if isinstance(a, int):
  142             a, b = b, a
  143         if isinstance(a, SAT_variable) and isinstance(b, SAT_variable):
  144             if a.encoding != b.encoding:
  145                 raise AttributeError('relation between variables of different encoding: %s' % (relation,))
  146             if a.lower_bound != b.lower_bound:
  147                 raise AttributeError('relation between variables of different lower bound: %s' % (relation,))
  148             if len(a.satvar) > len(b.satvar):
  149                 a, b = b, a
  150             con = []
  151             for i in range(len(a.satvar)):
  152                 con.append(Equivalence(Relation(a.satvar[i]), Relation(b.satvar[i])))
  153             for i in range(len(a.satvar), len(b.satvar)):
  154                 con.append(Negation(Relation(b.satvar[i])))
  155             return Conjunction(*con)
  156         elif satvar_and_int(a, b):
  157             if a.lower_bound > b or a.upper_bound < b:
  158                 raise AttributeError('relation out of bounds: ' % (relation,))
  159             b -= a.lower_bound
  160             if a.encoding == BINARY:
  161                 con = []
  162                 for sv in a.satvar:
  163                     rel = Relation(sv)
  164                     if not (b & 1):
  165                         rel = Negation(rel)
  166                     con.append(rel)
  167                     b >>= 1
  168                 return Conjunction(*con)
  169             else:
  170                 return Relation(a.satvar[b])
  171         else:
  172             raise self.__arg_unsupported(relation)
  173 
  174     def __ne(self, relation):
  175         return Negation(self.__eq(relation))
  176 
  177     def __lt(self, relation, args=None):
  178         if args:
  179             a, b = args
  180         else:
  181             a, b = relation.args
  182             if satvar_and_int(b, a):
  183                 if relation.rel == '<':
  184                     return self.__gt(relation)
  185                 else:
  186                     a, b = b, a
  187         if satvar_and_int(a, b):
  188             if a.lower_bound >= b:
  189                 raise self.__rel_never_holds(relation)
  190             elif a.upper_bound < b and (a.encoding != BINARY or a.bounded):
  191                 raise self.__rel_always_holds(relation)
  192             b -= a.lower_bound
  193             if a.encoding == BINARY:
  194                 # Example
  195                 #           X < 153
  196                 #    abcdefgh < 10011001
  197                 # a → (¬b ∧ ¬c ∧ (d → (e → (¬f ∧ ¬g ∧ ¬h)))
  198                 # ¬a ∨ (¬b ∧ ¬c ∧ (¬d ∨ (¬e ∨ (¬f ∧ ¬g ∧ ¬h))))
  199                 # ¬a ∨ (¬b ∧ ¬c ∧ (¬d ∨ ¬e ∨ (¬f ∧ ¬g ∧ ¬h)))
  200                 #
  201                 #           X < 152
  202                 #    abcdefgh < 10011000
  203                 # a → (¬b ∧ ¬c ∧ (d → ¬e))
  204                 # ¬a ∨ (¬b ∧ ¬c ∧ (¬d ∨ ¬e))
  205 
  206                 # Starting from lowest 1 (if it isn't zero in the most inner
  207                 # implication, the relation does not hold).
  208                 result = []
  209                 lastbit = None
  210                 for literal in a.satvar:
  211                     bit = b & 1
  212                     b >>= 1
  213                     rel = Negation(Relation(literal))
  214                     if not result:
  215                         if bit == 1:
  216                             result = [rel]
  217                         continue
  218                     elif bit == 1 and lastbit == 0:
  219                         result = [Conjunction(*result)]
  220                     elif bit == 0 and lastbit == 1:
  221                         result = [Disjunction(*result)]
  222                     result.append(rel)
  223                     lastbit = bit
  224                 if lastbit == 1:
  225                     result = Disjunction(*result)
  226                 else:
  227                     result = Conjunction(*result)
  228                 return result
  229             else:
  230                 downs = b
  231                 ups = a.different_values() - downs
  232                 if downs < ups:
  233                     return Conjunction(*map(Relation, a.satvar[:downs]))
  234                 else:
  235                     return Conjunction(*map(lambda x: Negation(Relation(x)), a.satvar[downs:]))
  236         else:
  237             raise self.__arg_unsupported(relation)
  238 
  239     def __gt(self, relation, args=None):
  240         if args:
  241             a, b = args
  242         else:
  243             a, b = relation.args
  244             if satvar_and_int(b, a):
  245                 if relation.rel == '>':
  246                     return self.__lt(relation)
  247                 else:
  248                     a, b = b, a
  249         if satvar_and_int(a, b):
  250             if a.upper_bound <= b:
  251                 raise self.__rel_never_holds(relation)
  252             elif a.lower_bound > b:
  253                 raise self.__rel_always_holds(relation)
  254             b -= a.lower_bound
  255             if a.encoding == BINARY:
  256                 # Example
  257                 #           X > 153
  258                 #    abcdefgh > 10011001
  259                 # a ∧ (¬b → (¬c → (d ∧ e ∧ (¬f → g))))
  260                 # a ∧ (b ∨ (c ∨ (d ∧ e ∧ (f ∨ g))))
  261                 # a ∧ (b ∨ c ∨ (d ∧ e ∧ (f ∨ g)))
  262 
  263                 # Starting from lowest 0 (if it isn't one in the most inner
  264                 # implication, the relation does not hold).
  265                 result = []
  266                 lastbit = None
  267                 for literal in a.satvar:
  268                     bit = b & 1
  269                     b >>= 1
  270                     if not result:
  271                         if bit == 0:
  272                             result = [Relation(literal)]
  273                         continue
  274                     elif bit == 1 and lastbit == 0:
  275                         result = [Disjunction(*result)]
  276                     elif bit == 0 and lastbit == 1:
  277                         result = [Conjunction(*result)]
  278                     result.append(Relation(literal))
  279                     lastbit = bit
  280                 if lastbit == 1:
  281                     result = Conjunction(*result)
  282                 else:
  283                     result = Disjunction(*result)
  284                 return result
  285             else:
  286                 downs = b + 1
  287                 ups = a.different_values() - downs
  288                 if downs < ups:
  289                     return Conjunction(*map(lambda x: Negation(Relation(x)), a.satvar[downs:]))
  290                 else:
  291                     return Conjunction(*map(Relation, a.satvar[:downs]))
  292         else:
  293             raise self.__arg_unsupported(relation)
  294 
  295     def __le(self, relation):
  296         a, b = relation.args
  297         if satvar_and_int(b, a):
  298             # a ≤ b → b > a-1
  299             return self.__gt(relation, (b, a-1))
  300         elif satvar_and_int(a, b):
  301             # a ≤ b → a < b+1
  302             return self.__lt(relation, (a, b+1))
  303         else:
  304             raise self.__arg_unsupported(relation)
  305 
  306     def __ge(self, relation):
  307         a, b = relation.args
  308         if satvar_and_int(b, a):
  309             # a ≥ b → b < a+1
  310             return self.__lt(relation, (b, a+1))
  311         elif satvar_and_int(a, b):
  312             # a ≥ b → a > b-1
  313             return self.__gt(relation, (a, b-1))
  314         else:
  315             raise self.__arg_unsupported(relation)
  316 
  317     def __relation_transformer(self, relation):
  318         if relation.arity() == 0:
  319             name = relation.name()
  320             var = self[name]
  321             if not var.is_boolean():
  322                 raise AttributeError('variable \'%s\' not boolean' % (name,))
  323             return Relation(var.satvar[0])
  324         elif relation.arity() == 2:
  325             transformers = {
  326                 '=': self.__eq,
  327                 '≠': self.__ne,
  328                 '<': self.__lt,
  329                 '>': self.__gt,
  330                 '≤': self.__le,
  331                 '≥': self.__ge,
  332             }
  333             if relation.name() not in transformers:
  334                 raise self.__rel_unsupported(relation)
  335             transformer = transformers[relation.name()]
  336             a, b = relation.args
  337             if isinstance(a, str):
  338                 a = self.vars[a]
  339             if isinstance(b, str):
  340                 b = self.vars[b]
  341             return transformer(Relation(relation.name(), a, b, infix=True))
  342         else:
  343             raise self.__rel_unsupported(relation)
  344 
  345     @staticmethod
  346     def __map2lit(x):
  347         if isinstance(x, Negation):
  348             return -x.a.name()
  349         else:
  350             return x.name()
  351 
  352     def add_constraint(self, constraint):
  353         if isinstance(constraint, Formula):
  354             formula = constraint
  355         else:
  356             formula = Formula.parse(constraint)
  357         formula = formula.transform_relations(self.__relation_transformer)
  358         formula = formula.cnf()
  359         if not isinstance(formula, Conjunction):
  360             formula = [formula]
  361         for dis in formula:
  362             if isinstance(dis, Disjunction):
  363                 self.clauses.append(map(self.__map2lit, dis))
  364             elif isinstance(dis, Negation):
  365                 self.clauses.append((-dis.a.name(),))
  366             else:
  367                 self.clauses.append((dis.name(),))
  368 
  369     def write_to_file(file):
  370         if isinstance(file, str):
  371             file = open(file, 'w')
  372         satsolver_write(file, self.satvars, self.clauses)
  373 
  374     def solve(self):
  375         values, cputime = self.solver(self.satvars, self.clauses)
  376         self.cputime = cputime
  377         if values:
  378             for name in self.vars:
  379                 self.vars[name].set_value(values)
  380             self.solved = True
  381         return self.solved
  382 
  383 if __name__ == '__main__':
  384     sat = SAT()
  385     for i in 'abcd':
  386         sat.new_integer_variable(i, 5, 21)
  387     sat.add_constraint('a < 10 -> b > 9')
  388     sat.add_constraint('b > 15 -> c < 8')
  389     sat.add_constraint('c < 6 -> d = 17')
  390     sat.add_constraint('a < 10 -> (a = d ∨ b = c)')
  391     sat.add_constraint('a >= 10 -> c < 7')
  392     sat.add_constraint('¬(c = 5) ∧ ¬(d = 5)')
  393     sat.solve()
  394     for i in 'abcd':
  395         print(i+' = '+str(sat[i]))
  396