source of highlighter
plain | download
    1 __all__ = ['Formula', 'Relation', 'Conjunction', 'Disjunction', 'Negation',
    2            'Implication', 'Equivalence']
    3 
    4 class Formula(object):
    5     def __eq__(self, other):
    6         return (self.__class__, self.__dict__) == (other.__class__, other.__dict__)
    7 
    8     def __ne__(self, other):
    9         return (self.__class__, self.__dict__) != (other.__class__, other.__dict__)
   10 
   11     def copy(self):
   12         res = self.__class__.__new__(self.__class__)
   13         res.__dict__ = self.__dict__
   14         return res
   15 
   16     def flatten(self):
   17         return self.copy()
   18 
   19     def basic(self):
   20         return self.copy()
   21 
   22     def demorgan(self):
   23         return self.copy()
   24 
   25     @staticmethod
   26     def _transform_relations(x, transformer):
   27         if isinstance(x, Relation):
   28             return transformer(x)
   29         else:
   30             return x.transform_relations()
   31 
   32     def is_atomic(self):
   33         return False
   34 
   35     def __distribute_ors(self):
   36         disjunction = list(self)
   37         conjunction = None
   38         for arg in disjunction:
   39             if isinstance(arg, Conjunction):
   40                 conjunction = arg
   41                 break
   42         if conjunction is not None:
   43             disjunction.remove(conjunction)
   44             result = []
   45             for arg in conjunction:
   46                 if isinstance(arg, Disjunction):
   47                     result.append(Disjunction(*(disjunction + list(arg))))
   48                 else:
   49                     result.append(Disjunction(*(disjunction + [arg])))
   50             return Conjunction(*result)
   51         else:
   52             return Disjunction(*disjunction)
   53 
   54     def cnf(self):
   55         result = self.basic().demorgan().flatten()
   56         if result.is_atomic():
   57             return result
   58         if isinstance(result, Disjunction):
   59             result = result.__distribute_ors()
   60             if isinstance(result, Disjunction):
   61                 return result
   62         again = True
   63         while again:
   64             again = False
   65             old_result = result
   66             result = []
   67             for arg in old_result:
   68                 if isinstance(arg, Disjunction):
   69                     arg = arg.__distribute_ors()
   70                 if isinstance(arg, Conjunction):
   71                     result += list(arg)
   72                     again = True
   73                 else:
   74                     result.append(arg)
   75             result = Conjunction(*result).demorgan().flatten()
   76         return result
   77 
   78 class Relation(Formula):
   79     def __init__(self, rel, *args, infix=False):
   80         self.rel = rel
   81         self.args = tuple(args)
   82         self.infix = len(self.args) == 2 and infix
   83 
   84     def arity(self):
   85         return len(self.args)
   86 
   87     def name(self):
   88         return self.rel
   89 
   90     def is_atomic(self):
   91         return True
   92 
   93     def transform_relations(self, transformer):
   94         return transformer(self)
   95 
   96     def __hash__(self):
   97         return hash((self.__class__, self.rel, self.args))
   98 
   99     def __str__(self):
  100         if not self.args:
  101             return str(self.rel)
  102         elif self.infix:
  103             return '(%s %s %s)' % (self.args[0], self.rel, self.args[1])
  104         else:
  105             return '%s(%s)' % (self.rel, ', '.join(map(str, self.args)))
  106 
  107 class NFormula(Formula):
  108     def __init__(self, *args):
  109         if not len(args):
  110             raise AttributeError('at least one argument is required')
  111         for arg in args:
  112             if not isinstance(arg, Formula):
  113                 raise TypeError('arguments must be instances of Formula')
  114         self.args = frozenset(args)
  115         if len(self.args) == 1:
  116             arg = list(args)[0]
  117             self.__class__ = arg.__class__
  118             self.__dict__ = arg.__dict__
  119 
  120     def __hash__(self):
  121         return hash((self.__class__, self.args))
  122 
  123     def __iter__(self):
  124         return iter(self.args)
  125 
  126     def flatten(self):
  127         args = []
  128         for arg in self:
  129             arg = arg.flatten()
  130             if isinstance(arg, self.__class__):
  131                 args += list(arg)
  132             else:
  133                 args.append(arg)
  134         return self.__class__(*args)
  135 
  136     def basic(self):
  137         return self.__class__(*map(lambda x: x.basic(), self))
  138 
  139     def demorgan(self):
  140         return self.__class__(*map(lambda x: x.demorgan(), self))
  141 
  142     def transform_relations(self, transformer):
  143         return self.__class__(*map(lambda x: x.transform_relations(transformer), self))
  144 
  145 class Conjunction(NFormula):
  146     def __str__(self):
  147         return '(' + (' ∧ '.join(map(str, self))) + ')'
  148 
  149 class Disjunction(NFormula):
  150     def __str__(self):
  151         return '(' + (' ∨ '.join(map(str, self))) + ')'
  152 
  153 class Negation(Formula):
  154     def __init__(self, a):
  155         if not isinstance(a, Formula):
  156             raise TypeError('argument must be instance of Formula')
  157         self.a = a
  158 
  159     def __hash__(self):
  160         return hash((self.__class__, self.a))
  161 
  162     def __str__(self):
  163         return '¬' + str(self.a)
  164 
  165     def is_atomic(self):
  166         return self.a.is_atomic()
  167 
  168     def flatten(self):
  169         if isinstance(self.a, Negation):
  170             return self.a.a.flatten()
  171         else:
  172             return Negation(self.a)
  173 
  174     def basic(self):
  175         return Negation(self.a.basic())
  176 
  177     def demorgan(self):
  178         res = self.a.demorgan()
  179         if isinstance(res, Conjunction):
  180             res = Disjunction(*map(lambda x: Negation(x).demorgan(), res))
  181         elif isinstance(res, Disjunction):
  182             res = Conjunction(*map(lambda x: Negation(x).demorgan(), res))
  183         else:
  184             res = Negation(res)
  185         return res
  186 
  187     def transform_relations(self, transformer):
  188         return Negation(self.a.transform_relations(transformer))
  189 
  190 class BinFormula(Formula):
  191     def __init__(self, a, b):
  192         if not isinstance(a, Formula) or not isinstance(b, Formula):
  193             return TypeError('arguments must be instances of Formula')
  194         self.a = a
  195         self.b = b
  196 
  197     def __hash__(self):
  198         return hash((self.__class__, self.a, self.b))
  199 
  200     def flatten(self):
  201         return self.__class__(self.a.flatten(), self.b.flatten())
  202 
  203     def demorgan(self):
  204         return self.__class__(self.a.demorgan(), self.b.demorgan())
  205 
  206     def transform_relations(self, transformer):
  207         return self.__class__(
  208             self.a.transform_relations(transformer),
  209             self.b.transform_relations(transformer)
  210         )
  211 
  212 class Implication(BinFormula):
  213     def __str__(self):
  214         return '(' + str(self.a) + ' → ' + str(self.b) + ')'
  215 
  216     def basic(self):
  217         return Disjunction(Negation(self.a.basic()), self.b.basic())
  218 
  219 class Equivalence(BinFormula):
  220     def __str__(self):
  221         return '(' + str(self.a) + ' ↔ ' + str(self.b) + ')'
  222 
  223     def basic(self):
  224         return Conjunction(
  225             Disjunction(Negation(self.a.basic()), self.b.basic()),
  226             Disjunction(Negation(self.b.basic()), self.a.basic())
  227         )
  228 
  229 def isvarname(a):
  230     if not a[0].isalpha() and a[0] != '_':
  231         return False
  232     for c in a[1:]:
  233         if not c.isalnum() and c != '_':
  234             return False
  235     return True
  236 
  237 def next_token(input):
  238     ops = {
  239         '(': '(',
  240         ')': ')',
  241         '|': '∨',
  242         '∨': '∨',
  243         '&': '∧',
  244         '∧': '∧',
  245         '!': '¬',
  246         '~': '¬',
  247         '¬': '¬',
  248         '→': '→',
  249         '↔': '↔',
  250         '=': '=',
  251         '≠': '≠',
  252         '<': '<',
  253         '>': '>',
  254         '≤': '≤',
  255         '≥': '≥',
  256         '->': '→',
  257         '!=': '≠',
  258         '>=': '≥',
  259         '<=': '≤',
  260         '<->': '↔',
  261     }
  262 
  263     if input == '':
  264         return None, None
  265     input = input.lstrip()
  266 
  267     if input[0:3] in ops:
  268         return ops[input[0:3]], input[3:]
  269     elif input[0:2] in ops:
  270         return ops[input[0:2]], input[2:]
  271     elif input[0] in ops:
  272         return ops[input[0]], input[1:]
  273     elif input[0].isdigit() or input[0] == '-':
  274         v = input[0]
  275         for c in input[1:]:
  276             if not c.isdigit():
  277                 break
  278             v += c
  279         return int(v), input[len(v):]
  280     elif input[0].isalpha() or input[0] == '_':
  281         v = input[0]
  282         for c in input[1:]:
  283             if not c.isalnum() and c != '_':
  284                 break
  285             v += c
  286         return v, input[len(v):]
  287     else:
  288         raise SyntaxError('unknown token at ' + input)
  289 
  290 # Formula parsing by shunting yard algorithm
  291 def parse(input):
  292     prec = {
  293         '∧': 1,
  294         '∨': 1,
  295         '→': 2,
  296         '↔': 2,
  297         '=': 3,
  298         '≠': 3,
  299         '<': 3,
  300         '>': 3,
  301         '≤': 3,
  302         '≥': 3,
  303     }
  304     formula = {
  305         '∧': Conjunction,
  306         '∨': Disjunction,
  307         '→': Implication,
  308         '↔': Equivalence,
  309         '=': Relation,
  310         '≠': Relation,
  311         '<': Relation,
  312         '>': Relation,
  313         '≤': Relation,
  314         '≥': Relation,
  315     }
  316     stack = []
  317     output = []
  318 
  319     def negate():
  320         while stack and stack[-1] == '¬':
  321             stack.pop()
  322             output[-1] = Negation(output[-1])
  323 
  324     def apply():
  325         try:
  326             b = output.pop()
  327             a = output.pop()
  328             op = stack.pop()
  329             cls = formula[op]
  330             if cls == Relation:
  331                 if isinstance(a, Relation):
  332                     a = a.name()
  333                 if isinstance(b, Relation):
  334                     b = b.name()
  335                 obj = cls(op, a, b, infix=True)
  336             else:
  337                 obj = cls(a, b)
  338             output.append(obj)
  339             negate()
  340         except:
  341             raise SyntaxError('invalid input')
  342 
  343     input = input.strip()
  344     while True:
  345         token, input = next_token(input)
  346         if token is None:
  347             break
  348 
  349         if isinstance(token, int):
  350             output.append(token)
  351         elif isvarname(token):
  352             res = Relation(token)
  353             while stack and stack[-1] == '¬':
  354                 stack.pop()
  355                 res = Negation(res)
  356             output.append(res)
  357         elif token == '(':
  358             stack.append('(')
  359         elif token == ')':
  360             negate()
  361             while stack[-1] != '(':
  362                 apply()
  363             if stack[-1] != '(':
  364                 raise SyntaxError('parentheses mismatch')
  365             stack.pop()
  366         else:
  367             if stack and token not in '(¬' and stack[-1] not in '(¬' and prec[token] <= prec[stack[-1]]:
  368                 apply()
  369             stack.append(token)
  370 
  371     negate()
  372     while stack:
  373         apply()
  374 
  375     if len(output) > 1:
  376         raise SyntaxError('invalid input')
  377     return output[0]
  378 
  379 Formula.parse = staticmethod(parse)
  380 
  381 if __name__ == '__main__':
  382     from sys import argv
  383 
  384     if len(argv) < 2:
  385         formula = '(a->b->c)<->~(d&f|a)'
  386     else:
  387         formula = argv[1]
  388 
  389     f = Formula.parse(formula)
  390     print(f.flatten())
  391