-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathcompute_attack_success.py
More file actions
138 lines (113 loc) · 5.98 KB
/
compute_attack_success.py
File metadata and controls
138 lines (113 loc) · 5.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""Provided test prediction probabilities of several attacked target models,
this script computes and saves the attack success metric.
"""
from logging import INFO
from pathlib import Path
import hydra
import numpy as np
import pandas as pd
from omegaconf import DictConfig
from midst_toolkit.common.logger import log
from midst_toolkit.evaluation.privacy.mia_scoring import TprAtFpr
def load_target_challenge_labels_and_probabilities(
metaclassifier_model_name: str, attack_results_path: Path, challenge_label_path: Path
) -> tuple[np.ndarray, np.ndarray]:
"""
Loads and returns the challenge labels and test prediction probabilities for
a given target model.
Args:
metaclassifier_model_name: Name of the metaclassifier model used in the attack.
attack_results_path: Path to the directory where attack results are saved.
challenge_label_path: Path to the CSV file containing challenge labels.
Return:
A tuple containing:
- test_target: Numpy array of true membership labels for the challenge points.
- test_prediction_probabilities: Numpy array of prediction probabilities
outputted by the metaclassifier for the challenge points.
"""
attack_result_file_path = attack_results_path / f"{metaclassifier_model_name}_test_pred_proba.npy"
assert attack_result_file_path.exists(), (
f"No file found at {attack_result_file_path}. Make sure the path is correct, or run the attack on the target model first."
)
# Load the attack results containing test prediction probabilities.
test_prediction_probabilities = np.load(attack_result_file_path)
# Challenge labels are the true membership labels for the challenge points.
if challenge_label_path.suffix == ".npy":
test_target = np.load(challenge_label_path).squeeze()
elif challenge_label_path.suffix == ".csv":
test_target = pd.read_csv(challenge_label_path).to_numpy().squeeze()
else:
raise ValueError(f"Unsupported challenge label file type: {challenge_label_path}. Must be .npy or .csv.")
assert len(test_prediction_probabilities) == len(test_target), (
"Number of challenge labels must match number of prediction probabilities."
)
return test_target, test_prediction_probabilities
def compute_attack_success_for_given_targets(
target_model_config: DictConfig,
target_ids: list[int],
experiment_directory: Path,
metaclassifier_model_name: str,
) -> None:
"""
Computes and saves the attack success metric given the test prediction probabilities
of several attacked target models by concatenating the target models' targets and predictions.
NOTE: This function does not compute the average success across all models but rather
treats all predictions and labels together for metric computation.
Args:
target_model_config: Configuration object for target models set in ``experiments_config.yaml``.
target_ids: List of target model IDs to compute the attack success for.
experiment_directory: Path to the base experiment directory where results are saved.
metaclassifier_model_name: Name of the metaclassifier model used in the attack.
"""
predictions = []
targets = []
for target_id in target_ids:
# If there is a target model id in the config, override it with the current target id
if "target_model_id" in target_model_config:
# Override target model id in config as ``attack_probabilities_result_path`` and
# ``challenge_label_path`` are dependent on it and change in runtime.
target_model_config.target_model_id = target_id
# Load challenge labels and prediction probabilities
log(INFO, f"Loading challenge labels and prediction probabilities for target model ID {target_id}...")
test_target, test_prediction_probabilities = load_target_challenge_labels_and_probabilities(
metaclassifier_model_name=metaclassifier_model_name,
attack_results_path=Path(target_model_config.attack_probabilities_result_path),
challenge_label_path=Path(target_model_config.challenge_label_path),
)
predictions.append(test_prediction_probabilities)
targets.append(test_target)
# Flatten arrays
predictions = np.concatenate(predictions)
targets = np.concatenate(targets)
assert len(predictions) == len(targets), "Number of predictions must match number of targets."
# Compute TPR@FPR for all the target models
tpr_at_fpr = TprAtFpr.get_tpr_at_fpr(targets, predictions, fpr_threshold=0.1)
# Save the final attack success rate into a text file.
metric_save_path = experiment_directory / f"attack_success_for_{metaclassifier_model_name}.txt"
log(INFO, f"Saving attack success value of {tpr_at_fpr} TPR at FPR=0.1 to {metric_save_path}")
with open(metric_save_path, "w") as f:
f.write(f"Final TPR at FPR=0.1: {tpr_at_fpr:.4f}\n")
@hydra.main(config_path="configs", config_name="experiment_config", version_base=None)
def main(
config: DictConfig,
) -> None:
"""
Main function to compute and save the attack success metric given the test prediction probabilities
of several attacked target models.
Args:
config: Configuration object set in ``experiments_config.yaml``.
"""
assert config.attack_success_computation.target_ids_to_test is not None, (
"Please specify target model IDs to compute attack success for in the config "
"by specifying `attack_success_computation.target_ids_to_test`."
)
target_ids = list(config.attack_success_computation.target_ids_to_test)
log(INFO, f"Computing attack success for target model IDs: {target_ids}...")
compute_attack_success_for_given_targets(
target_model_config=config.target_model,
target_ids=target_ids,
experiment_directory=Path(config.base_experiment_dir),
metaclassifier_model_name=config.metaclassifier.meta_classifier_model_name,
)
if __name__ == "__main__":
main()