From 9625693bff6ed4236f1f92976929b0fd250420cf Mon Sep 17 00:00:00 2001 From: Tzu-Wei Huang Date: Wed, 8 Apr 2026 16:21:28 +0800 Subject: [PATCH 1/3] Trigger CI and coverage tests on all pushes and pull requests --- .github/workflows/test-matrix.yml | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/.github/workflows/test-matrix.yml b/.github/workflows/test-matrix.yml index e910f74..18339f0 100644 --- a/.github/workflows/test-matrix.yml +++ b/.github/workflows/test-matrix.yml @@ -3,11 +3,7 @@ name: Python application -on: - push: - branches: [ master ] - pull_request: - branches: [ master ] +on: [push, pull_request] jobs: From a4f194312505f4fc44c564576e4c310dd62220b3 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Huang Date: Wed, 8 Apr 2026 16:34:53 +0800 Subject: [PATCH 2/3] Remove obsolete TorchVis class and torchvis.py --- docs/tensorboard.rst | 5 ----- tensorboardX/__init__.py | 1 - tensorboardX/torchvis.py | 47 ---------------------------------------- 3 files changed, 53 deletions(-) delete mode 100644 tensorboardX/torchvis.py diff --git a/docs/tensorboard.rst b/docs/tensorboard.rst index 99f03ea..75e8377 100644 --- a/docs/tensorboard.rst +++ b/docs/tensorboard.rst @@ -11,8 +11,3 @@ tensorboardX :members: .. automethod:: __init__ - -.. autoclass:: TorchVis - :members: - - .. automethod:: __init__ \ No newline at end of file diff --git a/tensorboardX/__init__.py b/tensorboardX/__init__.py index 5181247..4244d39 100644 --- a/tensorboardX/__init__.py +++ b/tensorboardX/__init__.py @@ -3,7 +3,6 @@ from .global_writer import GlobalSummaryWriter from .record_writer import RecordWriter -from .torchvis import TorchVis from .writer import FileWriter, SummaryWriter try: diff --git a/tensorboardX/torchvis.py b/tensorboardX/torchvis.py deleted file mode 100644 index 735abfe..0000000 --- a/tensorboardX/torchvis.py +++ /dev/null @@ -1,47 +0,0 @@ -import gc - -from .writer import SummaryWriter - -# Supports TensorBoard visualization -vis_formats = {'tensorboard': SummaryWriter} - - -class TorchVis: - def __init__(self, *args, **init_kwargs): - """ - Args: - args (list of strings): The name of the visualization target(s). - Accepted targets are 'tensorboard'. - init_kwargs: Additional keyword parameters for the writer. - """ - self.subscribers = {} - self.register(*args, **init_kwargs) - - def register(self, *args, **init_kwargs): - # Sets tensorboard as the default visualization format if not specified - formats = args if args else ['tensorboard'] - for format in formats: - if self.subscribers.get(format) is None and format in vis_formats: - self.subscribers[format] = vis_formats[format](**init_kwargs.get(format, {})) - - def unregister(self, *args): - for format in args: - if format in self.subscribers: - self.subscribers[format].close() - del self.subscribers[format] - gc.collect() - - def __getattr__(self, attr): - if not self.subscribers: - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'") - - def wrapper(*args, **kwargs): - for _, subscriber in self.subscribers.items(): - if hasattr(subscriber, attr): - getattr(subscriber, attr)(*args, **kwargs) - return wrapper - - # Handle writer management (open/close) for the user - def __del__(self): - for _, subscriber in self.subscribers.items(): - subscriber.close() From 410a92045d9a7629ae6f4f78cf4ff8a63d293d2f Mon Sep 17 00:00:00 2001 From: Tzu-Wei Huang Date: Wed, 8 Apr 2026 21:26:01 +0800 Subject: [PATCH 3/3] Fix issue #749: Add py.typed and fix type hints --- MANIFEST.in | 1 + pyproject.toml | 8 ++++++++ tensorboardX/py.typed | 0 tensorboardX/writer.py | 40 ++++++++++++++++++++-------------------- 4 files changed, 29 insertions(+), 20 deletions(-) create mode 100644 tensorboardX/py.typed diff --git a/MANIFEST.in b/MANIFEST.in index fb37de6..99bba38 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,6 +1,7 @@ include HISTORY.rst include LICENSE include compile.sh +include tensorboardX/py.typed recursive-include tensorboardX/proto * recursive-exclude test * recursive-exclude examples * diff --git a/pyproject.toml b/pyproject.toml index 35b8374..94b4e60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ dev = [ "ruff>=0.8.4", "pillow==11.0.0", "setuptools==81.0.0", + "mypy>=1.14.1", ] [tool.ruff] @@ -95,3 +96,10 @@ select = [ ] ignore = ["F401", "E501", "E721", "E741"] +[tool.mypy] +ignore_missing_imports = true +exclude = [ + 'tensorboardX/proto/', +] +disable_error_code = ["attr-defined", "name-defined", "arg-type", "var-annotated"] + diff --git a/tensorboardX/py.typed b/tensorboardX/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/tensorboardX/writer.py b/tensorboardX/writer.py index 4a98f52..0587be0 100644 --- a/tensorboardX/writer.py +++ b/tensorboardX/writer.py @@ -8,7 +8,7 @@ import logging import os import time -from typing import Optional, Union +from typing import Any, Optional, Union import numpy @@ -38,12 +38,9 @@ logger = logging.getLogger(__name__) -numpy_compatible = numpy.ndarray -try: +numpy_compatible = Any +with contextlib.suppress(ImportError): import torch - numpy_compatible = torch.Tensor -except ImportError: - pass class DummyFileWriter: @@ -264,7 +261,7 @@ class SummaryWriter: def __init__( self, logdir: Optional[str] = None, - comment: Optional[str] = "", + comment: str = "", purge_step: Optional[int] = None, max_queue: Optional[int] = 10, flush_secs: Optional[int] = 120, @@ -327,7 +324,7 @@ def __init__( from datetime import datetime current_time = datetime.now().strftime('%b%d_%H-%M-%S') logdir = os.path.join( - 'runs', current_time + '_' + socket.gethostname() + comment) + 'runs', current_time + '_' + socket.gethostname() + (comment if comment else "")) self.logdir = logdir self.purge_step = purge_step self._max_queue = max_queue @@ -340,7 +337,8 @@ def __init__( # Initialize the file writers, but they can be cleared out on close # and recreated later as needed. - self.file_writer = self.all_writers = None + self.file_writer: Optional[Union[FileWriter, DummyFileWriter]] = None + self.all_writers: Optional[dict[str, Union[FileWriter, DummyFileWriter]]] = None self._get_file_writer() # Create default bins for histograms, see generate_testdata.py in tensorflow/tensorboard @@ -367,7 +365,7 @@ def __append_to_scalar_dict(self, tag, scalar_value, global_step, self.scalar_dict[tag].append( [timestamp, global_step, float(make_np(scalar_value).squeeze())]) - def _get_file_writer(self): + def _get_file_writer(self) -> Union[FileWriter, DummyFileWriter]: """Returns the default FileWriter instance. Recreates it if closed.""" if not self._write_to_disk: self.file_writer = DummyFileWriter(logdir=self.logdir) @@ -434,10 +432,10 @@ def add_hparams( if not name: name = str(time.time()) - with SummaryWriter(logdir=os.path.join(self.file_writer.get_logdir(), name)) as w_hp: - w_hp.file_writer.add_summary(exp) - w_hp.file_writer.add_summary(ssi) - w_hp.file_writer.add_summary(sei) + with SummaryWriter(logdir=os.path.join(self._get_file_writer().get_logdir(), name)) as w_hp: + w_hp._get_file_writer().add_summary(exp) + w_hp._get_file_writer().add_summary(ssi) + w_hp._get_file_writer().add_summary(sei) for k, v in metric_dict.items(): w_hp.add_scalar(k, v, global_step) self._get_comet_logger().log_parameters(hparam_dict, step=global_step) @@ -517,6 +515,8 @@ def add_scalars( """ walltime = time.time() if walltime is None else walltime fw_logdir = self._get_file_writer().get_logdir() + if self.all_writers is None: + self.all_writers = {} for tag, scalar_value in tag_scalar_dict.items(): fw_tag = os.path.join(str(fw_logdir), main_tag, tag) if fw_tag in self.all_writers: @@ -551,7 +551,7 @@ def add_histogram( tag: str, values: numpy_compatible, global_step: Optional[int] = None, - bins: Optional[str] = 'tensorflow', + bins: Union[Optional[str], list, Any] = 'tensorflow', walltime: Optional[float] = None, max_bins=None): """Add histogram to summary. @@ -755,7 +755,7 @@ def add_images( """ if isinstance(img_tensor, list): # a list of tensors in CHW or HWC - if dataformats.upper() != 'CHW' and dataformats.upper() != 'HWC': + if dataformats is None or (dataformats.upper() != 'CHW' and dataformats.upper() != 'HWC'): print('A list of image is passed, but the dataformat is neither CHW nor HWC.') print('Nothing is written.') return @@ -766,7 +766,7 @@ def add_images( import numpy as np img_tensor = np.stack(img_tensor, 0) - dataformats = 'N' + dataformats + dataformats = 'N' + (dataformats if dataformats else "") summary = image(tag, img_tensor, dataformats=dataformats) encoded_image_string = summary.value[0].image.encoded_image_string @@ -967,7 +967,7 @@ def add_embedding( self, mat: numpy_compatible, metadata=None, - label_img: numpy_compatible = None, + label_img: Optional[numpy_compatible] = None, global_step: Optional[int] = None, tag='default', metadata_header=None): @@ -1202,8 +1202,8 @@ def add_mesh( self, tag: str, vertices: numpy_compatible, - colors: numpy_compatible = None, - faces: numpy_compatible = None, + colors: Optional[numpy_compatible] = None, + faces: Optional[numpy_compatible] = None, config_dict=None, global_step: Optional[int] = None, walltime: Optional[float] = None):