Skip to content

Commit f96699e

Browse files
committed
Updated the buffer
1 parent 1d73c00 commit f96699e

1 file changed

Lines changed: 47 additions & 35 deletions

File tree

libemg/discrete.py

Lines changed: 47 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import numpy as np
22
import torch.nn.functional as F
33
import torch
4-
from playsound import playsound
5-
import matplotlib.pyplot as plt
64
from libemg.feature_extractor import FeatureExtractor
75
from libemg.utils import get_windows
86
import pyautogui
@@ -11,69 +9,84 @@
119

1210
class DiscreteControl:
1311
"""
14-
A class for detecting gestures using the Teager-Kaiser energy operator on EMG signals.
12+
The temporary discrete control class for interfacing the cross-user Myo model made available at: <insert git repo here>.
13+
The model currently supports 5 gestures: Close, Flexion, Extension, Open, Pinch.
14+
These gestures can be mapped to keyboard keys for controlling applications.
15+
16+
Parameters
17+
----------
18+
odh: OnlineDataHandler
19+
The online data handler object for streaming EMG data.
20+
window_size: int
21+
The window size (in samples) to use for splitting up each template.
22+
increment: int
23+
The increment size (in samples) for the sliding window.
24+
model: torch.nn.Module
25+
The trained PyTorch model for gesture classification.
26+
buffer: int, optional
27+
The size of the prediction buffer to use for mode filtering. Default is 1.
28+
template_size: int, optional
29+
The size of each EMG template (in samples). Default is 250 (1.5s for the Myo Armband).
30+
min_template_size: int, optional
31+
The minimum number of samples required before starting to make predictions (helps reduce the delay needed between subsequent gestures). Default is 100.
32+
key_mapping: dict, optional
33+
A dictionary mapping gesture names to keyboard keys. Default maps 'Close' to 'c', 'Flexion' to 'f', 'Extension' to 'e', 'Open' to 'o', and 'Pinch' to 'p'.
34+
debug: bool, optional
35+
If True, enables debug mode with additional print statements. Default is True.
1536
"""
16-
17-
def __init__(self, odh, window_size, increment, threshold=100, buffer=20, subject=None, model=None):
37+
def __init__(self, odh, window_size, increment, model, buffer=5, template_size=250, min_template_size=150, key_mapping={'Close':'c', 'Flexion':'f', 'Extension':'e', 'Open':'o', 'Pinch':'p'}, debug=True):
1838
self.odh = odh
1939
self.window_size = window_size
2040
self.increment = increment
21-
self.threshold = threshold
2241
self.buffer_size = buffer
23-
self.subject = subject
2442
self.model = model
43+
self.template_size = template_size
44+
self.min_template_size = min_template_size
45+
self.key_mapping = key_mapping
46+
self.debug = debug
2547

2648
def run(self):
2749
"""
2850
Main loop for gesture detection.
29-
Continuously monitors EMG data and detects gestures based on energy thresholds.
51+
Runs a sliding window over incoming EMG data and makes predictions based on the trained model.
3052
"""
3153
gesture_mapping = ['Nothing', 'Close', 'Flexion', 'Extension', 'Open', 'Pinch']
32-
expected_count = 250
54+
expected_count = self.min_template_size
55+
buffer = []
3356

3457
while True:
35-
buffer = []
36-
3758
# Get and process EMG data
38-
data, counts = self.odh.get_data(self.window_size)
59+
_, counts = self.odh.get_data(self.window_size)
3960
if counts['emg'][0][0] >= expected_count:
40-
data, counts = self.odh.get_data(250)
61+
data, counts = self.odh.get_data(self.template_size)
4162
emg = data['emg'][::-1]
42-
feats = self.get_features([emg], 10, 5, None, None)
43-
pred, _ = self.predict(feats[0])
63+
feats = self._get_features([emg], self.window_size, self.increment, None, None)
64+
pred, _ = self._predict(feats[0])
4465
buffer.append(pred)
45-
mode_pred = statistics.mode(buffer[-20:])
66+
mode_pred = statistics.mode(buffer[-self.buffer_size:])
4667
if mode_pred != 0:
47-
print(str(time.time()) + ' ' + gesture_mapping[mode_pred])
48-
self.key_press(mode_pred, gesture_mapping)
68+
if self.debug:
69+
print(str(time.time()) + ' ' + gesture_mapping[mode_pred])
70+
self._key_press(mode_pred, gesture_mapping)
4971
self.odh.reset()
50-
expected_count = 250
72+
expected_count = self.min_template_size
5173
buffer = []
5274
else:
5375
expected_count += 10
5476

55-
def key_press(self, pred, mapping):
56-
if mapping[pred] == 'Close':
57-
pyautogui.press('c')
58-
elif mapping[pred] == 'Flexion':
59-
pyautogui.press('f')
60-
elif mapping[pred] == 'Extension':
61-
pyautogui.press('e')
62-
elif mapping[pred] == 'Open':
63-
pyautogui.press('o')
64-
elif mapping[pred] == 'Pinch':
65-
pyautogui.press('p')
66-
playsound('Other/click.wav')
77+
def _key_press(self, pred, mapping):
78+
if mapping[pred] in self.key_mapping:
79+
pyautogui.press(self.key_mapping[mapping[pred]])
6780

68-
def predict(self, gest, device='cpu'):
69-
g_tensor = torch.tensor([gest], dtype=torch.float32).to(device)
81+
def _predict(self, gest, device='cpu'):
82+
g_tensor = torch.tensor(np.expand_dims(np.array(gest, dtype=np.float32), axis=0), dtype=torch.float32).to(device)
7083
with torch.no_grad():
7184
output = self.model.forward_once(g_tensor)
7285
pred = output.argmax(dim=1).item()
7386
prob = F.softmax(output, dim=1).max().item()
7487
return pred, prob
7588

76-
def get_features(self, data, window_size, window_inc, feats, feat_dic):
89+
def _get_features(self, data, window_size, window_inc, feats, feat_dic):
7790
fe = FeatureExtractor()
7891
data = np.array([get_windows(d, window_size, window_inc) for d in data], dtype='object')
7992
if feats is None:
@@ -83,5 +96,4 @@ def get_features(self, data, window_size, window_inc, feats, feat_dic):
8396
else:
8497
feats = np.array([fe.extract_features(feats, np.array(d, dtype='float'), array=True) for d in data], dtype='object')
8598
feats = np.nan_to_num(feats, copy=True, nan=0, posinf=0, neginf=0)
86-
# expected shape: (NFiles,) -> (Time, channel)
8799
return feats

0 commit comments

Comments
 (0)