Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20260219225535725727.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "create_final_text_units streaming"
}
62 changes: 50 additions & 12 deletions packages/graphrag-storage/graphrag_storage/tables/csv_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@

import csv
import inspect
import os
import shutil
import sys
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -63,8 +66,10 @@ def __init__(
transformer: Optional callable to transform each row before
yielding. Receives a dict, returns a transformed dict.
Defaults to identity (no transformation).
truncate: If True (default), truncate file on first write.
If False, append to existing file.
truncate: If True (default), writes go to a temporary file
which is moved over the original on close(). This allows
safe concurrent reads from the original while writes
accumulate. If False, append to existing file.
encoding: Character encoding for reading/writing CSV files.
Defaults to "utf-8".
"""
Expand All @@ -77,6 +82,8 @@ def __init__(
self._write_file: TextIOWrapper | None = None
self._writer: csv.DictWriter | None = None
self._header_written = False
self._temp_path: Path | None = None
self._final_path: Path | None = None

def __aiter__(self) -> AsyncIterator[Any]:
"""Iterate through rows one at a time.
Expand Down Expand Up @@ -128,9 +135,11 @@ async def has(self, row_id: str) -> bool:
async def write(self, row: dict[str, Any]) -> None:
"""Write a single row to the CSV file.

On first write, opens the file. If truncate=True, overwrites any existing
file and writes header. If truncate=False, appends to existing file
(skips header if file exists).
On first write, opens a file handle. When truncate=True, writes
go to a temporary file in the same directory; the temp file is
moved over the original in close(), making it safe to read from
the original while writes are in progress. When truncate=False,
rows are appended directly to the existing file.

Args
----
Expand All @@ -139,13 +148,36 @@ async def write(self, row: dict[str, Any]) -> None:
if isinstance(self._storage, FileStorage) and self._write_file is None:
file_path = self._storage.get_path(self._file_key)
file_path.parent.mkdir(parents=True, exist_ok=True)
file_exists = file_path.exists() and file_path.stat().st_size > 0
mode = "w" if self._truncate else "a"
write_header = self._truncate or not file_exists
self._write_file = Path.open(
file_path, mode, encoding=self._encoding, newline=""

if self._truncate:
fd, tmp = tempfile.mkstemp(
suffix=".csv",
dir=file_path.parent,
)
os.close(fd)
self._temp_path = Path(tmp)
self._final_path = file_path
self._write_file = Path.open(
self._temp_path,
"w",
encoding=self._encoding,
newline="",
)
write_header = True
else:
file_exists = file_path.exists() and file_path.stat().st_size > 0
write_header = not file_exists
self._write_file = Path.open(
file_path,
"a",
encoding=self._encoding,
newline="",
)

self._writer = csv.DictWriter(
self._write_file,
fieldnames=list(row.keys()),
)
self._writer = csv.DictWriter(self._write_file, fieldnames=list(row.keys()))
if write_header:
self._writer.writeheader()
self._header_written = write_header
Expand All @@ -156,10 +188,16 @@ async def write(self, row: dict[str, Any]) -> None:
async def close(self) -> None:
"""Flush buffered writes and release resources.

Closes the file handle if writing was performed.
When truncate=True, the temp file is moved over the original
so that readers never see a partially-written file.
"""
if self._write_file is not None:
self._write_file.close()
self._write_file = None
self._writer = None
self._header_written = False

if self._temp_path is not None and self._final_path is not None:
shutil.move(str(self._temp_path), str(self._final_path))
self._temp_path = None
self._final_path = None
24 changes: 21 additions & 3 deletions packages/graphrag/graphrag/data_model/dfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,31 @@ def _safe_int(series: pd.Series, fill: int = -1) -> pd.Series:
return series.fillna(fill).astype(int)


def _split_list_column(value: Any) -> list[Any]:
"""Split a column containing a list string into an actual list."""
def split_list_column(value: Any) -> list[Any]:
r"""Split a column containing a list string into an actual list.

Handles two CSV serialization formats:
- Comma-separated (standard ``str(list)``): ``"['a', 'b']"``
- Newline-separated (pandas ``to_csv`` of numpy arrays):
``"['a'\\n 'b']"``

Both formats are stripped of brackets, quotes, and whitespace so
that existing indexes produced by the old pandas-based indexer
remain readable.
"""
if isinstance(value, str):
return [item.strip("[] '") for item in value.split(",")] if value else []
if not value:
return []
normalized = value.replace("\n", ",")
items = [item.strip("[] '\"") for item in normalized.split(",")]
return [item for item in items if item]
return value


# Backward-compatible alias so internal callers keep working.
_split_list_column = split_list_column


def entities_typed(df: pd.DataFrame) -> pd.DataFrame:
"""Return the entities dataframe with correct types, in case it was stored in a weakly-typed format."""
if SHORT_ID in df.columns:
Expand Down
242 changes: 242 additions & 0 deletions packages/graphrag/graphrag/data_model/row_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
# Copyright (C) 2026 Microsoft

"""Row-level type coercion for streaming Table reads.

Each transformer converts a raw ``dict[str, Any]`` row (as produced by
``csv.DictReader``, where every value is a string) into a dict with
properly typed fields. They serve the same purpose as the DataFrame-
based ``*_typed`` helpers in ``dfs.py``, but operate on single rows so
they can be passed as the *transformer* argument to
``TableProvider.open()``.
"""

from typing import Any

from graphrag.data_model.dfs import split_list_column


def _safe_int(value: Any, fill: int = -1) -> int:
"""Coerce a value to int, returning *fill* when missing or empty.

Handles NaN from Parquet float-promoted nullable int columns
(``row.to_dict()`` yields ``numpy.float64(nan)``) and any other
non-convertible type.
"""
if value is None or value == "":
return fill
try:
return int(value)
except (ValueError, TypeError):
return fill


def _safe_float(value: Any, fill: float = 0.0) -> float:
"""Coerce a value to float, returning *fill* when missing or empty.

Also applies *fill* when the value is NaN (e.g. from Parquet
nullable columns), keeping behavior consistent with the
DataFrame-level ``fillna(fill).astype(float)`` pattern.
"""
if value is None or value == "":
return fill
try:
result = float(value)
except (ValueError, TypeError):
return fill
# math.isnan without importing math
if result != result:
return fill
return result


def _coerce_list(value: Any) -> list[Any]:
"""Parse a value into a list, handling CSV string and array types.

Handles three cases:
- str: CSV-encoded list (e.g. from CSVTable rows)
- list: already a Python list (pass-through)
- array-like with ``tolist`` (e.g. numpy.ndarray from ParquetTable
``row.to_dict()``)
"""
if isinstance(value, str):
return split_list_column(value)
if isinstance(value, list):
return value
if hasattr(value, "tolist"):
return value.tolist()
return []


# -- entities (mirrors entities_typed) ------------------------------------


def transform_entity_row(row: dict[str, Any]) -> dict[str, Any]:
"""Coerce types for an entity row.

Mirrors ``entities_typed``: ``human_readable_id`` -> int,
``text_unit_ids`` -> list, ``frequency`` -> int, ``degree`` -> int.
"""
if "human_readable_id" in row:
row["human_readable_id"] = _safe_int(
row["human_readable_id"],
)
if "text_unit_ids" in row:
row["text_unit_ids"] = _coerce_list(row["text_unit_ids"])
if "frequency" in row:
row["frequency"] = _safe_int(row["frequency"], 0)
if "degree" in row:
row["degree"] = _safe_int(row["degree"], 0)
return row


# -- relationships (mirrors relationships_typed) --------------------------


def transform_relationship_row(
row: dict[str, Any],
) -> dict[str, Any]:
"""Coerce types for a relationship row.

Mirrors ``relationships_typed``: ``human_readable_id`` -> int,
``weight`` -> float, ``combined_degree`` -> int,
``text_unit_ids`` -> list.
"""
if "human_readable_id" in row:
row["human_readable_id"] = _safe_int(
row["human_readable_id"],
)
if "weight" in row:
row["weight"] = _safe_float(row["weight"])
if "combined_degree" in row:
row["combined_degree"] = _safe_int(
row["combined_degree"],
0,
)
if "text_unit_ids" in row:
row["text_unit_ids"] = _coerce_list(row["text_unit_ids"])
return row


# -- communities (mirrors communities_typed) ------------------------------


def transform_community_row(
row: dict[str, Any],
) -> dict[str, Any]:
"""Coerce types for a community row.

Mirrors ``communities_typed``: ``human_readable_id`` -> int,
``community`` -> int, ``level`` -> int, ``children`` -> list,
``entity_ids`` -> list, ``relationship_ids`` -> list,
``text_unit_ids`` -> list, ``period`` -> str, ``size`` -> int.
"""
if "human_readable_id" in row:
row["human_readable_id"] = _safe_int(
row["human_readable_id"],
)
row["community"] = _safe_int(row.get("community"))
row["level"] = _safe_int(row.get("level"))
row["children"] = _coerce_list(row.get("children"))
if "entity_ids" in row:
row["entity_ids"] = _coerce_list(row["entity_ids"])
if "relationship_ids" in row:
row["relationship_ids"] = _coerce_list(
row["relationship_ids"],
)
if "text_unit_ids" in row:
row["text_unit_ids"] = _coerce_list(row["text_unit_ids"])
row["period"] = str(row.get("period", ""))
row["size"] = _safe_int(row.get("size"), 0)
return row


# -- community reports (mirrors community_reports_typed) ------------------


def transform_community_report_row(
row: dict[str, Any],
) -> dict[str, Any]:
"""Coerce types for a community report row.

Mirrors ``community_reports_typed``: ``human_readable_id`` -> int,
``community`` -> int, ``level`` -> int, ``children`` -> list,
``rank`` -> float, ``findings`` -> list, ``size`` -> int.
"""
if "human_readable_id" in row:
row["human_readable_id"] = _safe_int(
row["human_readable_id"],
)
row["community"] = _safe_int(row.get("community"))
row["level"] = _safe_int(row.get("level"))
row["children"] = _coerce_list(row.get("children"))
row["rank"] = _safe_float(row.get("rank"))
row["findings"] = _coerce_list(row.get("findings"))
row["size"] = _safe_int(row.get("size"), 0)
return row


# -- covariates (mirrors covariates_typed) --------------------------------


def transform_covariate_row(
row: dict[str, Any],
) -> dict[str, Any]:
"""Coerce types for a covariate row.

Mirrors ``covariates_typed``: ``human_readable_id`` -> int.
"""
if "human_readable_id" in row:
row["human_readable_id"] = _safe_int(
row["human_readable_id"],
)
return row


# -- text units (mirrors text_units_typed) --------------------------------


def transform_text_unit_row(
row: dict[str, Any],
) -> dict[str, Any]:
"""Coerce types for a text-unit row.

Mirrors ``text_units_typed``: ``human_readable_id`` -> int,
``n_tokens`` -> int, ``entity_ids`` -> list,
``relationship_ids`` -> list, ``covariate_ids`` -> list.
"""
if "human_readable_id" in row:
row["human_readable_id"] = _safe_int(
row["human_readable_id"],
)
row["n_tokens"] = _safe_int(row.get("n_tokens"), 0)
if "entity_ids" in row:
row["entity_ids"] = _coerce_list(row["entity_ids"])
if "relationship_ids" in row:
row["relationship_ids"] = _coerce_list(
row["relationship_ids"],
)
if "covariate_ids" in row:
row["covariate_ids"] = _coerce_list(
row["covariate_ids"],
)
return row


# -- documents (mirrors documents_typed) ----------------------------------


def transform_document_row(
row: dict[str, Any],
) -> dict[str, Any]:
"""Coerce types for a document row.

Mirrors ``documents_typed``: ``human_readable_id`` -> int,
``text_unit_ids`` -> list.
"""
if "human_readable_id" in row:
row["human_readable_id"] = _safe_int(
row["human_readable_id"],
)
if "text_unit_ids" in row:
row["text_unit_ids"] = _coerce_list(row["text_unit_ids"])
return row
Loading
Loading