forked from Learning-and-Intelligent-Systems/predicators
-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathdemo_only.py
More file actions
348 lines (329 loc) · 16.4 KB
/
demo_only.py
File metadata and controls
348 lines (329 loc) · 16.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
"""Create offline datasets by collecting demonstrations."""
import functools
import logging
import os
import re
from typing import Callable, List, Set
import dill as pkl
import matplotlib
import matplotlib.pyplot as plt
from predicators import utils
from predicators.approaches import ApproachFailure, ApproachTimeout
from predicators.approaches.oracle_approach import OracleApproach
from predicators.envs import BaseEnv
from predicators.envs.behavior import BehaviorEnv
from predicators.planning import _run_plan_with_option_model
from predicators.settings import CFG
from predicators.structs import Action, Dataset, LowLevelTrajectory, \
ParameterizedOption, State, Task
def create_demo_data(env: BaseEnv, train_tasks: List[Task],
known_options: Set[ParameterizedOption]) -> Dataset:
"""Create offline datasets by collecting demos."""
assert CFG.demonstrator in ("oracle", "human")
dataset_fname, dataset_fname_template = utils.create_dataset_filename_str(
saving_ground_atoms=False)
os.makedirs(CFG.data_dir, exist_ok=True)
if CFG.load_data:
dataset = _create_demo_data_with_loading(env, train_tasks,
known_options,
dataset_fname_template,
dataset_fname)
else:
trajectories = _generate_demonstrations(env,
train_tasks,
known_options,
train_tasks_start_idx=0)
logging.info(f"\n\nCREATED {len(trajectories)} DEMONSTRATIONS")
dataset = Dataset(trajectories)
# NOTE: This is necessary because BEHAVIOR options save
# the BEHAVIOR environment object in their memory, and this
# can't be pickled.
if CFG.env == "behavior": # pragma: no cover
for traj in dataset.trajectories:
for act in traj.actions:
act.get_option().memory = {}
with open(dataset_fname, "wb") as f:
pkl.dump(dataset, f)
# Pickle information about dataset created.
if CFG.env == "behavior": # pragma: no cover
assert isinstance(env, BehaviorEnv)
info = {}
info["behavior_task_list"] = CFG.behavior_task_list
info["behavior_train_scene_name"] = CFG.behavior_train_scene_name
info["behavior_test_scene_name"] = CFG.behavior_test_scene_name
info["seed"] = CFG.seed
if len(CFG.behavior_task_list) != 1:
info["task_list_indices"] = env.task_list_indices
info["scene_list"] = env.scene_list
info[
"task_num_task_instance_id_to_igibson_seed"] = \
env.task_num_task_instance_id_to_igibson_seed
with open(dataset_fname.replace(".data", ".info"), "wb") as f:
pkl.dump(info, f)
# NOTE: This is necessary because we replace BEHAVIOR
# options with dummy options in order to pickle them, so
# when we load them, we need to make sure they have the
# correct options from the environment.
if CFG.env == "behavior": # pragma: no cover
assert isinstance(env, BehaviorEnv)
option_name_to_option = env.option_name_to_option
for traj in dataset.trajectories:
for act in traj.actions:
dummy_opt = act.get_option()
gt_param_opt = option_name_to_option[dummy_opt.name]
gt_opt = gt_param_opt.ground(dummy_opt.objects,
dummy_opt.params)
act.set_option(gt_opt)
return dataset
def _create_demo_data_with_loading(env: BaseEnv, train_tasks: List[Task],
known_options: Set[ParameterizedOption],
dataset_fname_template: str,
dataset_fname: str) -> Dataset:
"""Create demonstration data while handling loading from disk.
This method takes care of three cases: the demonstrations on disk
are exactly the desired number, too many, or too few.
"""
if os.path.exists(dataset_fname):
# Case 1: we already have a file with the exact name that we need
# (i.e., the correct amount of data).
with open(dataset_fname, "rb") as f:
dataset = pkl.load(f)
logging.info(f"\n\nLOADED DATASET OF {len(dataset.trajectories)} "
"DEMONSTRATIONS")
return dataset
fnames_with_less_data = {} # used later, in Case 3
for fname in os.listdir(CFG.data_dir):
regex_match = re.match(dataset_fname_template, fname)
if not regex_match:
continue
num_train_tasks = int(regex_match.groups()[0])
assert num_train_tasks != CFG.num_train_tasks # would be Case 1
# Case 2: we already have a file with MORE data than we need. Load
# and truncate this data.
if num_train_tasks > CFG.num_train_tasks:
with open(os.path.join(CFG.data_dir, fname), "rb") as f:
dataset = pkl.load(f)
logging.info("\n\nLOADED AND TRUNCATED DATASET OF "
f"{len(dataset.trajectories)} DEMONSTRATIONS")
assert not dataset.has_annotations
# To truncate, note that we can't simply take the first
# `CFG.num_train_tasks` elements of `dataset.trajectories`,
# because some of these might have a `train_task_idx` that is
# out of range (if there were errors in the course of
# collecting those demonstrations). The correct thing to do
# here is to truncate based on the value of `train_task_idx`.
return Dataset([
traj for traj in dataset.trajectories
if traj.train_task_idx < CFG.num_train_tasks
])
# Save the names of all datasets that have less data than
# we need, to be used in Case 3.
fnames_with_less_data[num_train_tasks] = fname
if not fnames_with_less_data:
# Give up: we did not find any data file we can load from.
raise ValueError(f"Cannot load data: {dataset_fname}")
# Case 3: we already have a file with LESS data than we need. Load
# this data and generate some more. Specifically, we load from the
# file with the maximum data among all files that have less data
# than we need, then we generate the remaining demonstrations.
train_tasks_start_idx = max(fnames_with_less_data)
fname = fnames_with_less_data[train_tasks_start_idx]
with open(os.path.join(CFG.data_dir, fname), "rb") as f:
dataset = pkl.load(f)
# NOTE: This is necessary because we replace BEHAVIOR
# options with dummy options in order to pickle them, so
# when we load them, we need to make sure they have the
# correct options from the environment.
if CFG.env == "behavior": # pragma: no cover
assert isinstance(env, BehaviorEnv)
option_name_to_option = env.option_name_to_option
for traj in dataset.trajectories:
for act in traj.actions:
dummy_opt = act.get_option()
gt_param_opt = option_name_to_option[dummy_opt.name]
gt_opt = gt_param_opt.ground(dummy_opt.objects,
dummy_opt.params)
act.set_option(gt_opt)
loaded_trajectories = dataset.trajectories
generated_trajectories = _generate_demonstrations(
env,
train_tasks,
known_options,
train_tasks_start_idx=train_tasks_start_idx)
logging.info(f"\n\nLOADED DATASET OF {len(loaded_trajectories)} "
"DEMONSTRATIONS")
logging.info(f"CREATED {len(generated_trajectories)} DEMONSTRATIONS")
dataset = Dataset(loaded_trajectories + generated_trajectories)
# NOTE: This is necessary because BEHAVIOR options save
# the BEHAVIOR environment object in their memory, and this
# can't be pickled.
if CFG.env == "behavior": # pragma: no cover
for traj in dataset.trajectories:
for act in traj.actions:
act.get_option().memory = {}
with open(dataset_fname, "wb") as f:
pkl.dump(dataset, f)
return dataset
def _generate_demonstrations(
env: BaseEnv, train_tasks: List[Task],
known_options: Set[ParameterizedOption],
train_tasks_start_idx: int) -> List[LowLevelTrajectory]:
"""Use the demonstrator to generate demonstrations, one per training task
starting from train_tasks_start_idx."""
if CFG.demonstrator == "oracle":
oracle_approach = OracleApproach(
env.predicates,
env.options,
env.types,
env.action_space,
train_tasks,
task_planning_heuristic=CFG.offline_data_task_planning_heuristic,
max_skeletons_optimized=CFG.offline_data_max_skeletons_optimized)
else: # pragma: no cover
# Disable all built-in keyboard shortcuts.
keymaps = {k for k in plt.rcParams if k.startswith("keymap.")}
for k in keymaps:
plt.rcParams[k].clear()
# Create the environment-specific method for turning events into
# actions. This should also log instructions.
event_to_action = env.get_event_to_action_fn()
trajectories = []
num_tasks = min(len(train_tasks), CFG.max_initial_demos)
for idx, task in enumerate(train_tasks):
if idx < train_tasks_start_idx: # ignore demos before this index
continue
# Note: we assume in main.py that demonstrations are only generated
# for train tasks whose index is less than CFG.max_initial_demos. If
# you modify code around here, make sure that this invariant holds.
if idx >= CFG.max_initial_demos:
break
# Loop over task until successful completion.
attempts = 0
while attempts < CFG.max_demo_attempts:
attempts += 1
skip = False
try:
if CFG.demonstrator == "oracle":
timeout = CFG.offline_data_planning_timeout
if timeout == -1:
timeout = CFG.timeout
oracle_approach.solve(task, timeout=timeout)
# Since we're running the oracle approach, we know that
# the policy is actually a plan under the hood, and we
# can retrieve it with get_last_plan(). We do this
# because we want to run the full plan.
last_plan = oracle_approach.get_last_plan()
policy = utils.option_plan_to_policy(last_plan)
# We will stop run_policy() when OptionExecutionFailure()
# is hit, which should only happen when the goal has been
# reached, as verified by the assertion later.
termination_function = lambda s: False
else: # pragma: no cover
policy = functools.partial(_human_demonstrator_policy, env,
idx, num_tasks, task,
event_to_action)
termination_function = task.goal_holds
if CFG.env == "behavior": # pragma: no cover
# For BEHAVIOR we are generating the trajectory by running
# our plan on our option models. Since option models
# return only states, we will add dummy actions to the
# states to create our low-level trajectories.
last_traj = oracle_approach.get_last_traj()
traj, success = _run_plan_with_option_model(
idx, oracle_approach.get_option_model(),
last_plan, task=task, last_traj=last_traj)
# Is successful if we found a low-level plan that achieves
# our goal using option models.
if not success:
raise ApproachFailure(
"Falied execution of low-level plan on option model"
)
else:
if CFG.make_demo_videos:
monitor = utils.VideoMonitor(env.render)
else:
monitor = None
traj, _ = utils.run_policy(
policy,
env,
"train",
idx,
termination_function=termination_function,
max_num_steps=CFG.horizon,
exceptions_to_break_on={
utils.OptionExecutionFailure,
utils.HumanDemonstrationFailure,
},
monitor=monitor)
except (ApproachTimeout, ApproachFailure,
utils.EnvironmentFailure) as e:
logging.warning(
"WARNING: Approach failed to solve with error: "
f"{e}")
skip = True
continue
# Check that the goal holds at the end. Print a warning if not.
if not task.goal_holds(traj.states[-1]): # pragma: no cover
logging.warning("WARNING: Oracle failed on training task.")
skip = True
continue
if CFG.demonstrator == "human": # pragma: no cover
logging.info("Successfully collected human demonstration of "
f"length {len(traj.states)} for task {idx+1} / "
f"{num_tasks}.")
break # pragma: no cover
# Here we skip if the last attempt to create the demo trajectory
# was not successful and move on to producing the next trajectory.
if skip:
continue
# Add is_demo flag and task index information into the trajectory.
traj = LowLevelTrajectory(traj.states,
traj.actions,
_is_demo=True,
_train_task_idx=idx)
# To prevent cheating by option learning approaches, remove all oracle
# options from the trajectory actions, unless the options are known
# (via CFG.included_options or CFG.option_learner = 'no_learning').
if CFG.demonstrator == "oracle":
for act in traj.actions:
if act.get_option().parent not in known_options:
assert CFG.option_learner != "no_learning"
act.unset_option()
trajectories.append(traj)
if CFG.make_demo_videos:
assert monitor is not None
video = monitor.get_video()
outfile = f"{CFG.env}__{CFG.seed}__demo__task{idx}.mp4"
utils.save_video(outfile, video)
return trajectories
def _human_demonstrator_policy(env: BaseEnv, idx: int, num_tasks: int,
task: Task, event_to_action: Callable[
[State, matplotlib.backend_bases.Event],
Action],
state: State) -> Action: # pragma: no cover
# Temporarily change the backend to one that supports a GUI.
# We do this here because we don't want the rest of the codebase
# to use GUI-based Matplotlib.
cur_backend = matplotlib.get_backend()
matplotlib.use("Qt5Agg")
# Render the state.
caption = (f"Task {idx+1} / {num_tasks}\nPlease demonstrate "
f"achieving the goal:\n{task.goal}")
fig = env.render_plt(caption=caption)
container = {}
def _handler(event: matplotlib.backend_bases.Event) -> None:
container["action"] = event_to_action(state, event)
keyboard_cid = fig.canvas.mpl_connect("key_press_event", _handler)
mouse_cid = fig.canvas.mpl_connect("button_press_event", _handler)
# Hang until either a mouse press or a keyboard press.
plt.waitforbuttonpress()
fig.canvas.mpl_disconnect(keyboard_cid)
fig.canvas.mpl_disconnect(mouse_cid)
plt.close()
if "action" not in container:
logging.warning("WARNING: Event handler failed. Its error message "
"should be printed above. Terminating task.")
raise utils.HumanDemonstrationFailure("Event handler failed!")
# Revert to the previous backend.
matplotlib.use(cur_backend)
return container["action"]