1#!/usr/bin/python -E 2# 3# -E Suppresses all PYTHON* environment variables. In particular 4# PYTHONHOME is a source of disaster. 5# 6# 7# === README === 8# 9# This program replaces function calls of density functionals with the 10# corresponding results of a density functional subroutine call. In addition 11# the appropriate subroutine call is also spliced in at the right place. 12# 13# 14# BACKGROUND 15# ---------- 16# 17# In the normal operation of Maxima generating source code one would specify 18# the expression for the density functional in full. The resulting (typically 19# very large) expression is then differentiated in every required way, the 20# catalogue of expressions is optimized and everything is transformed into 21# Fortran. This leads to very large, illegible, subroutines that may take 22# hours to generate, and tens of minutes to compile (e.g. nwxcm_c_tpss03_d3 23# is one subroutine of 20857 lines and 1.25 MB in size). Clearly this is 24# undesirable. 25# 26# In practice the functionals implemented in these routines can be constructed 27# from other functionals. For example, the TPSS correlation functional 28# mentioned above is expressed in terms of a modified PBE functional, which 29# in turn is specified in terms of the PW91 LDA functional. When hand coding 30# functionals this structure is maintained to make functionals manageable. 31# Ideally we would like to preserve this structure while using symbolic algebra 32# to generate the source code. 33# 34# Using symbolic algebra it is in principle possible to do this as one can 35# specify a function in terms of other functions that Maxima does not know. 36# For example, one can specify a function as 37# 38# f(x,y) := (x+y)^2*h(y); 39# 40# where Maxima does not know what h(y) is. Differentiating f(x,y) wrt. y 41# Maxima will generate 42# 43# diff(f(x,y),y) := 2*(x+y)*h(y)+(x+y)^2*'diff(h(y),y,1) 44# 45# where 'diff indicates that a derivative is required. The Fortran generated 46# by Maxima will reference h(y) and 'diff(h(y),y,1) which is not necessarily 47# valid Fortran. If we know what h(y) and 'diff(h(y),y,1) stand for this problem 48# can resolved by replacing these entities with valid Fortran functions or 49# variables. In that fashion Maxima can be used to generate source code while 50# preserving the structure of the original functional implementation. 51# 52# This script implements these source code transformations. 53# 54# 55# APPROACH 56# -------- 57# 58# The main complication in this program is that the expression optimization 59# that Maxima performs messes things up as in the example above it would 60# break out the references to h(y) as in 61# 62# t1 = h(y) 63# diff(f(x,y),y) := 2*(x+y)*t1+(x+y)^2*'diff(t1,y,1) 64# 65# Also note that h may be invoked multiple times as in, for example, 66# 67# f(x,y) := (x+y)^2*h(x,y) - x^2*h(x,0) - y^2*h(0,y) 68# 69# so we need to distinguish every invocation based on the arguments passed in. 70# In addition, in our case, h is actually implemented as a subroutine that 71# returns the function value and the values of the various derivatives. The 72# inputs and outputs are stored in arrays. The steps required are: 73# 74# 1. All Fortran lines need to be unwrapped so that simple text analysis 75# can reliably find references to h(x,y), 'diff(h(x,y)), etc. 76# 2. Where h(x,y) is invoked we need to generate arrays x and y, as well as 77# arrays for the required derivatives and initialize them. 78# 3. We need to insert the array declarations in the type declarations. 79# 4. We need to insert the subroutine calls at the appropriate places (i.e. 80# right before the first place where results are used as we need to make 81# sure that the input data is defined). 82# 5. Where h(x,y) or its derivatives are referenced we need to substitute 83# the appropriate variables. 84# 6. The Fortran lines need to be rewrapped such that we use at most 72 columns. 85# 86import re 87import sys 88import string 89 90type_autoxc = 1 91type_autoxcDs = 2 92 93ifbranch_closedshell = 1 94ifbranch_openshell = 2 95 96func_invalid = -1 97func_lda = 1 98func_gga = 2 99func_mgga = 3 100 101version = "$Id$" 102version = version.split() 103revision = version[1]+" revision "+version[2]+" "+version[3] 104 105def usage(code): 106 """ 107 Print usage information 108 """ 109 sys.stderr.write("Insert subroutine calls to functionals in Maxima generated code.") 110 sys.stderr.write("") 111 sys.stderr.write(" %s [-h] [-v|--version] < <filein> > <fileout>"%sys.argv[0]) 112 sys.stderr.write("") 113 sys.stderr.write("-h Print this information") 114 sys.stderr.write("-v|--version Print the version data") 115 sys.stderr.write("<filein> Raw Fortran from Maxima") 116 sys.stderr.write("<fileout> Fortran from Maxima with subroutine calls") 117 sys.stderr.write(" for functionals") 118 sys.stderr.write("") 119 sys.stderr.write("$Id$") 120 sys.exit(code) 121 122def var_to_int(var): 123 """ 124 Convert a variable name (such as "t5") to an integer (in this case 5). 125 Variables like "Amat(iq,D1_RA)" should never show up here. 126 """ 127 return int(var[1:]) 128 129def unwrap_lines(lines_in): 130 """ 131 Take a list of lines of Fortran and unwrap all continuation lines. 132 The result is stored in a new list which is returned. 133 All the input lines end with a '\n' character, the output lines have no 134 newline characters. 135 """ 136 pattern = re.compile(" [0-9:;<=>?@+]") 137 lines_out = [] 138 longline = "" 139 length = 0 140 for line in lines_in: 141 length = len(line) 142 if pattern.match(line): 143 shortline = line[6:length-1] 144 longline += shortline.lstrip() 145 elif length > 1: 146 if len(longline) > 0: 147 lines_out.append(longline) 148 longline = line[:length-1] 149 if len(longline) > 0: 150 lines_out.append(longline) 151 return lines_out 152 153def find_subroutine(lines,lineno): 154 """ 155 Given the starting line number (lineno) and a list of lines find the 156 line numbers where the next subroutine starts and ends. Here we use 157 Fortran90 coding styles so the starting and end points can be identified 158 by looking for 159 - subroutine 160 - end subroutine 161 A tuple with the corresponding starting and end line numbers is returned. 162 """ 163 length = len(lines) 164 line = lineno 165 lineno_start = -1 166 lineno_end = -1 167 pattern = re.compile(" subroutine") 168 while not pattern.match(lines[line]): 169 line += 1 170 if line >= length: 171 break 172 if line < length: 173 lineno_start = line 174 # next statement needed to guarantee that line is in the valid range 175 if line >= length: 176 line = length-1 177 pattern = re.compile(" end subroutine") 178 while not pattern.match(lines[line]): 179 line += 1 180 if line >= length: 181 break 182 if line < length: 183 lineno_end = line 184 return (lineno_start,lineno_end) 185 186def find_code_skeleton_type(lines,subr_lines): 187 """ 188 Scan the lines of a subroutine to find out whether it is an autoxc or and 189 autoxc-Ds generated subroutine. 190 """ 191 (lineno_start,lineno_end) = subr_lines 192 pattern = re.compile("if \(taua\.gt\.tol_rho\) then") 193 subr_type = type_autoxc 194 line = lineno_start 195 while line <= lineno_end: 196 if pattern.search(lines[line]): 197 subr_type = type_autoxcDs 198 break 199 line += 1 200 return subr_type 201 202def find_autoxc_code_skeleton(lines,subr_lines): 203 """ 204 Assuming that the subroutine is of the autoxc code skeleton type find all 205 the if-branches and return a list of all the (start,end) tuples. 206 """ 207 (lineno_start,lineno_end) = subr_lines 208 ifbranches = [] 209 line = lineno_start 210 ifstart = -1 211 ifend = -1 212 # 213 pattern = re.compile("if \(rhoa\.gt\.tol_rho\) then") 214 while line <= lineno_end: 215 if pattern.search(lines[line]): 216 ifstart = line+1 217 break 218 line += 1 219 # 220 pattern = re.compile("endif ! rhoa\.gt\.tol_rho") 221 while line <= lineno_end: 222 if pattern.search(lines[line]): 223 ifend = line-1 224 break 225 line += 1 226 # 227 ifbranches.append((ifstart,ifend)) 228 # 229 ifstarta = -1 230 ifenda = -1 231 ifstartb = -1 232 ifendb = -1 233 ifstartc = -1 234 ifendc = -1 235 pattern = re.compile("if \(rhoa\.gt\.tol_rho\.and\.rhob\.gt\.tol_rho\) then") 236 while line <= lineno_end: 237 if pattern.search(lines[line]): 238 ifstarta = line+1 239 break 240 line += 1 241 # 242 pattern = re.compile("elseif \(rhoa\.gt\.tol_rho\.and\.rhob\.le\.tol_rho\) then") 243 while line <= lineno_end: 244 if pattern.search(lines[line]): 245 ifenda = line-1 246 ifstartb = line+1 247 break 248 line += 1 249 # 250 pattern = re.compile("elseif \(rhoa\.le\.tol_rho\.and\.rhob\.gt\.tol_rho\) then") 251 while line <= lineno_end: 252 if pattern.search(lines[line]): 253 ifendb = line-1 254 ifstartc = line+1 255 break 256 line += 1 257 # 258 pattern = re.compile("endif ! rhoa\.gt\.tol_rho\.and\.rhob\.gt\.tol_rho") 259 while line <= lineno_end: 260 if pattern.search(lines[line]): 261 ifendc = line-1 262 break 263 line += 1 264 # 265 ifbranches.append((ifstarta,ifenda)) 266 ifbranches.append((ifstartb,ifendb)) 267 ifbranches.append((ifstartc,ifendc)) 268 # 269 return ifbranches 270 271def find_autoxcDs_code_skeleton(lines,subr_lines): 272 """ 273 Assuming that the subroutine is of the autoxc-Ds code skeleton type find all 274 the if-branches and return a list of all the (start,end) tuples. 275 """ 276 (lineno_start,lineno_end) = subr_lines 277 ifbranches = [] 278 line = lineno_start 279 ifstarta = -1 280 ifenda = -1 281 ifstartb = -1 282 ifendb = -1 283 # 284 pattern = re.compile("if \(taua\.gt\.tol_rho\) then") 285 while line <= lineno_end: 286 if pattern.search(lines[line]): 287 ifstarta = line+1 288 break 289 line += 1 290 # 291 pattern = re.compile("else") 292 while line <= lineno_end: 293 if pattern.search(lines[line]): 294 ifenda = line-1 295 ifstartb = line+1 296 break 297 line += 1 298 # 299 pattern = re.compile("endif") 300 while line <= lineno_end: 301 if pattern.search(lines[line]): 302 ifendb = line-1 303 break 304 line += 1 305 # 306 ifbranches.append((ifstarta,ifenda)) 307 ifbranches.append((ifstartb,ifendb)) 308 # 309 for x in range(0,3): 310 ifstarta = -1 311 ifenda = -1 312 ifstartb = -1 313 ifendb = -1 314 ifstartc = -1 315 ifendc = -1 316 ifstartd = -1 317 ifendd = -1 318 pattern = re.compile("if \(taua\.gt\.tol_rho\.and\.taub\.gt\.tol_rho\) then") 319 while line <= lineno_end: 320 if pattern.search(lines[line]): 321 ifstarta = line+1 322 break 323 line += 1 324 # 325 pattern = re.compile("elseif \(taua\.gt\.tol_rho\.and\.taub\.le\.tol_rho\) then") 326 while line <= lineno_end: 327 if pattern.search(lines[line]): 328 ifenda = line-1 329 ifstartb = line+1 330 break 331 line += 1 332 # 333 pattern = re.compile("elseif \(taua\.le\.tol_rho\.and\.taub\.gt\.tol_rho\) then") 334 while line <= lineno_end: 335 if pattern.search(lines[line]): 336 ifendb = line-1 337 ifstartc = line+1 338 break 339 line += 1 340 # 341 # needed because "else" is a substring of "elseif ..." 342 line = ifstartc 343 pattern = re.compile("else") 344 while line <= lineno_end: 345 if pattern.search(lines[line]): 346 ifendc = line-1 347 ifstartd = line+1 348 break 349 line += 1 350 # 351 pattern = re.compile("endif") 352 while line <= lineno_end: 353 if pattern.search(lines[line]): 354 ifendd = line-1 355 break 356 line += 1 357 # 358 ifbranches.append((ifstarta,ifenda)) 359 ifbranches.append((ifstartb,ifendb)) 360 ifbranches.append((ifstartc,ifendc)) 361 ifbranches.append((ifstartd,ifendd)) 362 # 363 return ifbranches 364 365def find_type_declaration_insertion_point(lines,subr_lines): 366 """ 367 Find the point in the subroutine where the array declarations for the 368 subroutine calls can be inserted. 369 """ 370 (lineno_start,lineno_end) = subr_lines 371 pattern = re.compile("#include \"nwxc_param.fh\"") 372 line_insert = -1 373 line = lineno_start 374 while line <= lineno_end: 375 if pattern.match(lines[line]): 376 line_insert = line 377 break 378 line += 1 379 return line_insert 380 381def find_subroutine_call_insertion_point(lines,ifbranch,varname): 382 """ 383 The assumption is that the symbolic algebra optimization will always break 384 out the function evaluation. Therefore there always is a line where the 385 functional value is assigned to a variable. The actual functional 386 subroutine call, at the latest, has to happen on the line before. 387 This routine returns the line where the variable first appears. 388 """ 389 (lineno_start,lineno_end) = ifbranch 390 line = lineno_start 391 insert_point = -1 392 while line <= lineno_end: 393 aline = lines[line] 394 aline = aline.split(" = ") 395 key = aline[0].lstrip() 396 if key == varname: 397 insert_point = line 398 break 399 line += 1 400 return insert_point 401 402def collect_subroutine_calls(lines,ifbranch_lines): 403 """ 404 Collect a dictionary of all subroutine call instances. In the source code 405 on entry the "subroutine call" is given in the form of a function call, e.g. 406 t5 = nwxc_c_Mpbe(rhoa,0.0d+0,gammaaa,0.0d+0,0.0d+0) 407 we need to know the "nwxc_c_Mpbe(rhoa,0.0d+0,gammaaa,0.0d+0,0.0d+0)" part. 408 The key in the dictionary is going to be the variable name (i.e. "t5") as we 409 will have to replace those variable instances with an array element 410 reference. 411 The resulting dictionary is returned. 412 """ 413 (lineno_start,lineno_end) = ifbranch_lines 414 dict = {} 415 pattern = re.compile("nwxc") 416 line = lineno_start 417 while line <= lineno_end: 418 if pattern.search(lines[line]): 419 aline = lines[line] 420 aline = aline.split(" = ") 421 key = aline[0].lstrip() 422 dict[key] = aline[1] 423 line += 1 424 return dict 425 426def delete_lines(lines,ifbranches): 427 """ 428 Collect a list of lines that should be deleted (i.e. skipped) when writing 429 the output subroutine. The lines in question are the ones generated from the 430 "at" command, e.g. 431 t10(1) = (gammaaa = gammaaa) 432 t10(2) = (gammaab = gammaaa) 433 t10(3) = (gammabb = gammaaa) 434 t10(4) = (rhoa = rhoa) 435 t10(5) = (rhob = rhoa) 436 t10(6) = (taua = taua) 437 t10(7) = (taub = taua) 438 The list of line numbers is returned. 439 """ 440 dlist = [] 441 for ifbranch in ifbranches: 442 (lineno_start,lineno_end) = ifbranch 443 line = lineno_start 444 while line <= lineno_end: 445 aline = lines[line] 446 aline = aline.split(" = ") 447 if len(aline) == 3: 448 # line: t10(2) = (gammaab = gammaaa) 449 dlist.append(line) 450 line += 1 451 return dlist 452 453def find_maxno_calls(lines,ifbranches): 454 """ 455 In every if-branch some subroutine calls will be inserted. For the inputs 456 we need one set of arrays for rho, gamma and tau. However, for the outputs 457 we need a separate set of arrays for each separate call as the results may 458 appear in multiple place throughout the remained of the if-branch. In 459 addition there is no guarantee that the results are not used in overlapping 460 code segments. Therefore we need to count how many different sets of 461 output variables we need to declare. This routine returns the resulting 462 number. 463 """ 464 varsets = 0 465 for ifbranch in ifbranches: 466 dict = collect_subroutine_calls(lines,ifbranch) 467 numcalls = len(dict) 468 if numcalls > varsets: 469 varsets = numcalls 470 return varsets 471 472def append_declarations(olines,varsets,orderdiff): 473 """ 474 Given the list of output lines so far, the number of output variable sets 475 and the maximum order of differentiation append the required array 476 declarations. The new list of lines is returned. 477 """ 478 olines.append(" double precision sr(NCOL_RHO)") 479 olines.append(" double precision sg(NCOL_GAMMA)") 480 olines.append(" double precision st(NCOL_TAU)") 481 for ii in range(1,varsets+1): 482 line = " double precision s"+str(ii)+"f" 483 olines.append(line) 484 line = " double precision s"+str(ii)+"a(NCOL_AMAT)" 485 olines.append(line) 486 line = " double precision s"+str(ii)+"c(NCOL_CMAT)" 487 olines.append(line) 488 line = " double precision s"+str(ii)+"m(NCOL_MMAT)" 489 olines.append(line) 490 if orderdiff > 1: 491 line = " double precision s"+str(ii)+"a2(NCOL_AMAT2)" 492 olines.append(line) 493 line = " double precision s"+str(ii)+"c2(NCOL_CMAT2)" 494 olines.append(line) 495 line = " double precision s"+str(ii)+"m2(NCOL_MMAT2)" 496 olines.append(line) 497 if orderdiff > 2: 498 line = " double precision s"+str(ii)+"a3(NCOL_AMAT3)" 499 olines.append(line) 500 line = " double precision s"+str(ii)+"c3(NCOL_CMAT3)" 501 olines.append(line) 502 line = " double precision s"+str(ii)+"m3(NCOL_MMAT3)" 503 olines.append(line) 504 return olines 505 506def append_subroutine_call(olines,subrname,arglist,orderdiff,ifbranch_kind,funckind,num,indent): 507 """ 508 Given the list of output lines so far, the subroutine name, the input 509 argument list, the order of differentiation, the functional kind as well as 510 the number of the output variable set, generate the subroutine call code. 511 The subroutine call code involves initializing the input and output 512 arguments, and constructing the actual call itself. 513 The list of output lines is returned. 514 """ 515 #DEBUG 516 #print "append_subroutine_call: arglist:",arglist 517 #DEBUG 518 if ifbranch_kind == ifbranch_closedshell: 519 if arglist[1] == '0.0d+0' or arglist[2] == '0.0d+0': 520 # We are dealing with the open shell term of the Stoll partitioning 521 # of the correlation energy 522 line = indent+"sr(R_A) = "+arglist[1] 523 olines.append(line) 524 line = indent+"sr(R_B) = "+arglist[2] 525 olines.append(line) 526 if funckind >= func_gga: 527 line = indent+"sg(G_AA) = "+arglist[3] 528 olines.append(line) 529 line = indent+"sg(G_AB) = "+arglist[4] 530 olines.append(line) 531 line = indent+"sg(G_BB) = "+arglist[5] 532 olines.append(line) 533 if funckind >= func_mgga: 534 line = indent+"st(T_A) = "+arglist[6] 535 olines.append(line) 536 line = indent+"st(T_B) = "+arglist[7] 537 olines.append(line) 538 else: 539 # We are dealing with a regular closed shell call 540 line = indent+"sr(R_T) = 2.0d0*"+arglist[1] 541 olines.append(line) 542 if funckind >= func_gga: 543 line = indent+"sg(G_TT) = 4.0d0*"+arglist[3] 544 olines.append(line) 545 if funckind >= func_mgga: 546 line = indent+"st(T_T) = 2.0d0*"+arglist[6] 547 olines.append(line) 548 elif ifbranch_kind == ifbranch_openshell: 549 # We are dealing with a regular open shell call 550 line = indent+"sr(R_A) = "+arglist[1] 551 olines.append(line) 552 line = indent+"sr(R_B) = "+arglist[2] 553 olines.append(line) 554 if funckind >= func_gga: 555 line = indent+"sg(G_AA) = "+arglist[3] 556 olines.append(line) 557 line = indent+"sg(G_AB) = "+arglist[4] 558 olines.append(line) 559 line = indent+"sg(G_BB) = "+arglist[5] 560 olines.append(line) 561 if funckind >= func_mgga: 562 line = indent+"st(T_A) = "+arglist[6] 563 olines.append(line) 564 line = indent+"st(T_B) = "+arglist[7] 565 olines.append(line) 566 else: 567 sys.stderr.write("append_subroutine_call: invalid ifbranch_kind: %d\n"%ifbranch_kind) 568 sys.exit(20) 569 line = indent+"s"+str(num)+"f = 0.0d0" 570 olines.append(line) 571 line = indent+"call dcopy(NCOL_AMAT,0.0d0,0,s"+str(num)+"a,1)" 572 olines.append(line) 573 if orderdiff >= 2: 574 line = indent+"call dcopy(NCOL_AMAT2,0.0d0,0,s"+str(num)+"a2,1)" 575 olines.append(line) 576 if orderdiff >= 3: 577 line = indent+"call dcopy(NCOL_AMAT3,0.0d0,0,s"+str(num)+"a3,1)" 578 olines.append(line) 579 if funckind >= func_gga: 580 line = indent+"call dcopy(NCOL_CMAT,0.0d0,0,s"+str(num)+"c,1)" 581 olines.append(line) 582 if orderdiff >= 2: 583 line = indent+"call dcopy(NCOL_CMAT2,0.0d0,0,s"+str(num)+"c2,1)" 584 olines.append(line) 585 if orderdiff >= 3: 586 line = indent+"call dcopy(NCOL_CMAT3,0.0d0,0,s"+str(num)+"c3,1)" 587 olines.append(line) 588 if funckind >= func_mgga: 589 line = indent+"call dcopy(NCOL_MMAT,0.0d0,0,s"+str(num)+"m,1)" 590 olines.append(line) 591 if orderdiff >= 2: 592 line = indent+"call dcopy(NCOL_MMAT2,0.0d0,0,s"+str(num)+"m2,1)" 593 olines.append(line) 594 if orderdiff >= 3: 595 line = indent+"call dcopy(NCOL_MMAT3,0.0d0,0,s"+str(num)+"m3,1)" 596 olines.append(line) 597 line = indent+"call "+subrname 598 if orderdiff == 2: 599 line = line+"_d2" 600 elif orderdiff == 3: 601 line = line+"_d3" 602 line = line+"("+arglist[0]+",tol_rho," 603 if arglist[1] == '0.0d+0' or arglist[2] == '0.0d+0': 604 # We are dealing with the Stoll partitioning of the correlation energy 605 line = line+"2" 606 else: 607 line = line+"ipol" 608 line = line+",1,1.0d0,sr" 609 if funckind >= func_gga: 610 line = line+",sg" 611 if funckind >= func_mgga: 612 line = line+",st" 613 line = line+",s"+str(num)+"f,s"+str(num)+"a" 614 if orderdiff >= 2: 615 line = line+",s"+str(num)+"a2" 616 if orderdiff >= 3: 617 line = line+",s"+str(num)+"a3" 618 if funckind >= func_gga: 619 line = line+",s"+str(num)+"c" 620 if orderdiff >= 2: 621 line = line+",s"+str(num)+"c2" 622 if orderdiff >= 3: 623 line = line+",s"+str(num)+"c3" 624 if funckind >= func_mgga: 625 line = line+",s"+str(num)+"m" 626 if orderdiff >= 2: 627 line = line+",s"+str(num)+"m2" 628 if orderdiff >= 3: 629 line = line+",s"+str(num)+"m3" 630 line = line+")" 631 olines.append(line) 632 return olines 633 634 635def find_max_order_diff(lines,subr_lines): 636 """ 637 Work out the maximum order of differentiation for a given subroutine. 638 The maximum order is returned. 639 """ 640 (lineno_start,lineno_end) = subr_lines 641 aline = lines[lineno_start] 642 orderdiff = 0 643 pattern_d2 = re.compile("_d2\(") 644 pattern_d3 = re.compile("_d3\(") 645 if pattern_d2.search(aline): 646 orderdiff = 2 647 elif pattern_d3.search(aline): 648 orderdiff = 3 649 else: 650 orderdiff = 1 651 return orderdiff 652 653def find_functional_kind(funccall): 654 """ 655 Given the way the functional was invoked as a function work out whether 656 the functional is an LDA, GGA or Meta-GGA functional. We establish this 657 by counting the number of arguments which is 3 for LDA, 6 for GGA and 658 8 for a Meta-GGA functional. The numbers include the "param" argument 659 as well as the density dependent quantities, i.e. for LDA we have 3 660 arguments: param + rhoa + rhob. 661 The functional kind is returned. 662 """ 663 data = funccall 664 data = data.split(",") 665 length = len(data) 666 funckind = func_invalid 667 if length == 3: 668 funckind = func_lda 669 elif length == 6: 670 funckind = func_gga 671 elif length == 8: 672 funckind = func_mgga 673 return funckind 674 675def make_functional_name(funccall): 676 """ 677 Given the way the functional was invoked as a function work out what the 678 name of the subroutine to call is. 679 The subroutine name is returned. 680 """ 681 data = funccall 682 data = data.split("(") 683 data = data[0] 684 data = data.replace("nwxc_","nwxcm_") 685 return data 686 687def make_input_args_list(funccall): 688 """ 689 Given the way the functional was invoked as a function work out what the 690 argument list is. This list will be used to initialize the input data to 691 the actual subroutine call. 692 The list of arguments is returned. 693 """ 694 data = funccall.split("(") 695 data = data[1].split(")") 696 data = data[0].split(",") 697 return data 698 699def find_varname(dict,diffstr): 700 """ 701 Given the subroutine call dictionary and the diffstr representing the 702 "'diff(...)" command work out which variable is indicated. If 703 diffstr is e.g. 'diff(t5,gammaa,2,rhob,1) then we return e.g. 704 s3c3(D3_RB_GAA_GAA). 705 If diffstr is just a variable name, e.g. t5. then it represents the 706 functional value rather than a derivative and we return e.g. s3f. 707 The indicated variable is returned as a string. 708 """ 709 length = 0 710 data = diffstr 711 #DEBUG 712 #print "find_varname: data:",data 713 #DEBUG 714 if "%at(" == data[:4]: 715 data = data[4:-1] 716 iend = data.rfind(",") 717 data = data[:iend] 718 #DEBUG 719 #print "find_varname: at:",data 720 #DEBUG 721 if "'diff(" == data[:6]: 722 data = data[6:-1] 723 #DEBUG 724 #print "find_varname: diff:",data 725 #DEBUG 726 data = data.split(",") 727 callref = data[0] 728 list = dict.keys() 729 list = sorted(list,key=var_to_int) 730 num = -1 731 if len(list) > 0: 732 num = 0 733 #DEBUG 734 #print "find_varname:",callref,list 735 #DEBUG 736 while callref != list[num]: 737 num += 1 738 #DEBUG 739 #print "find_varname: num:",num 740 #DEBUG 741 lengthl = len(list) 742 if num == lengthl: 743 sys.stdout.write("entity %s not found\n"%callref) 744 for jj in range(0,length): 745 sys.stdout.write("list %d: %s\n"%(jj,list[jj])) 746 sys.exit(10) 747 # num is now the variable set number 748 lengthd = len(data) 749 #DEBUG 750 #print "find_varname: lengthd:",lengthd 751 #DEBUG 752 if lengthd == 1: 753 # this is the energy functional value 754 var_name = "s"+str(num+1)+"f" 755 return var_name 756 orderdiff = 0 757 if lengthd >= 3: 758 orderdiff += int(data[2]) 759 if lengthd >= 5: 760 orderdiff += int(data[4]) 761 if lengthd >= 7: 762 orderdiff += int(data[6]) 763 #DEBUG 764 #print "find_varname: orderdiff:",orderdiff 765 #DEBUG 766 # orderdiff is now the order of differentiation 767 patternc = re.compile("gamma") 768 patternt = re.compile("tau") 769 var_func = func_lda 770 if lengthd >= 3: 771 if patternc.match(data[1]): 772 var_func = func_gga 773 if patternt.match(data[1]): 774 var_func = func_mgga 775 if lengthd >= 5: 776 if var_func < func_gga and patternc.match(data[3]): 777 var_func = func_gga 778 if var_func < func_mgga and patternt.match(data[3]): 779 var_func = func_mgga 780 if lengthd >= 7: 781 if var_func < func_gga and patternc.match(data[5]): 782 var_func = func_gga 783 if var_func < func_mgga and patternt.match(data[5]): 784 var_func = func_mgga 785 #DEBUG 786 #print "find_varname: var_func:",var_func 787 #DEBUG 788 var_char = "invalid" 789 if var_func == func_lda: 790 var_char = "a" 791 elif var_func == func_gga: 792 var_char = "c" 793 elif var_func == func_mgga: 794 var_char = "m" 795 # var_char is now the character representing the output matrix, 796 # a for Amat, c for Cmat, and m for Mmat 797 var_field = "D"+str(orderdiff) 798 if lengthd >= 3: 799 if "rhoa" == data[1]: 800 for ii in range(0,int(data[2])): 801 var_field = var_field + "_RA" 802 if lengthd >= 5: 803 if "rhoa" == data[3]: 804 for ii in range(0,int(data[4])): 805 var_field = var_field + "_RA" 806 if lengthd >= 7: 807 if "rhoa" == data[5]: 808 for ii in range(0,int(data[6])): 809 var_field = var_field + "_RA" 810 # 811 if lengthd >= 3: 812 if "rhob" == data[1]: 813 for ii in range(0,int(data[2])): 814 var_field = var_field + "_RB" 815 if lengthd >= 5: 816 if "rhob" == data[3]: 817 for ii in range(0,int(data[4])): 818 var_field = var_field + "_RB" 819 if lengthd >= 7: 820 if "rhob" == data[5]: 821 for ii in range(0,int(data[6])): 822 var_field = var_field + "_RB" 823 # 824 if lengthd >= 3: 825 if "gammaaa" == data[1]: 826 for ii in range(0,int(data[2])): 827 var_field = var_field + "_GAA" 828 if lengthd >= 5: 829 if "gammaaa" == data[3]: 830 for ii in range(0,int(data[4])): 831 var_field = var_field + "_GAA" 832 if lengthd >= 7: 833 if "gammaaa" == data[5]: 834 for ii in range(0,int(data[6])): 835 var_field = var_field + "_GAA" 836 # 837 if lengthd >= 3: 838 if "gammaab" == data[1]: 839 for ii in range(0,int(data[2])): 840 var_field = var_field + "_GAB" 841 if lengthd >= 5: 842 if "gammaab" == data[3]: 843 for ii in range(0,int(data[4])): 844 var_field = var_field + "_GAB" 845 if lengthd >= 7: 846 if "gammaab" == data[5]: 847 for ii in range(0,int(data[6])): 848 var_field = var_field + "_GAB" 849 # 850 if lengthd >= 3: 851 if "gammabb" == data[1]: 852 for ii in range(0,int(data[2])): 853 var_field = var_field + "_GBB" 854 if lengthd >= 5: 855 if "gammabb" == data[3]: 856 for ii in range(0,int(data[4])): 857 var_field = var_field + "_GBB" 858 if lengthd >= 7: 859 if "gammabb" == data[5]: 860 for ii in range(0,int(data[6])): 861 var_field = var_field + "_GBB" 862 # 863 if lengthd >= 3: 864 if "taua" == data[1]: 865 for ii in range(0,int(data[2])): 866 var_field = var_field + "_TA" 867 if lengthd >= 5: 868 if "taua" == data[3]: 869 for ii in range(0,int(data[4])): 870 var_field = var_field + "_TA" 871 if lengthd >= 7: 872 if "taua" == data[5]: 873 for ii in range(0,int(data[6])): 874 var_field = var_field + "_TA" 875 # 876 if lengthd >= 3: 877 if "taub" == data[1]: 878 for ii in range(0,int(data[2])): 879 var_field = var_field + "_TB" 880 if lengthd >= 5: 881 if "taub" == data[3]: 882 for ii in range(0,int(data[4])): 883 var_field = var_field + "_TB" 884 if lengthd >= 7: 885 if "taub" == data[5]: 886 for ii in range(0,int(data[6])): 887 var_field = var_field + "_TB" 888 # var_field now contains the array field, e.g. D1_RA, D3_GAA_TB_TB, etc. 889 var_name = "s"+str(num+1)+var_char 890 if orderdiff >= 2: 891 var_name = var_name+str(orderdiff) 892 var_name = var_name+"("+var_field+")" 893 return var_name 894 895def line_contains_var(line,var): 896 """ 897 Check whether a variable specified by "var" is contained within the line 898 specified by "line". E.g. var="t2" and line="a=b+t2*t4" should return 899 True whereas line="a=b+t21" should return False. 900 """ 901 regular = var+"[^0-9]" 902 pattern = re.compile(regular) 903 longline = line+" " 904 result = pattern.match(longline) 905 return (result != None) 906 907def find_indent(line): 908 """ 909 Given a line of source code work the indentation out and return a string 910 containing as many spaces as the indentation. 911 """ 912 pattern = re.compile("\s*") 913 obj = pattern.match(line) 914 indent = obj.group() 915 return indent 916 917def find_var_in_line(line,var_name): 918 """ 919 Find all locations where the variable given by "var_name" is used. 920 The locations are stored in a list which is returned. 921 """ 922 patterne = re.compile(" = ") 923 # If var_name = t2 we need to make sure we do not replace 924 # cmat2 as well. Patternw is needed to achieve that. So we look for 925 # t2 as well at2, if the end-points are equal the string found does 926 # not match the variable t2. 927 patternv = re.compile(var_name+"[^0-9]") 928 patternw = re.compile("a"+var_name+"[^0-9]") 929 aline = line+" " 930 found = patterne.search(aline) 931 list = [] 932 if found: 933 pos = found.end(0) 934 obj = patternv.search(aline,pos) 935 while obj: 936 objw = patternw.search(aline,pos) 937 if objw: 938 if objw.end(0) != obj.end(0): 939 list.append((obj.start(0),obj.end(0)-1)) 940 else: 941 list.append((obj.start(0),obj.end(0)-1)) 942 pos = obj.end(0) 943 obj = patternv.search(aline,pos) 944 return list 945 946def expand_var_in_line(line,list): 947 """ 948 The variable references are given in list in the form of tuples so that 949 line(begin:end) matches the variable exactly. However, the string we need 950 to replace might be larger, e.g. in the case we have an entity such as 951 'diff(t5,rhoa,1,gammaaa,2), or worse %at('diff(t5,rhoa,1),t10). 952 Here we expand the tuples in the list to cover these strings in full. 953 The list with updated tuples is returned. 954 """ 955 pattern = re.compile("\)") 956 item = 0 957 length = len(list) 958 while item < length: 959 (ibegin,iend) = list[item] 960 iibegin = ibegin-6 961 if line[iibegin:ibegin] == "'diff(": 962 iiend = pattern.search(line,iend).end(0) 963 list[item] = (iibegin,iiend) 964 (ibegin,iend) = list[item] 965 iibegin = ibegin-4 966 if line[iibegin:ibegin] == "%at(": 967 iiend = pattern.search(line,iend).end(0) 968 list[item] = (iibegin,iiend) 969 item += 1 970 return list 971 972def find_replace_var_in_range(lines,iline_begin,iline_end,dict,var): 973 """ 974 Find all locations in the range of lines (ibegin,iend) where the variable 975 "var" is referenced and replace those references with the proper functional 976 output variable. 977 The updated list of lines "lines" is returned. 978 """ 979 iline = iline_begin 980 while iline <= iline_end: 981 list = find_var_in_line(lines[iline],var) 982 list = expand_var_in_line(lines[iline],list) 983 aline = lines[iline] 984 bline = "" 985 oend = 0 986 for tuple in list: 987 (ibegin,iend) = tuple 988 diffstr = aline[ibegin:iend] 989 #DEBUG 990 #print "diffstr:",diffstr 991 #DEBUG 992 var_name = find_varname(dict,diffstr) 993 #DEBUG 994 #print "var_name:",var_name 995 #DEBUG 996 bline = bline+aline[oend:ibegin]+var_name 997 oend = iend 998 bline = bline+aline[oend:] 999 lines[iline] = bline 1000 iline += 1 1001 return lines 1002 1003def rewrap_line(longline): 1004 """ 1005 Break a given long line "longline" up into 72 character long chunks that 1006 conform the Fortran77 standard. The chunks are written to standard output. 1007 In addition we do not want to break the line in the middle of numbers. 1008 """ 1009 pattern = re.compile("\S") 1010 i = (pattern.search(longline)).start() 1011 indent = longline[:i] 1012 indent = indent[:5]+"+"+indent[7:]+" " 1013 while len(longline) > 72: 1014 i = -1 1015 # wrap before * / ( ) + or - 1016 i = max(i,string.rfind(longline,",",0,70)+1) 1017 i = max(i,string.rfind(longline,"*",0,71)) 1018 i = max(i,string.rfind(longline,"/",0,71)) 1019 i = max(i,string.rfind(longline,"(",0,71)) 1020 i = max(i,string.rfind(longline,")",0,71)) 1021 # wrap before + but not in the middle of a numerical constant... 1022 j = string.rfind(longline,"+",0,71) 1023 k = string.rfind(longline,"d+",0,71) 1024 if j-1 == k: 1025 j = string.rfind(longline,"+",0,k) 1026 i = max(i,j) 1027 # wrap before - but not in the middle of a numerical constant... 1028 j = string.rfind(longline,"-",0,71) 1029 k = string.rfind(longline,"d-",0,71) 1030 if j-1 == k: 1031 j = string.rfind(longline,"-",0,k) 1032 i = max(i,j) 1033 if i == -1: 1034 sys.stderr.write("No sensible break point found in:\n") 1035 sys.stderr.write(longline) 1036 sys.exit(1) 1037 elif i == 6: 1038 sys.stderr.write("Same break point found repeatedly in:\n") 1039 sys.stderr.write(longline) 1040 sys.exit(1) 1041 sys.stdout.write(longline[:i]+"\n") 1042 longline = indent + longline[i:] 1043 sys.stdout.write(longline+"\n") 1044 1045 1046if len(sys.argv) == 2: 1047 if sys.argv[1] == "-h": 1048 usage(0) 1049 elif sys.argv[1] == "-v" or sys.argv[1] == "--version": 1050 sys.stdout.write("%s\n"%revision) 1051 sys.exit(0) 1052 else: 1053 usage(1) 1054elif len(sys.argv) > 2: 1055 usage(1) 1056 1057ilines = sys.stdin.readlines() 1058ilines = unwrap_lines(ilines) 1059nlines = len(ilines) 1060#DEBUG 1061#file = open("junkjunk",'w') 1062#for line in ilines: 1063# file.write("%s\n"%line) 1064#file.close() 1065#DEBUG 1066olines = [] 1067line_start = 0 1068subr_lines = find_subroutine(ilines,line_start) 1069(subr_lines_start,subr_lines_end)=subr_lines 1070#DEBUG 1071#print "subrs: start,end:",subr_lines_start,subr_lines_end 1072#DEBUG 1073while subr_lines_start != -1 and subr_lines_end != -1: 1074 # 1075 # Roll forward to the beginning of the subroutine 1076 # 1077 while line_start < subr_lines_start: 1078 olines.append(ilines[line_start]) 1079 line_start += 1 1080 # 1081 # Work the subroutine type out and then work the if-branches out 1082 # 1083 code_skel_type = find_code_skeleton_type(ilines,subr_lines) 1084 if code_skel_type == type_autoxc: 1085 ifbranches = find_autoxc_code_skeleton(ilines,subr_lines) 1086 elif code_skel_type == type_autoxcDs: 1087 ifbranches = find_autoxcDs_code_skeleton(ilines,subr_lines) 1088 else: 1089 sys.stderr.write("Unexpected code_type\n") 1090 exit(20) 1091 # 1092 # Work the order of differentiation out 1093 # 1094 orderdiff = find_max_order_diff(ilines,subr_lines) 1095 #DEBUG 1096 #print "orderdiff:",orderdiff 1097 #DEBUG 1098 # 1099 # Work the declaration insertion point out 1100 # 1101 type_decl_insertion_point = find_type_declaration_insertion_point(ilines,subr_lines) 1102 # 1103 # How many additional variables do we need? 1104 # 1105 max_calls = find_maxno_calls(ilines,ifbranches) 1106 #DEBUG 1107 #print "max_calls:",max_calls 1108 #DEBUG 1109 # 1110 # Which lines do we need to drop? 1111 # 1112 delete_lines_list = delete_lines(ilines,ifbranches) 1113 #DEBUG 1114 #print "delete_list: ",delete_lines_list 1115 #DEBUG 1116 # 1117 # We know what additional variables we need to declare. 1118 # So copy more lines from the input file to the output until we reach 1119 # the declaration insertion point. Then call the function to the inject 1120 # the additional declarations in the output routine. 1121 # 1122 while line_start <= type_decl_insertion_point: 1123 olines.append(ilines[line_start]) 1124 line_start += 1 1125 olines = append_declarations(olines,max_calls,orderdiff) 1126 # 1127 # Go through branches and mess with subroutine calls 1128 # 1129 i_ifbranch = 0 1130 for ifbranch in ifbranches: 1131 i_ifbranch += 1 1132 if code_skel_type == type_autoxc and i_ifbranch <= 1: 1133 ifbranch_kind = ifbranch_closedshell 1134 elif code_skel_type == type_autoxcDs and i_ifbranch <= 2: 1135 ifbranch_kind = ifbranch_closedshell 1136 else: 1137 ifbranch_kind = ifbranch_openshell 1138 (line_if_start,line_if_end) = ifbranch 1139 call_lines = collect_subroutine_calls(ilines,ifbranch) 1140 #DEBUG 1141 #print "calls: start,end:",line_if_start,line_if_end 1142 #print "calls: call_lines:",call_lines 1143 #DEBUG 1144 call_vars = call_lines.keys() 1145 #DEBUG 1146 #print "calls: call_vars :",call_vars 1147 #DEBUG 1148 call_vars = sorted(call_vars,key=var_to_int) 1149 num = 0 1150 for var in call_vars: 1151 num += 1 1152 #DEBUG 1153 #print "looping over vars:",num,var 1154 #DEBUG 1155 call_insert = find_subroutine_call_insertion_point(ilines,ifbranch,var) 1156 indent = find_indent(ilines[call_insert]) 1157 ilines = find_replace_var_in_range(ilines,call_insert,line_if_end,call_lines,var) 1158 # var_name = find_varname() 1159 #DEBUG 1160 #print "ilines:",call_insert,":",ilines[call_insert] 1161 #DEBUG 1162 ilines[call_insert] = indent+var+" = s"+str(num)+"f" 1163 # Roll forward to just before the insertion point 1164 while line_start < call_insert: 1165 if not (line_start in delete_lines_list): 1166 olines.append(ilines[line_start]) 1167 line_start += 1 1168 funckind = find_functional_kind(call_lines[var]) 1169 funcname = make_functional_name(call_lines[var]) 1170 arglist = make_input_args_list(call_lines[var]) 1171 #DEBUG 1172 #print "funckind: ",funckind 1173 #print "funcname: ",funcname 1174 #print "arglist : ",arglist 1175 #DEBUG 1176 olines = append_subroutine_call(olines,funcname,arglist,orderdiff,ifbranch_kind,funckind,num,indent) 1177 #DEBUG 1178 #print "var: ",var,call_insert 1179 #DEBUG 1180 #break 1181 while line_start <= line_if_end: 1182 if not (line_start in delete_lines_list): 1183 olines.append(ilines[line_start]) 1184 line_start += 1 1185 #break 1186 while line_start <= subr_lines_end: 1187 olines.append(ilines[line_start]) 1188 line_start += 1 1189 subr_lines = find_subroutine(ilines,line_start) 1190 (subr_lines_start,subr_lines_end)=subr_lines 1191 #DEBUG 1192 #print "subrt: start,end:",subr_lines_start,subr_lines_end 1193 #DEBUG 1194while line_start < nlines: 1195 olines.append(ilines[line_start]) 1196 line_start += 1 1197for line in olines: 1198 rewrap_line(line) 1199 #sys.stdout.write("%s\n"%line) 1200