Skip to content

Commit d804676

Browse files
committed
Working for CHI
1 parent b5248a5 commit d804676

2 files changed

Lines changed: 99 additions & 30 deletions

File tree

libemg/discrete.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import numpy as np
2+
import torch.nn.functional as F
3+
import torch
4+
from playsound import playsound
5+
import matplotlib.pyplot as plt
6+
from libemg.feature_extractor import FeatureExtractor
7+
from libemg.utils import get_windows
8+
import pyautogui
9+
import time
10+
import statistics
11+
12+
class DiscreteControl:
13+
"""
14+
A class for detecting gestures using the Teager-Kaiser energy operator on EMG signals.
15+
"""
16+
17+
def __init__(self, odh, window_size, increment, threshold=100, buffer=20, subject=None, model=None):
18+
self.odh = odh
19+
self.window_size = window_size
20+
self.increment = increment
21+
self.threshold = threshold
22+
self.buffer_size = buffer
23+
self.subject = subject
24+
self.model = model
25+
26+
def run(self):
27+
"""
28+
Main loop for gesture detection.
29+
Continuously monitors EMG data and detects gestures based on energy thresholds.
30+
"""
31+
gesture_mapping = ['Nothing', 'Close', 'Flexion', 'Extension', 'Open', 'Pinch']
32+
expected_count = 250
33+
34+
while True:
35+
buffer = []
36+
37+
# Get and process EMG data
38+
data, counts = self.odh.get_data(self.window_size)
39+
if counts['emg'][0][0] >= expected_count:
40+
data, counts = self.odh.get_data(250)
41+
emg = data['emg'][::-1]
42+
feats = self.get_features([emg], 10, 5, None, None)
43+
pred, _ = self.predict(feats[0])
44+
buffer.append(pred)
45+
mode_pred = statistics.mode(buffer[-20:])
46+
if mode_pred != 0:
47+
print(str(time.time()) + ' ' + gesture_mapping[mode_pred])
48+
self.key_press(mode_pred, gesture_mapping)
49+
self.odh.reset()
50+
expected_count = 250
51+
buffer = []
52+
else:
53+
expected_count += 10
54+
55+
def key_press(self, pred, mapping):
56+
if mapping[pred] == 'Close':
57+
pyautogui.press('c')
58+
elif mapping[pred] == 'Flexion':
59+
pyautogui.press('f')
60+
elif mapping[pred] == 'Extension':
61+
pyautogui.press('e')
62+
elif mapping[pred] == 'Open':
63+
pyautogui.press('o')
64+
elif mapping[pred] == 'Pinch':
65+
pyautogui.press('p')
66+
playsound('Other/click.wav')
67+
68+
def predict(self, gest, device='cpu'):
69+
g_tensor = torch.tensor([gest], dtype=torch.float32).to(device)
70+
with torch.no_grad():
71+
output = self.model.forward_once(g_tensor)
72+
pred = output.argmax(dim=1).item()
73+
prob = F.softmax(output, dim=1).max().item()
74+
return pred, prob
75+
76+
def get_features(self, data, window_size, window_inc, feats, feat_dic):
77+
fe = FeatureExtractor()
78+
data = np.array([get_windows(d, window_size, window_inc) for d in data], dtype='object')
79+
if feats is None:
80+
return data
81+
if feat_dic is not None:
82+
feats = np.array([fe.extract_features(feats, d, array=True, feature_dic=feat_dic) for d in data], dtype='object')
83+
else:
84+
feats = np.array([fe.extract_features(feats, np.array(d, dtype='float'), array=True) for d in data], dtype='object')
85+
feats = np.nan_to_num(feats, copy=True, nan=0, posinf=0, neginf=0)
86+
# expected shape: (NFiles,) -> (Time, channel)
87+
return feats

libemg/emg_predictor.py

Lines changed: 12 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import re
2828
from matplotlib.animation import FuncAnimation
2929
from functools import partial
30-
from typing import Callable
3130

3231
from libemg.utils import get_windows
3332
from libemg.environments.controllers import RegressorController, ClassifierController
@@ -308,34 +307,21 @@ def add_majority_vote(self, num_samples=5):
308307

309308
def add_velocity(self, train_windows, train_labels,
310309
velocity_metric_handle = None,
311-
velocity_mapping_handle: str | None | Callable[[int], int] = None):
310+
velocity_mapping_handle = None):
312311
"""Adds velocity (i.e., proportional) control where a multiplier is generated for the level of contraction intensity.
313312
314313
Note, that when using this optional, ramp contractions should be captured for training.
315314
316315
Parameters
317316
-----------
318-
train_windows: np.ndarray
319-
The training windows extracted from the offline data handler.
320-
train_labels: np.ndarray
321-
The labels associated with the train windows. Allows for per class proportional control mapping.
322-
velocity_mapping_handle: function or a string (valid options: "SIGMOID", "SQUARED", "LOG", "RELU")
323-
A function that maps the proportionality (bounded between 0-1) to some other value.
324317
"""
325-
if isinstance(velocity_mapping_handle, str):
326-
if not velocity_mapping_handle in ['SIGMOID', 'SQUARED', 'LOG', 'RELU']:
327-
print("Invalid velocity mapping... Defaulting to linear.")
328-
else:
329-
if velocity_mapping_handle == 'SQUARED':
330-
self.velocity_mapping_handle = lambda x: x**2
331-
# TODO: Fill out rest
332-
else:
333-
self.velocity_metric_handle = velocity_metric_handle
334-
self.velocity_mapping_handle = velocity_mapping_handle
335-
318+
self.velocity_metric_handle = velocity_metric_handle
319+
self.velocity_mapping_handle = velocity_mapping_handle
336320
self.velocity = True
321+
337322
self.th_min_dic, self.th_max_dic = self._set_up_velocity_control(train_windows, train_labels)
338323

324+
339325

340326
'''
341327
---------------------- Private Helper Functions ----------------------
@@ -634,7 +620,10 @@ def __init__(self,
634620
["adapt_flag", (1,1), np.int32],
635621
["active_flag", (1,1), np.int8]
636622
]
637-
smm_items.extend(required_smm_items)
623+
current_smm_tags = [item[0] for item in smm_items]
624+
for smm_item in required_smm_items:
625+
if smm_item[0] not in current_smm_tags:
626+
smm_items.append(smm_item)
638627
self.smm = smm
639628
self.smm_items = smm_items
640629

@@ -763,7 +752,7 @@ def _run_helper(self):
763752
model_input = None
764753
for mod in self.odh.modalities:
765754
# todo: features for each modality can be different
766-
mod_features = fe.extract_features(self.features, window[mod], array=True)
755+
mod_features = fe.extract_features(self.features, window[mod], feature_dic=self.predictor.feature_params, array=True)
767756
if model_input is None:
768757
model_input = mod_features
769758
else:
@@ -963,13 +952,8 @@ def insert_classifier_output(data):
963952
insert_classifier_output)
964953
self.options['model_smm_writes'] += 1
965954

966-
if self.output_format == "predictions":
967-
message = str(prediction) + calculated_velocity + '\n'
968-
elif self.output_format == "probabilities":
969-
message = ' '.join([f'{i:.2f}' for i in probabilities[0]]) + calculated_velocity + " " + str(time_stamp)
970-
else:
971-
raise ValueError(f"Unexpected value for output_format. Accepted values are 'predictions' and 'probabilities'. Got: {self.output_format}.")
972-
955+
message = str(prediction) + " " + str(np.abs(np.array(window['emg'])).mean(axis=2).mean()) + str(calculated_velocity)
956+
973957
if not self.tcp:
974958
self.sock.sendto(bytes(message, 'utf-8'), (self.ip, self.port))
975959
else:
@@ -997,7 +981,6 @@ def visualize(self, max_len=50, legend=None):
997981
cmap = cm.get_cmap('turbo', num_classes)
998982

999983
controller = ClassifierController(output_format=self.output_format, num_classes=num_classes, ip=self.ip, port=self.port)
1000-
controller.start()
1001984

1002985
if legend is not None:
1003986
for i in range(num_classes):
@@ -1184,7 +1167,6 @@ def visualize(self, max_len = 50, legend = False):
11841167
ax.set_ylabel('Prediction')
11851168

11861169
controller = RegressorController(ip=self.ip, port=self.port)
1187-
controller.start()
11881170

11891171
# Wait for controller to start receiving data
11901172
predictions = None

0 commit comments

Comments
 (0)