1 2# Generate Bspline Jastrow values to test against 3 4# Cut and paste the relevant part of the output into test_bspline_jastrow.cpp 5 6from sympy import * 7from collections import defaultdict 8import numpy as np 9 10def put_variable_on_lhs(cond, var_name): 11 """Return a conditional with the variable always on the left hand side""" 12 if cond.args[1] == var_name: 13 if isinstance(cond, StrictGreaterThan): 14 return StrictLessThan(var_name, cond.args[0]) 15 if isinstance(cond, GreaterThan): 16 return LessThan(var_name, cond.args[0]) 17 if isinstance(cond, StrictLessThan): 18 return StrictGreaterThan(var_name, cond.args[0]) 19 if isinstance(cond, LessThan): 20 return GreaterThan(var_name, cond.args[0]) 21 22 return cond 23 24 25def to_interval(ival, var_name): 26 """Convert relational expression to an Interval""" 27 min_val = None 28 lower_open = False 29 max_val = None 30 upper_open = True 31 if isinstance(ival, And): 32 for rel in ival.args: 33 rel = put_variable_on_lhs(rel, var_name) 34 if isinstance(rel, StrictGreaterThan): 35 min_val = rel.args[1] 36 #lower_open = True 37 elif isinstance(rel, GreaterThan): 38 min_val = rel.args[1] 39 #lower_open = False 40 elif isinstance(rel, StrictLessThan): 41 max_val = rel.args[1] 42 #upper_open = True 43 elif isinstance(rel, LessThan): 44 max_val = rel.args[1] 45 #upper_open = False 46 else: 47 print('unhandled ',rel) 48 49 if min_val == None or max_val == None: 50 print('error',ival) 51 52 return Interval(min_val, max_val, lower_open, upper_open) 53 54# Transpose the interval and coefficients 55# Note that interval [0,1) has the polynomial coefficients found in the einspline code 56# The other intervals could be shifted, and they would also have the same polynomials 57def transpose_interval_and_coefficients(sym_basis, var_name): 58 cond_map = defaultdict(list) 59 60 i1 = Interval(0,5, False, False) # interval for evaluation 61 for idx, s0 in enumerate(sym_basis): 62 for expr, cond in s0.args: 63 if cond != True: 64 i2 = to_interval(cond, var_name) 65 if not i1.is_disjoint(i2): 66 cond_map[i2].append( (idx, expr) ) 67 return cond_map 68 69# Create piecewise expression from the transposed intervals 70# basis_map - map of interval to list of spline expressions for that interval 71# c - coefficient symbol (needs to allow indexing) 72# xs - symbol for the position variable ('x') 73def recreate_piecewise(basis_map, c, xs): 74 args = [] 75 for cond, exprs in basis_map.items(): 76 e = 0 77 for idx, b in exprs: 78 e += c[idx] * b 79 args.append( (e, cond.as_relational(xs))) 80 args.append((0, True)) 81 return Piecewise(*args) 82 83 84 85 86 87def gen_bspline_jastrow(nknots, rcut_val, param, cusp_val): 88 xs = Symbol('x') 89 90 # Workaround for Sympy issue 91 # bspline_basis_set depends on ordering of relational items. 92 # The order for <something>*Delta is okay, the order for bare Delta is not (that is, when i = 1) 93 # Workaround is to multiply Delta by a dummy variable so there is always a multiplication. 94 # Issue 95 # https://github.com/sympy/sympy/issues/19262 96 # It has been fixed in mainline, but not a released version (1.6 series as of this writing) 97 98 dummy = Symbol('y', positive=True) 99 Delta = Symbol('Delta',positive=True) 100 knots = [i*Delta for i in range(nknots)] 101 all_knots = [i*Delta*dummy for i in range(-3,nknots+3)] 102 rcut = (nknots-1)*Delta 103 104 # Third-order bspline 105 jastrow_sym_basis1 = bspline_basis_set(3, all_knots, xs) 106 # Remove the dummy variable after it has served its purpose 107 jastrow_sym_basis = [s.subs(dummy, 1) for s in jastrow_sym_basis1] 108 109 print("Number of basis functions = ",len(jastrow_sym_basis)) 110 111 # Rearrange the basis and conditionals into a more useful form 112 jastrow_cond_map = transpose_interval_and_coefficients(jastrow_sym_basis, xs) 113 c = IndexedBase('c',shape=(nknots+3)) 114 #c = MatrixSymbol('c',nknots+2,1) # better for code-gen 115 jastrow_spline = recreate_piecewise(jastrow_cond_map, c, xs) 116 117 118 Delta_val = rcut_val*1.0/(nknots+1) 119 120 coeffs = np.zeros(nknots+4) 121 coeffs[0] = -2*cusp_val*Delta_val + param[1] 122 coeffs[1] = param[0] 123 coeffs[2] = param[1] 124 coeffs[3:-3] = param[2:] 125 126 deriv_jastrow_spline = diff(jastrow_spline, xs) 127 deriv2_jastrow_spline = diff(jastrow_spline, xs, 2) 128 129 vals = [] 130 for i in range(20): 131 x = 0.6*i 132 jv = jastrow_spline.subs({xs:x,Delta:Delta_val}) 133 jd = deriv_jastrow_spline.subs({xs:x,Delta:Delta_val}) 134 jdd = deriv2_jastrow_spline.subs({xs:x,Delta:Delta_val}) 135 subslist = dict() 136 for i in range(12): 137 subslist[c[i]] = coeffs[i] 138 vals.append((x,jv.subs(subslist),jd.subs(subslist),jdd.subs(subslist))) 139 140 # Assumes 141 # struct JValues 142 # { 143 # double r; 144 # double u; 145 # double du; 146 # double ddu; 147 # }; 148 tmpl = """ 149 const int N = {N}; 150 JValues Vals[N] = {{ 151 {values} 152 }}; 153""" 154 fmt_values = ',\n '.join("{%.2f, %15.10g, %15.10g, %15.10g}"%(r,u,du,ddu) for r,u,du,ddu in vals) 155 s = tmpl.format(N=len(vals), values=fmt_values) 156 print(s) 157 158 159# Generate output for these parameters 160 161#<jastrow name=\"J2\" type=\"Two-Body\" function=\"Bspline\" print=\"yes\"> \ 162# <correlation rcut=\"10\" size=\"10\" speciesA=\"u\" speciesB=\"d\"> \ 163# <coefficients id=\"ud\" type=\"Array\"> 0.02904699284 -0.1004179 -0.1752703883 -0.2232576505 -0.2728029201 -0.3253286875 -0.3624525145 -0.3958223107 -0.4268582166 -0.4394531176</coefficients> \ 164# </correlation> \ 165#</jastrow> \ 166 167def gen_case_two_body(): 168 rcut_val = 10 169 param = np.array([0.02904699284, -0.1004179, -0.1752703883, -0.2232576505, -0.2728029201, -0.3253286875, -0.3624525145, -0.3958223107, -0.4268582166, -0.4394531176]) 170 nknots = 10 171 cusp_val = -0.5 172 gen_bspline_jastrow(nknots, rcut_val, param, cusp_val) 173 174 175# <jastrow type=\"One-Body\" name=\"J1\" function=\"bspline\" source=\"ion0\" print=\"yes\"> \ 176# <correlation elementType=\"C\" size=\"8\" cusp=\"0.0\"> \ 177# <coefficients id=\"eC\" type=\"Array\"> \ 178#-0.2032153051 -0.1625595974 -0.143124599 -0.1216434956 -0.09919771951 -0.07111729038 \ 179#-0.04445345869 -0.02135082917 \ 180# </coefficients> \ 181# </correlation> \ 182# </jastrow> \ 183 184 185def gen_case_one_body(): 186 rcut_val = 10 187 param = np.array([-0.2032153051, -0.1625595974, -0.143124599, -0.1216434956, -0.09919771951, -0.07111729038, -0.04445345869, -0.02135082917]) 188 nknots = 8 189 cusp_val = 0.0 190 gen_bspline_jastrow(nknots, rcut_val, param, cusp_val) 191 192# <jastrow type=\"One-Body\" name=\"J1\" function=\"bspline\" source=\"ion0\" print=\"yes\"> \ 193# <correlation elementType=\"C\" size=\"8\" cusp=\"0.0\"> \ 194# <coefficients id=\"eC\" type=\"Array\"> \ 195#-0.2032153051 -0.1625595974 -0.143124599 -0.1216434956 -0.09919771951 -0.07111729038 \ 196#-0.04445345869 -0.02135082917 \ 197# </coefficients> \ 198# </correlation> \ 199# </jastrow> \ 200 201def gen_case_one_body_cusp(): 202 rcut_val = 10 203 param = np.array([-0.2032153051, -0.1625595974, -0.143124599, -0.1216434956, -0.09919771951, -0.07111729038, -0.04445345869, -0.02135082917]) 204 nknots = 8 205 cusp_val = 2.0 206 gen_bspline_jastrow(nknots, rcut_val, param, cusp_val) 207 208 209if __name__ == '__main__': 210 #gen_case_two_body() 211 #gen_case_one_body() 212 gen_case_one_body_cusp() 213 214