Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions nemo_curator/stages/text/download/arxiv/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,18 @@


class ArxivIterator(DocumentIterator):
"""Processes downloaded Arxiv files and extracts article content."""
"""Processes downloaded Arxiv files and extracts article content.

def __init__(self, log_frequency: int = 1000):
Args:
log_frequency: How often to log progress.
extract_tmp_dir: Directory in which temporary extraction directories are
created. If None, the system default temporary directory is used.
"""

def __init__(self, log_frequency: int = 1000, extract_tmp_dir: str | None = None):
super().__init__()
self._log_frequency = log_frequency
self._extract_tmp_dir = extract_tmp_dir
self._counter = 0

def _tex_proj_loader(self, file_or_dir_path: str) -> list[str] | None:
Expand Down Expand Up @@ -121,10 +128,10 @@ def _format_arxiv_id(self, arxiv_id: str) -> str:

def iterate(self, file_path: str) -> Iterator[dict[str, Any]]:
self._counter = 0
download_dir = os.path.split(file_path)[0]
bname = os.path.split(file_path)[-1]

with tempfile.TemporaryDirectory(dir=download_dir) as tmpdir, tarfile.open(file_path) as tf:
tmpdir_kwargs = {"dir": self._extract_tmp_dir} if self._extract_tmp_dir is not None else {}
with tempfile.TemporaryDirectory(**tmpdir_kwargs) as tmpdir, tarfile.open(file_path) as tf:
# Use safe extraction instead of extractall to prevent path traversal attacks
tar_safe_extract(tf, tmpdir)
for _i, item in enumerate(get_all_file_paths_under(tmpdir, recurse_subdirectories=True)):
Expand Down
28 changes: 27 additions & 1 deletion tests/stages/text/download/arxiv/test_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

import io
import tarfile
import tempfile
from pathlib import Path
from unittest import mock

import pytest

Expand Down Expand Up @@ -51,7 +53,13 @@ def test_arxiv_iterator(self, tmp_path: Path) -> None:
outer_tar.add(inner_tar_path, arcname="2103.00001.tar")

iterator = ArxivIterator(log_frequency=1)
results = list(iterator.iterate(str(outer_tar_path)))
with mock.patch(
"nemo_curator.stages.text.download.arxiv.iterator.tempfile.TemporaryDirectory",
wraps=tempfile.TemporaryDirectory,
) as tmpdir_mock:
results = list(iterator.iterate(str(outer_tar_path)))

tmpdir_mock.assert_called_once_with()
# Expect one paper extracted.
assert len(results) == 1
tex_files = results[0]
Expand All @@ -63,6 +71,24 @@ def test_arxiv_iterator(self, tmp_path: Path) -> None:
assert isinstance(tex_files["content"], list)
assert dummy_tex_content in tex_files["content"]

def test_arxiv_iterator_custom_extract_tmp_dir(self, tmp_path: Path) -> None:
outer_tar_path = tmp_path / "dummy_main.tar"
with tarfile.open(outer_tar_path, "w"):
pass

extract_tmp_dir = tmp_path / "extract_tmp"
extract_tmp_dir.mkdir()

iterator = ArxivIterator(extract_tmp_dir=str(extract_tmp_dir))
with mock.patch(
"nemo_curator.stages.text.download.arxiv.iterator.tempfile.TemporaryDirectory",
wraps=tempfile.TemporaryDirectory,
) as tmpdir_mock:
results = list(iterator.iterate(str(outer_tar_path)))

tmpdir_mock.assert_called_once_with(dir=str(extract_tmp_dir))
assert results == []


class TestSafeExtract:
"""Test suite for tar_safe_extract function."""
Expand Down