Python CGDAL类——支持栅格数据的栅格计算/线性增强/滤波增强
来源:互联网 发布:凤凰直播软件下载 编辑:程序博客网 时间:2024/05/19 04:07
# -*- coding: UTF-8 -*-'''python version: 2.7.11numpy ver=1.11.1gdal ver=2.0.3Author: LiuphDate: 2016/9/9Description: This is a GDAL Class adapted from the Python GDAL_OGR Cookbook documentation(http://pcjericks.github.io/py-gdalogr-cookbook/raster_layers.html). The CGDAL Class can used toload Image and obtain description information of the image file. Usually, basic image processingis included in the class and a linear enhancement or a spatial filtering operation can well performedvia CGDAL Class. Moreover, this class offers some functions to generate a raster image via a numpy array.However,it' does not process perfection in exception handling, for instance, while a image with "None"as nodata value might lead to some puzzles, or, it wont't happen.It need to be confirmed. Addition, as all know, codes in python formats is very concise and distinct. However, the ratio of running speed of python codes and C++ codes is about first in thirty ,which is a serious and longstanding problem, or shortcut. May python be better!'''from osgeo import gdal, gdalnumeric, ogr, osrfrom PIL import Image, ImageDrawimport os, sysfrom gdalconst import *import structimport numpy as npimport regdal.UseExceptions()class CGDAL: #数据部分 mpoDataset = None __mpData = None mpArray = np.array([]) mgDataType = GDT_Byte mnRows = mnCols = mnBands = -1 mnDatalength = -1 mpGeoTransfor = [] msProjectionRef = "" msFilename = "" mdInvalidValue = 0.0 mnPerPixSize = 1 srcSR = None latLongSR = None poTransform = None poTransformT = None #函数部分 def __init__(self): pass def __del__(self): self.mpoDataset = None self.__mpData = None self.mpArray = np.array([]) self.mgDataType = GDT_Byte self.mnRows = self.mnCols = self.mnBands = -1 self.mnDatalength = -1 self.mpGeoTransform = [] self.msProjectionRef = "" self.msFilename = "" self.mdInvalidValue = 0.0 self.mnPerPixSize = 1 self.srcSR = None self.latLongSR = None self.poTransform = None self.poTransformT = None def read(self, band, row, col): return self.mpArray[band, row, col] def printimg(self): print self.mpArray def isValid(self): if self.__mpData == None or self.mpoDataset == None: return False return True def world2Pixel(self, lat, lon): if self.poTransformT is not None: CST = osr.CoordinateTransformation(self.poTransformT) CST.TransformPoint(lon, lat) adfInverseGeoTransform = [] x = y = 0.0 gdal.InvGeoTransform(self.mpGeoTransform, adfInverseGeoTransform) gdal.ApplyGeoTransform(adfInverseGeoTransform, lon, lat, x, y) return {'x': x, 'y': y} def pixel2World(self, x, y): if self.poTransform is not None: self.poTransform = None self.poTransform = osr.CoordinateTransformation(self.latLongSR, self.srcSR) lon = lat = 0.0 gdal.ApplyGeoTransform(self.mpGeoTransform, x, y, lon, lat) if self.poTransform is not None: CST = osr.CoordinateTransformation(self.poTransform) CST.TransformPoint(lon, lat) return {'lon': lon, 'lat': lat} def pixel2Ground(self, x, y): pX = pY = 0.0 gdal.ApplyGeoTransform(self.mpGeoTransform, x, y, pX, pY) return {'pX': pX, 'pY': pY} def ground2Pixel(self, pX, pY): x = y = 0.0 adfInverseGeoTransform = [] gdal.InvGeoTransform(self.mpGeoTransform, adfInverseGeoTransform) gdal.ApplyGeoTransform(adfInverseGeoTransform, pX, pY, x, y) return {'x': x, 'y': y} def loadFrom(self,filename): #close fore image self.mpoDataset = None #open image try: self.mpoDataset = gdal.Open( filename, GA_ReadOnly ) except RuntimeError, e: print 'Unable to open %s' % filename print e return False self.msFilename = filename #get attribute self.mnRows = self.mpoDataset.RasterYSize self.mnCols = self.mpoDataset.RasterXSize self.mnBands = self.mpoDataset.RasterCount self.mgDataType = self.mpoDataset.GetRasterBand(1).DataType self.mdInvalidValue = self.mpoDataset.GetRasterBand(1).GetNoDataValue() #mapinfo ''' GeoTransform[0] /* top left x */ GeoTransform[1] /* w-e pixel resolution */ GeoTransform[2] /* 0 */ GeoTransform[3] /* top left y */ GeoTransform[4] /* 0 */ GeoTransform[5] /* n-s pixel resolution (negative value) */ ''' self.mpGeoTransform = self.mpoDataset.GetGeoTransform() self.msProjectionRef = self.mpoDataset.GetProjection() self.srcSR = osr.SpatialReference(self.msProjectionRef) #ground self.latLongSR = osr.SpatialReference() self.latLongSR = osr.SpatialReference.CloneGeogCS(self.srcSR ) #geo self.poTransform = osr.CoordinateTransformation(self.srcSR, self.latLongSR) self.poTransformT = osr.CoordinateTransformation(self.latLongSR, self.srcSR) #get data self.msDataType = "Byte" typeformat = "B" if self.mgDataType == GDT_Byte: typeformat = "B" self.msDataType = "Byte" elif self.mgDataType == GDT_UInt16: typeformat = "H" self.msDataType = "Unsigned Int 16" elif self.mgDataType == GDT_Int16: typeformat = "h" self.msDataType = "Signed Int 16" elif self.mgDataType == GDT_UInt32: typeformat = "I" self.msDataType = "Unsigned Int 32" elif self.mgDataType == GDT_Int32: typeformat = "i" self.msDataType = "Signed Int 32" elif self.mgDataType == GDT_Float32: typeformat = "f" self.msDataType = "Float 32" elif self.mgDataType == GDT_Float64: typeformat = "d" self.msDataType = "Float 64" self.__mpData = struct.unpack(typeformat*self.mnBands*self.mnCols*self.mnRows, self.mpoDataset.ReadRaster()) self.mpArray = np.array(self.__mpData) self.mpArray.shape = (self.mnBands, self.mnRows, self.mnCols) return True def getRasterBand(self, band_num): """获取特定波段的数据 """ try: srcband = self.mpoDataset.GetRasterBand(band_num) return srcband except RuntimeError, e: print 'Band ( %i ) not found' % band_num print e sys.exit(0) def getRasterBand2Array(self, band_num): """获取特定波段的数据,存储为数组""" srcband = self.mpoDataset.GetRasterBand(band_num) return srcband.ReadAsArray() def getRasterBandStas(self, band_num): """获取特定波段的统计量(最小值,最大值,均值,标准差)""" srcband = self.mpoDataset.GetRasterBand(band_num) if srcband is None: print "Band %i is NULL" % band_num sys.exit(1) stats = srcband.GetStatistics(True, True) if stats is None: print "Statistics of Band %i is NULL" % band_num sys.exit(1) print "[ STATS ] = Minimum=%.3f, Maximum=%.3f, Mean=%.3f, StdDev=%.3f" % ( stats[0], stats[1], stats[2], stats[3]) def getRasterBandInfo(self, band_num): """获取特定波段的描述数据""" srcband = self.mpoDataset.GetRasterBand(band_num) if srcband is None: print "Band %i is NULL" % band_num sys.exit(1) print "[ NO DATA VALUE ] = ", srcband.GetNoDataValue() print "[ MIN ] = ", srcband.GetMinimum() print "[ MAX ] = ", srcband.GetMaximum() print "[ SCALE ] = ", srcband.GetScale() print "[ UNIT TYPE ] = ", srcband.GetUnitType() ctable = srcband.GetColorTable() if ctable is None: print 'No ColorTable found' sys.exit(1) print "[ COLOR TABLE COUNT ] = ", ctable.GetCount() for i in range(0, ctable.GetCount()): entry = ctable.GetColorEntry(i) if not entry: continue print "[ COLOR ENTRY RGB ] = ", ctable.GetColorEntryAsRGB(i, entry) def getRasterBandMinVal(self, band_num): """获取某个波段的最小值""" _arr = self.mpArray[band_num-1,:,:] if self.mdInvalidValue != None: _arr[_arr == self.mdInvalidValue] = np.nan return np.nanmin(_arr) def getRasterBandMaxVal(self, band_num): """由于精度问题,显示一位小数,但计算不出错""" _arr = self.mpArray[band_num - 1, :,:] if self.mdInvalidValue != None: _arr[_arr == self.mdInvalidValue] = np.nan return np.nanmax(_arr) def getRasterBandMeanVal(self, band_num): """均值""" _arr = self.mpArray[band_num - 1, :,:] if self.mdInvalidValue != None: _arr[_arr == self.mdInvalidValue] = np.nan return np.nanmean(_arr) def getRasterBandStdVal(self, band_num): """标准差""" _arr = self.mpArray[band_num - 1, :,:] if self.mdInvalidValue != None: _arr[_arr == self.mdInvalidValue] = np.nan return np.nanstd(_arr) def getRasterBandVarVal(self, band_num): """方差""" _arr = self.mpArray[band_num - 1, :,:] if self.mdInvalidValue != None: _arr[_arr == self.mdInvalidValue] = np.nan return np.nanvar(_arr) def raster2shp(self, band_num, dst_layername): """栅格转矢量,慎用""" srcband = self.mpoDataset.GetRasterBand(band_num) drv = ogr.GetDriverByName("ESRI Shapefile") dst_ds = drv.CreateDataSource(dst_layername + ".shp") dst_layer = dst_ds.CreateLayer(dst_layername, srs=None) gdal.Polygonize(srcband, None, dst_layer, -1, [], callback=None) def replaceNoData2New(self, ds_fn, new_NoData): """用新的值替代原先的nodata值""" outArr = np.zeros(self.mnBands * self.mnRows * self.mnCols) outArr.shape = (self.mnBands, self.mnRows, self.mnCols) for band_num in range(1, self.mnBands + 1): self.mpoDataset.GetRasterBand(band_num).SetNoDataValue(-9999) org_Nodata = -9999 rasterArray = self.getRasterBand2Array(band_num) rasterArray[rasterArray == org_Nodata] = new_NoData outArr[band_num - 1, :, :] = rasterArray array2MultiBandsrasterfn(self.msFilename, ds_fn, outArr, self.mnBands) def linearEnhance(self, ds_fn, _MinValue, _MaxValue): """线性增强处理,指定拉伸后的最大最小值,float64型""" outArr = np.zeros(self.mnBands * self.mnRows * self.mnCols) outArr.shape = (self.mnBands, self.mnRows, self.mnCols) for band_num in range(1, self.mnBands + 1): print "Linear Cal %i/%i"%(band_num, self.mnBands) srcband = self.mpoDataset.GetRasterBand(band_num) _nodata = srcband.GetNoDataValue() _array = self.getRasterBand2Array(band_num) _newarray = _array.astype(np.float32) _min = self.getRasterBandMinVal(band_num) _max = self.getRasterBandMaxVal(band_num) #print _min, _max for i in range(self.mnRows): for j in range(self.mnCols): if _array[i][j] >= _min and _array[i][j] <= _max: _newarray[i][j] = (_array[i][j] - _min) / ((_max - _min) * 1.0) * ( _MaxValue - _MinValue) + _MinValue else: _newarray[i][j] = _nodata outArr[band_num - 1, :, :] = _newarray print "Writing output data..." array2MultiBandsrasterfn(self.msFilename, ds_fn, outArr, self.mnBands, self.mdInvalidValue) def spatialFiltering(self, ds_fn, sAlgorithm = "MeanFiltering"): """空间滤波增强""" window_size = 3 if window_size%2 == 0: print "Please input a uneven number for the window size!" sys.exit(1) subsize = (window_size-1)/2 #输出文件 outArr = np.zeros(self.mnBands * self.mnRows * self.mnCols) outArr.shape = (self.mnBands, self.mnRows, self.mnCols) algori = np.ones(1 * window_size * window_size, dtype=float) algori.shape = (window_size, window_size) # 选择算子 if sAlgorithm == "MeanFiltering": algori /= (window_size*window_size) elif sAlgorithm == "LaplaceFiltering": algori = np.array([[-1.0,-1.0,-1.0],[-1.0,9,-1],[-1,-1,-1]]) elif sAlgorithm == "WallisFiltering": algori = np.array([[0,-0.25,0],[-0.25,1,-0.25],[0,-0.25,0]]) elif sAlgorithm == "SobelXFiltering": algori = np.array([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]]) elif sAlgorithm == "SobelYFiltering": algori = np.array([[1.0, 2.0, 1.0], [0.0, 0.0, 0.0], [-1.0, -2.0, -1.0]]) elif sAlgorithm == "LogFiltering": window_size = 5 subsize = (window_size - 1) / 2 algori = np.ones(1 * window_size * window_size, dtype=float) algori.shape = (window_size, window_size) algori = np.array([[-2.,-4.,-4.,-4.,-2.], [-4.,0.,8.,0.,-4.], [-4.,8.,24.,8.,-4.], [-4., 0., 8., 0., -4.], [-2., -4., -4., -4., -2.] ]) elif sAlgorithm == "RelievoFiltering": algori = np.array([[-3.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 3.0]]) elif sAlgorithm == "HorizonalMaskFiltering": algori = np.array([[3.0, 3.0, 3.0], [-6.0, -6.0, -6.0], [3.0, 3.0, 3.0]]) elif sAlgorithm == "VerticalnMaskFiltering": algori = np.array([[3.0, -6.0, 3.0], [3.0, -6.0, 3.0], [3.0, -6.0, 3.0]]) elif sAlgorithm == "DiagonalMaskFiltering": algori = np.array([[3.0, 3.0, -6.0], [3.0, -6.0, 3.0], [-6.0, 3.0, 3.0]]) elif sAlgorithm== "QualcommEdgeDec": algori = np.array([[-1.0, 0.0, -1.0], [0.0, 4.0, 0.0], [-1.0, 0.0, 1.0]]) else: print "There is no such filtering algorithm called %s"%sAlgorithm sys.exit(1) print "Filtering Algorithm: \n", algori #波段迭代循环 for band_num in range(1, self.mnBands +1): _arr = np.zeros(1 * self.mnRows * self.mnCols, dtype=float) _arr.shape = (self.mnRows, self.mnCols) for i in range(0, self.mnRows): for j in range(0, self.mnCols): #边缘维持原像元值 if i<=subsize-1 or j<=subsize-1 or i>= self.mnRows-subsize or j >= self.mnCols-subsize: _arr[i][j] = self.mpArray[band_num-1][i][j] else: for x in range(0, window_size ): for y in range(0, window_size): _arr[i][j] += self.mpArray[band_num-1][i - subsize + x][j - subsize + y] * algori[x][y] outArr[band_num - 1, :, :] = _arr print "Filtered %i/%i"%(band_num, self.mnBands) print "Writing output file..." array2MultiBandsrasterfn(self.msFilename, ds_fn, outArr, self.mnBands,self.mdInvalidValue) def rasterCalculation(self, ds_fn, expr = "(band3-band2)/(band3+band2)"): """栅格计算,暂时只支持用band1,2之类的形式表示各个波段""" mode = re.compile(r'\d+') m = mode.findall(expr) nums = np.unique(np.array(m)) sortedNums = np.sort(nums) for num in sortedNums: expr = expr.replace(num, str(int(num)-1)+',:,:]') expr = expr.replace('band','1.0*self.mpArray[') print expr resultArr = eval(expr) array2MultiBandsrasterfn(self.msFilename,ds_fn,resultArr,1,self.mdInvalidValue) def printRasterAttr(self): """显示图像信息""" print "File Name: %s"%self.msFilename print "Rows: %i Cols: %i Bands: %i Pixel Size: %.2f*%.2f"%(self.mnRows, self.mnCols, self.mnBands, self.mpGeoTransform[1],-self.mpGeoTransform[5]) print "Data Type: %s No-Data Value: "%(self.msDataType),self.mdInvalidValue print "SpatialRef: %s \nProjection: %s"%(self.mpGeoTransform, self.msProjectionRef)def array2MultiBandsrasterfn(rasterfn, newRasterfn, array, bandCount, nodata = None): """文件尺度上数组生成栅格文件,前者栅格文件提供描述信息(多波段)""" raster = gdal.Open(rasterfn) geotransform = raster.GetGeoTransform() originX = geotransform[0] originY = geotransform[3] pixelWidth = geotransform[1] pixelHeight = geotransform[5] cols = raster.RasterXSize rows = raster.RasterYSize array.shape = (bandCount, rows, cols) driver = gdal.GetDriverByName('GTiff') outRaster = driver.Create(newRasterfn, cols, rows, bandCount, gdal.GDT_Float32) outRaster.SetGeoTransform((originX, pixelWidth, 0, originY, 0, pixelHeight)) outRasterSRS = osr.SpatialReference() outRasterSRS.ImportFromWkt(raster.GetProjectionRef()) outRaster.SetProjection(outRasterSRS.ExportToWkt()) for band_num in range(1, bandCount + 1): outband = outRaster.GetRasterBand(band_num) outband.SetNoDataValue(nodata) outband.WriteArray(array[band_num - 1, :, :]) outband.FlushCache() print "write output file -- %s success!"%newRasterfndef creatraster(newRasterfn, GeoTransform, projection, datatype, imgdata, cols, rows, bands): #必须使用numpy下的numpy.array作为imgdata if bands == 1: imgdata.shape = (bands, rows, cols) driver = gdal.GetDriverByName('GTiff') outRaster = driver.Create(newRasterfn, cols, rows, bands, datatype) outRaster.SetGeoTransform(GeoTransform) outRaster.SetProjection(projection) for i in range(bands): array = imgdata[i, :, :] outband = outRaster.GetRasterBand(i+1) outband.WriteArray(array) print "write data succeed!"def rasterCalculations(ds_fn, expr): """仅支持tif格式",表达式中要写文件后缀""" m = re.findall(r'([a-z,A-Z,_]+[1-9,a-z,A-Z,_]*.tif)', expr) unim = np.unique(np.array(m)) print unim i = 0 mArrs = [] for item in unim: Cgdal = CGDAL() Cgdal.loadFrom(item) if i ==0: no_data = Cgdal.mdInvalidValue if Cgdal.mnBands != 1: print "The input raster is not useful. Only 1 band is required, %i is given."%Cgdal.mnBands sys.exit(1) Cgdal.mpArray[Cgdal.mpArray == no_data] = np.nan mArrs.append(Cgdal.mpArray) expr = expr.replace(item,'1.0*mArrs[%i]'%i) i = i + 1 print expr resultArr = eval(expr) array2MultiBandsrasterfn(unim[0],ds_fn,resultArr,1,nodata=no_data)
0 0
- Python CGDAL类——支持栅格数据的栅格计算/线性增强/滤波增强
- QGis二次开发基础 -- 栅格图像增强显示
- QGis二次开发基础 -- 栅格图像增强显示
- Python gdal 读取栅格数据
- 栅格数据
- 栅格数据
- 三种利用Python批量处理地理数据的方法——以栅格数据投影转换为例
- 栅格
- 栅格数据的属性信息
- CityEngine支持多少种栅格数据?
- 【技术类】【栅格那点儿事(四B)】多波段栅格数据的显示
- arcgis server 获取SDE中栅格数据的栅格值
- 图像增强——同态滤波
- 图像增强之——同态滤波
- 矢量数据向栅格数据的转换
- web——栅格系统
- 【03】Bootstrap — 栅格系统
- 矢量数据栅格化——多边形填充
- PDB文件头中时间格式解析
- QT 2D绘图学习文档
- IOS导航栏颜色渐变与常用属性
- 如何理解hadoop的安装方式
- Web服务器工作原理概述
- Python CGDAL类——支持栅格数据的栅格计算/线性增强/滤波增强
- 异常:Null value was assigned to a property of primitive type setter of···
- 169. Majority Element
- 中兴2017面试总结
- PAT甲 1007. Maximum Subsequence Sum (25)
- 我的小一步,争取是人类的一大步
- jQuery基础
- sql 多表连接与子查询
- Docker基础教程——数据管理