|
44 | 44 | level3_actions = pd.read_csv("./v" + version + "/lev3/results/actions_" + control + "_param=0.0.csv",index_col=0,parse_dates=True) |
45 | 45 | level3_states = pd.read_csv("./v" + version + "/lev3/results/states_" + control + "_param=0.0.csv",index_col=0,parse_dates=True) |
46 | 46 |
|
| 47 | +# don't plot the states with "C" in them |
| 48 | +states_to_plot = [state for state in env.config['states'] if "C" not in state[0]] |
47 | 49 |
|
48 | | - |
49 | | -plots_high = max(len(env.config['action_space']) , len(env.config['states'])) |
| 50 | +plots_high = max(len(env.config['action_space']) , len(states_to_plot)) |
50 | 51 |
|
51 | 52 | fig = plt.figure(figsize=(10,2*plots_high)) |
52 | 53 | gs = GridSpec(plots_high,2,figure=fig) |
|
73 | 74 | ax.annotate(str(env.config['action_space'][idx]), xy=(0.5, 0.8), xycoords='axes fraction', ha='center', va='center',fontsize='xx-large') |
74 | 75 |
|
75 | 76 | # plot the states |
76 | | -for idx in range(len(env.config['states'])): |
| 77 | +for idx in range(len(states_to_plot)): |
77 | 78 | ax = fig.add_subplot(gs[idx,1] ) |
78 | | - ax.plot(uncontrolled_states.index, uncontrolled_states[str(env.config['states'][idx])], label='Uncontrolled',color='black',alpha=0.6) |
79 | | - ax.plot(level1_states.index, level1_states[str(env.config['states'][idx])], label='Level 1',color='blue',alpha=0.6) |
80 | | - ax.plot(level2_states.index, level2_states[str(env.config['states'][idx])], label='Level 2',color='green',alpha=0.6) |
81 | | - ax.plot(level3_states.index, level3_states[str(env.config['states'][idx])], label='Level 3',color='red',alpha=0.6) |
| 79 | + ax.plot(uncontrolled_states.index, uncontrolled_states[str(states_to_plot[idx])], label='Uncontrolled',color='black',alpha=0.6) |
| 80 | + ax.plot(level1_states.index, level1_states[str(states_to_plot[idx])], label='Level 1',color='blue',alpha=0.6) |
| 81 | + ax.plot(level2_states.index, level2_states[str(states_to_plot[idx])], label='Level 2',color='green',alpha=0.6) |
| 82 | + ax.plot(level3_states.index, level3_states[str(states_to_plot[idx])], label='Level 3',color='red',alpha=0.6) |
82 | 83 |
|
83 | | - ax.annotate(str(env.config['states'][idx]), xy=(0.5, 0.8), xycoords='axes fraction', ha='center', va='center',fontsize='xx-large') |
| 84 | + ax.annotate(str(states_to_plot[idx]), xy=(0.5, 0.8), xycoords='axes fraction', ha='center', va='center',fontsize='xx-large') |
84 | 85 |
|
85 | 86 |
|
86 | 87 |
|
87 | | - if idx == len(env.config['states']) - 1: |
| 88 | + if idx == len(states_to_plot) - 1: |
88 | 89 | ax.set_xlabel("time") |
89 | 90 | # just add ticks in the beginning, middle, and end of the index |
90 | 91 | ax.set_xticks([level1_states.index[0],level1_states.index[int(len(level1_states.index)/2)],level1_states.index[-1]]) |
91 | 92 |
|
92 | 93 | if idx == 0: |
93 | 94 | ax.set_title("States") |
94 | | - if idx != len(env.config['states']) - 1: # not the last row |
| 95 | + if idx != len(states_to_plot) - 1: # not the last row |
95 | 96 | ax.set_xticks([]) |
96 | 97 | ax.set_xticklabels([]) |
97 | 98 |
|
98 | | - if idx == len(env.config['states']) - 2: # second to last row, for the legend |
| 99 | + if idx == len(states_to_plot) - 2: # second to last row, for the legend |
99 | 100 | ax = fig.add_subplot(gs[idx,0]) |
100 | 101 | ax.plot(uncontrolled_states.index[0:2], np.zeros((2,1)), label = 'Uncontrolled',color='black',alpha=0.6) |
101 | 102 | ax.plot(level1_states.index[0:2], np.zeros((2,1)), label = 'Level 1',color='blue',alpha=0.6) |
|
0 commit comments