Skip to content

Commit 00ebb80

Browse files
committed
DependencyTree now support more than one root + GameProgression can track multiple quests.
1 parent fc31a4a commit 00ebb80

File tree

7 files changed

+542
-243
lines changed

7 files changed

+542
-243
lines changed

tests/test_play_generated_games.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@ def test_play_generated_games():
4747
game_state, reward, done = env.step(command)
4848

4949
if done:
50-
msg = "Finished before playing `max_steps` steps."
50+
msg = "Finished before playing `max_steps` steps because of command '{}'.".format(command)
5151
if game_state.has_won:
5252
msg += " (winning)"
53-
assert game_state._game_progression.winning_policy == []
53+
assert len(game_state._game_progression.winning_policy) == 0
5454

5555
if game_state.has_lost:
5656
msg += " (losing)"

textworld/envs/glulx/git_glulx_ml.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def init(self, output: str, game=None,
146146
"""
147147
output = _strip_input_prompt_symbol(output)
148148
super().init(output)
149-
self._game_progression = GameProgression(game, track_quest=compute_intermediate_reward)
149+
self._game_progression = GameProgression(game, track_quests=compute_intermediate_reward)
150150
self._state_tracking = state_tracking
151151
self._compute_intermediate_reward = compute_intermediate_reward and len(game.quests) > 0
152152

textworld/generator/dependency_tree.py

Lines changed: 111 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44

55
import textwrap
6+
from typing import List, Any, Iterable
67

78
from textworld.utils import uniquify
89

@@ -18,42 +19,51 @@ class DependencyTreeElement:
1819
`__str__` accordingly.
1920
"""
2021

21-
def __init__(self, value):
22+
def __init__(self, value: Any):
2223
self.value = value
24+
self.parent = None
2325

24-
def depends_on(self, other):
26+
def depends_on(self, other: "DependencyTreeElement") -> bool:
2527
"""
2628
Check whether this element depends on the `other`.
2729
"""
2830
return self.value > other.value
2931

30-
def is_distinct_from(self, others):
32+
def is_distinct_from(self, others: Iterable["DependencyTreeElement"]) -> bool:
3133
"""
3234
Check whether this element is distinct from `others`.
3335
"""
3436
return self.value not in [other.value for other in others]
3537

36-
def __str__(self):
38+
def __str__(self) -> str:
3739
return str(self.value)
3840

3941

4042
class DependencyTree:
4143
class _Node:
42-
def __init__(self, element):
44+
def __init__(self, element: DependencyTreeElement):
4345
self.element = element
4446
self.children = []
47+
self.parent = None
4548

46-
def push(self, node):
49+
def push(self, node: "DependencyTree._Node") -> bool:
4750
if node == self:
48-
return
49-
51+
return True
52+
53+
added = False
5054
for child in self.children:
51-
child.push(node)
55+
added |= child.push(node)
5256

5357
if self.element.depends_on(node.element) and not self.already_added(node):
58+
node = node.copy()
5459
self.children.append(node)
60+
node.element.parent = self.element
61+
node.parent = self
62+
return True
63+
64+
return added
5565

56-
def already_added(self, node):
66+
def already_added(self, node: "DependencyTree._Node") -> bool:
5767
# We want to avoid duplicate information about dependencies.
5868
if node in self.children:
5969
return True
@@ -63,14 +73,15 @@ def already_added(self, node):
6373
if not node.element.is_distinct_from((child.element for child in self.children)):
6474
return True
6575

66-
# for child in self.children:
67-
# # if node.element.value == child.element.value:
68-
# if not node.element.is_distinct_from((child.element):
69-
# return True
70-
7176
return False
77+
78+
def __iter__(self) -> Iterable["DependencyTree._Node"]:
79+
for child in self.children:
80+
yield from list(child)
81+
82+
yield self
7283

73-
def __str__(self):
84+
def __str__(self) -> str:
7485
node_text = str(self.element)
7586

7687
txt = [node_text]
@@ -79,85 +90,112 @@ def __str__(self):
7990

8091
return "\n".join(txt)
8192

82-
def copy(self):
93+
def copy(self) -> "DependencyTree._Node":
8394
node = DependencyTree._Node(self.element)
84-
node.children = [child.copy() for child in self.children]
95+
for child in self.children:
96+
child_ = child.copy()
97+
child_.parent = node
98+
node.children.append(child_)
99+
85100
return node
86101

87-
def __init__(self, element_type=DependencyTreeElement):
88-
self.root = None
102+
def __init__(self, element_type: DependencyTreeElement = DependencyTreeElement, trees: Iterable["DependencyTree"] = []):
103+
self.roots = []
89104
self.element_type = element_type
90-
self._update()
91-
92-
def push(self, value):
93-
element = self.element_type(value)
94-
node = DependencyTree._Node(element)
95-
if self.root is None:
96-
self.root = node
97-
else:
98-
self.root.push(node)
105+
for tree in trees:
106+
self.roots += [root.copy() for root in tree.roots]
99107

100-
# Recompute leaves.
101108
self._update()
102-
if element in self.leaves_elements:
103-
return node
104109

105-
return None
110+
def push(self, value: Any, allow_multi_root: bool = False) -> bool:
111+
""" Add a value to this dependency tree.
106112
107-
def pop(self, value):
113+
Adding a value already present in the tree does not modify the tree.
114+
115+
Args:
116+
value: value to add.
117+
allow_multi_root: if `True`, allow the value to spawn an
118+
additional root if needed.
119+
120+
"""
121+
element = self.element_type(value)
122+
node = DependencyTree._Node(element)
123+
124+
added = False
125+
for root in self.roots:
126+
added |= root.push(node)
127+
128+
if len(self.roots) == 0 or (not added and allow_multi_root):
129+
self.roots.append(node)
130+
added = True
131+
132+
self._update() # Recompute leaves.
133+
return added
134+
135+
def remove(self, value: Any) -> None:
136+
""" Remove all leaves having the given value.
137+
138+
The value to remove needs to belong to at least one leaf in this tree.
139+
Otherwise, the tree remains unchanged.
140+
141+
Args:
142+
value: value to remove from the tree.
143+
144+
Returns:
145+
Whether the tree has changed or not.
146+
"""
108147
if value not in self.leaves_values:
109-
raise ValueError("That element is not a leaf: {!r}.".format(value))
110-
111-
def _visit(node):
112-
for child in list(node.children):
113-
if child.element.value == value:
114-
node.children.remove(child)
115-
116-
self._postorder(self.root, _visit)
117-
if self.root.element.value == value:
118-
self.root = None
119-
120-
# Recompute leaves.
121-
self._update()
122-
123-
def _postorder(self, node, visit):
124-
for child in node.children:
125-
self._postorder(child, visit)
126-
127-
visit(node)
148+
return False
128149

129-
def _update(self):
150+
root_to_remove = []
151+
for node in self:
152+
if node.element.value == value:
153+
if node.parent is not None:
154+
node.parent.children.remove(node)
155+
else:
156+
root_to_remove.append(node)
157+
158+
for node in root_to_remove:
159+
self.roots.remove(node)
160+
161+
self._update() # Recompute leaves.
162+
return True
163+
164+
def _update(self) -> None:
130165
self._leaves_values = []
131-
self._leaves_elements = set()
166+
self._leaves_elements = []
132167

133-
def _visit(node):
168+
for node in self:
134169
if len(node.children) == 0:
135-
self._leaves_elements.add(node.element)
170+
self._leaves_elements.append(node.element)
136171
self._leaves_values.append(node.element.value)
137172

138-
if self.root is not None:
139-
self._postorder(self.root, _visit)
140-
141173
self._leaves_values = uniquify(self._leaves_values)
174+
self._leaves_elements = uniquify(self._leaves_elements)
175+
176+
def copy(self) -> "DependencyTree":
177+
tree = type(self)(element_type=self.element_type)
178+
for root in self.roots:
179+
tree.roots.append(root.copy())
180+
181+
tree._update()
182+
return tree
142183

143-
def copy(self):
144-
tree = DependencyTree(self.element_type)
145-
if self.root is not None:
146-
tree.root = self.root.copy()
147-
tree._update()
184+
def __iter__(self) -> Iterable["DependencyTree._Node"]:
185+
for root in self.roots:
186+
yield from list(root)
148187

149-
return tree
188+
@property
189+
def values(self) -> List[Any]:
190+
return [node.element.value for node in self]
150191

151192
@property
152-
def leaves_elements(self):
193+
def leaves_elements(self) -> List[DependencyTreeElement]:
153194
return self._leaves_elements
154195

155196
@property
156-
def leaves_values(self):
197+
def leaves_values(self) -> List[Any]:
157198
return self._leaves_values
158199

159-
def __str__(self):
160-
if self.root is None:
161-
return ""
162-
163-
return str(self.root)
200+
def __str__(self) -> str:
201+
return "\n".join(map(str, self.roots))

0 commit comments

Comments
 (0)