Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .claude/formatting.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ Do not suppress errors with workarounds like `# type: ignore`:

# Formatting Details

- **autoflake**: Removes unused imports. Excludes `*_pb2.py*` and `__init__.py`.
- **isort**: Sorts imports (black profile).
- **black**: Code formatter (line length 88). Excludes `*_pb2.py*`.
- **ruff check**: Removes unused imports (`F401`) and sorts imports (`I`). Excludes `*_pb2.py*` and ignores `F401` in
`__init__.py`.
- **ruff format**: Code formatter (line length 88, black-compatible). Excludes `*_pb2.py*`.
- **mdformat**: Markdown formatter (wrap 120, tables extension).

**Note:** `make format` is NOT a pre-commit hook — pre-commit only runs whitespace and EOF fixes. Always run
Expand Down
10 changes: 4 additions & 6 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,8 @@ unit_test_scala: clean_build_files_scala
unit_test: precondition_tests unit_test_py unit_test_scala

check_format_py:
uv run autoflake --check --config pyproject.toml ${PYTHON_DIRS}
uv run isort --check-only --settings-path=pyproject.toml ${PYTHON_DIRS}
uv run black --check --config=pyproject.toml ${PYTHON_DIRS}
uv run ruff check --config pyproject.toml ${PYTHON_DIRS}
Comment thread
svij-sc marked this conversation as resolved.
uv run ruff format --check --config pyproject.toml ${PYTHON_DIRS}

check_format_scala:
( cd scala; sbt "scalafmtCheckAll; scalafixAll --check"; )
Expand Down Expand Up @@ -129,9 +128,8 @@ mock_assets:
uv run python -m gigl.src.mocking.dataset_asset_mocking_suite --resource_config_uri="deployment/configs/e2e_cicd_resource_config.yaml" --env test

format_py:
uv run autoflake --config pyproject.toml ${PYTHON_DIRS}
uv run isort --settings-path=pyproject.toml ${PYTHON_DIRS}
uv run black --config=pyproject.toml ${PYTHON_DIRS}
uv run ruff check --fix --config pyproject.toml ${PYTHON_DIRS}
uv run ruff format --config pyproject.toml ${PYTHON_DIRS}

format_scala:
# We run "clean" before the formatting because otherwise some "scalafix.sbt.ScalafixFailed: NoFilesError" may get thrown after switching branches...
Expand Down
149 changes: 79 additions & 70 deletions examples/MAG240M/fetch_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,31 +23,33 @@
"metadata": {},
"outputs": [],
"source": [
"from typing import Optional\n",
"import os \n",
"import multiprocessing\n",
"import numpy as np\n",
"import os\n",
"from io import BytesIO\n",
"from fastavro import writer\n",
"from google.cloud import storage\n",
"from google.cloud import bigquery\n",
"from typing import Optional\n",
"\n",
"from ogb.lsc import MAG240MDataset\n",
"import numpy as np\n",
"from common import MAG240_DATASET_PATH\n",
"from fastavro import writer\n",
"from google.cloud import bigquery, storage\n",
"from ogb.lsc import MAG240MDataset\n",
"\n",
"from gigl.common import GcsUri\n",
"from gigl.common.utils.gcs import GcsUtils\n",
"from gigl.src.common.utils.bq import BqUtils\n",
"from gigl.common import GcsUri\n",
"\n",
"BASE_GCS_BUCKET_PATH = (\n",
" \"CHANGE THIS WITH YOUR GCS PATH\" # WARN: CHANGE THIS WITH YOUR GCS PATH\n",
")\n",
"BASE_BQ_PATH = \"CHANGE THIS WITH YOUR BQ PATH\" # WARN: CHANGE THIS WITH YOUR BQ PATH\n",
"\n",
"BASE_GCS_BUCKET_PATH = \"CHANGE THIS WITH YOUR GCS PATH\" # WARN: CHANGE THIS WITH YOUR GCS PATH\n",
"BASE_BQ_PATH = \"CHANGE THIS WITH YOUR BQ PATH\" # WARN: CHANGE THIS WITH YOUR BQ PATH\n",
"\n",
"def get_gcs_path_for_asset(asset_name):\n",
" return f\"{BASE_GCS_BUCKET_PATH}/{asset_name}\"\n",
"\n",
"\n",
"def get_bq_path_for_asset(asset_name):\n",
" return f\"{BASE_BQ_PATH}_{asset_name}\"\n"
" return f\"{BASE_BQ_PATH}_{asset_name}\""
]
},
{
Expand All @@ -60,14 +62,15 @@
" # !pip install -U ogb # We will pull the dataset from ogb: https://ogb.stanford.edu/docs/lsc/\n",
" # Example of how this dataset can be used: https://github.com/snap-stanford/ogb/blob/master/examples/lsc/mag240m/gnn.py\n",
" # WARNING: This code block can take hours to even days to run if the dataset is not available already locally\n",
" # This is likely because upstream servers are slow as even on cloud dinstances with 100GB/s + network, \n",
" # This is likely because upstream servers are slow as even on cloud dinstances with 100GB/s + network,\n",
" # the download is slow and took half a day.\n",
" print(\"Fetching MAG240M dataset and storing it in \", MAG240_DATASET_PATH)\n",
" if not os.path.exists(MAG240_DATASET_PATH):\n",
" os.makedirs(MAG240_DATASET_PATH)\n",
" dataset = MAG240MDataset(root = MAG240_DATASET_PATH)\n",
" dataset = MAG240MDataset(root=MAG240_DATASET_PATH)\n",
" return dataset\n",
"\n",
"\n",
"dataset = fetch_dataset()"
]
},
Expand All @@ -77,9 +80,9 @@
"metadata": {},
"outputs": [],
"source": [
"author_writes_paper = dataset.edge_index('author', 'paper')\n",
"author_affiliated_with_institution = dataset.edge_index('author', 'institution')\n",
"paper_cites_paper = dataset.edge_index('paper', 'paper')\n",
"author_writes_paper = dataset.edge_index(\"author\", \"paper\")\n",
"author_affiliated_with_institution = dataset.edge_index(\"author\", \"institution\")\n",
"paper_cites_paper = dataset.edge_index(\"paper\", \"paper\")\n",
"\n",
"paper_feat = dataset.all_paper_feat\n",
"paper_label = dataset.all_paper_label\n",
Expand All @@ -92,63 +95,46 @@
"metadata": {},
"outputs": [],
"source": [
"\n",
"# Define all the Avro schemas\n",
"fields_node_paper = [\n",
" {\"name\": \"paper\", \"type\": \"int\"},\n",
"]\n",
"for featIdx in range(dataset.num_paper_features):\n",
" fields_node_paper.append({\"name\": f\"feat_{featIdx}\", \"type\": \"float\"})\n",
"\n",
"schema_node_paper = {\n",
" \"type\": \"record\",\n",
" \"name\": \"Paper\",\n",
" \"fields\": fields_node_paper\n",
"}\n",
"schema_node_paper = {\"type\": \"record\", \"name\": \"Paper\", \"fields\": fields_node_paper}\n",
"\n",
"schema_node_paper_label = {\n",
" \"type\": \"record\",\n",
" \"name\": \"PaperLabel\",\n",
" \"fields\": [\n",
" {\"name\": \"paper\", \"type\": \"int\"},\n",
" {\"name\": \"label\", \"type\": \"float\"}\n",
" ]\n",
" \"fields\": [{\"name\": \"paper\", \"type\": \"int\"}, {\"name\": \"label\", \"type\": \"float\"}],\n",
"}\n",
"\n",
"schema_node_paper_year = {\n",
" \"type\": \"record\",\n",
" \"name\": \"PaperYear\",\n",
" \"fields\": [\n",
" {\"name\": \"paper\", \"type\": \"int\"},\n",
" {\"name\": \"year\", \"type\": \"int\"}\n",
" ]\n",
" \"fields\": [{\"name\": \"paper\", \"type\": \"int\"}, {\"name\": \"year\", \"type\": \"int\"}],\n",
"}\n",
"\n",
"schema_edge_author_writes_paper = {\n",
" \"type\": \"record\",\n",
" \"name\": \"AuthorWritesPaper\",\n",
" \"fields\": [\n",
" {\"name\": \"author\", \"type\": \"int\"},\n",
" {\"name\": \"paper\", \"type\": \"int\"}\n",
" ]\n",
" \"fields\": [{\"name\": \"author\", \"type\": \"int\"}, {\"name\": \"paper\", \"type\": \"int\"}],\n",
"}\n",
"\n",
"schema_edge_author_afil_with_institution = {\n",
" \"type\": \"record\",\n",
" \"name\": \"AuthorAfiliatedWithInstitution\",\n",
" \"fields\": [\n",
" {\"name\": \"author\", \"type\": \"int\"},\n",
" {\"name\": \"institution\", \"type\": \"int\"}\n",
" ]\n",
" {\"name\": \"institution\", \"type\": \"int\"},\n",
" ],\n",
"}\n",
"\n",
"schema_edge_paper_cites_paper = {\n",
" \"type\": \"record\",\n",
" \"name\": \"PaperCitesPaper\",\n",
" \"fields\": [\n",
" {\"name\": \"src\", \"type\": \"int\"},\n",
" {\"name\": \"dst\", \"type\": \"int\"}\n",
" ]\n",
" \"fields\": [{\"name\": \"src\", \"type\": \"int\"}, {\"name\": \"dst\", \"type\": \"int\"}],\n",
"}"
]
},
Expand All @@ -161,73 +147,91 @@
"# We will write to avro format and flush buffer every 100k rows and change the file name\n",
"class BufferedGCSAvroWriter:\n",
" def __init__(\n",
" self, \n",
" schema: dict, \n",
" gcs_bucket_path: str, \n",
" max_buffer_bytes=1.5e+8, # 150MB\n",
" ):\n",
" self,\n",
" schema: dict,\n",
" gcs_bucket_path: str,\n",
" max_buffer_bytes=1.5e8, # 150MB\n",
" ):\n",
" self.schema = schema\n",
" self.gcs_bucket_path = gcs_bucket_path\n",
" self.max_buffer_bytes = max_buffer_bytes\n",
" self.buffer = BytesIO()\n",
" self.file_index = 0\n",
"\n",
" \n",
" def flush(self):\n",
" # Reset buffer position to the beginning\n",
" self.buffer.seek(0)\n",
" # Initialize GCS client and upload the file\n",
" storage_client = storage.Client()\n",
" \n",
"\n",
" split_path = self.gcs_bucket_path.split(\"/\")\n",
" bucket_name = split_path[2]\n",
" blob_name_prefix = \"/\".join(split_path[3:])\n",
" destination_blob_name = f\"{blob_name_prefix}_{self.file_index}.avro\"\n",
" print(f\"Avro file will be uploaded to gs://{bucket_name}/{destination_blob_name}\")\n",
" print(\n",
" f\"Avro file will be uploaded to gs://{bucket_name}/{destination_blob_name}\"\n",
" )\n",
"\n",
" bucket = storage_client.bucket(bucket_name)\n",
" blob = bucket.blob(destination_blob_name)\n",
" blob.upload_from_file(self.buffer, content_type=\"application/octet-stream\")\n",
" print(f\"Avro file uploaded to gs://{bucket_name}/{destination_blob_name}\")\n",
" self.buffer = BytesIO()\n",
" self.file_index += 1\n",
" \n",
"\n",
" def write(self, record: dict):\n",
" # Write Avro data to the buffer using fastavro\n",
" writer(self.buffer, self.schema, [record])\n",
" if self.buffer.tell() > self.max_buffer_bytes:\n",
" self.flush()\n",
"\n",
"def __write_table_to_gcs_proc(schema: dict, edge_table: np.ndarray, gcs_bucket_path: str, proc_num: int):\n",
"\n",
"def __write_table_to_gcs_proc(\n",
" schema: dict, edge_table: np.ndarray, gcs_bucket_path: str, proc_num: int\n",
"):\n",
" print(f\"Begin writing edge table to GCS with proc_num: {proc_num}\")\n",
" gcs_path_for_writer = f\"{gcs_bucket_path}/{proc_num}\"\n",
" avro_writer = BufferedGCSAvroWriter(schema=schema, gcs_bucket_path=gcs_path_for_writer)\n",
" avro_writer = BufferedGCSAvroWriter(\n",
" schema=schema, gcs_bucket_path=gcs_path_for_writer\n",
" )\n",
" field_names = [field_info[\"name\"] for field_info in schema[\"fields\"]]\n",
" print(f\"Will write the following fields: {field_names}\")\n",
" \n",
"\n",
" for edge in edge_table.T:\n",
" obj = {field_names[i]: edge[i] for i in range(len(field_names))}\n",
" avro_writer.write(obj)\n",
" print(f\"Proc {proc_num} finished writing edge table to GCS.\")\n",
" avro_writer.flush()\n",
"\n",
"\n",
"def write_edge_table_to_gcs(schema: dict, edge_table: np.ndarray, gcs_bucket_path: str):\n",
" gcs_utils = GcsUtils()\n",
" print(f\"Clearing GCS path {gcs_bucket_path}\")\n",
" gcs_utils.delete_files_in_bucket_dir(gcs_path = GcsUri(gcs_bucket_path))\n",
" gcs_utils.delete_files_in_bucket_dir(gcs_path=GcsUri(gcs_bucket_path))\n",
" num_procs = 10\n",
" edge_table_chunks = np.array_split(edge_table, 10, axis=1)\n",
" with multiprocessing.Pool(processes=num_procs) as pool:\n",
" pool.starmap(__write_table_to_gcs_proc, [\n",
" (schema, edge_table_chunk, gcs_bucket_path, i)\n",
" for i, edge_table_chunk in enumerate(edge_table_chunks)\n",
" ])\n",
" pool.starmap(\n",
" __write_table_to_gcs_proc,\n",
" [\n",
" (schema, edge_table_chunk, gcs_bucket_path, i)\n",
" for i, edge_table_chunk in enumerate(edge_table_chunks)\n",
" ],\n",
" )\n",
"\n",
"\n",
"def __write_node_table_to_gcs_proc(schema: dict, node_table: np.ndarray, gcs_bucket_path: str, proc_num: int, enumerate_starting_from: Optional[int] = None):\n",
"def __write_node_table_to_gcs_proc(\n",
" schema: dict,\n",
" node_table: np.ndarray,\n",
" gcs_bucket_path: str,\n",
" proc_num: int,\n",
" enumerate_starting_from: Optional[int] = None,\n",
"):\n",
" print(f\"Begin writing node table to GCS with proc_num: {proc_num}\")\n",
" gcs_path_for_writer = f\"{gcs_bucket_path}/{proc_num}\"\n",
" avro_writer = BufferedGCSAvroWriter(schema=schema, gcs_bucket_path=gcs_path_for_writer)\n",
" avro_writer = BufferedGCSAvroWriter(\n",
" schema=schema, gcs_bucket_path=gcs_path_for_writer\n",
" )\n",
" field_names = [field_info[\"name\"] for field_info in schema[\"fields\"]]\n",
" print(f\"Will write the following fields: {field_names}\")\n",
" # If node table dim is 1, then carefully mange the enumeration\n",
Expand All @@ -239,7 +243,7 @@
" if curr_count is not None:\n",
" obj = {field_names[0]: curr_count}\n",
" for i in range(1, len(field_names)):\n",
" obj[field_names[i]] = node[i-1]\n",
" obj[field_names[i]] = node[i - 1]\n",
" curr_count += 1\n",
" else:\n",
" obj = {field_names[i]: node[i] for i in range(len(field_names))}\n",
Expand All @@ -248,20 +252,22 @@
" print(f\"Proc {proc_num} finished writing node table to GCS.\")\n",
" avro_writer.flush()\n",
"\n",
"\n",
"def write_node_table_to_gcs(schema: dict, node_table: np.ndarray, gcs_bucket_path: str):\n",
" gcs_utils = GcsUtils()\n",
" print(f\"Clearing GCS path {gcs_bucket_path}\")\n",
" gcs_utils.delete_files_in_bucket_dir(gcs_path = GcsUri(gcs_bucket_path))\n",
" gcs_utils.delete_files_in_bucket_dir(gcs_path=GcsUri(gcs_bucket_path))\n",
" num_procs = 10\n",
" node_table_chunks = np.array_split(node_table, 10, axis=0)\n",
" chunk_sizes = [len(chunk) for chunk in node_table_chunks]\n",
" with multiprocessing.Pool(processes=num_procs) as pool:\n",
" pool.starmap(__write_node_table_to_gcs_proc, [\n",
" (schema, node_table_chunk, gcs_bucket_path, i, sum(chunk_sizes[:i]))\n",
" for i, node_table_chunk in enumerate(node_table_chunks)\n",
" ])\n",
"\n",
" \n",
" pool.starmap(\n",
" __write_node_table_to_gcs_proc,\n",
" [\n",
" (schema, node_table_chunk, gcs_bucket_path, i, sum(chunk_sizes[:i]))\n",
" for i, node_table_chunk in enumerate(node_table_chunks)\n",
" ],\n",
" )\n",
"\n",
"\n",
"class AvroToBqWriter:\n",
Expand All @@ -278,7 +284,7 @@
" print(f\"Finished writing to BQ table {table_id}; with result: {result}\")\n",
"\n",
" num_rows = bq_utils.count_number_of_rows_in_bq_table(table_id)\n",
" print(f\"Loaded {num_rows} rows to {table_id}\")\n"
" print(f\"Loaded {num_rows} rows to {table_id}\")"
]
},
{
Expand All @@ -289,9 +295,12 @@
"source": [
"edge_assets_to_export = [\n",
" (\"author_writes_paper\", author_writes_paper, schema_edge_author_writes_paper),\n",
" (\"author_affiliated_with_institution\", author_affiliated_with_institution, schema_edge_author_afil_with_institution),\n",
" (\n",
" \"author_affiliated_with_institution\",\n",
" author_affiliated_with_institution,\n",
" schema_edge_author_afil_with_institution,\n",
" ),\n",
" (\"paper_cites_paper\", paper_cites_paper, schema_edge_paper_cites_paper),\n",
" \n",
"]\n",
"\n",
"for asset_name, edge_table, schema in edge_assets_to_export:\n",
Expand Down
Loading