Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
1a40fd7
wip
lotif Jan 8, 2026
1d18580
wip
lotif Jan 8, 2026
e42e630
WIP moving forward with the ensemble attack code changes
lotif Jan 13, 2026
a46a010
WIP adding training and sythesizing code
lotif Jan 13, 2026
30c0ed3
More info on readme
lotif Jan 14, 2026
9464962
More ctgan changes
lotif Feb 23, 2026
e5c8fda
Adding the split data code
lotif Feb 24, 2026
8f10678
More config changes and bug fixes
lotif Feb 24, 2026
077d909
Removing ids dynamically
lotif Feb 25, 2026
b711fbd
Working!
lotif Feb 25, 2026
efdde68
Merge branch 'main' into marcelo/ensamble-ctgan
lotif Mar 3, 2026
1a38af2
Fixing indent on config file and adding some more information to the …
lotif Mar 3, 2026
af4f04e
Adding test attack model code
lotif Mar 4, 2026
5afb774
Small bug fixes
lotif Mar 5, 2026
e4ec793
Updates to readme and config file values
lotif Mar 5, 2026
1c13126
Small changes on configs and script bug fixes
lotif Mar 5, 2026
4e9a8c9
Adding the compute attack success script and fixing minor issues
lotif Mar 5, 2026
d83aabf
Cr by CodeRabbit and Sara
lotif Mar 9, 2026
a198fe9
Reducing the amount of training samples to 20k
lotif Mar 9, 2026
0416dbc
Merge branch 'main' into marcelo/ensamble-ctgan
lotif Mar 9, 2026
e69b07e
Change function name to avoid pytest thinking it's a test
lotif Mar 9, 2026
579d0f3
Merge remote-tracking branch 'origin/marcelo/ensamble-ctgan' into mar…
lotif Mar 9, 2026
5fa4fef
Fixing test assertions
lotif Mar 9, 2026
8b6bf10
Merge branch 'main' into marcelo/ensamble-ctgan
lotif Mar 9, 2026
a9369f6
Making population_all_with_challenge.csv into a constant and adding a…
lotif Mar 13, 2026
163bba8
Addressing last comments by Fatemeh
lotif Mar 16, 2026
bf805c1
Merge branch 'main' into marcelo/ensamble-ctgan
lotif Mar 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/ensemble_attack/README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Ensemble Attack

## Data Processing
As the first step of the attack, we need to collect and split the data. The input data collected from all the attacks provided by the MIDST Challenge should be stored in `data_paths.data_paths` as defined by `config.yaml`. You can download and unzip the resources from [this Google Drive link](https://drive.google.com/drive/folders/1rmJ_E6IzG25eCL3foYAb2jVmAstXktJ1?usp=drive_link). Note that you can safely remove the provided shadow models with the competition resources since they are not used in this attack.
As the first step of the attack, we need to collect and split the data. The input data collected from all the attacks provided by the MIDST Challenge should be stored in `data_paths.midst_data_path` as defined by [`configs/experiment_config.yaml`](configs/experiment_config.yaml). You can download and unzip the resources from [this Google Drive link](https://drive.google.com/drive/folders/1rmJ_E6IzG25eCL3foYAb2jVmAstXktJ1?usp=drive_link). Note that you can safely remove the provided shadow models with the competition resources since they are not used in this attack.

Make sure directories and JSON files specified in `data_paths` and `data_processing_config` configurations in `examples/ensemble_attack/config.yaml` exist.
Make sure directories and JSON files specified in `data_paths` and `data_processing_config` configurations in `examples/ensemble_attack/configs/experiment_config.yaml` exist.

To run the whole data processing pipeline, run `run_attack.py` and set `pipeline.run_data_processing` to `true` in `config.yaml`. It reads data from `data_paths.midst_data_path` specified in config, and populates `data_paths.population_data` and `data_paths.processed_attack_data_path` directories.
To run the whole data processing pipeline, run `run_attack.py` and set `pipeline.run_data_processing` to `true` in [`configs/experiment_config.yaml`](configs/experiment_config.yaml). It reads data from `data_paths.midst_data_path` specified in config, and populates `data_paths.population_path` and `data_paths.processed_attack_data_path` directories.

Data processing steps for the MIDST challenge provided resources according to Ensemble attack are as follows:

Expand Down
16 changes: 12 additions & 4 deletions examples/ensemble_attack/compute_attack_success.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@ def load_target_challenge_labels_and_probabilities(
test_prediction_probabilities = np.load(attack_result_file_path)

# Challenge labels are the true membership labels for the challenge points.
test_target = pd.read_csv(challenge_label_path).to_numpy().squeeze()
if challenge_label_path.suffix == ".npy":
test_target = np.load(challenge_label_path).squeeze()
elif challenge_label_path.suffix == ".csv":
test_target = pd.read_csv(challenge_label_path).to_numpy().squeeze()
else:
raise ValueError(f"Unsupported challenge label file type: {challenge_label_path}. Must be .npy or .csv.")

assert len(test_prediction_probabilities) == len(test_target), (
"Number of challenge labels must match number of prediction probabilities."
Expand Down Expand Up @@ -71,9 +76,12 @@ def compute_attack_success_for_given_targets(
predictions = []
targets = []
for target_id in target_ids:
# Override target model id in config as ``attack_probabilities_result_path`` and
# ``challenge_label_path`` are dependent on it and change in runtime.
target_model_config.target_model_id = target_id
# If there is a target model id in the config, override it with the current target id
if "target_model_id" in target_model_config:
# Override target model id in config as ``attack_probabilities_result_path`` and
# ``challenge_label_path`` are dependent on it and change in runtime.
target_model_config.target_model_id = target_id

# Load challenge labels and prediction probabilities
log(INFO, f"Loading challenge labels and prediction probabilities for target model ID {target_id}...")
test_target, test_prediction_probabilities = load_target_challenge_labels_and_probabilities(
Expand Down
7 changes: 2 additions & 5 deletions examples/ensemble_attack/configs/experiment_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,6 @@ data_paths:
attack_evaluation_result_path: ${base_experiment_dir}/evaluation_results # Path where the attack (train phase) evaluation results will be stored (output)


model_paths:
metaclassifier_model_path: ${base_experiment_dir}/trained_models # Path where the trained metaclassifier model will be saved


# Dataset specific information used for processing in this example
data_processing_config:
midst_data_path: /projects/midst-experiments/all_tabddpms/ # Used to collect the data (input)
Expand Down Expand Up @@ -76,7 +72,7 @@ shadow_training:
training_json_config_paths: # Config json files used for tabddpm training on the trans table
table_domain_file_path: ${base_data_config_dir}/trans_domain.json
dataset_meta_file_path: ${base_data_config_dir}/dataset_meta.json
tabddpm_training_config_path: ${base_data_config_dir}/trans.json
training_config_path: ${base_data_config_dir}/trans.json
# Model training artifacts are saved under shadow_models_data_path/workspace_name/exp_name
# Also, training configs for each shadow model are created under shadow_models_data_path.
shadow_models_output_path: ${base_experiment_dir}/shadow_models_and_data
Expand Down Expand Up @@ -112,6 +108,7 @@ metaclassifier:
# Temporary. Might remove having an epoch parameter.
epochs: 1
meta_classifier_model_name: ${metaclassifier.model_type}_metaclassifier_model
metaclassifier_model_path: ${base_experiment_dir}/trained_models # Path where the trained metaclassifier model will be saved

attack_success_computation:
target_ids_to_test: [21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40] # List of target model IDs to compute the attack success for.
Expand Down
6 changes: 2 additions & 4 deletions examples/ensemble_attack/configs/original_attack_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@ data_paths:
processed_attack_data_path: ${base_data_dir}/attack_data # Path where the processed attack real train and evaluation data is stored
attack_evaluation_result_path: ${base_example_dir}/attack_results # Path where the attack evaluation results will be stored

model_paths:
metaclassifier_model_path: ${base_example_dir}/trained_models # Path where the trained metaclassifier model will be saved

# Pipeline control
pipeline:
run_data_processing: true # Set this to false if you have already saved the processed data
Expand Down Expand Up @@ -58,7 +55,7 @@ shadow_training:
training_json_config_paths: # Config json files used for tabddpm training on the trans table
table_domain_file_path: ${base_example_dir}/data_configs/trans_domain.json
dataset_meta_file_path: ${base_example_dir}/data_configs/dataset_meta.json
tabddpm_training_config_path: ${base_example_dir}/data_configs/trans.json
training_config_path: ${base_example_dir}/data_configs/trans.json
# Model training artifacts are saved under shadow_models_data_path/workspace_name/exp_name
# Also, training configs for each shadow model are created under shadow_models_data_path.
shadow_models_output_path: ${base_data_dir}/shadow_models_and_data
Expand Down Expand Up @@ -93,6 +90,7 @@ metaclassifier:
# Temporary. Might remove having an epoch parameter.
epochs: 1
meta_classifier_model_name: ${metaclassifier.model_type}_metaclassifier_model
metaclassifier_model_path: ${base_example_dir}/trained_models # Path where the trained metaclassifier model will be saved


# General settings
Expand Down
42 changes: 24 additions & 18 deletions examples/ensemble_attack/real_data_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from enum import Enum
from logging import INFO
from pathlib import Path
from typing import Literal

import pandas as pd
from omegaconf import DictConfig
Expand All @@ -15,6 +14,10 @@
from midst_toolkit.common.logger import log


COLLECTED_DATA_FILE_NAME = "population_all_with_challenge.csv"
COLLECTED_DATA_NO_CHALLENGE_FILE_NAME = "population_all_no_challenge.csv"


class AttackType(Enum):
"""Enum for the different attack types."""

Expand All @@ -32,7 +35,11 @@ class AttackType(Enum):
TABDDPM_100K = "tabddpm_trained_with_100k"


DatasetType = Literal["train", "challenge"]
class AttackDataset(Enum):
"""Enum for the different attack datasets."""

TRAIN = "train"
CHALLENGE = "challenge"


def expand_ranges(ranges: list[tuple[int, int]]) -> list[int]:
Expand All @@ -56,7 +63,7 @@ def collect_midst_attack_data(
attack_type: AttackType,
data_dir: Path,
split_folder: str,
dataset: DatasetType,
dataset: AttackDataset,
data_processing_config: DictConfig,
) -> pd.DataFrame:
"""
Expand All @@ -74,20 +81,19 @@ def collect_midst_attack_data(
Returns:
pd.DataFrame: The specified dataset in this setting.
"""
assert dataset in {
"train",
"challenge",
}, "Only 'train' and 'challenge' collection is supported."
assert dataset in {AttackDataset.TRAIN, AttackDataset.CHALLENGE}, (
"Only 'train' and 'challenge' collections are supported."
)
# `data_id` is the folder numbering of each training or challenge dataset,
# and is defined with the provided config.
data_id = expand_ranges(data_processing_config.folder_ranges[split_folder])

# Get file name based on the kind of dataset to be collected (i.e. train vs challenge).
# TODO: Make the below parsing a bit more robust and less brittle
generation_name = attack_type.value.split("_")[0]
if dataset == "challenge":
if dataset == AttackDataset.CHALLENGE:
file_name = data_processing_config.challenge_data_file_name
else:
else: # dataset == AttackDataset.TRAIN
# Multi-table attacks have different file names.
file_name = (
data_processing_config.multi_table_train_data_file_name
Expand All @@ -110,7 +116,7 @@ def collect_midst_data(
midst_data_input_dir: Path,
attack_types: list[AttackType],
split_folders: list[str],
dataset: DatasetType,
dataset: AttackDataset,
data_processing_config: DictConfig,
) -> pd.DataFrame:
"""
Expand All @@ -133,7 +139,6 @@ def collect_midst_data(
Returns:
Collected train or challenge data as a dataframe.
"""
assert dataset in {"train", "challenge"}, "Only 'train' and 'challenge' collection is supported."
population = []
for attack_type in attack_types:
for split_folder in split_folders:
Expand Down Expand Up @@ -204,7 +209,7 @@ def collect_population_data_ensemble(
midst_data_input_dir,
population_attack_types,
split_folders=population_splits,
dataset="train",
dataset=AttackDataset.TRAIN,
data_processing_config=data_processing_config,
)

Expand All @@ -221,7 +226,8 @@ def collect_population_data_ensemble(
)

# Drop ids.
df_population_no_id = df_population.drop(columns=["trans_id", "account_id"])
id_columns = [c for c in df_population.columns if c.endswith("_id")]
df_population_no_id = df_population.drop(columns=id_columns)
# Save the population data
save_dataframe(df_population, save_dir, "population_all.csv")
save_dataframe(df_population_no_id, save_dir, "population_all_no_id.csv")
Expand All @@ -233,7 +239,7 @@ def collect_population_data_ensemble(
midst_data_input_dir,
attack_types=challenge_attack_types,
split_folders=challenge_splits,
dataset="challenge",
dataset=AttackDataset.CHALLENGE,
data_processing_config=data_processing_config,
)
log(INFO, f"Collected challenge data length: {len(df_challenge)} from splits: {challenge_splits}")
Expand All @@ -242,23 +248,23 @@ def collect_population_data_ensemble(

# Population data without the challenge points
df_population_no_challenge = df_population[~df_population["trans_id"].isin(df_challenge["trans_id"])]
save_dataframe(df_population_no_challenge, save_dir, "population_all_no_challenge.csv")
save_dataframe(df_population_no_challenge, save_dir, COLLECTED_DATA_NO_CHALLENGE_FILE_NAME)
# Remove ids
df_population_no_challenge_no_id = df_population_no_challenge.drop(columns=["trans_id", "account_id"])
save_dataframe(
df_population_no_challenge_no_id,
save_dir,
"population_all_no_challenge_no_id.csv",
f"{Path(COLLECTED_DATA_NO_CHALLENGE_FILE_NAME).stem}_no_id.csv",
)

# Population data with all the challenge points
df_population_with_challenge = pd.concat([df_population_no_challenge, df_challenge])
save_dataframe(df_population_with_challenge, save_dir, "population_all_with_challenge.csv")
save_dataframe(df_population_with_challenge, save_dir, COLLECTED_DATA_FILE_NAME)
# Remove ids
df_population_with_challenge_no_id = df_population_with_challenge.drop(columns=["trans_id", "account_id"])
save_dataframe(
df_population_with_challenge_no_id,
save_dir,
"population_all_with_challenge_no_id.csv",
f"{Path(COLLECTED_DATA_FILE_NAME).stem}_no_id.csv",
)
return df_population_with_challenge
4 changes: 2 additions & 2 deletions examples/ensemble_attack/run_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import examples.ensemble_attack.run_metaclassifier_training as meta_pipeline
import examples.ensemble_attack.run_shadow_model_training as shadow_pipeline
from examples.ensemble_attack.real_data_collection import collect_population_data_ensemble
from examples.ensemble_attack.real_data_collection import COLLECTED_DATA_FILE_NAME, collect_population_data_ensemble
from midst_toolkit.attacks.ensemble.data_utils import load_dataframe
from midst_toolkit.attacks.ensemble.process_split_data import process_split_data
from midst_toolkit.common.logger import log
Expand All @@ -33,7 +33,7 @@ def run_data_processing(config: DictConfig) -> None:
# is not enough.
original_population_data = load_dataframe(
Path(config.data_processing_config.original_population_data_path),
"population_all_with_challenge.csv",
COLLECTED_DATA_FILE_NAME,
)
log(INFO, "Running data processing pipeline...")
# Collect the real data from the MIDST challenge resources.
Expand Down
6 changes: 4 additions & 2 deletions examples/ensemble_attack/run_metaclassifier_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pandas as pd
from omegaconf import DictConfig

from examples.ensemble_attack.real_data_collection import COLLECTED_DATA_FILE_NAME
from midst_toolkit.attacks.ensemble.blending import BlendingPlusPlus, MetaClassifierType
from midst_toolkit.attacks.ensemble.data_utils import load_dataframe
from midst_toolkit.common.logger import log
Expand Down Expand Up @@ -80,9 +81,10 @@ def run_metaclassifier_training(
assert target_synthetic_data is not None, "Target model's synthetic data is missing."
target_synthetic_data = target_synthetic_data.copy()

data_file_name = config.data_file_name if "data_file_name" in config else COLLECTED_DATA_FILE_NAME
df_reference = load_dataframe(
Path(config.data_paths.population_path),
"population_all_with_challenge_no_id.csv",
f"{Path(data_file_name).stem}_no_id.csv",
)
log(
INFO,
Expand Down Expand Up @@ -125,7 +127,7 @@ def run_metaclassifier_training(
)

model_filename = config.metaclassifier.meta_classifier_model_name
model_path = Path(config.model_paths.metaclassifier_model_path) / f"{model_filename}.pkl"
model_path = Path(config.metaclassifier.metaclassifier_model_path) / f"{model_filename}.pkl"
model_path.parent.mkdir(parents=True, exist_ok=True)
with open(model_path, "wb") as f:
pickle.dump(blending_attacker.trained_model, f)
Expand Down
Loading