-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathget.ts
More file actions
78 lines (71 loc) · 3.28 KB
/
get.ts
File metadata and controls
78 lines (71 loc) · 3.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import type { DataType, Network, Task } from '../index.js'
import { aggregator } from '../index.js'
import { ByzantineRobustAggregator } from './byzantine.js';
type AggregatorOptions = Partial<{
scheme: Task<DataType, Network>["trainingInformation"]["scheme"]; // if undefined, fallback on task.trainingInformation.scheme
roundCutOff: number, // MeanAggregator
threshold: number, // MeanAggregator
thresholdType: 'relative' | 'absolute', // MeanAggregator
}>
/**
* Initializes an aggregator according to the task definition, the training scheme and the aggregator parameters.
* Here is the ordered list of parameters used to define the aggregator and its default behavior:
* task.trainingInformation.aggregationStrategy > options.scheme > task.trainingInformation.scheme
*
* If `task.trainingInformation.aggregationStrategy` is defined, we initialize the chosen aggregator with `options` parameter values.
* Otherwise, we default to a MeanAggregator for both training schemes.
*
* For the MeanAggregator we rely on `options.scheme` and fallback on `task.trainingInformation.scheme` to infer default values.
* Unless specified otherwise, for federated learning or local training the aggregator default to waiting
* for a single contribution to trigger a model update.
* (the server's model update for federated learning or our own contribution if training locally)
* For decentralized learning the aggregator defaults to waiting for every nodes' contribution to trigger a model update.
*
* @param task The task object associated with the current training session
* @param options Options passed down to the aggregator's constructor
* @returns The aggregator
*/
export function getAggregator(
task: Task<DataType, Network>,
options: AggregatorOptions = {},
): aggregator.Aggregator {
const scheme = options.scheme ?? task.trainingInformation.scheme
// If options are not specified, we default to expecting a contribution from all peers, so we set the threshold to 100%
// If scheme == 'federated' then we only expect the server's contribution at each round
// so we set the aggregation threshold to 1 contribution
// If scheme == 'local' then we only expect our own contribution
const networkOptions: Required<AggregatorOptions> = {
scheme,
roundCutOff: 0,
threshold: 1,
thresholdType: scheme === "decentralized" ? "relative" : "absolute",
...options, // user overrides defaults
};
switch (task.trainingInformation.aggregationStrategy) {
case 'byzantine': {
const {byzantineClippingRadius = 1.0, maxIterations = 1, beta = 0.9,
} = task.trainingInformation;
return new ByzantineRobustAggregator(
networkOptions.roundCutOff,
networkOptions.threshold,
networkOptions.thresholdType,
byzantineClippingRadius,
maxIterations,
beta
);
}
case 'mean':
return new aggregator.MeanAggregator(
networkOptions.roundCutOff,
networkOptions.threshold,
networkOptions.thresholdType
)
case 'secure':
if (scheme !== 'decentralized') {
throw new Error('secure aggregation is currently supported for decentralized only')
}
return new aggregator.SecureAggregator(
task.trainingInformation.maxShareValue
)
}
}