1#!/usr/bin/python -E
2#
3# -E  Suppresses all PYTHON* environment variables. In particular
4#     PYTHONHOME is a source of disaster.
5#
6#
7# === README ===
8#
9# This program replaces function calls of density functionals with the
10# corresponding results of a density functional subroutine call. In addition
11# the appropriate subroutine call is also spliced in at the right place.
12#
13#
14# BACKGROUND
15# ----------
16#
17# In the normal operation of Maxima generating source code one would specify
18# the expression for the density functional in full. The resulting (typically
19# very large) expression is then differentiated in every required way, the
20# catalogue of expressions is optimized and everything is transformed into
21# Fortran. This leads to very large, illegible, subroutines that may take
22# hours to generate, and tens of minutes to compile (e.g. nwxcm_c_tpss03_d3
23# is one subroutine of 20857 lines and 1.25 MB in size). Clearly this is
24# undesirable.
25#
26# In practice the functionals implemented in these routines can be constructed
27# from other functionals. For example, the TPSS correlation functional
28# mentioned above is expressed in terms of a modified PBE functional, which
29# in turn is specified in terms of the PW91 LDA functional. When hand coding
30# functionals this structure is maintained to make functionals manageable.
31# Ideally we would like to preserve this structure while using symbolic algebra
32# to generate the source code.
33#
34# Using symbolic algebra it is in principle possible to do this as one can
35# specify a function in terms of other functions that Maxima does not know.
36# For example, one can specify a function as
37#
38#    f(x,y) := (x+y)^2*h(y);
39#
40# where Maxima does not know what h(y) is. Differentiating f(x,y) wrt. y
41# Maxima will generate
42#
43#    diff(f(x,y),y) := 2*(x+y)*h(y)+(x+y)^2*'diff(h(y),y,1)
44#
45# where 'diff indicates that a derivative is required. The Fortran generated
46# by Maxima will reference h(y) and 'diff(h(y),y,1) which is not necessarily
47# valid Fortran. If we know what h(y) and 'diff(h(y),y,1) stand for this problem
48# can resolved by replacing these entities with valid Fortran functions or
49# variables. In that fashion Maxima can be used to generate source code while
50# preserving the structure of the original functional implementation.
51#
52# This script implements these source code transformations.
53#
54#
55# APPROACH
56# --------
57#
58# The main complication in this program is that the expression optimization
59# that Maxima performs messes things up as in the example above it would
60# break out the references to h(y) as in
61#
62#    t1 = h(y)
63#    diff(f(x,y),y) := 2*(x+y)*t1+(x+y)^2*'diff(t1,y,1)
64#
65# Also note that h may be invoked multiple times as in, for example,
66#
67#    f(x,y) := (x+y)^2*h(x,y) - x^2*h(x,0) - y^2*h(0,y)
68#
69# so we need to distinguish every invocation based on the arguments passed in.
70# In addition, in our case, h is actually implemented as a subroutine that
71# returns the function value and the values of the various derivatives. The
72# inputs and outputs are stored in arrays. The steps required are:
73#
74# 1. All Fortran lines need to be unwrapped so that simple text analysis
75#    can reliably find references to h(x,y), 'diff(h(x,y)), etc.
76# 2. Where h(x,y) is invoked we need to generate arrays x and y, as well as
77#    arrays for the required derivatives and initialize them.
78# 3. We need to insert the array declarations in the type declarations.
79# 4. We need to insert the subroutine calls at the appropriate places (i.e.
80#    right before the first place where results are used as we need to make
81#    sure that the input data is defined).
82# 5. Where h(x,y) or its derivatives are referenced we need to substitute
83#    the appropriate variables.
84# 6. The Fortran lines need to be rewrapped such that we use at most 72 columns.
85#
86import re
87import sys
88import string
89
90type_autoxc   = 1
91type_autoxcDs = 2
92
93ifbranch_closedshell = 1
94ifbranch_openshell   = 2
95
96func_invalid  = -1
97func_lda      = 1
98func_gga      = 2
99func_mgga     = 3
100
101version = "$Id$"
102version = version.split()
103revision = version[1]+" revision "+version[2]+" "+version[3]
104
105def usage(code):
106   """
107   Print usage information
108   """
109   sys.stderr.write("Insert subroutine calls to functionals in Maxima generated code.")
110   sys.stderr.write("")
111   sys.stderr.write("  %s [-h] [-v|--version] < <filein> > <fileout>"%sys.argv[0])
112   sys.stderr.write("")
113   sys.stderr.write("-h            Print this information")
114   sys.stderr.write("-v|--version  Print the version data")
115   sys.stderr.write("<filein>      Raw Fortran from Maxima")
116   sys.stderr.write("<fileout>     Fortran from Maxima with subroutine calls")
117   sys.stderr.write("              for functionals")
118   sys.stderr.write("")
119   sys.stderr.write("$Id$")
120   sys.exit(code)
121
122def var_to_int(var):
123   """
124   Convert a variable name (such as "t5") to an integer (in this case 5).
125   Variables like "Amat(iq,D1_RA)" should never show up here.
126   """
127   return int(var[1:])
128
129def unwrap_lines(lines_in):
130   """
131   Take a list of lines of Fortran and unwrap all continuation lines.
132   The result is stored in a new list which is returned.
133   All the input lines end with a '\n' character, the output lines have no
134   newline characters.
135   """
136   pattern = re.compile("     [0-9:;<=>?@+]")
137   lines_out = []
138   longline = ""
139   length = 0
140   for line in lines_in:
141      length = len(line)
142      if pattern.match(line):
143         shortline = line[6:length-1]
144         longline += shortline.lstrip()
145      elif length > 1:
146         if len(longline) > 0:
147            lines_out.append(longline)
148         longline = line[:length-1]
149   if len(longline) > 0:
150      lines_out.append(longline)
151   return lines_out
152
153def find_subroutine(lines,lineno):
154   """
155   Given the starting line number (lineno) and a list of lines find the
156   line numbers where the next subroutine starts and ends. Here we use
157   Fortran90 coding styles so the starting and end points can be identified
158   by looking for
159   - subroutine
160   - end subroutine
161   A tuple with the corresponding starting and end line numbers is returned.
162   """
163   length = len(lines)
164   line = lineno
165   lineno_start = -1
166   lineno_end   = -1
167   pattern = re.compile("      subroutine")
168   while not pattern.match(lines[line]):
169      line += 1
170      if line >= length:
171         break
172   if line < length:
173      lineno_start = line
174   # next statement needed to guarantee that line is in the valid range
175   if line >= length:
176      line = length-1
177   pattern = re.compile("      end subroutine")
178   while not pattern.match(lines[line]):
179      line += 1
180      if line >= length:
181         break
182   if line < length:
183      lineno_end = line
184   return (lineno_start,lineno_end)
185
186def find_code_skeleton_type(lines,subr_lines):
187   """
188   Scan the lines of a subroutine to find out whether it is an autoxc or and
189   autoxc-Ds generated subroutine.
190   """
191   (lineno_start,lineno_end) = subr_lines
192   pattern = re.compile("if \(taua\.gt\.tol_rho\) then")
193   subr_type = type_autoxc
194   line = lineno_start
195   while line <= lineno_end:
196      if pattern.search(lines[line]):
197         subr_type = type_autoxcDs
198         break
199      line += 1
200   return subr_type
201
202def find_autoxc_code_skeleton(lines,subr_lines):
203   """
204   Assuming that the subroutine is of the autoxc code skeleton type find all
205   the if-branches and return a list of all the (start,end) tuples.
206   """
207   (lineno_start,lineno_end) = subr_lines
208   ifbranches = []
209   line = lineno_start
210   ifstart = -1
211   ifend   = -1
212   #
213   pattern = re.compile("if \(rhoa\.gt\.tol_rho\) then")
214   while line <= lineno_end:
215      if pattern.search(lines[line]):
216         ifstart = line+1
217         break
218      line += 1
219   #
220   pattern = re.compile("endif ! rhoa\.gt\.tol_rho")
221   while line <= lineno_end:
222      if pattern.search(lines[line]):
223         ifend   = line-1
224         break
225      line += 1
226   #
227   ifbranches.append((ifstart,ifend))
228   #
229   ifstarta = -1
230   ifenda   = -1
231   ifstartb = -1
232   ifendb   = -1
233   ifstartc = -1
234   ifendc   = -1
235   pattern = re.compile("if \(rhoa\.gt\.tol_rho\.and\.rhob\.gt\.tol_rho\) then")
236   while line <= lineno_end:
237      if pattern.search(lines[line]):
238         ifstarta = line+1
239         break
240      line += 1
241   #
242   pattern = re.compile("elseif \(rhoa\.gt\.tol_rho\.and\.rhob\.le\.tol_rho\) then")
243   while line <= lineno_end:
244      if pattern.search(lines[line]):
245         ifenda   = line-1
246         ifstartb = line+1
247         break
248      line += 1
249   #
250   pattern = re.compile("elseif \(rhoa\.le\.tol_rho\.and\.rhob\.gt\.tol_rho\) then")
251   while line <= lineno_end:
252      if pattern.search(lines[line]):
253         ifendb   = line-1
254         ifstartc = line+1
255         break
256      line += 1
257   #
258   pattern = re.compile("endif ! rhoa\.gt\.tol_rho\.and\.rhob\.gt\.tol_rho")
259   while line <= lineno_end:
260      if pattern.search(lines[line]):
261         ifendc   = line-1
262         break
263      line += 1
264   #
265   ifbranches.append((ifstarta,ifenda))
266   ifbranches.append((ifstartb,ifendb))
267   ifbranches.append((ifstartc,ifendc))
268   #
269   return ifbranches
270
271def find_autoxcDs_code_skeleton(lines,subr_lines):
272   """
273   Assuming that the subroutine is of the autoxc-Ds code skeleton type find all
274   the if-branches and return a list of all the (start,end) tuples.
275   """
276   (lineno_start,lineno_end) = subr_lines
277   ifbranches = []
278   line = lineno_start
279   ifstarta = -1
280   ifenda   = -1
281   ifstartb = -1
282   ifendb   = -1
283   #
284   pattern = re.compile("if \(taua\.gt\.tol_rho\) then")
285   while line <= lineno_end:
286      if pattern.search(lines[line]):
287         ifstarta = line+1
288         break
289      line += 1
290   #
291   pattern = re.compile("else")
292   while line <= lineno_end:
293      if pattern.search(lines[line]):
294         ifenda   = line-1
295         ifstartb = line+1
296         break
297      line += 1
298   #
299   pattern = re.compile("endif")
300   while line <= lineno_end:
301      if pattern.search(lines[line]):
302         ifendb  = line-1
303         break
304      line += 1
305   #
306   ifbranches.append((ifstarta,ifenda))
307   ifbranches.append((ifstartb,ifendb))
308   #
309   for x in range(0,3):
310      ifstarta = -1
311      ifenda   = -1
312      ifstartb = -1
313      ifendb   = -1
314      ifstartc = -1
315      ifendc   = -1
316      ifstartd = -1
317      ifendd   = -1
318      pattern = re.compile("if \(taua\.gt\.tol_rho\.and\.taub\.gt\.tol_rho\) then")
319      while line <= lineno_end:
320         if pattern.search(lines[line]):
321            ifstarta = line+1
322            break
323         line += 1
324      #
325      pattern = re.compile("elseif \(taua\.gt\.tol_rho\.and\.taub\.le\.tol_rho\) then")
326      while line <= lineno_end:
327         if pattern.search(lines[line]):
328            ifenda   = line-1
329            ifstartb = line+1
330            break
331         line += 1
332      #
333      pattern = re.compile("elseif \(taua\.le\.tol_rho\.and\.taub\.gt\.tol_rho\) then")
334      while line <= lineno_end:
335         if pattern.search(lines[line]):
336            ifendb   = line-1
337            ifstartc = line+1
338            break
339         line += 1
340      #
341      # needed because "else" is a substring of "elseif ..."
342      line = ifstartc
343      pattern = re.compile("else")
344      while line <= lineno_end:
345         if pattern.search(lines[line]):
346            ifendc   = line-1
347            ifstartd = line+1
348            break
349         line += 1
350      #
351      pattern = re.compile("endif")
352      while line <= lineno_end:
353         if pattern.search(lines[line]):
354            ifendd   = line-1
355            break
356         line += 1
357      #
358      ifbranches.append((ifstarta,ifenda))
359      ifbranches.append((ifstartb,ifendb))
360      ifbranches.append((ifstartc,ifendc))
361      ifbranches.append((ifstartd,ifendd))
362      #
363   return ifbranches
364
365def find_type_declaration_insertion_point(lines,subr_lines):
366   """
367   Find the point in the subroutine where the array declarations for the
368   subroutine calls can be inserted.
369   """
370   (lineno_start,lineno_end) = subr_lines
371   pattern = re.compile("#include \"nwxc_param.fh\"")
372   line_insert = -1
373   line = lineno_start
374   while line <= lineno_end:
375      if pattern.match(lines[line]):
376         line_insert = line
377         break
378      line += 1
379   return line_insert
380
381def find_subroutine_call_insertion_point(lines,ifbranch,varname):
382   """
383   The assumption is that the symbolic algebra optimization will always break
384   out the function evaluation. Therefore there always is a line where the
385   functional value is assigned to a variable. The actual functional
386   subroutine call, at the latest, has to happen on the line before.
387   This routine returns the line where the variable first appears.
388   """
389   (lineno_start,lineno_end) = ifbranch
390   line = lineno_start
391   insert_point = -1
392   while line <= lineno_end:
393      aline = lines[line]
394      aline = aline.split(" = ")
395      key = aline[0].lstrip()
396      if key == varname:
397         insert_point = line
398         break
399      line += 1
400   return insert_point
401
402def collect_subroutine_calls(lines,ifbranch_lines):
403   """
404   Collect a dictionary of all subroutine call instances. In the source code
405   on entry the "subroutine call" is given in the form of a function call, e.g.
406   t5 = nwxc_c_Mpbe(rhoa,0.0d+0,gammaaa,0.0d+0,0.0d+0)
407   we need to know the "nwxc_c_Mpbe(rhoa,0.0d+0,gammaaa,0.0d+0,0.0d+0)" part.
408   The key in the dictionary is going to be the variable name (i.e. "t5") as we
409   will have to replace those variable instances with an array element
410   reference.
411   The resulting dictionary is returned.
412   """
413   (lineno_start,lineno_end) = ifbranch_lines
414   dict = {}
415   pattern = re.compile("nwxc")
416   line = lineno_start
417   while line <= lineno_end:
418      if pattern.search(lines[line]):
419         aline = lines[line]
420         aline = aline.split(" = ")
421         key   = aline[0].lstrip()
422         dict[key] = aline[1]
423      line += 1
424   return dict
425
426def delete_lines(lines,ifbranches):
427   """
428   Collect a list of lines that should be deleted (i.e. skipped) when writing
429   the output subroutine. The lines in question are the ones generated from the
430   "at" command, e.g.
431   t10(1) = (gammaaa = gammaaa)
432   t10(2) = (gammaab = gammaaa)
433   t10(3) = (gammabb = gammaaa)
434   t10(4) = (rhoa = rhoa)
435   t10(5) = (rhob = rhoa)
436   t10(6) = (taua = taua)
437   t10(7) = (taub = taua)
438   The list of line numbers is returned.
439   """
440   dlist = []
441   for ifbranch in ifbranches:
442      (lineno_start,lineno_end) = ifbranch
443      line = lineno_start
444      while line <= lineno_end:
445         aline = lines[line]
446         aline = aline.split(" = ")
447         if len(aline) == 3:
448            # line: t10(2) = (gammaab = gammaaa)
449            dlist.append(line)
450         line += 1
451   return dlist
452
453def find_maxno_calls(lines,ifbranches):
454   """
455   In every if-branch some subroutine calls will be inserted. For the inputs
456   we need one set of arrays for rho, gamma and tau. However, for the outputs
457   we need a separate set of arrays for each separate call as the results may
458   appear in multiple place throughout the remained of the if-branch. In
459   addition there is no guarantee that the results are not used in overlapping
460   code segments. Therefore we need to count how many different sets of
461   output variables we need to declare. This routine returns the resulting
462   number.
463   """
464   varsets = 0
465   for ifbranch in ifbranches:
466      dict = collect_subroutine_calls(lines,ifbranch)
467      numcalls = len(dict)
468      if numcalls > varsets:
469         varsets = numcalls
470   return varsets
471
472def append_declarations(olines,varsets,orderdiff):
473   """
474   Given the list of output lines so far, the number of output variable sets
475   and the maximum order of differentiation append the required array
476   declarations. The new list of lines is returned.
477   """
478   olines.append("      double precision sr(NCOL_RHO)")
479   olines.append("      double precision sg(NCOL_GAMMA)")
480   olines.append("      double precision st(NCOL_TAU)")
481   for ii in range(1,varsets+1):
482      line = "      double precision s"+str(ii)+"f"
483      olines.append(line)
484      line = "      double precision s"+str(ii)+"a(NCOL_AMAT)"
485      olines.append(line)
486      line = "      double precision s"+str(ii)+"c(NCOL_CMAT)"
487      olines.append(line)
488      line = "      double precision s"+str(ii)+"m(NCOL_MMAT)"
489      olines.append(line)
490      if orderdiff > 1:
491         line = "      double precision s"+str(ii)+"a2(NCOL_AMAT2)"
492         olines.append(line)
493         line = "      double precision s"+str(ii)+"c2(NCOL_CMAT2)"
494         olines.append(line)
495         line = "      double precision s"+str(ii)+"m2(NCOL_MMAT2)"
496         olines.append(line)
497      if orderdiff > 2:
498         line = "      double precision s"+str(ii)+"a3(NCOL_AMAT3)"
499         olines.append(line)
500         line = "      double precision s"+str(ii)+"c3(NCOL_CMAT3)"
501         olines.append(line)
502         line = "      double precision s"+str(ii)+"m3(NCOL_MMAT3)"
503         olines.append(line)
504   return olines
505
506def append_subroutine_call(olines,subrname,arglist,orderdiff,ifbranch_kind,funckind,num,indent):
507   """
508   Given the list of output lines so far, the subroutine name, the input
509   argument list, the order of differentiation, the functional kind as well as
510   the number of the output variable set, generate the subroutine call code.
511   The subroutine call code involves initializing the input and output
512   arguments, and constructing the actual call itself.
513   The list of output lines is returned.
514   """
515   #DEBUG
516   #print "append_subroutine_call: arglist:",arglist
517   #DEBUG
518   if ifbranch_kind == ifbranch_closedshell:
519      if arglist[1] == '0.0d+0' or arglist[2] == '0.0d+0':
520         # We are dealing with the open shell term of the Stoll partitioning
521         # of the correlation energy
522         line = indent+"sr(R_A) = "+arglist[1]
523         olines.append(line)
524         line = indent+"sr(R_B) = "+arglist[2]
525         olines.append(line)
526         if funckind >= func_gga:
527            line = indent+"sg(G_AA) = "+arglist[3]
528            olines.append(line)
529            line = indent+"sg(G_AB) = "+arglist[4]
530            olines.append(line)
531            line = indent+"sg(G_BB) = "+arglist[5]
532            olines.append(line)
533         if funckind >= func_mgga:
534            line = indent+"st(T_A) = "+arglist[6]
535            olines.append(line)
536            line = indent+"st(T_B) = "+arglist[7]
537            olines.append(line)
538      else:
539         # We are dealing with a regular closed shell call
540         line = indent+"sr(R_T) = 2.0d0*"+arglist[1]
541         olines.append(line)
542         if funckind >= func_gga:
543            line = indent+"sg(G_TT) = 4.0d0*"+arglist[3]
544            olines.append(line)
545         if funckind >= func_mgga:
546            line = indent+"st(T_T) = 2.0d0*"+arglist[6]
547            olines.append(line)
548   elif ifbranch_kind == ifbranch_openshell:
549      # We are dealing with a regular open shell call
550      line = indent+"sr(R_A) = "+arglist[1]
551      olines.append(line)
552      line = indent+"sr(R_B) = "+arglist[2]
553      olines.append(line)
554      if funckind >= func_gga:
555         line = indent+"sg(G_AA) = "+arglist[3]
556         olines.append(line)
557         line = indent+"sg(G_AB) = "+arglist[4]
558         olines.append(line)
559         line = indent+"sg(G_BB) = "+arglist[5]
560         olines.append(line)
561      if funckind >= func_mgga:
562         line = indent+"st(T_A) = "+arglist[6]
563         olines.append(line)
564         line = indent+"st(T_B) = "+arglist[7]
565         olines.append(line)
566   else:
567      sys.stderr.write("append_subroutine_call: invalid ifbranch_kind: %d\n"%ifbranch_kind)
568      sys.exit(20)
569   line = indent+"s"+str(num)+"f = 0.0d0"
570   olines.append(line)
571   line = indent+"call dcopy(NCOL_AMAT,0.0d0,0,s"+str(num)+"a,1)"
572   olines.append(line)
573   if orderdiff >= 2:
574      line = indent+"call dcopy(NCOL_AMAT2,0.0d0,0,s"+str(num)+"a2,1)"
575      olines.append(line)
576   if orderdiff >= 3:
577      line = indent+"call dcopy(NCOL_AMAT3,0.0d0,0,s"+str(num)+"a3,1)"
578      olines.append(line)
579   if funckind >= func_gga:
580      line = indent+"call dcopy(NCOL_CMAT,0.0d0,0,s"+str(num)+"c,1)"
581      olines.append(line)
582      if orderdiff >= 2:
583         line = indent+"call dcopy(NCOL_CMAT2,0.0d0,0,s"+str(num)+"c2,1)"
584         olines.append(line)
585      if orderdiff >= 3:
586         line = indent+"call dcopy(NCOL_CMAT3,0.0d0,0,s"+str(num)+"c3,1)"
587         olines.append(line)
588   if funckind >= func_mgga:
589      line = indent+"call dcopy(NCOL_MMAT,0.0d0,0,s"+str(num)+"m,1)"
590      olines.append(line)
591      if orderdiff >= 2:
592         line = indent+"call dcopy(NCOL_MMAT2,0.0d0,0,s"+str(num)+"m2,1)"
593         olines.append(line)
594      if orderdiff >= 3:
595         line = indent+"call dcopy(NCOL_MMAT3,0.0d0,0,s"+str(num)+"m3,1)"
596         olines.append(line)
597   line = indent+"call "+subrname
598   if orderdiff == 2:
599      line = line+"_d2"
600   elif orderdiff == 3:
601      line = line+"_d3"
602   line = line+"("+arglist[0]+",tol_rho,"
603   if arglist[1] == '0.0d+0' or arglist[2] == '0.0d+0':
604      # We are dealing with the Stoll partitioning of the correlation energy
605      line = line+"2"
606   else:
607      line = line+"ipol"
608   line = line+",1,1.0d0,sr"
609   if funckind >= func_gga:
610      line = line+",sg"
611   if funckind >= func_mgga:
612      line = line+",st"
613   line = line+",s"+str(num)+"f,s"+str(num)+"a"
614   if orderdiff >= 2:
615      line = line+",s"+str(num)+"a2"
616   if orderdiff >= 3:
617      line = line+",s"+str(num)+"a3"
618   if funckind >= func_gga:
619      line = line+",s"+str(num)+"c"
620      if orderdiff >= 2:
621         line = line+",s"+str(num)+"c2"
622      if orderdiff >= 3:
623         line = line+",s"+str(num)+"c3"
624   if funckind >= func_mgga:
625      line = line+",s"+str(num)+"m"
626      if orderdiff >= 2:
627         line = line+",s"+str(num)+"m2"
628      if orderdiff >= 3:
629         line = line+",s"+str(num)+"m3"
630   line = line+")"
631   olines.append(line)
632   return olines
633
634
635def find_max_order_diff(lines,subr_lines):
636   """
637   Work out the maximum order of differentiation for a given subroutine.
638   The maximum order is returned.
639   """
640   (lineno_start,lineno_end) = subr_lines
641   aline = lines[lineno_start]
642   orderdiff = 0
643   pattern_d2 = re.compile("_d2\(")
644   pattern_d3 = re.compile("_d3\(")
645   if pattern_d2.search(aline):
646      orderdiff = 2
647   elif pattern_d3.search(aline):
648      orderdiff = 3
649   else:
650      orderdiff = 1
651   return orderdiff
652
653def find_functional_kind(funccall):
654   """
655   Given the way the functional was invoked as a function work out whether
656   the functional is an LDA, GGA or Meta-GGA functional. We establish this
657   by counting the number of arguments which is 3 for LDA, 6 for GGA and
658   8 for a Meta-GGA functional. The numbers include the "param" argument
659   as well as the density dependent quantities, i.e. for LDA we have 3
660   arguments: param + rhoa + rhob.
661   The functional kind is returned.
662   """
663   data = funccall
664   data = data.split(",")
665   length = len(data)
666   funckind = func_invalid
667   if   length == 3:
668      funckind = func_lda
669   elif length == 6:
670      funckind = func_gga
671   elif length == 8:
672      funckind = func_mgga
673   return funckind
674
675def make_functional_name(funccall):
676   """
677   Given the way the functional was invoked as a function work out what the
678   name of the subroutine to call is.
679   The subroutine name is returned.
680   """
681   data = funccall
682   data = data.split("(")
683   data = data[0]
684   data = data.replace("nwxc_","nwxcm_")
685   return data
686
687def make_input_args_list(funccall):
688   """
689   Given the way the functional was invoked as a function work out what the
690   argument list is. This list will be used to initialize the input data to
691   the actual subroutine call.
692   The list of arguments is returned.
693   """
694   data = funccall.split("(")
695   data = data[1].split(")")
696   data = data[0].split(",")
697   return data
698
699def find_varname(dict,diffstr):
700   """
701   Given the subroutine call dictionary and the diffstr representing the
702   "'diff(...)" command work out which variable is indicated. If
703   diffstr is e.g. 'diff(t5,gammaa,2,rhob,1) then we return e.g.
704   s3c3(D3_RB_GAA_GAA).
705   If diffstr is just a variable name, e.g. t5. then it represents the
706   functional value rather than a derivative and we return e.g. s3f.
707   The indicated variable is returned as a string.
708   """
709   length = 0
710   data = diffstr
711   #DEBUG
712   #print "find_varname: data:",data
713   #DEBUG
714   if "%at(" == data[:4]:
715      data = data[4:-1]
716      iend = data.rfind(",")
717      data = data[:iend]
718      #DEBUG
719      #print "find_varname: at:",data
720      #DEBUG
721   if "'diff(" == data[:6]:
722      data = data[6:-1]
723      #DEBUG
724      #print "find_varname: diff:",data
725      #DEBUG
726   data = data.split(",")
727   callref = data[0]
728   list = dict.keys()
729   list = sorted(list,key=var_to_int)
730   num = -1
731   if len(list) > 0:
732      num = 0
733      #DEBUG
734      #print "find_varname:",callref,list
735      #DEBUG
736      while callref != list[num]:
737         num += 1
738      #DEBUG
739      #print "find_varname: num:",num
740      #DEBUG
741   lengthl = len(list)
742   if num == lengthl:
743      sys.stdout.write("entity %s not found\n"%callref)
744      for jj in range(0,length):
745         sys.stdout.write("list %d: %s\n"%(jj,list[jj]))
746      sys.exit(10)
747   # num is now the variable set number
748   lengthd = len(data)
749   #DEBUG
750   #print "find_varname: lengthd:",lengthd
751   #DEBUG
752   if lengthd == 1:
753      # this is the energy functional value
754      var_name = "s"+str(num+1)+"f"
755      return var_name
756   orderdiff = 0
757   if lengthd >= 3:
758      orderdiff += int(data[2])
759   if lengthd >= 5:
760      orderdiff += int(data[4])
761   if lengthd >= 7:
762      orderdiff += int(data[6])
763   #DEBUG
764   #print "find_varname: orderdiff:",orderdiff
765   #DEBUG
766   # orderdiff is now the order of differentiation
767   patternc = re.compile("gamma")
768   patternt = re.compile("tau")
769   var_func = func_lda
770   if lengthd >= 3:
771      if patternc.match(data[1]):
772         var_func = func_gga
773      if patternt.match(data[1]):
774         var_func = func_mgga
775   if lengthd >= 5:
776      if var_func < func_gga and patternc.match(data[3]):
777         var_func = func_gga
778      if var_func < func_mgga and patternt.match(data[3]):
779         var_func = func_mgga
780   if lengthd >= 7:
781      if var_func < func_gga and patternc.match(data[5]):
782         var_func = func_gga
783      if var_func < func_mgga and patternt.match(data[5]):
784         var_func = func_mgga
785   #DEBUG
786   #print "find_varname: var_func:",var_func
787   #DEBUG
788   var_char = "invalid"
789   if var_func == func_lda:
790      var_char = "a"
791   elif var_func == func_gga:
792      var_char = "c"
793   elif var_func == func_mgga:
794      var_char = "m"
795   # var_char is now the character representing the output matrix,
796   # a for Amat, c for Cmat, and m for Mmat
797   var_field = "D"+str(orderdiff)
798   if lengthd >= 3:
799      if "rhoa" == data[1]:
800         for ii in range(0,int(data[2])):
801            var_field = var_field + "_RA"
802   if lengthd >= 5:
803      if "rhoa" == data[3]:
804         for ii in range(0,int(data[4])):
805            var_field = var_field + "_RA"
806   if lengthd >= 7:
807      if "rhoa" == data[5]:
808         for ii in range(0,int(data[6])):
809            var_field = var_field + "_RA"
810   #
811   if lengthd >= 3:
812      if "rhob" == data[1]:
813         for ii in range(0,int(data[2])):
814            var_field = var_field + "_RB"
815   if lengthd >= 5:
816      if "rhob" == data[3]:
817         for ii in range(0,int(data[4])):
818            var_field = var_field + "_RB"
819   if lengthd >= 7:
820      if "rhob" == data[5]:
821         for ii in range(0,int(data[6])):
822            var_field = var_field + "_RB"
823   #
824   if lengthd >= 3:
825      if "gammaaa" == data[1]:
826         for ii in range(0,int(data[2])):
827            var_field = var_field + "_GAA"
828   if lengthd >= 5:
829      if "gammaaa" == data[3]:
830         for ii in range(0,int(data[4])):
831            var_field = var_field + "_GAA"
832   if lengthd >= 7:
833      if "gammaaa" == data[5]:
834         for ii in range(0,int(data[6])):
835            var_field = var_field + "_GAA"
836   #
837   if lengthd >= 3:
838      if "gammaab" == data[1]:
839         for ii in range(0,int(data[2])):
840            var_field = var_field + "_GAB"
841   if lengthd >= 5:
842      if "gammaab" == data[3]:
843         for ii in range(0,int(data[4])):
844            var_field = var_field + "_GAB"
845   if lengthd >= 7:
846      if "gammaab" == data[5]:
847         for ii in range(0,int(data[6])):
848            var_field = var_field + "_GAB"
849   #
850   if lengthd >= 3:
851      if "gammabb" == data[1]:
852         for ii in range(0,int(data[2])):
853            var_field = var_field + "_GBB"
854   if lengthd >= 5:
855      if "gammabb" == data[3]:
856         for ii in range(0,int(data[4])):
857            var_field = var_field + "_GBB"
858   if lengthd >= 7:
859      if "gammabb" == data[5]:
860         for ii in range(0,int(data[6])):
861            var_field = var_field + "_GBB"
862   #
863   if lengthd >= 3:
864      if "taua" == data[1]:
865         for ii in range(0,int(data[2])):
866            var_field = var_field + "_TA"
867   if lengthd >= 5:
868      if "taua" == data[3]:
869         for ii in range(0,int(data[4])):
870            var_field = var_field + "_TA"
871   if lengthd >= 7:
872      if "taua" == data[5]:
873         for ii in range(0,int(data[6])):
874            var_field = var_field + "_TA"
875   #
876   if lengthd >= 3:
877      if "taub" == data[1]:
878         for ii in range(0,int(data[2])):
879            var_field = var_field + "_TB"
880   if lengthd >= 5:
881      if "taub" == data[3]:
882         for ii in range(0,int(data[4])):
883            var_field = var_field + "_TB"
884   if lengthd >= 7:
885      if "taub" == data[5]:
886         for ii in range(0,int(data[6])):
887            var_field = var_field + "_TB"
888   # var_field now contains the array field, e.g. D1_RA, D3_GAA_TB_TB, etc.
889   var_name = "s"+str(num+1)+var_char
890   if orderdiff >= 2:
891      var_name = var_name+str(orderdiff)
892   var_name = var_name+"("+var_field+")"
893   return var_name
894
895def line_contains_var(line,var):
896   """
897   Check whether a variable specified by "var" is contained within the line
898   specified by "line". E.g. var="t2" and line="a=b+t2*t4" should return
899   True whereas line="a=b+t21" should return False.
900   """
901   regular  = var+"[^0-9]"
902   pattern  = re.compile(regular)
903   longline = line+" "
904   result   = pattern.match(longline)
905   return (result != None)
906
907def find_indent(line):
908   """
909   Given a line of source code work the indentation out and return a string
910   containing as many spaces as the indentation.
911   """
912   pattern = re.compile("\s*")
913   obj     = pattern.match(line)
914   indent  = obj.group()
915   return indent
916
917def find_var_in_line(line,var_name):
918   """
919   Find all locations where the variable given by "var_name" is used.
920   The locations are stored in a list which is returned.
921   """
922   patterne = re.compile(" = ")
923   # If var_name = t2 we need to make sure we do not replace
924   # cmat2 as well. Patternw is needed to achieve that. So we look for
925   # t2 as well at2, if the end-points are equal the string found does
926   # not match the variable t2.
927   patternv = re.compile(var_name+"[^0-9]")
928   patternw = re.compile("a"+var_name+"[^0-9]")
929   aline = line+" "
930   found = patterne.search(aline)
931   list = []
932   if found:
933      pos = found.end(0)
934      obj = patternv.search(aline,pos)
935      while obj:
936         objw = patternw.search(aline,pos)
937         if objw:
938            if objw.end(0) != obj.end(0):
939               list.append((obj.start(0),obj.end(0)-1))
940         else:
941            list.append((obj.start(0),obj.end(0)-1))
942         pos = obj.end(0)
943         obj = patternv.search(aline,pos)
944   return list
945
946def expand_var_in_line(line,list):
947   """
948   The variable references are given in list in the form of tuples so that
949   line(begin:end) matches the variable exactly. However, the string we need
950   to replace might be larger, e.g. in the case we have an entity such as
951   'diff(t5,rhoa,1,gammaaa,2), or worse %at('diff(t5,rhoa,1),t10).
952   Here we expand the tuples in the list to cover these strings in full.
953   The list with updated tuples is returned.
954   """
955   pattern = re.compile("\)")
956   item = 0
957   length = len(list)
958   while item < length:
959      (ibegin,iend) = list[item]
960      iibegin = ibegin-6
961      if line[iibegin:ibegin] == "'diff(":
962         iiend = pattern.search(line,iend).end(0)
963         list[item] = (iibegin,iiend)
964      (ibegin,iend) = list[item]
965      iibegin = ibegin-4
966      if line[iibegin:ibegin] == "%at(":
967         iiend = pattern.search(line,iend).end(0)
968         list[item] = (iibegin,iiend)
969      item += 1
970   return list
971
972def find_replace_var_in_range(lines,iline_begin,iline_end,dict,var):
973   """
974   Find all locations in the range of lines (ibegin,iend) where the variable
975   "var" is referenced and replace those references with the proper functional
976   output variable.
977   The updated list of lines "lines" is returned.
978   """
979   iline = iline_begin
980   while iline <= iline_end:
981      list = find_var_in_line(lines[iline],var)
982      list = expand_var_in_line(lines[iline],list)
983      aline = lines[iline]
984      bline = ""
985      oend = 0
986      for tuple in list:
987         (ibegin,iend) = tuple
988         diffstr = aline[ibegin:iend]
989         #DEBUG
990         #print "diffstr:",diffstr
991         #DEBUG
992         var_name = find_varname(dict,diffstr)
993         #DEBUG
994         #print "var_name:",var_name
995         #DEBUG
996         bline = bline+aline[oend:ibegin]+var_name
997         oend = iend
998      bline = bline+aline[oend:]
999      lines[iline] = bline
1000      iline += 1
1001   return lines
1002
1003def rewrap_line(longline):
1004  """
1005  Break a given long line "longline" up into 72 character long chunks that
1006  conform the Fortran77 standard. The chunks are written to standard output.
1007  In addition we do not want to break the line in the middle of numbers.
1008  """
1009  pattern = re.compile("\S")
1010  i = (pattern.search(longline)).start()
1011  indent = longline[:i]
1012  indent = indent[:5]+"+"+indent[7:]+"   "
1013  while len(longline) > 72:
1014    i = -1
1015    # wrap before * / ( ) + or -
1016    i = max(i,string.rfind(longline,",",0,70)+1)
1017    i = max(i,string.rfind(longline,"*",0,71))
1018    i = max(i,string.rfind(longline,"/",0,71))
1019    i = max(i,string.rfind(longline,"(",0,71))
1020    i = max(i,string.rfind(longline,")",0,71))
1021    # wrap before + but not in the middle of a numerical constant...
1022    j = string.rfind(longline,"+",0,71)
1023    k = string.rfind(longline,"d+",0,71)
1024    if j-1 == k:
1025      j = string.rfind(longline,"+",0,k)
1026    i = max(i,j)
1027    # wrap before - but not in the middle of a numerical constant...
1028    j = string.rfind(longline,"-",0,71)
1029    k = string.rfind(longline,"d-",0,71)
1030    if j-1 == k:
1031      j = string.rfind(longline,"-",0,k)
1032    i = max(i,j)
1033    if i == -1:
1034      sys.stderr.write("No sensible break point found in:\n")
1035      sys.stderr.write(longline)
1036      sys.exit(1)
1037    elif i == 6:
1038      sys.stderr.write("Same break point found repeatedly in:\n")
1039      sys.stderr.write(longline)
1040      sys.exit(1)
1041    sys.stdout.write(longline[:i]+"\n")
1042    longline = indent + longline[i:]
1043  sys.stdout.write(longline+"\n")
1044
1045
1046if len(sys.argv) == 2:
1047   if sys.argv[1] == "-h":
1048      usage(0)
1049   elif sys.argv[1] == "-v" or sys.argv[1] == "--version":
1050      sys.stdout.write("%s\n"%revision)
1051      sys.exit(0)
1052   else:
1053      usage(1)
1054elif len(sys.argv) > 2:
1055   usage(1)
1056
1057ilines = sys.stdin.readlines()
1058ilines = unwrap_lines(ilines)
1059nlines = len(ilines)
1060#DEBUG
1061#file = open("junkjunk",'w')
1062#for line in ilines:
1063#   file.write("%s\n"%line)
1064#file.close()
1065#DEBUG
1066olines = []
1067line_start = 0
1068subr_lines = find_subroutine(ilines,line_start)
1069(subr_lines_start,subr_lines_end)=subr_lines
1070#DEBUG
1071#print "subrs: start,end:",subr_lines_start,subr_lines_end
1072#DEBUG
1073while subr_lines_start != -1 and subr_lines_end != -1:
1074   #
1075   # Roll forward to the beginning of the subroutine
1076   #
1077   while line_start < subr_lines_start:
1078      olines.append(ilines[line_start])
1079      line_start += 1
1080   #
1081   # Work the subroutine type out and then work the if-branches out
1082   #
1083   code_skel_type = find_code_skeleton_type(ilines,subr_lines)
1084   if   code_skel_type == type_autoxc:
1085      ifbranches = find_autoxc_code_skeleton(ilines,subr_lines)
1086   elif code_skel_type == type_autoxcDs:
1087      ifbranches = find_autoxcDs_code_skeleton(ilines,subr_lines)
1088   else:
1089      sys.stderr.write("Unexpected code_type\n")
1090      exit(20)
1091   #
1092   # Work the order of differentiation out
1093   #
1094   orderdiff = find_max_order_diff(ilines,subr_lines)
1095   #DEBUG
1096   #print "orderdiff:",orderdiff
1097   #DEBUG
1098   #
1099   # Work the declaration insertion point out
1100   #
1101   type_decl_insertion_point = find_type_declaration_insertion_point(ilines,subr_lines)
1102   #
1103   # How many additional variables do we need?
1104   #
1105   max_calls = find_maxno_calls(ilines,ifbranches)
1106   #DEBUG
1107   #print "max_calls:",max_calls
1108   #DEBUG
1109   #
1110   # Which lines do we need to drop?
1111   #
1112   delete_lines_list = delete_lines(ilines,ifbranches)
1113   #DEBUG
1114   #print "delete_list: ",delete_lines_list
1115   #DEBUG
1116   #
1117   # We know what additional variables we need to declare.
1118   # So copy more lines from the input file to the output until we reach
1119   # the declaration insertion point. Then call the function to the inject
1120   # the additional declarations in the output routine.
1121   #
1122   while line_start <= type_decl_insertion_point:
1123      olines.append(ilines[line_start])
1124      line_start += 1
1125   olines = append_declarations(olines,max_calls,orderdiff)
1126   #
1127   # Go through branches and mess with subroutine calls
1128   #
1129   i_ifbranch = 0
1130   for ifbranch in ifbranches:
1131      i_ifbranch += 1
1132      if   code_skel_type == type_autoxc and i_ifbranch <= 1:
1133         ifbranch_kind = ifbranch_closedshell
1134      elif code_skel_type == type_autoxcDs and i_ifbranch <= 2:
1135         ifbranch_kind = ifbranch_closedshell
1136      else:
1137         ifbranch_kind = ifbranch_openshell
1138      (line_if_start,line_if_end) = ifbranch
1139      call_lines = collect_subroutine_calls(ilines,ifbranch)
1140      #DEBUG
1141      #print "calls: start,end:",line_if_start,line_if_end
1142      #print "calls: call_lines:",call_lines
1143      #DEBUG
1144      call_vars = call_lines.keys()
1145      #DEBUG
1146      #print "calls: call_vars :",call_vars
1147      #DEBUG
1148      call_vars = sorted(call_vars,key=var_to_int)
1149      num = 0
1150      for var in call_vars:
1151         num += 1
1152         #DEBUG
1153         #print "looping over vars:",num,var
1154         #DEBUG
1155         call_insert = find_subroutine_call_insertion_point(ilines,ifbranch,var)
1156         indent = find_indent(ilines[call_insert])
1157         ilines = find_replace_var_in_range(ilines,call_insert,line_if_end,call_lines,var)
1158         # var_name = find_varname()
1159         #DEBUG
1160         #print "ilines:",call_insert,":",ilines[call_insert]
1161         #DEBUG
1162         ilines[call_insert] = indent+var+" = s"+str(num)+"f"
1163         # Roll forward to just before the insertion point
1164         while line_start < call_insert:
1165            if not (line_start in delete_lines_list):
1166               olines.append(ilines[line_start])
1167            line_start += 1
1168         funckind = find_functional_kind(call_lines[var])
1169         funcname = make_functional_name(call_lines[var])
1170         arglist  = make_input_args_list(call_lines[var])
1171         #DEBUG
1172         #print "funckind: ",funckind
1173         #print "funcname: ",funcname
1174         #print "arglist : ",arglist
1175         #DEBUG
1176         olines = append_subroutine_call(olines,funcname,arglist,orderdiff,ifbranch_kind,funckind,num,indent)
1177         #DEBUG
1178         #print "var: ",var,call_insert
1179         #DEBUG
1180      #break
1181      while line_start <= line_if_end:
1182         if not (line_start in delete_lines_list):
1183            olines.append(ilines[line_start])
1184         line_start += 1
1185   #break
1186   while line_start <= subr_lines_end:
1187      olines.append(ilines[line_start])
1188      line_start += 1
1189   subr_lines = find_subroutine(ilines,line_start)
1190   (subr_lines_start,subr_lines_end)=subr_lines
1191   #DEBUG
1192   #print "subrt: start,end:",subr_lines_start,subr_lines_end
1193   #DEBUG
1194while line_start < nlines:
1195    olines.append(ilines[line_start])
1196    line_start += 1
1197for line in olines:
1198   rewrap_line(line)
1199   #sys.stdout.write("%s\n"%line)
1200