Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 163 additions & 0 deletions openwpm/storage/in_process_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
"""In-process StorageController for testing.

Runs the StorageController's asyncio event loop in a daemon thread instead of
a subprocess, eliminating subprocess spawn overhead. Uses the same TCP server
and protocol, so the WebExtension connects to it identically.
"""

import asyncio
import logging
import random
import threading
import time
from typing import List, Optional, Tuple

from multiprocess import Queue
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This handle runs entirely in-process (thread-based), but it uses multiprocess.Queue. multiprocessing/multiprocess queues have semantics optimized for IPC and their .empty()/.qsize() behavior can be unreliable; they can also add unnecessary overhead. Prefer queue.Queue/queue.SimpleQueue for thread communication here (and update the drain logic accordingly).

Copilot uses AI. Check for mistakes.

from ..types import BrowserId, VisitId
from .storage_controller import StorageController
from .storage_providers import StructuredStorageProvider, UnstructuredStorageProvider


class InProcessStorageControllerHandle:
"""StorageControllerHandle replacement that runs in a thread, not a subprocess.

Implements the same interface as StorageControllerHandle (satisfies
StorageInterface protocol) but runs the asyncio event loop in a daemon
thread within the current process. This avoids subprocess spawn overhead
for testing.
"""

def __init__(
self,
structured_storage: StructuredStorageProvider,
unstructured_storage: Optional[UnstructuredStorageProvider],
) -> None:
self.listener_address: Optional[Tuple[str, int]] = None
self.status_queue: Queue = Queue()
self.completion_queue: Queue = Queue()
self.shutdown_queue: Queue = Queue()
self._last_status: Optional[int] = None
self._last_status_received: Optional[float] = None
self.logger = logging.getLogger("openwpm")
self._storage_controller = StorageController(
structured_storage,
unstructured_storage,
status_queue=self.status_queue,
completion_queue=self.completion_queue,
shutdown_queue=self.shutdown_queue,
)
self._thread: Optional[threading.Thread] = None

def get_next_visit_id(self) -> VisitId:
"""Generate visit id as randomly generated positive integer less than 2^53."""
return VisitId(random.getrandbits(53))

def get_next_browser_id(self) -> BrowserId:
"""Generate crawl id as randomly generated positive 32bit integer."""
return BrowserId(random.getrandbits(32))
Comment on lines +54 to +58
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

random.getrandbits() can return 0, which contradicts the docstring ('positive integer') and may collide with sentinel/invalid IDs if 0 is treated specially elsewhere. Generate IDs in a strictly-positive range (e.g., loop until non-zero, or add an offset while preserving the intended bit bounds).

Suggested change
return VisitId(random.getrandbits(53))
def get_next_browser_id(self) -> BrowserId:
"""Generate crawl id as randomly generated positive 32bit integer."""
return BrowserId(random.getrandbits(32))
visit_id = 0
while visit_id == 0:
visit_id = random.getrandbits(53)
return VisitId(visit_id)
def get_next_browser_id(self) -> BrowserId:
"""Generate crawl id as randomly generated positive 32bit integer."""
browser_id = 0
while browser_id == 0:
browser_id = random.getrandbits(32)
return BrowserId(browser_id)

Copilot uses AI. Check for mistakes.

def _run_loop(self) -> None:
"""Run the storage controller's asyncio loop in this thread."""
logging.getLogger("asyncio").setLevel(logging.WARNING)
asyncio.run(self._storage_controller._run(), debug=True)
Comment on lines +60 to +63
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This introduces global side effects for the whole process by changing the asyncio logger level, and it always enables asyncio debug mode (debug=True) which can significantly slow down tests and change behavior. Consider avoiding global logger mutation (or scoping it) and making debug configurable/defaulting to False.

Copilot uses AI. Check for mistakes.

def launch(self) -> None:
"""Start the storage controller in a daemon thread."""
self._thread = threading.Thread(
target=self._run_loop, name="InProcessStorageController", daemon=True
)
self._thread.start()
# Wait for the listener address from the status queue
self.listener_address = self.status_queue.get()
Comment on lines +65 to +72
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.status_queue.get() blocks indefinitely if the controller fails to start or dies before publishing the listener address, which can hang tests/CI. Use a bounded timeout here and raise a clear error if the address doesn't arrive (and consider surfacing thread exceptions via a shared variable/queue).

Copilot uses AI. Check for mistakes.

def get_new_completed_visits(self) -> List[Tuple[int, bool]]:
"""Return visit ids completed since last call."""
finished_visit_ids = list()
while not self.completion_queue.empty():
finished_visit_ids.append(self.completion_queue.get())
return finished_visit_ids
Comment on lines +74 to +79
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Draining a queue using while not queue.empty(): queue.get() is race-prone (items can arrive between empty() and get()) and .empty() is not reliable for multiprocessing-style queues. Use non-blocking gets in a try/except loop (e.g., get_nowait) until empty is raised, which is correct for both thread queues and process queues.

Copilot uses AI. Check for mistakes.

def shutdown(self, relaxed: bool = True) -> None:
"""Signal the storage controller to shut down and wait for the thread."""
assert self._thread is not None
self.logger.debug("Sending shutdown signal to in-process StorageController...")
self.shutdown_queue.put(("SHUTDOWN", relaxed))
start_time = time.time()
self._thread.join(timeout=60)
self.logger.debug(
"%s took %s seconds to close."
% (type(self).__name__, str(time.time() - start_time))
)
Comment on lines +88 to +91
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After join(timeout=60), the code doesn’t check whether the thread actually stopped. This can silently leak a daemon thread and cause later tests to behave unpredictably. Check self._thread.is_alive() after join and raise/handle it (and consider a longer timeout or a deterministic shutdown acknowledgement from the controller).

Suggested change
self.logger.debug(
"%s took %s seconds to close."
% (type(self).__name__, str(time.time() - start_time))
)
elapsed = time.time() - start_time
self.logger.debug(
"%s took %s seconds to close."
% (type(self).__name__, str(elapsed))
)
if self._thread.is_alive():
msg = (
"InProcessStorageController thread failed to shut down "
"within the 60-second timeout."
)
self.logger.error(msg)
raise RuntimeError(msg)

Copilot uses AI. Check for mistakes.

def get_most_recent_status(self) -> int:
"""Return the most recent queue size."""
if self._last_status is None:
return self.get_status()

while not self.status_queue.empty():
self._last_status = self.status_queue.get()
self._last_status_received = time.time()

if self._last_status_received is not None and (
time.time() - self._last_status_received
) > 120:
raise RuntimeError(
"No status update from the storage controller "
"for %d seconds." % (time.time() - self._last_status_received)
)

return self._last_status
Comment on lines +93 to +110
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_most_recent_status() is annotated to return int, but it can return a non-int if any non-status message ends up on status_queue (there is no type check here, unlike get_status()). Also, the queue-drain loop uses .empty() which is race-prone/unreliable. Mirror get_status()'s validation (assert/raise if non-int) and use a get_nowait drain pattern.

Copilot uses AI. Check for mistakes.

def get_status(self) -> int:
"""Get listener process status. If the status queue is empty, block."""
import queue

try:
self._last_status = self.status_queue.get(block=True, timeout=120)
self._last_status_received = time.time()
except queue.Empty:
assert self._last_status_received is not None
raise RuntimeError(
"No status update from the storage controller "
"for %d seconds." % (time.time() - self._last_status_received)
)
assert isinstance(self._last_status, int)
return self._last_status

def save_configuration(self, *args, **kwargs) -> None:
"""Save configuration - delegates to a DataSocket like StorageControllerHandle."""
from .storage_controller import DataSocket, INVALID_VISIT_ID
from ..config import BrowserParamsInternal, ManagerParamsInternal
from .storage_providers import TableName

assert self.listener_address is not None
manager_params: ManagerParamsInternal = args[0]
browser_params: List[BrowserParamsInternal] = args[1]
openwpm_version: str = args[2]
browser_version: str = args[3]
Comment on lines +128 to +138
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using *args/**kwargs and positional indexing makes this API easy to misuse and harder to type-check (e.g., callers passing keywords or refactoring argument order). Prefer an explicit signature matching the production handle’s save_configuration(...) parameters (including typing), and avoid relying on args[n].

Suggested change
def save_configuration(self, *args, **kwargs) -> None:
"""Save configuration - delegates to a DataSocket like StorageControllerHandle."""
from .storage_controller import DataSocket, INVALID_VISIT_ID
from ..config import BrowserParamsInternal, ManagerParamsInternal
from .storage_providers import TableName
assert self.listener_address is not None
manager_params: ManagerParamsInternal = args[0]
browser_params: List[BrowserParamsInternal] = args[1]
openwpm_version: str = args[2]
browser_version: str = args[3]
def save_configuration(
self,
manager_params: "ManagerParamsInternal",
browser_params: List["BrowserParamsInternal"],
openwpm_version: str,
browser_version: str,
) -> None:
"""Save configuration - delegates to a DataSocket like StorageControllerHandle."""
from .storage_controller import DataSocket, INVALID_VISIT_ID
from ..config import BrowserParamsInternal, ManagerParamsInternal
from .storage_providers import TableName
assert self.listener_address is not None

Copilot uses AI. Check for mistakes.

sock = DataSocket(self.listener_address, "StorageControllerHandle")
task_id = random.getrandbits(32)
sock.store_record(
TableName("task"),
INVALID_VISIT_ID,
{
"task_id": task_id,
"manager_params": manager_params.to_json(),
"openwpm_version": openwpm_version,
"browser_version": browser_version,
},
)
for browser_param in browser_params:
sock.store_record(
TableName("crawl"),
INVALID_VISIT_ID,
{
"browser_id": browser_param.browser_id,
"task_id": task_id,
"browser_params": browser_param.to_json(),
},
)
sock.finalize_visit_id(INVALID_VISIT_ID, success=True)
sock.close()
31 changes: 31 additions & 0 deletions openwpm/storage/storage_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Protocol defining the interface TaskManager uses to interact with storage.

This decouples TaskManager from the concrete StorageControllerHandle,
allowing tests to use lightweight in-process alternatives.
"""

from typing import List, Optional, Protocol, Tuple

from ..types import BrowserId, VisitId


class StorageInterface(Protocol):
"""Interface for storage controller handles.

StorageControllerHandle implements this protocol for production use.
InProcessStorageControllerHandle implements it for testing.
"""

def get_next_visit_id(self) -> VisitId: ...

def get_next_browser_id(self) -> BrowserId: ...

def get_most_recent_status(self) -> int: ...

def get_new_completed_visits(self) -> List[Tuple[int, bool]]: ...

def launch(self) -> None: ...

listener_address: Optional[Tuple[str, int]]

def shutdown(self, relaxed: bool = True) -> None: ...
28 changes: 23 additions & 5 deletions test/storage/test_storage_controller.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,39 @@
from typing import Any, Type, Union

import pandas as pd
import pytest
from pandas.testing import assert_frame_equal

from openwpm.mp_logger import MPLogger
from openwpm.storage.in_memory_storage import (
MemoryArrowProvider,
MemoryStructuredProvider,
)
from openwpm.storage.in_process_storage import InProcessStorageControllerHandle
from openwpm.storage.storage_controller import (
INVALID_VISIT_ID,
DataSocket,
StorageControllerHandle,
)
from test.storage.fixtures import dt_test_values
from test.storage.test_values import dt_test_values

HandleCls = Type[Union[StorageControllerHandle, InProcessStorageControllerHandle]]


@pytest.fixture(params=["subprocess", "in_process"])
def controller_handle_cls(request: Any) -> HandleCls:
if request.param == "subprocess":
return StorageControllerHandle
else:
return InProcessStorageControllerHandle


def test_startup_and_shutdown(mp_logger: MPLogger, test_values: dt_test_values) -> None:
def test_startup_and_shutdown(
mp_logger: MPLogger, test_values: dt_test_values, controller_handle_cls: HandleCls
) -> None:
test_table, visit_ids = test_values
structured = MemoryStructuredProvider()
controller_handle = StorageControllerHandle(structured, None)
controller_handle = controller_handle_cls(structured, None)
controller_handle.launch()
assert controller_handle.listener_address is not None
cs = DataSocket(controller_handle.listener_address, "Test")
Expand All @@ -40,10 +56,12 @@ def test_startup_and_shutdown(mp_logger: MPLogger, test_values: dt_test_values)
assert handle.storage[table] == [data]


def test_arrow_provider(mp_logger: MPLogger, test_values: dt_test_values) -> None:
def test_arrow_provider(
mp_logger: MPLogger, test_values: dt_test_values, controller_handle_cls: HandleCls
) -> None:
test_table, visit_ids = test_values
structured = MemoryArrowProvider()
controller_handle = StorageControllerHandle(structured, None)
controller_handle = controller_handle_cls(structured, None)
controller_handle.launch()

assert controller_handle.listener_address is not None
Expand Down
Loading