1# ----------------------------------------------------------------------
2#   LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator
3#   http://lammps.sandia.gov, Sandia National Laboratories
4#   Steve Plimpton, sjplimp@sandia.gov
5#
6#   Copyright (2003) Sandia Corporation.  Under the terms of Contract
7#   DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains
8#   certain rights in this software.  This software is distributed under
9#   the GNU General Public License.
10#
11#   See the README file in the top-level LAMMPS directory.
12# -------------------------------------------------------------------------
13
14# Python wrapper on LAMMPS library via ctypes
15
16import sys,traceback,types
17from ctypes import *
18
19class liggghts:
20  def __init__(self,name="",cmdargs=None):
21
22    # Check which python version is being used
23    self.pyVersion = sys.version_info
24
25    # load libliggghts.so by default
26    # if name = "g++", load libliggghts_g++.so
27
28    try:
29      if not name: self.lib = CDLL("libliggghts.so",RTLD_GLOBAL)
30      else: self.lib = CDLL("libliggghts_%s.so" % name,RTLD_GLOBAL)
31    except:
32      type,value,tb = sys.exc_info()
33      traceback.print_exception(type,value,tb)
34      raise OSError("Could not load LIGGGHTS dynamic library")
35
36    # create an instance of LAMMPS
37    # don't know how to pass an MPI communicator from PyPar
38    # no_mpi call lets LAMMPS use MPI_COMM_WORLD
39    # cargs = array of C strings from args
40
41    if cmdargs:
42      cmdargs.insert(0,"liggghts.py")
43      narg = len(cmdargs)
44      cargs = (c_char_p*narg)(*cmdargs)
45      self.lmp = c_void_p()
46      self.lib.lammps_open_no_mpi(narg,cargs,byref(self.lmp))
47    else:
48      self.lmp = c_void_p()
49      self.lib.lammps_open_no_mpi(0,None,byref(self.lmp))
50      # could use just this if LAMMPS lib interface supported it
51      # self.lmp = self.lib.lammps_open_no_mpi(0,None)
52
53  def __del__(self):
54    if self.lmp: self.lib.lammps_close(self.lmp)
55
56  def close(self):
57    self.lib.lammps_close(self.lmp)
58    self.lmp = None
59
60  def file(self,file):
61    if self.pyVersion[0] == 3:
62      file = file.encode()
63    self.lib.lammps_file(self.lmp,file)
64
65  def command(self,cmd):
66    if self.pyVersion[0] == 3:
67      cmd = cmd.encode()
68    self.lib.lammps_command(self.lmp,cmd)
69
70  def extract_global(self,name,type):
71    if self.pyVersion[0] == 3:
72      name = name.encode()
73    if type == 0:
74      self.lib.lammps_extract_global.restype = POINTER(c_int)
75    elif type == 1:
76      self.lib.lammps_extract_global.restype = POINTER(c_double)
77    else: return None
78    ptr = self.lib.lammps_extract_global(self.lmp,name)
79    return ptr[0]
80
81  def extract_atom(self,name,type):
82    if self.pyVersion[0] == 3:
83      name = name.encode()
84    if type == 0:
85      self.lib.lammps_extract_atom.restype = POINTER(c_int)
86    elif type == 1:
87      self.lib.lammps_extract_atom.restype = POINTER(POINTER(c_int))
88    elif type == 2:
89      self.lib.lammps_extract_atom.restype = POINTER(c_double)
90    elif type == 3:
91      self.lib.lammps_extract_atom.restype = POINTER(POINTER(c_double))
92    else: return None
93    ptr = self.lib.lammps_extract_atom(self.lmp,name)
94    return ptr
95
96  def extract_compute(self,c_id,style,type):
97    if self.pyVersion[0] == 3:
98      c_id = c_id.encode()
99    if type == 0:
100      if style > 0: return None
101      self.lib.lammps_extract_compute.restype = POINTER(c_double)
102      ptr = self.lib.lammps_extract_compute(self.lmp,c_id,style,type)
103      return ptr[0]
104    if type == 1:
105      self.lib.lammps_extract_compute.restype = POINTER(c_double)
106      ptr = self.lib.lammps_extract_compute(self.lmp,c_id,style,type)
107      return ptr
108    if type == 2:
109      self.lib.lammps_extract_compute.restype = POINTER(POINTER(c_double))
110      ptr = self.lib.lammps_extract_compute(self.lmp,c_id,style,type)
111      return ptr
112    return None
113
114  # in case of global datum, free memory for 1 double via lammps_free()
115  # double was allocated by library interface function
116
117  def extract_fix(self,f_id,style,type,i=0,j=0):
118    if self.pyVersion[0] == 3:
119      f_id = f_id.encode()
120    if type == 0:
121      if style > 0: return None
122      self.lib.lammps_extract_fix.restype = POINTER(c_double)
123      ptr = self.lib.lammps_extract_fix(self.lmp,f_id,style,type,i,j)
124      result = ptr[0]
125      self.lib.lammps_free(ptr)
126      return result
127    if type == 1:
128      self.lib.lammps_extract_fix.restype = POINTER(c_double)
129      ptr = self.lib.lammps_extract_fix(self.lmp,f_id,style,type,i,j)
130      return ptr
131    if type == 2:
132      self.lib.lammps_extract_fix.restype = POINTER(POINTER(c_double))
133      ptr = self.lib.lammps_extract_fix(self.lmp,f_id,style,type,i,j)
134      return ptr
135    return None
136
137  # free memory for 1 double or 1 vector of doubles via lammps_free()
138  # for vector, must copy nlocal returned values to local c_double vector
139  # memory was allocated by library interface function
140
141  def extract_variable(self,name,group,type):
142    if self.pyVersion[0] == 3:
143      name = name.encode()
144      group = group.encode()
145    if type == 0:
146      self.lib.lammps_extract_variable.restype = POINTER(c_double)
147      ptr = self.lib.lammps_extract_variable(self.lmp,name,group)
148      result = ptr[0]
149      self.lib.lammps_free(ptr)
150      return result
151    if type == 1:
152      self.lib.lammps_extract_global.restype = POINTER(c_int)
153      nlocalptr = self.lib.lammps_extract_global(self.lmp,"nlocal")
154      nlocal = nlocalptr[0]
155      result = (c_double*nlocal)()
156      self.lib.lammps_extract_variable.restype = POINTER(c_double)
157      ptr = self.lib.lammps_extract_variable(self.lmp,name,group)
158      for i in xrange(nlocal): result[i] = ptr[i]
159      self.lib.lammps_free(ptr)
160      return result
161    return None
162
163  # return total number of atoms in system
164
165  def get_natoms(self):
166    return self.lib.lammps_get_natoms(self.lmp)
167
168  # return vector of atom properties gathered across procs, ordered by atom ID
169
170  def gather_atoms(self,name,type,count):
171    if self.pyVersion[0] == 3:
172      name = name.encode()
173    natoms = self.lib.lammps_get_natoms(self.lmp)
174    if type == 0:
175      data = ((count*natoms)*c_int)()
176      self.lib.lammps_gather_atoms(self.lmp,name,type,count,data)
177    elif type == 1:
178      data = ((count*natoms)*c_double)()
179      self.lib.lammps_gather_atoms(self.lmp,name,type,count,data)
180    else: return None
181    return data
182
183  # scatter vector of atom properties across procs, ordered by atom ID
184  # assume vector is of correct type and length, as created by gather_atoms()
185
186  def scatter_atoms(self,name,type,count,data):
187    if self.pyVersion[0] == 3:
188      name = name.encode()
189    self.lib.lammps_scatter_atoms(self.lmp,name,type,count,data)
190