Skip to content

Commit

Permalink
Fix uploader counts of scalars and tensors (#4698)
Browse files Browse the repository at this point in the history
  • Loading branch information
bileschi authored Feb 24, 2021
1 parent 2441835 commit 7bc40d0
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
4 changes: 2 additions & 2 deletions tensorboard/uploader/uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,6 @@ def add_event(self, run_name, event, value, metadata):
raise RuntimeError("add_event failed despite flush")

def _add_event_internal(self, run_name, event, value, metadata):
self._num_values += 1
run_proto = self._runs.get(run_name)
if run_proto is None:
run_proto = self._create_run(run_name)
Expand All @@ -539,6 +538,7 @@ def _add_event_internal(self, run_name, event, value, metadata):
tag_proto = self._create_tag(run_proto, value.tag, metadata)
self._tags[(run_name, value.tag)] = tag_proto
self._create_point(tag_proto, event, value)
self._num_values += 1

def flush(self):
"""Sends the active request after removing empty runs and tags.
Expand Down Expand Up @@ -703,6 +703,7 @@ def _add_event_internal(self, run_name, event, value, metadata):
tag_proto = self._create_tag(run_proto, value.tag, metadata)
self._tags[(run_name, value.tag)] = tag_proto
self._create_point(tag_proto, event, value, run_name)
self._num_values += 1

def flush(self):
"""Sends the active request after removing empty runs and tags.
Expand Down Expand Up @@ -788,7 +789,6 @@ def _create_point(self, tag_proto, event, value, run_name):
point.value.CopyFrom(value.tensor)
util.set_timestamp(point.wall_time, event.wall_time)

self._num_values += 1
self._tensor_bytes += point.value.ByteSize()
if point.value.ByteSize() > self._max_tensor_point_size:
logger.warning(
Expand Down
19 changes: 15 additions & 4 deletions tensorboard/uploader/uploader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,10 @@ def _create_request_sender(


def _create_scalar_request_sender(
experiment_id=None, api=_USE_DEFAULT, max_request_size=_USE_DEFAULT
experiment_id=None,
api=_USE_DEFAULT,
max_request_size=_USE_DEFAULT,
tracker=None,
):
if api is _USE_DEFAULT:
api = _create_mock_client()
Expand All @@ -209,7 +212,7 @@ def _create_scalar_request_sender(
api=api,
rpc_rate_limiter=util.RateLimiter(0),
max_request_size=max_request_size,
tracker=upload_tracker.UploadTracker(verbosity=0),
tracker=tracker or upload_tracker.UploadTracker(verbosity=0),
)


Expand All @@ -218,6 +221,7 @@ def _create_tensor_request_sender(
api=_USE_DEFAULT,
max_request_size=_USE_DEFAULT,
max_tensor_point_size=_USE_DEFAULT,
tracker=None,
):
if api is _USE_DEFAULT:
api = _create_mock_client()
Expand All @@ -231,7 +235,7 @@ def _create_tensor_request_sender(
rpc_rate_limiter=util.RateLimiter(0),
max_request_size=max_request_size,
max_tensor_point_size=max_tensor_point_size,
tracker=upload_tracker.UploadTracker(verbosity=0),
tracker=tracker or upload_tracker.UploadTracker(verbosity=0),
)


Expand Down Expand Up @@ -1302,12 +1306,13 @@ def test_break_at_scalar_point_boundary(self):
if step > 0:
summary.value[0].ClearField("metadata")
events.append(event_pb2.Event(summary=summary, step=step))

tracker = upload_tracker.UploadTracker(verbosity=0)
sender = _create_scalar_request_sender(
"123",
mock_client,
# Set a limit to request size
max_request_size=1024,
tracker=tracker,
)
self._add_events(sender, "train", _apply_compat(events))
sender.flush()
Expand Down Expand Up @@ -1337,6 +1342,8 @@ def test_break_at_scalar_point_boundary(self):
total_points_in_result += 1
self.assertLessEqual(request.ByteSize(), 1024)
self.assertEqual(total_points_in_result, point_count)
with self.subTest("Scalar report count correct."):
self.assertEqual(tracker._stats.num_scalars, point_count)

def test_prunes_tags_and_runs(self):
mock_client = _create_mock_client()
Expand Down Expand Up @@ -1674,11 +1681,13 @@ def test_break_at_tensor_point_boundary(self):
event.summary.value.add(tag="histo", tensor=tensor_proto)
events.append(event)

tracker = upload_tracker.UploadTracker(verbosity=0)
sender = _create_tensor_request_sender(
"123",
mock_client,
# Set a limit to request size
max_request_size=1024,
tracker=tracker,
)
self._add_events(sender, "train", _apply_compat(events))
sender.flush()
Expand All @@ -1705,6 +1714,8 @@ def test_break_at_tensor_point_boundary(self):
total_points_in_result += 1
self.assertLessEqual(request.ByteSize(), 1024)
self.assertEqual(total_points_in_result, point_count)
with self.subTest("Tensor report count correct."):
self.assertEqual(tracker._stats.num_tensors, point_count)

def test_strip_large_tensors(self):
# Generate test data with varying tensor point sizes. Use raw bytes.
Expand Down

0 comments on commit 7bc40d0

Please sign in to comment.