From 54d8b6c91154c43ff5b726d550ccf897e23849f3 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Huang Date: Fri, 10 Apr 2026 12:43:16 +0800 Subject: [PATCH] Fix issue #727: Improve multiprocessing safety and prevent segfaults on Windows --- repro_749.py | 1 + tensorboardX/event_file_writer.py | 22 ++++++++++++- tensorboardX/writer.py | 35 +++++++++++++++++--- tests/test_issue_727.py | 54 +++++++++++++++++++++++++++++++ 4 files changed, 107 insertions(+), 5 deletions(-) create mode 100644 repro_749.py create mode 100644 tests/test_issue_727.py diff --git a/repro_749.py b/repro_749.py new file mode 100644 index 00000000..c2b43267 --- /dev/null +++ b/repro_749.py @@ -0,0 +1 @@ +from tensorboardX import SummaryWriter diff --git a/tensorboardX/event_file_writer.py b/tensorboardX/event_file_writer.py index bafb476f..4ad40ad2 100644 --- a/tensorboardX/event_file_writer.py +++ b/tensorboardX/event_file_writer.py @@ -109,10 +109,25 @@ def __init__(self, logdir, max_queue_size=10, flush_secs=120, filename_suffix='' self._worker.start() + def start(self): + self._worker.start() + def get_logdir(self): """Returns the directory where event file will be written.""" return self._logdir + def __getstate__(self): + state = self.__dict__.copy() + # Do not pickle the thread and the internal writer (which contains file handles) + state['_worker'] = None + state['_ev_writer'] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) + # In child processes, we don't start a new thread. + # add_event will continue to work as it only uses self._event_queue. + def reopen(self): """Reopens the EventFileWriter. Can be called after `close()` to add more events in the same directory. @@ -133,6 +148,8 @@ def add_event(self, event): event: An `Event` protocol buffer. """ if not self._closed: + if isinstance(event, event_pb2.Event): + event = event.SerializeToString() self._event_queue.put(event) def flush(self): @@ -202,7 +219,10 @@ def run(self): if type(data) == type(self._shutdown_signal): return - self._record_writer.write_event(data) + if isinstance(data, bytes): + self._record_writer._write_serialized_event(data) + else: + self._record_writer.write_event(data) self._has_pending_data = True except queue.Empty: pass diff --git a/tensorboardX/writer.py b/tensorboardX/writer.py index 0587be09..e674eaac 100644 --- a/tensorboardX/writer.py +++ b/tensorboardX/writer.py @@ -115,16 +115,22 @@ def __init__(self, logdir, max_queue=10, flush_secs=120, filename_suffix=''): self.event_writer = EventFileWriter( logdir, max_queue, flush_secs, filename_suffix) - def cleanup(): - self.event_writer.close() - - atexit.register(cleanup) + self._pid = os.getpid() + atexit.register(_clean_up_file_writer, self) self._default_metadata = {} def get_logdir(self): """Returns the directory where event file will be written.""" return self.event_writer.get_logdir() + def __getstate__(self): + state = self.__dict__.copy() + state['event_writer'] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) + def add_event(self, event, step=None, walltime=None): """Adds an event to the event file. Args: @@ -247,6 +253,11 @@ def use_metadata(self, *, global_step=None, walltime=None): self._default_metadata = {} +def _clean_up_file_writer(writer): + if os.getpid() == writer._pid: + writer.close() + + class SummaryWriter: """Writes entries directly to event files in the logdir to be consumed by TensorBoard. @@ -354,6 +365,22 @@ def __init__( self.scalar_dict = {} self._default_metadata = {} + def __getstate__(self): + state = self.__dict__.copy() + # Do not pickle the writers and comet logger as they are not picklable + # or contain resources that should not be shared across processes. + state['file_writer'] = None + state['all_writers'] = None + state['_comet_logger'] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) + # Note: We do NOT call self._get_file_writer() here. + # It will be lazily re-initialized when add_scalar/etc is called in the child process. + # This ensures the child process gets its own clean writer/queue state if needed, + # or just works with the existing logic. + def __append_to_scalar_dict(self, tag, scalar_value, global_step, timestamp): """This adds an entry to the self.scalar_dict datastructure with format diff --git a/tests/test_issue_727.py b/tests/test_issue_727.py new file mode 100644 index 00000000..845b543b --- /dev/null +++ b/tests/test_issue_727.py @@ -0,0 +1,54 @@ +import unittest +import multiprocessing as mp +import os +import time +import numpy as np +from tensorboardX import SummaryWriter +from tensorboardX.proto import event_pb2 + +def worker_fn(w): + # In 'spawn' mode, 'w' is a unpickled copy. + # On Windows, the thread isn't running here, but add_scalar + # should still put serialized bytes into the shared queue. + try: + w.add_scalar('worker_metric', 1.23, global_step=10) + except Exception as e: + print(f"Worker failed: {e}") + +class Issue727Test(unittest.TestCase): + def test_multiprocess_serialization(self): + # Use 'spawn' to simulate Windows behavior if possible + try: + ctx = mp.get_context('spawn') + except ValueError: + ctx = mp.get_context('fork') + + writer = SummaryWriter() + event_filename = writer.file_writer.event_writer._ev_writer._file_name + + p = ctx.Process(target=worker_fn, args=(writer,)) + p.start() + p.join() + + writer.close() + + # Verify data was written correctly + from tensorboard.compat.tensorflow_stub.pywrap_tensorflow import PyRecordReader_New + r = PyRecordReader_New(event_filename) + r.GetNext() # meta data + r.GetNext() # the metric + + ev = event_pb2.Event() + ev.ParseFromString(r.record()) + self.assertEqual(ev.step, 10) + self.assertEqual(ev.summary.value[0].tag, 'worker_metric') + self.assertAlmostEqual(ev.summary.value[0].simple_value, 1.23) + + def test_atexit_pid_check(self): + writer = SummaryWriter() + original_pid = writer.file_writer._pid + self.assertEqual(original_pid, os.getpid()) + self.assertTrue(hasattr(writer.file_writer, '_pid')) + +if __name__ == '__main__': + unittest.main()