Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions comtypes/logutil.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# logutil.py
import functools
import logging
import warnings
from ctypes import WinDLL
from ctypes.wintypes import LPCSTR, LPCWSTR

Expand All @@ -18,19 +20,28 @@ class NTDebugHandler(logging.Handler):
def emit(
self,
record,
writeA=_OutputDebugStringA,
writeW=_OutputDebugStringW,
):
text = self.format(record)
if isinstance(text, str):
writeA(text + "\n")
else:
writeW(text + "\n")
writeW(text + "\n")


logging.NTDebugHandler = NTDebugHandler


def deprecated(reason: str):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
warnings.warn(reason, category=DeprecationWarning, stacklevel=2)
return func(*args, **kwargs)

return wrapper

return decorator


@deprecated("Deprecated. See https://github.com/enthought/comtypes/issues/920.")
def setup_logging(*pathnames):
import configparser

Expand Down
240 changes: 240 additions & 0 deletions comtypes/test/test_logutil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
import contextlib
import ctypes
import logging
import threading
import unittest as ut
from collections.abc import Iterator
from ctypes import POINTER, WinDLL, c_void_p
from ctypes import c_size_t as SIZE_T
from ctypes.wintypes import BOOL, DWORD, HANDLE, LPCWSTR
from queue import Queue
from typing import TYPE_CHECKING, Optional
from typing import Union as _UnionT

from comtypes.client._events import SECURITY_ATTRIBUTES
from comtypes.logutil import NTDebugHandler, deprecated
from comtypes.logutil import (
_OutputDebugStringW as OutputDebugStringW,
)

if TYPE_CHECKING:
from ctypes import _CArgObject, _Pointer


class Test_deprecated(ut.TestCase):
def test_warning_is_raised(self):
reason_text = "This is deprecated."

@deprecated(reason_text)
def test_func():
return "success"

with self.assertWarns(DeprecationWarning) as cm:
result = test_func()
self.assertEqual(result, "success")
self.assertEqual(reason_text, str(cm.warning))


_kernel32 = WinDLL("kernel32", use_last_error=True)

# https://learn.microsoft.com/en-us/windows/win32/api/synchapi/nf-synchapi-createeventw
_CreateEventW = _kernel32.CreateEventW
_CreateEventW.argtypes = [POINTER(SECURITY_ATTRIBUTES), BOOL, BOOL, LPCWSTR]
_CreateEventW.restype = HANDLE

# https://learn.microsoft.com/en-us/windows/win32/api/synchapi/nf-synchapi-setevent
_SetEvent = _kernel32.SetEvent
_SetEvent.argtypes = [HANDLE]
_SetEvent.restype = BOOL

# https://learn.microsoft.com/en-us/windows/win32/api/synchapi/nf-synchapi-waitforsingleobject
_WaitForSingleObject = _kernel32.WaitForSingleObject
_WaitForSingleObject.argtypes = [HANDLE, DWORD]
_WaitForSingleObject.restype = DWORD

# https://learn.microsoft.com/en-us/windows/win32/api/memoryapi/nf-memoryapi-createfilemappingw
_CreateFileMappingW = _kernel32.CreateFileMappingW
_CreateFileMappingW.argtypes = [
HANDLE,
POINTER(SECURITY_ATTRIBUTES),
DWORD,
DWORD,
DWORD,
LPCWSTR,
]
_CreateFileMappingW.restype = HANDLE

# https://learn.microsoft.com/en-us/windows/win32/api/memoryapi/nf-memoryapi-mapviewoffile
_MapViewOfFile = _kernel32.MapViewOfFile
_MapViewOfFile.argtypes = [HANDLE, DWORD, DWORD, DWORD, SIZE_T]
_MapViewOfFile.restype = c_void_p

# https://learn.microsoft.com/en-us/windows/win32/api/memoryapi/nf-memoryapi-unmapviewoffile
_UnmapViewOfFile = _kernel32.UnmapViewOfFile
_UnmapViewOfFile.argtypes = [c_void_p]
_UnmapViewOfFile.restype = BOOL

# https://learn.microsoft.com/en-us/windows/win32/api/handleapi/nf-handleapi-closehandle
_CloseHandle = _kernel32.CloseHandle
_CloseHandle.argtypes = [HANDLE]
_CloseHandle.restype = BOOL

# https://learn.microsoft.com/en-us/windows/win32/api/processthreadsapi/nf-processthreadsapi-getcurrentprocessid
_GetCurrentProcessId = _kernel32.GetCurrentProcessId
_GetCurrentProcessId.argtypes = []
_GetCurrentProcessId.restype = DWORD


@contextlib.contextmanager
def create_file_mapping(
hfile: int,
security: _UnionT["_Pointer[SECURITY_ATTRIBUTES]", "_CArgObject", None],
flprotect: int,
size_high: int,
size_low: int,
name: Optional[str],
) -> Iterator[int]:
"""Context manager to creates a Windows file mapping object."""
handle = _CreateFileMappingW(hfile, security, flprotect, size_high, size_low, name)
assert handle, ctypes.FormatError(ctypes.get_last_error())
try:
yield handle
finally:
_CloseHandle(handle)


@contextlib.contextmanager
def map_view_of_file(
handle: int, access: int, offset_high: int, offset_low: int, size: int
) -> Iterator[int]:
"""Context manager to map a view of a file mapping into the process's
address space.
"""
p_view = _MapViewOfFile(handle, access, offset_high, offset_low, size)
assert p_view, ctypes.FormatError(ctypes.get_last_error())
try:
yield p_view
finally:
_UnmapViewOfFile(p_view)


@contextlib.contextmanager
def create_event(
security: _UnionT["_Pointer[SECURITY_ATTRIBUTES]", "_CArgObject", None],
manual: bool,
init: bool,
name: Optional[str],
) -> Iterator[int]:
"""Context manager to creates a Windows event object."""
handle = _CreateEventW(security, manual, init, name)
assert handle, ctypes.FormatError(ctypes.get_last_error())
try:
yield handle
finally:
_CloseHandle(handle)


DBWIN_BUFFER_SIZE = 4096 # Longer messages are truncated at the source by the OS
WAIT_OBJECT_0 = 0x00000000
PAGE_READWRITE = 0x04
FILE_MAP_READ = 0x04
INVALID_HANDLE_VALUE = -1 # Backed by the system paging file instead of a file on disk


@contextlib.contextmanager
def open_dbwin_debug_channels() -> Iterator[tuple[int, int, int]]:
"""Context manager to open the standard Windows debug output channels
(events and shared memory).
Yields handles to `DBWIN_BUFFER_READY`, `DBWIN_DATA_READY`, and a pointer
to `DBWIN_BUFFER`.
"""
with (
# "DBWIN_BUFFER_READY": An event signaled by the listener to indicate
# it's ready to receive debug output. `OutputDebugString` waits for this.
create_event(None, False, False, "DBWIN_BUFFER_READY") as h_buffer_ready,
# "DBWIN_DATA_READY": An event signaled by `OutputDebugString` to
# indicate new data is written to the shared buffer. Listener waits.
create_event(None, False, False, "DBWIN_DATA_READY") as h_data_ready,
# "DBWIN_BUFFER": A shared memory region where `OutputDebugString`
# writes the debug string data.
create_file_mapping(
INVALID_HANDLE_VALUE,
None,
PAGE_READWRITE,
0,
DBWIN_BUFFER_SIZE,
"DBWIN_BUFFER",
) as h_mapping,
# Map the shared memory region into the listener's address space
# for reading the debug strings.
map_view_of_file(h_mapping, FILE_MAP_READ, 0, 0, DBWIN_BUFFER_SIZE) as p_view,
):
yield (h_buffer_ready, h_data_ready, p_view)


@contextlib.contextmanager
def capture_debug_strings(ready: threading.Event, *, interval: int) -> Iterator[Queue]:
"""Context manager to capture debug strings emitted via `OutputDebugString`.
Spawns a listener thread to monitor the debug channels.
"""
captured = Queue()
finished = threading.Event()

def _listener(
q: Queue, rdy: threading.Event, fin: threading.Event, pid: int
) -> None:
# Create/open named events and file mapping for interprocess communication.
# These objects are part of the Windows Debugging API contract.
with open_dbwin_debug_channels() as (h_buffer_ready, h_data_ready, p_view):
rdy.set() # Signal to the main thread that listener is ready.
while not fin.is_set(): # Loop until the main thread signals to finish.
_SetEvent(h_buffer_ready) # Signal readiness to `OutputDebugString`.
# Wait for `OutputDebugString` to signal that data is ready.
if _WaitForSingleObject(h_data_ready, interval) == WAIT_OBJECT_0:
# Debug string buffer format: [4 bytes: PID][N bytes: string].
# Check if the process ID in the buffer matches the current PID.
if ctypes.cast(p_view, POINTER(DWORD)).contents.value == pid:
# Extract the null-terminated string, skipping the PID,
# and put it into the queue.
q.put(ctypes.string_at(p_view + 4).strip(b"\x00"))

th = threading.Thread(
target=_listener,
args=(captured, ready, finished, _GetCurrentProcessId()),
daemon=True,
)
th.start()
try:
yield captured
finally:
finished.set()
th.join()


class Test_OutputDebugStringW(ut.TestCase):
def test(self):
ready = threading.Event()
with capture_debug_strings(ready, interval=100) as cap:
ready.wait(timeout=5) # Wait for the listener to be ready
OutputDebugStringW("hello world")
OutputDebugStringW("test message")
self.assertEqual(cap.get(), b"hello world")
self.assertEqual(cap.get(), b"test message")


class Test_NTDebugHandler(ut.TestCase):
def test_emit(self):
ready = threading.Event()
handler = NTDebugHandler()
logger = logging.getLogger("test_ntdebug_handler")
# Clear existing handlers to prevent interference from other tests
logger.handlers = []
logger.addHandler(handler)
logger.setLevel(logging.INFO)
with capture_debug_strings(ready, interval=100) as cap:
ready.wait(timeout=5) # Wait for the listener to be ready
msg = "This is a test message from NTDebugHandler."
logger.info(msg)
logger.removeHandler(handler)
handler.close()
self.assertEqual(cap.get(), msg.encode("utf-8") + b"\n")