Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions utils/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@


def resolve_path(base: str | Path, path: str | Path):
"""Resolve a path relative to a base directory.

Args:
base: Base directory used when ``path`` is relative.
path: Absolute or relative path to resolve.

Returns:
The resolved absolute path when ``path`` is absolute, otherwise the
resolved base directory joined with ``path``.
"""
path = Path(path)
if path.is_absolute():
return path.resolve()
Expand All @@ -10,6 +20,16 @@ def resolve_path(base: str | Path, path: str | Path):


def display_path_rel_to_cwd(path: str, cwd: Path | None) -> str:
"""Return a path display string relative to the current working directory.

Args:
path: Path to display.
cwd: Current working directory to make ``path`` relative to, if possible.

Returns:
``path`` relative to ``cwd`` when possible; otherwise, the original path
string or normalized path string.
"""
try:
p = Path(path)
except Exception:
Expand All @@ -25,13 +45,30 @@ def display_path_rel_to_cwd(path: str, cwd: Path | None) -> str:


def ensure_parent_directory(path: str | Path) -> Path:
"""Ensure the parent directory for a path exists.

Args:
path: File path whose parent directory should be created.

Returns:
The input path converted to a ``Path`` instance.
"""
path = Path(path)

path.parent.mkdir(parents=True, exist_ok=True)
return path


def is_binary_file(path: str | Path) -> bool:
"""Check whether a file appears to contain binary data.

Args:
path: File path to inspect.

Returns:
True if the first bytes of the file contain a null byte; otherwise False.
Returns False if the file cannot be read.
"""
try:
with open(path, "rb") as f:
chunk = f.read(8192)
Expand Down
39 changes: 39 additions & 0 deletions utils/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@


def get_tokenizer(model: str):
"""Get a tokenization function for a model.

Args:
model: Model name used to select a tiktoken encoding.

Returns:
A callable that encodes text into token IDs. Falls back to the
``cl100k_base`` encoding when the model-specific encoding is unavailable.
"""
try:
encoding = tiktoken.encoding_for_model(model)
return encoding.encode
Expand All @@ -11,6 +20,15 @@ def get_tokenizer(model: str):


def count_tokens(text: str, model: str = "gemini-2.0-flash-exp") -> int:
"""Count the number of tokens in text for a model.

Args:
text: Text to count tokens for.
model: Model name used to select a tokenizer.

Returns:
Number of tokens in ``text``.
"""
tokenizer = get_tokenizer(model)

if tokenizer:
Expand All @@ -20,6 +38,14 @@ def count_tokens(text: str, model: str = "gemini-2.0-flash-exp") -> int:


def estimate_tokens(text: str) -> int:
"""Estimate the number of tokens in text using character count.

Args:
text: Text to estimate token count for.

Returns:
Estimated token count, with a minimum of 1.
"""
return max(1, len(text) // 4)


Expand All @@ -30,6 +56,19 @@ def truncate_text(
suffix: str = "\n... [truncated]",
preserve_lines: bool = True,
):
"""Truncate text to fit within a maximum token count.

Args:
text: Text to truncate.
model: Model name used to count tokens.
max_tokens: Maximum number of tokens allowed in the returned text.
suffix: Text appended when truncation occurs.
preserve_lines: Whether to truncate only at line boundaries when possible.

Returns:
The original text when it fits within ``max_tokens``; otherwise, a
truncated version with ``suffix`` appended.
"""
current_tokens = count_tokens(text, model)
if current_tokens <= max_tokens:
return text
Expand Down