Skip to content

Commit 0c77659

Browse files
author
Pierre
authored
Merge branch 'main' into pierre-example-chatbot-ecom
2 parents 37dec3d + 8da35a6 commit 0c77659

File tree

3 files changed

+170
-15
lines changed

3 files changed

+170
-15
lines changed
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
"""
2+
This example demonstrates how to implement Contextual Retrieval's context generation
3+
as described by Anthropic (https://www.anthropic.com/news/contextual-retrieval).
4+
It shows how to generate concise, contextual descriptions for document chunks.
5+
"""
6+
7+
import asyncio
8+
9+
from pydantic import BaseModel, Field
10+
11+
import workflowai
12+
from workflowai import Model, Run
13+
14+
15+
class ContextGeneratorInput(BaseModel):
16+
"""Input for generating context for a document chunk."""
17+
doc_content: str = Field(
18+
description="The full text content of the document",
19+
)
20+
chunk_content: str = Field(
21+
description="The specific chunk of text to generate context for",
22+
)
23+
24+
25+
class ContextGeneratorOutput(BaseModel):
26+
"""Output containing the generated context for a chunk."""
27+
context: str = Field(
28+
description="Generated contextual information for the chunk",
29+
examples=[
30+
"This chunk is from section 3.2 discussing revenue growth in Q2 2023",
31+
"This appears in the methodology section explaining the experimental setup",
32+
],
33+
)
34+
35+
36+
@workflowai.agent(
37+
id="context-generator",
38+
model=Model.CLAUDE_3_5_SONNET_LATEST,
39+
)
40+
async def generate_chunk_context(context_input: ContextGeneratorInput) -> Run[ContextGeneratorOutput]:
41+
"""
42+
Here is the chunk we want to situate within the whole document.
43+
Please give a short succinct context to situate this chunk within the overall document
44+
for the purposes of improving search retrieval of the chunk.
45+
"""
46+
...
47+
48+
49+
async def main():
50+
# Example: Generate context for a document chunk
51+
print("\nGenerating context for document chunk")
52+
print("-" * 50)
53+
54+
# Example document content
55+
doc_content = """
56+
ACME Corporation (NASDAQ: ACME)
57+
Second Quarter 2023 Financial Results
58+
59+
Executive Summary
60+
ACME Corp. delivered strong performance in Q2 2023, with revenue growth
61+
exceeding market expectations. The company's strategic initiatives in
62+
AI and cloud services continued to drive expansion.
63+
64+
Financial Highlights
65+
The company's revenue grew by 3% over the previous quarter, reaching
66+
$323.4M. This growth was primarily driven by our enterprise segment,
67+
which saw a 15% increase in cloud service subscriptions.
68+
69+
Operational Metrics
70+
- Customer base expanded to 15,000 enterprise clients
71+
- Cloud platform usage increased by 25%
72+
- AI solutions adoption rate reached 40%
73+
"""
74+
75+
# Example chunk from the Financial Highlights section
76+
chunk_content = "The company's revenue grew by 3% over the previous quarter, reaching $323.4M."
77+
78+
context_input = ContextGeneratorInput(
79+
doc_content=doc_content,
80+
chunk_content=chunk_content,
81+
)
82+
83+
run = await generate_chunk_context(context_input)
84+
print("\nGenerated Context:")
85+
print(run.output.context)
86+
87+
88+
if __name__ == "__main__":
89+
asyncio.run(main())

workflowai/core/utils/_tools.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1+
import contextlib
12
import inspect
2-
from enum import Enum
3-
from typing import Any, Callable, NamedTuple, Optional, cast, get_type_hints
3+
from typing import Any, Callable, NamedTuple, Optional, get_type_hints
44

5-
from pydantic import BaseModel
5+
from pydantic import TypeAdapter
66

77
from workflowai.core.utils._schema_generator import JsonSchemaGenerator
88

@@ -24,10 +24,6 @@ def _get_type_schema(param_type: type):
2424
Returns:
2525
A dictionary containing the JSON schema type definition
2626
"""
27-
if issubclass(param_type, Enum):
28-
if not issubclass(param_type, str):
29-
raise ValueError(f"Non string enums are not supported: {param_type}")
30-
return SchemaDeserializer({"type": "string", "enum": [e.value for e in param_type]})
3127

3228
if param_type is str:
3329
return SchemaDeserializer({"type": "string"})
@@ -41,11 +37,13 @@ def _get_type_schema(param_type: type):
4137
if param_type is bool:
4238
return SchemaDeserializer({"type": "boolean"})
4339

44-
if issubclass(param_type, BaseModel):
40+
# Attempting to build a type adapter with pydantic
41+
with contextlib.suppress(Exception):
42+
adapter = TypeAdapter[Any](param_type)
4543
return SchemaDeserializer(
46-
schema=param_type.model_json_schema(by_alias=True, schema_generator=JsonSchemaGenerator),
47-
serializer=lambda x: cast(BaseModel, x).model_dump(mode="json"), # pyright: ignore [reportUnknownLambdaType]
48-
deserializer=param_type.model_validate,
44+
schema=adapter.json_schema(schema_generator=JsonSchemaGenerator),
45+
deserializer=adapter.validate_python, # pyright: ignore [reportUnknownLambdaType]
46+
serializer=lambda x: adapter.dump_python(x, mode="json"), # pyright: ignore [reportUnknownLambdaType]
4947
)
5048

5149
raise ValueError(f"Unsupported type: {param_type}")

workflowai/core/utils/_tools_test.py

Lines changed: 72 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,54 @@
1+
import json
2+
from datetime import datetime
13
from enum import Enum
2-
from typing import Annotated
4+
from typing import Annotated, Any
35

6+
import pytest
47
from pydantic import BaseModel
5-
6-
from workflowai.core.utils._tools import tool_schema
8+
from zoneinfo import ZoneInfo
9+
10+
from workflowai.core.utils._tools import _get_type_schema, tool_schema # pyright: ignore [reportPrivateUsage]
11+
12+
13+
class TestGetTypeSchema:
14+
class _BasicEnum(str, Enum):
15+
A = "a"
16+
B = "b"
17+
18+
class _BasicModel(BaseModel):
19+
a: int
20+
b: str
21+
22+
@pytest.mark.parametrize(
23+
("param_type", "value"),
24+
[
25+
(int, 1),
26+
(float, 1.0),
27+
(bool, True),
28+
(str, "test"),
29+
(datetime, datetime.now(tz=ZoneInfo("UTC"))),
30+
(ZoneInfo, ZoneInfo("UTC")),
31+
(list[int], [1, 2, 3]),
32+
(dict[str, int], {"a": 1, "b": 2}),
33+
(_BasicEnum, _BasicEnum.A),
34+
(_BasicModel, _BasicModel(a=1, b="test")),
35+
(list[_BasicModel], [_BasicModel(a=1, b="test"), _BasicModel(a=2, b="test2")]),
36+
(tuple[int, str], (1, "test")),
37+
],
38+
)
39+
def test_get_type_schema(self, param_type: Any, value: Any):
40+
schema = _get_type_schema(param_type)
41+
if schema.serializer is None or schema.deserializer is None:
42+
assert schema.serializer is None
43+
assert schema.deserializer is None
44+
45+
# Check that the value is serializable and deserializable with plain json
46+
assert json.loads(json.dumps(value)) == value
47+
return
48+
49+
serialized = schema.serializer(value)
50+
deserialized = schema.deserializer(serialized)
51+
assert deserialized == value
752

853

954
class TestToolSchema:
@@ -17,6 +62,7 @@ def sample_func(
1762
age: int,
1863
height: float,
1964
is_active: bool,
65+
date: datetime,
2066
mode: TestMode = TestMode.FAST,
2167
) -> bool:
2268
"""Sample function for testing"""
@@ -44,8 +90,12 @@ def sample_func(
4490
"type": "string",
4591
"enum": ["fast", "slow"],
4692
},
93+
"date": {
94+
"type": "string",
95+
"format": "date-time",
96+
},
4797
},
48-
"required": ["name", "age", "height", "is_active"], # 'mode' is not required
98+
"required": ["name", "age", "height", "is_active", "date"], # 'mode' is not required
4999
}
50100
assert output_schema.schema == {
51101
"type": "boolean",
@@ -123,3 +173,21 @@ def sample_func() -> TestModel: ...
123173
}
124174
assert output_schema.serializer is not None
125175
assert output_schema.serializer(TestModel(val=10)) == {"val": 10}
176+
177+
def test_with_datetime_in_input(self):
178+
def sample_func(time: datetime) -> str: ...
179+
180+
input_schema, _ = tool_schema(sample_func)
181+
182+
assert input_schema.deserializer is not None
183+
assert input_schema.deserializer({"time": "2024-01-01T12:00:00+00:00"}) == {
184+
"time": datetime(
185+
2024,
186+
1,
187+
1,
188+
12,
189+
0,
190+
0,
191+
tzinfo=ZoneInfo("UTC"),
192+
),
193+
}

0 commit comments

Comments
 (0)