Multidimensional Embeddings for Geospatial Data with Hex2Vec and Storing in Geobase
This tutorial demonstrates how to create hex2vec embeddings for geospatial data using srai and store them in a geobase database. We will use the following Python libraries: srai
(version 0.7.5
), geopandas
, and sqlalchemy
.
How Hex2Vec Embedder Works
The Hex2Vec embedder is designed to generate embeddings for hexagonally partitioned geographical regions. These embeddings capture the spatial relationships and the characteristics of the features present within those regions.
How are geospatial embeddings created?
The creation of geospatial embeddings involves several steps:
- select a geographical area of interest
- load OpenStreetMap (OSM) data for that area
- regionalize the area into hexagonal regions
- join the regions with the OSM features
- create embeddings for the regions based on the features
- store the embeddings in a geobase database
This is a high-level overview of the process:
In the rest of this tutorial, we will walk through each of these steps in detail.
Installing Required Packages
To begin, you need to install several Python packages. These packages include libraries for working with geospatial data, embedding models, and Geobase.
!pip install srai[all]
!pip install folium matplotlib mapclassify
!pip install pgvector
!pip install geoalchemy2
Importing Necessary Libraries
import pandas as pd
from sqlalchemy import create_engine, Table, Column, Integer, Float, MetaData
from sqlalchemy.dialects.postgresql import ARRAY
import geopandas as gpd
from pgvector.sqlalchemy import Vector
from geoalchemy2 import Geometry
from sqlalchemy import text, insert
import numpy as np
sqlalchemy
: A Python SQL toolkit for database interaction.pgvector
: Allows you to store vector data in PostgreSQL.geoalchemy2
: Extends SQLAlchemy to support geospatial data types.
Import Necessary SRAI Libraries
from srai.embedders import Hex2VecEmbedder
from srai.joiners import IntersectionJoiner
from srai.loaders import OSMPbfLoader
from srai.loaders.osm_loaders.filters import HEX2VEC_FILTER
from srai.neighbourhoods.h3_neighbourhood import H3Neighbourhood
from srai.regionalizers import H3Regionalizer, geocode_to_region_gdf
from srai.plotting import plot_regions, plot_numeric_data
- OSMPbfLoader: Loads OSM data for a specified area.
- H3Regionalizer: Converts geographical areas into H3 hexagons at a specific resolution.
- IntersectionJoiner: Joins geospatial regions with OSM features.
Load OSM Data and Regionalize
- Load OSM Data: Use OSMOnlineLoader to load features for a specific area and query.
- Regionalize Data: We use the
H3Regionalizer
to divide the area of interest into hexagonal regions. The H3 indexing system is a hierarchical hexagonal grid system that is useful for spatial analysis.
loader = OSMPbfLoader()
regionalizer = H3Regionalizer(resolution=11)
area = geocode_to_region_gdf("City of London")
regions = regionalizer.transform(area)
Load Features
Using the loader, we load the features from OSM within the specified area. The HEX2VEC_FILTER
ensures that only relevant features for the Hex2Vec model are loaded.
features = loader.load(area, HEX2VEC_FILTER)
Initialize the Joiner
The IntersectionJoiner is used to join the spatial features with the defined regions. This step ensures that each region has associated feature data for embedding.
joiner = IntersectionJoiner()
Join Regions and Features
The joiner transforms the regions and features into a joint representation, necessary for embedding.
joint = joiner.transform(regions, features)
Define the Neighbourhood
We define the spatial neighbourhood using the H3Neighbourhood class, which will help the embedder understand spatial relationships.
neighbourhood = H3Neighbourhood(regions_gdf=regions)
Initialize, Fit and Transform Embedder
We fit the embedder to the regions, features, joint representation, and neighbourhood. This step trains the model to create meaningful embeddings.
embedder = Hex2VecEmbedder()
embedder = Hex2VecEmbedder([15, 10, 3])
embedder.fit(regions, features, joint, neighbourhood, batch_size=128)
embeddings = embedder.transform(regions, features, joint)
Visualize the Results
Use srai.plotting
to visualize the regions and the embeddings.
folium_map = plot_regions(area, colormap=["rgba(0,0,0,0.1)"], tiles_style="CartoDB positron")
plot_numeric_data(regions, 0, embeddings, map=folium_map)
Saving Embeddings to Geobase
- Create a Connection to the Database: Use sqlalchemy to create an engine for the geobase database.
- Prepare Data for Storage: Convert the embeddings to a DataFrame and then to a GeoDataFrame.
- Define the Table Schema: Define the metadata and table structure.
Note: The pgvector extension must be enabled in the Geobase project to support vector data types.
Here is how you can enable the pgvector extension in Geobase:
Create a Connection to the Database
As a first step, create a connection to the geobase database using sqlalchemy.
To copy the connection string from the Geobase project, navigate to the project settings.
Import database connection details from environment variables.
from google.colab import userdata
db_string = userdata.get('db_string')
Or you can import using the .env file
from dotenv import load_dotenv
load_dotenv()
import os
db_string = os.getenv('DB_STRING')
Create an engine using the connection string.
# Create a connection to the database
engine = create_engine(db_string)
Note: The connection string should be in the format:
postgresql://username:password@host:port/database
To check if the connection is successful, you can use the following code:
try:
connection = engine.connect()
print("Connection to the database successful!")
connection.close()
except Exception as e:
print("Connection failed:", e)
Note: While executing any SQL commands you need to start and finish
transactions using BEGIN;
and COMMIT;
respectively. Or you can use
conn.begin()
and conn.commit()
methods if you are using a connection
object.
create_table_sql = """
BEGIN;
CREATE TABLE embeddings_vector (
id SERIAL PRIMARY KEY,
region GEOMETRY(POLYGON, 4326),
embeddings VECTOR(3)
);
COMMIT;
"""
with engine.connect() as connection:
connection.execute(text(create_table_sql))
connection.commit()
Now we can save embeddings into database in two ways:
- using sqlalchemy insert command
- using pandas to_sql builtin method
Using sqlalchemy insert command
metadata = MetaData()
malmo_embeddings = Table(
'embeddings_vector', metadata,
Column('id', Integer, primary_key=True),
Column('region', Geometry('GEOMETRY', srid=4326)),
Column('embeddings', Vector(3))
)
metadata.create_all(engine)
records = []
for _, row in embeddings_gdf.iterrows():
records.append({
'region': row['region'].wkt, # Convert geometry to WKT format for insertion
'embeddings': np.array(row['embeddings']).tolist() # Ensure it's a list
})
with engine.connect() as conn:
with conn.begin():
conn.execute(insert(malmo_embeddings), records)
conn.commit()
Metadata
: Holds information about the database tables.Table
: Defines the table structure.insert
: Inserts records into the table.
Using pandas to_sql method
# Convert 'region' column to WKT format
embeddings_gdf['region'] = embeddings_gdf['region'].apply(lambda geom: geom.wkt)
# Use to_sql to save the data
embeddings_gdf.to_sql('embeddings_vector', engine, if_exists='append', index=False, dtype={ 'embeddings': ARRAY(Float)})
Note: The geopandas methods to_sql
and read_postgis
will convert the
vector data to a string format when saving or reading from the database. You
will need to convert the string back to a list or array when reading the data.
Load embeddings in geoDataframe from geobase
Use the geopandas method read_postgis
to read data from the geobase
query = "SELECT * FROM embeddings_vector"
read_df = gpd.read_postgis(query, engine,geom_col='region')
Verifying and Converting the ‘embeddings’ Column to a Numeric Vector
# Verify if the 'embeddings' column can be converted to a numeric vector
import ast
try:
read_df['embeddings'] = read_df['embeddings'].apply(ast.literal_eval)
print("The 'embeddings' column can be converted to a numeric vector.")
except Exception as e:
print(f"Error converting 'embeddings' column: {e}")