Python CGDAL类——支持栅格数据的栅格计算/线性增强/滤波增强

来源:互联网 发布:凤凰直播软件下载 编辑:程序博客网 时间:2024/05/19 04:07
# -*- coding: UTF-8 -*-'''python version: 2.7.11numpy  ver=1.11.1gdal   ver=2.0.3Author: LiuphDate: 2016/9/9Description: This is a GDAL Class adapted from the Python GDAL_OGR Cookbook documentation(http://pcjericks.github.io/py-gdalogr-cookbook/raster_layers.html). The CGDAL Class can used toload Image and obtain description information of the image file. Usually, basic image processingis included in the class and a linear enhancement or a spatial filtering operation can well performedvia CGDAL Class. Moreover, this class offers some functions to generate a raster image via a numpy array.However,it' does not process perfection in exception handling, for instance, while a image with "None"as nodata value might lead to some puzzles, or, it wont't happen.It need to be confirmed. Addition, as all know, codes in python formats is very concise and distinct. However, the ratio of running speed of python codes and C++ codes is about first in thirty ,which is a serious and longstanding problem, or shortcut. May python be better!'''from osgeo import gdal, gdalnumeric, ogr, osrfrom PIL import Image, ImageDrawimport os, sysfrom gdalconst import *import structimport numpy as npimport regdal.UseExceptions()class CGDAL:    #数据部分    mpoDataset = None    __mpData = None    mpArray = np.array([])    mgDataType = GDT_Byte    mnRows = mnCols = mnBands = -1    mnDatalength = -1    mpGeoTransfor = []    msProjectionRef = ""    msFilename = ""    mdInvalidValue = 0.0    mnPerPixSize = 1    srcSR = None    latLongSR = None    poTransform = None    poTransformT = None    #函数部分    def __init__(self):        pass    def __del__(self):        self.mpoDataset = None        self.__mpData = None        self.mpArray = np.array([])        self.mgDataType = GDT_Byte        self.mnRows = self.mnCols = self.mnBands = -1        self.mnDatalength = -1        self.mpGeoTransform = []        self.msProjectionRef = ""        self.msFilename = ""        self.mdInvalidValue = 0.0        self.mnPerPixSize = 1        self.srcSR = None        self.latLongSR = None        self.poTransform = None        self.poTransformT = None    def read(self, band, row, col):        return self.mpArray[band, row, col]    def printimg(self):        print self.mpArray    def isValid(self):        if self.__mpData == None or self.mpoDataset == None:            return False        return True    def world2Pixel(self, lat, lon):        if self.poTransformT is not None:            CST = osr.CoordinateTransformation(self.poTransformT)            CST.TransformPoint(lon, lat)            adfInverseGeoTransform = []            x = y = 0.0            gdal.InvGeoTransform(self.mpGeoTransform, adfInverseGeoTransform)            gdal.ApplyGeoTransform(adfInverseGeoTransform, lon, lat, x, y)        return {'x': x, 'y': y}    def pixel2World(self, x, y):        if self.poTransform is not None:            self.poTransform = None            self.poTransform = osr.CoordinateTransformation(self.latLongSR, self.srcSR)        lon = lat = 0.0        gdal.ApplyGeoTransform(self.mpGeoTransform, x, y, lon, lat)        if self.poTransform is not None:            CST = osr.CoordinateTransformation(self.poTransform)            CST.TransformPoint(lon, lat)        return {'lon': lon, 'lat': lat}    def pixel2Ground(self, x, y):        pX = pY = 0.0        gdal.ApplyGeoTransform(self.mpGeoTransform, x, y, pX, pY)        return {'pX': pX, 'pY': pY}    def ground2Pixel(self, pX, pY):        x = y = 0.0        adfInverseGeoTransform = []        gdal.InvGeoTransform(self.mpGeoTransform, adfInverseGeoTransform)        gdal.ApplyGeoTransform(adfInverseGeoTransform, pX, pY, x, y)        return {'x': x, 'y': y}    def loadFrom(self,filename):        #close fore image        self.mpoDataset = None        #open image        try:            self.mpoDataset = gdal.Open( filename, GA_ReadOnly )        except RuntimeError, e:            print 'Unable to open %s' % filename            print e            return False        self.msFilename = filename        #get attribute        self.mnRows = self.mpoDataset.RasterYSize        self.mnCols = self.mpoDataset.RasterXSize        self.mnBands = self.mpoDataset.RasterCount        self.mgDataType = self.mpoDataset.GetRasterBand(1).DataType        self.mdInvalidValue = self.mpoDataset.GetRasterBand(1).GetNoDataValue()        #mapinfo        '''        GeoTransform[0] /* top left x */        GeoTransform[1] /* w-e pixel resolution */        GeoTransform[2] /* 0 */        GeoTransform[3] /* top left y */        GeoTransform[4] /* 0 */        GeoTransform[5] /* n-s pixel resolution (negative value) */        '''        self.mpGeoTransform = self.mpoDataset.GetGeoTransform()        self.msProjectionRef = self.mpoDataset.GetProjection()        self.srcSR = osr.SpatialReference(self.msProjectionRef) #ground        self.latLongSR = osr.SpatialReference()        self.latLongSR = osr.SpatialReference.CloneGeogCS(self.srcSR ) #geo        self.poTransform = osr.CoordinateTransformation(self.srcSR, self.latLongSR)        self.poTransformT = osr.CoordinateTransformation(self.latLongSR, self.srcSR)        #get data        self.msDataType = "Byte"        typeformat = "B"        if self.mgDataType == GDT_Byte:            typeformat = "B"            self.msDataType = "Byte"        elif self.mgDataType == GDT_UInt16:           typeformat = "H"           self.msDataType = "Unsigned Int 16"        elif self.mgDataType == GDT_Int16:            typeformat = "h"            self.msDataType = "Signed Int 16"        elif self.mgDataType == GDT_UInt32:            typeformat = "I"            self.msDataType = "Unsigned Int 32"        elif self.mgDataType == GDT_Int32:            typeformat = "i"            self.msDataType = "Signed Int 32"        elif self.mgDataType == GDT_Float32:            typeformat = "f"            self.msDataType = "Float 32"        elif self.mgDataType == GDT_Float64:            typeformat = "d"            self.msDataType = "Float 64"        self.__mpData = struct.unpack(typeformat*self.mnBands*self.mnCols*self.mnRows, self.mpoDataset.ReadRaster())        self.mpArray = np.array(self.__mpData)        self.mpArray.shape = (self.mnBands, self.mnRows, self.mnCols)        return True    def getRasterBand(self, band_num):        """获取特定波段的数据        """        try:            srcband = self.mpoDataset.GetRasterBand(band_num)            return srcband        except RuntimeError, e:            print 'Band ( %i ) not found' % band_num            print e            sys.exit(0)    def getRasterBand2Array(self, band_num):        """获取特定波段的数据,存储为数组"""        srcband = self.mpoDataset.GetRasterBand(band_num)        return srcband.ReadAsArray()    def getRasterBandStas(self, band_num):        """获取特定波段的统计量(最小值,最大值,均值,标准差)"""        srcband = self.mpoDataset.GetRasterBand(band_num)        if srcband is None:            print "Band %i is NULL" % band_num            sys.exit(1)        stats = srcband.GetStatistics(True, True)        if stats is None:            print "Statistics of Band %i is NULL" % band_num            sys.exit(1)        print "[ STATS ] =  Minimum=%.3f, Maximum=%.3f, Mean=%.3f, StdDev=%.3f" % (        stats[0], stats[1], stats[2], stats[3])    def getRasterBandInfo(self, band_num):        """获取特定波段的描述数据"""        srcband = self.mpoDataset.GetRasterBand(band_num)        if srcband is None:            print "Band %i is NULL" % band_num            sys.exit(1)        print "[ NO DATA VALUE ] = ", srcband.GetNoDataValue()        print "[ MIN ] = ", srcband.GetMinimum()        print "[ MAX ] = ", srcband.GetMaximum()        print "[ SCALE ] = ", srcband.GetScale()        print "[ UNIT TYPE ] = ", srcband.GetUnitType()        ctable = srcband.GetColorTable()        if ctable is None:            print 'No ColorTable found'            sys.exit(1)        print "[ COLOR TABLE COUNT ] = ", ctable.GetCount()        for i in range(0, ctable.GetCount()):            entry = ctable.GetColorEntry(i)            if not entry:                continue            print "[ COLOR ENTRY RGB ] = ", ctable.GetColorEntryAsRGB(i, entry)    def getRasterBandMinVal(self, band_num):        """获取某个波段的最小值"""        _arr = self.mpArray[band_num-1,:,:]        if self.mdInvalidValue != None:            _arr[_arr == self.mdInvalidValue] = np.nan        return np.nanmin(_arr)    def getRasterBandMaxVal(self, band_num):        """由于精度问题,显示一位小数,但计算不出错"""        _arr = self.mpArray[band_num - 1, :,:]        if self.mdInvalidValue != None:            _arr[_arr == self.mdInvalidValue] = np.nan        return np.nanmax(_arr)    def getRasterBandMeanVal(self, band_num):        """均值"""        _arr = self.mpArray[band_num - 1, :,:]        if self.mdInvalidValue != None:            _arr[_arr == self.mdInvalidValue] = np.nan        return np.nanmean(_arr)    def getRasterBandStdVal(self, band_num):        """标准差"""        _arr = self.mpArray[band_num - 1, :,:]        if self.mdInvalidValue != None:            _arr[_arr == self.mdInvalidValue] = np.nan        return np.nanstd(_arr)    def getRasterBandVarVal(self, band_num):        """方差"""        _arr = self.mpArray[band_num - 1, :,:]        if self.mdInvalidValue != None:            _arr[_arr == self.mdInvalidValue] = np.nan        return np.nanvar(_arr)    def raster2shp(self, band_num, dst_layername):        """栅格转矢量,慎用"""        srcband = self.mpoDataset.GetRasterBand(band_num)        drv = ogr.GetDriverByName("ESRI Shapefile")        dst_ds = drv.CreateDataSource(dst_layername + ".shp")        dst_layer = dst_ds.CreateLayer(dst_layername, srs=None)        gdal.Polygonize(srcband, None, dst_layer, -1, [], callback=None)    def replaceNoData2New(self, ds_fn, new_NoData):        """用新的值替代原先的nodata值"""        outArr = np.zeros(self.mnBands * self.mnRows * self.mnCols)        outArr.shape = (self.mnBands, self.mnRows, self.mnCols)        for band_num in range(1, self.mnBands + 1):            self.mpoDataset.GetRasterBand(band_num).SetNoDataValue(-9999)            org_Nodata = -9999            rasterArray = self.getRasterBand2Array(band_num)            rasterArray[rasterArray == org_Nodata] = new_NoData            outArr[band_num - 1, :, :] = rasterArray        array2MultiBandsrasterfn(self.msFilename, ds_fn, outArr, self.mnBands)    def linearEnhance(self, ds_fn, _MinValue, _MaxValue):        """线性增强处理,指定拉伸后的最大最小值,float64型"""        outArr = np.zeros(self.mnBands * self.mnRows * self.mnCols)        outArr.shape = (self.mnBands, self.mnRows, self.mnCols)        for band_num in range(1, self.mnBands + 1):            print "Linear Cal %i/%i"%(band_num, self.mnBands)            srcband = self.mpoDataset.GetRasterBand(band_num)            _nodata = srcband.GetNoDataValue()            _array = self.getRasterBand2Array(band_num)            _newarray = _array.astype(np.float32)            _min = self.getRasterBandMinVal(band_num)            _max = self.getRasterBandMaxVal(band_num)            #print _min, _max            for i in range(self.mnRows):                for j in range(self.mnCols):                    if _array[i][j] >= _min and _array[i][j] <= _max:                        _newarray[i][j] = (_array[i][j] - _min) / ((_max - _min) * 1.0) * (                        _MaxValue - _MinValue) + _MinValue                    else:                        _newarray[i][j] = _nodata            outArr[band_num - 1, :, :] = _newarray        print "Writing output data..."        array2MultiBandsrasterfn(self.msFilename, ds_fn, outArr, self.mnBands, self.mdInvalidValue)    def spatialFiltering(self, ds_fn, sAlgorithm = "MeanFiltering"):        """空间滤波增强"""        window_size = 3        if window_size%2 == 0:            print "Please input a uneven number for the window size!"            sys.exit(1)        subsize = (window_size-1)/2        #输出文件        outArr = np.zeros(self.mnBands * self.mnRows * self.mnCols)        outArr.shape = (self.mnBands, self.mnRows, self.mnCols)        algori = np.ones(1 * window_size * window_size, dtype=float)        algori.shape = (window_size, window_size)        # 选择算子        if sAlgorithm == "MeanFiltering":            algori /= (window_size*window_size)        elif sAlgorithm == "LaplaceFiltering":            algori = np.array([[-1.0,-1.0,-1.0],[-1.0,9,-1],[-1,-1,-1]])        elif sAlgorithm == "WallisFiltering":            algori = np.array([[0,-0.25,0],[-0.25,1,-0.25],[0,-0.25,0]])        elif sAlgorithm == "SobelXFiltering":            algori = np.array([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]])        elif sAlgorithm == "SobelYFiltering":            algori = np.array([[1.0, 2.0, 1.0], [0.0, 0.0, 0.0], [-1.0, -2.0, -1.0]])        elif sAlgorithm == "LogFiltering":            window_size = 5            subsize = (window_size - 1) / 2            algori = np.ones(1 * window_size * window_size, dtype=float)            algori.shape = (window_size, window_size)            algori = np.array([[-2.,-4.,-4.,-4.,-2.],                               [-4.,0.,8.,0.,-4.],                               [-4.,8.,24.,8.,-4.],                               [-4., 0., 8., 0., -4.],                               [-2., -4., -4., -4., -2.]                               ])        elif sAlgorithm == "RelievoFiltering":            algori = np.array([[-3.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 3.0]])        elif sAlgorithm == "HorizonalMaskFiltering":            algori = np.array([[3.0, 3.0, 3.0], [-6.0, -6.0, -6.0], [3.0, 3.0, 3.0]])        elif sAlgorithm == "VerticalnMaskFiltering":            algori = np.array([[3.0, -6.0, 3.0], [3.0, -6.0, 3.0], [3.0, -6.0, 3.0]])        elif sAlgorithm == "DiagonalMaskFiltering":            algori = np.array([[3.0, 3.0, -6.0], [3.0, -6.0, 3.0], [-6.0, 3.0, 3.0]])        elif sAlgorithm== "QualcommEdgeDec":            algori = np.array([[-1.0, 0.0, -1.0], [0.0, 4.0, 0.0], [-1.0, 0.0, 1.0]])        else:            print "There is no such filtering algorithm called %s"%sAlgorithm            sys.exit(1)        print "Filtering Algorithm: \n", algori        #波段迭代循环        for band_num in range(1, self.mnBands +1):            _arr = np.zeros(1 * self.mnRows * self.mnCols, dtype=float)            _arr.shape = (self.mnRows, self.mnCols)            for i in range(0, self.mnRows):                for j in range(0, self.mnCols):                    #边缘维持原像元值                    if i<=subsize-1 or j<=subsize-1 or i>= self.mnRows-subsize or j >= self.mnCols-subsize:                        _arr[i][j] = self.mpArray[band_num-1][i][j]                    else:                        for x in range(0, window_size ):                            for y in range(0, window_size):                                _arr[i][j] += self.mpArray[band_num-1][i - subsize + x][j - subsize + y] * algori[x][y]            outArr[band_num - 1, :, :] = _arr            print "Filtered %i/%i"%(band_num, self.mnBands)        print "Writing output file..."        array2MultiBandsrasterfn(self.msFilename, ds_fn, outArr, self.mnBands,self.mdInvalidValue)    def rasterCalculation(self, ds_fn, expr = "(band3-band2)/(band3+band2)"):        """栅格计算,暂时只支持用band1,2之类的形式表示各个波段"""        mode = re.compile(r'\d+')        m = mode.findall(expr)        nums = np.unique(np.array(m))        sortedNums = np.sort(nums)        for num in sortedNums:            expr = expr.replace(num, str(int(num)-1)+',:,:]')        expr = expr.replace('band','1.0*self.mpArray[')        print expr        resultArr = eval(expr)        array2MultiBandsrasterfn(self.msFilename,ds_fn,resultArr,1,self.mdInvalidValue)    def printRasterAttr(self):        """显示图像信息"""        print "File Name: %s"%self.msFilename        print "Rows: %i   Cols: %i   Bands: %i   Pixel Size: %.2f*%.2f"%(self.mnRows, self.mnCols,                self.mnBands, self.mpGeoTransform[1],-self.mpGeoTransform[5])        print "Data Type: %s   No-Data Value: "%(self.msDataType),self.mdInvalidValue        print "SpatialRef: %s    \nProjection: %s"%(self.mpGeoTransform, self.msProjectionRef)def array2MultiBandsrasterfn(rasterfn, newRasterfn, array, bandCount, nodata = None):    """文件尺度上数组生成栅格文件,前者栅格文件提供描述信息(多波段)"""    raster = gdal.Open(rasterfn)    geotransform = raster.GetGeoTransform()    originX = geotransform[0]    originY = geotransform[3]    pixelWidth = geotransform[1]    pixelHeight = geotransform[5]    cols = raster.RasterXSize    rows = raster.RasterYSize    array.shape = (bandCount, rows, cols)    driver = gdal.GetDriverByName('GTiff')    outRaster = driver.Create(newRasterfn, cols, rows, bandCount, gdal.GDT_Float32)    outRaster.SetGeoTransform((originX, pixelWidth, 0, originY, 0, pixelHeight))    outRasterSRS = osr.SpatialReference()    outRasterSRS.ImportFromWkt(raster.GetProjectionRef())    outRaster.SetProjection(outRasterSRS.ExportToWkt())    for band_num in range(1, bandCount + 1):        outband = outRaster.GetRasterBand(band_num)        outband.SetNoDataValue(nodata)        outband.WriteArray(array[band_num - 1, :, :])        outband.FlushCache()    print "write output file -- %s success!"%newRasterfndef creatraster(newRasterfn, GeoTransform, projection, datatype, imgdata, cols, rows, bands):    #必须使用numpy下的numpy.array作为imgdata    if bands == 1:        imgdata.shape = (bands, rows, cols)    driver = gdal.GetDriverByName('GTiff')    outRaster = driver.Create(newRasterfn, cols, rows, bands, datatype)    outRaster.SetGeoTransform(GeoTransform)    outRaster.SetProjection(projection)    for i in range(bands):        array = imgdata[i, :, :]        outband = outRaster.GetRasterBand(i+1)        outband.WriteArray(array)    print "write data succeed!"def rasterCalculations(ds_fn, expr):    """仅支持tif格式",表达式中要写文件后缀"""    m = re.findall(r'([a-z,A-Z,_]+[1-9,a-z,A-Z,_]*.tif)', expr)    unim = np.unique(np.array(m))    print unim    i = 0    mArrs = []    for item in unim:        Cgdal = CGDAL()        Cgdal.loadFrom(item)        if i ==0:            no_data = Cgdal.mdInvalidValue        if Cgdal.mnBands != 1:            print "The input raster is not useful. Only 1 band is required, %i is given."%Cgdal.mnBands            sys.exit(1)        Cgdal.mpArray[Cgdal.mpArray == no_data] = np.nan        mArrs.append(Cgdal.mpArray)        expr = expr.replace(item,'1.0*mArrs[%i]'%i)        i = i + 1    print expr    resultArr = eval(expr)    array2MultiBandsrasterfn(unim[0],ds_fn,resultArr,1,nodata=no_data)

0 0