Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion tensorboardX/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
"""

from .record_writer import RecordWriter
from .record_reader import RecordReader
from .torchvis import TorchVis
from .writer import FileWriter, SummaryWriter
from .reader import SummaryReader
from .global_writer import GlobalSummaryWriter
__version__ = "2.0" # will be overwritten if run setup.py

__version__ = "2.0" # will be overwritten if run setup.py
20 changes: 20 additions & 0 deletions tensorboardX/reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Provides an API for reading protocol buffers from event files"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from .record_reader import RecordReader
from .proto.event_pb2 import Event


class SummaryReader:
def __init__(self, filename):
self.filename = filename

def __iter__(self):
reader = RecordReader(self.filename)
for event_str in reader:
event = Event()
event.ParseFromString(event_str)
yield event
289 changes: 289 additions & 0 deletions tensorboardX/record_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""A wrapper for TensorFlow SWIG-generated bindings."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import array
import struct
from pathlib import Path


TFE_DEVICE_PLACEMENT_WARN = 0
TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32 = 0
TFE_DEVICE_PLACEMENT_SILENT = 0
TFE_DEVICE_PLACEMENT_EXPLICIT = 0


def __getattr__(attr):
return 0


def TF_bfloat16_type():
return 0


def masked_crc32c(data):
x = u32(crc32c(data))
return u32(((x >> 15) | u32(x << 17)) + 0xA282EAD8)


def u32(x):
return x & 0xFFFFFFFF


# fmt: off
CRC_TABLE = (
0x00000000, 0xF26B8303, 0xE13B70F7, 0x1350F3F4,
0xC79A971F, 0x35F1141C, 0x26A1E7E8, 0xD4CA64EB,
0x8AD958CF, 0x78B2DBCC, 0x6BE22838, 0x9989AB3B,
0x4D43CFD0, 0xBF284CD3, 0xAC78BF27, 0x5E133C24,
0x105EC76F, 0xE235446C, 0xF165B798, 0x030E349B,
0xD7C45070, 0x25AFD373, 0x36FF2087, 0xC494A384,
0x9A879FA0, 0x68EC1CA3, 0x7BBCEF57, 0x89D76C54,
0x5D1D08BF, 0xAF768BBC, 0xBC267848, 0x4E4DFB4B,
0x20BD8EDE, 0xD2D60DDD, 0xC186FE29, 0x33ED7D2A,
0xE72719C1, 0x154C9AC2, 0x061C6936, 0xF477EA35,
0xAA64D611, 0x580F5512, 0x4B5FA6E6, 0xB93425E5,
0x6DFE410E, 0x9F95C20D, 0x8CC531F9, 0x7EAEB2FA,
0x30E349B1, 0xC288CAB2, 0xD1D83946, 0x23B3BA45,
0xF779DEAE, 0x05125DAD, 0x1642AE59, 0xE4292D5A,
0xBA3A117E, 0x4851927D, 0x5B016189, 0xA96AE28A,
0x7DA08661, 0x8FCB0562, 0x9C9BF696, 0x6EF07595,
0x417B1DBC, 0xB3109EBF, 0xA0406D4B, 0x522BEE48,
0x86E18AA3, 0x748A09A0, 0x67DAFA54, 0x95B17957,
0xCBA24573, 0x39C9C670, 0x2A993584, 0xD8F2B687,
0x0C38D26C, 0xFE53516F, 0xED03A29B, 0x1F682198,
0x5125DAD3, 0xA34E59D0, 0xB01EAA24, 0x42752927,
0x96BF4DCC, 0x64D4CECF, 0x77843D3B, 0x85EFBE38,
0xDBFC821C, 0x2997011F, 0x3AC7F2EB, 0xC8AC71E8,
0x1C661503, 0xEE0D9600, 0xFD5D65F4, 0x0F36E6F7,
0x61C69362, 0x93AD1061, 0x80FDE395, 0x72966096,
0xA65C047D, 0x5437877E, 0x4767748A, 0xB50CF789,
0xEB1FCBAD, 0x197448AE, 0x0A24BB5A, 0xF84F3859,
0x2C855CB2, 0xDEEEDFB1, 0xCDBE2C45, 0x3FD5AF46,
0x7198540D, 0x83F3D70E, 0x90A324FA, 0x62C8A7F9,
0xB602C312, 0x44694011, 0x5739B3E5, 0xA55230E6,
0xFB410CC2, 0x092A8FC1, 0x1A7A7C35, 0xE811FF36,
0x3CDB9BDD, 0xCEB018DE, 0xDDE0EB2A, 0x2F8B6829,
0x82F63B78, 0x709DB87B, 0x63CD4B8F, 0x91A6C88C,
0x456CAC67, 0xB7072F64, 0xA457DC90, 0x563C5F93,
0x082F63B7, 0xFA44E0B4, 0xE9141340, 0x1B7F9043,
0xCFB5F4A8, 0x3DDE77AB, 0x2E8E845F, 0xDCE5075C,
0x92A8FC17, 0x60C37F14, 0x73938CE0, 0x81F80FE3,
0x55326B08, 0xA759E80B, 0xB4091BFF, 0x466298FC,
0x1871A4D8, 0xEA1A27DB, 0xF94AD42F, 0x0B21572C,
0xDFEB33C7, 0x2D80B0C4, 0x3ED04330, 0xCCBBC033,
0xA24BB5A6, 0x502036A5, 0x4370C551, 0xB11B4652,
0x65D122B9, 0x97BAA1BA, 0x84EA524E, 0x7681D14D,
0x2892ED69, 0xDAF96E6A, 0xC9A99D9E, 0x3BC21E9D,
0xEF087A76, 0x1D63F975, 0x0E330A81, 0xFC588982,
0xB21572C9, 0x407EF1CA, 0x532E023E, 0xA145813D,
0x758FE5D6, 0x87E466D5, 0x94B49521, 0x66DF1622,
0x38CC2A06, 0xCAA7A905, 0xD9F75AF1, 0x2B9CD9F2,
0xFF56BD19, 0x0D3D3E1A, 0x1E6DCDEE, 0xEC064EED,
0xC38D26C4, 0x31E6A5C7, 0x22B65633, 0xD0DDD530,
0x0417B1DB, 0xF67C32D8, 0xE52CC12C, 0x1747422F,
0x49547E0B, 0xBB3FFD08, 0xA86F0EFC, 0x5A048DFF,
0x8ECEE914, 0x7CA56A17, 0x6FF599E3, 0x9D9E1AE0,
0xD3D3E1AB, 0x21B862A8, 0x32E8915C, 0xC083125F,
0x144976B4, 0xE622F5B7, 0xF5720643, 0x07198540,
0x590AB964, 0xAB613A67, 0xB831C993, 0x4A5A4A90,
0x9E902E7B, 0x6CFBAD78, 0x7FAB5E8C, 0x8DC0DD8F,
0xE330A81A, 0x115B2B19, 0x020BD8ED, 0xF0605BEE,
0x24AA3F05, 0xD6C1BC06, 0xC5914FF2, 0x37FACCF1,
0x69E9F0D5, 0x9B8273D6, 0x88D28022, 0x7AB90321,
0xAE7367CA, 0x5C18E4C9, 0x4F48173D, 0xBD23943E,
0xF36E6F75, 0x0105EC76, 0x12551F82, 0xE03E9C81,
0x34F4F86A, 0xC69F7B69, 0xD5CF889D, 0x27A40B9E,
0x79B737BA, 0x8BDCB4B9, 0x988C474D, 0x6AE7C44E,
0xBE2DA0A5, 0x4C4623A6, 0x5F16D052, 0xAD7D5351,
)
# fmt: on


CRC_INIT = 0

_MASK = 0xFFFFFFFF


def crc_update(crc, data):
"""Update CRC-32C checksum with data.
Args:
crc: 32-bit checksum to update as long.
data: byte array, string or iterable over bytes.
Returns:
32-bit updated CRC-32C as long.
"""

if type(data) != array.array or data.itemsize != 1:
buf = array.array("B", data)
else:
buf = data

crc ^= _MASK
for b in buf:
table_index = (crc ^ b) & 0xFF
crc = (CRC_TABLE[table_index] ^ (crc >> 8)) & _MASK
return crc ^ _MASK


def crc_finalize(crc):
"""Finalize CRC-32C checksum.
This function should be called as last step of crc calculation.
Args:
crc: 32-bit checksum as long.
Returns:
finalized 32-bit checksum as long
"""
return crc & _MASK


def crc32c(data):
"""Compute CRC-32C checksum of the data.
Args:
data: byte array, string or iterable over bytes.
Returns:
32-bit CRC-32C checksum of data as long.
"""
return crc_finalize(crc_update(CRC_INIT, data))


class BufferedReader:
def __init__(self, filename):
self.filename = filename
self.file_handle = None
self._buffer = b""
self._buffer_pos = 0

def __enter__(self):
if self.file_handle is not None:
raise RuntimeError('Dont use reader multiple times')
self.file_handle = open(self.filename, 'rb')
self._buffer = b""
self._buffer_pos = 0
return self

def __exit__(self, *args):
self.file_handle.close()
self.file_handle = None
self._buffer = b""
self._buffer_pos = 0

def read(self, n_bytes):
"""Read up to n bytes from the underlying file, with buffering.
Reads are satisfied from a buffer of previous data read starting at
`self._buffer_pos` until the buffer is exhausted, and then from the
actual underlying file. Any new data is added to the buffer, and
`self._buffer_pos` is advanced to the point in the buffer past all
data returned as part of this read.
Args:
n_butes: non-negative number of bytes to read
Returns:
bytestring of data read, up to n bytes
"""
result = self._buffer[self._buffer_pos : self._buffer_pos + n_bytes]
self._buffer_pos += len(result)
n_bytes -= len(result)
if n_bytes > 0:
new_data = self.file_handle.read(n_bytes)
result += new_data
self._buffer += new_data
self._buffer_pos += len(new_data)
return result

def reset_buffer(self):
self._buffer = b""
self._buffer_pos = 0


class CorruptedDataError(RuntimeError):
pass


class RecordReader:
def __init__(
self, filename=None, start_offset=0, compression_type=None, status=None
):
if filename is None:
raise FileNotFoundError(
"No filename provided, cannot read Events"
)

filename = Path(filename)
if not filename.exists:
raise FileNotFoundError(
"{} does not point to valid Events file".format(filename),
)
if start_offset:
raise NotImplementedError(
"start offset not supported by compat reader"
)
if compression_type:
# TODO: Handle gzip and zlib compressed files
raise NotImplementedError(
"compression not supported by compat reader"
)

self.filename = filename
self.start_offset = start_offset
self.compression_type = compression_type
self.status = status
self.curr_event = None

def __iter__(self):

with BufferedReader(self.filename) as buf:
while True:
buf.reset_buffer()
header_str = buf.read(8)
if not header_str:
return
if len(header_str) < 8:
raise CorruptedDataError("Could not read header, data corrupted")
header = struct.unpack("<Q", header_str)

# Read the crc32, which is 4 bytes, and check it against
# the crc32 of the header
crc_header_str = buf.read(4)
if len(crc_header_str) < 4:
raise CorruptedDataError("Could not read CRC32 checksum, data corrupted")
crc_header = struct.unpack("<I", crc_header_str)
header_crc_calc = masked_crc32c(header_str)
if header_crc_calc != crc_header[0]:
raise CorruptedDataError('Checksum did not check out')

# The length of the header tells us how many bytes the Event
# string takes
header_len = int(header[0])
event_str = buf.read(header_len)
if len(event_str) < header_len:
raise CorruptedDataError("Could not read record, data corrupted")

event_crc_calc = masked_crc32c(event_str)

# The next 4 bytes contain the crc32 of the Event string,
# which we check for integrity.
crc_event_str = buf.read(4)
if len(crc_event_str) < 4:
raise CorruptedDataError("Could not read CRC32 data-checksum, data corrupted")
crc_event = struct.unpack("<I", crc_event_str)
if event_crc_calc != crc_event[0]:
raise CorruptedDataError('Data-checksum did not check out')

# Set the current event to be read later by record() call
yield event_str
59 changes: 59 additions & 0 deletions tests/record_reader_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# """Tests for RecordReader"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import six
import os
from tensorboardX.record_writer import RecordWriter
from tensorboardX.record_reader import RecordReader
import unittest


class RecordReaderTest(unittest.TestCase):
def get_temp_dir(self):
import tempfile
return tempfile.mkdtemp()

def test_empty_record(self):
filename = os.path.join(self.get_temp_dir(), "empty_record")
w = RecordWriter(filename)
bytes_to_write = b""
w.write(bytes_to_write)
w.close()
r = RecordReader(filename)
record, = list(r)
self.assertEqual(record, bytes_to_write)

def test_record_reader_roundtrip(self):
filename = os.path.join(self.get_temp_dir(), "record_writer_roundtrip")
w = RecordWriter(filename)
bytes_to_write = b"hello world"
times_to_test = 50
for _ in range(times_to_test):
w.write(bytes_to_write)
w.close()

r = RecordReader(filename)
for n, record in enumerate(r):
self.assertEqual(record, bytes_to_write)
assert n == times_to_test - 1

if __name__ == '__main__':
unittest.main()
Loading