-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathsimulated.py
More file actions
49 lines (40 loc) · 1.2 KB
/
simulated.py
File metadata and controls
49 lines (40 loc) · 1.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from benchopt import BaseDataset, safe_import_context
from benchopt.datasets.simulated import make_correlated_data
with safe_import_context() as import_ctx:
preprocess_data = import_ctx.import_from("utils", "preprocess_data")
class Dataset(BaseDataset):
name = "Simulated"
parameters = {
"n_samples, n_features, n_signals": [
(10_000, 200, 20),
(200, 10_000, 20),
],
"rho": [0, 0.5],
"standardize": [True, False],
}
def __init__(
self,
n_samples=10,
n_features=50,
n_signals=5,
rho=0,
random_state=27,
standardize=True,
):
self.n_samples = n_samples
self.n_features = n_features
self.n_signals = n_signals
self.random_state = random_state
self.rho = rho
self.standardize = standardize
def get_data(self):
X, y, _ = make_correlated_data(
self.n_samples,
self.n_features,
rho=self.rho,
density=self.n_signals / self.n_features,
random_state=self.random_state,
)
if self.standardize:
X, y = preprocess_data(X, y)
return dict(X=X, y=y)