1
2try:
3    import numpy as np
4    numpy_available = True
5except:
6    numpy_available = False
7#end try
8
9
10def_atol =  0.0
11def_rtol = 1e-6
12
13# determine if two floats differ
14def float_diff(v1,v2,atol=def_atol,rtol=def_rtol):
15    return np.abs(v1-v2)>atol+rtol*np.abs(v2)
16#end def float_diff
17
18
19# determine if two values differ
20def value_diff(v1,v2,atol=def_atol,rtol=def_rtol,int_as_float=False):
21    diff = False
22    v1_bool  = isinstance(v1,(bool,np.bool_))
23    v2_bool  = isinstance(v2,(bool,np.bool_))
24    v1_int   = isinstance(v1,(int,np.int_)) and not v1_bool
25    v2_int   = isinstance(v2,(int,np.int_)) and not v2_bool
26    v1_float = isinstance(v1,(float,np.float_))
27    v2_float = isinstance(v2,(float,np.float_))
28    v1_str   = isinstance(v1,(str,np.string_))
29    v2_str   = isinstance(v2,(str,np.string_))
30    if id(v1)==id(v2):
31        None
32    elif int_as_float and (v1_int or v1_float) and (v2_int or v2_float):
33        diff = float_diff(v1,v2,atol=atol,rtol=rtol)
34    elif v1_float and v2_float:
35        diff = float_diff(v1,v2,atol=atol,rtol=rtol)
36    elif v1_int and v2_int:
37        diff = v1!=v2
38    elif (v1_bool or v1_str) and (v2_bool or v2_str):
39        diff = v1!=v2
40    elif not isinstance(v1,type(v2)) or not isinstance(v2,type(v1)):
41        diff = True
42    elif isinstance(v1,(list,tuple)):
43        v1 = np.array(v1,dtype=object).ravel()
44        v2 = np.array(v2,dtype=object).ravel()
45        if len(v1)!=len(v2):
46            diff = True
47        else:
48            for vv1,vv2 in zip(v1,v2):
49                diff |= value_diff(vv1,vv2,atol,rtol,int_as_float)
50            #end for
51        #end if
52    elif isinstance(v1,np.ndarray):
53        v1 = v1.ravel()
54        v2 = v2.ravel()
55        if len(v1)!=len(v2):
56            diff = True
57        else:
58            for vv1,vv2 in zip(v1,v2):
59                diff |= value_diff(vv1,vv2,atol,rtol,int_as_float)
60            #end for
61        #end if
62    elif isinstance(v1,dict):
63        k1 = v1.keys()
64        k2 = v2.keys()
65        if set(k1)!=set(k2):
66            diff = True
67        else:
68            for k in k1:
69                diff |= value_diff(v1[k],v2[k],atol,rtol,int_as_float)
70            #end for
71        #end if
72    elif isinstance(v1,set):
73        diff = v1!=v2
74    elif v1 is None and v2 is None:
75        diff = False
76    elif hasattr(v1,'__len__') and hasattr(v2,'__len__') and len(v1)==0 and len(v2)==0:
77        None
78    else:
79        diff = True # unsupported types
80    #end if
81    return diff
82#end def value_diff
83
84
85# determine if two objects differ
86def object_diff(o1,o2,atol=def_atol,rtol=def_rtol,int_as_float=False,full=False,bypass=False):
87    diff1 = dict()
88    diff2 = dict()
89    if not bypass:
90        o1 = o1._serial().__dict__
91        o2 = o2._serial().__dict__
92    #end if
93    keys1 = set(o1.keys())
94    keys2 = set(o2.keys())
95    ku1   = keys1 - keys2
96    ku2   = keys2 - keys1
97    km    = keys1 & keys2
98    for k in ku1:
99        diff1[k] = o1[k]
100    #end for
101    for k in ku2:
102        diff2[k] = o2[k]
103    #end for
104    for k in km:
105        v1 = o1[k]
106        v2 = o2[k]
107        if value_diff(v1,v2,atol,rtol,int_as_float):
108            diff1[k] = v1
109            diff2[k] = v2
110        #end if
111    #end for
112    diff = len(diff1)!=0 or len(diff2)!=0
113    if not full:
114        return diff
115    else:
116        return diff,diff1,diff2
117    #end if
118#end def object_diff
119
120
121# determine if two text blocks differ
122def read_text_value(s):
123    v = s
124    try:
125        vi = int(s)
126    except:
127        try:
128            v = float(s)
129        except:
130            None
131        #end try
132    #end try
133    return v
134#end def read_text_value
135
136def read_text_tokens(t):
137    tokens = []
138    for v in t.split():
139        tokens.append(read_text_value(v))
140    #end for
141    return tokens
142#end def read_text_tokens
143
144def text_diff(t1,t2,atol=def_atol,rtol=def_rtol,int_as_float=False,full=False,by_line=False):
145    t1 = t1.replace(',',' , ')
146    t2 = t2.replace(',',' , ')
147    tokens1 = read_text_tokens(t1)
148    tokens2 = read_text_tokens(t2)
149    diff = value_diff(tokens1,tokens2,atol,rtol,int_as_float)
150    if not full:
151        return diff
152    elif not by_line:
153        diff1 = dict()
154        diff2 = dict()
155        nmin = min(len(tokens1),len(tokens2))
156        for n,(v1,v2) in enumerate(zip(tokens1[:nmin],tokens2[:nmin])):
157            if value_diff(v1,v2,atol,rtol,int_as_float):
158                diff1[n] = v1
159                diff2[n] = v2
160            #end if
161        #end for
162        if len(tokens1)>len(tokens2):
163            for n,v in enumerate(tokens1[nmin:]):
164                diff1[nmin+n] = v
165                diff2[nmin+n] = None
166            #end for
167        elif len(tokens2)>len(tokens1):
168            for n,v in enumerate(tokens2[nmin:]):
169                diff1[nmin+n] = None
170                diff2[nmin+n] = v
171            #end for
172        #end if
173        return diff,diff1,diff2
174    else:
175        diff1 = dict()
176        diff2 = dict()
177        lines1 = t1.splitlines()
178        lines2 = t2.splitlines()
179        nmin = min(len(lines1),len(lines2))
180        for n,(l1,l2) in enumerate(zip(lines1[:nmin],lines2[:nmin])):
181            tokens1 = read_text_tokens(l1)
182            tokens2 = read_text_tokens(l2)
183            if value_diff(tokens1,tokens2,atol,rtol,int_as_float):
184                diff1[n] = l1
185                diff2[n] = l2
186            #end if
187        #end for
188        if len(lines1)>len(lines2):
189            for n,l in enumerate(lines1[nmin:]):
190                diff1[nmin+n] = l
191                diff2[nmin+n] = None
192            #end for
193        elif len(lines2)>len(lines1):
194            for n,l in enumerate(lines2[nmin:]):
195                diff1[nmin+n] = None
196                diff2[nmin+n] = l
197            #end for
198        #end if
199        return diff,diff1,diff2
200    #end if
201#end def text_diff
202
203
204# print the difference between two objects
205def print_diff(o1,o2,atol=def_atol,rtol=def_rtol,int_as_float=False,text=False,by_line=False): # used in debugging, not actual tests
206    from generic import obj
207    hline = '========== {} =========='
208    print(hline.format('left object'))
209    print(o1)
210    print(hline.format('right object'))
211    print(o2)
212    if not text:
213        diff,diff1,diff2 = object_diff(o1,o2,atol,rtol,int_as_float,full=True)
214    else:
215        diff,diff1,diff2 = text_diff(o1,o2,atol,rtol,int_as_float,full=True,by_line=by_line)
216    #end if
217    d1 = obj(diff1)
218    d2 = obj(diff2)
219    print(hline.format('left diff'))
220    print(d1)
221    print(hline.format('right diff'))
222    print(d2)
223#end def print_diff
224
225
226# check for value equality and if different, print the difference
227def check_value_eq(v1,v2,**kwargs):
228    same = value_eq(v1,v2,**kwargs)
229    if not same and global_data['verbose']:
230        print('\nValues differ, please see below for details')
231        hline = '========== {} =========='
232        print()
233        print(hline.format('left value'))
234        print(v1)
235        print()
236        print(hline.format('right value'))
237        print(v2)
238        print()
239    #end if
240    return same
241#end def check_value_eq
242
243
244# check for object equality and if different, print the difference
245def check_object_eq(o1,o2,**kwargs):
246    same = object_eq(o1,o2,**kwargs)
247    if not same and global_data['verbose']:
248        print('\nObjects differ, please see below for details')
249        print_diff(o1,o2)
250    #end if
251    return same
252#end def check_object_eq
253
254
255
256# additional convenience functions to use value_diff and object_diff
257value_neq = value_diff
258def value_eq(*args,**kwargs):
259    return not value_neq(*args,**kwargs)
260#end def value_eq
261
262object_neq = object_diff
263def object_eq(*args,**kwargs):
264    return not object_neq(*args,**kwargs)
265#end def object_eq
266
267text_neq = text_diff
268def text_eq(*args,**kwargs):
269    return not text_neq(*args,**kwargs)
270#end def text_eq
271
272
273# find the path to the Nexus directory and other internal paths
274def nexus_path(append=None,location=None):
275    import os
276    testing_path = os.path.realpath(__file__)
277
278    assert(isinstance(testing_path,str))
279    assert(len(testing_path)>0)
280    assert('/' in testing_path)
281
282    tokens = testing_path.split('/')
283
284    assert(len(tokens)>=3)
285    assert(tokens[-1].startswith('testing.py'))
286    assert(tokens[-2]=='lib')
287    assert(tokens[-3]=='nexus')
288
289    path = os.path.dirname(testing_path)
290    path = os.path.dirname(path)
291
292    assert(path.endswith('/nexus'))
293
294    if location is not None:
295        if location=='unit':
296            append = 'tests/unit'
297        elif location=='bin':
298            append = 'bin'
299        else:
300            print('nexus location "{}" is unknown'.format(location))
301            raise ValueError
302        #end if
303    #end if
304    if append is not None:
305        path = os.path.join(path,append)
306    #end if
307
308    assert(os.path.exists(path))
309
310    return path
311#end def nexus_path
312
313
314
315# find the path to a file associated with a unit test
316def unit_test_file_path(test,file=None):
317    import os
318    unit_path  = nexus_path(location='unit')
319    files_dir  = 'test_{}_files'.format(test)
320    path = os.path.join(unit_path,files_dir)
321    if file is not None:
322        path = os.path.join(path,file)
323    #end if
324    assert(os.path.exists(path))
325    return path
326#end def unit_test_file_path
327
328
329
330# collect paths to all files associated with a unit test
331def collect_unit_test_file_paths(test,storage):
332    import os
333    if len(storage)==0:
334        test_files_dir = unit_test_file_path(test)
335        files = os.listdir(test_files_dir)
336        for file in files:
337            if not file.startswith('.'):
338                filepath = os.path.join(test_files_dir,file)
339                assert(os.path.exists(filepath))
340                storage[file] = filepath
341            #end if
342        #end for
343    #end if
344    return storage
345#end def collect_unit_test_file_paths
346
347
348
349# find the output path for a test
350def unit_test_output_path(test,subtest=None):
351    import os
352    unit_path  = nexus_path(location='unit')
353    files_dir  = 'test_{}_output'.format(test)
354    path = os.path.join(unit_path,files_dir)
355    if subtest is not None:
356        path = os.path.join(path,subtest)
357    #end if
358    return path
359#end def unit_test_output_path
360
361
362
363# setup the output directory for a test
364def setup_unit_test_output_directory(test,subtest,divert=False,file_sets=None,pseudo_dir=None,pseudo_files=None,pseudo_files_create=None):
365    import os
366    import shutil
367    from subprocess import Popen,PIPE
368
369    divert |= pseudo_dir is not None
370
371    path = unit_test_output_path(test,subtest)
372    assert('nexus' in path)
373    assert('unit' in path)
374    assert(os.path.basename(path).startswith('test_'))
375    assert(path.endswith('/'+subtest))
376    if os.path.exists(path):
377        shutil.rmtree(path)
378    #end if
379    os.makedirs(path)
380    assert(os.path.exists(path))
381
382    # divert nexus paths and output, if requested
383    if divert:
384        from nexus_base import nexus_core
385        divert_nexus()
386        nexus_core.local_directory  = path
387        nexus_core.remote_directory = path
388        nexus_core.file_locations = nexus_core.file_locations + [path]
389    #end if
390
391    # transfer files into output directory, if requested
392    if file_sets is not None:
393        if isinstance(file_sets,list):
394            file_sets = {'':file_sets}
395        #end if
396        assert(isinstance(file_sets,dict))
397        filepaths = dict()
398        collect_unit_test_file_paths(test,filepaths)
399        for fpath,filenames in file_sets.items():
400            assert(len(set(filenames)-set(filepaths.keys()))==0)
401            dest_path = path
402            if fpath is not None:
403                dest_path = os.path.join(dest_path,fpath)
404                if not os.path.exists(dest_path):
405                    os.makedirs(dest_path)
406                #end if
407            #end if
408            assert(os.path.exists(dest_path))
409            for filename in filenames:
410                source_filepath = filepaths[filename]
411                if os.path.isdir(source_filepath):
412                    command = 'rsync -a {} {}'.format(source_filepath,dest_path)
413                    process = Popen(command,shell=True,stdout=PIPE,stderr=PIPE,close_fds=True)
414                    out,err = process.communicate()
415                else:
416                    shutil.copy2(source_filepath,dest_path)
417                #end if
418                assert(os.path.exists(dest_path))
419            #end for
420        #end for
421    #end if
422
423    # create pseudopotential directory and set internal nexus data structures
424    if pseudo_dir is not None:
425        from nexus_base import nexus_noncore
426        pseudo_path = os.path.join(path,pseudo_dir)
427        if not os.path.exists(pseudo_path):
428            os.makedirs(pseudo_path)
429        #end if
430        assert(os.path.exists(pseudo_path))
431        pseudo_filepaths = []
432        if pseudo_files is not None:
433            assert(isinstance(pseudo_files,list))
434            for src_file in pseudo_files:
435                assert(os.path.exists(src_file))
436                assert(os.path.isfile(src_file))
437                pp_filename = os.path.basename(src_file)
438                pp_file = os.path.join(pseudo_path,pp_filename)
439                shutil.copy2(src_file,pseudo_path)
440                assert(os.path.exists(pp_file))
441                assert(os.path.isfile(pp_file))
442                pseudo_filepaths.append(pp_file)
443            #end for
444        #end if
445        if pseudo_files_create is not None:
446            assert(isinstance(pseudo_files_create,list))
447            for pp_filename in pseudo_files_create:
448                pp_contents = ''
449                if isinstance(pp_filename,tuple):
450                    pp_filename,pp_contents = pp_filename
451                #end if
452                pp_file = os.path.join(pseudo_path,pp_filename)
453                f = open(pp_file,'w')
454                f.write(pp_contents)
455                f.close()
456                assert(os.path.exists(pp_file))
457                assert(os.path.isfile(pp_file))
458                pseudo_filepaths.append(pp_file)
459            #end for
460        #end if
461        if len(pseudo_filepaths)>0:
462            from pseudopotential import Pseudopotentials
463            for pp_file in pseudo_filepaths:
464                assert(os.path.exists(pp_file))
465                assert(os.path.isfile(pp_file))
466            #end for
467            pps = Pseudopotentials(pseudo_filepaths)
468            nexus_core.pseudopotentials    = pps
469            nexus_noncore.pseudopotentials = pps
470        #end if
471        nexus_core.pseudo_dir    = pseudo_path
472        nexus_noncore.pseudo_dir = pseudo_path
473    #end if
474
475    return path
476#end def setup_unit_test_output_directory
477
478
479
480# class used to divert log output when desired
481class FakeLog:
482    def __init__(self):
483        self.reset()
484    #end def __init__
485
486    def reset(self):
487        self.s = ''
488    #end def reset
489
490    def write(self,s):
491        self.s+=s
492    #end def write
493
494    def close(self):
495        None
496    #end def close
497
498    def contents(self):
499        return self.s
500    #end def contents
501#end class FakeLog
502
503
504# dict to temporarily store logger when log output is diverted
505logging_storage = dict()
506
507# dict to temporarily store nexus core attributes when diverted
508nexus_core_storage    = dict()
509nexus_noncore_storage = dict()
510
511
512# divert nexus log output
513def divert_nexus_log():
514    from generic import generic_settings,object_interface
515    assert(len(logging_storage)==0)
516    logging_storage['devlog'] = generic_settings.devlog
517    logging_storage['objlog'] = object_interface._logfile
518    logfile = FakeLog()
519    generic_settings.devlog   = logfile
520    object_interface._logfile = logfile
521    return logfile
522#end def divert_nexus_log
523
524
525# restore nexus log output
526def restore_nexus_log():
527    from generic import generic_settings,object_interface
528    assert(set(logging_storage.keys())==set(['devlog','objlog']))
529    generic_settings.devlog   = logging_storage.pop('devlog')
530    object_interface._logfile = logging_storage.pop('objlog')
531    assert(len(logging_storage)==0)
532#end def restore_nexus_log
533
534
535core_keys = [
536    'local_directory',
537    'remote_directory',
538    'mode',
539    'stages',
540    'stages_set',
541    'status',
542    'sleep',
543    'file_locations',
544    'pseudo_dir',
545    'pseudopotentials',
546    'runs',
547    'results',
548    ]
549noncore_keys = [
550    'pseudo_dir',
551    'pseudopotentials',
552    ]
553
554# divert nexus core attributes
555def divert_nexus_core():
556    from nexus_base import nexus_core,nexus_noncore
557    assert(len(nexus_core_storage)==0)
558    for key in core_keys:
559        nexus_core_storage[key] = nexus_core[key]
560    #end for
561    assert(len(nexus_noncore_storage)==0)
562    for key in noncore_keys:
563        if key in nexus_noncore:
564            nexus_noncore_storage[key] = nexus_noncore[key]
565        #end if
566    #end for
567#end def divert_nexus_core
568
569
570# restore nexus core attributes
571def restore_nexus_core():
572    from nexus_base import nexus_core,nexus_noncore,nexus_core_noncore
573    from nexus_base import nexus_noncore_defaults
574    for key in core_keys:
575        nexus_core[key] = nexus_core_storage.pop(key)
576    #end for
577    assert(len(nexus_core_storage)==0)
578    for key in noncore_keys:
579        if key in nexus_noncore_storage:
580            nexus_noncore[key] = nexus_noncore_storage.pop(key)
581        elif key in nexus_noncore:
582            del nexus_noncore[key]
583        #end if
584    #end for
585    for key in list(nexus_noncore.keys()):
586        if key not in nexus_noncore_defaults:
587            del nexus_noncore[key]
588        #end if
589    #end for
590    nexus_core_noncore.pseudopotentials = None
591    assert(len(nexus_noncore_storage)==0)
592#end def restore_nexus_core
593
594
595def divert_nexus():
596    divert_nexus_log()
597    divert_nexus_core()
598#end def divert_nexus
599
600
601def restore_nexus():
602    restore_nexus_log()
603    restore_nexus_core()
604#end def restore_nexus
605
606
607
608# declare test failure
609#   useful inside try/except blocks
610def failed(msg='Test failed.'):
611    assert False,msg
612#end def failed
613
614
615class FailedTest(Exception):
616    None
617#end class FailedTest
618
619
620global_data = dict(
621    verbose       = False,
622    job_ref_table = False,
623    )
624
625
626def divert_nexus_errors():
627    from generic import generic_settings
628    generic_settings.raise_error = True
629#end def divert_nexus_errors
630
631
632def clear_all_sims():
633    from simulation import Simulation
634    Simulation.clear_all_sims()
635#end def clear_all_sims
636
637
638
639def check_final_state():
640    from nexus_base import nexus_core,nexus_core_defaults
641    from nexus_base import nexus_noncore,nexus_noncore_defaults
642    from nexus_base import nexus_core_noncore,nexus_core_noncore_defaults
643
644    assert('runs' in nexus_core_defaults)
645    assert('basis_dir' in nexus_noncore_defaults)
646    assert('pseudo_dir' in nexus_core_noncore_defaults)
647
648    assert(object_eq(nexus_core,nexus_core_defaults))
649    assert(object_eq(nexus_noncore,nexus_noncore_defaults))
650    assert(object_eq(nexus_core_noncore,nexus_core_noncore_defaults))
651
652    from simulation import Simulation
653
654    assert(Simulation.sim_count==0)
655    assert(len(Simulation.all_sims)==0)
656    assert(len(Simulation.sim_directories)==0)
657#end def check_final_state
658
659
660
661def executable_path(exe_name):
662    import os
663    # nexus bin directory
664    nexus_bin = nexus_path(location='bin')
665    # path to exe
666    exe_path = os.path.join(nexus_bin,exe_name)
667    # exe file exists
668    assert(os.path.isfile(exe_path))
669    # exe file is executable
670    assert(os.access(exe_path,os.X_OK))
671    return exe_path
672#end def executable_path
673
674
675
676def create_file(filename,path,contents=''):
677    import os
678    filepath = os.path.join(path,filename)
679    f = open(filepath,'w')
680    f.write(contents)
681    f.close()
682    assert(os.path.isfile(filepath))
683    return filepath
684#end def create_file
685
686
687
688def create_path(path,basepath=None):
689    import os
690    if basepath is not None:
691        path = os.path.join(basepath,path)
692    #end if
693    if not os.path.exists(path):
694        os.makedirs(path)
695    #end if
696    assert(os.path.isdir(path))
697#end def create_path
698
699
700
701def execute(command):
702    from execute import execute as nexus_execute
703    out,err,rc = nexus_execute(command)
704    if rc!=0:
705        msg = '''Executed system command failed.
706
707Command:
708========
709{}
710
711stdout:
712=======
713{}
714
715stderr:
716=======
717{}
718
719Return code:
720============
721{}
722
723'''.format(command,out,err,rc)
724        failed(msg)
725    #end if
726    return out,err,rc
727#end def execute
728