|
4 | 4 | constructing the computational graph. |
5 | 5 | """ |
6 | 6 |
|
7 | | - |
8 | | -from typing import List, Optional, Iterable |
9 | | - |
| 7 | +from typing import List, Optional |
10 | 8 | import tensorflow as tf |
11 | 9 |
|
12 | | -from neuralmonkey.logging import log, debug |
13 | | -from neuralmonkey.dataset import Dataset |
14 | | -from neuralmonkey.runners.base_runner import BaseRunner |
15 | | - |
16 | 10 |
|
17 | 11 | class CheckingException(Exception): |
18 | 12 | pass |
19 | 13 |
|
20 | 14 |
|
21 | | -def check_dataset_and_coders(dataset: Dataset, |
22 | | - runners: Iterable[BaseRunner]) -> None: |
23 | | - # pylint: disable=protected-access |
24 | | - |
25 | | - data_list = [] |
26 | | - for runner in runners: |
27 | | - for c in runner.feedables: |
28 | | - if hasattr(c, "data_id"): |
29 | | - data_list.append((getattr(c, "data_id"), c)) |
30 | | - elif hasattr(c, "data_ids"): |
31 | | - data_list.extend([(d, c) for d in getattr(c, "data_ids")]) |
32 | | - elif hasattr(c, "input_sequence"): |
33 | | - inpseq = getattr(c, "input_sequence") |
34 | | - if hasattr(inpseq, "data_id"): |
35 | | - data_list.append((getattr(inpseq, "data_id"), c)) |
36 | | - elif hasattr(inpseq, "data_ids"): |
37 | | - data_list.extend( |
38 | | - [(d, c) for d in getattr(inpseq, "data_ids")]) |
39 | | - else: |
40 | | - log("Input sequence: {} does not have a data attribute" |
41 | | - .format(str(inpseq))) |
42 | | - else: |
43 | | - log(("Coder: {} has neither an input sequence attribute nor a " |
44 | | - "a data attribute.").format(c)) |
45 | | - |
46 | | - debug("Found series: {}".format(str(data_list)), "checking") |
47 | | - missing = [] |
48 | | - |
49 | | - for (serie, coder) in data_list: |
50 | | - if serie not in dataset: |
51 | | - log("dataset {} does not have serie {}".format( |
52 | | - dataset.name, serie)) |
53 | | - missing.append((coder, serie)) |
54 | | - |
55 | | - if missing: |
56 | | - formated = ["{} ({}, {}.{})" .format(serie, str(cod), |
57 | | - cod.__class__.__module__, |
58 | | - cod.__class__.__name__) |
59 | | - for cod, serie in missing] |
60 | | - |
61 | | - raise CheckingException("Dataset '{}' is mising series {}:" |
62 | | - .format(dataset.name, ", ".join(formated))) |
63 | | - |
64 | | - |
65 | 15 | def assert_shape(tensor: tf.Tensor, |
66 | 16 | expected_shape: List[Optional[int]]) -> None: |
67 | 17 | """Check shape of a tensor. |
|
0 commit comments