|
27 | 27 | import re |
28 | 28 | from matplotlib.animation import FuncAnimation |
29 | 29 | from functools import partial |
30 | | -from typing import Callable |
31 | 30 |
|
32 | 31 | from libemg.utils import get_windows |
33 | 32 | from libemg.environments.controllers import RegressorController, ClassifierController |
@@ -308,34 +307,21 @@ def add_majority_vote(self, num_samples=5): |
308 | 307 |
|
309 | 308 | def add_velocity(self, train_windows, train_labels, |
310 | 309 | velocity_metric_handle = None, |
311 | | - velocity_mapping_handle: str | None | Callable[[int], int] = None): |
| 310 | + velocity_mapping_handle = None): |
312 | 311 | """Adds velocity (i.e., proportional) control where a multiplier is generated for the level of contraction intensity. |
313 | 312 |
|
314 | 313 | Note, that when using this optional, ramp contractions should be captured for training. |
315 | 314 |
|
316 | 315 | Parameters |
317 | 316 | ----------- |
318 | | - train_windows: np.ndarray |
319 | | - The training windows extracted from the offline data handler. |
320 | | - train_labels: np.ndarray |
321 | | - The labels associated with the train windows. Allows for per class proportional control mapping. |
322 | | - velocity_mapping_handle: function or a string (valid options: "SIGMOID", "SQUARED", "LOG", "RELU") |
323 | | - A function that maps the proportionality (bounded between 0-1) to some other value. |
324 | 317 | """ |
325 | | - if isinstance(velocity_mapping_handle, str): |
326 | | - if not velocity_mapping_handle in ['SIGMOID', 'SQUARED', 'LOG', 'RELU']: |
327 | | - print("Invalid velocity mapping... Defaulting to linear.") |
328 | | - else: |
329 | | - if velocity_mapping_handle == 'SQUARED': |
330 | | - self.velocity_mapping_handle = lambda x: x**2 |
331 | | - # TODO: Fill out rest |
332 | | - else: |
333 | | - self.velocity_metric_handle = velocity_metric_handle |
334 | | - self.velocity_mapping_handle = velocity_mapping_handle |
335 | | - |
| 318 | + self.velocity_metric_handle = velocity_metric_handle |
| 319 | + self.velocity_mapping_handle = velocity_mapping_handle |
336 | 320 | self.velocity = True |
| 321 | + |
337 | 322 | self.th_min_dic, self.th_max_dic = self._set_up_velocity_control(train_windows, train_labels) |
338 | 323 |
|
| 324 | + |
339 | 325 |
|
340 | 326 | ''' |
341 | 327 | ---------------------- Private Helper Functions ---------------------- |
@@ -634,7 +620,10 @@ def __init__(self, |
634 | 620 | ["adapt_flag", (1,1), np.int32], |
635 | 621 | ["active_flag", (1,1), np.int8] |
636 | 622 | ] |
637 | | - smm_items.extend(required_smm_items) |
| 623 | + current_smm_tags = [item[0] for item in smm_items] |
| 624 | + for smm_item in required_smm_items: |
| 625 | + if smm_item[0] not in current_smm_tags: |
| 626 | + smm_items.append(smm_item) |
638 | 627 | self.smm = smm |
639 | 628 | self.smm_items = smm_items |
640 | 629 |
|
@@ -763,7 +752,7 @@ def _run_helper(self): |
763 | 752 | model_input = None |
764 | 753 | for mod in self.odh.modalities: |
765 | 754 | # todo: features for each modality can be different |
766 | | - mod_features = fe.extract_features(self.features, window[mod], array=True) |
| 755 | + mod_features = fe.extract_features(self.features, window[mod], feature_dic=self.predictor.feature_params, array=True) |
767 | 756 | if model_input is None: |
768 | 757 | model_input = mod_features |
769 | 758 | else: |
@@ -963,13 +952,8 @@ def insert_classifier_output(data): |
963 | 952 | insert_classifier_output) |
964 | 953 | self.options['model_smm_writes'] += 1 |
965 | 954 |
|
966 | | - if self.output_format == "predictions": |
967 | | - message = str(prediction) + calculated_velocity + '\n' |
968 | | - elif self.output_format == "probabilities": |
969 | | - message = ' '.join([f'{i:.2f}' for i in probabilities[0]]) + calculated_velocity + " " + str(time_stamp) |
970 | | - else: |
971 | | - raise ValueError(f"Unexpected value for output_format. Accepted values are 'predictions' and 'probabilities'. Got: {self.output_format}.") |
972 | | - |
| 955 | + message = str(prediction) + " " + str(np.abs(np.array(window['emg'])).mean(axis=2).mean()) + str(calculated_velocity) |
| 956 | + |
973 | 957 | if not self.tcp: |
974 | 958 | self.sock.sendto(bytes(message, 'utf-8'), (self.ip, self.port)) |
975 | 959 | else: |
@@ -997,7 +981,6 @@ def visualize(self, max_len=50, legend=None): |
997 | 981 | cmap = cm.get_cmap('turbo', num_classes) |
998 | 982 |
|
999 | 983 | controller = ClassifierController(output_format=self.output_format, num_classes=num_classes, ip=self.ip, port=self.port) |
1000 | | - controller.start() |
1001 | 984 |
|
1002 | 985 | if legend is not None: |
1003 | 986 | for i in range(num_classes): |
@@ -1184,7 +1167,6 @@ def visualize(self, max_len = 50, legend = False): |
1184 | 1167 | ax.set_ylabel('Prediction') |
1185 | 1168 |
|
1186 | 1169 | controller = RegressorController(ip=self.ip, port=self.port) |
1187 | | - controller.start() |
1188 | 1170 |
|
1189 | 1171 | # Wait for controller to start receiving data |
1190 | 1172 | predictions = None |
|
0 commit comments