2525
2626from tensorboard .data import provider
2727from tensorboard .plugins .hparams import api_pb2
28+ from tensorboard .plugins .hparams import backend_context as backend_context_lib
2829from tensorboard .plugins .hparams import error
2930from tensorboard .plugins .hparams import json_format_compat
3031from tensorboard .plugins .hparams import metadata
3132from tensorboard .plugins .hparams import metrics
33+ from tensorboard .plugins .hparams import plugin_data_pb2
3234
3335
3436class Handler :
@@ -93,13 +95,15 @@ def _session_groups_from_tags(self):
9395 hparams_run_to_tag_to_content ,
9496 # Don't pass any information from the DataProvider since we are only
9597 # examining session groups based on tag metadata
96- [],
98+ provider .ListHyperparametersResult (
99+ hyperparameters = [], session_groups = []
100+ ),
97101 )
98102 extractors = _create_extractors (self ._request .col_params )
99103 filters = _create_filters (self ._request .col_params , extractors )
100104
101105 session_groups = self ._build_session_groups (
102- hparams_run_to_tag_to_content , experiment
106+ hparams_run_to_tag_to_content , experiment . metric_infos
103107 )
104108 session_groups = self ._filter (session_groups , filters )
105109 self ._sort (session_groups , extractors )
@@ -116,16 +120,37 @@ def _session_groups_from_data_provider(self):
116120 sort ,
117121 )
118122
123+ metric_infos = self ._backend_context .compute_metric_infos_from_data_provider_session_groups (
124+ self ._request_context , self ._experiment_id , response
125+ )
126+
127+ all_metric_evals = self ._backend_context .read_last_scalars (
128+ self ._request_context ,
129+ self ._experiment_id ,
130+ run_tag_filter = None ,
131+ )
132+
119133 session_groups = []
120134 for provider_group in response :
121- sessions = [
122- api_pb2 .Session (name = f"{ s .experiment_id } /{ s .run } " )
123- for s in provider_group .sessions
124- ]
125- name = (
126- f"{ provider_group .root .experiment_id } /{ provider_group .root .run } "
127- if provider_group .root .run
128- else provider_group .root .experiment_id
135+ sessions = []
136+ for session in provider_group .sessions :
137+ session_name = (
138+ backend_context_lib .generate_data_provider_session_name (
139+ self ._experiment_id , session
140+ )
141+ )
142+ sessions .append (
143+ self ._build_session (
144+ metric_infos ,
145+ session_name ,
146+ plugin_data_pb2 .SessionStartInfo (),
147+ plugin_data_pb2 .SessionEndInfo (),
148+ all_metric_evals ,
149+ )
150+ )
151+
152+ name = backend_context_lib .generate_data_provider_session_name (
153+ self ._experiment_id , provider_group .root
129154 )
130155 session_group = api_pb2 .SessionGroup (
131156 name = name ,
@@ -154,9 +179,16 @@ def _session_groups_from_data_provider(self):
154179
155180 session_groups .append (session_group )
156181
182+ # Compute the session group's aggregated metrics for each group.
183+ for group in session_groups :
184+ if group .sessions :
185+ self ._aggregate_metrics (group )
186+
157187 return session_groups
158188
159- def _build_session_groups (self , hparams_run_to_tag_to_content , experiment ):
189+ def _build_session_groups (
190+ self , hparams_run_to_tag_to_content , metric_infos
191+ ):
160192 """Returns a list of SessionGroups protobuffers from the summary
161193 data."""
162194
@@ -178,7 +210,7 @@ def _build_session_groups(self, hparams_run_to_tag_to_content, experiment):
178210 metric_runs = set ()
179211 metric_tags = set ()
180212 for session_name in session_names :
181- for metric in experiment . metric_infos :
213+ for metric in metric_infos :
182214 metric_name = metric .name
183215 (run , tag ) = metrics .run_tag_from_session_and_metric (
184216 session_name , metric_name
@@ -207,7 +239,11 @@ def _build_session_groups(self, hparams_run_to_tag_to_content, experiment):
207239 tag_to_content [metadata .SESSION_END_INFO_TAG ]
208240 )
209241 session = self ._build_session (
210- experiment , session_name , start_info , end_info , all_metric_evals
242+ metric_infos ,
243+ session_name ,
244+ start_info ,
245+ end_info ,
246+ all_metric_evals ,
211247 )
212248 if session .status in self ._request .allowed_statuses :
213249 self ._add_session (session , start_info , groups_by_name )
@@ -263,7 +299,7 @@ def _add_session(self, session, start_info, groups_by_name):
263299 groups_by_name [group_name ] = group
264300
265301 def _build_session (
266- self , experiment , name , start_info , end_info , all_metric_evals
302+ self , metric_infos , name , start_info , end_info , all_metric_evals
267303 ):
268304 """Builds a session object."""
269305
@@ -273,7 +309,7 @@ def _build_session(
273309 start_time_secs = start_info .start_time_secs ,
274310 model_uri = start_info .model_uri ,
275311 metric_values = self ._build_session_metric_values (
276- experiment , name , all_metric_evals
312+ metric_infos , name , all_metric_evals
277313 ),
278314 monitor_url = start_info .monitor_url ,
279315 )
@@ -283,13 +319,13 @@ def _build_session(
283319 return result
284320
285321 def _build_session_metric_values (
286- self , experiment , session_name , all_metric_evals
322+ self , metric_infos , session_name , all_metric_evals
287323 ):
288324 """Builds the session metric values."""
289325
290326 # result is a list of api_pb2.MetricValue instances.
291327 result = []
292- for metric_info in experiment . metric_infos :
328+ for metric_info in metric_infos :
293329 metric_name = metric_info .name
294330 (run , tag ) = metrics .run_tag_from_session_and_metric (
295331 session_name , metric_name
0 commit comments