GuidesSRAIMultidimensional Embeddings

Multidimensional Embeddings for Geospatial Data with Hex2Vec and Storing in Geobase

This tutorial demonstrates how to create hex2vec embeddings for geospatial data using srai and store them in a geobase database. We will use the following Python libraries: srai (version 0.7.5), geopandas, and sqlalchemy.

How Hex2Vec Embedder Works

The Hex2Vec embedder is designed to generate embeddings for hexagonally partitioned geographical regions. These embeddings capture the spatial relationships and the characteristics of the features present within those regions.

How are geospatial embeddings created?

The creation of geospatial embeddings involves several steps:

  1. select a geographical area of interest
  2. load OpenStreetMap (OSM) data for that area
  3. regionalize the area into hexagonal regions
  4. join the regions with the OSM features
  5. create embeddings for the regions based on the features
  6. store the embeddings in a geobase database

This is a high-level overview of the process:

Hex2Vec Embedder Flow Diagram

In the rest of this tutorial, we will walk through each of these steps in detail.

Installing Required Packages

To begin, you need to install several Python packages. These packages include libraries for working with geospatial data, embedding models, and Geobase.

!pip install srai[all]
!pip install folium matplotlib mapclassify
!pip install pgvector
!pip install geoalchemy2

Importing Necessary Libraries

import pandas as pd
from sqlalchemy import create_engine, Table, Column, Integer, Float, MetaData
from sqlalchemy.dialects.postgresql import ARRAY
import geopandas as gpd
from pgvector.sqlalchemy import Vector
from geoalchemy2 import Geometry
from sqlalchemy import text, insert
import numpy as np
  • sqlalchemy: A Python SQL toolkit for database interaction.
  • pgvector: Allows you to store vector data in PostgreSQL.
  • geoalchemy2: Extends SQLAlchemy to support geospatial data types.

Import Necessary SRAI Libraries

from srai.embedders import Hex2VecEmbedder
from srai.joiners import IntersectionJoiner
from srai.loaders import OSMPbfLoader
from srai.loaders.osm_loaders.filters import HEX2VEC_FILTER
from srai.neighbourhoods.h3_neighbourhood import H3Neighbourhood
from srai.regionalizers import H3Regionalizer, geocode_to_region_gdf
from srai.plotting import plot_regions, plot_numeric_data
  • OSMPbfLoader: Loads OSM data for a specified area.
  • H3Regionalizer: Converts geographical areas into H3 hexagons at a specific resolution.
  • IntersectionJoiner: Joins geospatial regions with OSM features.

Load OSM Data and Regionalize

  • Load OSM Data: Use OSMOnlineLoader to load features for a specific area and query.
  • Regionalize Data: We use the H3Regionalizer to divide the area of interest into hexagonal regions. The H3 indexing system is a hierarchical hexagonal grid system that is useful for spatial analysis.
loader = OSMPbfLoader()
regionalizer = H3Regionalizer(resolution=11)
 
area = geocode_to_region_gdf("City of London")
regions = regionalizer.transform(area)

Load Features

Using the loader, we load the features from OSM within the specified area. The HEX2VEC_FILTER ensures that only relevant features for the Hex2Vec model are loaded.

features = loader.load(area, HEX2VEC_FILTER)

Initialize the Joiner

The IntersectionJoiner is used to join the spatial features with the defined regions. This step ensures that each region has associated feature data for embedding.

joiner = IntersectionJoiner()

Join Regions and Features

The joiner transforms the regions and features into a joint representation, necessary for embedding.

joint = joiner.transform(regions, features)

Joint Representation

Define the Neighbourhood

We define the spatial neighbourhood using the H3Neighbourhood class, which will help the embedder understand spatial relationships.

neighbourhood = H3Neighbourhood(regions_gdf=regions)

Initialize, Fit and Transform Embedder

We fit the embedder to the regions, features, joint representation, and neighbourhood. This step trains the model to create meaningful embeddings.

embedder = Hex2VecEmbedder()
embedder = Hex2VecEmbedder([15, 10, 3])
 
embedder.fit(regions, features, joint, neighbourhood, batch_size=128)
embeddings = embedder.transform(regions, features, joint)

Visualize the Results

Use srai.plotting to visualize the regions and the embeddings.

folium_map = plot_regions(area, colormap=["rgba(0,0,0,0.1)"], tiles_style="CartoDB positron")
plot_numeric_data(regions, 0, embeddings, map=folium_map)
 

Embeddings Visualization

Saving Embeddings to Geobase

  • Create a Connection to the Database: Use sqlalchemy to create an engine for the geobase database.
  • Prepare Data for Storage: Convert the embeddings to a DataFrame and then to a GeoDataFrame.
  • Define the Table Schema: Define the metadata and table structure.
ℹ️

Note: The pgvector extension must be enabled in the Geobase project to support vector data types.

Here is how you can enable the pgvector extension in Geobase: Enable pgvector Extension

Create a Connection to the Database

As a first step, create a connection to the geobase database using sqlalchemy.

To copy the connection string from the Geobase project, navigate to the project settings. Geobase Connection String

Import database connection details from environment variables.

from google.colab import userdata
db_string = userdata.get('db_string')

Or you can import using the .env file

from dotenv import load_dotenv
load_dotenv()
import os
db_string = os.getenv('DB_STRING')

Create an engine using the connection string.

# Create a connection to the database
engine = create_engine(db_string)
ℹ️

Note: The connection string should be in the format: postgresql://username:password@host:port/database

To check if the connection is successful, you can use the following code:

try:
    connection = engine.connect()
    print("Connection to the database successful!")
    connection.close()
except Exception as e:
    print("Connection failed:", e)
ℹ️

Note: While executing any SQL commands you need to start and finish transactions using BEGIN; and COMMIT; respectively. Or you can use conn.begin() and conn.commit() methods if you are using a connection object.

create_table_sql = """
BEGIN;
CREATE TABLE embeddings_vector (
    id SERIAL PRIMARY KEY,
    region GEOMETRY(POLYGON, 4326),
    embeddings VECTOR(3)
);
COMMIT;
"""
 
with engine.connect() as connection:
    connection.execute(text(create_table_sql))
    connection.commit()

Now we can save embeddings into database in two ways:

  • using sqlalchemy insert command
  • using pandas to_sql builtin method

Using sqlalchemy insert command

metadata = MetaData()
malmo_embeddings = Table(
    'embeddings_vector', metadata,
    Column('id', Integer, primary_key=True),
    Column('region', Geometry('GEOMETRY', srid=4326)),
    Column('embeddings', Vector(3))
)
 
metadata.create_all(engine)
 
records = []
for _, row in embeddings_gdf.iterrows():
    records.append({
        'region': row['region'].wkt,  # Convert geometry to WKT format for insertion
        'embeddings': np.array(row['embeddings']).tolist()  # Ensure it's a list
    })
 
with engine.connect() as conn:
    with conn.begin():
        conn.execute(insert(malmo_embeddings), records)
        conn.commit()
  • Metadata: Holds information about the database tables.
  • Table: Defines the table structure.
  • insert: Inserts records into the table.

Using pandas to_sql method

# Convert 'region' column to WKT format
embeddings_gdf['region'] = embeddings_gdf['region'].apply(lambda geom: geom.wkt)
 
# Use to_sql to save the data
embeddings_gdf.to_sql('embeddings_vector', engine, if_exists='append', index=False, dtype={ 'embeddings': ARRAY(Float)})

Geobase Table

ℹ️

Note: The geopandas methods to_sql and read_postgis will convert the vector data to a string format when saving or reading from the database. You will need to convert the string back to a list or array when reading the data.

Load embeddings in geoDataframe from geobase

Use the geopandas method read_postgis to read data from the geobase

query = "SELECT * FROM embeddings_vector"
read_df = gpd.read_postgis(query, engine,geom_col='region')

Verifying and Converting the ‘embeddings’ Column to a Numeric Vector

# Verify if the 'embeddings' column can be converted to a numeric vector
import ast
try:
    read_df['embeddings'] = read_df['embeddings'].apply(ast.literal_eval)
    print("The 'embeddings' column can be converted to a numeric vector.")
except Exception as e:
    print(f"Error converting 'embeddings' column: {e}")