Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 164 additions & 22 deletions api/memory/graphiti_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# pylint: disable=all
import asyncio
import os
import uuid
from typing import List, Dict, Any, Optional
from datetime import datetime

Expand All @@ -26,6 +27,22 @@
from litellm import completion


def extract_embedding_model_name(full_model_name: str) -> str:
"""
Extract just the model name without provider prefix for Graphiti.

Args:
full_model_name: Model name that may include provider prefix (e.g., "azure/text-embedding-ada-002")

Returns:
Model name without prefix (e.g., "text-embedding-ada-002")
"""
if "/" in full_model_name:
return full_model_name.split("/", 1)[1] # Remove provider prefix
else:
return full_model_name


class MemoryTool:
"""Memory management tool for handling user memories and interactions."""

Expand All @@ -43,10 +60,12 @@ def __init__(self, user_id: str, graph_id: str):


@classmethod
async def create(cls, user_id: str, graph_id: str) -> "MemoryTool":
async def create(cls, user_id: str, graph_id: str, use_direct_entities: bool = True) -> "MemoryTool":
"""Async factory to construct and initialize the tool."""
self = cls(user_id, graph_id)
await self._ensure_database_node(graph_id, user_id)

await self._ensure_entity_nodes_direct(user_id, graph_id)


vector_size = Config.EMBEDDING_MODEL.get_vector_size()
driver = self.graphiti_client.driver
Expand Down Expand Up @@ -128,6 +147,114 @@ async def _ensure_database_node(self, database_name: str, user_id: str) -> Optio
print(f"Error creating database node for {database_name}: {e}")
return None

async def _ensure_entity_nodes_direct(self, user_id: str, database_name: str) -> bool:
"""
Ensure user and database entity nodes exist using direct Cypher queries.
This function creates Entity nodes similar to what Graphiti does but with hardcoded Cypher.
"""
try:
graph_driver = self.graphiti_client.driver

# Check if user entity node already exists
user_node_name = f"User {user_id}"
check_user_query = """
MATCH (n:Entity {name: $name})
RETURN n.uuid AS uuid
LIMIT 1
"""
user_check_result = await graph_driver.execute_query(check_user_query, name=user_node_name)

if not user_check_result[0]: # If no records found, create user node
user_uuid = str(uuid.uuid4())
user_name_embedding = Config.EMBEDDING_MODEL.embed(user_node_name)[0]

user_node_data = {
'uuid': user_uuid,
'name': user_node_name,
'group_id': '\\_',
Comment thread
galshubeli marked this conversation as resolved.
'created_at': datetime.now().isoformat(),
'summary': f'User {user_id} is using QueryWeaver',
'name_embedding': user_name_embedding
}

# Execute Cypher query for user entity node
user_cypher = """
MERGE (n:Entity {uuid: $node.uuid})
SET n = $node
SET n.timestamp = timestamp()
WITH n, $node AS node
SET n.name_embedding = vecf32(node.name_embedding)
RETURN n.uuid AS uuid
"""

await graph_driver.execute_query(user_cypher, node=user_node_data)
print(f"Created user entity node: {user_node_name} with UUID: {user_uuid}")
else:
print(f"User entity node already exists: {user_node_name}")

# Check if database entity node already exists
database_node_name = f"Database {database_name}"
check_database_query = """
MATCH (n:Entity {name: $name})
RETURN n.uuid AS uuid
LIMIT 1
"""
database_check_result = await graph_driver.execute_query(check_database_query, name=database_node_name)

if not database_check_result[0]: # If no records found, create database node
database_uuid = str(uuid.uuid4())
database_name_embedding = Config.EMBEDDING_MODEL.embed(database_node_name)[0]

database_node_data = {
'uuid': database_uuid,
'name': database_node_name,
'group_id': '\\_',
'created_at': datetime.now().isoformat(),
'summary': f'Database {database_name} available for querying by user {user_id}',
'name_embedding': database_name_embedding
}

# Execute Cypher query for database entity node
database_cypher = """
MERGE (n:Entity {uuid: $node.uuid})
SET n = $node
SET n.timestamp = timestamp()
WITH n, $node AS node
SET n.name_embedding = vecf32(node.name_embedding)
RETURN n.uuid AS uuid
"""

await graph_driver.execute_query(database_cypher, node=database_node_data)
print(f"Created database entity node: {database_node_name} with UUID: {database_uuid}")
else:
print(f"Database entity node already exists: {database_node_name}")

# Create HAS_DATABASE relationship between user and database entities
try:
relationship_query = """
MATCH (user:Entity {name: $user_name})
MATCH (db:Entity {name: $database_name})
MERGE (user)-[r:HAS_DATABASE]->(db)
RETURN r
"""

await graph_driver.execute_query(
relationship_query,
user_name=user_node_name,
database_name=database_node_name
)
print(f"Created HAS_DATABASE relationship between {user_node_name} and {database_node_name}")

except Exception as rel_error:
print(f"Error creating HAS_DATABASE relationship: {rel_error}")
# Don't fail the entire function if relationship creation fails

return True

except Exception as e:
print(f"Error creating entity nodes directly for user {user_id} and database {database_name}: {e}")
return False

async def add_new_memory(self, conversation: Dict[str, Any]) -> bool:
# Use LLM to analyze and summarize the conversation with focus on graph-oriented database facts
analysis = await self.summarize_conversation(conversation)
Expand Down Expand Up @@ -177,26 +304,24 @@ async def save_query_memory(self, query: str, sql_query: str, success: bool, err
"""
try:
database_name = self.graph_id

# Find the database node
database_node_name = f"Database {database_name}"
node_search_config = NODE_HYBRID_SEARCH_RRF.model_copy(deep=True)
node_search_config.limit = 1
graph_driver = self.graphiti_client.driver

database_node_results = await self.graphiti_client.search_(
query=database_node_name,
config=node_search_config,
)
# Find the database node using direct Cypher query
find_database_query = """
MATCH (n:Entity {name: $name})
RETURN n.uuid AS uuid
LIMIT 1
"""

database_result = await graph_driver.execute_query(find_database_query, name=database_node_name)

# Check if database node exists
database_node_exists = False
for node in database_node_results.nodes:
if node.name == database_node_name:
database_node_exists = True
database_node_uuid = node.uuid
break
if not database_node_exists:
if not database_result[0]: # If no records found
print(f"Database entity node {database_node_name} not found")
return False

database_node_uuid = database_result[0][0]['uuid']

# Check if Query node with same user_query and sql_query already exists
relationship_type = "SUCCESS" if success else "FAILED"
Expand Down Expand Up @@ -238,7 +363,7 @@ async def save_query_memory(self, query: str, sql_query: str, success: bool, err
CREATE (db)-[:{relationship_type} {{timestamp: timestamp()}}]->(q)
RETURN q.uuid as query_uuid
"""

# Execute the Cypher query through Graphiti's graph driver
try:
result = await graph_driver.execute_query(cypher_query, embedding=embeddings)
Expand Down Expand Up @@ -598,7 +723,10 @@ def __init__(self):
self.endpoint = os.getenv('AZURE_API_BASE')
self.api_version = os.getenv('AZURE_API_VERSION', '2024-02-01')
self.model_choice = "gpt-4.1" # Use the model name directly
self.embedding_model = "text-embedding-ada-002" # Use model name, not deployment

# Extract just the model name without provider prefix for Graphiti
self.embedding_model = extract_embedding_model_name(Config.EMBEDDING_MODEL_NAME)

self.small_model = os.getenv('AZURE_SMALL_MODEL', 'gpt-4o-mini')

# Use model names directly instead of deployment names
Expand Down Expand Up @@ -652,7 +780,10 @@ def create_graphiti_client(falkor_driver: FalkorDriver) -> Graphiti:
graph_driver=falkor_driver,
llm_client=OpenAIClient(config=azure_llm_config, client=llm_client_azure),
embedder=OpenAIEmbedder(
config=OpenAIEmbedderConfig(embedding_model=config.embedding_deployment),
config=OpenAIEmbedderConfig(
embedding_model=config.embedding_deployment,
embedding_dim=1536
),
client=embedding_client_azure,
),
cross_encoder=OpenAIRerankerClient(
Expand All @@ -662,8 +793,19 @@ def create_graphiti_client(falkor_driver: FalkorDriver) -> Graphiti:
client=llm_client_azure,
),
)
else: # Fallback to default OpenAI config
graphiti_client = Graphiti(graph_driver=falkor_driver)
else: # Fallback to default OpenAI config but use Config's embedding model
# Extract just the model name without provider prefix for Graphiti
embedding_model_name = extract_embedding_model_name(Config.EMBEDDING_MODEL_NAME)

graphiti_client = Graphiti(
graph_driver=falkor_driver,
embedder=OpenAIEmbedder(
config=OpenAIEmbedderConfig(
embedding_model=embedding_model_name,
embedding_dim=1536
)
),
)

return graphiti_client

11 changes: 11 additions & 0 deletions app/public/css/buttons.css
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

.input-button {
background: var(--falkor-primary);
color: var(--text-primary);
cursor: pointer;
transition: all 0.2s ease;
border: none;
Expand All @@ -15,6 +16,11 @@
box-shadow: 0 2px 8px rgba(0,0,0,0.15);
}

.input-button:disabled {
opacity: 0.5;
cursor: not-allowed;
}

.input-button img {
filter: var(--icon-filter) brightness(0.8) saturate(0.3);
width: 100%;
Expand Down Expand Up @@ -46,6 +52,11 @@
position: relative;
}

.action-button:disabled {
opacity: 0.5;
cursor: not-allowed;
}

.action-button:hover {
box-shadow: 0 4px 12px rgba(0,0,0,0.25);
background: var(--falkor-primary);
Expand Down
3 changes: 2 additions & 1 deletion app/public/css/chat-components.css
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
.final-result-message-container::before {
height: 32px;
width: 32px;
content: 'Bot';
content: 'QW';
color: var(--text-secondary);
display: flex;
align-items: center;
Expand Down Expand Up @@ -230,6 +230,7 @@

#pause-button {
display: none;
color: var(--text-primary);
}

#reset-button {
Expand Down
9 changes: 4 additions & 5 deletions app/public/css/menu.css
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,11 @@
transition: border-color 0.2s;
min-width: 180px;
appearance: none;
background-image: linear-gradient(45deg, transparent 50%, var(--text-secondary) 50%),
linear-gradient(135deg, var(--text-secondary) 50%, transparent 50%);
background-position: calc(100% - 20px) center, calc(100% - 15px) center;
background-size: 5px 5px, 5px 5px;
background-repeat: no-repeat;
cursor: pointer;
display: flex;
align-items: center;
justify-content: space-between;
gap: 4px;
}

#open-pg-modal {
Expand Down
7 changes: 0 additions & 7 deletions app/public/icons/pause.svg

This file was deleted.

9 changes: 5 additions & 4 deletions app/templates/components/chat_header.j2
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
{# Chat header with logo, title, and action buttons #}
<div class="chat-header">
<img src="/static/icons/queryweaver.svg" alt="Chat Logo" class="logo">
<h1>Natural Language to SQL Generator</h1>
<h1>Text-to-SQL for Enterprise Databases</h1>
<div class="button-container">
<button class="action-button" id="graph-select-refresh">
<button class="action-button" id="graph-select-refresh" disabled>
<svg width="800px" height="800px" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg">
<path fill="none" stroke="currentColor" stroke-width="2"
d="M20,8 C18.5974037,5.04031171 15.536972,3 12,3 C7.02943725,3 3,7.02943725 3,12 C3,16.9705627 7.02943725,21 12,21 L12,21 C16.9705627,21 21,16.9705627 21,12 M21,3 L21,9 L15,9" />
</svg>
</button>
<div id="graph-custom-dropdown" class="graph-custom-dropdown">
<div id="graph-selected" class="graph-selected dropdown-selected" title="Select Database">
<span class="dropdown-text">Select database</span>
<span class="dropdown-text">Select Database</span>
<span class="dropdown-arrow">▼</span>
</div>
<div id="graph-options" class="graph-options dropdown-options" aria-hidden="true"></div>
Expand All @@ -22,7 +22,8 @@
<input title="Upload Schema" id="schema-upload" type="file" accept=".json" style="display: none;" tabindex="-1"
disabled />
<label for="schema-upload" id="custom-file-upload">
Upload Schema
<span>Upload Schema</span>
<span class="dropdown-arrow">▼</span>
</label>
<div class="vertical-separator"></div>
<button class="connect-database-btn header-button">Connect Database</button>
Expand Down
14 changes: 7 additions & 7 deletions app/templates/components/chat_input.j2
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
{# Chat input area with text input and action buttons #}
<div class="chat-input">
<div class="input-container" id="input-container">
<input type="text" id="message-input" placeholder="Describe the SQL query you want..." />
<button class="input-button" title="Submit" id="submit-button">
<svg width="54" height="54" viewBox="0 0 54 54" fill="none" xmlns="http://www.w3.org/2000/svg">
<textarea type="text" id="message-input" placeholder="Describe the SQL query you want..." rows="1"></textarea>
Comment thread
galshubeli marked this conversation as resolved.
<button class="input-button" title="Submit" id="submit-button" disabled>
<svg width="54" height="54" viewBox="0 0 54 54" xmlns="http://www.w3.org/2000/svg">
<!-- Rounded rectangle background -->
<rect x="2" y="2" width="50" height="50" rx="6" ry="6" fill="none" />
<!-- Submit arrow icon -->
<path
d="M24.332 26.4667V35C24.332 35.7555 24.588 36.3893 25.1 36.9013C25.612 37.4133 26.2449 37.6684 26.9987 37.6667C27.7525 37.6649 28.3863 37.4089 28.9 36.8987C29.4138 36.3884 29.6689 35.7555 29.6654 35V26.4667L32.0654 28.8666C32.5543 29.3555 33.1765 29.6 33.932 29.6C34.6876 29.6 35.3098 29.3555 35.7987 28.8666C36.2876 28.3778 36.532 27.7555 36.532 27C36.532 26.2444 36.2876 25.6222 35.7987 25.1333L28.8654 18.2C28.332 17.6666 27.7098 17.4 26.9987 17.4C26.2876 17.4 25.6654 17.6666 25.132 18.2L18.1987 25.1333C17.7098 25.6222 17.4654 26.2444 17.4654 27C17.4654 27.7555 17.7098 28.3778 18.1987 28.8666C18.6876 29.3555 19.3098 29.6 20.0654 29.6C20.8209 29.6 21.4431 29.3555 21.932 28.8666L24.332 26.4667Z"
fill="white" />
fill="currentColor" />
</svg>
</button>
<button class="input-button" title="Pause" id="pause-button">
<svg width="54" height="54" viewBox="0 0 54 54" fill="none" xmlns="http://www.w3.org/2000/svg">
<svg width="54" height="54" viewBox="0 0 54 54" xmlns="http://www.w3.org/2000/svg">
<!-- Rounded rectangle background -->
<rect x="2" y="2" width="50" height="50" rx="6" ry="6" fill="none" />
<!-- Pause bars -->
<rect x="18.6667" y="16" width="5.3333" height="21.3333" fill="white" />
<rect x="29.3333" y="16" width="5.3333" height="21.3333" fill="white" />
<rect x="18.6667" y="16" width="5.3333" height="21.3333" fill="currentColor" />
<rect x="29.3333" y="16" width="5.3333" height="21.3333" fill="currentColor" />
</svg>

</button>
Expand Down
Loading