Extract and display raster windows from 16-bit satellite imagery using pixel coordinates or geo coordinates. We first presented this tutorial as part of a three-hour session on Working with Geographic Information Systems in Python during the 2009 Python Conference in Chicago, Illinois.
Make sure all dependencies are installed.
yum install gdal gdal-python numpy python-imaging ipython wxPython qgis
Download the code and data and run the scripts
wget http://invisibleroads.com/tutorials/_downloads/gdal-raster-extract.zip
unzip gdal-raster-extract.zip
cd gdal-raster-extract
cd examples
python ../extractSamples.py multispectral.tif panchromatic.tif locations.shp
python ../browseSamples.py
Open the generated SQLite database samples.db and browse the extracted samples. Below you can see the first sample with the four low-resolution multispectral bands on the left and the high-resolution panchromatic band on the right.
To run browseSamples.py, you will also need the following:
Make sure all dependencies are installed.
yum install gdal gdal-python numpy python-imaging ipython wxPython qgis
Download the code and data and start IPython.
wget http://invisibleroads.com/tutorials/_downloads/gdal-raster-extract.zip
unzip gdal-raster-extract.zip
cd gdal-raster-extract
ipython
The package contains two scripts and example data as well as supporting library modules.
In [1]: ls
browseSamples.py* examples/ extractSamples.py* libraries/
In [2]: ls examples/
locations.shp
locations.shx
locations.prj
locations.dbf
multispectral.tif
multispectral.tif.aux.xml
panchromatic.tif
In [3]: ls libraries/
__init__.py
image_store.py
point_store.py
sample_store.py
sample_store_lush.py
sequence.py
store.py
window_process.py
view.py
The example data contains a low-resolution multispectral image, a high-resolution panchromatic image and several hand-marked locations of buildings in the image. The imagery shows a region of Boulder, Colorado cropped from DigitalGlobe’s sample standard 16-bit imagery. Note that for higher geo-referencing accuracy, you would probably use ortho-ready imagery instead of the coarsely-orthorectified standard imagery and perform orthorectification yourself using a high-resolution digital elevation model (DEM) as described in the tutorial Orthorectify satellite images with ENVI.
Multispectral image with locations
Panchromatic image with locations
Load panchromatic image.
import osgeo.gdal
imageDataset = osgeo.gdal.Open('examples/panchromatic.tif')
Define methods that we will use later. Note that we use pylab.imshow() because PIL’s Image class has difficulty handling 16-bit image data. We want to keep the 16-bit values because each raw image bit has potential information in remote sensing.
import struct, numpy, pylab
def extractWindow(pixelX, pixelY, pixelWidth, pixelHeight):
# Extract raw data
band = imageDataset.GetRasterBand(1)
byteString = band.ReadRaster(pixelX, pixelY, pixelWidth, pixelHeight)
# Convert to a matrix
valueType = {osgeo.gdal.GDT_Byte: 'B', osgeo.gdal.GDT_UInt16: 'H'}[band.DataType]
values = struct.unpack('%d%s' % (pixelWidth * pixelHeight, valueType), byteString)
matrix = numpy.reshape(values, (pixelWidth, pixelHeight))
# Display matrix
pylab.imshow(matrix, cmap=pylab.cm.gray)
pylab.show()
# Return
return matrix
def extractCenteredWindow(pixelX, pixelY, pixelWidth, pixelHeight):
centeredPixelX = pixelX - pixelWidth / 2
centeredPixelY = pixelY - pixelHeight / 2
return extractWindow(centeredPixelX, centeredPixelY, pixelWidth, pixelHeight)
Extract a 100x100 pixel window near the middle of the image.
extractCenteredWindow(imageDataset.RasterXSize / 2, imageDataset.RasterYSize / 2, 100, 100)
To extract raster windows using geocoordinates, we must convert the geocoordinates to their corresponding pixel locations in the image. Each image has a set of numbers called the GeoTransform that tell us how to convert between geo locations and pixel locations.
Get image georeferencing information.
g0, g1, g2, g3, g4, g5 = imageDataset.GetGeoTransform()
Define conversion methods.
def convertGeoLocationToPixelLocation(geoLocation):
xGeo, yGeo = geoLocation
if g2 == 0:
xPixel = (xGeo - g0) / float(g1)
yPixel = (yGeo - g3 - xPixel*g4) / float(g5)
else:
xPixel = (yGeo*g2 - xGeo*g5 + g0*g5 - g2*g3) / float(g2*g4 - g1*g5)
yPixel = (xGeo - g0 - xPixel*g1) / float(g2)
return int(round(xPixel)), int(round(yPixel))
def convertGeoDimensionsToPixelDimensions(geoWidth, geoHeight):
return int(round(abs(float(geoWidth) / g1))), int(round(abs(float(geoHeight) / g5)))
Load locations from shapefile. For details, please see the tutorial Load points from a shapefile with GDAL.
from libraries import point_store
geoLocations, spatialReference = point_store.load('examples/locations.shp')
Convert the first geo location to a pixel location.
windowPixelX, windowPixelY = convertGeoLocationToPixelLocation(geoLocations[0])
Convert window dimensions from 25 meters to their equivalent in pixels.
windowPixelWidth, windowPixelHeight = convertGeoDimensionsToPixelDimensions(25, 25)
Now you can use the method you defined in Extract raster window from satellite image using a pixel location.
extractCenteredWindow(windowPixelX, windowPixelY, windowPixelWidth, windowPixelHeight)
#!/usr/bin/env python
# Import system modules
import optparse
# Import custom modules
from libraries import image_store, point_store, window_process
# If we are running the script from the command-line,
if __name__ == '__main__':
# Parse options and arguments
optionParser = optparse.OptionParser(
usage='%prog MULTISPECTRAL-PATH PANCHROMATIC-PATH SHAPE-PATH',
epilog=(
'Extracts raster windows from MULTISPECTRAL-PATH and '
'PANCHROMATIC-PATH using locations from SHAPE-PATH and '
'saves results in OUTPUT-PATH.'
)
)
optionParser.add_option('-o', '--output-path', dest='outputPath',
metavar='OUTPUT-PATH', default='samples.db',
help='save results in OUTPUT-PATH')
optionParser.add_option('-m', '--meters', dest='windowGeoLength',
metavar='LENGTH', default=25, type='int',
help='specify centered window length in meters')
optionParser.add_option('-l', '--label', dest='windowLabel',
metavar='INTEGER', default=1, type='int',
help='specify window label')
options, arguments = optionParser.parse_args()
# Verify
if len(arguments) == 3:
# Extract
multispectralImagePath, panchromaticImagePath, locationPath = arguments
window_process.extract(options.outputPath,
options.windowLabel, options.windowGeoLength,
image_store.load(multispectralImagePath),
image_store.load(panchromaticImagePath),
point_store.load(locationPath)[0])
else:
# Show help
optionParser.print_help()
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# generated by wxGlade 0.6.3 on Sat Sep 27 02:31:00 2008
# Import system modules
import sys
import itertools
import wx
import os
# Import custom modules
from libraries import sample_store, sample_store_lush, image_store, store
# begin wxGlade: extracode
# end wxGlade
# Set
size_small = 100, 100
size_large = 400, 400
class MainFrame(wx.Frame):
def __init__(self, *args, **kwds):
# begin wxGlade: MainFrame.__init__
kwds["style"] = wx.DEFAULT_FRAME_STYLE
wx.Frame.__init__(self, *args, **kwds)
self.button_open = wx.Button(self, -1, "Open")
self.slider = wx.Slider(self, -1, 0, 0, 10)
self.checkbox_hasRoof = wx.CheckBox(self, -1, "Has Roof")
self.label_location = wx.StaticText(self, -1, "")
self.label_red = wx.StaticText(self, -1, "Red", style=wx.ALIGN_CENTRE)
self.bitmap_red = wx.StaticBitmap(self, -1, wx.NullBitmap)
self.label_green = wx.StaticText(self, -1, "Green", style=wx.ALIGN_CENTRE)
self.bitmap_green = wx.StaticBitmap(self, -1, wx.NullBitmap)
self.label_blue = wx.StaticText(self, -1, "Blue", style=wx.ALIGN_CENTRE)
self.bitmap_blue = wx.StaticBitmap(self, -1, wx.NullBitmap)
self.label_infrared = wx.StaticText(self, -1, "Infrared", style=wx.ALIGN_CENTRE)
self.bitmap_infrared = wx.StaticBitmap(self, -1, wx.NullBitmap)
self.label_panchromatic = wx.StaticText(self, -1, "Panchromatic", style=wx.ALIGN_CENTRE)
self.bitmap_panchromatic = wx.StaticBitmap(self, -1, wx.NullBitmap)
self.button_1 = wx.Button(self, -1, "Save")
self.__set_properties()
self.__do_layout()
self.Bind(wx.EVT_BUTTON, self.onClickOpen, self.button_open)
self.Bind(wx.EVT_COMMAND_SCROLL, self.onScroll, self.slider)
self.Bind(wx.EVT_BUTTON, self.onClickSave, self.button_1)
# end wxGlade
self.samples = []
self.flag_changed = False
def __set_properties(self):
# begin wxGlade: MainFrame.__set_properties
self.SetTitle("Sample Browser")
self.SetSize((800, 600))
self.checkbox_hasRoof.Enable(False)
# end wxGlade
def __do_layout(self):
# begin wxGlade: MainFrame.__do_layout
sizer_1 = wx.BoxSizer(wx.VERTICAL)
sizer_4 = wx.BoxSizer(wx.HORIZONTAL)
sizer_6 = wx.BoxSizer(wx.VERTICAL)
grid_sizer_1 = wx.GridSizer(2, 2, 0, 0)
sizer_8 = wx.BoxSizer(wx.VERTICAL)
sizer_5 = wx.BoxSizer(wx.VERTICAL)
sizer_3 = wx.BoxSizer(wx.VERTICAL)
sizer_2 = wx.BoxSizer(wx.VERTICAL)
sizer_7 = wx.BoxSizer(wx.HORIZONTAL)
sizer_1.Add(self.button_open, 0, wx.EXPAND, 0)
sizer_1.Add(self.slider, 0, wx.EXPAND, 0)
sizer_7.Add(self.checkbox_hasRoof, 0, 0, 0)
sizer_7.Add((20, 20), 0, 0, 0)
sizer_7.Add(self.label_location, 1, wx.EXPAND, 0)
sizer_1.Add(sizer_7, 0, wx.EXPAND, 0)
sizer_2.Add(self.label_red, 0, wx.EXPAND, 0)
sizer_2.Add(self.bitmap_red, 0, wx.ALIGN_CENTER_HORIZONTAL|wx.ALIGN_CENTER_VERTICAL, 0)
grid_sizer_1.Add(sizer_2, 1, wx.EXPAND, 0)
sizer_3.Add(self.label_green, 0, wx.EXPAND, 0)
sizer_3.Add(self.bitmap_green, 0, wx.ALIGN_CENTER_HORIZONTAL|wx.ALIGN_CENTER_VERTICAL, 0)
grid_sizer_1.Add(sizer_3, 1, wx.EXPAND, 0)
sizer_5.Add(self.label_blue, 0, wx.EXPAND, 0)
sizer_5.Add(self.bitmap_blue, 0, wx.ALIGN_CENTER_HORIZONTAL|wx.ALIGN_CENTER_VERTICAL, 0)
grid_sizer_1.Add(sizer_5, 1, wx.EXPAND, 0)
sizer_8.Add(self.label_infrared, 0, wx.EXPAND, 0)
sizer_8.Add(self.bitmap_infrared, 0, wx.ALIGN_CENTER_HORIZONTAL|wx.ALIGN_CENTER_VERTICAL, 0)
grid_sizer_1.Add(sizer_8, 1, wx.EXPAND, 0)
sizer_4.Add(grid_sizer_1, 2, wx.EXPAND, 0)
sizer_6.Add(self.label_panchromatic, 0, wx.EXPAND, 0)
sizer_6.Add(self.bitmap_panchromatic, 1, wx.ALIGN_CENTER_HORIZONTAL|wx.ALIGN_CENTER_VERTICAL, 0)
sizer_4.Add(sizer_6, 3, wx.EXPAND, 0)
sizer_1.Add(sizer_4, 1, wx.EXPAND, 0)
sizer_1.Add(self.button_1, 0, wx.EXPAND, 0)
self.SetSizer(sizer_1)
self.Layout()
# end wxGlade
def onClickOpen(self, event): # wxGlade: MainFrame.<event_handler>
# Get file information
fileTypes = 'Sample databases (*.db)|*.db', 'Lush samples (*Samples)|*Samples'
dialog = wx.FileDialog(self, 'Open', style=wx.OPEN, wildcard='|'.join(fileTypes))
if dialog.ShowModal() != wx.ID_OK: return
filePath = dialog.GetPath()
self.fileType = dialog.GetFilterIndex()
dialog.Destroy()
# If the user wants to open a database,
if self.fileType == 0:
self.samples = sample_store.load(filePath).getSamples()
# If the user wants to open Lush samples,
else:
filePath = filePath.replace('Samples', '')
self.samples = [(label, sample) for label, sample in itertools.izip(sample_store_lush.makeLabelGeneratorFromLushDataset(filePath), sample_store_lush.makeSampleGeneratorFromLushDataset(filePath))]
# Set slider
self.slider.SetRange(0, len(self.samples) - 1)
# Show first sample
self.refresh()
def onScroll(self, event): # wxGlade: MainFrame.<event_handler>
self.refresh()
def onClickSave(self, event): # wxGlade: MainFrame.<event_handler>
# Get file information
fileTypes = 'Sample databases (*.db)|*.db', 'Matlab matrices (*.mat)|*.mat'
dialog = wx.FileDialog(self, 'Save', style=wx.SAVE, wildcard='|'.join(fileTypes))
if dialog.ShowModal() != wx.ID_OK: return
targetPath = dialog.GetPath()
fileType = dialog.GetFilterIndex()
dialog.Destroy()
# If the user wants to save the data as a database, save it
if fileType == 0:
sample_store.save(self.samples, store.replaceFileExtension(targetPath, 'db'))
# If the user wants to save the data in Matlab format, save it
else:
sample_store.saveForMatlab(self.samples, store.replaceFileExtension(targetPath, 'mat'))
def refresh(self):
self.showSample(self.slider.GetValue())
self.Fit()
def showSample(self, index):
# Get sample
sample = self.samples[index]
# If we have samples from a database
if self.fileType == 0:
hasRoof, geoLocation, multispectralWindow, panchromaticWindow = sample
# Display
self.checkbox_hasRoof.SetValue(hasRoof)
self.label_location.SetLabel('%s, %s' % geoLocation)
# Display multispectral images
if multispectralWindow:
label_red = 'Red (%dx%d pixels)' % (multispectralWindow.width, multispectralWindow.height)
red, green, blue, infrared = multispectralWindow.getImages_pylab(*size_small)
setBitmapFromPIL(self.bitmap_red, red)
setBitmapFromPIL(self.bitmap_green, green)
setBitmapFromPIL(self.bitmap_blue, blue)
setBitmapFromPIL(self.bitmap_infrared, infrared)
else:
label_red = 'Red'
self.bitmap_red.SetBitmap(wx.EmptyImage(0,0).ConvertToBitmap())
self.bitmap_green.SetBitmap(wx.EmptyImage(0,0).ConvertToBitmap())
self.bitmap_blue.SetBitmap(wx.EmptyImage(0,0).ConvertToBitmap())
self.bitmap_infrared.SetBitmap(wx.EmptyImage(0,0).ConvertToBitmap())
print 'Empty multispectral'
# Display panchromatic image
if panchromaticWindow:
label_panchromatic = 'Panchromatic (%dx%d pixels)' % (panchromaticWindow.width, panchromaticWindow.height)
panchromatic = panchromaticWindow.getImages_pylab(*size_large)[0]
setBitmapFromPIL(self.bitmap_panchromatic, panchromatic)
else:
label_panchromatic = 'Panchromatic'
self.bitmap_panchromatic.SetBitmap(wx.EmptyImage(0,0).ConvertToBitmap())
print 'Empty panchromatic'
# Otherwise,
else:
# Expand
lushLabel, lushSample = sample
# Set label
self.checkbox_hasRoof.SetValue(lushLabel)
self.label_location.SetLabel('')
# Set multispectral
red, green, blue, infrared = image_store.getImages_pylab('imshow', lushSample[:4], *size_small)
setBitmapFromPIL(self.bitmap_red, red)
setBitmapFromPIL(self.bitmap_green, green)
setBitmapFromPIL(self.bitmap_blue, blue)
setBitmapFromPIL(self.bitmap_infrared, infrared)
# Set panchromatic
panchromatic = image_store.getImages_pylab('imshow', [lushSample[4]], *size_large)[0]
setBitmapFromPIL(self.bitmap_panchromatic, panchromatic)
# end of class MainFrame
def setBitmapFromPIL(bitmap, pilImage):
width, height = pilImage.size
wxImage = wx.EmptyImage(width, height)
wxImage.SetData(pilImage.convert('RGB').tostring())
bitmap.SetBitmap(wxImage.ConvertToBitmap())
if __name__ == "__main__":
app = wx.PySimpleApp(0)
wx.InitAllImageHandlers()
mainFrame = MainFrame(None, -1, "")
app.SetTopWindow(mainFrame)
mainFrame.Show()
app.MainLoop()
# Import system modules
# Import custom modules
import sample_store
import view
def extract(targetDatasetPath, label, windowGeoLength, multispectralImage, panchromaticImage, geoCenters):
# Initialize
dataset = sample_store.create(targetDatasetPath)
windowCount = len(geoCenters)
# For each geoCenter,
for windowIndex, geoCenter in enumerate(geoCenters):
window = [x.extractCenteredGeoWindow(geoCenter, windowGeoLength, windowGeoLength) for x in multispectralImage, panchromaticImage]
if window[0] and window[1]: dataset.addSample(label, geoCenter, *window)
if windowIndex % 100 == 0:
view.printPercentUpdate(windowIndex + 1, windowCount)
view.printPercentFinal(windowCount)
# Return
return dataset
# Import system modules
import osgeo.gdal
import osgeo.osr
import struct
import numpy
import Image
import cPickle as pickle
import cStringIO as StringIO
# Import custom modules
import store
import view
# Types used by struct.unpack
typeByGDT = {
osgeo.gdal.GDT_Byte: 'B',
osgeo.gdal.GDT_UInt16: 'H',
}
modeByType = {
osgeo.gdal.GDT_Byte: 'L',
osgeo.gdal.GDT_UInt16: 'I;16',
}
# Define shortcuts
def load(imagePath):
return GeoImage(imagePath)
def convertGeoDimensionsToPixelDimensions(geoWidth, geoHeight, geoTransform):
g0, g1, g2, g3, g4, g5 = geoTransform
return int(round(abs(float(geoWidth) / g1))), int(round(abs(float(geoHeight) / g5)))
def convertPixelLocationsToGeoLocations(pixelLocations, geoTransform):
return [convertPixelLocationToGeoLocation(x, geoTransform) for x in pixelLocations]
def convertPixelLocationToGeoLocation(pixelLocation, geoTransform):
g0, g1, g2, g3, g4, g5 = geoTransform
xPixel, yPixel = pixelLocation
xGeo = g0 + xPixel*g1 + yPixel*g2
yGeo = g3 + xPixel*g4 + yPixel*g5
return xGeo, yGeo
# Define core
class GeoImage(object):
# Constructor
def __init__(self, imagePath):
# Initialize
self.imagePath = imagePath
self.dataset = osgeo.gdal.Open(imagePath)
self.width = self.dataset.RasterXSize
self.height = self.dataset.RasterYSize
self.geoTransform = self.dataset.GetGeoTransform()
# Get spatialReference
spatialReference = osgeo.osr.SpatialReference()
spatialReference.ImportFromWkt(self.dataset.GetProjectionRef())
self.spatialReferenceAsProj4 = spatialReference.ExportToProj4()
# Get
def getPath(self):
return self.imagePath
def getPixelWidth(self):
return self.width
def getPixelHeight(self):
return self.height
def getGeoTransform(self):
return self.geoTransform
def getSpatialReference(self):
return self.spatialReferenceAsProj4
# Extract
def extractCenteredGeoWindow(self, geoCenter, geoWidth, geoHeight):
pixelCenter = self.convertGeoLocationToPixelLocation(geoCenter)
pixelWidth, pixelHeight = self.convertGeoDimensionsToPixelDimensions(geoWidth, geoHeight)
return self.extractCenteredPixelWindow(pixelCenter, pixelWidth, pixelHeight)
def extractCenteredPixelWindow(self, pixelCenter, pixelWidth, pixelHeight):
pixelUpperLeft = centerPixelFrame(pixelCenter, pixelWidth, pixelHeight)[:2]
return self.extractPixelWindow(pixelUpperLeft, pixelWidth, pixelHeight)
def extractPixelWindow(self, pixelUpperLeft, pixelWidth, pixelHeight):
# Set
iLeft = int(pixelUpperLeft[0])
iTop = int(pixelUpperLeft[1])
iWidth = int(pixelWidth)
iHeight = int(pixelHeight)
# If the box is outside, return
if not self.isWindowInside(iLeft, iTop, iWidth, iHeight): return
# Extract
bands = map(self.dataset.GetRasterBand, xrange(1, self.dataset.RasterCount + 1))
packs = [(x.DataType, x.ReadRaster(iLeft, iTop, iWidth, iHeight)) for x in bands]
# Return
return Window(iLeft, iTop, iWidth, iHeight, packs)
# Is
def isWindowInside(self, left, top, width, height):
right = left + width
bottom = top + height
if left >= 0 and top >= 0 and right <= self.width and bottom <= self.height: return True
# Convert
def convertGeoFrameToPixelFrame(self, geoFrame):
geoLeft, geoTop, geoRight, geoBottom = geoFrame
pixelLeft, pixelTop = self.convertGeoLocationToPixelLocation((geoLeft, geoTop))
pixelRight, pixelBottom = self.convertGeoLocationToPixelLocation((geoRight, geoBottom))
pixelLeft, pixelRight = sorted((pixelLeft, pixelRight))
pixelTop, pixelBottom = sorted((pixelTop, pixelBottom))
return pixelLeft, pixelTop, pixelRight, pixelBottom
def convertPixelFrameToGeoFrame(self, pixelFrame):
pixelLeft, pixelTop, pixelRight, pixelBottom = pixelFrame
geoLeft, geoTop = self.convertPixelLocationToGeoLocation((pixelLeft, pixelTop))
geoRight, geoBottom = self.convertPixelLocationToGeoLocation((pixelRight, pixelBottom))
geoLeft, geoRight = sorted((geoLeft, geoRight))
geoTop, geoBottom = sorted((geoTop, geoBottom))
return geoLeft, geoTop, geoRight, geoBottom
def convertGeoLocationToPixelLocation(self, geoLocation):
g0, g1, g2, g3, g4, g5 = self.geoTransform
xGeo, yGeo = geoLocation
if g2 == 0:
xPixel = (xGeo - g0) / float(g1)
yPixel = (yGeo - g3 - xPixel*g4) / float(g5)
else:
xPixel = (yGeo*g2 - xGeo*g5 + g0*g5 - g2*g3) / float(g2*g4 - g1*g5)
yPixel = (xGeo - g0 - xPixel*g1) / float(g2)
return int(round(xPixel)), int(round(yPixel))
def convertPixelLocationToGeoLocation(self, pixelLocation):
return convertPixelLocationToGeoLocation(pixelLocation, self.geoTransform)
def convertGeoLocationsToPixelLocations(self, geoLocations):
return map(self.convertGeoLocationToPixelLocation, geoLocations)
def convertPixelLocationsToGeoLocations(self, pixelLocations):
return map(self.convertPixelLocationToGeoLocation, pixelLocations)
def convertGeoDimensionsToPixelDimensions(self, geoWidth, geoHeight):
return convertGeoDimensionsToPixelDimensions(geoWidth, geoHeight, self.geoTransform)
def convertPixelDimensionsToGeoDimensions(self, pixelWidth, pixelHeight):
g0, g1, g2, g3, g4, g5 = self.geoTransform
return abs(pixelWidth * g1), abs(pixelHeight * g5)
class Window(object):
values = None
matrices = None
matrixImages = None, None
pairwiseMatrices = None
images_pil = None
images_pylab = None, None
def __init__(self, left, top, width, height, packs):
# Set
self.width = width
self.height = height
self.packs = packs
# Initialize
self.frame = left, top, left + width, top + height
def getFrame(self):
return self.frame
def getValues(self):
if not self.values:
length = self.width * self.height
self.values = [struct.unpack('%d%s' % (length, typeByGDT[x[0]]), x[1]) for x in self.packs]
return self.values
def getMatrices(self):
if not self.matrices:
self.matrices = [numpy.reshape(x, (self.height, self.width)) for x in self.getValues()]
return self.matrices
def getMatrixImages(self, imageWidthInPixels, imageHeightInPixels):
imageDimensions = imageWidthInPixels, imageHeightInPixels
if not self.matrixImages[0] or self.matrixImages[1] != imageDimensions:
self.matrixImages = getImages_pylab('matshow', self.getMatrices(), *imageDimensions), imageDimensions
return self.matrixImages[0]
def getImages_pil(self):
if not self.images_pil:
imageDimensions = self.width, self.height
self.images_pil = [Image.fromstring(modeByType[x[0]], imageDimensions, x[1]) for x in self.packs]
return self.images_pil
def getImages_pylab(self, imageWidthInPixels, imageHeightInPixels):
imageDimensions = imageWidthInPixels, imageHeightInPixels
if not self.images_pylab[0] or self.images_pylab[1] != imageDimensions:
self.images_pylab = getImages_pylab('imshow', self.getMatrices(), *imageDimensions), imageDimensions
return self.images_pylab[0]
def pickle(self):
return pickle.dumps((self.width, self.height, self.packs))
def unpickle(data):
if not data: return None
width, height, packs = pickle.loads(data)
return Window(0, 0, width, height, packs)
# Get helpers
def centerPixelFrame(pixelCenter, pixelWidth, pixelHeight):
# Compute frame
pixelLeft = pixelCenter[0] - pixelWidth / 2
pixelTop = pixelCenter[1] - pixelHeight / 2
pixelRight = pixelLeft + pixelWidth
pixelBottom = pixelTop + pixelHeight
# Return
return pixelLeft, pixelTop, pixelRight, pixelBottom
def getCenter(frame):
left, top, right, bottom = frame
widthHalved = (right - left) / 2
heightHalved = (bottom - top) / 2
return int(left + widthHalved), int(top + heightHalved)
def getWindowCenter(windowLocation, windowPixelWidth, windowPixelHeight):
left, top = windowLocation
widthHalved = windowPixelWidth / 2
heightHalved = windowPixelHeight / 2
return int(left + widthHalved), int(top + heightHalved)
def getPixelsInFrame(frame):
left, top, right, bottom = frame
xRange = xrange(left, right)
yRange = xrange(top, bottom)
return [(x, y) for x in xRange for y in yRange]
def getImages_pylab(string_method, matrices, imageWidthInPixels, imageHeightInPixels):
# Import system modules
import matplotlib
import pylab
# Initialize
images = []
convertPixelLengthToInchLength = lambda x: x / float(matplotlib.rc_params()['figure.dpi'])
imageWidthInInches = convertPixelLengthToInchLength(imageWidthInPixels)
imageHeightInInches = convertPixelLengthToInchLength(imageHeightInPixels)
imageDimensions = imageWidthInInches, imageHeightInInches
# For each matrix,
for matrix in matrices:
# Create a new figure
pylab.figure(figsize=imageDimensions)
axes = pylab.axes()
# Plot grayscale matrix
method = getattr(axes, string_method)
method(matrix, cmap=pylab.cm.gray)
# Clear tick marks
axes.set_xticks([])
axes.set_yticks([])
# Save image
imageFile = StringIO.StringIO()
pylab.savefig(imageFile, format='png')
imageFile.seek(0)
# Append
images.append(Image.open(imageFile))
# Return
return images
class Information(object):
# Constructor
def __init__(self, informationPath):
self.information = store.loadInformation(informationPath)
# Positive location
def getPositiveLocationPath(self):
return self.information['positive location']['path']
# Multispectral
def getMultispectralImagePath(self):
return self.information['multispectral image']['path']
def getMultispectralImage(self):
return load(self.getMultispectralImagePath())
# Panchromatic
def getPanchromaticImagePath(self):
return self.information['panchromatic image']['path']
def getPanchromaticImage(self):
return load(self.getPanchromaticImagePath())
# Import system modules
import sqlite3
import os
import numpy
import random
import scipy.io
import itertools
import cPickle as pickle
# Import custom modules
import view
import store
import image_store
import sequence
# Set SQL
sql_getSample = 'SELECT hasRoof, xGeo, yGeo, multispectralData, panchromaticData FROM samples'
# Shortcuts
def create(datasetPath):
datasetPath = store.replaceFileExtension(datasetPath, 'db')
if os.path.exists(datasetPath): os.remove(datasetPath)
return Store(datasetPath)
def load(datasetPath):
datasetPath = store.replaceFileExtension(datasetPath, 'db')
if not os.path.exists(datasetPath): raise IOError('Dataset does not exist: ' + datasetPath)
return Store(datasetPath)
# Restore
def restoreResult(result):
# Extract
hasRoof, xGeo, yGeo, multispectralData, panchromaticData = result
geoCenter = xGeo, yGeo
# Restore
multispectralWindow = image_store.unpickle(multispectralData)
panchromaticWindow = image_store.unpickle(panchromaticData)
# Return
return hasRoof, geoCenter, multispectralWindow, panchromaticWindow
class Store(object):
# Constructor
def __init__(self, datasetPath):
# Fix extension
datasetPath = store.replaceFileExtension(datasetPath, 'db')
# Check whether the dataset exists
flag_exists = True if os.path.exists(datasetPath) else False
# Connect
self.connection = sqlite3.connect(datasetPath)
self.connection.text_factory = str
self.cursor = self.connection.cursor()
# If the dataset doesn't exist, create it
if not flag_exists:
self.cursor.execute('CREATE TABLE samples (hasRoof INTEGER, xGeo REAL, yGeo REAL, multispectralData BLOB, panchromaticData BLOB)')
self.connection.commit()
# Remember
self.datasetPath = datasetPath
# Destructor
def __del__(self):
self.connection.close()
# Add
@store.commit
def addSample(self, hasRoof, geoCenter, multispectralWindow, panchromaticWindow):
xGeo, yGeo = geoCenter
multispectralData = multispectralWindow.pickle()
panchromaticData = panchromaticWindow.pickle()
return 'INSERT INTO samples (hasRoof, xGeo, yGeo, multispectralData, panchromaticData) VALUES (?,?,?,?,?)', (hasRoof, xGeo, yGeo, multispectralData, panchromaticData)
# Delete
@store.commit
def deleteSample(self, sampleID):
return 'DELETE FROM samples WHERE rowid=?', [sampleID]
# Get
def getDatasetPath(self):
return self.datasetPath
@store.fetchAll
def getSampleIDs(self):
return 'SELECT rowid FROM samples', None, store.pullFirst
@store.fetchAll
def getRandomSampleIDs(self):
return 'SELECT rowid FROM samples ORDER BY RANDOM()', None, store.pullFirst
@store.fetchAll
def getPositiveSampleIDs(self):
return 'SELECT rowid FROM samples WHERE hasRoof=1', None, store.pullFirst
@store.fetchAll
def getNegativeSampleIDs(self):
return 'SELECT rowid FROM samples WHERE hasRoof=0', None, store.pullFirst
@store.fetchAll
def getSamplesByIDs(self, sampleIDs):
return sql_getSample + ' WHERE samples.rowID IN (%s)' % ','.join(map(str, sampleIDs)), None, restoreResult
@store.fetchAll
def getSamples(self):
return sql_getSample, None, restoreResult
@store.fetchOne
def getSample(self, sampleID):
return sql_getSample + ' WHERE samples.rowID=?', (sampleID,), restoreResult
@store.fetchAll
def getGeoCenters(self):
return 'SELECT xGeo, yGeo FROM samples', None
# Count
@store.fetchOne
def countSamples(self):
return 'SELECT COUNT(*) FROM samples', None, store.pullFirst
@store.fetchOne
def countPositiveSamples(self):
return 'SELECT COUNT(*) FROM samples WHERE hasRoof=1', None, store.pullFirst
@store.fetchOne
def countNegativeSamples(self):
return 'SELECT COUNT(*) FROM samples WHERE hasRoof=0', None, store.pullFirst
def getStatistics(self):
return {
'total': self.countSamples(),
'positive': self.countPositiveSamples(),
'negative': self.countNegativeSamples(),
}
# Cut
def cutIDs(self, testFraction, withRandomization=True):
# For each cut,
for cutPack in sequence.cut(self.getSampleIDs(), testFraction, withRandomization):
# Yield
yield cutPack
# Save
def saveForMatlab(self, targetPath):
# Initialize
samples = self.getSamples(); labels = []; geoCenters = []
reds = []; greens = []; blues = []; infrareds = []; panchromatics = []
# Assemble
for index in xrange(len(samples)):
# Extract
hasRoof, geoCenter, multispectralWindow, panchromaticWindow = samples[index]
# Gather
labels.append(hasRoof); geoCenters.append(geoCenter)
# Gather multispectral
red, green, blue, infrared = multispectralWindow.getMatrices()
reds.append(red); greens.append(green); blues.append(blue); infrareds.append(infrared)
# Gather panchromatic
panchromatics.append(panchromaticWindow.getMatrices()[0])
# Save
matrixDictionary = {
'labels': numpy.dstack(labels), 'geoCenters': numpy.dstack(geoCenters),
'reds': numpy.dstack(reds), 'greens': numpy.dstack(greens),
'blues': numpy.dstack(blues), 'infrareds': numpy.dstack(infrareds),
'panchromatics': numpy.dstack(panchromatics),
}
scipy.io.savemat(targetPath, matrixDictionary, True)
# Save
def save(targetPath, datasetSampleIDs):
# Open targetDataset
targetDataset = Store(targetPath)
# Save samples
sampleCount = len(datasetSampleIDs)
for sampleIndex, (sourceDataset, sampleID) in enumerate(datasetSampleIDs):
targetDataset.addSample(*sourceDataset.getSample(sampleID))
if sampleIndex % 100 == 0:
view.printPercentUpdate(sampleIndex + 1, sampleCount)
view.printPercentFinal(sampleCount)
# Return
return targetDataset
# Import system modules
import numpy
def makeSampleLabelPaths(basePath):
return basePath + '-samples', basePath + '-labels'
def makeLabelGeneratorFromLushDataset(filePath):
# Set paths
labelPath = makeSampleLabelPaths(filePath)[1]
# Initialize
labelFile = open(labelPath)
header = labelFile.next()
terms = []
# For each line,
for line in labelFile:
# Extend terms
terms.extend(int(x) for x in line.split())
# Pop each term
for termIndex in xrange(len(terms)):
yield terms.pop(0)
def makeSampleGeneratorFromLushDataset(filePath):
# Set paths
samplePath = makeSampleLabelPaths(filePath)[0]
# Initialize
sampleFile = open(samplePath)
header = sampleFile.next()
sampleShape = [int(x) for x in header.split()[3:]]
sampleSize = numpy.product(sampleShape)
terms = []
# For each line,
for line in sampleFile:
# Extend terms
terms.extend(float(x) for x in line.split())
# If we have enough terms,
while len(terms) >= sampleSize:
# Yield sample
yield numpy.array(terms[:sampleSize]).reshape(sampleShape)
# Set terms
terms = terms[sampleSize:]
'Save and load points to a shapefile'
# Import system modules
import osgeo.ogr
import osgeo.osr
import os
# Core
def save(shapePath, geoLocations, proj4):
'Save points in the given shapePath'
# Get driver
driver = osgeo.ogr.GetDriverByName('ESRI Shapefile')
# Create shapeData
shapePath = validateShapePath(shapePath)
if os.path.exists(shapePath):
os.remove(shapePath)
shapeData = driver.CreateDataSource(shapePath)
# Create spatialReference
spatialReference = getSpatialReferenceFromProj4(proj4)
# Create layer
layerName = os.path.splitext(os.path.split(shapePath)[1])[0]
layer = shapeData.CreateLayer(layerName, spatialReference, osgeo.ogr.wkbPoint)
layerDefinition = layer.GetLayerDefn()
# For each point,
for pointIndex, geoLocation in enumerate(geoLocations):
# Create point
geometry = osgeo.ogr.Geometry(osgeo.ogr.wkbPoint)
geometry.SetPoint(0, geoLocation[0], geoLocation[1])
# Create feature
feature = osgeo.ogr.Feature(layerDefinition)
feature.SetGeometry(geometry)
feature.SetFID(pointIndex)
# Save feature
layer.CreateFeature(feature)
# Cleanup
geometry.Destroy()
feature.Destroy()
# Cleanup
shapeData.Destroy()
# Return
return shapePath
def load(shapePath):
'Given a shapePath, return a list of points in GIS coordinates'
# Open shapeData
shapeData = osgeo.ogr.Open(validateShapePath(shapePath))
# Validate shapeData
validateShapeData(shapeData)
# Get the first layer
layer = shapeData.GetLayer()
# Initialize
points = []
# For each point,
for index in xrange(layer.GetFeatureCount()):
# Get
feature = layer.GetFeature(index)
geometry = feature.GetGeometryRef()
# Make sure that it is a point
if geometry.GetGeometryType() != osgeo.ogr.wkbPoint:
raise ShapeDataError('This module can only load points; use geometry_store.py')
# Get pointCoordinates
pointCoordinates = geometry.GetX(), geometry.GetY()
# Append
points.append(pointCoordinates)
# Cleanup
feature.Destroy()
# Get spatial reference as proj4
proj4 = layer.GetSpatialRef().ExportToProj4()
# Cleanup
shapeData.Destroy()
# Return
return points, proj4
def merge(sourcePaths, targetPath):
'Merge a list of shapefiles into a single shapefile'
# Load
items = [load(validateShapePath(x)) for x in sourcePaths]
pointLists = [x[0] for x in items]
points = reduce(lambda x,y: x+y, pointLists)
spatialReferences= [x[1] for x in items]
# Make sure that all the spatial references are the same
if len(set(spatialReferences)) != 1:
raise ShapeDataError('The shapefiles must have the same spatial reference')
spatialReference = spatialReferences[0]
# Save
save(validateShapePath(targetPath), points, spatialReference)
def getSpatialReferenceFromProj4(proj4):
'Return GDAL spatial reference object from proj4 string'
spatialReference = osgeo.osr.SpatialReference()
spatialReference.ImportFromProj4(proj4)
return spatialReference
# Validate
def validateShapePath(shapePath):
'Validate shapefile extension'
return os.path.splitext(str(shapePath))[0] + '.shp'
def validateShapeData(shapeData):
'Make sure we can access the shapefile'
# Make sure the shapefile exists
if not shapeData:
raise ShapeDataError('The shapefile is invalid')
# Make sure there is exactly one layer
if shapeData.GetLayerCount() != 1:
raise ShapeDataError('The shapefile must have exactly one layer')
# Error
class ShapeDataError(Exception):
pass
# Import system modules
import random
import numpy
def uniquify(sequence):
seen = set()
return [x for x in sequence if x not in seen and not seen.add(x)]
def cut(sequence, fraction, withRandomization=True):
# Copy
sequence = list(sequence)
# Randomize
if withRandomization: random.shuffle(sequence)
# Count
totalCount = len(sequence)
partialCount = int(numpy.ceil(totalCount * fraction))
# Split
for firstIndex in xrange(0, totalCount, partialCount):
lastIndex = firstIndex + partialCount
insideFraction = sequence[firstIndex:lastIndex]
outsideFraction = sequence[0:firstIndex] + sequence[lastIndex:totalCount]
yield outsideFraction, insideFraction
# Import system modules
import os
import re
import sys
import time
import numpy
import datetime
import ConfigParser
# SQL
def execute(function, args, kwargs):
# Extract
self = args[0]
# Execute
z = function(*args, **kwargs)
sql, value = z[:2]
if value == None: self.cursor.execute(sql)
else: self.cursor.execute(sql, value)
# Return
method = z[2] if len(z) > 2 else None
return self, method
def commit(function):
def wrapper(*args, **kwargs):
self = execute(function, args, kwargs)[0]
self.connection.commit()
return self.cursor.lastrowid
return wrapper
def fetchOne(function):
def wrapper(*args, **kwargs):
self, method = execute(function, args, kwargs)
result = self.cursor.fetchone()
return method(result) if method else result
return wrapper
def fetchAll(function):
def wrapper(*args, **kwargs):
self, method = execute(function, args, kwargs)
results = self.cursor.fetchall()
return map(method, results) if method else results
return wrapper
def pullFirst(result):
if result != None: return result[0]
def pullTrueIfResult(result):
return True if result != None else False
# File
def replaceFileExtension(filePath, newExtension):
if not newExtension.startswith('.'): newExtension = '.' + newExtension
base = os.path.splitext(filePath)[0]
return base + newExtension
def extractFileBaseName(filePath):
filename = os.path.split(filePath)[1]
return os.path.splitext(filename)[0]
def extractFileName(filePath):
return os.path.split(filePath)[1]
def verifyPath(filePath):
if not os.path.exists(filePath): raise QueueError('Path not found: %s' % filePath)
return filePath.strip()
def fillPath(rootPath, relativeFolderPath, relativeFilePath):
folderPath = os.path.join(rootPath, relativeFolderPath)
filePath = os.path.join(folderPath, relativeFilePath)
return os.path.abspath(filePath)
def makeTimestamp():
return datetime.datetime.now().strftime('%Y%m%d%H%M%S')
def makeFolderSafely(folderPath):
if not os.path.exists(folderPath): os.mkdir(folderPath)
return folderPath
def removeSafely(filePath):
if os.path.exists(filePath): os.remove(filePath)
# Information
def saveInformation(filePath, valueByOptionBySection, fileExtension='info'):
# Initialize
configuration = ConfigParser.RawConfigParser()
# For each section,
for section in valueByOptionBySection:
# Initialize
valueByOption = valueByOptionBySection[section]
# Add section
addConfigurationSection(configuration, section, valueByOption)
# Write
filePath = replaceFileExtension(filePath, fileExtension)
configuration.write(open(filePath, 'wt'))
def addConfigurationSection(configuration, section, valueByOption):
# Add section
configuration.add_section(section)
# For each option,
for option in valueByOption:
# Initialize
value = valueByOption[option]
# If value is a dictionary,
if isinstance(value, dict):
# Add section
addConfigurationSection(configuration, '%s.%s' % (section, option), value)
# Otherwise
else:
# Add option
configuration.set(section, option, value)
def loadInformation(filePath, fileExtension='info'):
# Initialize
configuration = ConfigParser.RawConfigParser()
valueByOptionBySection = {}
# Read
filePath = replaceFileExtension(filePath, fileExtension)
configuration.read(filePath)
# For each section,
for section in configuration.sections():
# Initialize
valueByOption = {}
# For each option,
for option in configuration.options(section):
# Get option and value
valueByOption[option] = configuration.get(section, option)
# Store
valueByOptionBySection[section] = valueByOption
# Return
return valueByOptionBySection
def loadQueue(queuePath, convertByName):
# Load
valueByNameBySection = loadInformation(queuePath, fileExtension='queue')
sections = valueByNameBySection.keys()
# Convert values
for section in sections:
valueByNameBySection[section] = convertValueByName(valueByNameBySection[section], convertByName)
# Set globalParameterByName
globalParameterByName = valueByNameBySection.get('parameters', {})
if 'parameters' in sections:
sections.remove('parameters')
# Set parameterByTaskByName
parameterByTaskByName = {}
for section in sections:
# Load
valueByName = globalParameterByName.copy()
valueByName.update(valueByNameBySection[section])
# Save
parameterByTaskByName[section] = valueByName
# Return
return parameterByTaskByName
def convertValueByName(valueByName, convertByName):
# For each name,
for name, value in valueByName.iteritems():
# Convert
try:
convert = convertByName[name]
valueByName[name] = convert(value)
except KeyError:
raise QueueError('%s is undefined' % name)
except ValueError:
raise QueueError('%s=%s has the wrong type' % (name, value))
# Return
return valueByName
def stringifyList(items):
return '\n' + '\n'.join(str(x) for x in items)
def stringifyNestedList(lists):
return '\n' + '\n'.join(' '.join(str(item) for item in list) for list in lists)
def unstringifyStringList(content):
lines = map(str.strip, content.splitlines())
return filter(lambda line: True if line else False, lines)
def unstringifyFloatList(content):
return map(float, content.split())
def unstringifyIntegerList(content):
return map(int, content.split())
def unstringifyNestedIntegerList(content):
return [[int(x) for x in line.split()] for line in unstringifyStringList(content)]
class Information(object):
def __init__(self, informationPath):
self.informationPath = informationPath
self.information = loadInformation(informationPath)
self.parameterByName = self.information.get('parameters', {})
self.experimentPath = os.path.dirname(os.path.dirname(os.path.dirname(informationPath)))
def expandPath(self, relativePath):
return os.path.join(self.experimentPath, relativePath)
# Time
def recordElapsedTime(function):
# Define wrapper
def wrapper(*args, **kwargs):
# Run
startTimeInSeconds = time.time()
resultByName = function(*args, **kwargs)
endTimeInSeconds = time.time()
# Record
elapsedTimeInSeconds = int(round(endTimeInSeconds - startTimeInSeconds))
if not resultByName:
resultByName = {}
resultByName['elapsed time in seconds'] = elapsedTimeInSeconds
print 'elapsed time in seconds = %s' % elapsedTimeInSeconds
# Return
return resultByName
# Return
return wrapper
# Module
def getLibraryModule(moduleName):
return __import__('libraries.' + moduleName, fromlist=['libraries'])
# Error
class QueueError(Exception):
pass
# Import system modules
import sys
# Feedback
def printDirectly(feedback):
sys.stdout.write(feedback)
sys.stdout.flush()
def printPercentUpdate(currentCount, totalCount):
printDirectly('\r% 3d %% \t%d ' % (100 * currentCount / totalCount, currentCount))
def printPercentFinal(totalCount):
printDirectly('\r100 %% \t%d \n' % totalCount)
def trackProgress(generator, totalCount, packetLength):
# Initialize
items = []
# For each item,
for currentCount, item in enumerate(generator):
items.append(item)
if currentCount % packetLength == 0:
printPercentUpdate(currentCount + 1, totalCount)
# Return
printPercentFinal(totalCount)
return items