From c7fc3b6337b499845dd98db2f4cfb2af3feb8fe1 Mon Sep 17 00:00:00 2001 From: coltonpeltier-db Date: Thu, 12 Feb 2026 11:50:30 +0000 Subject: [PATCH 1/4] Initial checkpointer code --- .../lakebase_checkpointer.py | 102 ++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 src/databricks_ai_bridge/lakebase_checkpointer.py diff --git a/src/databricks_ai_bridge/lakebase_checkpointer.py b/src/databricks_ai_bridge/lakebase_checkpointer.py new file mode 100644 index 000000000..9ccb810b7 --- /dev/null +++ b/src/databricks_ai_bridge/lakebase_checkpointer.py @@ -0,0 +1,102 @@ +from pydantic import BaseModel +from datetime import datetime +from .lakebase import LakebaseClient +from typing import Tuple +from psycopg.types.json import Json + +class Checkpoint(BaseModel): + id : str + state : dict + creation_timestamp : datetime + update_timestamp : datetime + def get_insert_params(self) -> Tuple: + return ( + self.id, + Json(self.state), + self.creation_timestamp, + self.update_timestamp + ) + def get_update_params(self) -> Tuple: + return ( + Json(self.state), + self.update_timestamp + ) + +class LakebaseCheckpointer: + def __init__(self, lakebase_client : LakebaseClient, sessions_table_name : str): + self.client = lakebase_client + self.table_name = sessions_table_name + + def init_schema(self) -> None: + sql = f""" + CREATE TABLE IF NOT EXISTS {self.table_name} ( + lb_id bigserial PRIMARY_KEY, + id text NOT NULL, + state jsonb NOT NULL default '{{}}', + creation_timestamp timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, + update_timestamp timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, + UNIQUE(id,creation_timestamp) + ) + """ + response = self.client.execute(sql=sql) + + def get_most_recent_checkpoint(self, id : str) -> Checkpoint | None: + # Get the most recently updated checkpoint for this id + sql = f""" + SELECT id, state, creation_timestamp, update_timestamp + FROM {self.table_name} + WHERE id = %s + ORDER BY update_timestamp DESC + LIMIT 1 + """ + response = self.client.execute(sql=sql, params = (id,)) + if response: + response = Checkpoint(**response) + return response + + def checkpoint_exists(self, id : str) -> bool: + sql = f""" + SELECT COUNT(*) AS count + FROM {self.table_name} + WHERE id = %s + """ + response = self.client.execute(sql=sql, params=(id,)) + count = response[0]["count"] + return True if count > 0 else False + + def update_checkpoint(self, id : str, state : dict) -> None: + # Get the most recent checkpoint + _checkpoint = self.get_most_recent_checkpoint(id = id) + # Set the new state value locally + _checkpoint.state = state + _checkpoint.update_timestamp = datetime.now() + sql = f""" + UPDATE {self.table_name} + SET state = %s, update_timestamp = %s + WHERE id = %s AND creation_timestamp = %s + """ + self.client.execute(sql=sql, params=_checkpoint.get_update_params + (_checkpoint.id, _checkpoint.creation_timestamp)) + return + + + def insert_checkpoint(self, id : str, state : dict) -> None: + _checkpoint = Checkpoint(id = id, state = state, creation_timestamp = datetime.now(), update_timestamp = datetime.now()) + sql = f""" + INSERT INTO {self.table_name} + (id, state, creation_timestamp, update_timestamp) + VALUES (%s, %s, %s, %s) + """ + response = self.client.execute(sql=sql, params=_checkpoint.get_insert_params()) + return + + + def save_checkpoint(self, id : str, state : dict, overwrite : bool = False) -> None: + if overwrite: + checkpoint_exists = self.checkpoint_exists(id = id) + if checkpoint_exists: + self.update_checkpoint(id = id, state = state) + else: + self.insert_checkpoint(id = id, state = state) + else: + # Don't overwrite existing checkpoints, so just insert a new one + self.insert_checkpoint(id = id, state = state) \ No newline at end of file From 6613060941f1582571f8444f38bebe99c46fee39 Mon Sep 17 00:00:00 2001 From: coltonpeltier-db Date: Fri, 13 Feb 2026 11:58:32 +0000 Subject: [PATCH 2/4] move sql generation into GenericCheckpoint class --- .../lakebase_checkpointer.py | 105 ++++++++++-------- 1 file changed, 61 insertions(+), 44 deletions(-) diff --git a/src/databricks_ai_bridge/lakebase_checkpointer.py b/src/databricks_ai_bridge/lakebase_checkpointer.py index 9ccb810b7..b2321a8a0 100644 --- a/src/databricks_ai_bridge/lakebase_checkpointer.py +++ b/src/databricks_ai_bridge/lakebase_checkpointer.py @@ -3,34 +3,37 @@ from .lakebase import LakebaseClient from typing import Tuple from psycopg.types.json import Json +from abc import ABC, abstractmethod +from typing import Any, Union, Tuple +from uuid import uuid4 -class Checkpoint(BaseModel): - id : str - state : dict - creation_timestamp : datetime - update_timestamp : datetime - def get_insert_params(self) -> Tuple: - return ( - self.id, - Json(self.state), - self.creation_timestamp, - self.update_timestamp - ) - def get_update_params(self) -> Tuple: - return ( - Json(self.state), - self.update_timestamp - ) - -class LakebaseCheckpointer: - def __init__(self, lakebase_client : LakebaseClient, sessions_table_name : str): - self.client = lakebase_client - self.table_name = sessions_table_name +class GenericCheckpoint(BaseModel): + # Define the fields for your checkpoint + id : str = str(uuid4()) + state : dict = {} + creation_timestamp : datetime = datetime.now() + update_timestamp : datetime = datetime.now() + # Implement abstract methods + def generate_insert_sql(self, table_name : str) -> Tuple[str, tuple]: + sql = f""" + INSERT INTO {table_name} + (id, state, creation_timestamp, update_timestamp) + VALUES (%s, %s, %s, %s) + """ + return sql, (self.id, Json(self.state), self.creation_timestamp, self.update_timestamp) - def init_schema(self) -> None: + def generate_update_sql(self, table_name : str) -> Tuple[str, tuple]: sql = f""" - CREATE TABLE IF NOT EXISTS {self.table_name} ( - lb_id bigserial PRIMARY_KEY, + UPDATE {table_name} + SET state = %s, update_timestamp = %s + WHERE id = %s AND creation_timestamp = %s + """ + return sql, (Json(self.state), self.update_timestamp, self.id, self.creation_timestamp) + + def generate_init_sql(self, table_name : str) -> Tuple[str, tuple]: + sql = f""" + CREATE TABLE IF NOT EXISTS {table_name} ( + lb_id bigserial PRIMARY KEY, id text NOT NULL, state jsonb NOT NULL default '{{}}', creation_timestamp timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, @@ -38,20 +41,40 @@ def init_schema(self) -> None: UNIQUE(id,creation_timestamp) ) """ - response = self.client.execute(sql=sql) + return sql, None - def get_most_recent_checkpoint(self, id : str) -> Checkpoint | None: - # Get the most recently updated checkpoint for this id + def generate_retrieve_checkpoint_sql(self, table_name : str) -> Tuple[str, tuple]: sql = f""" SELECT id, state, creation_timestamp, update_timestamp - FROM {self.table_name} + FROM {table_name} WHERE id = %s ORDER BY update_timestamp DESC LIMIT 1 """ - response = self.client.execute(sql=sql, params = (id,)) + return sql, (self.id,) + +class LakebaseCheckpointer: + def __init__(self, + lakebase_client : LakebaseClient, + sessions_table_name : str, + checkpoint_class : GenericCheckpoint = GenericCheckpoint): + self.checkpoint_class = checkpoint_class + self.client = lakebase_client + self.table_name = sessions_table_name + self.init_schema() + + def init_schema(self) -> None: + _checkpoint = self.checkpoint_class() + sql, params = _checkpoint.generate_init_sql(table_name = self.table_name) + response = self.client.execute(sql=sql, params = params) + + def get_most_recent_checkpoint(self, id : str) -> GenericCheckpoint | None: + _checkpoint = self.checkpoint_class(id = id) + # Get the most recently updated checkpoint for this id + sql, params = _checkpoint.generate_retrieve_checkpoint_sql(table_name = self.table_name) + response = self.client.execute(sql=sql, params = params)[0] if response: - response = Checkpoint(**response) + response = self.checkpoint_class(**response) return response def checkpoint_exists(self, id : str) -> bool: @@ -70,23 +93,17 @@ def update_checkpoint(self, id : str, state : dict) -> None: # Set the new state value locally _checkpoint.state = state _checkpoint.update_timestamp = datetime.now() - sql = f""" - UPDATE {self.table_name} - SET state = %s, update_timestamp = %s - WHERE id = %s AND creation_timestamp = %s - """ - self.client.execute(sql=sql, params=_checkpoint.get_update_params + (_checkpoint.id, _checkpoint.creation_timestamp)) + sql, params = _checkpoint.generate_update_sql(table_name = self.table_name) + self.client.execute(sql = sql, params = params) return def insert_checkpoint(self, id : str, state : dict) -> None: - _checkpoint = Checkpoint(id = id, state = state, creation_timestamp = datetime.now(), update_timestamp = datetime.now()) - sql = f""" - INSERT INTO {self.table_name} - (id, state, creation_timestamp, update_timestamp) - VALUES (%s, %s, %s, %s) - """ - response = self.client.execute(sql=sql, params=_checkpoint.get_insert_params()) + _checkpoint = self.checkpoint_class( + id = id, state = state, creation_timestamp = datetime.now(), update_timestamp = datetime.now() + ) + sql, params = _checkpoint.generate_insert_sql(table_name = self.table_name) + response = self.client.execute(sql=sql, params=params) return From 6feb1b359a106f732e44433327ca7e9ab32adc23 Mon Sep 17 00:00:00 2001 From: coltonpeltier-db Date: Fri, 13 Feb 2026 11:59:01 +0000 Subject: [PATCH 3/4] remove empty lines --- src/databricks_ai_bridge/lakebase_checkpointer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/databricks_ai_bridge/lakebase_checkpointer.py b/src/databricks_ai_bridge/lakebase_checkpointer.py index b2321a8a0..be119d8da 100644 --- a/src/databricks_ai_bridge/lakebase_checkpointer.py +++ b/src/databricks_ai_bridge/lakebase_checkpointer.py @@ -97,7 +97,6 @@ def update_checkpoint(self, id : str, state : dict) -> None: self.client.execute(sql = sql, params = params) return - def insert_checkpoint(self, id : str, state : dict) -> None: _checkpoint = self.checkpoint_class( id = id, state = state, creation_timestamp = datetime.now(), update_timestamp = datetime.now() @@ -106,7 +105,6 @@ def insert_checkpoint(self, id : str, state : dict) -> None: response = self.client.execute(sql=sql, params=params) return - def save_checkpoint(self, id : str, state : dict, overwrite : bool = False) -> None: if overwrite: checkpoint_exists = self.checkpoint_exists(id = id) From 50d7fc38a86b069cc63f09b712e669d511582e91 Mon Sep 17 00:00:00 2001 From: coltonpeltier-db Date: Fri, 13 Feb 2026 12:22:53 +0000 Subject: [PATCH 4/4] add factory initalizations, add function to get all checkpoints, make functions more generic to handle future subclasses --- .../lakebase_checkpointer.py | 68 +++++++++++++------ 1 file changed, 46 insertions(+), 22 deletions(-) diff --git a/src/databricks_ai_bridge/lakebase_checkpointer.py b/src/databricks_ai_bridge/lakebase_checkpointer.py index be119d8da..eaea5f082 100644 --- a/src/databricks_ai_bridge/lakebase_checkpointer.py +++ b/src/databricks_ai_bridge/lakebase_checkpointer.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from pydantic import BaseModel, Field from datetime import datetime from .lakebase import LakebaseClient from typing import Tuple @@ -8,12 +8,26 @@ from uuid import uuid4 class GenericCheckpoint(BaseModel): - # Define the fields for your checkpoint - id : str = str(uuid4()) + #################### + # CHECKPOINT COLUMNS + #################### + id : str = Field(default_factory = lambda: str(uuid4())) # `id` column must exist for locating the checkpoint, even in subclasses. state : dict = {} - creation_timestamp : datetime = datetime.now() - update_timestamp : datetime = datetime.now() - # Implement abstract methods + creation_timestamp : datetime = Field(default_factory=datetime.now) + update_timestamp : datetime = Field(default_factory=datetime.now) + #################### + # UPDATE ATTRIBUTES + #################### + def update(self, **kwargs): + for k,v in kwargs.items(): + # Check attribute exists + assert hasattr(self, k), f"Attribute {k} does not exist in {self.__class__.__name__}" + setattr(self, k, v) + self.update_timestamp = datetime.now() + ################################# + # SQL GENERATION IMPLEMENTATIONS + ################################# + # If subclassing the `GenericCheckpoint` class, implement new methods to handle any changes in attributes. def generate_insert_sql(self, table_name : str) -> Tuple[str, tuple]: sql = f""" INSERT INTO {table_name} @@ -44,12 +58,18 @@ def generate_init_sql(self, table_name : str) -> Tuple[str, tuple]: return sql, None def generate_retrieve_checkpoint_sql(self, table_name : str) -> Tuple[str, tuple]: + sql_all, params = self.generate_retrieval_all_checkpoints_sql(table_name = table_name) + sql = f""" + {sql_all} LIMIT 1 + """ + return sql, params + + def generate_retrieval_all_checkpoints_sql(self, table_name : str) -> Tuple[str, tuple]: sql = f""" SELECT id, state, creation_timestamp, update_timestamp FROM {table_name} WHERE id = %s ORDER BY update_timestamp DESC - LIMIT 1 """ return sql, (self.id,) @@ -72,10 +92,16 @@ def get_most_recent_checkpoint(self, id : str) -> GenericCheckpoint | None: _checkpoint = self.checkpoint_class(id = id) # Get the most recently updated checkpoint for this id sql, params = _checkpoint.generate_retrieve_checkpoint_sql(table_name = self.table_name) - response = self.client.execute(sql=sql, params = params)[0] - if response: - response = self.checkpoint_class(**response) - return response + response = self.client.execute(sql=sql, params = params) + if len(response) == 0: + return None + return self.checkpoint_class(**response[0]) + + def get_all_checkpoints(self, id : str) -> list[GenericCheckpoint]: + _checkpoint = self.checkpoint_class(id = id) + sql, params = _checkpoint.generate_retrieval_all_checkpoints_sql(table_name = self.table_name) + response = self.client.execute(sql=sql, params = params) + return [self.checkpoint_class(**_resp) for _resp in response] def checkpoint_exists(self, id : str) -> bool: sql = f""" @@ -87,31 +113,29 @@ def checkpoint_exists(self, id : str) -> bool: count = response[0]["count"] return True if count > 0 else False - def update_checkpoint(self, id : str, state : dict) -> None: + def update_most_recent_checkpoint(self, id : str, **checkpoint_kwargs) -> None: # Get the most recent checkpoint _checkpoint = self.get_most_recent_checkpoint(id = id) - # Set the new state value locally - _checkpoint.state = state - _checkpoint.update_timestamp = datetime.now() + _checkpoint.update(**checkpoint_kwargs) sql, params = _checkpoint.generate_update_sql(table_name = self.table_name) self.client.execute(sql = sql, params = params) return - def insert_checkpoint(self, id : str, state : dict) -> None: + def insert_checkpoint(self, id : str, **checkpoint_kwargs) -> None: _checkpoint = self.checkpoint_class( - id = id, state = state, creation_timestamp = datetime.now(), update_timestamp = datetime.now() + id = id, **checkpoint_kwargs ) sql, params = _checkpoint.generate_insert_sql(table_name = self.table_name) response = self.client.execute(sql=sql, params=params) return - def save_checkpoint(self, id : str, state : dict, overwrite : bool = False) -> None: + def save_checkpoint(self, id : str, overwrite : bool = False, **checkpoint_kwargs) -> None: if overwrite: checkpoint_exists = self.checkpoint_exists(id = id) if checkpoint_exists: - self.update_checkpoint(id = id, state = state) + self.update_checkpoint(id = id, **checkpoint_kwargs) else: - self.insert_checkpoint(id = id, state = state) + self.insert_checkpoint(id = id, **checkpoint_kwargs) else: - # Don't overwrite existing checkpoints, so just insert a new one - self.insert_checkpoint(id = id, state = state) \ No newline at end of file + # Don't overwrite the most recent checkpoint, so just insert a new one + self.insert_checkpoint(id = id, **checkpoint_kwargs) \ No newline at end of file