forked from Learning-and-Intelligent-Systems/predicators
-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathutils.py
More file actions
3006 lines (2568 loc) · 116 KB
/
utils.py
File metadata and controls
3006 lines (2568 loc) · 116 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""General utility methods."""
from __future__ import annotations
import abc
import contextlib
import functools
import gc
import heapq as hq
import io
import itertools
import logging
import os
import re
import subprocess
import sys
import time
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Collection, Dict, \
FrozenSet, Generator, Generic, Hashable, Iterator, List, Optional, \
Sequence, Set, Tuple
from typing import Type as TypingType
from typing import TypeVar, Union, cast
import imageio
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pathos.multiprocessing as mp
from gym.spaces import Box
from matplotlib import patches
from pyperplan.heuristics.heuristic_base import \
Heuristic as _PyperplanBaseHeuristic
from pyperplan.planner import HEURISTICS as _PYPERPLAN_HEURISTICS
from predicators.args import create_arg_parser
from predicators.pybullet_helpers.joint import JointPositions
from predicators.settings import CFG, GlobalSettings
from predicators.structs import NSRT, Action, Array, DummyOption, \
EntToEntSub, GroundAtom, GroundAtomTrajectory, \
GroundNSRTOrSTRIPSOperator, Image, LDLRule, LiftedAtom, \
LiftedDecisionList, LiftedOrGroundAtom, LowLevelTrajectory, Metrics, \
NSRTOrSTRIPSOperator, Object, ObjectOrVariable, OptionSpec, \
ParameterizedOption, Predicate, Segment, State, STRIPSOperator, Task, \
Type, Variable, VarToObjSub, Video, _GroundLDLRule, _GroundNSRT, \
_GroundSTRIPSOperator, _Option, _TypedEntity
from predicators.third_party.fast_downward_translator.translate import \
main as downward_translate
if TYPE_CHECKING:
from predicators.envs import BaseEnv
matplotlib.use("Agg")
def count_positives_for_ops(
strips_ops: List[STRIPSOperator],
option_specs: List[OptionSpec],
segments: List[Segment],
max_groundings: Optional[int] = None,
) -> Tuple[int, int, List[Set[int]], List[Set[int]]]:
"""Returns num true positives, num false positives, and for each strips op,
lists of segment indices that contribute true or false positives.
The lists of segment indices are useful only for debugging; they are
otherwise redundant with num_true_positives/num_false_positives.
"""
assert len(strips_ops) == len(option_specs)
num_true_positives = 0
num_false_positives = 0
# The following two lists are just useful for debugging.
true_positive_idxs: List[Set[int]] = [set() for _ in strips_ops]
false_positive_idxs: List[Set[int]] = [set() for _ in strips_ops]
for seg_idx, segment in enumerate(segments):
objects = set(segment.states[0])
segment_option = segment.get_option()
option_objects = segment_option.objects
covered_by_some_op = False
# Ground only the operators with a matching option spec.
for op_idx, (op,
option_spec) in enumerate(zip(strips_ops, option_specs)):
# If the parameterized options are different, not relevant.
if option_spec[0] != segment_option.parent:
continue
option_vars = option_spec[1]
assert len(option_vars) == len(option_objects)
option_var_to_obj = dict(zip(option_vars, option_objects))
# We want to get all ground operators whose corresponding
# substitution is consistent with the option vars for this
# segment. So, determine all of the operator variables
# that are not in the option vars, and consider all
# groundings of them.
for grounding_idx, ground_op in enumerate(
all_ground_operators_given_partial(op, objects,
option_var_to_obj)):
if max_groundings is not None and \
grounding_idx > max_groundings:
break
# Check the ground_op against the segment.
if not ground_op.preconditions.issubset(segment.init_atoms):
continue
if ground_op.add_effects == segment.add_effects and \
ground_op.delete_effects == segment.delete_effects:
covered_by_some_op = True
true_positive_idxs[op_idx].add(seg_idx)
else:
false_positive_idxs[op_idx].add(seg_idx)
num_false_positives += 1
if covered_by_some_op:
num_true_positives += 1
return num_true_positives, num_false_positives, \
true_positive_idxs, false_positive_idxs
def count_branching_factor(strips_ops: List[STRIPSOperator],
segments: List[Segment]) -> int:
"""Returns the total branching factor for all states in the segments."""
total_branching_factor = 0
for segment in segments:
atoms = segment.init_atoms
objects = set(segment.states[0])
ground_ops = {
ground_op
for op in strips_ops
for ground_op in all_ground_operators(op, objects)
}
for _ in get_applicable_operators(ground_ops, atoms):
total_branching_factor += 1
return total_branching_factor
def segment_trajectory_to_state_sequence(
seg_traj: List[Segment]) -> List[State]:
"""Convert a trajectory of segments into a trajectory of states, made up of
only the initial/final states of the segments.
The length of the return value will always be one greater than the
length of the given seg_traj.
"""
assert len(seg_traj) >= 1
states = []
for i, seg in enumerate(seg_traj):
states.append(seg.states[0])
if i < len(seg_traj) - 1:
assert seg.states[-1].allclose(seg_traj[i + 1].states[0])
states.append(seg_traj[-1].states[-1])
assert len(states) == len(seg_traj) + 1
return states
def segment_trajectory_to_atoms_sequence(
seg_traj: List[Segment]) -> List[Set[GroundAtom]]:
"""Convert a trajectory of segments into a trajectory of ground atoms.
The length of the return value will always be one greater than the
length of the given seg_traj.
"""
assert len(seg_traj) >= 1
atoms_seq = []
for i, seg in enumerate(seg_traj):
atoms_seq.append(seg.init_atoms)
if i < len(seg_traj) - 1:
assert seg.final_atoms == seg_traj[i + 1].init_atoms
atoms_seq.append(seg_traj[-1].final_atoms)
assert len(atoms_seq) == len(seg_traj) + 1
return atoms_seq
def num_options_in_action_sequence(actions: Sequence[Action]) -> int:
"""Given a sequence of actions with options included, get the number of
options that are encountered."""
num_options = 0
last_option = None
for action in actions:
current_option = action.get_option()
if not current_option is last_option:
last_option = current_option
num_options += 1
return num_options
def entropy(p: float) -> float:
"""Entropy of a Bernoulli variable with parameter p."""
assert 0.0 <= p <= 1.0
if p in {0.0, 1.0}:
return 0.0
return -(p * np.log2(p) + (1 - p) * np.log2(1 - p))
def create_state_from_dict(data: Dict[Object, Dict[str, float]],
simulator_state: Optional[Any] = None) -> State:
"""Small utility to generate a state from a dictionary `data` of individual
feature values for each object.
A simulator_state for the outputted State may optionally be
provided.
"""
state_dict = {}
for obj, obj_data in data.items():
obj_vec = []
for feat in obj.type.feature_names:
obj_vec.append(obj_data[feat])
state_dict[obj] = np.array(obj_vec)
return State(state_dict, simulator_state)
class _Geom2D(abc.ABC):
"""A 2D shape that contains some points."""
@abc.abstractmethod
def plot(self, ax: plt.Axes, **kwargs: Any) -> None:
"""Plot the shape on a given pyplot axis."""
raise NotImplementedError("Override me!")
@abc.abstractmethod
def contains_point(self, x: float, y: float) -> bool:
"""Checks if a point is contained in the shape."""
raise NotImplementedError("Override me!")
def intersects(self, other: _Geom2D) -> bool:
"""Checks if this shape intersects with another one."""
return geom2ds_intersect(self, other)
@dataclass(frozen=True)
class LineSegment(_Geom2D):
"""A helper class for visualizing and collision checking line segments."""
x1: float
y1: float
x2: float
y2: float
def plot(self, ax: plt.Axes, **kwargs: Any) -> None:
ax.plot([self.x1, self.x2], [self.y1, self.y2], **kwargs)
def contains_point(self, x: float, y: float) -> bool:
# https://stackoverflow.com/questions/328107
a = (self.x1, self.y1)
b = (self.x2, self.y2)
c = (x, y)
# Need to use an epsilon for numerical stability. But we are checking
# if the distance from a to b is (approximately) equal to the distance
# from a to c and the distance from c to b.
eps = 1e-6
def _dist(p: Tuple[float, float], q: Tuple[float, float]) -> float:
return np.sqrt((p[0] - q[0])**2 + (p[1] - q[1])**2)
return -eps < _dist(a, c) + _dist(c, b) - _dist(a, b) < eps
@dataclass(frozen=True)
class Circle(_Geom2D):
"""A helper class for visualizing and collision checking circles."""
x: float
y: float
radius: float
def plot(self, ax: plt.Axes, **kwargs: Any) -> None:
patch = patches.Circle((self.x, self.y), self.radius, **kwargs)
ax.add_patch(patch)
def contains_point(self, x: float, y: float) -> bool:
return (x - self.x)**2 + (y - self.y)**2 <= self.radius**2
@dataclass(frozen=True)
class Triangle(_Geom2D):
"""A helper class for visualizing and collision checking triangles."""
x1: float
y1: float
x2: float
y2: float
x3: float
y3: float
def plot(self, ax: plt.Axes, **kwargs: Any) -> None:
patch = patches.Polygon(
[[self.x1, self.y1], [self.x2, self.y2], [self.x3, self.y3]],
**kwargs)
ax.add_patch(patch)
def __post_init__(self) -> None:
dist1 = np.sqrt((self.x1 - self.x2)**2 + (self.y1 - self.y2)**2)
dist2 = np.sqrt((self.x2 - self.x3)**2 + (self.y2 - self.y3)**2)
dist3 = np.sqrt((self.x3 - self.x1)**2 + (self.y3 - self.y1)**2)
dists = sorted([dist1, dist2, dist3])
assert dists[0] + dists[1] >= dists[2]
if dists[0] + dists[1] == dists[2]:
raise ValueError("Degenerate triangle!")
def contains_point(self, x: float, y: float) -> bool:
# Adapted from https://stackoverflow.com/questions/2049582/.
sign1 = ((x - self.x2) * (self.y1 - self.y2) - (self.x1 - self.x2) *
(y - self.y2)) > 0
sign2 = ((x - self.x3) * (self.y2 - self.y3) - (self.x2 - self.x3) *
(y - self.y3)) > 0
sign3 = ((x - self.x1) * (self.y3 - self.y1) - (self.x3 - self.x1) *
(y - self.y1)) > 0
has_neg = (not sign1) or (not sign2) or (not sign3)
has_pos = sign1 or sign2 or sign3
return not has_neg or not has_pos
@dataclass(frozen=True)
class Rectangle(_Geom2D):
"""A helper class for visualizing and collision checking rectangles.
Following the convention in plt.Rectangle, the origin is at the
bottom left corner, and rotation is anti-clockwise about that point.
Unlike plt.Rectangle, the angle is in radians.
"""
x: float
y: float
width: float
height: float
theta: float # in radians, between -np.pi and np.pi
def __post_init__(self) -> None:
assert -np.pi <= self.theta <= np.pi, "Expecting angle in [-pi, pi]."
@functools.cached_property
def vertices(self) -> List[Tuple[float, float]]:
"""Get the four vertices for the rectangle."""
scale_matrix = np.array([
[self.width, 0],
[0, self.height],
])
rotate_matrix = np.array([[np.cos(self.theta), -np.sin(self.theta)],
[np.sin(self.theta),
np.cos(self.theta)]])
translate_vector = np.array([self.x, self.y])
vertices = np.array([
(0, 0),
(0, 1),
(1, 1),
(1, 0),
])
vertices = vertices @ scale_matrix.T
vertices = vertices @ rotate_matrix.T
vertices = translate_vector + vertices
# Convert to a list of tuples. Slightly complicated to appease both
# type checking and linting.
return list(map(lambda p: (p[0], p[1]), vertices))
@functools.cached_property
def line_segments(self) -> List[LineSegment]:
"""Get the four line segments for the rectangle."""
vs = list(zip(self.vertices, self.vertices[1:] + [self.vertices[0]]))
line_segments = []
for ((x1, y1), (x2, y2)) in vs:
line_segments.append(LineSegment(x1, y1, x2, y2))
return line_segments
@functools.cached_property
def center(self) -> Tuple[float, float]:
"""Get the point at the center of the rectangle."""
x, y = np.mean(self.vertices, axis=0)
return (x, y)
@functools.cached_property
def circumscribed_circle(self) -> Circle:
"""Returns x, y, radius."""
x, y = self.center
radius = np.sqrt((self.width / 2)**2 + (self.height / 2)**2)
return Circle(x, y, radius)
def contains_point(self, x: float, y: float) -> bool:
rotate_matrix = np.array([[np.cos(self.theta),
np.sin(self.theta)],
[-np.sin(self.theta),
np.cos(self.theta)]])
rx, ry = np.array([x - self.x, y - self.y]) @ rotate_matrix.T
return 0 <= rx <= self.width and \
0 <= ry <= self.height
def rotate_about_point(self, x: float, y: float, rot: float) -> Rectangle:
"""Create a new rectangle that is this rectangle, but rotated CCW by
the given rotation (in radians), relative to the (x, y) origin.
Rotates the vertices first, then uses them to recompute the new
theta.
"""
vertices = np.array(self.vertices)
origin = np.array([x, y])
# Translate the vertices so that they become the "origin".
vertices = vertices - origin
# Rotate.
rotate_matrix = np.array([[np.cos(rot), -np.sin(rot)],
[np.sin(rot), np.cos(rot)]])
vertices = vertices @ rotate_matrix.T
# Translate the vertices back.
vertices = vertices + origin
# Recompute theta.
(lx, ly), _, _, (rx, ry) = vertices
theta = np.arctan2(ry - ly, rx - lx)
rect = Rectangle(lx, ly, self.width, self.height, theta)
assert np.allclose(rect.vertices, vertices)
return rect
def plot(self, ax: plt.Axes, **kwargs: Any) -> None:
angle = self.theta * 180 / np.pi
patch = patches.Rectangle((self.x, self.y), self.width, self.height,
angle, **kwargs)
ax.add_patch(patch)
def line_segments_intersect(seg1: LineSegment, seg2: LineSegment) -> bool:
"""Checks if two line segments intersect.
This method, which works by checking relative orientation, allows
for collinearity, and only checks if each segment straddles the line
containing the other.
"""
def _subtract(a: Tuple[float, float], b: Tuple[float, float]) \
-> Tuple[float, float]:
x1, y1 = a
x2, y2 = b
return (x1 - x2), (y1 - y2)
def _cross_product(a: Tuple[float, float], b: Tuple[float, float]) \
-> float:
x1, y1 = b
x2, y2 = a
return x1 * y2 - x2 * y1
def _direction(a: Tuple[float, float], b: Tuple[float, float],
c: Tuple[float, float]) -> float:
return _cross_product(_subtract(a, c), _subtract(a, b))
p1 = (seg1.x1, seg1.y1)
p2 = (seg1.x2, seg1.y2)
p3 = (seg2.x1, seg2.y1)
p4 = (seg2.x2, seg2.y2)
d1 = _direction(p3, p4, p1)
d2 = _direction(p3, p4, p2)
d3 = _direction(p1, p2, p3)
d4 = _direction(p1, p2, p4)
return ((d2 < 0 < d1) or (d1 < 0 < d2)) and ((d4 < 0 < d3) or
(d3 < 0 < d4))
def circles_intersect(circ1: Circle, circ2: Circle) -> bool:
"""Checks if two circles intersect."""
x1, y1, r1 = circ1.x, circ1.y, circ1.radius
x2, y2, r2 = circ2.x, circ2.y, circ2.radius
return (x1 - x2)**2 + (y1 - y2)**2 < (r1 + r2)**2
def rectangles_intersect(rect1: Rectangle, rect2: Rectangle) -> bool:
"""Checks if two rectangles intersect."""
# Optimization: if the circumscribed circles don't intersect, then
# the rectangles also don't intersect.
if not circles_intersect(rect1.circumscribed_circle,
rect2.circumscribed_circle):
return False
# Case 1: line segments intersect.
if any(
line_segments_intersect(seg1, seg2) for seg1 in rect1.line_segments
for seg2 in rect2.line_segments):
return True
# Case 2: rect1 inside rect2.
if rect1.contains_point(rect2.center[0], rect2.center[1]):
return True
# Case 3: rect2 inside rect1.
if rect2.contains_point(rect1.center[0], rect1.center[1]):
return True
# Not intersecting.
return False
def line_segment_intersects_circle(seg: LineSegment,
circ: Circle,
ax: Optional[plt.Axes] = None) -> bool:
"""Checks if a line segment intersects a circle.
If ax is not None, a diagram is plotted on the axis to illustrate
the computations, which is useful for checking correctness.
"""
# First check if the end points of the segment are in the circle.
if circ.contains_point(seg.x1, seg.y1):
return True
if circ.contains_point(seg.x2, seg.y2):
return True
# Project the circle radius onto the extended line.
c = (circ.x, circ.y)
# Project (a, c) onto (a, b).
a = (seg.x1, seg.y1)
b = (seg.x2, seg.y2)
ba = np.subtract(b, a)
ca = np.subtract(c, a)
da = ba * np.dot(ca, ba) / np.dot(ba, ba)
# The point on the extended line that is the closest to the center.
d = dx, dy = (a[0] + da[0], a[1] + da[1])
# Optionally plot the important points.
if ax is not None:
circ.plot(ax, color="red", alpha=0.5)
seg.plot(ax, color="black", linewidth=2)
ax.annotate("A", a)
ax.annotate("B", b)
ax.annotate("C", c)
ax.annotate("D", d)
# Check if the point is on the line. If it's not, there is no intersection,
# because we already checked that the circle does not contain the end
# points of the line segment.
if not seg.contains_point(dx, dy):
return False
# So d is on the segment. Check if it's in the circle.
return circ.contains_point(dx, dy)
def line_segment_intersects_rectangle(seg: LineSegment,
rect: Rectangle) -> bool:
"""Checks if a line segment intersects a rectangle."""
# Case 1: one of the end points of the segment is in the rectangle.
if rect.contains_point(seg.x1, seg.y1) or \
rect.contains_point(seg.x2, seg.y2):
return True
# Case 2: the segment intersects with one of the rectangle sides.
return any(line_segments_intersect(s, seg) for s in rect.line_segments)
def rectangle_intersects_circle(rect: Rectangle, circ: Circle) -> bool:
"""Checks if a rectangle intersects a circle."""
# Optimization: if the circumscribed circle of the rectangle doesn't
# intersect with the circle, then there can't be an intersection.
if not circles_intersect(rect.circumscribed_circle, circ):
return False
# Case 1: the circle's center is in the rectangle.
if rect.contains_point(circ.x, circ.y):
return True
# Case 2: one of the sides of the rectangle intersects the circle.
for seg in rect.line_segments:
if line_segment_intersects_circle(seg, circ):
return True
return False
def geom2ds_intersect(geom1: _Geom2D, geom2: _Geom2D) -> bool:
"""Check if two 2D bodies intersect."""
if isinstance(geom1, LineSegment) and isinstance(geom2, LineSegment):
return line_segments_intersect(geom1, geom2)
if isinstance(geom1, LineSegment) and isinstance(geom2, Circle):
return line_segment_intersects_circle(geom1, geom2)
if isinstance(geom1, LineSegment) and isinstance(geom2, Rectangle):
return line_segment_intersects_rectangle(geom1, geom2)
if isinstance(geom1, Rectangle) and isinstance(geom2, LineSegment):
return line_segment_intersects_rectangle(geom2, geom1)
if isinstance(geom1, Circle) and isinstance(geom2, LineSegment):
return line_segment_intersects_circle(geom2, geom1)
if isinstance(geom1, Rectangle) and isinstance(geom2, Rectangle):
return rectangles_intersect(geom1, geom2)
if isinstance(geom1, Rectangle) and isinstance(geom2, Circle):
return rectangle_intersects_circle(geom1, geom2)
if isinstance(geom1, Circle) and isinstance(geom2, Rectangle):
return rectangle_intersects_circle(geom2, geom1)
if isinstance(geom1, Circle) and isinstance(geom2, Circle):
return circles_intersect(geom1, geom2)
raise NotImplementedError("Intersection not implemented for geoms "
f"{geom1} and {geom2}")
@functools.lru_cache(maxsize=None)
def unify(atoms1: FrozenSet[LiftedOrGroundAtom],
atoms2: FrozenSet[LiftedOrGroundAtom]) -> Tuple[bool, EntToEntSub]:
"""Return whether the given two sets of atoms can be unified.
Also return the mapping between variables/objects in these atom
sets. This mapping is empty if the first return value is False.
"""
atoms_lst1 = sorted(atoms1)
atoms_lst2 = sorted(atoms2)
# Terminate quickly if there is a mismatch between predicates
preds1 = [atom.predicate for atom in atoms_lst1]
preds2 = [atom.predicate for atom in atoms_lst2]
if preds1 != preds2:
return False, {}
# Terminate quickly if there is a mismatch between numbers
num1 = len({o for atom in atoms_lst1 for o in atom.entities})
num2 = len({o for atom in atoms_lst2 for o in atom.entities})
if num1 != num2:
return False, {}
# Try to get lucky with a one-to-one mapping
subs12: EntToEntSub = {}
subs21 = {}
success = True
for atom1, atom2 in zip(atoms_lst1, atoms_lst2):
if not success:
break
for v1, v2 in zip(atom1.entities, atom2.entities):
if v1 in subs12 and subs12[v1] != v2:
success = False
break
if v2 in subs21:
success = False
break
subs12[v1] = v2
subs21[v2] = v1
if success:
return True, subs12
# If all else fails, use search
solved, sub = find_substitution(atoms_lst1, atoms_lst2)
rev_sub = {v: k for k, v in sub.items()}
return solved, rev_sub
@functools.lru_cache(maxsize=None)
def unify_preconds_effects_options(
preconds1: FrozenSet[LiftedOrGroundAtom],
preconds2: FrozenSet[LiftedOrGroundAtom],
add_effects1: FrozenSet[LiftedOrGroundAtom],
add_effects2: FrozenSet[LiftedOrGroundAtom],
delete_effects1: FrozenSet[LiftedOrGroundAtom],
delete_effects2: FrozenSet[LiftedOrGroundAtom],
param_option1: ParameterizedOption, param_option2: ParameterizedOption,
option_args1: Tuple[_TypedEntity, ...],
option_args2: Tuple[_TypedEntity, ...]) -> Tuple[bool, EntToEntSub]:
"""Wrapper around unify() that handles option arguments, preconditions, add
effects, and delete effects.
Changes predicate names so that all are treated differently by
unify().
"""
if param_option1 != param_option2:
# Can't unify if the parameterized options are different.
return False, {}
opt_arg_pred1 = Predicate("OPT-ARGS", [a.type for a in option_args1],
_classifier=lambda s, o: False) # dummy
f_option_args1 = frozenset({GroundAtom(opt_arg_pred1, option_args1)})
new_preconds1 = wrap_atom_predicates(preconds1, "PRE-")
f_new_preconds1 = frozenset(new_preconds1)
new_add_effects1 = wrap_atom_predicates(add_effects1, "ADD-")
f_new_add_effects1 = frozenset(new_add_effects1)
new_delete_effects1 = wrap_atom_predicates(delete_effects1, "DEL-")
f_new_delete_effects1 = frozenset(new_delete_effects1)
opt_arg_pred2 = Predicate("OPT-ARGS", [a.type for a in option_args2],
_classifier=lambda s, o: False) # dummy
f_option_args2 = frozenset({LiftedAtom(opt_arg_pred2, option_args2)})
new_preconds2 = wrap_atom_predicates(preconds2, "PRE-")
f_new_preconds2 = frozenset(new_preconds2)
new_add_effects2 = wrap_atom_predicates(add_effects2, "ADD-")
f_new_add_effects2 = frozenset(new_add_effects2)
new_delete_effects2 = wrap_atom_predicates(delete_effects2, "DEL-")
f_new_delete_effects2 = frozenset(new_delete_effects2)
all_atoms1 = (f_option_args1 | f_new_preconds1 | f_new_add_effects1
| f_new_delete_effects1)
all_atoms2 = (f_option_args2 | f_new_preconds2 | f_new_add_effects2
| f_new_delete_effects2)
return unify(all_atoms1, all_atoms2)
def wrap_predicate(predicate: Predicate, prefix: str) -> Predicate:
"""Return a new predicate which adds the given prefix string to the name.
NOTE: the classifier is removed.
"""
new_predicate = Predicate(prefix + predicate.name,
predicate.types,
_classifier=lambda s, o: False) # dummy
return new_predicate
def wrap_atom_predicates(atoms: Collection[LiftedOrGroundAtom],
prefix: str) -> Set[LiftedOrGroundAtom]:
"""Return a new set of atoms which adds the given prefix string to the name
of every atom's predicate.
NOTE: all the classifiers are removed.
"""
new_atoms = set()
for atom in atoms:
new_predicate = wrap_predicate(atom.predicate, prefix)
new_atoms.add(atom.__class__(new_predicate, atom.entities))
return new_atoms
class LinearChainParameterizedOption(ParameterizedOption):
"""A parameterized option implemented via a sequence of "child"
parameterized options.
This class is meant to help ParameterizedOption manual design.
The children are executed in order starting with the first in the sequence
and transitioning when the terminal function of each child is hit.
The children are assumed to chain together, so the initiable of the next
child should always be True when the previous child terminates. If this
is not the case, an AssertionError is raised.
The children must all have the same types and params_space, which in turn
become the types and params_space for this ParameterizedOption.
The LinearChainParameterizedOption has memory, which stores the current
child index.
"""
def __init__(self, name: str,
children: Sequence[ParameterizedOption]) -> None:
assert len(children) > 0
self._children = children
# Make sure that the types and params spaces are consistent.
types = children[0].types
params_space = children[0].params_space
for i in range(1, len(self._children)):
child = self._children[i]
assert types == child.types
assert np.allclose(params_space.low, child.params_space.low)
assert np.allclose(params_space.high, child.params_space.high)
super().__init__(name,
types,
params_space,
policy=self._policy,
initiable=self._initiable,
terminal=self._terminal)
def _initiable(self, state: State, memory: Dict, objects: Sequence[Object],
params: Array) -> bool:
# Initialize the current child to the first one.
memory["current_child_index"] = 0
# Create memory dicts for each child to avoid key collisions. One
# example of a failure that arises without this is when using
# multiple SingletonParameterizedOption instances, each of those
# options would be referencing the same start_state in memory.
memory["child_memory"] = [{} for _ in self._children]
current_child = self._children[0]
child_memory = memory["child_memory"][0]
return current_child.initiable(state, child_memory, objects, params)
def _policy(self, state: State, memory: Dict, objects: Sequence[Object],
params: Array) -> Action:
# Check if the current child has terminated.
current_index = memory["current_child_index"]
current_child = self._children[current_index]
child_memory = memory["child_memory"][current_index]
if current_child.terminal(state, child_memory, objects, params):
# Move on to the next child.
current_index += 1
memory["current_child_index"] = current_index
current_child = self._children[current_index]
child_memory = memory["child_memory"][current_index]
assert current_child.initiable(state, child_memory, objects,
params)
return current_child.policy(state, child_memory, objects, params)
def _terminal(self, state: State, memory: Dict, objects: Sequence[Object],
params: Array) -> bool:
# Check if the last child has terminated.
current_index = memory["current_child_index"]
if current_index < len(self._children) - 1:
return False
current_child = self._children[current_index]
child_memory = memory["child_memory"][current_index]
return current_child.terminal(state, child_memory, objects, params)
class SingletonParameterizedOption(ParameterizedOption):
"""A parameterized option that takes a single action and stops.
For convenience:
* Initiable defaults to always True.
* Types defaults to [].
* Params space defaults to Box(0, 1, (0, )).
"""
def __init__(
self,
name: str,
policy: Callable[[State, Dict, Sequence[Object], Array], Action],
types: Optional[Sequence[Type]] = None,
params_space: Optional[Box] = None,
initiable: Optional[Callable[[State, Dict, Sequence[Object], Array],
bool]] = None
) -> None:
if types is None:
types = []
if params_space is None:
params_space = Box(0, 1, (0, ))
if initiable is None:
initiable = lambda _1, _2, _3, _4: True
# Wrap the given initiable so that we can track whether the action
# has been executed yet.
def _initiable(state: State, memory: Dict, objects: Sequence[Object],
params: Array) -> bool:
if "start_state" in memory:
assert state.allclose(memory["start_state"])
# Always update the memory dict due to the "is" check in _terminal.
memory["start_state"] = state
assert initiable is not None
return initiable(state, memory, objects, params)
def _terminal(state: State, memory: Dict, objects: Sequence[Object],
params: Array) -> bool:
del objects, params # unused
assert "start_state" in memory, \
"Must call initiable() before terminal()."
return state is not memory["start_state"]
super().__init__(name,
types,
params_space,
policy=policy,
initiable=_initiable,
terminal=_terminal)
class BehaviorState(State):
"""A BEHAVIOR state that stores the index of the temporary BEHAVIOR state
folder in addition to the features that are exposed in the object-centric
state."""
def allclose(self, other: State) -> bool:
# Ignores the simulator state.
return State(self.data).allclose(State(other.data))
class PyBulletState(State):
"""A PyBullet state that stores the robot joint positions in addition to
the features that are exposed in the object-centric state."""
@property
def joint_positions(self) -> JointPositions:
"""Expose the current joints state in the simulator_state."""
return cast(JointPositions, self.simulator_state)
def allclose(self, other: State) -> bool:
# Ignores the simulator state.
return State(self.data).allclose(State(other.data))
def copy(self) -> State:
state_dict_copy = super().copy().data
simulator_state_copy = list(self.joint_positions)
return PyBulletState(state_dict_copy, simulator_state_copy)
class Monitor(abc.ABC):
"""Observes states and actions during environment interaction."""
@abc.abstractmethod
def observe(self, state: State, action: Optional[Action]) -> None:
"""Record a state and the action that is about to be taken.
On the last timestep of a trajectory, no action is taken, so
action is None.
"""
raise NotImplementedError("Override me!")
def run_policy(
policy: Callable[[State], Action],
env: BaseEnv,
train_or_test: str,
task_idx: int,
termination_function: Callable[[State], bool],
max_num_steps: int,
exceptions_to_break_on: Optional[Set[TypingType[Exception]]] = None,
monitor: Optional[Monitor] = None
) -> Tuple[LowLevelTrajectory, Metrics]:
"""Execute a policy starting from the initial state of a train or test task
in the environment. The task's goal is not used.
Note that the environment internal state is updated.
Terminates when any of these conditions hold:
(1) the termination_function returns True
(2) max_num_steps is reached
(3) policy() or step() raise an exception of type in exceptions_to_break_on
Note that in the case where the exception is raised in step, we exclude the
last action from the returned trajectory to maintain the invariant that
the trajectory states are of length one greater than the actions.
"""
state = env.reset(train_or_test, task_idx)
states = [state]
actions: List[Action] = []
metrics: Metrics = defaultdict(float)
metrics["policy_call_time"] = 0.0
exception_raised_in_step = False
if not termination_function(state):
for _ in range(max_num_steps):
monitor_observed = False
exception_raised_in_step = False
try:
start_time = time.perf_counter()
act = policy(state)
metrics["policy_call_time"] += time.perf_counter() - start_time
# Note: it's important to call monitor.observe() before
# env.step(), because the monitor may use the environment's
# internal state.
if monitor is not None:
monitor.observe(state, act)
monitor_observed = True
state = env.step(act)
actions.append(act)
states.append(state)
except Exception as e:
if exceptions_to_break_on is not None and \
type(e) in exceptions_to_break_on:
if monitor_observed:
exception_raised_in_step = True
break
if monitor is not None and not monitor_observed:
monitor.observe(state, None)
raise e
if termination_function(state):
break
if monitor is not None and not exception_raised_in_step:
monitor.observe(state, None)
traj = LowLevelTrajectory(states, actions)
return traj, metrics
def run_policy_with_simulator(
policy: Callable[[State], Action],
simulator: Callable[[State, Action], State],
init_state: State,
termination_function: Callable[[State], bool],
max_num_steps: int,
exceptions_to_break_on: Optional[Set[TypingType[Exception]]] = None,
monitor: Optional[Monitor] = None) -> LowLevelTrajectory:
"""Execute a policy from a given initial state, using a simulator.
*** This function should not be used with any core code, because we want
to avoid the assumption of a simulator when possible. ***
This is similar to run_policy, with three major differences:
(1) The initial state `init_state` can be any state, not just the initial
state of a train or test task. (2) A simulator (function that takes state
as input) is assumed. (3) Metrics are not returned.
Note that the environment internal state is NOT updated.
Terminates when any of these conditions hold:
(1) the termination_function returns True
(2) max_num_steps is reached
(3) policy() or step() raise an exception of type in exceptions_to_break_on
Note that in the case where the exception is raised in step, we exclude the
last action from the returned trajectory to maintain the invariant that
the trajectory states are of length one greater than the actions.
"""
state = init_state
states = [state]
actions: List[Action] = []
exception_raised_in_step = False
if not termination_function(state):
for _ in range(max_num_steps):
monitor_observed = False
exception_raised_in_step = False
try:
act = policy(state)
if monitor is not None:
monitor.observe(state, act)
monitor_observed = True
state = simulator(state, act)
actions.append(act)
states.append(state)
except Exception as e:
if exceptions_to_break_on is not None and \
type(e) in exceptions_to_break_on:
if monitor_observed:
exception_raised_in_step = True
break
if monitor is not None and not monitor_observed:
monitor.observe(state, None)
raise e
if termination_function(state):
break
if monitor is not None and not exception_raised_in_step:
monitor.observe(state, None)
traj = LowLevelTrajectory(states, actions)
return traj
class ExceptionWithInfo(Exception):
"""An exception with an optional info dictionary that is initially
empty."""
def __init__(self, message: str, info: Optional[Dict] = None) -> None:
super().__init__(message)
if info is None:
info = {}
assert isinstance(info, dict)
self.info = info
class OptionExecutionFailure(ExceptionWithInfo):