Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions repro_749.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from tensorboardX import SummaryWriter
22 changes: 21 additions & 1 deletion tensorboardX/event_file_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,25 @@ def __init__(self, logdir, max_queue_size=10, flush_secs=120, filename_suffix=''

self._worker.start()

def start(self):
self._worker.start()

def get_logdir(self):
"""Returns the directory where event file will be written."""
return self._logdir

def __getstate__(self):
state = self.__dict__.copy()
# Do not pickle the thread and the internal writer (which contains file handles)
state['_worker'] = None
state['_ev_writer'] = None
return state

def __setstate__(self, state):
self.__dict__.update(state)
# In child processes, we don't start a new thread.
# add_event will continue to work as it only uses self._event_queue.

def reopen(self):
"""Reopens the EventFileWriter.
Can be called after `close()` to add more events in the same directory.
Expand All @@ -133,6 +148,8 @@ def add_event(self, event):
event: An `Event` protocol buffer.
"""
if not self._closed:
if isinstance(event, event_pb2.Event):
event = event.SerializeToString()
self._event_queue.put(event)

def flush(self):
Expand Down Expand Up @@ -202,7 +219,10 @@ def run(self):

if type(data) == type(self._shutdown_signal):
return
self._record_writer.write_event(data)
if isinstance(data, bytes):
self._record_writer._write_serialized_event(data)
else:
self._record_writer.write_event(data)
self._has_pending_data = True
except queue.Empty:
pass
Expand Down
35 changes: 31 additions & 4 deletions tensorboardX/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,22 @@ def __init__(self, logdir, max_queue=10, flush_secs=120, filename_suffix=''):
self.event_writer = EventFileWriter(
logdir, max_queue, flush_secs, filename_suffix)

def cleanup():
self.event_writer.close()

atexit.register(cleanup)
self._pid = os.getpid()
atexit.register(_clean_up_file_writer, self)
self._default_metadata = {}

def get_logdir(self):
"""Returns the directory where event file will be written."""
return self.event_writer.get_logdir()

def __getstate__(self):
state = self.__dict__.copy()
state['event_writer'] = None
return state

def __setstate__(self, state):
self.__dict__.update(state)

def add_event(self, event, step=None, walltime=None):
"""Adds an event to the event file.
Args:
Expand Down Expand Up @@ -247,6 +253,11 @@ def use_metadata(self, *, global_step=None, walltime=None):
self._default_metadata = {}


def _clean_up_file_writer(writer):
if os.getpid() == writer._pid:
writer.close()


class SummaryWriter:
"""Writes entries directly to event files in the logdir to be
consumed by TensorBoard.
Expand Down Expand Up @@ -354,6 +365,22 @@ def __init__(
self.scalar_dict = {}
self._default_metadata = {}

def __getstate__(self):
state = self.__dict__.copy()
# Do not pickle the writers and comet logger as they are not picklable
# or contain resources that should not be shared across processes.
state['file_writer'] = None
state['all_writers'] = None
state['_comet_logger'] = None
return state

def __setstate__(self, state):
self.__dict__.update(state)
# Note: We do NOT call self._get_file_writer() here.
# It will be lazily re-initialized when add_scalar/etc is called in the child process.
# This ensures the child process gets its own clean writer/queue state if needed,
# or just works with the existing logic.

def __append_to_scalar_dict(self, tag, scalar_value, global_step,
timestamp):
"""This adds an entry to the self.scalar_dict datastructure with format
Expand Down
54 changes: 54 additions & 0 deletions tests/test_issue_727.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import unittest
import multiprocessing as mp
import os
import time
import numpy as np
from tensorboardX import SummaryWriter
from tensorboardX.proto import event_pb2

def worker_fn(w):
# In 'spawn' mode, 'w' is a unpickled copy.
# On Windows, the thread isn't running here, but add_scalar
# should still put serialized bytes into the shared queue.
try:
w.add_scalar('worker_metric', 1.23, global_step=10)
except Exception as e:
print(f"Worker failed: {e}")

class Issue727Test(unittest.TestCase):
def test_multiprocess_serialization(self):
# Use 'spawn' to simulate Windows behavior if possible
try:
ctx = mp.get_context('spawn')
except ValueError:
ctx = mp.get_context('fork')

writer = SummaryWriter()
event_filename = writer.file_writer.event_writer._ev_writer._file_name

p = ctx.Process(target=worker_fn, args=(writer,))
p.start()
p.join()

writer.close()

# Verify data was written correctly
from tensorboard.compat.tensorflow_stub.pywrap_tensorflow import PyRecordReader_New
r = PyRecordReader_New(event_filename)
r.GetNext() # meta data
r.GetNext() # the metric

ev = event_pb2.Event()
ev.ParseFromString(r.record())
self.assertEqual(ev.step, 10)
self.assertEqual(ev.summary.value[0].tag, 'worker_metric')
self.assertAlmostEqual(ev.summary.value[0].simple_value, 1.23)

def test_atexit_pid_check(self):
writer = SummaryWriter()
original_pid = writer.file_writer._pid
self.assertEqual(original_pid, os.getpid())
self.assertTrue(hasattr(writer.file_writer, '_pid'))

if __name__ == '__main__':
unittest.main()
Loading