1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4# Copyright (C) 2014  Mate Soos
5#
6# This program is free software; you can redistribute it and/or
7# modify it under the terms of the GNU General Public License
8# as published by the Free Software Foundation; version 2
9# of the License.
10#
11# This program is distributed in the hope that it will be useful,
12# but WITHOUT ANY WARRANTY; without even the implied warranty of
13# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14# GNU General Public License for more details.
15#
16# You should have received a copy of the GNU General Public License
17# along with this program; if not, write to the Free Software
18# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
19
20from __future__ import print_function
21import re
22
23
24class XorToCNF:
25    def __init__(self):
26        self.cutsize = 4
27
28    def get_max_var(self, clause):
29        maxvar = 0
30
31        tmp = clause.strip()
32        if len(tmp) == 0:
33            return 0
34
35        assert re.search(r'^x? *-?\d+', tmp)
36
37        if tmp[0] == 'x':
38            tmp = tmp[1:]
39
40        for lit in tmp.split():
41            var = abs(int(lit))
42            maxvar = max(var, maxvar)
43
44        return maxvar
45
46    def convert(self, infilename, outfilename, num_extra_cls=0):
47        assert isinstance(self.cutsize, int)
48        if self.cutsize <= 2:
49            print("ERROR: The cut size MUST be larger or equal to 3")
50            exit(-1)
51
52        maxvar, numcls, extravars_needed, extracls_needed = self.get_stats(infilename)
53        fout = open(outfilename, "w")
54        fout.write("p cnf %d %d\n" %
55                   (maxvar + extravars_needed,
56                    numcls + extracls_needed + num_extra_cls))
57        fin = open(infilename, "r")
58        atvar = maxvar
59        for line in fin:
60            line = line.strip()
61
62            # skip empty line
63            if len(line) == 0:
64                continue
65
66            # skip header and comments
67            if line[0] == 'c' or line[0] == 'p':
68                continue
69
70            if line[0] == 'x':
71                # convert XOR to normal(s)
72                xorclauses, atvar = self.cut_up_xor_to_n(line, atvar)
73                for xorcl in xorclauses:
74                    cls = self.xor_to_cnf_simple(xorcl)
75                    for cl in cls:
76                        fout.write(cl + "\n")
77            else:
78                # simply print normal clause
79                fout.write(line + "\n")
80
81        assert atvar == maxvar + extravars_needed
82        fout.close()
83        fin.close()
84
85    def popcount(self, x):
86        return bin(x).count('1')
87
88    def parse_xor(self, xorclause):
89        assert re.search(r'^x( *-?\d+ )*0$', xorclause)
90
91        tmp = xorclause[1:]
92        lits = [int(elem) for elem in tmp.split()]
93        assert lits[len(lits)-1] == 0
94
95        # remove last element, the 0
96        lits = lits[:len(lits)-1]
97
98        return lits
99
100    def xor_to_cnf_simple(self, xorclause, equals=True):
101        assert equals is True or equals is False
102        if equals is True:
103            equals = 1
104        else:
105            equals = 0
106
107        lits = self.parse_xor(xorclause)
108
109        # empty XOR clause is TRUE, so is NOT an empty clause (i.e. UNSAT)
110        if len(lits) == 0:
111            return []
112
113        ret = []
114        for i in range(2**(len(lits))):
115            # only the ones we need
116            cls = ""
117            if self.popcount(i) % 2 == equals:
118                continue
119
120            for at in range(len(lits)):
121                if ((i >> at) & 1) == 0:
122                    cls += "%d " % lits[at]
123                else:
124                    cls += "%d " % (-1*lits[at])
125
126            cls += "0"
127            ret.append(cls)
128
129        return ret
130
131    def cut_up_xor_to_n(self, xorclause, oldmaxvar):
132        assert self.cutsize > 2
133
134        lits = self.parse_xor(xorclause)
135        xors = []
136
137        # xor clause that doesn't need to be cut up
138        if len(lits) <= self.cutsize:
139            retcl = "x"
140            for lit in lits:
141                retcl += "%d " % lit
142            retcl += "0"
143            return [[retcl], oldmaxvar]
144
145        at = 0
146        newmaxvar = oldmaxvar
147        while(at < len(lits)):
148
149            # until when should we cut?
150            until = min(at + self.cutsize-1, len(lits))
151
152            # if in the middle, don't add so much
153            if at > 0 and until < len(lits):
154                until -= 1
155
156            thisxor = "x"
157            for i2 in range(at, until):
158                thisxor += "%d " % lits[i2]
159
160            # add the extra variables
161            if at == 0:
162                # beginning, add only one
163                thisxor += "%d 0" % (newmaxvar+1)
164                newmaxvar += 1
165            elif until == len(lits):
166                # end, only add the one we already made
167                thisxor += "-%d 0" % (newmaxvar)
168            else:
169                thisxor += "-%d %d 0" % (newmaxvar, newmaxvar+1)
170                newmaxvar += 1
171
172            xors.append(thisxor)
173
174            # move along where we are at
175            at = until
176
177        return [xors, newmaxvar]
178
179    def num_extra_vars_cls_needed(self, numlits):
180        def cls_for_plain_xor(numlits):
181            return 2**(numlits-1)
182
183        varsneeded = 0
184        clsneeded = 0
185
186        at = 0
187        while(at < numlits):
188            # at the beginning
189            if at == 0:
190                if numlits > self.cutsize:
191                    at += self.cutsize-1
192                    varsneeded += 1
193                    clsneeded += cls_for_plain_xor(self.cutsize)
194                else:
195                    at = numlits
196                    clsneeded += cls_for_plain_xor(numlits)
197
198            # in the middle
199            elif at + (self.cutsize-1) < numlits:
200                at += self.cutsize-2
201                varsneeded += 1
202                clsneeded += cls_for_plain_xor(self.cutsize)
203            # at the end
204            else:
205                clsneeded += cls_for_plain_xor(numlits-at+1)
206                at = numlits
207
208        return [varsneeded, clsneeded]
209
210    def get_stats(self, infilename):
211        infile = open(infilename, "r")
212
213        maxvar = 0
214        numcls = 0
215        extravars_needed = 0
216        extracls_needed = 0
217        for line in infile:
218            line = line.strip()
219
220            # empty line, skip
221            if len(line) == 0:
222                continue
223
224            # header or comment
225            if line[0] == 'p' or line[0] == 'c':
226                continue
227
228            # get max var
229            maxvar = max(self.get_max_var(line), maxvar)
230
231            if line[0] == 'x':
232                e_var, e_clause = self.num_extra_vars_cls_needed(len(self.parse_xor(line)))
233                extravars_needed += e_var
234                extracls_needed += e_clause
235            else:
236                numcls += 1
237
238        infile.close()
239
240        return [maxvar, numcls, extravars_needed, extracls_needed]
241