Extract raster windows from satellite imagery with GDAL

Extract and display raster windows from 16-bit satellite imagery using pixel coordinates or geo coordinates. We first presented this tutorial as part of a three-hour session on Working with Geographic Information Systems in Python during the 2009 Python Conference in Chicago, Illinois.

Example

Make sure all dependencies are installed.

yum install gdal gdal-python numpy python-imaging ipython wxPython qgis

Download the code and data and run the scripts

wget http://invisibleroads.com/tutorials/_downloads/gdal-raster-extract.zip
unzip gdal-raster-extract.zip
cd gdal-raster-extract
cd examples
python ../extractSamples.py multispectral.tif panchromatic.tif locations.shp
python ../browseSamples.py

Open the generated SQLite database samples.db and browse the extracted samples. Below you can see the first sample with the four low-resolution multispectral bands on the left and the high-resolution panchromatic band on the right.

_images/gdal-raster-extract-browse.png

Requirements

To run browseSamples.py, you will also need the following:

Walkthrough

Make sure all dependencies are installed.

yum install gdal gdal-python numpy python-imaging ipython wxPython qgis

Download the code and data and start IPython.

wget http://invisibleroads.com/tutorials/_downloads/gdal-raster-extract.zip
unzip gdal-raster-extract.zip
cd gdal-raster-extract
ipython

The package contains two scripts and example data as well as supporting library modules.

In [1]: ls
browseSamples.py*  examples/  extractSamples.py*  libraries/

In [2]: ls examples/
locations.shp
locations.shx
locations.prj
locations.dbf
multispectral.tif
multispectral.tif.aux.xml
panchromatic.tif

In [3]: ls libraries/
__init__.py
image_store.py
point_store.py
sample_store.py
sample_store_lush.py
sequence.py
store.py
window_process.py
view.py

The example data contains a low-resolution multispectral image, a high-resolution panchromatic image and several hand-marked locations of buildings in the image. The imagery shows a region of Boulder, Colorado cropped from DigitalGlobe’s sample standard 16-bit imagery. Note that for higher geo-referencing accuracy, you would probably use ortho-ready imagery instead of the coarsely-orthorectified standard imagery and perform orthorectification yourself using a high-resolution digital elevation model (DEM) as described in the tutorial Orthorectify satellite images with ENVI.

Multispectral image with locations

_images/gdal-raster-extract-multispectral-view.png

Panchromatic image with locations

_images/gdal-raster-extract-panchromatic-view.png

Extract raster window from satellite image using a pixel location

Load panchromatic image.

import osgeo.gdal
imageDataset = osgeo.gdal.Open('examples/panchromatic.tif')

Define methods that we will use later. Note that we use pylab.imshow() because PIL’s Image class has difficulty handling 16-bit image data. We want to keep the 16-bit values because each raw image bit has potential information in remote sensing.

import struct, numpy, pylab

def extractWindow(pixelX, pixelY, pixelWidth, pixelHeight):
    # Extract raw data
    band = imageDataset.GetRasterBand(1)
    byteString = band.ReadRaster(pixelX, pixelY, pixelWidth, pixelHeight)
    # Convert to a matrix
    valueType = {osgeo.gdal.GDT_Byte: 'B', osgeo.gdal.GDT_UInt16: 'H'}[band.DataType]
    values = struct.unpack('%d%s' % (pixelWidth * pixelHeight, valueType), byteString)
    matrix = numpy.reshape(values, (pixelWidth, pixelHeight))
    # Display matrix
    pylab.imshow(matrix, cmap=pylab.cm.gray)
    pylab.show()
    # Return
    return matrix

def extractCenteredWindow(pixelX, pixelY, pixelWidth, pixelHeight):
    centeredPixelX = pixelX - pixelWidth / 2
    centeredPixelY = pixelY - pixelHeight / 2
    return extractWindow(centeredPixelX, centeredPixelY, pixelWidth, pixelHeight)

Extract a 100x100 pixel window near the middle of the image.

extractCenteredWindow(imageDataset.RasterXSize / 2, imageDataset.RasterYSize / 2, 100, 100)
_images/gdal-raster-extract-pixel.png

Extract raster window from satellite image using a geo location

To extract raster windows using geocoordinates, we must convert the geocoordinates to their corresponding pixel locations in the image. Each image has a set of numbers called the GeoTransform that tell us how to convert between geo locations and pixel locations.

Get image georeferencing information.

g0, g1, g2, g3, g4, g5 = imageDataset.GetGeoTransform()

Define conversion methods.

def convertGeoLocationToPixelLocation(geoLocation):
    xGeo, yGeo = geoLocation
    if g2 == 0:
        xPixel = (xGeo - g0) / float(g1)
        yPixel = (yGeo - g3 - xPixel*g4) / float(g5)
    else:
        xPixel = (yGeo*g2 - xGeo*g5 + g0*g5 - g2*g3) / float(g2*g4 - g1*g5)
        yPixel = (xGeo - g0 - xPixel*g1) / float(g2)
    return int(round(xPixel)), int(round(yPixel))

def convertGeoDimensionsToPixelDimensions(geoWidth, geoHeight):
    return int(round(abs(float(geoWidth) / g1))), int(round(abs(float(geoHeight) / g5)))

Load locations from shapefile. For details, please see the tutorial Load points from a shapefile with GDAL.

from libraries import point_store
geoLocations, spatialReference = point_store.load('examples/locations.shp')

Convert the first geo location to a pixel location.

windowPixelX, windowPixelY = convertGeoLocationToPixelLocation(geoLocations[0])

Convert window dimensions from 25 meters to their equivalent in pixels.

windowPixelWidth, windowPixelHeight = convertGeoDimensionsToPixelDimensions(25, 25)

Now you can use the method you defined in Extract raster window from satellite image using a pixel location.

extractCenteredWindow(windowPixelX, windowPixelY, windowPixelWidth, windowPixelHeight)
_images/gdal-raster-extract-geo.png

Code

extractSamples.py

#!/usr/bin/env python


# Import system modules
import optparse
# Import custom modules
from libraries import image_store, point_store, window_process


# If we are running the script from the command-line,
if __name__ == '__main__':
    # Parse options and arguments
    optionParser = optparse.OptionParser(
        usage='%prog MULTISPECTRAL-PATH PANCHROMATIC-PATH SHAPE-PATH',
        epilog=(
            'Extracts raster windows from MULTISPECTRAL-PATH and '
            'PANCHROMATIC-PATH using locations from SHAPE-PATH and '
            'saves results in OUTPUT-PATH.'
        )
    )
    optionParser.add_option('-o', '--output-path', dest='outputPath', 
        metavar='OUTPUT-PATH', default='samples.db',
        help='save results in OUTPUT-PATH')
    optionParser.add_option('-m', '--meters', dest='windowGeoLength', 
        metavar='LENGTH', default=25, type='int',
        help='specify centered window length in meters')
    optionParser.add_option('-l', '--label', dest='windowLabel', 
        metavar='INTEGER', default=1, type='int',
        help='specify window label')
    options, arguments = optionParser.parse_args()
    # Verify
    if len(arguments) == 3:
        # Extract
        multispectralImagePath, panchromaticImagePath, locationPath = arguments
        window_process.extract(options.outputPath, 
            options.windowLabel, options.windowGeoLength,
            image_store.load(multispectralImagePath), 
            image_store.load(panchromaticImagePath),
            point_store.load(locationPath)[0])
    else:
        # Show help
        optionParser.print_help()

browseSamples.py

#!/usr/bin/env python


# -*- coding: utf-8 -*-
# generated by wxGlade 0.6.3 on Sat Sep 27 02:31:00 2008


# Import system modules
import sys
import itertools
import wx
import os
# Import custom modules
from libraries import sample_store, sample_store_lush, image_store, store


# begin wxGlade: extracode
# end wxGlade


# Set
size_small = 100, 100
size_large = 400, 400


class MainFrame(wx.Frame):

    def __init__(self, *args, **kwds):
        # begin wxGlade: MainFrame.__init__
        kwds["style"] = wx.DEFAULT_FRAME_STYLE
        wx.Frame.__init__(self, *args, **kwds)
        self.button_open = wx.Button(self, -1, "Open")
        self.slider = wx.Slider(self, -1, 0, 0, 10)
        self.checkbox_hasRoof = wx.CheckBox(self, -1, "Has Roof")
        self.label_location = wx.StaticText(self, -1, "")
        self.label_red = wx.StaticText(self, -1, "Red", style=wx.ALIGN_CENTRE)
        self.bitmap_red = wx.StaticBitmap(self, -1, wx.NullBitmap)
        self.label_green = wx.StaticText(self, -1, "Green", style=wx.ALIGN_CENTRE)
        self.bitmap_green = wx.StaticBitmap(self, -1, wx.NullBitmap)
        self.label_blue = wx.StaticText(self, -1, "Blue", style=wx.ALIGN_CENTRE)
        self.bitmap_blue = wx.StaticBitmap(self, -1, wx.NullBitmap)
        self.label_infrared = wx.StaticText(self, -1, "Infrared", style=wx.ALIGN_CENTRE)
        self.bitmap_infrared = wx.StaticBitmap(self, -1, wx.NullBitmap)
        self.label_panchromatic = wx.StaticText(self, -1, "Panchromatic", style=wx.ALIGN_CENTRE)
        self.bitmap_panchromatic = wx.StaticBitmap(self, -1, wx.NullBitmap)
        self.button_1 = wx.Button(self, -1, "Save")

        self.__set_properties()
        self.__do_layout()

        self.Bind(wx.EVT_BUTTON, self.onClickOpen, self.button_open)
        self.Bind(wx.EVT_COMMAND_SCROLL, self.onScroll, self.slider)
        self.Bind(wx.EVT_BUTTON, self.onClickSave, self.button_1)
        # end wxGlade
        self.samples = []
        self.flag_changed = False

    def __set_properties(self):
        # begin wxGlade: MainFrame.__set_properties
        self.SetTitle("Sample Browser")
        self.SetSize((800, 600))
        self.checkbox_hasRoof.Enable(False)
        # end wxGlade

    def __do_layout(self):
        # begin wxGlade: MainFrame.__do_layout
        sizer_1 = wx.BoxSizer(wx.VERTICAL)
        sizer_4 = wx.BoxSizer(wx.HORIZONTAL)
        sizer_6 = wx.BoxSizer(wx.VERTICAL)
        grid_sizer_1 = wx.GridSizer(2, 2, 0, 0)
        sizer_8 = wx.BoxSizer(wx.VERTICAL)
        sizer_5 = wx.BoxSizer(wx.VERTICAL)
        sizer_3 = wx.BoxSizer(wx.VERTICAL)
        sizer_2 = wx.BoxSizer(wx.VERTICAL)
        sizer_7 = wx.BoxSizer(wx.HORIZONTAL)
        sizer_1.Add(self.button_open, 0, wx.EXPAND, 0)
        sizer_1.Add(self.slider, 0, wx.EXPAND, 0)
        sizer_7.Add(self.checkbox_hasRoof, 0, 0, 0)
        sizer_7.Add((20, 20), 0, 0, 0)
        sizer_7.Add(self.label_location, 1, wx.EXPAND, 0)
        sizer_1.Add(sizer_7, 0, wx.EXPAND, 0)
        sizer_2.Add(self.label_red, 0, wx.EXPAND, 0)
        sizer_2.Add(self.bitmap_red, 0, wx.ALIGN_CENTER_HORIZONTAL|wx.ALIGN_CENTER_VERTICAL, 0)
        grid_sizer_1.Add(sizer_2, 1, wx.EXPAND, 0)
        sizer_3.Add(self.label_green, 0, wx.EXPAND, 0)
        sizer_3.Add(self.bitmap_green, 0, wx.ALIGN_CENTER_HORIZONTAL|wx.ALIGN_CENTER_VERTICAL, 0)
        grid_sizer_1.Add(sizer_3, 1, wx.EXPAND, 0)
        sizer_5.Add(self.label_blue, 0, wx.EXPAND, 0)
        sizer_5.Add(self.bitmap_blue, 0, wx.ALIGN_CENTER_HORIZONTAL|wx.ALIGN_CENTER_VERTICAL, 0)
        grid_sizer_1.Add(sizer_5, 1, wx.EXPAND, 0)
        sizer_8.Add(self.label_infrared, 0, wx.EXPAND, 0)
        sizer_8.Add(self.bitmap_infrared, 0, wx.ALIGN_CENTER_HORIZONTAL|wx.ALIGN_CENTER_VERTICAL, 0)
        grid_sizer_1.Add(sizer_8, 1, wx.EXPAND, 0)
        sizer_4.Add(grid_sizer_1, 2, wx.EXPAND, 0)
        sizer_6.Add(self.label_panchromatic, 0, wx.EXPAND, 0)
        sizer_6.Add(self.bitmap_panchromatic, 1, wx.ALIGN_CENTER_HORIZONTAL|wx.ALIGN_CENTER_VERTICAL, 0)
        sizer_4.Add(sizer_6, 3, wx.EXPAND, 0)
        sizer_1.Add(sizer_4, 1, wx.EXPAND, 0)
        sizer_1.Add(self.button_1, 0, wx.EXPAND, 0)
        self.SetSizer(sizer_1)
        self.Layout()
        # end wxGlade

    def onClickOpen(self, event): # wxGlade: MainFrame.<event_handler>
        # Get file information
        fileTypes = 'Sample databases (*.db)|*.db', 'Lush samples (*Samples)|*Samples'
        dialog = wx.FileDialog(self, 'Open', style=wx.OPEN, wildcard='|'.join(fileTypes))
        if dialog.ShowModal() != wx.ID_OK: return
        filePath = dialog.GetPath()
        self.fileType = dialog.GetFilterIndex()
        dialog.Destroy()
        # If the user wants to open a database,
        if self.fileType == 0: 
            self.samples = sample_store.load(filePath).getSamples()
        # If the user wants to open Lush samples,
        else: 
            filePath = filePath.replace('Samples', '')
            self.samples = [(label, sample) for label, sample in itertools.izip(sample_store_lush.makeLabelGeneratorFromLushDataset(filePath), sample_store_lush.makeSampleGeneratorFromLushDataset(filePath))]
        # Set slider
        self.slider.SetRange(0, len(self.samples) - 1)
        # Show first sample
        self.refresh()

    def onScroll(self, event): # wxGlade: MainFrame.<event_handler>
        self.refresh()

    def onClickSave(self, event): # wxGlade: MainFrame.<event_handler>
        # Get file information
        fileTypes = 'Sample databases (*.db)|*.db', 'Matlab matrices (*.mat)|*.mat'
        dialog = wx.FileDialog(self, 'Save', style=wx.SAVE, wildcard='|'.join(fileTypes))
        if dialog.ShowModal() != wx.ID_OK: return
        targetPath = dialog.GetPath()
        fileType = dialog.GetFilterIndex()
        dialog.Destroy()
        # If the user wants to save the data as a database, save it
        if fileType == 0: 
            sample_store.save(self.samples, store.replaceFileExtension(targetPath, 'db'))
        # If the user wants to save the data in Matlab format, save it
        else: 
            sample_store.saveForMatlab(self.samples, store.replaceFileExtension(targetPath, 'mat'))

    def refresh(self):        
        self.showSample(self.slider.GetValue())
        self.Fit()

    def showSample(self, index):
        # Get sample
        sample = self.samples[index]
        # If we have samples from a database
        if self.fileType == 0:
            hasRoof, geoLocation, multispectralWindow, panchromaticWindow = sample
            # Display
            self.checkbox_hasRoof.SetValue(hasRoof)
            self.label_location.SetLabel('%s, %s' % geoLocation)
            # Display multispectral images
            if multispectralWindow:
                label_red = 'Red (%dx%d pixels)' % (multispectralWindow.width, multispectralWindow.height)
                red, green, blue, infrared = multispectralWindow.getImages_pylab(*size_small)
                setBitmapFromPIL(self.bitmap_red, red)
                setBitmapFromPIL(self.bitmap_green, green)
                setBitmapFromPIL(self.bitmap_blue, blue)
                setBitmapFromPIL(self.bitmap_infrared, infrared)
            else:
                label_red = 'Red'
                self.bitmap_red.SetBitmap(wx.EmptyImage(0,0).ConvertToBitmap())
                self.bitmap_green.SetBitmap(wx.EmptyImage(0,0).ConvertToBitmap())
                self.bitmap_blue.SetBitmap(wx.EmptyImage(0,0).ConvertToBitmap())
                self.bitmap_infrared.SetBitmap(wx.EmptyImage(0,0).ConvertToBitmap())
                print 'Empty multispectral'
            # Display panchromatic image
            if panchromaticWindow:
                label_panchromatic = 'Panchromatic (%dx%d pixels)' % (panchromaticWindow.width, panchromaticWindow.height)
                panchromatic = panchromaticWindow.getImages_pylab(*size_large)[0]
                setBitmapFromPIL(self.bitmap_panchromatic, panchromatic)
            else: 
                label_panchromatic = 'Panchromatic'
                self.bitmap_panchromatic.SetBitmap(wx.EmptyImage(0,0).ConvertToBitmap())
                print 'Empty panchromatic'
        # Otherwise,
        else:
            # Expand
            lushLabel, lushSample = sample
            # Set label
            self.checkbox_hasRoof.SetValue(lushLabel)
            self.label_location.SetLabel('')
            # Set multispectral
            red, green, blue, infrared = image_store.getImages_pylab('imshow', lushSample[:4], *size_small) 
            setBitmapFromPIL(self.bitmap_red, red)
            setBitmapFromPIL(self.bitmap_green, green)
            setBitmapFromPIL(self.bitmap_blue, blue)
            setBitmapFromPIL(self.bitmap_infrared, infrared)
            # Set panchromatic
            panchromatic = image_store.getImages_pylab('imshow', [lushSample[4]], *size_large)[0]
            setBitmapFromPIL(self.bitmap_panchromatic, panchromatic)

# end of class MainFrame


def setBitmapFromPIL(bitmap, pilImage):
    width, height = pilImage.size
    wxImage = wx.EmptyImage(width, height)
    wxImage.SetData(pilImage.convert('RGB').tostring())
    bitmap.SetBitmap(wxImage.ConvertToBitmap())


if __name__ == "__main__":
    app = wx.PySimpleApp(0)
    wx.InitAllImageHandlers()
    mainFrame = MainFrame(None, -1, "")
    app.SetTopWindow(mainFrame)
    mainFrame.Show()
    app.MainLoop()

window_process.py

# Import system modules
# Import custom modules
import sample_store
import view


def extract(targetDatasetPath, label, windowGeoLength, multispectralImage, panchromaticImage, geoCenters):
    # Initialize
    dataset = sample_store.create(targetDatasetPath)
    windowCount = len(geoCenters)
    # For each geoCenter,
    for windowIndex, geoCenter in enumerate(geoCenters):
        window = [x.extractCenteredGeoWindow(geoCenter, windowGeoLength, windowGeoLength) for x in multispectralImage, panchromaticImage]
        if window[0] and window[1]: dataset.addSample(label, geoCenter, *window)
        if windowIndex % 100 == 0: 
            view.printPercentUpdate(windowIndex + 1, windowCount)
    view.printPercentFinal(windowCount)
    # Return
    return dataset

image_store.py

# Import system modules
import Image
import cPickle as pickle
import io as StringIO
import numpy
import osgeo.gdal
import osgeo.osr
import struct
# Import custom modules
import store
import view


# Types used by struct.unpack
typeByGDT = {
    osgeo.gdal.GDT_Byte: 'B',
    osgeo.gdal.GDT_UInt16: 'H', 
}
modeByType = {
    osgeo.gdal.GDT_Byte: 'L',
    osgeo.gdal.GDT_UInt16: 'I;16', 
}


# Define shortcuts

def load(imagePath): 
    return GeoImage(imagePath)

def convertGeoDimensionsToPixelDimensions(geoWidth, geoHeight, geoTransform):
    g0, g1, g2, g3, g4, g5 = geoTransform
    return int(round(abs(float(geoWidth) / g1))), int(round(abs(float(geoHeight) / g5)))

def convertPixelLocationsToGeoLocations(pixelLocations, geoTransform):
    return [convertPixelLocationToGeoLocation(x, geoTransform) for x in pixelLocations]

def convertPixelLocationToGeoLocation(pixelLocation, geoTransform):
    g0, g1, g2, g3, g4, g5 = geoTransform
    xPixel, yPixel = pixelLocation
    xGeo = g0 + xPixel*g1 + yPixel*g2
    yGeo = g3 + xPixel*g4 + yPixel*g5
    return xGeo, yGeo


# Define core

class GeoImage(object):

    # Constructor

    def __init__(self, imagePath):
        # Initialize
        self.imagePath = imagePath
        self.dataset = osgeo.gdal.Open(imagePath)
        self.width = self.dataset.RasterXSize
        self.height = self.dataset.RasterYSize
        self.geoTransform = self.dataset.GetGeoTransform()
        # Get spatialReference
        spatialReference = osgeo.osr.SpatialReference()
        spatialReference.ImportFromWkt(self.dataset.GetProjectionRef())
        self.spatialReferenceAsProj4 = spatialReference.ExportToProj4()

    # Get

    def getPath(self):
        return self.imagePath

    def getPixelWidth(self):
        return self.width

    def getPixelHeight(self):
        return self.height

    def getGeoTransform(self):
        return self.geoTransform

    def getSpatialReference(self):
        return self.spatialReferenceAsProj4

    # Extract

    def extractCenteredGeoWindow(self, geoCenter, geoWidth, geoHeight):
        pixelCenter = self.convertGeoLocationToPixelLocation(geoCenter)
        pixelWidth, pixelHeight = self.convertGeoDimensionsToPixelDimensions(geoWidth, geoHeight)
        return self.extractCenteredPixelWindow(pixelCenter, pixelWidth, pixelHeight)

    def extractCenteredPixelWindow(self, pixelCenter, pixelWidth, pixelHeight):
        pixelUpperLeft = centerPixelFrame(pixelCenter, pixelWidth, pixelHeight)[:2]
        return self.extractPixelWindow(pixelUpperLeft, pixelWidth, pixelHeight)

    def extractPixelWindow(self, pixelUpperLeft, pixelWidth, pixelHeight):
        # Set
        iLeft = int(pixelUpperLeft[0])
        iTop = int(pixelUpperLeft[1])
        iWidth = int(pixelWidth)
        iHeight = int(pixelHeight)
        # If the box is outside, return
        if not self.isWindowInside(iLeft, iTop, iWidth, iHeight): return
        # Extract
        bands = map(self.dataset.GetRasterBand, xrange(1, self.dataset.RasterCount + 1))
        packs = [(x.DataType, x.ReadRaster(iLeft, iTop, iWidth, iHeight)) for x in bands]
        # Return
        return Window(iLeft, iTop, iWidth, iHeight, packs)

    # Is

    def isWindowInside(self, left, top, width, height):
        right = left + width
        bottom = top + height
        if left >= 0 and top >= 0 and right <= self.width and bottom <= self.height: return True

    # Convert

    def convertGeoFrameToPixelFrame(self, geoFrame):
        geoLeft, geoTop, geoRight, geoBottom = geoFrame
        pixelLeft, pixelTop = self.convertGeoLocationToPixelLocation((geoLeft, geoTop))
        pixelRight, pixelBottom = self.convertGeoLocationToPixelLocation((geoRight, geoBottom))
        pixelLeft, pixelRight = sorted((pixelLeft, pixelRight))
        pixelTop, pixelBottom = sorted((pixelTop, pixelBottom))
        return pixelLeft, pixelTop, pixelRight, pixelBottom

    def convertPixelFrameToGeoFrame(self, pixelFrame):
        pixelLeft, pixelTop, pixelRight, pixelBottom = pixelFrame
        geoLeft, geoTop = self.convertPixelLocationToGeoLocation((pixelLeft, pixelTop))
        geoRight, geoBottom = self.convertPixelLocationToGeoLocation((pixelRight, pixelBottom))
        geoLeft, geoRight = sorted((geoLeft, geoRight))
        geoTop, geoBottom = sorted((geoTop, geoBottom))
        return geoLeft, geoTop, geoRight, geoBottom

    def convertGeoLocationToPixelLocation(self, geoLocation):
        g0, g1, g2, g3, g4, g5 = self.geoTransform
        xGeo, yGeo = geoLocation
        if g2 == 0:
            xPixel = (xGeo - g0) / float(g1)
            yPixel = (yGeo - g3 - xPixel*g4) / float(g5)
        else:
            xPixel = (yGeo*g2 - xGeo*g5 + g0*g5 - g2*g3) / float(g2*g4 - g1*g5)
            yPixel = (xGeo - g0 - xPixel*g1) / float(g2)
        return int(round(xPixel)), int(round(yPixel))

    def convertPixelLocationToGeoLocation(self, pixelLocation):
        return convertPixelLocationToGeoLocation(pixelLocation, self.geoTransform)

    def convertGeoLocationsToPixelLocations(self, geoLocations):
        return map(self.convertGeoLocationToPixelLocation, geoLocations)

    def convertPixelLocationsToGeoLocations(self, pixelLocations):
        return map(self.convertPixelLocationToGeoLocation, pixelLocations)

    def convertGeoDimensionsToPixelDimensions(self, geoWidth, geoHeight):
        return convertGeoDimensionsToPixelDimensions(geoWidth, geoHeight, self.geoTransform)

    def convertPixelDimensionsToGeoDimensions(self, pixelWidth, pixelHeight):
        g0, g1, g2, g3, g4, g5 = self.geoTransform
        return abs(pixelWidth * g1), abs(pixelHeight * g5)


class Window(object):

    values = None
    matrices = None
    matrixImages = None, None
    pairwiseMatrices = None
    images_pil = None
    images_pylab = None, None
    
    def __init__(self, left, top, width, height, packs):
        # Set
        self.width = width
        self.height = height
        self.packs = packs
        # Initialize
        self.frame = left, top, left + width, top + height

    def getFrame(self):
        return self.frame

    def getValues(self):
        if not self.values:
            length = self.width * self.height
            self.values = [struct.unpack('%d%s' % (length, typeByGDT[x[0]]), x[1]) for x in self.packs]
        return self.values

    def getMatrices(self):
        if not self.matrices:
            self.matrices = [numpy.reshape(x, (self.height, self.width)) for x in self.getValues()]
        return self.matrices

    def getMatrixImages(self, imageWidthInPixels, imageHeightInPixels):
        imageDimensions = imageWidthInPixels, imageHeightInPixels
        if not self.matrixImages[0] or self.matrixImages[1] != imageDimensions:
            self.matrixImages = getImages_pylab('matshow', self.getMatrices(), *imageDimensions), imageDimensions
        return self.matrixImages[0]

    def getImages_pil(self):
        if not self.images_pil:
            imageDimensions = self.width, self.height
            self.images_pil = [Image.fromstring(modeByType[x[0]], imageDimensions, x[1]) for x in self.packs]
        return self.images_pil

    def getImages_pylab(self, imageWidthInPixels, imageHeightInPixels):
        imageDimensions = imageWidthInPixels, imageHeightInPixels
        if not self.images_pylab[0] or self.images_pylab[1] != imageDimensions:
            self.images_pylab = getImages_pylab('imshow', self.getMatrices(), *imageDimensions), imageDimensions
        return self.images_pylab[0]

    def pickle(self):
        return pickle.dumps((self.width, self.height, self.packs))


def unpickle(data):
    if not data: return None
    width, height, packs = pickle.loads(data)
    return Window(0, 0, width, height, packs)


# Get helpers

def centerPixelFrame(pixelCenter, pixelWidth, pixelHeight):
    # Compute frame
    pixelLeft = pixelCenter[0] - pixelWidth / 2
    pixelTop = pixelCenter[1] - pixelHeight / 2
    pixelRight = pixelLeft + pixelWidth
    pixelBottom = pixelTop + pixelHeight
    # Return
    return pixelLeft, pixelTop, pixelRight, pixelBottom

def getCenter(frame):
    left, top, right, bottom = frame
    widthHalved = (right - left) / 2
    heightHalved = (bottom - top) / 2
    return int(left + widthHalved), int(top + heightHalved)

def getWindowCenter(windowLocation, windowPixelWidth, windowPixelHeight):
    left, top = windowLocation
    widthHalved = windowPixelWidth / 2
    heightHalved = windowPixelHeight / 2
    return int(left + widthHalved), int(top + heightHalved)

def getPixelsInFrame(frame):
    left, top, right, bottom = frame
    xRange = xrange(left, right)
    yRange = xrange(top, bottom)
    return [(x, y) for x in xRange for y in yRange]

def getImages_pylab(string_method, matrices, imageWidthInPixels, imageHeightInPixels):
    # Import system modules
    import matplotlib
    import pylab
    # Initialize
    images = []
    convertPixelLengthToInchLength = lambda x: x / float(matplotlib.rc_params()['figure.dpi'])
    imageWidthInInches = convertPixelLengthToInchLength(imageWidthInPixels)
    imageHeightInInches = convertPixelLengthToInchLength(imageHeightInPixels)
    imageDimensions = imageWidthInInches, imageHeightInInches
    # For each matrix,
    for matrix in matrices:
        # Create a new figure
        pylab.figure(figsize=imageDimensions)
        axes = pylab.axes()
        # Plot grayscale matrix
        method = getattr(axes, string_method)
        method(matrix, cmap=pylab.cm.gray)
        # Clear tick marks
        axes.set_xticks([])
        axes.set_yticks([])
        # Save image
        imageFile = StringIO.StringIO()
        pylab.savefig(imageFile, format='png')
        imageFile.seek(0)
        # Append
        images.append(Image.open(imageFile))
    # Return
    return images


class Information(object):

    # Constructor

    def __init__(self, informationPath):
        self.information = store.loadInformation(informationPath)

    # Positive location

    def getPositiveLocationPath(self):
        return self.information['positive location']['path']

    # Multispectral

    def getMultispectralImagePath(self):
        return self.information['multispectral image']['path']

    def getMultispectralImage(self):
        return load(self.getMultispectralImagePath())

    # Panchromatic

    def getPanchromaticImagePath(self):
        return self.information['panchromatic image']['path']

    def getPanchromaticImage(self):
        return load(self.getPanchromaticImagePath())

sample_store.py

# Import system modules
import sqlite3
import os
import numpy
import random
import scipy.io
import itertools
import cPickle as pickle
# Import custom modules
import view
import store
import image_store
import sequence


# Set SQL
sql_getSample = 'SELECT hasRoof, xGeo, yGeo, multispectralData, panchromaticData FROM samples'


# Shortcuts

def create(datasetPath):
    datasetPath = store.replaceFileExtension(datasetPath, 'db')
    if os.path.exists(datasetPath): os.remove(datasetPath)
    return Store(datasetPath)

def load(datasetPath):
    datasetPath = store.replaceFileExtension(datasetPath, 'db')
    if not os.path.exists(datasetPath): raise IOError('Dataset does not exist: ' + datasetPath)
    return Store(datasetPath)


# Restore

def restoreResult(result):
    # Extract
    hasRoof, xGeo, yGeo, multispectralData, panchromaticData = result
    geoCenter = xGeo, yGeo
    # Restore
    multispectralWindow = image_store.unpickle(multispectralData)
    panchromaticWindow = image_store.unpickle(panchromaticData)
    # Return
    return hasRoof, geoCenter, multispectralWindow, panchromaticWindow


class Store(object):

    # Constructor

    def __init__(self, datasetPath):
        # Fix extension
        datasetPath = store.replaceFileExtension(datasetPath, 'db')
        # Check whether the dataset exists
        flag_exists = True if os.path.exists(datasetPath) else False
        # Connect
        self.connection = sqlite3.connect(datasetPath)
        self.connection.text_factory = str
        self.cursor = self.connection.cursor()
        # If the dataset doesn't exist, create it
        if not flag_exists: 
            self.cursor.execute('CREATE TABLE samples (hasRoof INTEGER, xGeo REAL, yGeo REAL, multispectralData BLOB, panchromaticData BLOB)')
            self.connection.commit()
        # Remember
        self.datasetPath = datasetPath
    
    # Destructor

    def __del__(self):
        self.connection.close()

    # Add

    @store.commit
    def addSample(self, hasRoof, geoCenter, multispectralWindow, panchromaticWindow):
        xGeo, yGeo = geoCenter
        multispectralData = multispectralWindow.pickle()
        panchromaticData = panchromaticWindow.pickle()
        return 'INSERT INTO samples (hasRoof, xGeo, yGeo, multispectralData, panchromaticData) VALUES (?,?,?,?,?)', (hasRoof, xGeo, yGeo, multispectralData, panchromaticData)

    # Delete

    @store.commit
    def deleteSample(self, sampleID):
        return 'DELETE FROM samples WHERE rowid=?', [sampleID]

    # Get

    def getDatasetPath(self): 
        return self.datasetPath
    
    @store.fetchAll
    def getSampleIDs(self):
        return 'SELECT rowid FROM samples', None, store.pullFirst

    @store.fetchAll
    def getRandomSampleIDs(self):
        return 'SELECT rowid FROM samples ORDER BY RANDOM()', None, store.pullFirst

    @store.fetchAll
    def getPositiveSampleIDs(self):
        return 'SELECT rowid FROM samples WHERE hasRoof=1', None, store.pullFirst

    @store.fetchAll
    def getNegativeSampleIDs(self):
        return 'SELECT rowid FROM samples WHERE hasRoof=0', None, store.pullFirst

    @store.fetchAll
    def getSamplesByIDs(self, sampleIDs):
        return sql_getSample + ' WHERE samples.rowID IN (%s)' % ','.join(map(str, sampleIDs)), None, restoreResult

    @store.fetchAll
    def getSamples(self):
        return sql_getSample, None, restoreResult

    @store.fetchOne
    def getSample(self, sampleID):
        return sql_getSample + ' WHERE samples.rowID=?', (sampleID,), restoreResult

    @store.fetchAll
    def getGeoCenters(self):
        return 'SELECT xGeo, yGeo FROM samples', None

    # Count

    @store.fetchOne
    def countSamples(self):
        return 'SELECT COUNT(*) FROM samples', None, store.pullFirst

    @store.fetchOne
    def countPositiveSamples(self):
        return 'SELECT COUNT(*) FROM samples WHERE hasRoof=1', None, store.pullFirst

    @store.fetchOne
    def countNegativeSamples(self):
        return 'SELECT COUNT(*) FROM samples WHERE hasRoof=0', None, store.pullFirst

    def getStatistics(self):
        return {
            'total': self.countSamples(),
            'positive': self.countPositiveSamples(),
            'negative': self.countNegativeSamples(),
        }

    # Cut

    def cutIDs(self, testFraction, withRandomization=True):
        # For each cut,
        for cutPack in sequence.cut(self.getSampleIDs(), testFraction, withRandomization):
            # Yield
            yield cutPack

    # Save

    def saveForMatlab(self, targetPath):
        # Initialize
        samples = self.getSamples(); labels = []; geoCenters = []
        reds = []; greens = []; blues = []; infrareds = []; panchromatics = []
        # Assemble
        for index in xrange(len(samples)):
            # Extract
            hasRoof, geoCenter, multispectralWindow, panchromaticWindow = samples[index]
            # Gather
            labels.append(hasRoof); geoCenters.append(geoCenter)
            # Gather multispectral
            red, green, blue, infrared = multispectralWindow.getMatrices()
            reds.append(red); greens.append(green); blues.append(blue); infrareds.append(infrared)
            # Gather panchromatic 
            panchromatics.append(panchromaticWindow.getMatrices()[0])
        # Save
        matrixDictionary = {
            'labels': numpy.dstack(labels), 'geoCenters': numpy.dstack(geoCenters),
            'reds': numpy.dstack(reds), 'greens': numpy.dstack(greens),
            'blues': numpy.dstack(blues), 'infrareds': numpy.dstack(infrareds),
            'panchromatics': numpy.dstack(panchromatics),
        }
        scipy.io.savemat(targetPath, matrixDictionary, True)


# Save

def save(targetPath, datasetSampleIDs):
    # Open targetDataset
    targetDataset = Store(targetPath)
    # Save samples
    sampleCount = len(datasetSampleIDs)
    for sampleIndex, (sourceDataset, sampleID) in enumerate(datasetSampleIDs):
        targetDataset.addSample(*sourceDataset.getSample(sampleID))
        if sampleIndex % 100 == 0: 
            view.printPercentUpdate(sampleIndex + 1, sampleCount)
    view.printPercentFinal(sampleCount)
    # Return
    return targetDataset

sample_store_lush.py

# Import system modules
import numpy


def makeSampleLabelPaths(basePath):
    return basePath + '-samples', basePath + '-labels'

def makeLabelGeneratorFromLushDataset(filePath):
    # Set paths
    labelPath = makeSampleLabelPaths(filePath)[1]
    # Initialize
    labelFile = open(labelPath)
    header = labelFile.next()
    terms = []
    # For each line,
    for line in labelFile:
        # Extend terms
        terms.extend(int(x) for x in line.split())
        # Pop each term
        for termIndex in xrange(len(terms)):
            yield terms.pop(0)

def makeSampleGeneratorFromLushDataset(filePath):
    # Set paths
    samplePath = makeSampleLabelPaths(filePath)[0]
    # Initialize
    sampleFile = open(samplePath)
    header = sampleFile.next()
    sampleShape = [int(x) for x in header.split()[3:]]
    sampleSize = numpy.product(sampleShape)
    terms = []
    # For each line,
    for line in sampleFile:
        # Extend terms
        terms.extend(float(x) for x in line.split())
        # If we have enough terms,
        while len(terms) >= sampleSize:
            # Yield sample
            yield numpy.array(terms[:sampleSize]).reshape(sampleShape)
            # Set terms
            terms = terms[sampleSize:]

point_store.py

'Save and load points to a shapefile'
# Import system modules
import osgeo.ogr
import osgeo.osr
import os


# Core

def save(shapePath, geoLocations, proj4):
    'Save points in the given shapePath'
    # Get driver
    driver = osgeo.ogr.GetDriverByName('ESRI Shapefile')
    # Create shapeData
    shapePath = validateShapePath(shapePath)
    if os.path.exists(shapePath): 
        os.remove(shapePath)
    shapeData = driver.CreateDataSource(shapePath)
    # Create spatialReference
    spatialReference = getSpatialReferenceFromProj4(proj4)
    # Create layer
    layerName = os.path.splitext(os.path.split(shapePath)[1])[0]
    layer = shapeData.CreateLayer(layerName, spatialReference, osgeo.ogr.wkbPoint)
    layerDefinition = layer.GetLayerDefn()
    # For each point,
    for pointIndex, geoLocation in enumerate(geoLocations):
        # Create point
        geometry = osgeo.ogr.Geometry(osgeo.ogr.wkbPoint)
        geometry.SetPoint(0, geoLocation[0], geoLocation[1])
        # Create feature
        feature = osgeo.ogr.Feature(layerDefinition)
        feature.SetGeometry(geometry)
        feature.SetFID(pointIndex)
        # Save feature
        layer.CreateFeature(feature)
        # Cleanup
        geometry.Destroy()
        feature.Destroy()
    # Cleanup
    shapeData.Destroy()
    # Return
    return shapePath

def load(shapePath):
    'Given a shapePath, return a list of points in GIS coordinates'
    # Open shapeData
    shapeData = osgeo.ogr.Open(validateShapePath(shapePath))
    # Validate shapeData
    validateShapeData(shapeData)
    # Get the first layer
    layer = shapeData.GetLayer()
    # Initialize
    points = []
    # For each point,
    for index in xrange(layer.GetFeatureCount()):
        # Get
        feature = layer.GetFeature(index)
        geometry = feature.GetGeometryRef()
        # Make sure that it is a point
        if geometry.GetGeometryType() != osgeo.ogr.wkbPoint: 
            raise ShapeDataError('This module can only load points; use geometry_store.py')
        # Get pointCoordinates
        pointCoordinates = geometry.GetX(), geometry.GetY()
        # Append
        points.append(pointCoordinates)
        # Cleanup
        feature.Destroy()
    # Get spatial reference as proj4
    proj4 = layer.GetSpatialRef().ExportToProj4()
    # Cleanup
    shapeData.Destroy()
    # Return
    return points, proj4

def merge(sourcePaths, targetPath):
    'Merge a list of shapefiles into a single shapefile'
    # Load
    items = [load(validateShapePath(x)) for x in sourcePaths]
    pointLists = [x[0] for x in items]
    points = reduce(lambda x,y: x+y, pointLists)
    spatialReferences= [x[1] for x in items]
    # Make sure that all the spatial references are the same
    if len(set(spatialReferences)) != 1: 
        raise ShapeDataError('The shapefiles must have the same spatial reference')
    spatialReference = spatialReferences[0]
    # Save
    save(validateShapePath(targetPath), points, spatialReference)

def getSpatialReferenceFromProj4(proj4):
    'Return GDAL spatial reference object from proj4 string'
    spatialReference = osgeo.osr.SpatialReference()
    spatialReference.ImportFromProj4(proj4)
    return spatialReference


# Validate

def validateShapePath(shapePath):
    'Validate shapefile extension'
    return os.path.splitext(str(shapePath))[0] + '.shp'

def validateShapeData(shapeData):
    'Make sure we can access the shapefile'
    # Make sure the shapefile exists
    if not shapeData:
        raise ShapeDataError('The shapefile is invalid')
    # Make sure there is exactly one layer
    if shapeData.GetLayerCount() != 1:
        raise ShapeDataError('The shapefile must have exactly one layer')


# Error

class ShapeDataError(Exception):
    pass

sequence.py

# Import system modules
import random
import numpy


def uniquify(sequence):
    seen = set()
    return [x for x in sequence if x not in seen and not seen.add(x)]


def cut(sequence, fraction, withRandomization=True):
    # Copy
    sequence = list(sequence)
    # Randomize
    if withRandomization: random.shuffle(sequence)
    # Count
    totalCount = len(sequence)
    partialCount = int(numpy.ceil(totalCount * fraction))
    # Split
    for firstIndex in xrange(0, totalCount, partialCount):
        lastIndex = firstIndex + partialCount
        insideFraction = sequence[firstIndex:lastIndex]
        outsideFraction = sequence[0:firstIndex] + sequence[lastIndex:totalCount]
        yield outsideFraction, insideFraction

store.py

# Import system modules
import os
import re
import sys
import time
import numpy
import datetime
import ConfigParser


# SQL

def execute(function, args, kwargs):
    # Extract
    self = args[0]
    # Execute
    z = function(*args, **kwargs)
    sql, value = z[:2]
    if value == None: self.cursor.execute(sql)
    else: self.cursor.execute(sql, value)
    # Return
    method = z[2] if len(z) > 2 else None
    return self, method

def commit(function):
    def wrapper(*args, **kwargs):
        self = execute(function, args, kwargs)[0]
        self.connection.commit()
        return self.cursor.lastrowid
    return wrapper

def fetchOne(function):
    def wrapper(*args, **kwargs):
        self, method = execute(function, args, kwargs)
        result = self.cursor.fetchone()
        return method(result) if method else result
    return wrapper

def fetchAll(function):
    def wrapper(*args, **kwargs):
        self, method = execute(function, args, kwargs)
        results = self.cursor.fetchall()
        return map(method, results) if method else results
    return wrapper

def pullFirst(result):
    if result != None: return result[0]

def pullTrueIfResult(result):
    return True if result != None else False


# File 

def replaceFileExtension(filePath, newExtension):
    if not newExtension.startswith('.'): newExtension = '.' + newExtension
    base = os.path.splitext(filePath)[0]
    return base + newExtension

def extractFileBaseName(filePath):
    filename = os.path.split(filePath)[1]
    return os.path.splitext(filename)[0]

def extractFileName(filePath):
    return os.path.split(filePath)[1]

def verifyPath(filePath):
    if not os.path.exists(filePath): raise QueueError('Path not found: %s' % filePath)
    return filePath.strip()

def fillPath(rootPath, relativeFolderPath, relativeFilePath):
    folderPath = os.path.join(rootPath, relativeFolderPath)
    filePath = os.path.join(folderPath, relativeFilePath)
    return os.path.abspath(filePath)

def makeTimestamp():
    return datetime.datetime.now().strftime('%Y%m%d%H%M%S')

def makeFolderSafely(folderPath):
    if not os.path.exists(folderPath): os.mkdir(folderPath)
    return folderPath

def removeSafely(filePath):
    if os.path.exists(filePath): os.remove(filePath)


# Information

def saveInformation(filePath, valueByOptionBySection, fileExtension='info'):
    # Initialize
    configuration = ConfigParser.RawConfigParser()
    # For each section,
    for section in valueByOptionBySection:
        # Initialize
        valueByOption = valueByOptionBySection[section]
        # Add section
        addConfigurationSection(configuration, section, valueByOption)
    # Write
    filePath = replaceFileExtension(filePath, fileExtension)
    configuration.write(open(filePath, 'wt'))

def addConfigurationSection(configuration, section, valueByOption):
    # Add section
    configuration.add_section(section)
    # For each option,
    for option in valueByOption:
        # Initialize
        value = valueByOption[option]
        # If value is a dictionary,
        if isinstance(value, dict):
            # Add section
            addConfigurationSection(configuration, '%s.%s' % (section, option), value)
        # Otherwise
        else:
            # Add option
            configuration.set(section, option, value)

def loadInformation(filePath, fileExtension='info'):
    # Initialize
    configuration = ConfigParser.RawConfigParser()
    valueByOptionBySection = {}
    # Read
    filePath = replaceFileExtension(filePath, fileExtension)
    configuration.read(filePath)
    # For each section,
    for section in configuration.sections():
        # Initialize
        valueByOption = {}
        # For each option,
        for option in configuration.options(section):
            # Get option and value
            valueByOption[option] = configuration.get(section, option)
        # Store
        valueByOptionBySection[section] = valueByOption
    # Return
    return valueByOptionBySection

def loadQueue(queuePath, convertByName):
    # Load
    valueByNameBySection = loadInformation(queuePath, fileExtension='queue')
    sections = valueByNameBySection.keys()
    # Convert values
    for section in sections:
        valueByNameBySection[section] = convertValueByName(valueByNameBySection[section], convertByName)
    # Set globalParameterByName
    globalParameterByName = valueByNameBySection.get('parameters', {})
    if 'parameters' in sections: 
        sections.remove('parameters')
    # Set parameterByTaskByName
    parameterByTaskByName = {}
    for section in sections:
        # Load
        valueByName = globalParameterByName.copy()
        valueByName.update(valueByNameBySection[section])
        # Save
        parameterByTaskByName[section] = valueByName
    # Return
    return parameterByTaskByName

def convertValueByName(valueByName, convertByName):
    # For each name,
    for name, value in valueByName.items():
        # Convert
        try: 
            convert = convertByName[name]
            valueByName[name] = convert(value)
        except KeyError:
            raise QueueError('%s is undefined' % name)
        except ValueError: 
            raise QueueError('%s=%s has the wrong type' % (name, value))
    # Return
    return valueByName

def stringifyList(items):
    return '\n' + '\n'.join(str(x) for x in items)

def stringifyNestedList(lists):
    return '\n' + '\n'.join(' '.join(str(item) for item in list) for list in lists)

def unstringifyStringList(content):
    lines = map(str.strip, content.splitlines())
    return filter(lambda line: True if line else False, lines)

def unstringifyFloatList(content):
    return map(float, content.split())

def unstringifyIntegerList(content):
    return map(int, content.split())

def unstringifyNestedIntegerList(content):
    return [[int(x) for x in line.split()] for line in unstringifyStringList(content)]

class Information(object):

    def __init__(self, informationPath):
        self.informationPath = informationPath
        self.information = loadInformation(informationPath)
        self.parameterByName = self.information.get('parameters', {})
        self.experimentPath = os.path.dirname(os.path.dirname(os.path.dirname(informationPath)))

    def expandPath(self, relativePath):
        return os.path.join(self.experimentPath, relativePath)


# Time

def recordElapsedTime(function):
    # Define wrapper
    def wrapper(*args, **kwargs):
        # Run
        startTimeInSeconds = time.time()
        resultByName = function(*args, **kwargs)
        endTimeInSeconds = time.time()
        # Record
        elapsedTimeInSeconds = int(round(endTimeInSeconds - startTimeInSeconds))
        if not resultByName: 
            resultByName = {}
        resultByName['elapsed time in seconds'] = elapsedTimeInSeconds
        print 'elapsed time in seconds = %s' % elapsedTimeInSeconds
        # Return
        return resultByName
    # Return
    return wrapper


# Module

def getLibraryModule(moduleName):
    return __import__('libraries.' + moduleName, fromlist=['libraries'])


# Error

class QueueError(Exception):
    pass

view.py

# Import system modules
import sys


# Feedback

def printDirectly(feedback):
    sys.stdout.write(feedback)
    sys.stdout.flush()

def printPercentUpdate(currentCount, totalCount):
    printDirectly('\r% 3d %% \t%d        ' % (100 * currentCount / totalCount, currentCount))

def printPercentFinal(totalCount):
    printDirectly('\r100 %% \t%d        \n' % totalCount)

def trackProgress(generator, totalCount, packetLength):
    # Initialize
    items = []
    # For each item,
    for currentCount, item in enumerate(generator):
        items.append(item)
        if currentCount % packetLength == 0: 
            printPercentUpdate(currentCount + 1, totalCount)
    # Return
    printPercentFinal(totalCount)
    return items