root/trunk/bindings/python/nxstest.py

Revision 1222, 12.9 KB (checked in by Paul Kienzle, 17 months ago)

python binding requires matched open/close data to maintain path. Refs #101.

Line 
1# This program is public domain
2
3# Author: Paul Kienzle
4
5"""
6NeXus tests converted to python.
7"""
8
9import nxs,os,numpy,sys
10
11def memfootprint():
12    import gc
13    objs = gc.get_objects()
14    classes = set( c.__class__ for c in gc.get_objects() if hasattr(c,'__class__') )
15    # print "\n".join([c.__name__ for c in classes])
16    print "#objects=",len(objs)
17    print "#classes=",len(classes)
18
19def leak_test1(n = 1000, mode='w5'):
20#    import gc
21#    gc.enable()
22#    gc.set_debug(gc.DEBUG_LEAK)
23    filename = "leak_test1.nxs"
24    try: os.unlink(filename)
25    except OSError: pass
26    file = nxs.open(filename,mode)
27    file.close()
28    print "File should exist now"
29    for i in range(n):
30        if i%100 == 0: 
31            print "loop count %d"%i
32            memfootprint()
33        file.open()
34        file.close()
35#        gc.collect()
36    os.unlink(filename)
37
38def _show(file, indent=0):
39    prefix = ' '*indent
40    link = file.link()
41    if link:
42        print "%(prefix)s-> %(link)s" % locals()
43        return
44    for attr,value in file.attrs():
45        print "%(prefix)s@%(attr)s: %(value)s" % locals()
46    for name,nxclass in file.entries():
47        if nxclass == "SDS":
48            shape,dtype = file.getinfo()
49            dims = "x".join([str(x) for x in shape])
50            print "%(prefix)s%(name)s %(dtype)s %(dims)s" % locals()
51            link = file.link()
52            if link:
53                print %(prefix)s-> %(link)s" % locals()
54            else:
55                for attr,value in file.attrs():
56                    print %(prefix)s@%(attr)s: %(value)s" % locals()
57                if numpy.prod(shape) < 8:
58                    value = file.getdata()
59                    print %s%s"%(prefix,str(value))
60        else:
61            print "%(prefix)s%(name)s %(nxclass)s" % locals()
62            _show(file, indent+2)
63
64def show_structure(filename):
65    file = nxs.open(filename)
66    print "=== File",file.inquirefile()
67    _show(file)
68   
69
70def populate(filename,mode):
71    i1 = numpy.arange(4,dtype='uint8')
72    i2 = numpy.arange(4,dtype='int16')*1000
73    i4 = numpy.arange(4,dtype='int32')*1000000
74    i8 = numpy.arange(4,dtype='int64')*1000000000000
75    r4 = numpy.arange(20,dtype='float32').reshape((5,4))
76    r8 = numpy.arange(20,dtype='float64').reshape((5,4))
77    comp_array=numpy.ones((100,20),dtype='int32')
78    for i in range(100): comp_array[i,:] *= i
79
80    file = nxs.open(filename,mode)
81    file.setnumberformat('float32','%9.3g')
82    file.makegroup("entry","NXentry")
83    file.opengroup("entry","NXentry")
84    file.putattr("hugo","namenlos")
85    file.putattr("cucumber","passion")
86    #file.putattr("embedded_null","embedded\000null")
87
88    # Write character data
89    file.makedata("ch_data",'char',[10])
90    file.opendata("ch_data")
91    file.putdata("NeXus data")
92    file.closedata()
93   
94    # Write numeric data
95    for var in ['i1','i2','i4','i8','r4']:
96        if mode == 'w4' and var == 'i8': continue
97        name = var+'_data'
98        val = locals()[var]
99        file.makedata(name,val.dtype,val.shape)
100        file.opendata(name)
101        file.putdata(val)
102        file.closedata()
103   
104    # Write r8_data
105    file.makedata('r8_data','float64',[5,4])
106    file.opendata('r8_data')
107    file.putslab(r8[4,:],[4,0],[1,4])
108    file.putslab(r8[0:4,:],[0,0],[4,4])
109    file.putattr("ch_attribute","NeXus")
110    file.putattr("i4_attribute",42,dtype='int32')
111    file.putattr("r4_attribute",3.14159265,dtype='float32')
112    ## Oops... NAPI doesn't support array attributes
113    #file.putattr("i4_array",[3,2],dtype='int32')
114    #file.putattr("r4_array",[3.14159265,2.718281828],dtype='float32')
115    dataID = file.getdataID()
116    file.closedata()
117
118    # Create the NXdata group
119    file.makegroup("data","NXdata")
120    file.opengroup("data","NXdata")
121   
122    # .. demonstrate linking
123    file.makelink(dataID)
124
125    # .. demonstrate compressed data
126    file.compmakedata("comp_data",'int32',[100,20],'lzw',[20,20])
127    file.opendata('comp_data')
128    file.putdata(comp_array)
129    file.closedata()
130    file.flush()
131
132    # .. demonstrate extensible data
133    file.makedata('flush_data','int32',[nxs.UNLIMITED])
134    file.opendata('flush_data')
135    for i in range(7):
136        file.putslab(i,[i],[1])
137    file.closedata()
138    file.flush()
139    file.closegroup()
140
141    # Create NXsample group
142    file.makegroup('sample','NXsample')
143    file.opengroup('sample','NXsample')
144    file.makedata('ch_data','char',[20])
145    file.opendata('ch_data')
146    file.putdata('NeXus sample')
147    file.closedata()
148    sampleID = file.getgroupID()
149    file.closegroup()
150    file.closegroup()
151
152    # Demonstrate named links
153    file.makegroup('link','NXentry')
154    file.opengroup('link','NXentry')
155    file.makelink(sampleID)
156    file.makenamedlink('renLinkGroup',sampleID)
157    file.makenamedlink('renLinkData',dataID)
158    file.closegroup()
159   
160    file.close()
161    return filename
162
163failures = 0
164def fail(msg):
165    global failures
166    print "FAIL:",msg
167    failures += 1
168
169def dicteq(a,b):
170    """
171    Compare two dictionaries printing how they differ.
172    """
173    for k,v in a.iteritems():
174        if k not in b:
175            print k,"not in",b
176            return False
177        if v != b[k]: 
178            print v,"not equal",b[k]
179            return False
180    for k,v in b.iteritems():
181        if k not in a: 
182            print k,"not in",a
183            return False
184    return True
185
186def check(filename, mode):
187    global failures
188    failures = 0
189    file = nxs.open(filename,'rw')
190    if filename != file.inquirefile(): fail("Files don't match")
191
192    # check headers
193    num_attrs = file.getattrinfo()
194    wxattrs = ['xmlns','xmlns:xsi','xsi:schemaLocation', 'XML_version']
195    w4attrs = ['HDF_version']
196    w5attrs = ['HDF5_Version']
197    extras = dict(wx=wxattrs,w4=w4attrs,w5=w5attrs)
198    expected_attrs = ['NeXus_version','file_name','file_time']+extras[mode]
199    for i in range(num_attrs):
200        name,dims,type = file.getnextattr()
201        if name not in expected_attrs:
202            fail("attribute %s unexpected"%(name))
203    if num_attrs != len(expected_attrs): 
204        fail("Expected %d root attributes but got %d"
205             % (len(expected_attrs),num_attrs))
206   
207    file.opengroup('entry','NXentry')
208   
209    expect = dict(hugo='namenlos',cucumber='passion')
210    #expect['embedded_null'] = "embedded\000null"
211    get = dict((k,v) for k,v in file.attrs())
212    same = dicteq(get,expect)
213    if not same: fail("/entry attributes are %s"%(get))
214
215    # Check that the numbers are written correctly
216    for name,dtype,shape,scale in \
217        [('i1','int8',(4),1),
218         ('i2','int16',(4),1000),
219         ('i4','int32',(4),1000000),
220         ('i8','int64',(4),1000000000000),
221         ('r4','float32',(5,4),1),
222         ('r8','float64',(5,4),1)
223         ]:
224        if mode == 'w4' and name == 'i8': continue
225        n = numpy.prod(shape)
226        expected = numpy.arange(n,dtype=dtype).reshape(shape)*scale
227        file.opendata(name+'_data')
228        get = file.getdata()
229        file.closedata()
230        if not (get == expected).all(): 
231            fail("%s retrieved %s"%(dtype,get))
232
233
234    # Check attribute types
235    file.opendata('r8_data')
236    get = file.getattr("ch_attribute",5,'char')
237    if not get == "NeXus": fail("ch_attribute retrieved %s"%(get))
238    get = file.getattr("i4_attribute",1,'int32')
239    if not get == numpy.int32(42): fail("i4_attribute retrieved %s"%(get))
240    get = file.getattr("r4_attribute",1,'float32')
241    if ((mode=='wx' and not abs(get-3.14159265) < 1e-6) or
242        (mode!='wx' and not get == numpy.float32(3.14159265))):
243        fail("r4_attribute retrieved %s"%(get))
244    ## Oops... NAPI doesn't support array attributes
245    #expect = numpy.array([3,2],dtype='int32')
246    #get = file.getattr("i4_array",2,'int32')
247    #if not (get==expect).all(): fail('i4_array retrieved %s'%(get))
248    #expect = numpy.array([3.14159265,2.718281828],dtype='float32')
249    #get = file.getattr("r4_array",2,dtype='float32')
250    #if not (get==expect).all(): fail("r4_array retrieved %s"%(get))
251    file.closedata()
252
253
254    # Check reading from compressed datasets
255    comp_array=numpy.ones((100,20),dtype='int32')
256    for i in range(100): comp_array[i,:] *= i
257    expected = comp_array
258    file.opengroup('data','NXdata') #/entry/data
259    file.opendata('comp_data')      #/entry/data/comp_data
260    get = file.getdata()
261    file.closedata()                #/entry/data/comp_data
262    file.closegroup()               #/entry/data
263    if not (get == expected).all():
264        fail("compressed data differs")
265        print get
266       
267    # Check strings
268    file.opengroup('sample','NXsample') #/entry/sample
269    file.opendata('ch_data')            #/entry/sample/ch_data
270    rawshape,rawdtype = file.getrawinfo()
271    shape,dtype = file.getinfo()
272    get = file.getdata()
273    file.closedata()                    #/entry/sample/ch_data
274    file.closegroup()                   #/entry/sample
275    if not (shape[0]==12 and dtype=='char'):
276        fail("returned string info is incorrect")
277        print shape,dtype
278    if not (rawshape[0]==20 and rawdtype=='char'):
279        fail("returned string storage info is incorrect")
280        print shape,dtype
281    if not (get == "NeXus sample"):
282        fail("returned string is incorrect")
283        print shape,dtype
284
285
286    file.closegroup() #/entry
287
288    # Check read slab (e.g., from extensible)
289
290    # Check links
291    file.opengroup('entry','NXentry')
292    file.opengroup('sample','NXsample')
293    sampleid = file.getgroupID()
294    file.closegroup() #/entry/sample
295    file.opengroup('data','NXdata') #/entry/data
296    file.opendata('r8_data') #/entry/data/r8_data
297    dataid = file.getdataID()
298    file.closedata() #/entry/data/r8_data
299    file.closegroup() #/entry/data
300    file.opendata('r8_data')
301    data2id = file.getdataID()
302    file.closedata()
303    file.closegroup() #/entry
304    if not (file.sameID(dataid,data2id)):
305        fail("/entry/data/r8_data not linked to /entry/r8_data")
306   
307    # Check openpath and getslab
308    file.openpath('/entry/data/comp_data')
309    get = file.getslab([4,4],[5,3])
310    expected = comp_array[4:(4+5),4:(4+3)]
311    if not (get == expected).all():
312        fail("retrieved compressed slabs differ")
313        print get
314    file.openpath('/entry/data/comp_data')
315    get = file.getslab([4,4],[5,3])
316    expected = comp_array[4:(4+5),4:(4+3)]
317    if not (get == expected).all():
318        fail("after reopen: retrieved compressed slabs differ")
319        print get
320    file.openpath('../r8_data')
321    for k,v in file.attrs():
322        if k == 'target' and v != '/entry/r8_data':
323            fail("relative openpath was not successful")
324
325    return failures == 0
326
327def populate_external(filename,mode):
328    ext = dict(w5='.h5',w4='.hdf',wx='.xml')[mode]
329    file = nxs.open(filename,mode)
330    file.makegroup('entry1','NXentry')
331    file.linkexternal('entry1','NXentry','nxfile://data/dmc01'+ext)
332    file.makegroup('entry2','NXentry')
333    file.linkexternal('entry2','NXentry','nxfile://data/dmc02'+ext)
334    file.makegroup('entry3','NXentry')
335    file.close()
336
337def check_external(filename,mode):
338    ext = dict(w5='.h5',w4='.hdf',wx='.xml')[mode]
339    file = nxs.open(filename,'rw')
340   
341    file.openpath('/entry1/start_time')
342    time = file.getdata()
343   
344    get = file.inquirefile()
345    expected = 'nxfile://data/dmc01'+ext
346    if expected != get: fail("first external file returned %s"%(get))
347   
348    file.openpath('/entry2/sample/sample_name')
349    sample = file.getdata()
350
351    get = file.inquirefile()
352    expected = 'nxfile://data/dmc02'+ext
353    if expected != get: fail("second external file returned %s"%(get))
354
355    file.openpath('/')
356    remote = file.isexternalgroup('entry1','NXentry')
357    if remote is None:
358        fail("failed to identify /entry1 as external")
359    remote = file.isexternalgroup('entry3','NXentry')
360    if remote is not None: 
361        fail('incorrectly identified /entry3 as external')
362   
363    file.close()
364
365def test_external(mode,quiet=True):
366    ext = dict(w5='.h5',w4='.hdf',wx='.xml')[mode]
367    filename = 'nxext'+ext
368    populate_external(external,mode)
369    if not quiet:
370        show_structure(external)
371    failures = check_external(filename,mode)
372    return failures
373
374def test_mode(mode,quiet=True,external=False):
375    ext = dict(w5='.h5',w4='.hdf',wx='.xml')[mode]
376    filename = 'NXtest'+ext
377    populate(filename,mode=mode)
378    if not quiet and 'NX_LOAD_PATH' in os.environ:
379        show_structure('dmc01'+ext)
380    if not quiet:
381        show_structure(filename)
382    failures = check(filename,mode)
383    if external: failures += test_external(mode,quiet)
384    return failures
385
386def test():
387    tests = 0
388    if '-q' in sys.argv:
389        quiet = True
390    else:
391        quiet = False
392    if '-x' in sys.argv:
393        external = True
394       
395    else:
396        external = False
397    if 'hdf4' in sys.argv: 
398        test_mode('w4',quiet,external)
399        tests += 1
400    if 'xml' in sys.argv:
401        test_mode('wx',quiet,external)
402        tests += 1
403    if 'hdf5' in sys.argv: 
404        test_mode('w5',quiet,external)
405        tests += 1
406    if tests == 0: test_mode('w5',quiet,external)
407
408if __name__ == "__main__":
409    test()
410    #leak_test1(n=10000)
411   
Note: See TracBrowser for help on using the browser.