1from ctypes import *
2
3isl = cdll.LoadLibrary("libisl.so")
4
5class Context:
6  defaultInstance = None
7  instances = {}
8
9  def __init__(self):
10    ptr = isl.isl_ctx_alloc()
11    self.ptr = ptr
12    Context.instances[ptr] = self
13
14  def __del__(self):
15    isl.isl_ctx_free(self)
16
17  def from_param(self):
18    return self.ptr
19
20  @staticmethod
21  def from_ptr(ptr):
22    return Context.instances[ptr]
23
24  @staticmethod
25  def getDefaultInstance():
26    if Context.defaultInstance == None:
27      Context.defaultInstance = Context()
28
29    return Context.defaultInstance
30
31class IslObject:
32  def __init__(self, string = "", ctx = None, ptr = None):
33    self.initialize_isl_methods()
34    if ptr != None:
35      self.ptr = ptr
36      self.ctx = self.get_isl_method("get_ctx")(self)
37      return
38
39    if ctx == None:
40      ctx = Context.getDefaultInstance()
41
42    self.ctx = ctx
43    self.ptr = self.get_isl_method("read_from_str")(ctx, string, -1)
44
45  def __del__(self):
46    self.get_isl_method("free")(self)
47
48  def from_param(self):
49    return self.ptr
50
51  @property
52  def context(self):
53    return self.ctx
54
55  def __repr__(self):
56    p = Printer(self.ctx)
57    self.to_printer(p)
58    return p.getString();
59
60  def __str__(self):
61    p = Printer(self.ctx)
62    self.to_printer(p)
63    return p.getString();
64
65  @staticmethod
66  def isl_name():
67    return "No isl name available"
68
69  def initialize_isl_methods(self):
70    if hasattr(self.__class__, "initialized"):
71      return
72
73    self.__class__.initalized = True
74    self.get_isl_method("read_from_str").argtypes = [Context, c_char_p, c_int]
75    self.get_isl_method("copy").argtypes = [self.__class__]
76    self.get_isl_method("copy").restype = c_int
77    self.get_isl_method("free").argtypes = [self.__class__]
78    self.get_isl_method("get_ctx").argtypes = [self.__class__]
79    self.get_isl_method("get_ctx").restype = Context.from_ptr
80    getattr(isl, "isl_printer_print_" + self.isl_name()).argtypes = [Printer, self.__class__]
81
82  def get_isl_method(self, name):
83    return getattr(isl, "isl_" + self.isl_name() + "_" + name)
84
85  def to_printer(self, printer):
86    getattr(isl, "isl_printer_print_" + self.isl_name())(printer, self)
87
88class BSet(IslObject):
89  @staticmethod
90  def from_ptr(ptr):
91    if not ptr:
92      return
93    return BSet(ptr = ptr)
94
95  @staticmethod
96  def isl_name():
97    return "basic_set"
98
99class Set(IslObject):
100  @staticmethod
101  def from_ptr(ptr):
102    if not ptr:
103      return
104    return Set(ptr = ptr)
105
106  @staticmethod
107  def isl_name():
108    return "set"
109
110class USet(IslObject):
111  @staticmethod
112  def from_ptr(ptr):
113    if not ptr:
114      return
115    return USet(ptr = ptr)
116
117  @staticmethod
118  def isl_name():
119    return "union_set"
120
121
122class BMap(IslObject):
123  @staticmethod
124  def from_ptr(ptr):
125    if not ptr:
126      return
127    return BMap(ptr = ptr)
128
129  def __mul__(self, set):
130    return self.intersect_domain(set)
131
132  @staticmethod
133  def isl_name():
134    return "basic_map"
135
136class Map(IslObject):
137  @staticmethod
138  def from_ptr(ptr):
139    if not ptr:
140      return
141    return Map(ptr = ptr)
142
143  def __mul__(self, set):
144    return self.intersect_domain(set)
145
146  @staticmethod
147  def isl_name():
148    return "map"
149
150  @staticmethod
151  def lex_lt(dim):
152    dim = isl.isl_dim_copy(dim)
153    return isl.isl_map_lex_lt(dim)
154
155  @staticmethod
156  def lex_le(dim):
157    dim = isl.isl_dim_copy(dim)
158    return isl.isl_map_lex_le(dim)
159
160  @staticmethod
161  def lex_gt(dim):
162    dim = isl.isl_dim_copy(dim)
163    return isl.isl_map_lex_gt(dim)
164
165  @staticmethod
166  def lex_ge(dim):
167    dim = isl.isl_dim_copy(dim)
168    return isl.isl_map_lex_ge(dim)
169
170class UMap(IslObject):
171  @staticmethod
172  def from_ptr(ptr):
173    if not ptr:
174      return
175    return UMap(ptr = ptr)
176
177  @staticmethod
178  def isl_name():
179    return "union_map"
180
181class Dim(IslObject):
182  @staticmethod
183  def from_ptr(ptr):
184    if not ptr:
185      return
186    return Dim(ptr = ptr)
187
188  @staticmethod
189  def isl_name():
190    return "dim"
191
192  def initialize_isl_methods(self):
193    if hasattr(self.__class__, "initialized"):
194      return
195
196    self.__class__.initalized = True
197    self.get_isl_method("copy").argtypes = [self.__class__]
198    self.get_isl_method("copy").restype = c_int
199    self.get_isl_method("free").argtypes = [self.__class__]
200    self.get_isl_method("get_ctx").argtypes = [self.__class__]
201    self.get_isl_method("get_ctx").restype = Context.from_ptr
202
203  def __repr__(self):
204    return str(self)
205
206  def __str__(self):
207
208    dimParam = isl.isl_dim_size(self, 1)
209    dimIn = isl.isl_dim_size(self, 2)
210    dimOut = isl.isl_dim_size(self, 3)
211
212    if dimIn:
213      return "<dim In:%s, Out:%s, Param:%s>" % (dimIn, dimOut, dimParam)
214
215    return "<dim Set:%s, Param:%s>" % (dimOut, dimParam)
216
217class Printer:
218  FORMAT_ISL = 0
219  FORMAT_POLYLIB = 1
220  FORMAT_POLYLIB_CONSTRAINTS = 2
221  FORMAT_OMEGA = 3
222  FORMAT_C = 4
223  FORMAT_LATEX = 5
224  FORMAT_EXT_POLYLIB = 6
225
226  def __init__(self, ctx = None):
227    if ctx == None:
228      ctx = Context.getDefaultInstance()
229
230    self.ctx = ctx
231    self.ptr = isl.isl_printer_to_str(ctx)
232
233  def setFormat(self, format):
234    self.ptr = isl.isl_printer_set_output_format(self, format);
235
236  def from_param(self):
237    return self.ptr
238
239  def __del__(self):
240    isl.isl_printer_free(self)
241
242  def getString(self):
243    return isl.isl_printer_get_str(self)
244
245functions = [
246             # Unary properties
247             ("is_empty", BSet, [BSet], c_int),
248             ("is_empty", Set, [Set], c_int),
249             ("is_empty", USet, [USet], c_int),
250             ("is_empty", BMap, [BMap], c_int),
251             ("is_empty", Map, [Map], c_int),
252             ("is_empty", UMap, [UMap], c_int),
253
254    #         ("is_universe", Set, [Set], c_int),
255    #         ("is_universe", Map, [Map], c_int),
256
257             ("is_single_valued", Map, [Map], c_int),
258
259             ("is_bijective", Map, [Map], c_int),
260
261             ("is_wrapping", BSet, [BSet], c_int),
262             ("is_wrapping", Set, [Set], c_int),
263
264             # Binary properties
265             ("is_equal", BSet, [BSet, BSet], c_int),
266             ("is_equal", Set, [Set, Set], c_int),
267             ("is_equal", USet, [USet, USet], c_int),
268             ("is_equal", BMap, [BMap, BMap], c_int),
269             ("is_equal", Map, [Map, Map], c_int),
270             ("is_equal", UMap, [UMap, UMap], c_int),
271
272             # is_disjoint missing
273
274             # ("is_subset", BSet, [BSet, BSet], c_int),
275             ("is_subset", Set, [Set, Set], c_int),
276             ("is_subset", USet, [USet, USet], c_int),
277             ("is_subset", BMap, [BMap, BMap], c_int),
278             ("is_subset", Map, [Map, Map], c_int),
279             ("is_subset", UMap, [UMap, UMap], c_int),
280             #("is_strict_subset", BSet, [BSet, BSet], c_int),
281             ("is_strict_subset", Set, [Set, Set], c_int),
282             ("is_strict_subset", USet, [USet, USet], c_int),
283             ("is_strict_subset", BMap, [BMap, BMap], c_int),
284             ("is_strict_subset", Map, [Map, Map], c_int),
285             ("is_strict_subset", UMap, [UMap, UMap], c_int),
286
287             # Unary Operations
288             ("complement", Set, [Set], Set),
289             ("reverse", BMap, [BMap], BMap),
290             ("reverse", Map, [Map], Map),
291             ("reverse", UMap, [UMap], UMap),
292
293             # Projection missing
294             ("range", BMap, [BMap], BSet),
295             ("range", Map, [Map], Set),
296             ("range", UMap, [UMap], USet),
297             ("domain", BMap, [BMap], BSet),
298             ("domain", Map, [Map], Set),
299             ("domain", UMap, [UMap], USet),
300
301             ("identity", Set, [Set], Map),
302             ("identity", USet, [USet], UMap),
303
304             ("deltas", BMap, [BMap], BSet),
305             ("deltas", Map, [Map], Set),
306             ("deltas", UMap, [UMap], USet),
307
308             ("coalesce", Set, [Set], Set),
309             ("coalesce", USet, [USet], USet),
310             ("coalesce", Map, [Map], Map),
311             ("coalesce", UMap, [UMap], UMap),
312
313             ("detect_equalities", BSet, [BSet], BSet),
314             ("detect_equalities", Set, [Set], Set),
315             ("detect_equalities", USet, [USet], USet),
316             ("detect_equalities", BMap, [BMap], BMap),
317             ("detect_equalities", Map, [Map], Map),
318             ("detect_equalities", UMap, [UMap], UMap),
319
320             ("convex_hull", Set, [Set], Set),
321             ("convex_hull", Map, [Map], Map),
322
323             ("simple_hull", Set, [Set], Set),
324             ("simple_hull", Map, [Map], Map),
325
326             ("affine_hull", BSet, [BSet], BSet),
327             ("affine_hull", Set, [Set], BSet),
328             ("affine_hull", USet, [USet], USet),
329             ("affine_hull", BMap, [BMap], BMap),
330             ("affine_hull", Map, [Map], BMap),
331             ("affine_hull", UMap, [UMap], UMap),
332
333             ("polyhedral_hull", Set, [Set], Set),
334             ("polyhedral_hull", USet, [USet], USet),
335             ("polyhedral_hull", Map, [Map], Map),
336             ("polyhedral_hull", UMap, [UMap], UMap),
337
338             # Power missing
339             # Transitive closure missing
340             # Reaching path lengths missing
341
342             ("wrap", BMap, [BMap], BSet),
343             ("wrap", Map, [Map], Set),
344             ("wrap", UMap, [UMap], USet),
345             ("unwrap", BSet, [BMap], BMap),
346             ("unwrap", Set, [Map], Map),
347             ("unwrap", USet, [UMap], UMap),
348
349             ("flatten", Set, [Set], Set),
350             ("flatten", Map, [Map], Map),
351             ("flatten_map", Set, [Set], Map),
352
353             # Dimension manipulation missing
354
355             # Binary Operations
356             ("intersect", BSet, [BSet, BSet], BSet),
357             ("intersect", Set, [Set, Set], Set),
358             ("intersect", USet, [USet, USet], USet),
359             ("intersect", BMap, [BMap, BMap], BMap),
360             ("intersect", Map, [Map, Map], Map),
361             ("intersect", UMap, [UMap, UMap], UMap),
362             ("intersect_domain", BMap, [BMap, BSet], BMap),
363             ("intersect_domain", Map, [Map, Set], Map),
364             ("intersect_domain", UMap, [UMap, USet], UMap),
365             ("intersect_range", BMap, [BMap, BSet], BMap),
366             ("intersect_range", Map, [Map, Set], Map),
367             ("intersect_range", UMap, [UMap, USet], UMap),
368
369             ("union", BSet, [BSet, BSet], Set),
370             ("union", Set, [Set, Set], Set),
371             ("union", USet, [USet, USet], USet),
372             ("union", BMap, [BMap, BMap], Map),
373             ("union", Map, [Map, Map], Map),
374             ("union", UMap, [UMap, UMap], UMap),
375
376             ("subtract", Set, [Set, Set], Set),
377             ("subtract", Map, [Map, Map], Map),
378             ("subtract", USet, [USet, USet], USet),
379             ("subtract", UMap, [UMap, UMap], UMap),
380
381             ("apply", BSet, [BSet, BMap], BSet),
382             ("apply", Set, [Set, Map], Set),
383             ("apply", USet, [USet, UMap], USet),
384             ("apply_domain", BMap, [BMap, BMap], BMap),
385             ("apply_domain", Map, [Map, Map], Map),
386             ("apply_domain", UMap, [UMap, UMap], UMap),
387             ("apply_range", BMap, [BMap, BMap], BMap),
388             ("apply_range", Map, [Map, Map], Map),
389             ("apply_range", UMap, [UMap, UMap], UMap),
390
391             ("gist", BSet, [BSet, BSet], BSet),
392             ("gist", Set, [Set, Set], Set),
393             ("gist", USet, [USet, USet], USet),
394             ("gist", BMap, [BMap, BMap], BMap),
395             ("gist", Map, [Map, Map], Map),
396             ("gist", UMap, [UMap, UMap], UMap),
397
398             # Lexicographic Optimizations
399             # partial_lexmin missing
400             ("lexmin", BSet, [BSet], BSet),
401             ("lexmin", Set, [Set], Set),
402             ("lexmin", USet, [USet], USet),
403             ("lexmin", BMap, [BMap], BMap),
404             ("lexmin", Map, [Map], Map),
405             ("lexmin", UMap, [UMap], UMap),
406
407             ("lexmax", BSet, [BSet], BSet),
408             ("lexmax", Set, [Set], Set),
409             ("lexmax", USet, [USet], USet),
410             ("lexmax", BMap, [BMap], BMap),
411             ("lexmax", Map, [Map], Map),
412             ("lexmax", UMap, [UMap], UMap),
413
414              # Undocumented
415             ("lex_lt_union_set", USet, [USet, USet], UMap),
416             ("lex_le_union_set", USet, [USet, USet], UMap),
417             ("lex_gt_union_set", USet, [USet, USet], UMap),
418             ("lex_ge_union_set", USet, [USet, USet], UMap),
419
420             ]
421keep_functions = [
422             # Unary properties
423             ("get_dim", BSet, [BSet], Dim),
424             ("get_dim", Set, [Set], Dim),
425             ("get_dim", USet, [USet], Dim),
426             ("get_dim", BMap, [BMap], Dim),
427             ("get_dim", Map, [Map], Dim),
428             ("get_dim", UMap, [UMap], Dim)
429             ]
430
431def addIslFunction(object, name):
432    functionName = "isl_" + object.isl_name() + "_" + name
433    islFunction = getattr(isl, functionName)
434    if len(islFunction.argtypes) == 1:
435      f = lambda a: islFunctionOneOp(islFunction, a)
436    elif len(islFunction.argtypes) == 2:
437      f = lambda a, b: islFunctionTwoOp(islFunction, a, b)
438    object.__dict__[name] = f
439
440
441def islFunctionOneOp(islFunction, ops):
442  ops = getattr(isl, "isl_" + ops.isl_name() + "_copy")(ops)
443  return islFunction(ops)
444
445def islFunctionTwoOp(islFunction, opOne, opTwo):
446  opOne = getattr(isl, "isl_" + opOne.isl_name() + "_copy")(opOne)
447  opTwo = getattr(isl, "isl_" + opTwo.isl_name() + "_copy")(opTwo)
448  return islFunction(opOne, opTwo)
449
450for (operation, base, operands, ret) in functions:
451  functionName = "isl_" + base.isl_name() + "_" + operation
452  islFunction = getattr(isl, functionName)
453  if len(operands) == 1:
454    islFunction.argtypes = [c_int]
455  elif len(operands) == 2:
456    islFunction.argtypes = [c_int, c_int]
457
458  if ret == c_int:
459    islFunction.restype = ret
460  else:
461    islFunction.restype = ret.from_ptr
462
463  addIslFunction(base, operation)
464
465def addIslFunctionKeep(object, name):
466    functionName = "isl_" + object.isl_name() + "_" + name
467    islFunction = getattr(isl, functionName)
468    if len(islFunction.argtypes) == 1:
469      f = lambda a: islFunctionOneOpKeep(islFunction, a)
470    elif len(islFunction.argtypes) == 2:
471      f = lambda a, b: islFunctionTwoOpKeep(islFunction, a, b)
472    object.__dict__[name] = f
473
474def islFunctionOneOpKeep(islFunction, ops):
475  return islFunction(ops)
476
477def islFunctionTwoOpKeep(islFunction, opOne, opTwo):
478  return islFunction(opOne, opTwo)
479
480for (operation, base, operands, ret) in keep_functions:
481  functionName = "isl_" + base.isl_name() + "_" + operation
482  islFunction = getattr(isl, functionName)
483  if len(operands) == 1:
484    islFunction.argtypes = [c_int]
485  elif len(operands) == 2:
486    islFunction.argtypes = [c_int, c_int]
487
488  if ret == c_int:
489    islFunction.restype = ret
490  else:
491    islFunction.restype = ret.from_ptr
492
493  addIslFunctionKeep(base, operation)
494
495isl.isl_ctx_free.argtypes = [Context]
496isl.isl_basic_set_read_from_str.argtypes = [Context, c_char_p, c_int]
497isl.isl_set_read_from_str.argtypes = [Context, c_char_p, c_int]
498isl.isl_basic_set_copy.argtypes = [BSet]
499isl.isl_basic_set_copy.restype = c_int
500isl.isl_set_copy.argtypes = [Set]
501isl.isl_set_copy.restype = c_int
502isl.isl_set_copy.argtypes = [Set]
503isl.isl_set_copy.restype = c_int
504isl.isl_set_free.argtypes = [Set]
505isl.isl_basic_set_get_ctx.argtypes = [BSet]
506isl.isl_basic_set_get_ctx.restype = Context.from_ptr
507isl.isl_set_get_ctx.argtypes = [Set]
508isl.isl_set_get_ctx.restype = Context.from_ptr
509isl.isl_basic_set_get_dim.argtypes = [BSet]
510isl.isl_basic_set_get_dim.restype = Dim.from_ptr
511isl.isl_set_get_dim.argtypes = [Set]
512isl.isl_set_get_dim.restype = Dim.from_ptr
513isl.isl_union_set_get_dim.argtypes = [USet]
514isl.isl_union_set_get_dim.restype = Dim.from_ptr
515
516isl.isl_basic_map_read_from_str.argtypes = [Context, c_char_p, c_int]
517isl.isl_map_read_from_str.argtypes = [Context, c_char_p, c_int]
518isl.isl_basic_map_free.argtypes = [BMap]
519isl.isl_map_free.argtypes = [Map]
520isl.isl_basic_map_copy.argtypes = [BMap]
521isl.isl_basic_map_copy.restype = c_int
522isl.isl_map_copy.argtypes = [Map]
523isl.isl_map_copy.restype = c_int
524isl.isl_map_get_ctx.argtypes = [Map]
525isl.isl_basic_map_get_ctx.argtypes = [BMap]
526isl.isl_basic_map_get_ctx.restype = Context.from_ptr
527isl.isl_map_get_ctx.argtypes = [Map]
528isl.isl_map_get_ctx.restype = Context.from_ptr
529isl.isl_basic_map_get_dim.argtypes = [BMap]
530isl.isl_basic_map_get_dim.restype = Dim.from_ptr
531isl.isl_map_get_dim.argtypes = [Map]
532isl.isl_map_get_dim.restype = Dim.from_ptr
533isl.isl_union_map_get_dim.argtypes = [UMap]
534isl.isl_union_map_get_dim.restype = Dim.from_ptr
535isl.isl_printer_free.argtypes = [Printer]
536isl.isl_printer_to_str.argtypes = [Context]
537isl.isl_printer_print_basic_set.argtypes = [Printer, BSet]
538isl.isl_printer_print_set.argtypes = [Printer, Set]
539isl.isl_printer_print_basic_map.argtypes = [Printer, BMap]
540isl.isl_printer_print_map.argtypes = [Printer, Map]
541isl.isl_printer_get_str.argtypes = [Printer]
542isl.isl_printer_get_str.restype = c_char_p
543isl.isl_printer_set_output_format.argtypes = [Printer, c_int]
544isl.isl_printer_set_output_format.restype = c_int
545isl.isl_dim_size.argtypes = [Dim, c_int]
546isl.isl_dim_size.restype = c_int
547
548isl.isl_map_lex_lt.argtypes = [c_int]
549isl.isl_map_lex_lt.restype = Map.from_ptr
550isl.isl_map_lex_le.argtypes = [c_int]
551isl.isl_map_lex_le.restype = Map.from_ptr
552isl.isl_map_lex_gt.argtypes = [c_int]
553isl.isl_map_lex_gt.restype = Map.from_ptr
554isl.isl_map_lex_ge.argtypes = [c_int]
555isl.isl_map_lex_ge.restype = Map.from_ptr
556
557isl.isl_union_map_compute_flow.argtypes = [c_int, c_int, c_int, c_int, c_void_p,
558                                           c_void_p, c_void_p, c_void_p]
559
560def dependences(sink, must_source, may_source, schedule):
561  sink = getattr(isl, "isl_" + sink.isl_name() + "_copy")(sink)
562  must_source = getattr(isl, "isl_" + must_source.isl_name() + "_copy")(must_source)
563  may_source = getattr(isl, "isl_" + may_source.isl_name() + "_copy")(may_source)
564  schedule = getattr(isl, "isl_" + schedule.isl_name() + "_copy")(schedule)
565  must_dep = c_int()
566  may_dep = c_int()
567  must_no_source = c_int()
568  may_no_source = c_int()
569  isl.isl_union_map_compute_flow(sink, must_source, may_source, schedule, \
570                                 byref(must_dep), byref(may_dep),
571                                 byref(must_no_source),
572                                 byref(may_no_source))
573
574  return (UMap.from_ptr(must_dep), UMap.from_ptr(may_dep), \
575          USet.from_ptr(must_no_source), USet.from_ptr(may_no_source))
576
577
578__all__ = ['Set', 'Map', 'Printer', 'Context']
579