Save and load geometries and fields from a shapefile with GDAL

Here we demonstrate the use of a GDAL wrapper for processing geometries such as points, lines, polygons and their associated fields, properties, attributes.

  • Save geometries with their associated fields to SHP, KML, GPX files.
  • Load geometries with their associated fields from SHP, KML, GPX files.

See the complete table of vector formats supported by GDAL. Driver names are listed in the column entitled Code.

Example

Download the code, unzip and start IPython.

wget http://invisibleroads.com/tutorials/_downloads/gdal-shapefile-geometries.zip
unzip gdal-shapefile-geometries.zip
cd gdal-shapefile-geometries
ipython

Import modules.

# Import system modules
import osgeo.ogr
import shapely.geometry
# Import custom modules
import geometry_store

Save points to a GPX file using the longitude and latitude spatial reference.

geometry_store.save_points('communities.gpx', geometry_store.proj4LL, [
    (0, 0),
    (1, 1),
], driverName='GPX')

Save lines with fields to a KML file using the spherical mercator spatial reference.

geometry_store.save('roads.kml', geometry_store.proj4SM, [
    shapely.geometry.LineString([(0, 5), (5, 0)]),
    shapely.geometry.LineString([(5, 0), (10, 0)]),
    shapely.geometry.LineString([(10, 0), (10, 5)]),
], [
    ('Left road', 1980),
    ('Middle road', 1985),
    ('Right road', 1995),
], [
    ('Name', osgeo.ogr.OFTString),
    ('Year', osgeo.ogr.OFTInteger),
], driverName='KML')

Save polygons as a compressed ESRI shapefile transformed to the spherical mercator spatial reference.

geometry_store.save('geometries.shp.zip', geometry_store.proj4LL, [
    shapely.geometry.Polygon([(0, 0), (0, 1), (1, 1), (1, 0), (0, 0)]),
    shapely.geometry.MultiPolygon([
        (((0, 0), (0, 3), (3, 3), (3, 0)), [((0, 0), (0, 2), (2, 2), (2, 0))]),
    ]),
], targetProj4=geometry_store.proj4SM)

Load geometries with fields from a compressed ESRI shapefile transformed to the longitude and latitude spatial reference.

# Load
proj4, shapelyGeometries, fieldPacks, fieldDefinitions = geometry_store.load('geometries.shp.zip',
    targetProj4=geometry_store.proj4LL)
# Display
for shapelyGeometry, fieldPack in zip(shapelyGeometries, fieldPacks):
    print
    for fieldValue, (fieldName, fieldType) in zip(fieldPack, fieldDefinitions):
        print '%s = %s' % (fieldName, fieldValue)
    print shapelyGeometry

Code

'Tests for geometry_store'
# Import system modules
import os
import shutil
import tempfile
import unittest
from osgeo import ogr
from shapely import geometry
# Import custom modules
import geometry_store


# Define constants

shapelyGeometries = [
    geometry.Polygon([(0, 0), (0, 10), (10, 10), (10, 0), (0, 0)]),
    geometry.Polygon([(10, 0), (10, 10), (20, 10), (20, 0), (10, 0)]),
]
fieldPacks = [
    ('xxx', 11111, 44444.44),
    ('yyy', 22222, 88888.88),
]
fieldDefinitions = [
    ('Name', ogr.OFTString),
    ('Population', ogr.OFTInteger),
    ('GDP', ogr.OFTReal),
]


# Define tests

class GeometryStoreTest(unittest.TestCase):
    'Demonstrate geometry_store usage'

    index = 0

    def setUp(self):
        self.temporaryFolder = tempfile.mkdtemp()

    def tearDown(self):
        shutil.rmtree(self.temporaryFolder)

    def getPath(self, fileExtension):
        'Return a path with the given fileExtension in temporaryFolder'
        self.index += 1
        return os.path.join(
            self.temporaryFolder, str(self.index) + fileExtension)

    def test(self):
        'Run tests'

        print('Save and load a SHP file without attributes')
        path = self.getPath('.shp')
        geometry_store.save(path, geometry_store.proj4LL, shapelyGeometries)
        result = geometry_store.load(path)
        self.assertEqual(result[0].strip(), geometry_store.proj4LL)
        self.assertEqual(len(result[1]), len(shapelyGeometries))

        print('Save and load a SHP file with attributes')
        path = self.getPath('.shp')
        geometry_store.save(
            path, geometry_store.proj4LL, shapelyGeometries, fieldPacks,
            fieldDefinitions)
        result = geometry_store.load(path)
        self.assertEqual(len(result[2]), len(fieldPacks))
        for shapelyGeometry, fieldPack in zip(result[1], result[2]):
            print()
            for fieldValue, (
                fieldName, fieldType
            ) in zip(fieldPack, result[3]):
                print('%s = %s' % (fieldName, fieldValue))
            print(shapelyGeometry)

        print('Save a SHP file with attributes with different targetProj4')
        path = self.getPath('.shp')
        geometry_store.save(
            path, geometry_store.proj4LL, shapelyGeometries, fieldPacks,
            fieldDefinitions, targetProj4=geometry_store.proj4SM)
        result = geometry_store.load(path)
        self.assertNotEqual(result[0].strip(), geometry_store.proj4LL)

        print('Load a SHP file with attributes with different targetProj4')
        path = self.getPath('.shp')
        geometry_store.save(
            path, geometry_store.proj4LL, shapelyGeometries, fieldPacks,
            fieldDefinitions)
        result = geometry_store.load(path, targetProj4=geometry_store.proj4SM)
        self.assertNotEqual(result[0].strip(), geometry_store.proj4LL)

        print('Save and load a ZIP file without attributes using save')
        path = self.getPath('.shp.zip')
        geometry_store.save(path, geometry_store.proj4LL, shapelyGeometries)
        result = geometry_store.load(path)
        self.assertEqual(result[0].strip(), geometry_store.proj4LL)
        self.assertEqual(len(result[1]), len(shapelyGeometries))

        print('Save and load a ZIP file with attributes using save')
        path = self.getPath('.shp.zip')
        geometry_store.save(
            path, geometry_store.proj4LL, shapelyGeometries, fieldPacks,
            fieldDefinitions)
        result = geometry_store.load(path)
        self.assertEqual(len(result[2]), len(fieldPacks))

        print('Test saving and loading ZIP files of point coordinates')
        path = self.getPath('.shp.zip')
        geometry_store.save_points(
            path, geometry_store.proj4LL, [(0, 0)], fieldPacks,
            fieldDefinitions)
        result = geometry_store.load_points(path)
        self.assertEqual(result[1], [(0, 0)])

        print('Test get_transform_point')
        transform_point0 = geometry_store.get_transform_point(
            geometry_store.proj4LL, geometry_store.proj4LL)
        transform_point1 = geometry_store.get_transform_point(
            geometry_store.proj4LL, geometry_store.proj4SM)
        self.assertNotEqual(transform_point0(0, 0), transform_point1(0, 0))

        print('Test get_transform_geometry')
        transform_geometry = geometry_store.get_transform_geometry(
            geometry_store.proj4LL, geometry_store.proj4SM)
        self.assertEqual(type(transform_geometry(
            geometry.Point(0, 0))), type(geometry.Point(0, 0)))
        self.assertEqual(type(transform_geometry(ogr.CreateGeometryFromWkt(
            'POINT (0 0)'))), type(ogr.CreateGeometryFromWkt('POINT (0 0)')))

        print('Test get_coordinateTransformation')
        geometry_store.get_coordinateTransformation(
            geometry_store.proj4LL, geometry_store.proj4SM)

        print('Test get_spatialReference')
        geometry_store.get_spatialReference(geometry_store.proj4LL)
        with self.assertRaises(geometry_store.GeometryError):
            geometry_store.get_spatialReference('')

        print('Test get_geometryType')
        geometry_store.get_geometryType(shapelyGeometries)

        print('Test save() when a fieldPack has fewer fields than definitions')
        with self.assertRaises(geometry_store.GeometryError):
            path = self.getPath('.shp')
            geometry_store.save(
                path, geometry_store.proj4LL, shapelyGeometries,
                [x[1:] for x in fieldPacks], fieldDefinitions)

        print('Test save() when a fieldPack has more fields than definitions')
        with self.assertRaises(geometry_store.GeometryError):
            path = self.getPath('.shp')
            geometry_store.save(
                path, geometry_store.proj4LL, shapelyGeometries,
                [x * 2 for x in fieldPacks], fieldDefinitions)

        print('Test save() when the driverName is unrecognized')
        with self.assertRaises(geometry_store.GeometryError):
            path = self.getPath('.shp')
            geometry_store.save(
                path, geometry_store.proj4LL, shapelyGeometries, driverName='')

        print('Test load() when format is unrecognized')
        with self.assertRaises(geometry_store.GeometryError):
            path = self.getPath('')
            geometry_store.load(path)
"""
GDAL wrapper for reading and writing geospatial data
to a variety of vector formats.

For a list of supported vector formats and driver names,
please see http://www.gdal.org/ogr/ogr_formats.html
"""
# Import system modules
import os
from osgeo import ogr, osr
from shapely import wkb, geometry
# Import custom modules
import zip_store


# Set constants
proj4LL = '+proj=longlat +datum=WGS84 +no_defs'
proj4SM = '+proj=merc +a=6378137 +b=6378137 +lat_ts=0.0 +lon_0=0.0 +x_0=0.0 +y_0=0 +k=1.0 +units=m +nadgrids=@null +no_defs'  # noqa


# Define shortcuts

def save_points(
        targetPath, sourceProj4, coordinateTuples, fieldPacks=None,
        fieldDefinitions=None, driverName='ESRI Shapefile', targetProj4=''):
    'Save points using the given proj4 and fields'
    return save(targetPath, sourceProj4, [
        geometry.Point(x) for x in coordinateTuples
    ], fieldPacks, fieldDefinitions, driverName, targetProj4)


def load_points(sourcePath, sourceProj4='', targetProj4=''):
    'Load proj4, points, fields'
    proj4, shapelyGeometries, fieldPacks, fieldDefinitions = load(
        sourcePath, sourceProj4, targetProj4)
    return proj4, [
        (point.x, point.y) for point in shapelyGeometries
    ], fieldPacks, fieldDefinitions


# Define core

@zip_store.save
def save(
        targetPath, sourceProj4, shapelyGeometries, fieldPacks=None,
        fieldDefinitions=None, driverName='ESRI Shapefile', targetProj4=''):
    # Validate arguments
    if not fieldPacks:
        fieldPacks = []
    if not fieldDefinitions:
        fieldDefinitions = []
    if fieldPacks and set(
        len(x) for x in fieldPacks
    ) != set([len(fieldDefinitions)]):
        raise GeometryError('A field definition is required for each field')
    # Make dataSource
    if os.path.exists(targetPath):
        os.remove(targetPath)
    dataDriver = ogr.GetDriverByName(driverName)
    if not dataDriver:
        raise GeometryError('Could not load driver: {}'.format(driverName))
    dataSource = dataDriver.CreateDataSource(targetPath)
    # Make layer
    layerName = os.path.splitext(os.path.basename(targetPath))[0]
    spatialReference = get_spatialReference(targetProj4 or sourceProj4)
    geometryType = get_geometryType(shapelyGeometries)
    layer = dataSource.CreateLayer(layerName, spatialReference, geometryType)
    # Make fieldDefinitions in featureDefinition
    for fieldName, fieldType in fieldDefinitions:
        layer.CreateField(ogr.FieldDefn(fieldName, fieldType))
    featureDefinition = layer.GetLayerDefn()
    # Save features
    transform_geometry = get_transform_geometry(sourceProj4, targetProj4)
    for shapelyGeometry, fieldPack in zip(
        shapelyGeometries, fieldPacks
    ) if fieldPacks else ((x, []) for x in shapelyGeometries):
        # Prepare feature
        feature = ogr.Feature(featureDefinition)
        feature.SetGeometry(transform_geometry(ogr.CreateGeometryFromWkb(
            shapelyGeometry.wkb)))
        for fieldIndex, fieldValue in enumerate(fieldPack):
            feature.SetField(fieldIndex, fieldValue)
        # Save feature
        layer.CreateFeature(feature)
        # Clean up
        feature.Destroy()
    # Return
    return targetPath


@zip_store.load
def load(sourcePath, sourceProj4='', targetProj4=''):
    'Load proj4, shapelyGeometries, fields'
    # Get layer
    dataSource = ogr.Open(sourcePath)
    if not dataSource:
        raise GeometryError('Could not load {}'.format(
            os.path.basename(sourcePath)))
    layer = dataSource.GetLayer()
    # Get fieldDefinitions from featureDefinition
    featureDefinition = layer.GetLayerDefn()
    fieldIndices = range(featureDefinition.GetFieldCount())
    fieldDefinitions = []
    for fieldIndex in fieldIndices:
        fieldDefinition = featureDefinition.GetFieldDefn(fieldIndex)
        fieldDefinitions.append((
            fieldDefinition.GetName(), fieldDefinition.GetType()))
    # Get spatialReference
    spatialReference = layer.GetSpatialRef()
    if spatialReference:
        sourceProj4 = spatialReference.ExportToProj4() or sourceProj4
    # Load shapelyGeometries and fieldPacks
    shapelyGeometries, fieldPacks = [], []
    transform_geometry = get_transform_geometry(sourceProj4, targetProj4)
    feature = layer.GetNextFeature()
    while feature:
        # Append
        shapelyGeometries.append(wkb.loads(transform_geometry(
            feature.GetGeometryRef()).ExportToWkb()))
        fieldPacks.append([feature.GetField(x) for x in fieldIndices])
        # Get the next feature
        feature = layer.GetNextFeature()
    return (
        targetProj4 or sourceProj4, shapelyGeometries, fieldPacks,
        fieldDefinitions)


def get_transform_point(sourceProj4, targetProj4=proj4LL):
    if sourceProj4 == targetProj4:
        return lambda x, y: (x, y)
    coordinateTransformation = get_coordinateTransformation(
        sourceProj4, targetProj4)
    return lambda x, y: coordinateTransformation.TransformPoint(x, y)[:2]


def get_transform_geometry(sourceProj4, targetProj4=proj4LL):
    if not targetProj4 or sourceProj4 == targetProj4:
        return lambda x: x
    coordinateTransformation = get_coordinateTransformation(
        sourceProj4, targetProj4)

    def transform_geometry(g):
        # Test for shapelyGeometry
        isShapely = isinstance(g, geometry.base.BaseGeometry)
        # If we have a shapelyGeometry, convert it to a gdalGeometry
        if isShapely:
            g = ogr.CreateGeometryFromWkb(g.wkb)
        # If we could not transform the gdalGeometry,
        if g.Transform(coordinateTransformation):
            raise GeometryError(
                'Could not transform geometry: {}'.format(g.ExportToWkt()))
        # If we originally had a shapelyGeometry, convert it back
        if isShapely:
            g = wkb.loads(g.ExportToWkb())
        # Return
        return g

    return transform_geometry


def get_coordinateTransformation(sourceProj4, targetProj4=proj4LL):
    sourceSRS = get_spatialReference(sourceProj4)
    targetSRS = get_spatialReference(targetProj4)
    return osr.CoordinateTransformation(sourceSRS, targetSRS)


def get_spatialReference(proj4):
    spatialReference = osr.SpatialReference()
    if spatialReference.ImportFromProj4(proj4):
        raise GeometryError('Could not import proj4: {}'.format(proj4))
    return spatialReference


def get_geometryType(shapelyGeometries):
    geometryTypes = list(set(type(x) for x in shapelyGeometries))
    return ogr.wkbUnknown if len(geometryTypes) > 1 else {
        geometry.Point: ogr.wkbPoint,
        geometry.point.PointAdapter: ogr.wkbPoint,
        geometry.LineString: ogr.wkbLineString,
        geometry.linestring.LineStringAdapter: ogr.wkbLineString,
        geometry.Polygon: ogr.wkbPolygon,
        geometry.polygon.PolygonAdapter: ogr.wkbPolygon,
        geometry.MultiPoint: ogr.wkbMultiPoint,
        geometry.multipoint.MultiPointAdapter: ogr.wkbMultiPoint,
        geometry.MultiLineString: ogr.wkbMultiLineString,
        geometry.multilinestring.MultiLineStringAdapter:
            ogr.wkbMultiLineString,
        geometry.MultiPolygon: ogr.wkbMultiPolygon,
        geometry.multipolygon.MultiPolygonAdapter: ogr.wkbMultiPolygon,
    }[geometryTypes[0]]


# Define errors

class GeometryError(Exception):
    'Exception raised when there is an error loading or saving geometries'
    pass
"""
ZipFile wrapper for reading and writing to ZIP files.
"""
# Import system modules
import os
import shutil
import zipfile
import tempfile
from decorator import decorator


# Define core

@decorator
def save(function, *args, **kwargs):
    """
    Decorator to support saving to ZIP files

    If the first argument has a ZIP extension, it runs the function on the 
    first argument minus the ZIP extension and compresses the resulting files.
    """
    # Get targetExtension
    targetPath = kwargs.get('targetPath', args[0])
    targetBase, targetExtension = os.path.splitext(targetPath)
    # If the targetPath does not have a ZIP extension,
    if targetExtension.lower() != '.zip':
        # Run function as usual
        return function(*args, **kwargs)
    # Make temporaryFolder
    with TemporaryFolder() as temporaryFolder:
        # Run function in temporaryFolder
        temporaryPath = os.path.join(temporaryFolder, os.path.basename(targetBase))
        functionResult = function(temporaryPath, *args[1:], **kwargs)
        # Make zipFile
        with zipfile.ZipFile(targetPath, 'w', zipfile.ZIP_DEFLATED) as zipFile:
            # Walk sourceFolderPath
            for rootPath, directories, fileNames in os.walk(temporaryFolder):
                # For each file,
                for fileName in fileNames:
                    filePath = os.path.join(rootPath, fileName)
                    relativePath = filePath[len(temporaryFolder) + 1:]
                    zipFile.write(filePath, relativePath, zipfile.ZIP_DEFLATED)
        # Return
        return targetPath

@decorator
def load(function, *args, **kwargs):
    """
    Decorator to support loading from ZIP files

    If the first argument has a ZIP extension, it uncompresses the 
    first argument and runs the function on the resulting files.
    """
    # Get sourceExtension
    sourcePath = kwargs.get('sourcePath', args[0])
    sourceBase, sourceExtension = os.path.splitext(sourcePath)
    # If the sourcePath does not have a ZIP extension,
    if sourceExtension.lower() != '.zip':
        # Run function as usual
        return function(*args, **kwargs)
    # Make temporaryFolder
    with TemporaryFolder() as temporaryFolder:
        # Open zipFile
        with zipfile.ZipFile(sourcePath) as zipFile:
            # Unzip to temporaryFolder
            zipFile.extractall(temporaryFolder)
            # Run function on extracted files and exit on first success
            errors = []
            for fileName in zipFile.namelist():
                try:
                    temporaryPath = os.path.join(temporaryFolder, fileName)
                    return function(temporaryPath, *args[1:], **kwargs)
                except Exception as error:
                    errors.append(str(error))
            else:
                raise ZipError('Could not run {} on any file in {}:\n{}'.format(
                    function,
                    sourcePath,
                    '\n'.join(errors),
                ))


# Define wrappers

class TemporaryFolder(object):

    def __init__(self, suffix='', prefix='tmp', dir=None):
        self.suffix = suffix
        self.prefix = prefix
        self.dir = dir

    def __enter__(self):
        self.temporaryFolder = tempfile.mkdtemp(self.suffix, self.prefix, self.dir)
        return self.temporaryFolder

    def __exit__(self, type, value, traceback):
        shutil.rmtree(self.temporaryFolder)


# Define errors

class ZipError(object):
    pass