Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions tensorboard/plugins/hparams/_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
Use `tensorboard.plugins.hparams.api` to access this module's contents.
"""


import os
import tensorflow as tf

from tensorboard.plugins.hparams import api_pb2
Expand All @@ -39,7 +39,7 @@ def __init__(self, writer, hparams, trial_id=None):

Args:
writer: The `SummaryWriter` object to which hparams should be
written, or a logdir (as a `str`) to be passed to
written, or a logdir (as a `str` or `PathLike`) to be passed to
`tf.summary.create_file_writer` to create such a writer.
hparams: A `dict` mapping hyperparameters to the values used in
this session. Keys should be the names of `HParam` objects used
Expand All @@ -62,10 +62,12 @@ def __init__(self, writer, hparams, trial_id=None):
summary_v2.hparams_pb(self._hparams, trial_id=self._trial_id)
if writer is None:
raise TypeError(
"writer must be a `SummaryWriter` or `str`, not None"
"writer must be a `SummaryWriter`, `str` or `PathLike`, not None"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a nitpick here but it would be nice to keep the formatting consistent

Suggested change
"writer must be a `SummaryWriter`, `str` or `PathLike`, not None"
"writer must be a `SummaryWriter`, `str`, or `PathLike`, not None"

)
elif isinstance(writer, str):
self._writer = tf.compat.v2.summary.create_file_writer(writer)
elif isinstance(writer, os.PathLike):
self._writer = tf.compat.v2.summary.create_file_writer(os.fspath(writer))
else:
self._writer = writer

Expand Down
9 changes: 9 additions & 0 deletions tensorboard/plugins/hparams/_keras_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@


import os
from pathlib import Path
from unittest import mock

from google.protobuf import text_format
Expand Down Expand Up @@ -145,6 +146,14 @@ def test_explicit_writer(self):
# We'll assume that the contents are correct, as in the case where
# the file writer was constructed implicitly.

def test_pathlib_writer(self):
writer = Path(self.logdir)
self._initialize_model(writer=writer)
self.model.fit(x=[(1,)], y=[(2,)], callbacks=[self.callback])

files = os.listdir(self.logdir)
self.assertEqual(len(files), 1, files)

def test_non_eager_failure(self):
with tf.compat.v1.Graph().as_default():
assert not tf.executing_eagerly()
Expand Down