Changeset 1805 for trunk


Ignore:
Timestamp:
20/01/12 14:43:50 (4 months ago)
Author:
Ray Osborn
Message:

Refs #322: Added arguments to use the standard matplotlib format strings to
set axis and intensity limits. Improved handling of log scales. Added oplot
and logplot methods.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/bindings/python/nxs/tree.py

    r1804 r1805  
    11#!/usr/bin/env python 
    2 # This program is public domain 
     2# This program is public domain  
     3# Author: Paul Kienzle, Ray Osborn 
    34 
    45""" 
     
    218219class.  The plotter class has one method:: 
    219220 
    220     plot(signal, axes, entry, title) 
     221    plot(signal, axes, entry, title, format, **opts) 
    221222 
    222223where signal is the field containing the data, axes are the fields listing the 
     
    772773        print "yielding",self.nxname,self.nxclass 
    773774        if False: yield 
    774  
    775775 
    776776    def dir(self,attrs=False,recursive=False): 
     
    16511651    """ 
    16521652 
    1653     def plot(self, signal, axes, title, errors, **opts): 
     1653    def plot(self, signal, axes, title, errors, fmt,  
     1654             xmin, xmax, ymin, ymax, zmin, zmax, **opts): 
    16541655        """ 
    16551656        Plot the data entry. 
     
    16581659        """ 
    16591660        try: 
    1660             import pylab 
     1661            import matplotlib.pyplot as plt 
    16611662        except ImportError: 
    16621663            raise NeXusError, "Plotting package not available." 
     1664 
     1665        over = False 
    16631666        if "over" in opts.keys(): 
    1664             over = True 
     1667            if opts["over"]: over = True 
    16651668            del opts["over"] 
    1666         else: 
    1667             over = False 
    1668         if not over: pylab.clf() 
    1669  
     1669 
     1670        log = logx = logy = False 
    16701671        if "log" in opts.keys(): 
    1671             logplot = True 
     1672            if opts["log"]: log = True 
    16721673            del opts["log"] 
    1673         else: 
    1674             logplot = False 
     1674        if "logy" in opts.keys(): 
     1675            if opts["logy"]: logy = True 
     1676            del opts["logy"] 
     1677        if "logx" in opts.keys(): 
     1678            if opts["logx"]: logx = True 
     1679            del opts["logx"] 
     1680 
     1681        if over: 
     1682            plt.autoscale(enable=False) 
     1683        else: 
     1684            plt.autoscale(enable=True) 
     1685            plt.clf() 
    16751686 
    16761687        # Provide a new view of the data if there is a dimension of length 1 
     
    16851696        #One-dimensional Plot 
    16861697        if len(data.shape) == 1: 
     1698            plt.ioff() 
    16871699            if hasattr(signal, 'units'): 
    16881700                if not errors and signal.units == 'counts': 
    16891701                    errors = NXfield(np.sqrt(data)) 
    1690             if logplot: 
    1691                 data = np.log10(np.clip(data,0,1e8)) 
    1692                 if errors: ebars = np.log10(errors) 
    1693             elif errors: 
     1702            if errors: 
    16941703                ebars = errors.nxdata 
    1695             if errors: 
    1696                 myopts=copy(opts) 
    1697                 myopts.setdefault('fmt','o') 
    1698                 myopts.setdefault('linestyle','None') 
    1699                 pylab.scatter(axis_data[0], data, **opts) 
    1700                 pylab.errorbar(axis_data[0], data, ebars, **myopts) 
     1704                plt.errorbar(axis_data[0], data, ebars, fmt=fmt, **opts) 
    17011705            else: 
    1702                 pylab.scatter(axis_data[0], data, **opts) 
     1706                plt.plot(axis_data[0], data, fmt, **opts) 
    17031707            if not over: 
    1704                 pylab.xlabel(label(axes[0])) 
    1705                 pylab.ylabel(label(signal)) 
    1706                 pylab.title(title) 
     1708                ax = plt.gca() 
     1709                xlo, xhi = ax.set_xlim(auto=True)         
     1710                ylo, yhi = ax.set_ylim(auto=True)                 
     1711                if xmin: xlo = xmin 
     1712                if xmax: xhi = xmax 
     1713                ax.set_xlim(xlo, xhi) 
     1714                if ymin: ylo = ymin 
     1715                if ymax: yhi = ymax 
     1716                ax.set_ylim(ylo, yhi) 
     1717                if logx: ax.set_xscale('symlog') 
     1718                if log or logy: ax.set_yscale('symlog') 
     1719                plt.xlabel(label(axes[0])) 
     1720                plt.ylabel(label(signal)) 
     1721                plt.title(title) 
     1722            plt.ion() 
     1723            plt.show() 
    17071724 
    17081725        #Two dimensional plot 
    17091726        else: 
     1727            from matplotlib.image import NonUniformImage 
     1728            from matplotlib.colors import LogNorm 
     1729 
    17101730            if len(data.shape) > 2: 
    17111731                slab = [slice(None), slice(None)] 
     
    17141734                data = data[slab].view().reshape(data.shape[:2]) 
    17151735                print "Warning: Only the top 2D slice of the data is plotted" 
    1716             #from api.nexus import meshgl 
    1717             #gridplot = meshgl.pcolor_gl 
    1718             #gridplot = pylab.pcolormesh 
    1719             gridplot = imshow_irregular 
    1720             if logplot: 
    1721                 gridplot(axis_data[0], axis_data[1], 
    1722                          np.log10(np.clip(data,0.,1e8)+1).T, **opts) 
     1736 
     1737            x = axis_data[0] 
     1738            y = axis_data[1] 
     1739            if not zmin: zmin = np.min(data) 
     1740            if not zmax: zmax = np.max(data) 
     1741            z = np.clip(data,zmin,zmax).T 
     1742             
     1743            ax = plt.gca() 
     1744            extent = (x[0],x[-1],y[0],y[-1]) 
     1745            if log: 
     1746                opts["norm"] = LogNorm() 
     1747                if z.min() < 1e-8: 
     1748                    z = np.clip(z,0.1,zmax) 
     1749                 
     1750            im = NonUniformImage(ax, extent=extent, origin=None, **opts) 
     1751            im.set_data(x,y,z) 
     1752            ax.images.append(im) 
     1753            xlo, xhi = ax.set_xlim(x[0],x[-1]) 
     1754            ylo, yhi = ax.set_ylim(y[0],y[-1]) 
     1755            if xmin:  
     1756                xlo = xmin 
    17231757            else: 
    1724                 gridplot(axis_data[0], axis_data[1], np.clip(data,-1e8,1e8).T, **opts) 
    1725             pylab.xlabel(label(axes[0])) 
    1726             pylab.ylabel(label(axes[1])) 
    1727             pylab.title(title) 
     1758                xlo = x[0] 
     1759            if xmax:  
     1760                xhi = xmax 
     1761            else: 
     1762                xhi = x[-1] 
     1763            if ymin:  
     1764                yhi = ymin 
     1765            else: 
     1766                yhi = y[0] 
     1767            if ymax:  
     1768                yhi = ymax 
     1769            else: 
     1770                yhi = y[-1] 
     1771            ax.set_xlim(xlo, xhi) 
     1772            ax.set_ylim(ylo, yhi) 
     1773            plt.xlabel(label(axes[0])) 
     1774            plt.ylabel(label(axes[1])) 
     1775            plt.title(title) 
     1776            plt.colorbar(im) 
     1777            plt.gcf().canvas.draw_idle() 
     1778         
     1779        return plt.gcf() 
    17281780 
    17291781    @staticmethod 
    17301782    def show(): 
    1731         import pylab 
    1732         pylab.show() 
     1783        import matplotlib.pyplot as plt 
     1784        plt.show()     
    17331785 
    17341786 
     
    18501902 
    18511903        >>> entry.sample.tree = 100.0 
    1852         >>> entry.sample.tree 
     1904        >>> print entry.sample.tree 
    18531905        sample:NXsample 
    18541906          tree = 100.0 
     
    18701922        >>> entry.sample.temperature = 40.0 
    18711923        >>> entry.sample.attrs['tree'] = 10.0 
    1872         >>> entry.sample.tree 
     1924        >>> print entry.sample.tree 
    18731925        sample:NXsample 
    18741926          @tree = 10.0 
     
    18971949 
    18981950    tree: 
    1899         Print the group tree. 
     1951        Return the group tree. 
    19001952 
    19011953        It invokes the 'dir' method with both 'attrs' and 'recursive' 
    1902         set to True. Note that this method is defined as a property attribute and 
    1903         does not require parentheses. 
     1954        set to True. 
    19041955 
    19051956    save(self, filename, format='w5') 
     
    19161967    >>> entry.sample = NXgroup(temperature=NXfield(40.0,units='K'), 
    19171968                               nxclass='NXsample') 
    1918     >>> entry.sample.tree 
     1969    >>> print entry.sample.tree 
    19191970    sample:NXsample 
    19201971      temperature = 40.0 
     
    21632214            raise NeXusError, "Link target must be an NXobject" 
    21642215 
    2165  
    21662216    def sum(self, axis=None): 
    21672217        """ 
     
    21862236            signal.long_name = "Integral from %s to %s %s" % \ 
    21872237                               (summedaxis[0], summedaxis[-1], units) 
    2188             average = NXfield(0.5*(summedaxis.nxdata[0]+summedaxis.nxdata[-1]), name=summedaxis.nxname) 
     2238            average = NXfield(0.5*(summedaxis.nxdata[0]+summedaxis.nxdata[-1]),  
     2239                                   name=summedaxis.nxname) 
    21892240            if units: average.units = units 
    21902241            result = NXdata(signal, axes, average) 
     
    22962347    entries = property(_getentries,doc="NeXus objects within group") 
    22972348 
    2298  
    2299     def plot(self, **opts): 
     2349    def plot(self, fmt='bo', xmin=None, xmax=None, ymin=None, ymax=None, 
     2350             zmin=None, zmax=None, **opts): 
    23002351        """ 
    23012352        Plot data contained within the group. 
     2353 
     2354        The format argument is used to set the color and type of the 
     2355        markers or lines for one-dimensional plots, using the standard  
     2356        matplotlib syntax. The default is set to blue circles. All  
     2357        keyword arguments accepted by matplotlib.pyplot.plot can be 
     2358        used to customize the plot. 
     2359         
     2360        In addition to the matplotlib keyword arguments, the following 
     2361        are defined: 
     2362         
     2363            log = True     - plot the intensity on a log scale 
     2364            logy = True    - plot the y-axis on a log scale 
     2365            logx = True    - plot the x-axis on a log scale 
     2366            over = True    - plot on the current figure 
    23022367 
    23032368        Raises NeXusError if the data could not be plotted. 
     
    23282393 
    23292394        # Plot with the available plotter 
    2330         group._plotter.plot(signal, axes, title, errors, **opts) 
    2331  
     2395        group._plotter.plot(signal, axes, title, errors, fmt,  
     2396                            xmin, xmax, ymin, ymax, zmin, zmax, **opts) 
     2397     
     2398    def oplot(self, fmt='bo', **opts): 
     2399        """ 
     2400        Plot the data contained within the group over the current figure. 
     2401        """ 
     2402        self.plot(fmt=fmt, over=True, **opts) 
     2403 
     2404    def logplot(self, fmt='bo', xmin=None, xmax=None, ymin=None, ymax=None, 
     2405                zmin=None, zmax=None, **opts): 
     2406        """ 
     2407        Plot the data intensity contained within the group on a log scale. 
     2408        """ 
     2409        self.plot(fmt=fmt, log=True, 
     2410                  xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax, 
     2411                  zmin=zmin, zmax=zmax, **opts) 
    23322412 
    23332413class NXlink(NXobject): 
     
    28522932        return field.nxname 
    28532933 
    2854 def imshow_irregular(x,y,z): 
    2855     import pylab 
    2856 #    from matplotlib.ticker import LogFormatter 
    2857     ax = pylab.gca() 
    2858     im = pylab.mpl.image.NonUniformImage(ax, extent=(x[0],x[-1],y[0],y[-1]), origin=None) 
    2859     im.set_data(x,y,z) 
    2860     ax.images.append(im) 
    2861     ax.set_xlim(x[0],x[-1]) 
    2862     ax.set_ylim(y[0],y[-1]) 
    2863     pylab.colorbar(im)#, format=LogFormatter()) 
    2864     pylab.gcf().canvas.draw_idle() 
    2865  
    28662934# File level operations 
    28672935def load(filename, mode='r'): 
Note: See TracChangeset for help on using the changeset viewer.