1
2# Generate Bspline Jastrow values to test against
3
4# Cut and paste the relevant part of the output into test_bspline_jastrow.cpp
5
6from sympy import *
7from collections import defaultdict
8import numpy as np
9
10def put_variable_on_lhs(cond, var_name):
11  """Return a conditional with the variable always on the left hand side"""
12  if cond.args[1] == var_name:
13    if isinstance(cond, StrictGreaterThan):
14      return StrictLessThan(var_name, cond.args[0])
15    if isinstance(cond, GreaterThan):
16      return LessThan(var_name, cond.args[0])
17    if isinstance(cond, StrictLessThan):
18      return StrictGreaterThan(var_name, cond.args[0])
19    if isinstance(cond, LessThan):
20      return GreaterThan(var_name, cond.args[0])
21
22  return cond
23
24
25def to_interval(ival, var_name):
26    """Convert relational expression to an Interval"""
27    min_val = None
28    lower_open = False
29    max_val = None
30    upper_open = True
31    if isinstance(ival, And):
32        for rel in ival.args:
33            rel = put_variable_on_lhs(rel, var_name)
34            if isinstance(rel, StrictGreaterThan):
35                min_val = rel.args[1]
36                #lower_open = True
37            elif isinstance(rel, GreaterThan):
38                min_val = rel.args[1]
39                #lower_open = False
40            elif isinstance(rel, StrictLessThan):
41                max_val = rel.args[1]
42                #upper_open = True
43            elif isinstance(rel, LessThan):
44                max_val = rel.args[1]
45                #upper_open = False
46            else:
47                print('unhandled ',rel)
48
49    if min_val == None or max_val == None:
50        print('error',ival)
51
52    return Interval(min_val, max_val, lower_open, upper_open)
53
54# Transpose the interval and coefficients
55#  Note that interval [0,1) has the polynomial coefficients found in the einspline code
56#  The other intervals could be shifted, and they would also have the same polynomials
57def transpose_interval_and_coefficients(sym_basis, var_name):
58    cond_map = defaultdict(list)
59
60    i1 = Interval(0,5, False, False) # interval for evaluation
61    for idx, s0 in enumerate(sym_basis):
62        for expr, cond in s0.args:
63            if cond != True:
64                i2 = to_interval(cond, var_name)
65                if not i1.is_disjoint(i2):
66                    cond_map[i2].append( (idx, expr) )
67    return cond_map
68
69# Create piecewise expression from the transposed intervals
70# basis_map - map of interval to list of spline expressions for that interval
71#         c - coefficient symbol (needs to allow indexing)
72#        xs - symbol for the position variable ('x')
73def recreate_piecewise(basis_map, c, xs):
74    args = []
75    for cond, exprs in basis_map.items():
76        e = 0
77        for idx, b in exprs:
78            e += c[idx] * b
79        args.append( (e, cond.as_relational(xs)))
80    args.append((0, True))
81    return Piecewise(*args)
82
83
84
85
86
87def gen_bspline_jastrow(nknots, rcut_val, param, cusp_val):
88  xs = Symbol('x')
89
90  # Workaround for Sympy issue
91  #   bspline_basis_set depends on ordering of relational items.
92  #   The order for <something>*Delta is okay, the order for bare Delta is not (that is, when i = 1)
93  #   Workaround is to multiply Delta by a dummy variable so there is always a multiplication.
94  # Issue
95  #  https://github.com/sympy/sympy/issues/19262
96  #  It has been fixed in mainline, but not a released version (1.6 series as of this writing)
97
98  dummy = Symbol('y', positive=True)
99  Delta = Symbol('Delta',positive=True)
100  knots = [i*Delta for i in range(nknots)]
101  all_knots = [i*Delta*dummy for i in range(-3,nknots+3)]
102  rcut = (nknots-1)*Delta
103
104  # Third-order bspline
105  jastrow_sym_basis1 = bspline_basis_set(3, all_knots, xs)
106  # Remove the dummy variable after it has served its purpose
107  jastrow_sym_basis = [s.subs(dummy, 1) for s in jastrow_sym_basis1]
108
109  print("Number of basis functions = ",len(jastrow_sym_basis))
110
111  # Rearrange the basis and conditionals into a more useful form
112  jastrow_cond_map = transpose_interval_and_coefficients(jastrow_sym_basis, xs)
113  c = IndexedBase('c',shape=(nknots+3))
114  #c = MatrixSymbol('c',nknots+2,1)  # better for code-gen
115  jastrow_spline = recreate_piecewise(jastrow_cond_map, c, xs)
116
117
118  Delta_val = rcut_val*1.0/(nknots+1)
119
120  coeffs = np.zeros(nknots+4)
121  coeffs[0] = -2*cusp_val*Delta_val + param[1]
122  coeffs[1] = param[0]
123  coeffs[2] = param[1]
124  coeffs[3:-3] = param[2:]
125
126  deriv_jastrow_spline = diff(jastrow_spline, xs)
127  deriv2_jastrow_spline = diff(jastrow_spline, xs, 2)
128
129  vals = []
130  for i in range(20):
131    x = 0.6*i
132    jv = jastrow_spline.subs({xs:x,Delta:Delta_val})
133    jd = deriv_jastrow_spline.subs({xs:x,Delta:Delta_val})
134    jdd = deriv2_jastrow_spline.subs({xs:x,Delta:Delta_val})
135    subslist = dict()
136    for i in range(12):
137      subslist[c[i]] = coeffs[i]
138    vals.append((x,jv.subs(subslist),jd.subs(subslist),jdd.subs(subslist)))
139
140 # Assumes
141 # struct JValues
142 # {
143 #  double r;
144 #  double u;
145 #  double du;
146 #  double ddu;
147 # };
148  tmpl = """
149 const int N = {N};
150 JValues Vals[N] = {{
151   {values}
152 }};
153"""
154  fmt_values = ',\n  '.join("{%.2f, %15.10g, %15.10g, %15.10g}"%(r,u,du,ddu) for r,u,du,ddu in vals)
155  s = tmpl.format(N=len(vals), values=fmt_values)
156  print(s)
157
158
159# Generate output for these parameters
160
161#<jastrow name=\"J2\" type=\"Two-Body\" function=\"Bspline\" print=\"yes\"> \
162#   <correlation rcut=\"10\" size=\"10\" speciesA=\"u\" speciesB=\"d\"> \
163#      <coefficients id=\"ud\" type=\"Array\"> 0.02904699284 -0.1004179 -0.1752703883 -0.2232576505 -0.2728029201 -0.3253286875 -0.3624525145 -0.3958223107 -0.4268582166 -0.4394531176</coefficients> \
164#    </correlation> \
165#</jastrow> \
166
167def gen_case_two_body():
168  rcut_val = 10
169  param = np.array([0.02904699284, -0.1004179, -0.1752703883, -0.2232576505, -0.2728029201, -0.3253286875, -0.3624525145, -0.3958223107, -0.4268582166, -0.4394531176])
170  nknots = 10
171  cusp_val = -0.5
172  gen_bspline_jastrow(nknots, rcut_val, param, cusp_val)
173
174
175#   <jastrow type=\"One-Body\" name=\"J1\" function=\"bspline\" source=\"ion0\" print=\"yes\"> \
176#       <correlation elementType=\"C\" size=\"8\" cusp=\"0.0\"> \
177#               <coefficients id=\"eC\" type=\"Array\"> \
178#-0.2032153051 -0.1625595974 -0.143124599 -0.1216434956 -0.09919771951 -0.07111729038 \
179#-0.04445345869 -0.02135082917 \
180#               </coefficients> \
181#            </correlation> \
182#         </jastrow> \
183
184
185def gen_case_one_body():
186  rcut_val = 10
187  param = np.array([-0.2032153051, -0.1625595974, -0.143124599, -0.1216434956, -0.09919771951, -0.07111729038, -0.04445345869, -0.02135082917])
188  nknots = 8
189  cusp_val = 0.0
190  gen_bspline_jastrow(nknots, rcut_val, param, cusp_val)
191
192#   <jastrow type=\"One-Body\" name=\"J1\" function=\"bspline\" source=\"ion0\" print=\"yes\"> \
193#       <correlation elementType=\"C\" size=\"8\" cusp=\"0.0\"> \
194#               <coefficients id=\"eC\" type=\"Array\"> \
195#-0.2032153051 -0.1625595974 -0.143124599 -0.1216434956 -0.09919771951 -0.07111729038 \
196#-0.04445345869 -0.02135082917 \
197#               </coefficients> \
198#            </correlation> \
199#         </jastrow> \
200
201def gen_case_one_body_cusp():
202  rcut_val = 10
203  param = np.array([-0.2032153051, -0.1625595974, -0.143124599, -0.1216434956, -0.09919771951, -0.07111729038, -0.04445345869, -0.02135082917])
204  nknots = 8
205  cusp_val = 2.0
206  gen_bspline_jastrow(nknots, rcut_val, param, cusp_val)
207
208
209if __name__ == '__main__':
210  #gen_case_two_body()
211  #gen_case_one_body()
212  gen_case_one_body_cusp()
213
214