diff --git a/README.md b/README.md
index 1b357ea..65daf76 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,9 @@
# UnionChatBot
+
+
+
+
`UnionChatBot` - REST-сервис с агентом на базе LangChain/LangGraph и YandexGPT API.
Проект включает:
- FastAPI API слой;
diff --git a/coverage.xml b/coverage.xml
new file mode 100644
index 0000000..6c4e691
--- /dev/null
+++ b/coverage.xml
@@ -0,0 +1,1855 @@
+
+
+
+
+
+ /Users/aleksandrsamofalov/PycharmProjects/GeneralPurposeChatBot/src
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/pyproject.toml b/pyproject.toml
index 07101b2..b906068 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -61,6 +61,10 @@ dev = [
"ruff",
"pylint",
"pre-commit",
+ "pytest",
+ "pytest-asyncio",
+ "httpx",
+ "pytest-cov>=7.1.0",
]
[tool.ruff]
diff --git a/src/agents/profkom_consultant/nodes/base.py b/src/agents/profkom_consultant/nodes/base.py
index e019b2e..78498b2 100644
--- a/src/agents/profkom_consultant/nodes/base.py
+++ b/src/agents/profkom_consultant/nodes/base.py
@@ -1,9 +1,32 @@
+from contextlib import asynccontextmanager
+
from langchain_core.prompts import ChatPromptTemplate
+from langfuse.callback import CallbackHandler
from agents.profkom_consultant.states import AgentState
+from service.logger.context_vars import current_span, current_trace
class BaseAgentNodes:
+ @asynccontextmanager
+ async def _node_span(self, name: str, state: AgentState):
+ trace = current_trace.get()
+ span = None
+ if trace:
+ span = trace.span(name=name, input={"text": state.get("text", "")})
+ current_span.set(span)
+ try:
+ yield span
+ finally:
+ if span:
+ span.end()
+ current_span.set(None)
+
+ def _llm_config(self, span):
+ if span:
+ return {"callbacks": [CallbackHandler(stateful_client=span)]}
+ return {}
+
async def validate_text(self, state: AgentState) -> AgentState:
"""Проверяем, что текст вопроса пользователя соответсвует публичной политики.
@@ -14,39 +37,40 @@ async def validate_text(self, state: AgentState) -> AgentState:
Return:
Бинарное значение - спам или нормальный вопрос к агенту.
"""
- question = state["text"]
- try:
- cached_result = self.cache.get(meta_info="validate_input", query=question)
- if cached_result:
- self.logger.debug(f"Cached result {cached_result}")
- state["is_valid"] = cached_result.get("json").get("is_valid")
- state["final_answer"] = cached_result.get("json").get("final_answer")
- return state
- else:
- self.logger.debug(f"Cached result {cached_result}")
- prompt = ChatPromptTemplate.from_template(
- self.langfuse_client.get_prompt("policy_validation").get_langchain_prompt() # TO DO: FIX
- )
- chain = prompt | self.llm
- output = await chain.ainvoke({"text": state["text"]})
- output = output.content.strip().lower()
- self.logger.info(f"Output: {output}")
-
- is_valid = "да" in output
+ async with self._node_span("validate_text", state) as span:
+ question = state["text"]
+ try:
+ cached_result = self.cache.get(meta_info="validate_input", query=question)
+ if cached_result:
+ self.logger.debug(f"Cached result {cached_result}")
+ state["is_valid"] = cached_result.get("json").get("is_valid")
+ state["final_answer"] = cached_result.get("json").get("final_answer")
+ return state
+ else:
+ self.logger.debug(f"Cached result {cached_result}")
+ prompt = ChatPromptTemplate.from_template(
+ self.langfuse_client.get_prompt("policy_validation").get_langchain_prompt() # TO DO: FIX
+ )
+ chain = prompt | self.llm
+ output = await chain.ainvoke({"text": state["text"]}, config=self._llm_config(span))
+ output = output.content.strip().lower()
+ self.logger.info(f"Output: {output}")
+
+ is_valid = "да" in output
+
+ cache_data = {"is_valid": is_valid}
- cache_data = {"is_valid": is_valid}
-
- if not is_valid:
- cache_data["final_answer"] = "Не прошёл валидацию"
- state["final_answer"] = cache_data["final_answer"]
+ if not is_valid:
+ cache_data["final_answer"] = "Не прошёл валидацию"
+ state["final_answer"] = cache_data["final_answer"]
- state["is_valid"] = is_valid
- self.logger.debug(f"is_valid: {is_valid}")
- self.cache.save(meta_info="validate_input", query=question, output="", json_data=cache_data)
- return state
+ state["is_valid"] = is_valid
+ self.logger.debug(f"is_valid: {is_valid}")
+ self.cache.save(meta_info="validate_input", query=question, output="", json_data=cache_data)
+ return state
- except Exception as e:
- print(f"Validate error at validate_input: {e}")
+ except Exception as e:
+ self.logger.error(f"Validate error at validate_input: {e}")
async def validate_final_answer(self, state: AgentState) -> AgentState:
"""Проверяем, что текст ответа модели соответсвует публичной политики.
@@ -58,28 +82,31 @@ async def validate_final_answer(self, state: AgentState) -> AgentState:
Return:
Бинарное значение - спам или нормальный ответ от агента.
"""
- final_answer = state.get("final_answer", "")
- try:
- cached_result = self.cache.get(meta_info="validate_final_answer", query=final_answer)
- if cached_result:
- state["is_valid"] = cached_result.get("json").get("is_valid") or True
- return state
- else:
- prompt = self.langfuse_client.get_prompt("policy_validation").get_langchain_prompt()
- prompt = ChatPromptTemplate.from_template(prompt)
- chain = prompt | self.llm
- output = await chain.ainvoke({"text": final_answer})
-
- is_valid = "да" in output.content.strip().lower()
- cache_data = {"answer": is_valid}
- if not is_valid:
- state["final_answer"] = "Не прошёл валидацию"
- state["is_valid"] = is_valid
- self.cache.save(meta_info="validate_final_answer", query=final_answer, output="", json_data=cache_data)
- return state
-
- except Exception as e:
- print(f"Error at validate_final_answer: {e}")
+ async with self._node_span("validate_final_answer", state) as span:
+ final_answer = state.get("final_answer", "")
+ try:
+ cached_result = self.cache.get(meta_info="validate_final_answer", query=final_answer)
+ if cached_result:
+ state["is_valid"] = cached_result.get("json").get("is_valid") or True
+ return state
+ else:
+ prompt = self.langfuse_client.get_prompt("policy_validation").get_langchain_prompt()
+ prompt = ChatPromptTemplate.from_template(prompt)
+ chain = prompt | self.llm
+ output = await chain.ainvoke({"text": final_answer}, config=self._llm_config(span))
+
+ is_valid = "да" in output.content.strip().lower()
+ cache_data = {"answer": is_valid}
+ if not is_valid:
+ state["final_answer"] = "Не прошёл валидацию"
+ state["is_valid"] = is_valid
+ self.cache.save(
+ meta_info="validate_final_answer", query=final_answer, output="", json_data=cache_data
+ )
+ return state
+
+ except Exception as e:
+ self.logger.error(f"Error at validate_final_answer: {e}")
def update_user_history_context(self, state: AgentState) -> AgentState:
"""Обновляет историю вопросов/ответов: аппендит текущий вопрос + ответ, тримирует до HISTORY_LIMIT.
@@ -101,9 +128,8 @@ def update_user_history_context(self, state: AgentState) -> AgentState:
state["model_answers"] = [state["final_answer"]]
if len(state["user_history"]) > self.HISTORY_LIMIT:
- trim_count = len(state["user_history"]) - self.HISTORY_LIMIT
state["user_history"] = state["user_history"][-self.HISTORY_LIMIT :]
- state["model_answers"] = state["model_answers"][-trim_count:]
+ state["model_answers"] = state["model_answers"][-self.HISTORY_LIMIT :]
return {"user_history": state["user_history"], "model_answers": state["model_answers"]}
diff --git a/src/agents/profkom_consultant/nodes/core.py b/src/agents/profkom_consultant/nodes/core.py
index ff4aeac..1701910 100644
--- a/src/agents/profkom_consultant/nodes/core.py
+++ b/src/agents/profkom_consultant/nodes/core.py
@@ -46,10 +46,14 @@ async def _detect_topics_for_question(self, question: str) -> str:
Returns:
Relevant topic.
"""
+ from service.logger.context_vars import current_span
+
prompt = self.langfuse_client.get_prompt("topic_choose_router").get_langchain_prompt()
prompt = ChatPromptTemplate.from_template(prompt)
chain = prompt | self.llm
- response = await chain.ainvoke({"question": question})
+ span = current_span.get()
+ config = self._llm_config(span)
+ response = await chain.ainvoke({"question": question}, config=config)
return response.content.strip()
async def decompose_question(self, state: AgentState) -> None | dict[str, Any] | dict[str, list[Any]]:
@@ -66,30 +70,35 @@ async def decompose_question(self, state: AgentState) -> None | dict[str, Any] |
Return:
Словарь простых вопросов пользователя.
"""
- question = state["text"]
- try:
- cached_result = self.cache.get(meta_info="decompose_question_" + state["user_id"], query=question)
- if cached_result:
- return {"parts": cached_result.get("json").get("parts")}
- else:
- prompt = self.langfuse_client.get_prompt("decompose_question").get_langchain_prompt()
- prompt = ChatPromptTemplate.from_template(prompt)
- chain = prompt | self.llm
- response = await chain.ainvoke(
- {"user_question": question, "user_history": state.get("user_history", "")}
- )
- response = response.content.strip()
+ async with self._node_span("decompose_question", state) as span:
+ question = state["text"]
+ try:
+ cached_result = self.cache.get(meta_info="decompose_question_" + state["user_id"], query=question)
+ if cached_result:
+ return {"parts": cached_result.get("json").get("parts")}
+ else:
+ prompt = self.langfuse_client.get_prompt("decompose_question").get_langchain_prompt()
+ prompt = ChatPromptTemplate.from_template(prompt)
+ chain = prompt | self.llm
+ response = await chain.ainvoke(
+ {"user_question": question, "user_history": state.get("user_history", "")},
+ config=self._llm_config(span),
+ )
+ response = response.content.strip()
- content = re.search(r"<ЗАДАЧИ.*?>(.*?)ЗАДАЧИ>", response, re.IGNORECASE | re.DOTALL)
- content = content.group(1) if content else response
+ content = re.search(r"<ЗАДАЧИ.*?>(.*?)ЗАДАЧИ>", response, re.IGNORECASE | re.DOTALL)
+ content = content.group(1) if content else response
- cache_data = {"parts": [p.strip() for p in content.split("") if p.strip()]}
- self.cache.save(
- meta_info="decompose_question_" + state["user_id"], query=question, output="", json_data=cache_data
- )
- return cache_data
- except Exception as e:
- print(f"Error at decompose_question: {e}")
+ cache_data = {"parts": [p.strip() for p in content.split("") if p.strip()]}
+ self.cache.save(
+ meta_info="decompose_question_" + state["user_id"],
+ query=question,
+ output="",
+ json_data=cache_data,
+ )
+ return cache_data
+ except Exception as e:
+ self.logger.error(f"Error at decompose_question: {e}")
async def answer_parts_async(self, state: AgentState, max_concurrent: int = 8) -> AgentState:
"""Генерируем асинхронные ответы на список вопросов.
@@ -100,37 +109,42 @@ async def answer_parts_async(self, state: AgentState, max_concurrent: int = 8) -
Returns:
Список простых ответов на глобальный вопрос пользователя.
"""
- state["answers"] = []
- semaphore = asyncio.Semaphore(max_concurrent)
+ async with self._node_span("answer_parts_async", state) as span:
+ state["answers"] = []
+ semaphore = asyncio.Semaphore(max_concurrent)
- prompt = self.langfuse_client.get_prompt("query_worker").get_langchain_prompt()
- prompt = ChatPromptTemplate.from_template(prompt)
- # TO DO: CHECK что мы умеем работать с данными RAG
- chain = prompt | self.llm
+ prompt = self.langfuse_client.get_prompt("query_worker").get_langchain_prompt()
+ prompt = ChatPromptTemplate.from_template(prompt)
+ # TO DO: CHECK что мы умеем работать с данными RAG
+ chain = prompt | self.llm
- async def call_llm(part: str) -> str:
- self.logger.info(f"Calling {part}")
- async with semaphore:
- cached_result = self.cache.get(meta_info="answer_parts_async", query=part)
- if cached_result:
- return cached_result.get("json").get("answer")
- else:
- topic = await self._detect_topics_for_question(part)
- self.logger.info(f"Topic: {topic}")
- retrived_data = await asyncio.to_thread(
- self.chorma_client.get_info, query=part, collection_name=self.COLLECTION_NAME, topics=[topic]
- )
- html_data = retrived_data.to_html()
- result = await chain.ainvoke({"text": part, "rag": html_data})
- cache_data = {"answer": result.content.strip()}
- self.cache.save(meta_info="answer_parts_async", query=part, output="", json_data=cache_data)
- return cache_data.get("answer")
-
- if state.get("parts"):
- tasks = [asyncio.create_task(call_llm(part)) for part in state["parts"]]
- results = await asyncio.gather(*tasks, return_exceptions=True)
- state["answers"] = [str(r) if not isinstance(r, Exception) else f"Error: {r}" for r in results]
- return state
+ async def call_llm(part: str) -> str:
+ self.logger.info(f"Calling {part}")
+ async with semaphore:
+ cached_result = self.cache.get(meta_info="answer_parts_async", query=part)
+ if cached_result:
+ return cached_result.get("json").get("answer")
+ else:
+ topic = await self._detect_topics_for_question(part)
+ self.logger.info(f"Topic: {topic}")
+ retrived_data = await asyncio.to_thread(
+ self.chorma_client.get_info,
+ query=part,
+ collection_name=self.COLLECTION_NAME,
+ topics=[topic],
+ )
+ html_data = retrived_data.to_html()
+ config = self._llm_config(span)
+ result = await chain.ainvoke({"text": part, "rag": html_data}, config=config)
+ cache_data = {"answer": result.content.strip()}
+ self.cache.save(meta_info="answer_parts_async", query=part, output="", json_data=cache_data)
+ return cache_data.get("answer")
+
+ if state.get("parts"):
+ tasks = [asyncio.create_task(call_llm(part)) for part in state["parts"]]
+ results = await asyncio.gather(*tasks, return_exceptions=True)
+ state["answers"] = [str(r) if not isinstance(r, Exception) else f"Error: {r}" for r in results]
+ return state
async def collect_final_answer(self, state: AgentState) -> AgentState:
"""Собираем итоговый ответ на вопрос пользователя.
@@ -143,26 +157,28 @@ async def collect_final_answer(self, state: AgentState) -> AgentState:
Return:
Итоговый текст ответа пользователю на вопрос.
"""
- question = state["text"]
- if state.get("answers"):
- answers_text = "\n".join(f"{i + 1}. {ans}" for i, ans in enumerate(state["answers"]) if ans)
- prompt = self.langfuse_client.get_prompt("summary_response").get_langchain_prompt()
- prompt = ChatPromptTemplate.from_template(prompt)
- chain = prompt | self.llm
- # TO DO: CHECK что у нас огромный промпт не ломает ответ
- response = await chain.ainvoke(
- {
- "task_responses": answers_text,
- "user_history": state.get("user_history", "Нет истории запросов."),
- "original_question": question,
- "model_answers": state.get("model_answers", "Нет истории ответов от модели"),
- "additional_info": state.get(
- "additional_info", "Нет дополнительной информации по предыдущим ответам."
- ),
- }
- )
- response = response.content.strip()
- state["final_answer"] = response
- else:
- state["final_answer"] = "Нет данных для итогового ответа."
- return state
+ async with self._node_span("collect_final_answer", state) as span:
+ question = state["text"]
+ if state.get("answers"):
+ answers_text = "\n".join(f"{i + 1}. {ans}" for i, ans in enumerate(state["answers"]) if ans)
+ prompt = self.langfuse_client.get_prompt("summary_response").get_langchain_prompt()
+ prompt = ChatPromptTemplate.from_template(prompt)
+ chain = prompt | self.llm
+ # TO DO: CHECK что у нас огромный промпт не ломает ответ
+ response = await chain.ainvoke(
+ {
+ "task_responses": answers_text,
+ "user_history": state.get("user_history", "Нет истории запросов."),
+ "original_question": question,
+ "model_answers": state.get("model_answers", "Нет истории ответов от модели"),
+ "additional_info": state.get(
+ "additional_info", "Нет дополнительной информации по предыдущим ответам."
+ ),
+ },
+ config=self._llm_config(span),
+ )
+ response = response.content.strip()
+ state["final_answer"] = response
+ else:
+ state["final_answer"] = "Нет данных для итогового ответа."
+ return state
diff --git a/src/agents/profkom_consultant/nodes/loop.py b/src/agents/profkom_consultant/nodes/loop.py
index f64987a..4a31bc1 100644
--- a/src/agents/profkom_consultant/nodes/loop.py
+++ b/src/agents/profkom_consultant/nodes/loop.py
@@ -21,37 +21,39 @@ async def check_user_answer(self, state: AgentState) -> AgentState:
- status="DONE" если final_answer релевантен text.
- status="AGAIN" + counter_loop +=1 если нет (max 3).
"""
- prompt = self.langfuse_client.get_prompt("check_user_answer").get_langchain_prompt()
- prompt = ChatPromptTemplate.from_template(prompt)
- chain = prompt | self.llm
- response = await chain.ainvoke(
- {
- "question": state["text"],
- "parts": state.get("parts", "[]"),
- "history_questions": state.get("user_history", "[]"),
- "answer": state["final_answer"],
- }
- )
- response = "DONE" in response.content.strip().upper()
-
- if response:
- state["status"] = AgentStatus.DONE
- state["counter_loop"] = 0
- state["additional_info"] = ""
- else:
- counter = state.get("counter_loop", 0)
- if counter >= self.MAX_LOOP_GENERATION:
+ async with self._node_span("check_user_answer", state) as span:
+ prompt = self.langfuse_client.get_prompt("check_user_answer").get_langchain_prompt()
+ prompt = ChatPromptTemplate.from_template(prompt)
+ chain = prompt | self.llm
+ response = await chain.ainvoke(
+ {
+ "question": state["text"],
+ "parts": state.get("parts", "[]"),
+ "history_questions": state.get("user_history", "[]"),
+ "answer": state["final_answer"],
+ },
+ config=self._llm_config(span),
+ )
+ response = "DONE" in response.content.strip().upper()
+
+ if response:
state["status"] = AgentStatus.DONE
state["counter_loop"] = 0
state["additional_info"] = ""
else:
- if not state.get("counter_loop"):
+ counter = state.get("counter_loop", 0)
+ if counter >= self.MAX_LOOP_GENERATION:
+ state["status"] = AgentStatus.DONE
state["counter_loop"] = 0
+ state["additional_info"] = ""
+ else:
+ if not state.get("counter_loop"):
+ state["counter_loop"] = 0
- state["counter_loop"] += 1
- state["additional_info"] = state["final_answer"]
- state["status"] = AgentStatus.AGAIN
- return state
+ state["counter_loop"] += 1
+ state["additional_info"] = state["final_answer"]
+ state["status"] = AgentStatus.AGAIN
+ return state
async def generate_additional_questions(self, state) -> AgentState:
"""Генерируем новые вопросы чтобы ответить на вопрос пользователя.
@@ -63,22 +65,24 @@ async def generate_additional_questions(self, state) -> AgentState:
Return:
Новый список вопросов.
"""
- prompt = self.langfuse_client.get_prompt("generate_additional_questions").get_langchain_prompt()
- prompt = ChatPromptTemplate.from_template(prompt)
- chain = prompt | self.llm
- response = await chain.ainvoke(
- {
- "question": state["text"],
- "history_questions": state.get("user_history", "[]"),
- "answer": state["final_answer"],
- "parts": state.get("parts", "[]"),
- }
- )
-
- response = response.content.strip()
-
- content = re.search(r"<ЗАДАЧИ.*?>(.*?)ЗАДАЧИ>", response, re.IGNORECASE | re.DOTALL)
- content = content.group(1) if content else response
-
- data = {"parts": [p.strip() for p in content.split("") if p.strip()]}
- return data
+ async with self._node_span("generate_additional_questions", state) as span:
+ prompt = self.langfuse_client.get_prompt("generate_additional_questions").get_langchain_prompt()
+ prompt = ChatPromptTemplate.from_template(prompt)
+ chain = prompt | self.llm
+ response = await chain.ainvoke(
+ {
+ "question": state["text"],
+ "history_questions": state.get("user_history", "[]"),
+ "answer": state["final_answer"],
+ "parts": state.get("parts", "[]"),
+ },
+ config=self._llm_config(span),
+ )
+
+ response = response.content.strip()
+
+ content = re.search(r"<ЗАДАЧИ.*?>(.*?)ЗАДАЧИ>", response, re.IGNORECASE | re.DOTALL)
+ content = content.group(1) if content else response
+
+ data = {"parts": [p.strip() for p in content.split("") if p.strip()]}
+ return data
diff --git a/src/modules/chroma_ext/base.py b/src/modules/chroma_ext/base.py
index 4da0798..5729779 100644
--- a/src/modules/chroma_ext/base.py
+++ b/src/modules/chroma_ext/base.py
@@ -5,6 +5,7 @@
from chromadb import QueryResult
from service.logger import LoggerConfigurator
+from service.logger.context_vars import current_span, current_trace
from .utils import BM25Reranker, MyEmbeddingFunction
@@ -84,6 +85,15 @@ def embedding_function(self):
self.logger.debug("embedding_function initialized")
return self._embedding_function
+ def _start_span(self, name: str, input_data: dict):
+ span = current_span.get()
+ if span:
+ return span.span(name=name, input=input_data)
+ trace = current_trace.get()
+ if trace:
+ return trace.span(name=name, input=input_data)
+ return None
+
def get_info_from_db(
self, query: str, collection_name: str, n_results: int = 30, where: dict | None = None, **kwargs
) -> QueryResult:
@@ -99,15 +109,33 @@ def get_info_from_db(
Returns:
relevant documents
"""
- self.logger.debug(f"get_info_from_db called for {collection_name}")
- collection = self.client.get_collection(name=collection_name, embedding_function=self.embedding_function)
-
- return collection.query(
- query_texts=[query],
- n_results=n_results,
- include=["documents", "metadatas", "distances"],
- where=where,
+ span = self._start_span(
+ "chroma_query",
+ {
+ "query": query,
+ "collection": collection_name,
+ "n_results": n_results,
+ "where": where,
+ },
)
+ try:
+ self.logger.debug(f"get_info_from_db called for {collection_name}")
+ collection = self.client.get_collection(name=collection_name, embedding_function=self.embedding_function)
+
+ result = collection.query(
+ query_texts=[query],
+ n_results=n_results,
+ include=["documents", "metadatas", "distances"],
+ where=where,
+ )
+ if span:
+ docs = result.get("documents", [[]])[0]
+ span.end(output={"documents_returned": len(docs)})
+ return result
+ except Exception as e:
+ if span:
+ span.end(level="ERROR", status_message=str(e))
+ raise
def get_filtered_documents(self, data_raw: Dict[str, Any]) -> dict:
self.logger.debug(f"get_filtered_documents: documents number {len(data_raw['documents'])}")
@@ -137,41 +165,58 @@ def apply_reranker(self, query, documents):
def get_info(self, query: str, collection_name: str, topics: list[str] | None = None) -> pd.DataFrame:
# TO DO: фильтрация по метаданным и потом только query!
- self.logger.debug(f"called {query} in get_info for {collection_name} and topics {topics}")
-
- where = None
- if topics:
- # один topic можно передать прямо строкой, несколько — через $in
- if len(topics) == 1:
- where = {"topic": topics[0]}
- else:
- where = {"topic": {"$in": topics}}
-
- data_raw = self.get_info_from_db(
- query=query,
- collection_name=collection_name,
- n_results=self.max_rag_documents,
- where=where,
+ span = self._start_span(
+ "chroma_rag",
+ {
+ "query": query,
+ "collection": collection_name,
+ "topics": topics,
+ },
)
- filtered_documents = self.get_filtered_documents(data_raw)
-
- if not filtered_documents["documents"]:
- self.logger.debug(f"no documents found in {collection_name}")
+ try:
+ self.logger.debug(f"called {query} in get_info for {collection_name} and topics {topics}")
+
+ where = None
+ if topics:
+ # один topic можно передать прямо строкой, несколько — через $in
+ if len(topics) == 1:
+ where = {"topic": topics[0]}
+ else:
+ where = {"topic": {"$in": topics}}
+
+ data_raw = self.get_info_from_db(
+ query=query,
+ collection_name=collection_name,
+ n_results=self.max_rag_documents,
+ where=where,
+ )
+ filtered_documents = self.get_filtered_documents(data_raw)
+
+ if not filtered_documents["documents"]:
+ self.logger.debug(f"no documents found in {collection_name}")
+ if span:
+ span.end(output={"documents_found": 0})
+ return pd.DataFrame.from_dict(
+ data={
+ "documents": [],
+ "metadatas": [],
+ }
+ )
+
+ idx_relevant_documents = self.apply_reranker(query=query, documents=filtered_documents["documents"])
+ self.logger.debug(f"Finished get_info for {query} returned {len(idx_relevant_documents)} documents")
+ if span:
+ span.end(output={"documents_found": len(idx_relevant_documents)})
return pd.DataFrame.from_dict(
data={
- "documents": [],
- "metadatas": [],
+ "documents": [filtered_documents["documents"][idx] for idx in idx_relevant_documents],
+ "metadatas": [filtered_documents["metadatas"][idx] for idx in idx_relevant_documents],
}
)
-
- idx_relevant_documents = self.apply_reranker(query=query, documents=filtered_documents["documents"])
- self.logger.debug(f"Finished get_info for {query} returned {len(idx_relevant_documents)} documents")
- return pd.DataFrame.from_dict(
- data={
- "documents": [filtered_documents["documents"][idx] for idx in idx_relevant_documents],
- "metadatas": [filtered_documents["metadatas"][idx] for idx in idx_relevant_documents],
- }
- )
+ except Exception as e:
+ if span:
+ span.end(level="ERROR", status_message=str(e))
+ raise
def health_check(self) -> bool:
"""Simple Chroma check"""
diff --git a/src/modules/postgres_ext/base.py b/src/modules/postgres_ext/base.py
index 89abba6..a97c799 100644
--- a/src/modules/postgres_ext/base.py
+++ b/src/modules/postgres_ext/base.py
@@ -116,8 +116,8 @@ async def get_pool_stats(self) -> dict[str, Any] | None:
return None
stats = self._pool.get_stats()
self.logger.info(
- f"Postgres.get_pool_stats: {id(self): pool_size={stats.get('pool_size', 0)}}"
- f"pool_available={stats.get('pool_available', 0)}"
+ f"Postgres.get_pool_stats: id={id(self)} pool_size={stats.get('pool_size', 0)} "
+ f"pool_available={stats.get('pool_available', 0)} "
f"request_waiting={stats.get('request_waiting', 0)}"
)
return stats
diff --git a/src/modules/redis_ext/base.py b/src/modules/redis_ext/base.py
index 79a9d48..445d931 100644
--- a/src/modules/redis_ext/base.py
+++ b/src/modules/redis_ext/base.py
@@ -7,6 +7,7 @@
from langchain_redis import RedisSemanticCache
from service.logger import LoggerConfigurator
+from service.logger.context_vars import current_span, current_trace
class RedisAdapter:
@@ -36,31 +37,58 @@ def __init__(
self.logger.info(f"REDIS_THRESHOLD: {self.redis_threshold}")
self.logger.info(f"REDIS_TTL: {self.redis_ttl}")
+ def _start_span(self, name: str, input_data: dict):
+ span = current_span.get()
+ if span:
+ return span.span(name=name, input=input_data)
+ trace = current_trace.get()
+ if trace:
+ return trace.span(name=name, input=input_data)
+ return None
+
def save(self, meta_info: str, query: str = "", output: str = "", json_data: Optional[dict] = None):
"""
output=str в text, json_data=dict в metadata.
"""
- # self.logger.debug("saving query")
- metadata = {"json": json_data} if json_data else {}
- metadata["query"] = query
- metadata["output"] = output
+ span = self._start_span("redis_save", {"query": query, "meta_info": meta_info})
+ try:
+ metadata = {"json": json_data} if json_data else {}
+ metadata["query"] = query
+ metadata["output"] = output
- json_str = json.dumps(metadata)
+ json_str = json.dumps(metadata)
- result = [Generation(text=json_str)]
- self.semantic_cache.update(query, meta_info, result)
+ result = [Generation(text=json_str)]
+ self.semantic_cache.update(query, meta_info, result)
+ if span:
+ span.end(output={"status": "saved"})
+ except Exception as e:
+ if span:
+ span.end(level="ERROR", status_message=str(e))
+ raise
def get(self, meta_info: str, query: str = "") -> Optional[Dict[str, Any]]:
"""Возвращает полный dict из JSON в text."""
- # self.logger.debug("getting query")
- result = self.semantic_cache.lookup(query, meta_info)
- if result:
- try:
- return json.loads(result[0].text)
- except json.JSONDecodeError as e:
- self.logger.error(f"JSON decode error: {e}")
- return None
- return None
+ span = self._start_span("redis_get", {"query": query, "meta_info": meta_info})
+ try:
+ result = self.semantic_cache.lookup(query, meta_info)
+ if result:
+ parsed = json.loads(result[0].text)
+ if span:
+ span.end(output={"hit": True})
+ return parsed
+ if span:
+ span.end(output={"hit": False})
+ return None
+ except json.JSONDecodeError as e:
+ self.logger.error(f"JSON decode error: {e}")
+ if span:
+ span.end(level="ERROR", status_message=str(e))
+ return None
+ except Exception as e:
+ if span:
+ span.end(level="ERROR", status_message=str(e))
+ raise
def health_check(self) -> bool:
"""Simple health check"""
diff --git a/src/service/api/v1/router.py b/src/service/api/v1/router.py
index c0f8eab..4669aa7 100644
--- a/src/service/api/v1/router.py
+++ b/src/service/api/v1/router.py
@@ -10,6 +10,7 @@
from agents.profkom_consultant import AgentStatus, build_builder
from service.config import APP_CONFIG
from service.context import APP_CTX
+from service.logger.context_vars import current_trace
from . import schemas
from .schemas import AgentChatRequest, AgentChatResponse, FailedDependecyResponse, YandexGPTAPITestResponse
@@ -76,10 +77,16 @@ async def chat(
agent_graph = build_builder(agent=APP_CTX.get_agent(), checkpointer=checkpointer)
langfuse = await APP_CTX.get_langfuse()
+ trace = langfuse.client.trace(
+ name="chat",
+ user_id=headers.get("x-user-id"),
+ session_id=headers.get("x-trace-id"),
+ metadata={"stage": APP_CONFIG.app.stage},
+ )
+ current_trace.set(trace)
config = {
"configurable": {"thread_id": headers.get("x-user-id")},
- "callbacks": [langfuse.handler],
"metadata": {
"stage": APP_CONFIG.app.stage,
"langfuse_session_id": headers.get("x-trace-id"),
diff --git a/src/service/logger/context_vars.py b/src/service/logger/context_vars.py
index 0518d22..12d8ffc 100644
--- a/src/service/logger/context_vars.py
+++ b/src/service/logger/context_vars.py
@@ -1,7 +1,11 @@
from contextvars import ContextVar
+from typing import Any
from .models import ContextLog
+current_trace: ContextVar[Any | None] = ContextVar("current_trace", default=None)
+current_span: ContextVar[Any | None] = ContextVar("current_span", default=None)
+
class ContextVarsContainer:
@property
@@ -46,4 +50,4 @@ def get_context_vars(self):
)
-__all__ = ["ContextVarsContainer"]
+__all__ = ["ContextVarsContainer", "current_trace", "current_span"]
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/agents/__init__.py b/tests/agents/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/agents/profkom_consultant/__init__.py b/tests/agents/profkom_consultant/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/agents/profkom_consultant/nodes/__init__.py b/tests/agents/profkom_consultant/nodes/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/agents/profkom_consultant/nodes/test_base.py b/tests/agents/profkom_consultant/nodes/test_base.py
new file mode 100644
index 0000000..6b0777b
--- /dev/null
+++ b/tests/agents/profkom_consultant/nodes/test_base.py
@@ -0,0 +1,52 @@
+from unittest.mock import MagicMock
+
+import pytest
+
+from agents.profkom_consultant.nodes.base import BaseAgentNodes
+
+
+@pytest.fixture
+def agent():
+ instance = BaseAgentNodes.__new__(BaseAgentNodes)
+ instance.HISTORY_LIMIT = 3
+ return instance
+
+
+class TestUpdateUserHistoryContext:
+ def test_appends_question_and_answer(self, agent):
+ state = {
+ "text": "Новый вопрос",
+ "final_answer": "Новый ответ",
+ }
+
+ result = agent.update_user_history_context(state)
+
+ assert result["user_history"] == ["Новый вопрос"]
+ assert result["model_answers"] == ["Новый ответ"]
+
+ def test_trims_to_history_limit(self, agent):
+ state = {
+ "user_history": ["вопрос 1", "вопрос 2", "вопрос 3"],
+ "model_answers": ["ответ 1", "ответ 2", "ответ 3"],
+ "text": "вопрос 4",
+ "final_answer": "ответ 4",
+ }
+
+ result = agent.update_user_history_context(state)
+
+ assert result["user_history"] == ["вопрос 2", "вопрос 3", "вопрос 4"]
+ assert result["model_answers"] == ["ответ 2", "ответ 3", "ответ 4"]
+
+ def test_maintains_one_to_one_sync_after_trim(self, agent):
+ state = {
+ "user_history": ["вопрос 1", "вопрос 2"],
+ "model_answers": ["ответ 1", "ответ 2"],
+ "text": "вопрос 3",
+ "final_answer": "ответ 3",
+ }
+
+ result = agent.update_user_history_context(state)
+
+ assert len(result["user_history"]) == len(result["model_answers"])
+ assert result["user_history"][-1] == "вопрос 3"
+ assert result["model_answers"][-1] == "ответ 3"
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 0000000..154f766
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,33 @@
+import pytest
+import pytest_asyncio
+from httpx import AsyncClient, ASGITransport
+
+from service.api import create_app
+from service.context import APP_CTX
+
+
+@pytest.fixture
+def app(monkeypatch):
+ async def _noop(*args, **kwargs):
+ pass
+
+ monkeypatch.setattr(APP_CTX, "on_startup", _noop)
+ monkeypatch.setattr(APP_CTX, "on_shutdown", _noop)
+ return create_app()
+
+
+@pytest_asyncio.fixture
+async def async_client(app):
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
+ yield client
+
+
+@pytest.fixture
+def mock_headers():
+ return {
+ "x-trace-id": "test-trace-id",
+ "x-request-time": "2024-01-01T00:00:00+03:00",
+ "x-source-name": "pytest",
+ "x-user-id": "test-user-id",
+ }
diff --git a/tests/modules/chroma_ext/__init__.py b/tests/modules/chroma_ext/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/modules/chroma_ext/scripts/__init__.py b/tests/modules/chroma_ext/scripts/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/modules/chroma_ext/scripts/test_data_reader.py b/tests/modules/chroma_ext/scripts/test_data_reader.py
new file mode 100644
index 0000000..01b6dc6
--- /dev/null
+++ b/tests/modules/chroma_ext/scripts/test_data_reader.py
@@ -0,0 +1,117 @@
+from pathlib import Path
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from modules.chroma_ext.scripts.data_reader import (
+ DocumentChunk,
+ _build_topic_prefix,
+ _calc_signature,
+ _read_docx,
+ _split_into_chunks,
+ load_docx_with_metadata,
+)
+
+
+class TestReadDocx:
+ @patch("modules.chroma_ext.scripts.data_reader.docx2txt.process")
+ def test_returns_stripped_text(self, mock_process):
+ mock_process.return_value = " hello world \n\n"
+ result = _read_docx(Path("/fake/doc.docx"))
+ assert result == "hello world"
+
+ @patch("modules.chroma_ext.scripts.data_reader.docx2txt.process")
+ def test_returns_empty_when_none(self, mock_process):
+ mock_process.return_value = None
+ assert _read_docx(Path("/fake/doc.docx")) == ""
+
+
+class TestSplitIntoChunks:
+ def test_empty_text(self):
+ assert _split_into_chunks("") == []
+
+ def test_exact_size_no_overlap(self):
+ text = "a" * 10
+ chunks = _split_into_chunks(text, chunk_size=5, chunk_overlap=0)
+ assert chunks == ["a" * 5, "a" * 5]
+
+ def test_overlap(self):
+ text = "a" * 10
+ chunks = _split_into_chunks(text, chunk_size=6, chunk_overlap=2)
+ assert chunks == ["a" * 6, "a" * 6]
+
+ def test_single_chunk_when_text_shorter(self):
+ text = "short"
+ chunks = _split_into_chunks(text, chunk_size=100, chunk_overlap=10)
+ assert chunks == ["short"]
+
+
+class TestCalcSignature:
+ def test_deterministic(self):
+ assert _calc_signature("hello") == _calc_signature("hello")
+ assert _calc_signature("hello") != _calc_signature("world")
+
+
+class TestBuildTopicPrefix:
+ def test_empty(self):
+ assert _build_topic_prefix("") == ""
+
+ def test_truncates_to_max_tokens(self):
+ text = "one two three four five"
+ assert _build_topic_prefix(text, max_tokens=3) == "one two three"
+
+ def test_full_when_short(self):
+ text = "one two"
+ assert _build_topic_prefix(text, max_tokens=10) == "one two"
+
+
+class TestLoadDocxWithMetadata:
+ @patch("modules.chroma_ext.scripts.data_reader.docx2txt.process")
+ def test_skips_empty_files(self, mock_process, tmp_path):
+ mock_process.return_value = ""
+ (tmp_path / "empty.docx").write_text("fake")
+ logger = MagicMock()
+ result = load_docx_with_metadata(logger, tmp_path)
+ assert result == []
+
+ @patch("modules.chroma_ext.scripts.data_reader.docx2txt.process")
+ def test_loads_single_file_with_topic_prefix(self, mock_process, tmp_path):
+ text = "Title of document. " + "body " * 200
+ mock_process.return_value = text
+ (tmp_path / "contract.docx").write_text("fake")
+ logger = MagicMock()
+ result = load_docx_with_metadata(logger, tmp_path, chunk_size=50, chunk_overlap=10, topic_tokens=5)
+
+ assert len(result) >= 1
+ chunk = result[0]
+ assert isinstance(chunk, DocumentChunk)
+ assert chunk.id == "contract.docx::chunk:0"
+ assert chunk.metadata["filename"] == "contract.docx"
+ assert chunk.metadata["topic"] == "general"
+ assert "Title of" in chunk.text
+ assert chunk.metadata["file_signature"] == _calc_signature(text.strip())
+
+ @patch("modules.chroma_ext.scripts.data_reader.docx2txt.process")
+ def test_nested_directory_topic(self, mock_process, tmp_path):
+ text = "some content here"
+ mock_process.return_value = text
+ nested = tmp_path / "hr"
+ nested.mkdir()
+ (nested / "rules.docx").write_text("fake")
+ logger = MagicMock()
+ result = load_docx_with_metadata(logger, tmp_path)
+ assert len(result) == 1
+ assert result[0].metadata["topic"] == "hr"
+ assert result[0].id == "hr/rules.docx::chunk:0"
+
+ @patch("modules.chroma_ext.scripts.data_reader.docx2txt.process")
+ def test_multiple_chunks(self, mock_process, tmp_path):
+ text = "a" * 100
+ mock_process.return_value = text
+ (tmp_path / "long.docx").write_text("fake")
+ logger = MagicMock()
+ result = load_docx_with_metadata(logger, tmp_path, chunk_size=30, chunk_overlap=5)
+ assert len(result) > 1
+ for idx, chunk in enumerate(result):
+ assert chunk.metadata["chunk_index"] == idx
+ assert chunk.metadata["num_chunks"] == len(result)
diff --git a/tests/modules/chroma_ext/scripts/test_db_writer.py b/tests/modules/chroma_ext/scripts/test_db_writer.py
new file mode 100644
index 0000000..b3e400e
--- /dev/null
+++ b/tests/modules/chroma_ext/scripts/test_db_writer.py
@@ -0,0 +1,138 @@
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from modules.chroma_ext.scripts.data_reader import DocumentChunk
+from modules.chroma_ext.scripts.db_writer import (
+ _collect_current_sources,
+ _group_by_source,
+ sync_docx_directory_to_collection,
+)
+
+
+class TestGroupBySource:
+ def test_groups_by_source(self):
+ chunks = [
+ DocumentChunk(id="1", text="a", metadata={"source": "/a.docx"}),
+ DocumentChunk(id="2", text="b", metadata={"source": "/a.docx"}),
+ DocumentChunk(id="3", text="c", metadata={"source": "/b.docx"}),
+ ]
+ grouped = _group_by_source(chunks)
+ assert len(grouped["/a.docx"]) == 2
+ assert len(grouped["/b.docx"]) == 1
+
+
+class TestCollectCurrentSources:
+ def test_collects_docx_paths(self, tmp_path):
+ (tmp_path / "a.docx").write_text("x")
+ nested = tmp_path / "sub"
+ nested.mkdir()
+ (nested / "b.docx").write_text("y")
+ result = _collect_current_sources(str(tmp_path))
+ assert result == {str(tmp_path / "a.docx"), str(tmp_path / "sub" / "b.docx")}
+
+
+class TestSyncDocxDirectoryToCollection:
+ @patch("modules.chroma_ext.scripts.db_writer.chromadb.HttpClient")
+ @patch("modules.chroma_ext.scripts.db_writer.MyEmbeddingFunction")
+ @patch("modules.chroma_ext.scripts.db_writer.load_docx_with_metadata")
+ def test_no_chunks_early_return(self, mock_load, mock_embed, mock_client):
+ mock_load.return_value = []
+ logger = MagicMock()
+ sync_docx_directory_to_collection(
+ logger, "/docs", "test_collection", api_key="k", folder_id="f", host="h", port=8000
+ )
+ logger.warning.assert_called_once_with("No .docx files found, nothing to index")
+ # HttpClient is created before the empty check in current implementation
+ mock_client.assert_called_once()
+
+ @patch("modules.chroma_ext.scripts.db_writer.chromadb.HttpClient")
+ @patch("modules.chroma_ext.scripts.db_writer.MyEmbeddingFunction")
+ @patch("modules.chroma_ext.scripts.db_writer.load_docx_with_metadata")
+ def test_unchanged_file_skipped(self, mock_load, mock_embed, mock_client):
+ collection = MagicMock()
+ collection.get.return_value = {
+ "ids": ["old"],
+ "metadatas": [{"file_signature": "sig1"}],
+ }
+ mock_client.return_value.get_or_create_collection.return_value = collection
+
+ chunks = [
+ DocumentChunk(
+ id="f::chunk:0",
+ text="txt",
+ metadata={"source": "/f.docx", "file_signature": "sig1"},
+ )
+ ]
+ mock_load.return_value = chunks
+ logger = MagicMock()
+
+ sync_docx_directory_to_collection(
+ logger, "/docs", "test_collection", api_key="k", folder_id="f", host="h", port=8000
+ )
+ collection.add.assert_not_called()
+ collection.delete.assert_not_called()
+
+ @patch("modules.chroma_ext.scripts.db_writer.chromadb.HttpClient")
+ @patch("modules.chroma_ext.scripts.db_writer.MyEmbeddingFunction")
+ @patch("modules.chroma_ext.scripts.db_writer.load_docx_with_metadata")
+ def test_changed_file_deletes_and_adds(self, mock_load, mock_embed, mock_client):
+ collection = MagicMock()
+ collection.get.side_effect = [
+ {
+ "ids": ["old"],
+ "metadatas": [{"file_signature": "old_sig"}],
+ },
+ {"ids": [], "metadatas": []},
+ ]
+ mock_client.return_value.get_or_create_collection.return_value = collection
+
+ chunks = [
+ DocumentChunk(
+ id="f::chunk:0",
+ text="new text",
+ metadata={"source": "/f.docx", "file_signature": "new_sig"},
+ )
+ ]
+ mock_load.return_value = chunks
+ logger = MagicMock()
+
+ sync_docx_directory_to_collection(
+ logger, "/docs", "test_collection", api_key="k", folder_id="f", host="h", port=8000
+ )
+ collection.delete.assert_any_call(where={"source": "/f.docx"})
+ collection.add.assert_called_once_with(
+ ids=["f::chunk:0"],
+ documents=["new text"],
+ metadatas=[{"source": "/f.docx", "file_signature": "new_sig"}],
+ )
+
+ @patch("modules.chroma_ext.scripts.db_writer.chromadb.HttpClient")
+ @patch("modules.chroma_ext.scripts.db_writer.MyEmbeddingFunction")
+ @patch("modules.chroma_ext.scripts.db_writer.load_docx_with_metadata")
+ def test_removes_orphaned_sources(self, mock_load, mock_embed, mock_client, tmp_path):
+ collection = MagicMock()
+ collection.get.side_effect = [
+ {"ids": [], "metadatas": []},
+ {
+ "ids": ["old"],
+ "metadatas": [{"source": str(tmp_path / "gone.docx")}],
+ },
+ ]
+ mock_client.return_value.get_or_create_collection.return_value = collection
+
+ (tmp_path / "keep.docx").write_text("x")
+ chunks = [
+ DocumentChunk(
+ id="keep::chunk:0",
+ text="keep",
+ metadata={"source": str(tmp_path / "keep.docx"), "file_signature": "sig"},
+ )
+ ]
+ mock_load.return_value = chunks
+ logger = MagicMock()
+
+ sync_docx_directory_to_collection(
+ logger, str(tmp_path), "test_collection", api_key="k", folder_id="f", host="h", port=8000
+ )
+ collection.delete.assert_any_call(where={"source": str(tmp_path / "gone.docx")})
diff --git a/tests/modules/chroma_ext/test_base.py b/tests/modules/chroma_ext/test_base.py
new file mode 100644
index 0000000..261d92d
--- /dev/null
+++ b/tests/modules/chroma_ext/test_base.py
@@ -0,0 +1,210 @@
+from unittest.mock import MagicMock, patch
+
+import pandas as pd
+import pytest
+
+from modules.chroma_ext.base import ChromaAdapter
+
+
+@pytest.fixture
+def adapter():
+ logger = MagicMock()
+ with patch("modules.chroma_ext.base.chromadb.HttpClient") as MockClient, \
+ patch("modules.chroma_ext.base.BM25Reranker") as MockReranker:
+ mock_client = MagicMock()
+ MockClient.return_value = mock_client
+ mock_reranker = MagicMock()
+ MockReranker.return_value = mock_reranker
+ inst = ChromaAdapter(
+ logger=logger,
+ similarity_filter=1.0,
+ reranker_type="bm25",
+ text_type="query",
+ API_KEY="key123456789",
+ FOLDER_ID="fld123456789",
+ CHROMA_HOST="testhost",
+ CHROMA_PORT=9000,
+ CHROMA_TOPK_DOCUMENTS=3,
+ CHROMA_MAX_RAG_DOCUMENTS=10,
+ )
+ inst._mock_client = mock_client
+ inst._mock_reranker = mock_reranker
+ return inst
+
+
+class TestChromaAdapterInit:
+ def test_validation_errors(self):
+ logger = MagicMock()
+ with patch("modules.chroma_ext.base.chromadb.HttpClient"):
+ # FOLDER_ID=None and API_KEY=None currently raise TypeError due to slice
+ # before validation (bug in code)
+ with pytest.raises(TypeError):
+ ChromaAdapter(logger=logger, API_KEY="key", FOLDER_ID=None)
+ with pytest.raises(TypeError):
+ ChromaAdapter(logger=logger, API_KEY=None, FOLDER_ID="fld123456789")
+ with pytest.raises(ValueError, match="TOPK"):
+ ChromaAdapter(logger=logger, API_KEY="key", FOLDER_ID="fld123456789", CHROMA_TOPK_DOCUMENTS=20, CHROMA_MAX_RAG_DOCUMENTS=20)
+
+ def test_unsupported_reranker_does_not_raise(self):
+ # Current code instantiates NotImplementedError but does not raise it (bug).
+ logger = MagicMock()
+ with patch("modules.chroma_ext.base.chromadb.HttpClient"):
+ adapter = ChromaAdapter(
+ logger=logger,
+ API_KEY="key123456789",
+ FOLDER_ID="fld123456789",
+ reranker_type="unknown",
+ )
+ assert getattr(adapter, "reranker", None) is None
+
+ def test_params_set(self, adapter):
+ assert adapter.host == "testhost"
+ assert adapter.port == 9000
+ assert adapter.topk_documents == 3
+ assert adapter.max_rag_documents == 10
+ assert adapter.similarity_filter == 1.0
+
+
+class TestChromaAdapterEmbeddingFunction:
+ @patch("modules.chroma_ext.base.MyEmbeddingFunction")
+ def test_lazy_initialization(self, MockEmb, adapter):
+ mock_ef = MagicMock()
+ MockEmb.return_value = mock_ef
+ ef = adapter.embedding_function
+ assert ef is mock_ef
+ MockEmb.assert_called_once()
+ # second call returns cached
+ assert adapter.embedding_function is mock_ef
+
+
+class TestChromaAdapterStartSpan:
+ def test_prefers_current_span(self, adapter):
+ parent = MagicMock()
+ child = MagicMock()
+ parent.span.return_value = child
+ with patch("modules.chroma_ext.base.current_span") as mock_cs, \
+ patch("modules.chroma_ext.base.current_trace") as mock_ct:
+ mock_cs.get.return_value = parent
+ mock_ct.get.return_value = MagicMock()
+ result = adapter._start_span("chroma_test", {"a": 1})
+ assert result is child
+
+ def test_fallback_to_trace(self, adapter):
+ trace = MagicMock()
+ child = MagicMock()
+ trace.span.return_value = child
+ with patch("modules.chroma_ext.base.current_span") as mock_cs, \
+ patch("modules.chroma_ext.base.current_trace") as mock_ct:
+ mock_cs.get.return_value = None
+ mock_ct.get.return_value = trace
+ result = adapter._start_span("chroma_test", {"a": 1})
+ assert result is child
+
+ def test_none_when_no_context(self, adapter):
+ with patch("modules.chroma_ext.base.current_span") as mock_cs, \
+ patch("modules.chroma_ext.base.current_trace") as mock_ct:
+ mock_cs.get.return_value = None
+ mock_ct.get.return_value = None
+ assert adapter._start_span("chroma_test", {"a": 1}) is None
+
+
+class TestChromaAdapterGetInfoFromDb:
+ def test_success(self, adapter):
+ span = MagicMock()
+ adapter._start_span = MagicMock(return_value=span)
+ mock_collection = MagicMock()
+ mock_collection.query.return_value = {
+ "documents": [["doc1", "doc2"]],
+ "metadatas": [[{"m": 1}, {"m": 2}]],
+ "distances": [[0.1, 0.2]],
+ }
+ adapter._mock_client.get_collection.return_value = mock_collection
+
+ result = adapter.get_info_from_db("q", "coll", n_results=5, where={"topic": "t"})
+ assert result["documents"][0] == ["doc1", "doc2"]
+ span.end.assert_called_once_with(output={"documents_returned": 2})
+
+ def test_error_ends_span(self, adapter):
+ span = MagicMock()
+ adapter._start_span = MagicMock(return_value=span)
+ adapter._mock_client.get_collection.side_effect = RuntimeError("chroma down")
+ with pytest.raises(RuntimeError, match="chroma down"):
+ adapter.get_info_from_db("q", "coll")
+ span.end.assert_called_once_with(level="ERROR", status_message="chroma down")
+
+
+class TestChromaAdapterGetFilteredDocuments:
+ def test_filters_by_distance_and_strips_body(self, adapter):
+ data_raw = {
+ "documents": [["keep1", "keep2", "drop"]],
+ "metadatas": [[{"a": 1}, {"a": 2}, {"a": 3}]],
+ "distances": [[0.5, 0.9, 1.5]],
+ }
+ result = adapter.get_filtered_documents(data_raw)
+ assert result["documents"] == ["keep1", "keep2"]
+ assert result["metadatas"] == [{"a": 1}, {"a": 2}]
+
+
+class TestChromaAdapterGetPairs:
+ def test_builds_pairs(self, adapter):
+ result = adapter.get_pairs("query", ["d1", "d2"])
+ assert result == [["query", "d1"], ["query", "d2"]]
+
+
+class TestChromaAdapterApplyReranker:
+ def test_delegates_to_bm25(self, adapter):
+ adapter._mock_reranker.rerank.return_value = [1, 0]
+ idx = adapter.apply_reranker("q", ["d1", "d2", "d3"])
+ adapter._mock_reranker.fit.assert_called_once_with(["d1", "d2", "d3"])
+ adapter._mock_reranker.rerank.assert_called_once_with(query="q", top_k=3)
+ assert idx == [1, 0]
+
+
+class TestChromaAdapterGetInfo:
+ def test_full_flow_returns_dataframe(self, adapter):
+ span = MagicMock()
+ adapter._start_span = MagicMock(return_value=span)
+ mock_collection = MagicMock()
+ mock_collection.query.return_value = {
+ "documents": [["d1", "d2"]],
+ "metadatas": [[{"t": "a"}, {"t": "b"}]],
+ "distances": [[0.1, 0.2]],
+ }
+ adapter._mock_client.get_collection.return_value = mock_collection
+ adapter._mock_reranker.rerank.return_value = [0]
+
+ df = adapter.get_info("query", "coll", topics=["a", "b"])
+ assert isinstance(df, pd.DataFrame)
+ assert df["documents"].tolist() == ["d1"]
+ # get_info creates its own span and get_info_from_db creates another;
+ # last end call belongs to the outer chroma_rag span
+ span.end.assert_called_with(output={"documents_found": 1})
+
+ def test_empty_filtered_documents(self, adapter):
+ span = MagicMock()
+ adapter._start_span = MagicMock(return_value=span)
+ mock_collection = MagicMock()
+ mock_collection.query.return_value = {
+ "documents": [["drop"]],
+ "metadatas": [[{"t": "a"}]],
+ "distances": [[2.0]],
+ }
+ adapter._mock_client.get_collection.return_value = mock_collection
+
+ df = adapter.get_info("query", "coll")
+ assert isinstance(df, pd.DataFrame)
+ assert df.empty
+ span.end.assert_called_with(output={"documents_found": 0})
+
+ def test_exception_ends_span(self, adapter):
+ span = MagicMock()
+ adapter._start_span = MagicMock(return_value=span)
+ adapter._mock_client.get_collection.side_effect = ValueError("fail")
+ with pytest.raises(ValueError, match="fail"):
+ adapter.get_info("query", "coll")
+ span.end.assert_called_with(level="ERROR", status_message="fail")
+
+
+class TestChromaAdapterHealthCheck:
+ def test_always_true(self, adapter):
+ assert adapter.health_check() is True
diff --git a/tests/modules/chroma_ext/utils/__init__.py b/tests/modules/chroma_ext/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/modules/chroma_ext/utils/test_embedings.py b/tests/modules/chroma_ext/utils/test_embedings.py
new file mode 100644
index 0000000..c89562d
--- /dev/null
+++ b/tests/modules/chroma_ext/utils/test_embedings.py
@@ -0,0 +1,134 @@
+from unittest.mock import MagicMock, patch
+
+import numpy as np
+import pytest
+
+from modules.chroma_ext.utils.embedings import MyEmbeddingFunction
+
+
+@pytest.fixture
+def embedder():
+ logger = MagicMock()
+ return MyEmbeddingFunction(
+ logger=logger,
+ folder_id="b1g2d3f4",
+ iam_token="t0k3n-12345678",
+ doc_model_uri="doc-uri",
+ query_model_uri="query-uri",
+ text_type="doc",
+ time_sleep=0,
+ max_retries=2,
+ request_timeout=5,
+ batch_size=2,
+ sleep_between_batches=0,
+ )
+
+
+class TestMyEmbeddingFunctionInit:
+ def test_defaults_and_kwargs(self, embedder):
+ assert embedder.api_url == "https://llm.api.cloud.yandex.net:443/foundationModels/v1/textEmbedding"
+ assert embedder.folder_id == "b1g2d3f4"
+ assert embedder.iam_token == "t0k3n-12345678"
+ assert embedder.text_type == "doc"
+ assert embedder.doc_model_uri == "doc-uri"
+ assert embedder.query_model_uri == "query-uri"
+ assert embedder.max_retries == 2
+ assert embedder.batch_size == 2
+
+
+class TestMyEmbeddingFunctionGetSingleEmbedding:
+ @patch("modules.chroma_ext.utils.embedings.time.sleep", return_value=None)
+ @patch("modules.chroma_ext.utils.embedings.requests.post")
+ def test_success_returns_ndarray(self, mock_post, mock_sleep, embedder):
+ mock_resp = MagicMock()
+ mock_resp.ok = True
+ mock_resp.json.return_value = {"embedding": [0.1, 0.2, 0.3]}
+ mock_post.return_value = mock_resp
+
+ result = embedder._get_single_embedding("hello")
+ assert isinstance(result, np.ndarray)
+ np.testing.assert_array_equal(result, np.array([0.1, 0.2, 0.3]))
+
+ @patch("modules.chroma_ext.utils.embedings.time.sleep", return_value=None)
+ @patch("modules.chroma_ext.utils.embedings.requests.post")
+ def test_transient_retries_then_success(self, mock_post, mock_sleep, embedder):
+ bad_resp = MagicMock()
+ bad_resp.ok = False
+ bad_resp.status_code = 503
+ bad_resp.text = "busy"
+
+ good_resp = MagicMock()
+ good_resp.ok = True
+ good_resp.json.return_value = {"embedding": [1.0]}
+
+ mock_post.side_effect = [bad_resp, good_resp]
+
+ result = embedder._get_single_embedding("hello")
+ np.testing.assert_array_equal(result, np.array([1.0]))
+ assert mock_post.call_count == 2
+ embedder.logger.warning.assert_called()
+
+ @patch("modules.chroma_ext.utils.embedings.time.sleep", return_value=None)
+ @patch("modules.chroma_ext.utils.embedings.requests.post")
+ def test_non_transient_4xx_raises(self, mock_post, mock_sleep, embedder):
+ bad_resp = MagicMock()
+ bad_resp.ok = False
+ bad_resp.status_code = 400
+ bad_resp.text = "bad request"
+ bad_resp.raise_for_status.side_effect = Exception("HTTP 400")
+ mock_post.return_value = bad_resp
+
+ with pytest.raises(Exception, match="HTTP 400"):
+ embedder._get_single_embedding("hello")
+
+ @patch("modules.chroma_ext.utils.embedings.time.sleep", return_value=None)
+ @patch("modules.chroma_ext.utils.embedings.requests.post")
+ def test_timeout_retries_then_raises(self, mock_post, mock_sleep, embedder):
+ from requests.exceptions import ConnectTimeout
+
+ mock_post.side_effect = ConnectTimeout("timeout")
+
+ with pytest.raises(ConnectTimeout):
+ embedder._get_single_embedding("hello")
+ assert mock_post.call_count == 2
+
+
+class TestMyEmbeddingFunctionBatched:
+ def test_batched_exact_and_remainder(self, embedder):
+ result = list(embedder._batched(["a", "b", "c", "d", "e"], 2))
+ assert result == [["a", "b"], ["c", "d"], ["e"]]
+
+ def test_batched_empty(self, embedder):
+ result = list(embedder._batched([], 3))
+ assert result == []
+
+
+class TestMyEmbeddingFunctionCall:
+ @patch("modules.chroma_ext.utils.embedings.time.sleep", return_value=None)
+ @patch("modules.chroma_ext.utils.embedings.requests.post")
+ def test_call_single_string(self, mock_post, mock_sleep, embedder):
+ mock_resp = MagicMock()
+ mock_resp.ok = True
+ mock_resp.json.return_value = {"embedding": [0.5]}
+ mock_post.return_value = mock_resp
+
+ result = embedder("hello")
+ assert isinstance(result, list)
+ assert len(result) == 1
+ assert isinstance(result[0], np.ndarray)
+
+ @patch("modules.chroma_ext.utils.embedings.time.sleep", return_value=None)
+ @patch("modules.chroma_ext.utils.embedings.requests.post")
+ def test_call_batches_with_sleep(self, mock_post, mock_sleep, embedder):
+ mock_resp = MagicMock()
+ mock_resp.ok = True
+ mock_resp.json.return_value = {"embedding": [0.1]}
+ mock_post.return_value = mock_resp
+
+ embedder.batch_size = 2
+ result = embedder(["a", "b", "c"])
+ assert isinstance(result, list)
+ assert len(result) == 3
+ assert all(isinstance(r, np.ndarray) for r in result)
+ # sleep вызывается: base sleep для каждого запроса + sleep_between_batches
+ assert mock_sleep.call_count >= 1
diff --git a/tests/modules/chroma_ext/utils/test_reranker.py b/tests/modules/chroma_ext/utils/test_reranker.py
new file mode 100644
index 0000000..02b1bc1
--- /dev/null
+++ b/tests/modules/chroma_ext/utils/test_reranker.py
@@ -0,0 +1,60 @@
+from unittest.mock import MagicMock
+
+import pytest
+
+from modules.chroma_ext.utils.reranker import BM25Reranker
+
+
+@pytest.fixture
+def reranker():
+ logger = MagicMock()
+ return BM25Reranker(logger=logger, tokenizer_name="gpt-3.5-turbo")
+
+
+class TestBM25RerankerInit:
+ def test_initializes_tokenizer(self, reranker):
+ assert reranker.tokenizer is not None
+ reranker.logger.info.assert_called()
+
+
+class TestBM25RerankerPreprocess:
+ def test_preprocess_lowercases_and_tokenizes(self, reranker):
+ tokens = reranker.preprocess("Hello World")
+ assert isinstance(tokens, list)
+ assert len(tokens) > 0
+ # tiktoken tokens converted back to strings; "hello" and " world" or similar
+ text = " ".join(tokens)
+ assert "hello" in text.lower()
+
+ def test_preprocess_filters_empty(self, reranker):
+ tokens = reranker.preprocess("")
+ assert tokens == []
+
+
+class TestBM25RerankerFit:
+ def test_fit_builds_bm25(self, reranker):
+ reranker.fit(["first document", "second document"])
+ assert reranker.bm25 is not None
+
+
+class TestBM25RerankerRerank:
+ def test_rerank_before_fit_raises(self, reranker):
+ with pytest.raises(ValueError, match="not fitted"):
+ reranker.rerank("query", top_k=2)
+
+ def test_rerank_returns_top_k_indices(self, reranker):
+ docs = [
+ "python programming language",
+ "cooking recipes for dinner",
+ "python snakes in the wild",
+ ]
+ reranker.fit(docs)
+ indices = reranker.rerank("python code", top_k=2)
+ assert len(indices) == 2
+ assert all(isinstance(i, int) for i in indices)
+
+ def test_rerank_top_k_larger_than_docs(self, reranker):
+ docs = ["only one document here"]
+ reranker.fit(docs)
+ indices = reranker.rerank("query", top_k=5)
+ assert len(indices) == 1
diff --git a/tests/modules/langfuse_ext/__init__.py b/tests/modules/langfuse_ext/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/modules/langfuse_ext/test_base.py b/tests/modules/langfuse_ext/test_base.py
new file mode 100644
index 0000000..a367133
--- /dev/null
+++ b/tests/modules/langfuse_ext/test_base.py
@@ -0,0 +1,84 @@
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from modules.langfuse_ext.base import LangfuseClient
+
+
+@pytest.fixture
+def client():
+ config = MagicMock()
+ config.host = "https://langfuse.test"
+ config.secret_key = "secret_key_123"
+ config.public_key = "public_key_456"
+ config.stage = "test"
+ logger = MagicMock()
+ return LangfuseClient(app_config=config, logger=logger)
+
+
+class TestLangfuseClientInit:
+ def test_logs_masked_keys(self, client):
+ client.logger.debug.assert_any_call("Secret Key: secr**_123")
+ client.logger.debug.assert_any_call("Public Key: publ**_456")
+ assert client.client is not None
+ assert client.handler is not None
+
+
+class TestLangfuseClientCreateClient:
+ @patch("modules.langfuse_ext.base.Langfuse")
+ def test_creates_langfuse_instance(self, MockLangfuse):
+ config = MagicMock()
+ config.host = "h"
+ config.secret_key = "s"
+ config.public_key = "p"
+ logger = MagicMock()
+ client = LangfuseClient(app_config=config, logger=logger)
+ # access property to trigger creation
+ _ = client._LangfuseClient__create_client
+ MockLangfuse.assert_called_with(secret_key="s", public_key="p", host="h")
+
+
+class TestLangfuseClientCreateCallbackHandler:
+ @patch("modules.langfuse_ext.base.CallbackHandler")
+ def test_creates_handler(self, MockHandler):
+ config = MagicMock()
+ config.host = "h"
+ config.secret_key = "s"
+ config.public_key = "p"
+ config.stage = "stage"
+ logger = MagicMock()
+ client = LangfuseClient(app_config=config, logger=logger)
+ _ = client._LangfuseClient__create_callback_handler
+ MockHandler.assert_called_with(
+ public_key="p", secret_key="s", host="h", trace_name="stage"
+ )
+
+
+class TestLangfuseClientHealthCheck:
+ def test_true_when_auth_ok(self, client):
+ client.client = MagicMock()
+ client.client.auth_check.return_value = True
+ assert client.health_check() is True
+
+ def test_false_when_auth_fails(self, client):
+ client.client = MagicMock()
+ client.client.auth_check.return_value = False
+ assert client.health_check() is False
+
+
+class TestLangfuseClientOnStartup:
+ @pytest.mark.asyncio
+ async def test_reassigns_and_checks(self, client):
+ client.client = MagicMock()
+ client.handler = MagicMock()
+ with patch.object(client, "health_check", return_value=True) as mock_hc:
+ await client.on_startup()
+ mock_hc.assert_called_once()
+
+ @pytest.mark.asyncio
+ async def test_catches_exceptions(self, client):
+ with patch.object(
+ type(client), "_LangfuseClient__create_client", property(lambda self: (_ for _ in ()).throw(RuntimeError("boom")))
+ ):
+ # Should not raise
+ await client.on_startup()
diff --git a/tests/modules/postgres_ext/__init__.py b/tests/modules/postgres_ext/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/modules/postgres_ext/test_base.py b/tests/modules/postgres_ext/test_base.py
new file mode 100644
index 0000000..40473f0
--- /dev/null
+++ b/tests/modules/postgres_ext/test_base.py
@@ -0,0 +1,155 @@
+import asyncio
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+
+from modules.postgres_ext.base import PostgresClient
+
+
+@pytest.fixture
+def client():
+ config = MagicMock()
+ config.encoded_pass = "secret"
+ config.user = "u"
+ config.host = "h"
+ config.port = 5432
+ config.postgres_db = "db"
+ config.pool_min_size = 1
+ config.pool_max_size = 5
+ config.pool_max_idle = 30.0
+ config.conninfo = "postgresql://u:secret@h:5432/db"
+ logger = MagicMock()
+ return PostgresClient(config=config, logger=logger)
+
+
+class TestPostgresClientInit:
+ def test_pool_starts_none(self, client):
+ assert client._pool is None
+ assert isinstance(client._lock, asyncio.Lock)
+ client.logger.info.assert_called()
+
+
+@pytest.mark.asyncio
+class TestPostgresClientEnsurePool:
+ async def test_creates_pool_once(self, client):
+ with patch("modules.postgres_ext.base.AsyncConnectionPool") as MockPool:
+ mock_pool = AsyncMock()
+ mock_pool.get_stats = MagicMock(return_value={"pool_size": 1})
+ MockPool.return_value = mock_pool
+
+ await client.ensure_pool()
+ await client.ensure_pool()
+
+ MockPool.assert_called_once_with(
+ conninfo=client.settings.conninfo,
+ min_size=client.settings.pool_min_size,
+ max_size=client.settings.pool_max_size,
+ max_idle=client.settings.pool_max_idle,
+ )
+ assert mock_pool.open.await_count == 1
+
+ async def test_race_condition_safe(self, client):
+ with patch("modules.postgres_ext.base.AsyncConnectionPool") as MockPool:
+ mock_pool = AsyncMock()
+ mock_pool.get_stats = MagicMock(return_value={"pool_size": 1})
+ MockPool.return_value = mock_pool
+
+ async def task():
+ await client.ensure_pool()
+
+ await asyncio.gather(task(), task(), task())
+ MockPool.assert_called_once()
+
+
+@pytest.mark.asyncio
+class TestPostgresClientTakeConnInLoop:
+ async def test_happy_path(self, client):
+ with patch("modules.postgres_ext.base.AsyncConnectionPool") as MockPool:
+ mock_pool = AsyncMock()
+ mock_pool.get_stats = MagicMock(return_value={"pool_size": 1})
+ mock_conn = AsyncMock()
+ mock_pool.getconn.return_value = mock_conn
+ MockPool.return_value = mock_pool
+ await client.ensure_pool()
+
+ conn = await client._take_conn_in_loop(0, 3)
+ assert conn is mock_conn
+ mock_conn.execute.assert_any_call("SELECT 1")
+ mock_conn.execute.assert_any_call("ROLLBACK")
+
+ async def test_retries_then_none(self, client):
+ with patch("modules.postgres_ext.base.AsyncConnectionPool") as MockPool:
+ mock_pool = AsyncMock()
+ mock_pool.get_stats = MagicMock(return_value={"pool_size": 1})
+ mock_conn = AsyncMock()
+ mock_conn.execute.side_effect = RuntimeError("dead")
+ mock_pool.getconn.return_value = mock_conn
+ MockPool.return_value = mock_pool
+ await client.ensure_pool()
+
+ conn = await client._take_conn_in_loop(0, 2)
+ assert conn is None
+ assert mock_conn.close.await_count == 2
+
+
+@pytest.mark.asyncio
+class TestPostgresClientGetUserCheckpointer:
+ async def test_yields_saver_and_returns_conn(self, client):
+ with patch("modules.postgres_ext.base.AsyncConnectionPool") as MockPool, \
+ patch("modules.postgres_ext.base.AsyncPostgresSaver") as MockSaver:
+ mock_pool = AsyncMock()
+ mock_pool.get_stats = MagicMock(return_value={"pool_size": 1})
+ mock_conn = AsyncMock()
+ mock_pool.getconn.return_value = mock_conn
+ MockPool.return_value = mock_pool
+ mock_saver = MagicMock()
+ MockSaver.return_value = mock_saver
+
+ await client.ensure_pool()
+ async with client.get_user_checkpointer() as saver:
+ assert saver is mock_saver
+
+ mock_conn.set_autocommit.assert_awaited_once_with(True)
+ mock_pool.putconn.assert_awaited_once_with(mock_conn)
+
+
+@pytest.mark.asyncio
+class TestPostgresClientGetPoolStats:
+ async def test_none_when_no_pool(self, client):
+ assert await client.get_pool_stats() is None
+
+ async def test_returns_stats(self, client):
+ with patch("modules.postgres_ext.base.AsyncConnectionPool") as MockPool:
+ mock_pool = AsyncMock()
+ mock_pool.get_stats = MagicMock(return_value={"pool_size": 2})
+ MockPool.return_value = mock_pool
+ await client.ensure_pool()
+ stats = await client.get_pool_stats()
+ assert stats == {"pool_size": 2}
+
+
+@pytest.mark.asyncio
+class TestPostgresClientClose:
+ async def test_closes_and_nulls_pool(self, client):
+ with patch("modules.postgres_ext.base.AsyncConnectionPool") as MockPool:
+ mock_pool = AsyncMock()
+ mock_pool.get_stats = MagicMock(return_value={"pool_size": 1})
+ MockPool.return_value = mock_pool
+ await client.ensure_pool()
+ await client.close()
+ mock_pool.close.assert_awaited_once()
+ assert client._pool is None
+
+ async def test_idempotent(self, client):
+ await client.close()
+ assert client._pool is None
+
+
+class TestPostgresClientHealthCheck:
+ def test_false_before_init(self, client):
+ assert client.health_check() is False
+
+ def test_true_after_init(self, client):
+ # We can't easily async ensure_pool in sync test, so set pool manually
+ client._pool = MagicMock()
+ assert client.health_check() is True
diff --git a/tests/modules/redis_ext/__init__.py b/tests/modules/redis_ext/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/modules/redis_ext/test_base.py b/tests/modules/redis_ext/test_base.py
new file mode 100644
index 0000000..d88ed01
--- /dev/null
+++ b/tests/modules/redis_ext/test_base.py
@@ -0,0 +1,186 @@
+import json
+from unittest.mock import ANY, MagicMock, patch
+
+import pytest
+
+from modules.redis_ext.base import RedisAdapter
+
+
+@pytest.fixture
+def mock_embeddings():
+ return MagicMock()
+
+
+@pytest.fixture
+def mock_logger():
+ return MagicMock()
+
+
+@pytest.fixture
+def adapter(mock_logger, mock_embeddings):
+ with patch("modules.redis_ext.base.RedisSemanticCache") as MockCache:
+ mock_cache = MagicMock()
+ MockCache.return_value = mock_cache
+ inst = RedisAdapter(
+ logger=mock_logger,
+ embeddings=mock_embeddings,
+ redis_url="redis://test:6379",
+ redis_threshold=0.1,
+ redis_ttl=60,
+ )
+ inst._mock_cache = mock_cache
+ return inst
+
+
+class TestRedisAdapterInit:
+ def test_uses_explicit_args(self, mock_logger, mock_embeddings):
+ with patch("modules.redis_ext.base.RedisSemanticCache") as MockCache:
+ RedisAdapter(
+ logger=mock_logger,
+ embeddings=mock_embeddings,
+ redis_url="redis://explicit:6379",
+ redis_threshold=0.2,
+ redis_ttl=120,
+ )
+ MockCache.assert_called_once_with(
+ redis_url="redis://explicit:6379",
+ embeddings=mock_embeddings,
+ distance_threshold=0.2,
+ ttl=120,
+ )
+
+ def test_fallback_to_env_vars(self, monkeypatch, mock_logger, mock_embeddings):
+ monkeypatch.setenv("REDIS_URL", "redis://env:6379")
+ monkeypatch.setenv("REDIS_THRESHOLD", "0.3")
+ monkeypatch.setenv("REDIS_TTL", "240")
+ with patch("modules.redis_ext.base.RedisSemanticCache") as MockCache:
+ RedisAdapter(
+ logger=mock_logger,
+ embeddings=mock_embeddings,
+ redis_url=None,
+ redis_threshold=None,
+ redis_ttl=None,
+ )
+ MockCache.assert_called_once_with(
+ redis_url="redis://env:6379",
+ embeddings=mock_embeddings,
+ distance_threshold=0.3,
+ ttl=240,
+ )
+
+
+class TestRedisAdapterSave:
+ def test_save_calls_update_and_ends_span(self, adapter):
+ span = MagicMock()
+ adapter._start_span = MagicMock(return_value=span)
+ adapter.save(meta_info="meta", query="q", output="out", json_data={"k": "v"})
+
+ adapter._mock_cache.update.assert_called_once()
+ args = adapter._mock_cache.update.call_args[0]
+ assert args[0] == "q"
+ assert args[1] == "meta"
+ generation = args[2][0]
+ assert json.loads(generation.text)["output"] == "out"
+ span.end.assert_called_once_with(output={"status": "saved"})
+
+ def test_save_error_ends_span_with_error(self, adapter):
+ span = MagicMock()
+ adapter._start_span = MagicMock(return_value=span)
+ adapter._mock_cache.update.side_effect = RuntimeError("boom")
+
+ with pytest.raises(RuntimeError, match="boom"):
+ adapter.save(meta_info="meta", query="q")
+
+ span.end.assert_called_once_with(level="ERROR", status_message="boom")
+
+
+class TestRedisAdapterGet:
+ def test_get_hit_parses_json(self, adapter):
+ span = MagicMock()
+ adapter._start_span = MagicMock(return_value=span)
+ payload = {"output": "hello", "json": {"x": 1}}
+ gen = MagicMock()
+ gen.text = json.dumps(payload)
+ adapter._mock_cache.lookup.return_value = [gen]
+
+ result = adapter.get(meta_info="meta", query="q")
+ assert result == payload
+ span.end.assert_called_once_with(output={"hit": True})
+
+ def test_get_miss_returns_none(self, adapter):
+ span = MagicMock()
+ adapter._start_span = MagicMock(return_value=span)
+ adapter._mock_cache.lookup.return_value = None
+
+ result = adapter.get(meta_info="meta", query="q")
+ assert result is None
+ span.end.assert_called_once_with(output={"hit": False})
+
+ def test_get_json_decode_error_logs_and_returns_none(self, adapter):
+ span = MagicMock()
+ adapter._start_span = MagicMock(return_value=span)
+ gen = MagicMock()
+ gen.text = "not-json"
+ adapter._mock_cache.lookup.return_value = [gen]
+
+ result = adapter.get(meta_info="meta", query="q")
+ assert result is None
+ adapter.logger.error.assert_called()
+ span.end.assert_called_once_with(level="ERROR", status_message=ANY)
+
+ def test_get_exception_propagates(self, adapter):
+ span = MagicMock()
+ adapter._start_span = MagicMock(return_value=span)
+ adapter._mock_cache.lookup.side_effect = ValueError("redis down")
+
+ with pytest.raises(ValueError, match="redis down"):
+ adapter.get(meta_info="meta", query="q")
+
+ span.end.assert_called_once_with(level="ERROR", status_message="redis down")
+
+
+class TestRedisAdapterHealthCheck:
+ def test_true_when_cache_present(self, adapter):
+ assert adapter.health_check() is True
+
+ def test_false_when_cache_missing(self, adapter):
+ adapter.semantic_cache = None
+ assert adapter.health_check() is False
+
+
+class TestRedisAdapterStartSpan:
+ def test_prefers_current_span(self, adapter):
+ child_span = MagicMock()
+ parent_span = MagicMock()
+ parent_span.span.return_value = child_span
+
+ with patch("modules.redis_ext.base.current_span") as mock_cs, \
+ patch("modules.redis_ext.base.current_trace") as mock_ct:
+ mock_cs.get.return_value = parent_span
+ mock_ct.get.return_value = MagicMock()
+
+ result = adapter._start_span("redis_test", {"a": 1})
+ assert result is child_span
+ parent_span.span.assert_called_once_with(name="redis_test", input={"a": 1})
+
+ def test_fallback_to_trace(self, adapter):
+ trace = MagicMock()
+ child = MagicMock()
+ trace.span.return_value = child
+
+ with patch("modules.redis_ext.base.current_span") as mock_cs, \
+ patch("modules.redis_ext.base.current_trace") as mock_ct:
+ mock_cs.get.return_value = None
+ mock_ct.get.return_value = trace
+
+ result = adapter._start_span("redis_test", {"a": 1})
+ assert result is child
+ trace.span.assert_called_once_with(name="redis_test", input={"a": 1})
+
+ def test_none_when_no_context(self, adapter):
+ with patch("modules.redis_ext.base.current_span") as mock_cs, \
+ patch("modules.redis_ext.base.current_trace") as mock_ct:
+ mock_cs.get.return_value = None
+ mock_ct.get.return_value = None
+
+ assert adapter._start_span("redis_test", {"a": 1}) is None
diff --git a/tests/modules/redis_ext/utils/__init__.py b/tests/modules/redis_ext/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/modules/redis_ext/utils/test_RedisAdapters.py b/tests/modules/redis_ext/utils/test_RedisAdapters.py
new file mode 100644
index 0000000..5b5056a
--- /dev/null
+++ b/tests/modules/redis_ext/utils/test_RedisAdapters.py
@@ -0,0 +1,122 @@
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from modules.redis_ext.utils.RedisAdapters import UserRateLimiter
+
+
+@pytest.fixture
+def limiter():
+ mock_logger = MagicMock()
+ with patch("modules.redis_ext.utils.RedisAdapters.redis.Redis") as MockRedis:
+ mock_redis = MagicMock()
+ MockRedis.return_value = mock_redis
+ inst = UserRateLimiter(
+ logger=mock_logger,
+ host="test-host",
+ port=6380,
+ db=3,
+ decode_responses=True,
+ USER_QUERY_LIMIT_N=5,
+ USER_QUERY_LIMIT_TTL_SECONDS=60,
+ RATE_LIMIT_TEMPLATE="rl:{user_id}",
+ )
+ inst._mock_redis = mock_redis
+ return inst
+
+
+class TestUserRateLimiterInit:
+ def test_uses_defaults(self):
+ mock_logger = MagicMock()
+ with patch("modules.redis_ext.utils.RedisAdapters.redis.Redis") as MockRedis:
+ UserRateLimiter(logger=mock_logger)
+ MockRedis.assert_called_once_with(
+ host="127.0.0.1",
+ port=6379,
+ db=2,
+ decode_responses=True,
+ )
+
+ def test_uses_kwargs(self):
+ mock_logger = MagicMock()
+ with patch("modules.redis_ext.utils.RedisAdapters.redis.Redis") as MockRedis:
+ UserRateLimiter(
+ logger=mock_logger,
+ host="h",
+ port=1234,
+ db=7,
+ decode_responses=False,
+ )
+ MockRedis.assert_called_once_with(
+ host="h",
+ port=1234,
+ db=7,
+ decode_responses=False,
+ )
+
+
+class TestUserRateLimiterCheckAndIncrement:
+ def test_new_key_sets_expire(self, limiter):
+ pipe = MagicMock()
+ pipe.incr.return_value = None
+ pipe.ttl.return_value = None
+ pipe.execute.return_value = [1, -2]
+ limiter._mock_redis.pipeline.return_value.__enter__.return_value = pipe
+
+ allowed, count = limiter.check_and_increment("u1")
+ assert allowed is True
+ assert count == 1
+ limiter._mock_redis.expire.assert_called_once_with("rl:u1", 60)
+
+ def test_within_limit_no_expire(self, limiter):
+ pipe = MagicMock()
+ pipe.execute.return_value = [3, 55]
+ limiter._mock_redis.pipeline.return_value.__enter__.return_value = pipe
+
+ allowed, count = limiter.check_and_increment("u1")
+ assert allowed is True
+ assert count == 3
+ limiter._mock_redis.expire.assert_not_called()
+
+ def test_exceeds_limit(self, limiter):
+ pipe = MagicMock()
+ pipe.execute.return_value = [6, 10]
+ limiter._mock_redis.pipeline.return_value.__enter__.return_value = pipe
+
+ allowed, count = limiter.check_and_increment("u1")
+ assert allowed is False
+ assert count == 6
+
+
+class TestUserRateLimiterGetRemaining:
+ def test_key_exists(self, limiter):
+ limiter._mock_redis.get.return_value = "3"
+ assert limiter.get_remaining("u1") == 2
+
+ def test_key_missing(self, limiter):
+ limiter._mock_redis.get.return_value = None
+ assert limiter.get_remaining("u1") == 5
+
+
+class TestUserRateLimiterResetCounter:
+ def test_deletes_key(self, limiter):
+ limiter.reset_counter("u1")
+ limiter._mock_redis.delete.assert_called_once_with("rl:u1")
+
+
+class TestUserRateLimiterTtl:
+ def test_delegates_to_redis(self, limiter):
+ limiter._mock_redis.ttl.return_value = 42
+ assert limiter.ttl("u1") == 42
+ limiter._mock_redis.ttl.assert_called_once_with("rl:u1")
+
+
+class TestUserRateLimiterHealthCheck:
+ def test_healthy(self, limiter):
+ limiter._mock_redis.ping.return_value = True
+ assert limiter.health_check() is True
+
+ def test_unhealthy(self, limiter):
+ limiter._mock_redis.ping.return_value = False
+ assert limiter.health_check() is False
+ limiter.logger.warning.assert_called_once()
diff --git a/tests/service/__init__.py b/tests/service/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/service/api/__init__.py b/tests/service/api/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/service/api/v1/__init__.py b/tests/service/api/v1/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/service/api/v1/test_router.py b/tests/service/api/v1/test_router.py
new file mode 100644
index 0000000..5f3282b
--- /dev/null
+++ b/tests/service/api/v1/test_router.py
@@ -0,0 +1,117 @@
+import sys
+from unittest.mock import AsyncMock, MagicMock
+
+import pytest
+
+from agents.profkom_consultant import AgentStatus
+from service.context import APP_CTX
+
+
+class FakeMessage:
+ def __init__(self, content):
+ self.content = content
+
+
+@pytest.mark.anyio
+async def test_test_invoke_success(async_client, monkeypatch, mock_headers):
+ router_module = sys.modules["service.api.v1.router"]
+ monkeypatch.setattr(
+ router_module,
+ "ChatOpenAI",
+ lambda **kwargs: MagicMock(invoke=lambda question: FakeMessage("Mocked answer")),
+ )
+
+ payload = {
+ "question": "Кто ты воин?",
+ "generation_params": {},
+ }
+
+ response = await async_client.post("/api/v1/test_invoke", json=payload, headers=mock_headers)
+
+ assert response.status_code == 200
+ data = response.json()
+ assert data["answer"] == "Mocked answer"
+
+
+@pytest.mark.anyio
+async def test_test_invoke_failed_dependency(async_client, monkeypatch, mock_headers):
+ def _raise(*args, **kwargs):
+ raise RuntimeError("YandexGPT down")
+
+ router_module = sys.modules["service.api.v1.router"]
+ monkeypatch.setattr(
+ router_module,
+ "ChatOpenAI",
+ lambda **kwargs: MagicMock(invoke=_raise),
+ )
+
+ payload = {"question": "Кто ты воин?"}
+ response = await async_client.post("/api/v1/test_invoke", json=payload, headers=mock_headers)
+
+ assert response.status_code == 424
+ data = response.json()
+ assert "YandexGPT down" in data["error_description"]
+
+
+@pytest.mark.anyio
+async def test_chat_success(async_client, monkeypatch, mock_headers):
+ rate_limiter_mock = MagicMock()
+ rate_limiter_mock.check_and_increment.return_value = (True, 1)
+ monkeypatch.setattr(APP_CTX, "get_ratelimiter", AsyncMock(return_value=rate_limiter_mock))
+
+ checkpointer_mock = AsyncMock()
+ checkpointer_cm = AsyncMock()
+ checkpointer_cm.__aenter__ = AsyncMock(return_value=checkpointer_mock)
+ checkpointer_cm.__aexit__ = AsyncMock(return_value=None)
+
+ postgres_mock = MagicMock()
+ postgres_mock.get_user_checkpointer.return_value = checkpointer_cm
+ monkeypatch.setattr(APP_CTX, "get_postgres_client", AsyncMock(return_value=postgres_mock))
+
+ langfuse_mock = MagicMock()
+ langfuse_mock.client.trace.return_value = MagicMock()
+ monkeypatch.setattr(APP_CTX, "get_langfuse", AsyncMock(return_value=langfuse_mock))
+
+ agent_mock = MagicMock()
+ monkeypatch.setattr(APP_CTX, "get_agent", lambda: agent_mock)
+
+ graph_mock = AsyncMock()
+ graph_mock.ainvoke.return_value = {"final_answer": "Это ответ агента"}
+
+ router_module = sys.modules["service.api.v1.router"]
+ monkeypatch.setattr(router_module, "build_builder", lambda agent, checkpointer: graph_mock)
+
+ payload = {
+ "text": "Как вступить в профсоюз?",
+ "organisation": "ППО Невинномысский Азот",
+ }
+
+ response = await async_client.post("/api/v1/chat", json=payload, headers=mock_headers)
+
+ assert response.status_code == 200
+ data = response.json()
+ assert data["response"] == "Это ответ агента"
+
+ graph_mock.ainvoke.assert_awaited_once()
+ call_kwargs = graph_mock.ainvoke.call_args.kwargs
+ assert call_kwargs["input"]["status"] == AgentStatus.ACTIVE
+
+
+@pytest.mark.anyio
+async def test_chat_rate_limit(async_client, monkeypatch, mock_headers):
+ rate_limiter_mock = MagicMock()
+ rate_limiter_mock.check_and_increment.return_value = (False, 10)
+ rate_limiter_mock.ttl.return_value = 42
+ monkeypatch.setattr(APP_CTX, "get_ratelimiter", AsyncMock(return_value=rate_limiter_mock))
+
+ payload = {
+ "text": "Как вступить в профсоюз?",
+ "organisation": "ППО Невинномысский Азот",
+ }
+
+ response = await async_client.post("/api/v1/chat", json=payload, headers=mock_headers)
+
+ assert response.status_code == 200
+ data = response.json()
+ assert "превысили свой лимит" in data["response"]
+ assert "42" in data["response"]
diff --git a/uv.lock b/uv.lock
index f359195..622a18c 100644
--- a/uv.lock
+++ b/uv.lock
@@ -423,6 +423,30 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/60/97/891a0971e1e4a8c5d2b20bbe0e524dc04548d2307fee33cdeba148fd4fc7/comm-0.2.3-py3-none-any.whl", hash = "sha256:c615d91d75f7f04f095b30d1c1711babd43bdc6419c1be9886a85f2f4e489417", size = 7294, upload-time = "2025-07-25T14:02:02.896Z" },
]
+[[package]]
+name = "coverage"
+version = "7.13.5"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/9d/e0/70553e3000e345daff267cec284ce4cbf3fc141b6da229ac52775b5428f1/coverage-7.13.5.tar.gz", hash = "sha256:c81f6515c4c40141f83f502b07bbfa5c240ba25bbe73da7b33f1e5b6120ff179", size = 915967, upload-time = "2026-03-17T10:33:18.341Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/a0/c3/a396306ba7db865bf96fc1fb3b7fd29bcbf3d829df642e77b13555163cd6/coverage-7.13.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:460cf0114c5016fa841214ff5564aa4864f11948da9440bc97e21ad1f4ba1e01", size = 219554, upload-time = "2026-03-17T10:30:42.208Z" },
+ { url = "https://files.pythonhosted.org/packages/a6/16/a68a19e5384e93f811dccc51034b1fd0b865841c390e3c931dcc4699e035/coverage-7.13.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0e223ce4b4ed47f065bfb123687686512e37629be25cc63728557ae7db261422", size = 219908, upload-time = "2026-03-17T10:30:43.906Z" },
+ { url = "https://files.pythonhosted.org/packages/29/72/20b917c6793af3a5ceb7fb9c50033f3ec7865f2911a1416b34a7cfa0813b/coverage-7.13.5-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:6e3370441f4513c6252bf042b9c36d22491142385049243253c7e48398a15a9f", size = 251419, upload-time = "2026-03-17T10:30:45.545Z" },
+ { url = "https://files.pythonhosted.org/packages/8c/49/cd14b789536ac6a4778c453c6a2338bc0a2fb60c5a5a41b4008328b9acc1/coverage-7.13.5-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:03ccc709a17a1de074fb1d11f217342fb0d2b1582ed544f554fc9fc3f07e95f5", size = 254159, upload-time = "2026-03-17T10:30:47.204Z" },
+ { url = "https://files.pythonhosted.org/packages/9d/00/7b0edcfe64e2ed4c0340dac14a52ad0f4c9bd0b8b5e531af7d55b703db7c/coverage-7.13.5-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3f4818d065964db3c1c66dc0fbdac5ac692ecbc875555e13374fdbe7eedb4376", size = 255270, upload-time = "2026-03-17T10:30:48.812Z" },
+ { url = "https://files.pythonhosted.org/packages/93/89/7ffc4ba0f5d0a55c1e84ea7cee39c9fc06af7b170513d83fbf3bbefce280/coverage-7.13.5-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:012d5319e66e9d5a218834642d6c35d265515a62f01157a45bcc036ecf947256", size = 257538, upload-time = "2026-03-17T10:30:50.77Z" },
+ { url = "https://files.pythonhosted.org/packages/81/bd/73ddf85f93f7e6fa83e77ccecb6162d9415c79007b4bc124008a4995e4a7/coverage-7.13.5-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:8dd02af98971bdb956363e4827d34425cb3df19ee550ef92855b0acb9c7ce51c", size = 251821, upload-time = "2026-03-17T10:30:52.5Z" },
+ { url = "https://files.pythonhosted.org/packages/a0/81/278aff4e8dec4926a0bcb9486320752811f543a3ce5b602cc7a29978d073/coverage-7.13.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f08fd75c50a760c7eb068ae823777268daaf16a80b918fa58eea888f8e3919f5", size = 253191, upload-time = "2026-03-17T10:30:54.543Z" },
+ { url = "https://files.pythonhosted.org/packages/70/ee/fe1621488e2e0a58d7e94c4800f0d96f79671553488d401a612bebae324b/coverage-7.13.5-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:843ea8643cf967d1ac7e8ecd4bb00c99135adf4816c0c0593fdcc47b597fcf09", size = 251337, upload-time = "2026-03-17T10:30:56.663Z" },
+ { url = "https://files.pythonhosted.org/packages/37/a6/f79fb37aa104b562207cc23cb5711ab6793608e246cae1e93f26b2236ed9/coverage-7.13.5-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:9d44d7aa963820b1b971dbecd90bfe5fe8f81cff79787eb6cca15750bd2f79b9", size = 255404, upload-time = "2026-03-17T10:30:58.427Z" },
+ { url = "https://files.pythonhosted.org/packages/75/f0/ed15262a58ec81ce457ceb717b7f78752a1713556b19081b76e90896e8d4/coverage-7.13.5-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:7132bed4bd7b836200c591410ae7d97bf7ae8be6fc87d160b2bd881df929e7bf", size = 250903, upload-time = "2026-03-17T10:31:00.093Z" },
+ { url = "https://files.pythonhosted.org/packages/0f/e9/9129958f20e7e9d4d56d51d42ccf708d15cac355ff4ac6e736e97a9393d2/coverage-7.13.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a698e363641b98843c517817db75373c83254781426e94ada3197cabbc2c919c", size = 252780, upload-time = "2026-03-17T10:31:01.916Z" },
+ { url = "https://files.pythonhosted.org/packages/a4/d7/0ad9b15812d81272db94379fe4c6df8fd17781cc7671fdfa30c76ba5ff7b/coverage-7.13.5-cp312-cp312-win32.whl", hash = "sha256:bdba0a6b8812e8c7df002d908a9a2ea3c36e92611b5708633c50869e6d922fdf", size = 222093, upload-time = "2026-03-17T10:31:03.642Z" },
+ { url = "https://files.pythonhosted.org/packages/29/3d/821a9a5799fac2556bcf0bd37a70d1d11fa9e49784b6d22e92e8b2f85f18/coverage-7.13.5-cp312-cp312-win_amd64.whl", hash = "sha256:d2c87e0c473a10bffe991502eac389220533024c8082ec1ce849f4218dded810", size = 222900, upload-time = "2026-03-17T10:31:05.651Z" },
+ { url = "https://files.pythonhosted.org/packages/d4/fa/2238c2ad08e35cf4f020ea721f717e09ec3152aea75d191a7faf3ef009a8/coverage-7.13.5-cp312-cp312-win_arm64.whl", hash = "sha256:bf69236a9a81bdca3bff53796237aab096cdbf8d78a66ad61e992d9dac7eb2de", size = 221515, upload-time = "2026-03-17T10:31:07.293Z" },
+ { url = "https://files.pythonhosted.org/packages/9e/ee/a4cf96b8ce1e566ed238f0659ac2d3f007ed1d14b181bcb684e19561a69a/coverage-7.13.5-py3-none-any.whl", hash = "sha256:34b02417cf070e173989b3db962f7ed56d2f644307b2cf9d5a0f258e13084a61", size = 211346, upload-time = "2026-03-17T10:33:15.691Z" },
+]
+
[[package]]
name = "cryptography"
version = "46.0.4"
@@ -876,6 +900,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/a4/ed/1f1afb2e9e7f38a545d628f864d562a5ae64fe6f7a10e28ffb9b185b4e89/importlib_resources-6.5.2-py3-none-any.whl", hash = "sha256:789cfdc3ed28c78b67a06acb8126751ced69a3d5f79c095a98298cd8a760ccec", size = 37461, upload-time = "2025-01-03T18:51:54.306Z" },
]
+[[package]]
+name = "iniconfig"
+version = "2.3.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" },
+]
+
[[package]]
name = "ipykernel"
version = "7.1.0"
@@ -2098,6 +2131,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/cb/28/3bfe2fa5a7b9c46fe7e13c97bda14c895fb10fa2ebf1d0abb90e0cea7ee1/platformdirs-4.5.1-py3-none-any.whl", hash = "sha256:d03afa3963c806a9bed9d5125c8f4cb2fdaf74a55ab60e5d59b3fde758104d31", size = 18731, upload-time = "2025-12-05T13:52:56.823Z" },
]
+[[package]]
+name = "pluggy"
+version = "1.6.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
+]
+
[[package]]
name = "ply"
version = "3.11"
@@ -2436,6 +2478,49 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/bd/24/12818598c362d7f300f18e74db45963dbcb85150324092410c8b49405e42/pyproject_hooks-1.2.0-py3-none-any.whl", hash = "sha256:9e5c6bfa8dcc30091c74b0cf803c81fdd29d94f01992a7707bc97babb1141913", size = 10216, upload-time = "2024-09-29T09:24:11.978Z" },
]
+[[package]]
+name = "pytest"
+version = "9.0.3"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "colorama", marker = "sys_platform == 'win32'" },
+ { name = "iniconfig" },
+ { name = "packaging" },
+ { name = "pluggy" },
+ { name = "pygments" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/7d/0d/549bd94f1a0a402dc8cf64563a117c0f3765662e2e668477624baeec44d5/pytest-9.0.3.tar.gz", hash = "sha256:b86ada508af81d19edeb213c681b1d48246c1a91d304c6c81a427674c17eb91c", size = 1572165, upload-time = "2026-04-07T17:16:18.027Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/d4/24/a372aaf5c9b7208e7112038812994107bc65a84cd00e0354a88c2c77a617/pytest-9.0.3-py3-none-any.whl", hash = "sha256:2c5efc453d45394fdd706ade797c0a81091eccd1d6e4bccfcd476e2b8e0ab5d9", size = 375249, upload-time = "2026-04-07T17:16:16.13Z" },
+]
+
+[[package]]
+name = "pytest-asyncio"
+version = "1.3.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "pytest" },
+ { name = "typing-extensions" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/90/2c/8af215c0f776415f3590cac4f9086ccefd6fd463befeae41cd4d3f193e5a/pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5", size = 50087, upload-time = "2025-11-10T16:07:47.256Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/e5/35/f8b19922b6a25bc0880171a2f1a003eaeb93657475193ab516fd87cac9da/pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5", size = 15075, upload-time = "2025-11-10T16:07:45.537Z" },
+]
+
+[[package]]
+name = "pytest-cov"
+version = "7.1.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "coverage" },
+ { name = "pluggy" },
+ { name = "pytest" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/b1/51/a849f96e117386044471c8ec2bd6cfebacda285da9525c9106aeb28da671/pytest_cov-7.1.0.tar.gz", hash = "sha256:30674f2b5f6351aa09702a9c8c364f6a01c27aae0c1366ae8016160d1efc56b2", size = 55592, upload-time = "2026-03-21T20:11:16.284Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/9d/7a/d968e294073affff457b041c2be9868a40c1c71f4a35fcc1e45e5493067b/pytest_cov-7.1.0-py3-none-any.whl", hash = "sha256:a0461110b7865f9a271aa1b51e516c9a95de9d696734a2f71e3e78f46e1d4678", size = 22876, upload-time = "2026-03-21T20:11:14.438Z" },
+]
+
[[package]]
name = "python-dateutil"
version = "2.9.0.post0"
@@ -3144,8 +3229,12 @@ dependencies = [
[package.dev-dependencies]
dev = [
+ { name = "httpx" },
{ name = "pre-commit" },
{ name = "pylint" },
+ { name = "pytest" },
+ { name = "pytest-asyncio" },
+ { name = "pytest-cov" },
{ name = "ruff" },
]
@@ -3182,8 +3271,12 @@ requires-dist = [
[package.metadata.requires-dev]
dev = [
+ { name = "httpx" },
{ name = "pre-commit" },
{ name = "pylint" },
+ { name = "pytest" },
+ { name = "pytest-asyncio" },
+ { name = "pytest-cov", specifier = ">=7.1.0" },
{ name = "ruff" },
]