diff --git a/v3-examples/training-examples/breast-cancer-detection.ipynb b/v3-examples/training-examples/breast-cancer-detection.ipynb new file mode 100644 index 0000000000..7ba442ab81 --- /dev/null +++ b/v3-examples/training-examples/breast-cancer-detection.ipynb @@ -0,0 +1,656 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "nbpresent": { + "id": "42b5e80b-ad1d-4335-a1f7-10a91127e3dc" + } + }, + "source": [ + "# Breast Cancer Prediction (SageMaker V3)\n", + "_**Predict Breast Cancer using SageMaker's Linear-Learner with features derived from images of Breast Mass**_\n", + "\n", + "---\n", + "\n", + "This notebook has been migrated to use SageMaker Python SDK V3 interfaces.\n", + "\n", + "---\n", + "\n", + "## Contents\n", + "\n", + "1. [Background](#Background)\n", + "1. [Setup](#Setup)\n", + "1. [Data](#Data)\n", + "1. [Train](#Train)\n", + "1. [Host](#Host)\n", + "1. [Predict](#Predict)\n", + "1. [Extensions](#Extensions)\n", + "\n", + "---\n", + "\n", + "## Background\n", + "This notebook illustrates how one can use SageMaker's algorithms for solving applications which require `linear models` for prediction. For this illustration, we have taken an example for breast cancer prediction using UCI'S breast cancer diagnostic data set available at https://archive.ics.uci.edu/ml/datasets/Breast+Cancer+Wisconsin+%28Diagnostic%29. The data set is also available on Kaggle at https://www.kaggle.com/uciml/breast-cancer-wisconsin-data. The purpose here is to use this data set to build a predictve model of whether a breast mass image indicates benign or malignant tumor. The data set will be used to illustrate\n", + "\n", + "* Basic setup for using SageMaker V3.\n", + "* converting datasets to protobuf format used by the Amazon SageMaker algorithms and uploading to S3. \n", + "* Training SageMaker's linear learner on the data set using ModelTrainer.\n", + "* Hosting the trained model using V3 resources.\n", + "* Scoring using the trained model.\n", + "\n", + "\n", + "\n", + "---\n", + "\n", + "## Setup\n", + "\n", + "Let's start by specifying:\n", + "\n", + "* The SageMaker role arn used to give learning and hosting access to your data.\n", + "* The S3 bucket that you want to use for training and storing model objects." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "isConfigCell": true, + "nbpresent": { + "id": "6427e831-8f89-45c0-b150-0b134397d79a" + }, + "tags": [ + "parameters" + ] + }, + "outputs": [], + "source": [ + "import os\n", + "import boto3\n", + "import re\n", + "\n", + "# V3 imports\n", + "from sagemaker.core.helper.session_helper import Session, get_execution_role\n", + "from sagemaker.core import image_uris\n", + "from sagemaker.train.model_trainer import ModelTrainer\n", + "from sagemaker.train.configs import InputData, Compute\n", + "from sagemaker.core.resources import Model, EndpointConfig, Endpoint\n", + "\n", + "# Initialize V3 session\n", + "sagemaker_session = Session()\n", + "role = get_execution_role()\n", + "region = sagemaker_session.boto_region_name\n", + "\n", + "# S3 bucket for saving code and model artifacts.\n", + "bucket = sagemaker_session.default_bucket()\n", + "\n", + "prefix = 'sagemaker/DEMO-breast-cancer-prediction-v3' # place to upload training files within the bucket" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "nbpresent": { + "id": "b2548d66-6f8f-426f-9cda-7a3cd1459abd" + } + }, + "source": [ + "Now we'll import the Python libraries we'll need." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "nbpresent": { + "id": "bb88eea9-27f3-4e47-9133-663911ea09a9" + } + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import time\n", + "import json" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "nbpresent": { + "id": "142777ae-c072-448e-b941-72bc75735d01" + } + }, + "source": [ + "---\n", + "## Data\n", + "\n", + "Data Source: https://archive.ics.uci.edu/ml/machine-learning-databases/breast-cancer-wisconsin/wdbc.data\n", + " https://www.kaggle.com/uciml/breast-cancer-wisconsin-data\n", + "\n", + "Let's download the data and save it in the local folder with the name data.csv and take a look at it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "nbpresent": { + "id": "f8976dad-6897-4c7e-8c95-ae2f53070ef5" + } + }, + "outputs": [], + "source": [ + "data = pd.read_csv('https://archive.ics.uci.edu/ml/machine-learning-databases/breast-cancer-wisconsin/wdbc.data', header = None)\n", + "\n", + "# specify columns extracted from wbdc.names\n", + "data.columns = [\"id\",\"diagnosis\",\"radius_mean\",\"texture_mean\",\"perimeter_mean\",\"area_mean\",\"smoothness_mean\",\n", + " \"compactness_mean\",\"concavity_mean\",\"concave points_mean\",\"symmetry_mean\",\"fractal_dimension_mean\",\n", + " \"radius_se\",\"texture_se\",\"perimeter_se\",\"area_se\",\"smoothness_se\",\"compactness_se\",\"concavity_se\",\n", + " \"concave points_se\",\"symmetry_se\",\"fractal_dimension_se\",\"radius_worst\",\"texture_worst\",\n", + " \"perimeter_worst\",\"area_worst\",\"smoothness_worst\",\"compactness_worst\",\"concavity_worst\",\n", + " \"concave points_worst\",\"symmetry_worst\",\"fractal_dimension_worst\"] \n", + "\n", + "# save the data\n", + "data.to_csv(\"data.csv\", sep=',', index=False)\n", + "\n", + "# print the shape of the data file\n", + "print(data.shape)\n", + "\n", + "# show the top few rows\n", + "display(data.head())\n", + "\n", + "# describe the data object\n", + "display(data.describe())\n", + "\n", + "# we will also summarize the categorical field diganosis \n", + "display(data.diagnosis.value_counts())\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Key observations:\n", + "* Data has 569 observations and 32 columns.\n", + "* First field is 'id'.\n", + "* Second field, 'diagnosis', is an indicator of the actual diagnosis ('M' = Malignant; 'B' = Benign).\n", + "* There are 30 other numeric features available for prediction." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create Features and Labels\n", + "#### Split the data into 80% training, 10% validation and 10% testing." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "rand_split = np.random.rand(len(data))\n", + "train_list = rand_split < 0.8\n", + "val_list = (rand_split >= 0.8) & (rand_split < 0.9)\n", + "test_list = rand_split >= 0.9\n", + "\n", + "data_train = data[train_list]\n", + "data_val = data[val_list]\n", + "data_test = data[test_list]\n", + "\n", + "train_y = ((data_train.iloc[:,1] == 'M') +0).to_numpy();\n", + "train_X = data_train.iloc[:,2:].to_numpy();\n", + "\n", + "val_y = ((data_val.iloc[:,1] == 'M') +0).to_numpy();\n", + "val_X = data_val.iloc[:,2:].to_numpy();\n", + "\n", + "test_y = ((data_test.iloc[:,1] == 'M') +0).to_numpy();\n", + "test_X = data_test.iloc[:,2:].to_numpy();" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "nbpresent": { + "id": "ff9d10f9-b611-423b-80da-6dcdafd1c8b9" + } + }, + "source": [ + "Now, we'll convert the datasets to CSV format and upload to S3. Linear Learner expects the label in the first column." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "nbpresent": { + "id": "cd8e3431-79d9-40b6-91d1-d67cd61894e7" + } + }, + "outputs": [], + "source": [ + "# Training data - label in first column\n", + "train_df = pd.DataFrame(train_X)\n", + "train_df.insert(0, 'label', train_y)\n", + "train_file = 'linear_train.csv'\n", + "train_df.to_csv(train_file, header=False, index=False)\n", + "\n", + "# Upload to S3\n", + "boto3.Session().resource('s3').Bucket(bucket).Object(\n", + " os.path.join(prefix, 'train', train_file)\n", + ").upload_file(train_file)\n", + "\n", + "train_s3_uri = f\"s3://{bucket}/{prefix}/train/{train_file}\"\n", + "print(f\"Training data uploaded to: {train_s3_uri}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "nbpresent": { + "id": "71cbcebd-a2a5-419e-8e50-b2bc0909f564" + } + }, + "source": [ + "Next we'll convert and upload the validation dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "nbpresent": { + "id": "bd113b8e-adc1-4091-a26f-a426149fe604" + } + }, + "outputs": [], + "source": [ + "# Validation data - label in first column\n", + "val_df = pd.DataFrame(val_X)\n", + "val_df.insert(0, 'label', val_y)\n", + "validation_file = 'linear_validation.csv'\n", + "val_df.to_csv(validation_file, header=False, index=False)\n", + "\n", + "# Upload to S3\n", + "boto3.Session().resource('s3').Bucket(bucket).Object(\n", + " os.path.join(prefix, 'validation', validation_file)\n", + ").upload_file(validation_file)\n", + "\n", + "validation_s3_uri = f\"s3://{bucket}/{prefix}/validation/{validation_file}\"\n", + "print(f\"Validation data uploaded to: {validation_s3_uri}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "nbpresent": { + "id": "f3b125ad-a2d5-464c-8cfa-bd203034eee4" + } + }, + "source": [ + "---\n", + "## Train\n", + "\n", + "Now we can begin to specify our linear model using SageMaker V3's ModelTrainer. Amazon SageMaker's Linear Learner actually fits many models in parallel, each with slightly different hyperparameters, and then returns the one with the best fit. This functionality is automatically enabled. We can influence this using parameters like:\n", + "\n", + "- `num_models` to increase to total number of models run. The specified parameters will always be one of those models, but the algorithm also chooses models with nearby parameter values in order to find a solution nearby that may be more optimal. In this case, we're going to use the max of 32.\n", + "- `loss` which controls how we penalize mistakes in our model estimates. For this case, let's use absolute loss as we haven't spent much time cleaning the data, and absolute loss will be less sensitive to outliers.\n", + "- `wd` or `l1` which control regularization. Regularization can prevent model overfitting by preventing our estimates from becoming too finely tuned to the training data, which can actually hurt generalizability. In this case, we'll leave these parameters as their default \"auto\" though." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Get container image for linear-learner using" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# V3: Use image_uris.retrieve instead of get_image_uri\n", + "container = image_uris.retrieve(\n", + " framework='linear-learner',\n", + " region=region\n", + ")\n", + "print(f\"Using container: {container}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create and train model using ModelTrainer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "nbpresent": { + "id": "397fb60a-c48b-453f-88ea-4d832b70c919" + } + }, + "outputs": [], + "source": [ + "%%time\n", + "\n", + "# V3: Use ModelTrainer instead of boto3 create_training_job\n", + "trainer = ModelTrainer(\n", + " training_image=container,\n", + " role=role,\n", + " compute=Compute(\n", + " instance_count=1,\n", + " instance_type=\"ml.c4.2xlarge\",\n", + " volume_size_in_gb=10\n", + " ),\n", + " hyperparameters={\n", + " \"feature_dim\": \"30\",\n", + " \"mini_batch_size\": \"100\",\n", + " \"predictor_type\": \"regressor\",\n", + " \"epochs\": \"10\",\n", + " \"num_models\": \"32\",\n", + " \"loss\": \"absolute_loss\"\n", + " },\n", + " sagemaker_session=sagemaker_session\n", + ")\n", + "\n", + "# Train the model\n", + "training_job = trainer.train(\n", + " input_data_config=[\n", + " InputData(\n", + " channel_name=\"train\",\n", + " data_source=train_s3_uri,\n", + " content_type=\"text/csv\"\n", + " ),\n", + " InputData(\n", + " channel_name=\"validation\",\n", + " data_source=validation_s3_uri,\n", + " content_type=\"text/csv\"\n", + " )\n", + " ],\n", + " wait=True,\n", + " logs=True\n", + ")\n", + "\n", + "# Get the training job from the trainer\n", + "training_job = trainer._latest_training_job\n", + "print(f\"Training job completed: {training_job.training_job_name}\")\n", + "print(f\"Model artifacts: {training_job.model_artifacts.s3_model_artifacts}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "nbpresent": { + "id": "2adcc348-9ab5-4a8a-8139-d0ecd740208a" + } + }, + "source": [ + "---\n", + "## Host\n", + "\n", + "Now that we've trained the linear algorithm on our data, let's setup a model which can later be hosted using V3 resources." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "nbpresent": { + "id": "c88fb868-01d2-4991-8953-28814c022bdc" + } + }, + "outputs": [], + "source": [ + "# V3: Create Model using resources\n", + "model_name = f\"breast-cancer-model-{time.strftime('%Y-%m-%d-%H-%M-%S', time.gmtime())}\"\n", + "\n", + "model = Model.create(\n", + " model_name=model_name,\n", + " execution_role_arn=role,\n", + " primary_container={\n", + " 'image': container,\n", + " 'model_data_url': training_job.model_artifacts.s3_model_artifacts\n", + " },\n", + " session=sagemaker_session.boto_session\n", + ")\n", + "\n", + "print(f\"Model created: {model.model_name}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create Endpoint Configuration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# V3: Create EndpointConfig\n", + "endpoint_config_name = f\"breast-cancer-config-{time.strftime('%Y-%m-%d-%H-%M-%S', time.gmtime())}\"\n", + "\n", + "endpoint_config = EndpointConfig.create(\n", + " endpoint_config_name=endpoint_config_name,\n", + " production_variants=[{\n", + " 'variant_name': 'AllTraffic',\n", + " 'model_name': model.model_name,\n", + " 'instance_type': 'ml.m4.xlarge',\n", + " 'initial_instance_count': 1\n", + " }],\n", + " session=sagemaker_session.boto_session\n", + ")\n", + "\n", + "print(f\"Endpoint config created: {endpoint_config.endpoint_config_name}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create and Deploy Endpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "\n", + "# V3: Create Endpoint\n", + "endpoint_name = f\"breast-cancer-endpoint-{time.strftime('%Y%m%d%H%M', time.gmtime())}\"\n", + "\n", + "endpoint = Endpoint.create(\n", + " endpoint_name=endpoint_name,\n", + " endpoint_config_name=endpoint_config.endpoint_config_name,\n", + " session=sagemaker_session.boto_session\n", + ")\n", + "\n", + "print(f\"Endpoint created: {endpoint.endpoint_name}\")\n", + "print(\"Waiting for endpoint to be in service...\")\n", + "\n", + "# Wait for endpoint to be ready\n", + "endpoint.wait_for_status('InService')\n", + "\n", + "print(f\"Endpoint is ready: {endpoint.endpoint_status}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Predict\n", + "### Predict on Test Data\n", + "\n", + "Now that we have our hosted endpoint, we can generate statistical predictions from it. Let's predict on our test dataset to understand how accurate our model is.\n", + "\n", + "There are many metrics to measure classification accuracy. Common examples include include:\n", + "- Precision\n", + "- Recall\n", + "- F1 measure\n", + "- Area under the ROC curve - AUC\n", + "- Total Classification Accuracy \n", + "- Mean Absolute Error\n", + "\n", + "For our example, we'll keep things simple and use total classification accuracy as our metric of choice. We will also evaluate Mean Absolute Error (MAE) as the linear-learner has been optimized using this metric, not necessarily because it is a relevant metric from an application point of view. We'll compare the performance of the linear-learner against a naive benchmark prediction which uses majority class observed in the training data set for prediction on the test data.\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Function to convert an array to a csv" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import io\n", + "def np2csv(arr):\n", + " csv = io.BytesIO()\n", + " np.savetxt(csv, arr, delimiter=',', fmt='%g')\n", + " return csv.getvalue().decode().rstrip()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we'll invoke the endpoint to get predictions using V3." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# V3: Use endpoint.invoke instead of runtime.sagemaker\n", + "payload = np2csv(test_X)\n", + "\n", + "response = endpoint.invoke(\n", + " body=payload,\n", + " content_type='text/csv'\n", + ")\n", + "\n", + "result = json.loads(response.body.read().decode())\n", + "test_pred = np.array([r['score'] for r in result['predictions']])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's compare linear learner based mean absolute prediction errors from a baseline prediction which uses majority class to predict every instance." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_mae_linear = np.mean(np.abs(test_y - test_pred))\n", + "test_mae_baseline = np.mean(np.abs(test_y - np.median(train_y))) ## training median as baseline predictor\n", + "\n", + "print(\"Test MAE Baseline :\", round(test_mae_baseline, 3))\n", + "print(\"Test MAE Linear:\", round(test_mae_linear,3))\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's compare predictive accuracy using a classification threshold of 0.5 for the predicted and compare against the majority class prediction from training data set" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_pred_class = (test_pred > 0.5)+0;\n", + "test_pred_baseline = np.repeat(np.median(train_y), len(test_y))\n", + "\n", + "prediction_accuracy = np.mean((test_y == test_pred_class))*100\n", + "baseline_accuracy = np.mean((test_y == test_pred_baseline))*100\n", + "\n", + "print(\"Prediction Accuracy:\", round(prediction_accuracy,1), \"%\")\n", + "print(\"Baseline Accuracy:\", round(baseline_accuracy,1), \"%\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Cleanup\n", + "Run the cell below to delete endpoint once you are done." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# V3: Use endpoint.delete()\n", + "endpoint.delete()\n", + "print(f\"Endpoint {endpoint_name} deleted\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Extensions\n", + "\n", + "- Our linear model does a good job of predicting breast cancer and has an overall accuracy of close to 92%. We can re-run the model with different values of the hyper-parameters, loss functions etc and see if we get improved prediction. Re-running the model with further tweaks to these hyperparameters may provide more accurate out-of-sample predictions.\n", + "- We also did not do much feature engineering. We can create additional features by considering cross-product/intreaction of multiple features, squaring or raising higher powers of the features to induce non-linear effects, etc. If we expand the features using non-linear terms and interactions, we can then tweak the regulaization parameter to optimize the expanded model and hence generate improved forecasts.\n", + "- As a further extension, we can use many of non-linear models available through SageMaker such as XGBoost, MXNet etc.\n" + ] + } + ], + "metadata": { + "celltoolbar": "Tags", + "kernelspec": { + "display_name": "conda_python3", + "language": "python", + "name": "conda_python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + }, + "notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the License). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the license file accompanying this file. This file is distributed on an AS IS BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License." + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/v3-examples/training-examples/jumpstart-object-detection.ipynb b/v3-examples/training-examples/jumpstart-object-detection.ipynb new file mode 100644 index 0000000000..030cdf3f25 --- /dev/null +++ b/v3-examples/training-examples/jumpstart-object-detection.ipynb @@ -0,0 +1,733 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Amazon SageMaker JumpStart Object Detection for Bird Species\n", + "\n", + "## Introduction\n", + "\n", + "Object detection is the process of identifying and localizing objects in an image. A typical object detection solution takes an image as input and provides a bounding box on the image where an object of interest is found. It also identifies what type of object the box encapsulates.\n", + "\n", + "This notebook is an end-to-end example showing how Amazon SageMaker JumpStart can be used to train an object detection model on a custom dataset. We use the [Caltech Birds (CUB 200 2011)](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) dataset, which contains images of 200 bird species with bounding box annotations.\n", + "\n", + "### What is JumpStart?\n", + "\n", + "JumpStart provides pre-trained models that can be fine-tuned on your custom data without writing training scripts. This notebook uses the **Faster R-CNN with ResNet-50** backbone, which is a popular object detection architecture that balances accuracy and speed.\n", + "\n", + "### What You'll Learn\n", + "\n", + "- How to prepare custom data in COCO format for object detection\n", + "- How to train a JumpStart model on your own dataset\n", + "- How to deploy and test an object detection model\n", + "- How to visualize detection results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import json\n", + "import time\n", + "from sagemaker.core.helper.session_helper import Session, get_execution_role\n", + "from sagemaker.train.model_trainer import ModelTrainer\n", + "from sagemaker.train.configs import InputData\n", + "from sagemaker.core.jumpstart.configs import JumpStartConfig\n", + "\n", + "sagemaker_session = Session()\n", + "bucket = sagemaker_session.default_bucket()\n", + "role = get_execution_role()\n", + "\n", + "print(f'Bucket: {bucket}')\n", + "print(f'Role: {role}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Download and Prepare CUB Dataset\n", + "\n", + "The [Caltech Birds (CUB 200 2011)](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) dataset contains 11,788 images across 200 bird species. Each species comes with around 60 images, with a typical size of about 350 pixels by 500 pixels. Bounding boxes are provided for each bird in the image.\n", + "\n", + "For this demonstration, we'll use a subset of 5 bird species to keep training time manageable. The same approach works for all 200 species." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Download CUB-200-2011 dataset\n", + "!wget -q https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz\n", + "!tar -xzf CUB_200_2011.tgz\n", + "print('Dataset downloaded and extracted')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Create COCO Format Annotations\n", + "\n", + "JumpStart object detection models expect data in COCO (Common Objects in Context) format, which is a standard format for object detection datasets.\n", + "\n", + "### COCO Format Structure\n", + "\n", + "A COCO dataset consists of:\n", + "- **images**: List of image metadata (id, filename, width, height)\n", + "- **annotations**: List of bounding boxes with category labels\n", + "- **categories**: List of object categories (id, name)\n", + "\n", + "### Important Format Requirements\n", + "\n", + "1. **Category IDs start at 1**: COCO reserves category 0 for background, so your object categories should be 1, 2, 3, etc.\n", + "2. **Bounding box format**: Use corner coordinates `[x_min, y_min, x_max, y_max]`\n", + "3. **Coordinate validation**: Ensure all bounding boxes are within image bounds and have positive dimensions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from PIL import Image\n", + "\n", + "def create_coco_dataset(base_dir, num_classes=5):\n", + " \"\"\"Create COCO format annotations with 1-indexed categories.\"\"\"\n", + " images_dict = {}\n", + " with open(os.path.join(base_dir, 'images.txt')) as f:\n", + " for line in f:\n", + " img_id, img_path = line.strip().split(' ', 1)\n", + " images_dict[img_id] = os.path.basename(img_path)\n", + " \n", + " bboxes_dict = {}\n", + " with open(os.path.join(base_dir, 'bounding_boxes.txt')) as f:\n", + " for line in f:\n", + " parts = line.strip().split()\n", + " img_id = parts[0]\n", + " x, y, w, h = map(float, parts[1:5])\n", + " bboxes_dict[img_id] = [x, y, w, h]\n", + " \n", + " labels_dict = {}\n", + " with open(os.path.join(base_dir, 'image_class_labels.txt')) as f:\n", + " for line in f:\n", + " img_id, class_id = line.strip().split()\n", + " labels_dict[img_id] = int(class_id) # Keep 1-indexed\n", + " \n", + " split_dict = {}\n", + " with open(os.path.join(base_dir, 'train_test_split.txt')) as f:\n", + " for line in f:\n", + " img_id, is_train = line.strip().split()\n", + " split_dict[img_id] = int(is_train) == 1\n", + " \n", + " valid_classes = list(range(1, num_classes + 1)) # 1-indexed: [1, 2, 3, 4, 5]\n", + " \n", + " # Combine train and val into single dataset\n", + " coco = {\n", + " 'images': [],\n", + " 'annotations': [],\n", + " 'categories': [{'id': i, 'name': f'bird_class_{i}'} for i in valid_classes]\n", + " }\n", + " \n", + " ann_id = 0\n", + " skipped = 0\n", + " \n", + " for img_id in sorted(images_dict.keys()):\n", + " class_id = labels_dict[img_id]\n", + " if class_id not in valid_classes:\n", + " continue\n", + " \n", + " # Get image dimensions\n", + " img_path_full = None\n", + " with open(os.path.join(base_dir, 'images.txt')) as f:\n", + " for line in f:\n", + " if line.startswith(img_id + ' '):\n", + " img_path_full = line.strip().split(' ', 1)[1]\n", + " break\n", + " \n", + " img_full_path = os.path.join(base_dir, 'images', img_path_full)\n", + " img = Image.open(img_full_path)\n", + " width, height = img.size\n", + " \n", + " x, y, w, h = bboxes_dict[img_id]\n", + " \n", + " # Fix negative dimensions\n", + " if w < 0:\n", + " x = x + w\n", + " w = abs(w)\n", + " if h < 0:\n", + " y = y + h\n", + " h = abs(h)\n", + " \n", + " # Clamp to image bounds\n", + " x = max(0, x)\n", + " y = max(0, y)\n", + " w = min(w, width - x)\n", + " h = min(h, height - y)\n", + " \n", + " # Skip invalid boxes\n", + " if w <= 0 or h <= 0:\n", + " print(f'Skipping image {img_id} with invalid bbox: [{x}, {y}, {w}, {h}]')\n", + " skipped += 1\n", + " continue\n", + " \n", + " # Add image and annotation\n", + " coco['images'].append({\n", + " 'id': int(img_id),\n", + " 'file_name': images_dict[img_id],\n", + " 'width': width,\n", + " 'height': height\n", + " })\n", + " \n", + " coco['annotations'].append({\n", + " 'id': ann_id,\n", + " 'image_id': int(img_id),\n", + " 'category_id': class_id,\n", + " 'bbox': [x, y, x + w, y + h], # Convert to [x_min, y_min, x_max, y_max] for PyTorch\n", + " 'area': w * h,\n", + " 'iscrowd': 0\n", + " })\n", + " ann_id += 1\n", + " \n", + " os.makedirs('annotations', exist_ok=True)\n", + " with open('annotations/combined.json', 'w') as f:\n", + " json.dump(coco, f)\n", + " \n", + " print(f'Total: {len(coco[\"images\"])} images, {len(coco[\"annotations\"])} annotations, {skipped} skipped')\n", + " return coco\n", + "\n", + "coco_data = create_coco_dataset('CUB_200_2011', num_classes=5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Prepare Flat Image Structure\n", + "\n", + "The CUB dataset organizes images in nested folders by species (e.g., `001.Black_footed_Albatross/image1.jpg`). However, JumpStart expects all images in a single flat directory.\n", + "\n", + "We'll copy all images to a flat directory structure where each image filename is unique. The COCO annotations file will reference these filenames." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import shutil\n", + "\n", + "# Create flat image directory\n", + "if os.path.exists('flat_images'):\n", + " shutil.rmtree('flat_images')\n", + "os.makedirs('flat_images', exist_ok=True)\n", + "\n", + "# Copy all images to flat directory\n", + "print('Creating flat image structure...')\n", + "for root, dirs, files in os.walk('CUB_200_2011/images'):\n", + " for file in files:\n", + " if file.endswith(('.jpg', '.jpeg', '.png')):\n", + " src = os.path.join(root, file)\n", + " dst = os.path.join('flat_images', file)\n", + " shutil.copy2(src, dst)\n", + "\n", + "print(f'Copied {len(os.listdir(\"flat_images\"))} images to flat directory')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Upload to S3\n", + "\n", + "SageMaker training jobs read data from S3. We need to upload our prepared dataset in the following structure:\n", + "\n", + "```\n", + "s3://bucket/prefix/train/\n", + "├── images/\n", + "│ ├── image1.jpg\n", + "│ ├── image2.jpg\n", + "│ └── ...\n", + "└── annotations.json\n", + "```\n", + "\n", + "**Important**: JumpStart expects all data in a single `training` channel. The training script will automatically split it 80% for training and 20% for validation. Do not create separate train and validation folders.\n", + "\n", + "We use a timestamped prefix to ensure we're using fresh data and not cached versions from previous runs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Use timestamped prefix to avoid caching\n", + "prefix = f'jumpstart-od-birds-{int(time.time())}'\n", + "train_s3 = f's3://{bucket}/{prefix}/train'\n", + "\n", + "print(f'Uploading to: {train_s3}')\n", + "\n", + "# Upload images\n", + "!aws s3 sync flat_images {train_s3}/images/ --quiet\n", + "\n", + "# Upload combined annotations\n", + "!aws s3 cp annotations/combined.json {train_s3}/annotations.json\n", + "\n", + "print('Upload complete!')\n", + "print(f'\\nData location: {train_s3}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Verify Upload\n", + "\n", + "Before starting a training job, it's good practice to verify that:\n", + "1. Data was uploaded successfully to S3\n", + "2. The annotations file is valid JSON\n", + "3. The number of images and annotations match expectations\n", + "\n", + "This helps catch issues early before spending time and money on training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Verify S3 data\n", + "!aws s3 ls {train_s3}/\n", + "\n", + "# Check annotation count\n", + "result = !aws s3 cp {train_s3}/annotations.json - | python3 -c \"import json, sys; d=json.load(sys.stdin); print(f'Images: {len(d[\\\"images\\\"])}, Annotations: {len(d[\\\"annotations\\\"])}')\"\n", + "print(result[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Train the Model\n", + "\n", + "Now we're ready to train our object detection model using JumpStart.\n", + "\n", + "### Model Selection\n", + "\n", + "We use `pytorch-od1-fasterrcnn-resnet50-fpn`, which is a Faster R-CNN model with a ResNet-50 backbone and Feature Pyramid Network (FPN). This architecture is well-suited for detecting objects of various sizes.\n", + "\n", + "### Training Configuration\n", + "\n", + "The `ModelTrainer.from_jumpstart_config()` method automatically configures:\n", + "- The training container image\n", + "- Default hyperparameters optimized for the model\n", + "- Instance type for training\n", + "\n", + "We only need to provide:\n", + "- The model ID\n", + "- The S3 location of our training data\n", + "- A base name for the training job\n", + "\n", + "The training process will:\n", + "1. Load the pre-trained Faster R-CNN model\n", + "2. Replace the final detection layer to match our 5 bird categories\n", + "3. Fine-tune the model on our bird images\n", + "4. Save the trained model artifacts to S3" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Select JumpStart model\n", + "model_id = 'pytorch-od1-fasterrcnn-resnet50-fpn'\n", + "\n", + "js_config = JumpStartConfig(model_id=model_id)\n", + "\n", + "# Create input data config - ONLY training channel\n", + "train_input = InputData(\n", + " channel_name='training',\n", + " data_source=train_s3\n", + ")\n", + "\n", + "# Create trainer\n", + "trainer = ModelTrainer.from_jumpstart_config(\n", + " jumpstart_config=js_config,\n", + " base_job_name='jumpstart-od-birds',\n", + " sagemaker_session=sagemaker_session,\n", + " input_data_config=[train_input] # Only training channel\n", + ")\n", + "\n", + "print('Trainer created')\n", + "print(f'Hyperparameters: {trainer.hyperparameters}')\n", + "print(f'Input data: {train_input}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Start training\n", + "trainer.train()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import boto3\n", + "# Check training job logs for detailed error\n", + "training_job = trainer._latest_training_job\n", + "print(f\"Training job name: {training_job.training_job_name}\")\n", + "print(f\"Training job status: {training_job.training_job_status}\")\n", + "print(f\"\\nFailure reason: {training_job.failure_reason}\")\n", + "\n", + "# Get CloudWatch logs\n", + "logs_client = boto3.client('logs', region_name=sagemaker_session.boto_region_name)\n", + "\n", + "log_group = '/aws/sagemaker/TrainingJobs'\n", + "\n", + "try:\n", + " # List log streams for this job\n", + " streams = logs_client.describe_log_streams(\n", + " logGroupName=log_group,\n", + " logStreamNamePrefix=training_job.training_job_name\n", + " )\n", + " \n", + " if streams['logStreams']:\n", + " stream_name = streams['logStreams'][0]['logStreamName']\n", + " print(f\"\\nLog stream: {stream_name}\")\n", + " \n", + " # Get last 100 log events\n", + " events = logs_client.get_log_events(\n", + " logGroupName=log_group,\n", + " logStreamName=stream_name,\n", + " limit=100,\n", + " startFromHead=False\n", + " )\n", + " \n", + " print(\"\\n=== Last 100 log lines ===\")\n", + " for event in events['events']:\n", + " print(event['message'])\n", + "except Exception as e:\n", + " print(f\"Could not fetch logs: {e}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Deploy the Model\n", + "\n", + "After training completes, we need to deploy the model to an endpoint for real-time inference.\n", + "\n", + "### Deployment Process\n", + "\n", + "Deployment involves three steps:\n", + "\n", + "1. **Create a Model**: Defines the model artifacts and inference container image\n", + "2. **Create an Endpoint Configuration**: Specifies the instance type and count\n", + "3. **Create an Endpoint**: Deploys the model to a running instance\n", + "\n", + "### Instance Selection\n", + "\n", + "We use `ml.g4dn.xlarge`, which is a GPU instance. Object detection models with deep neural networks require GPU acceleration for fast inference. CPU instances would be too slow for practical use.\n", + "\n", + "The endpoint will remain running until you delete it, so remember to clean up when done to avoid charges." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get model artifacts from training job\n", + "training_job = trainer._latest_training_job\n", + "training_job.refresh()\n", + "model_data = training_job.model_artifacts.s3_model_artifacts\n", + "\n", + "print(f'Model artifacts: {model_data}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sagemaker.core import image_uris\n", + "\n", + "image = image_uris.retrieve(\n", + " framework='pytorch',\n", + " region=sagemaker_session.boto_region_name,\n", + " image_scope='inference',\n", + " instance_type='ml.g4dn.xlarge',\n", + " version='1.8.1',\n", + " py_version='py3',\n", + ")\n", + "print(image)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sagemaker.core.resources import Model, EndpointConfig, Endpoint\n", + "\n", + "# Create model\n", + "model = Model.create(\n", + " model_name=f'jumpstart-od-birds-{int(time.time())}',\n", + " execution_role_arn=role,\n", + " primary_container={\n", + " 'image': image,\n", + " 'model_data_url': model_data\n", + " }\n", + ")\n", + "\n", + "# Create endpoint config\n", + "endpoint_config = EndpointConfig.create(\n", + " endpoint_config_name=f'jumpstart-od-birds-config-{int(time.time())}',\n", + " production_variants=[{\n", + " 'variant_name': 'AllTraffic',\n", + " 'model_name': model.model_name,\n", + " 'instance_type': 'ml.g4dn.xlarge',\n", + " 'initial_instance_count': 1\n", + " }]\n", + ")\n", + "\n", + "# Create endpoint\n", + "endpoint = Endpoint.create(\n", + " endpoint_name=f'jumpstart-od-birds-{int(time.time())}',\n", + " endpoint_config_name=endpoint_config.endpoint_config_name\n", + ")\n", + "\n", + "endpoint.wait_for_status('InService')\n", + "print(f'Endpoint: {endpoint.endpoint_name}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 8. Test the Model\n", + "\n", + "With our model deployed, we can test it by sending images and examining the predictions.\n", + "\n", + "### How Inference Works\n", + "\n", + "1. **Load an image** from the test set as raw bytes\n", + "2. **Send to endpoint** using the SageMaker Runtime client\n", + "3. **Parse the response** which contains detected objects\n", + "\n", + "### Understanding the Response\n", + "\n", + "The model returns three parallel arrays:\n", + "- `normalized_boxes`: Bounding box coordinates in [0, 1] range as `[x_min, y_min, x_max, y_max]`\n", + "- `classes`: Category IDs (1-5 for our bird species)\n", + "- `scores`: Confidence scores (0-1) for each detection\n", + "\n", + "Each index represents one detected object. For example:\n", + "- `normalized_boxes[0]` = `[0.2, 0.3, 0.8, 0.9]` (bird occupies 20-80% width, 30-90% height)\n", + "- `classes[0]` = `1` (bird species 1)\n", + "- `scores[0]` = `0.95` (95% confident)\n", + "\n", + "Let's test with a sample bird image from our dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def visualize_detection(img_file, dets, classes=[], thresh=0.5):\n", + " \"\"\"\n", + " Visualize detections with bounding boxes.\n", + " \n", + " Parameters:\n", + " - img_file: path to image\n", + " - dets: detections as [[class, score, x_min, y_min, x_max, y_max], ...]\n", + " - classes: list of class names\n", + " - thresh: confidence threshold\n", + " \"\"\"\n", + " import random\n", + " import matplotlib.pyplot as plt\n", + " import matplotlib.image as mpimg\n", + " \n", + " img = mpimg.imread(img_file)\n", + " plt.imshow(img)\n", + " height, width = img.shape[:2]\n", + " colors = {}\n", + " num_detections = 0\n", + " \n", + " for det in dets:\n", + " klass, score, x0, y0, x1, y1 = det\n", + " if score < thresh:\n", + " continue\n", + " num_detections += 1\n", + " cls_id = int(klass)\n", + " \n", + " if cls_id not in colors:\n", + " colors[cls_id] = (random.random(), random.random(), random.random())\n", + " \n", + " xmin = int(x0 * width)\n", + " ymin = int(y0 * height)\n", + " xmax = int(x1 * width)\n", + " ymax = int(y1 * height)\n", + " \n", + " rect = plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, \n", + " fill=False, edgecolor=colors[cls_id], linewidth=3.5)\n", + " plt.gca().add_patch(rect)\n", + " \n", + " class_name = classes[cls_id - 1] if classes and cls_id <= len(classes) else str(cls_id)\n", + " print(f'{class_name}, {score:.3f}')\n", + " plt.gca().text(xmin, ymin - 2, f'{class_name} {score:.3f}',\n", + " bbox=dict(facecolor=colors[cls_id], alpha=0.5),\n", + " fontsize=12, color='white')\n", + " \n", + " print(f'Number of detections: {num_detections}')\n", + " plt.show()\n", + "\n", + "def predict_and_visualize(img_file, endpoint_name, thresh=0.5):\n", + " \"\"\"Run inference and visualize results.\"\"\"\n", + " runtime_client = sagemaker_session.sagemaker_runtime_client\n", + " \n", + " with open(img_file, 'rb') as f:\n", + " img_bytes = f.read()\n", + " \n", + " response = runtime_client.invoke_endpoint(\n", + " EndpointName=endpoint_name,\n", + " ContentType='application/x-image',\n", + " Body=img_bytes\n", + " )\n", + " \n", + " result = json.loads(response['Body'].read())\n", + " \n", + " # Convert to detection format: [class, score, x_min, y_min, x_max, y_max]\n", + " dets = []\n", + " for bbox, cls, score in zip(result['normalized_boxes'], result['classes'], result['scores']):\n", + " if score > 0.5: # Only include high-confidence detections\n", + " x_min, y_min, x_max, y_max = bbox\n", + " dets.append([cls, score, x_min, y_min, x_max, y_max])\n", + " \n", + " class_names = [f'bird_class_{i}' for i in range(1, 6)]\n", + " visualize_detection(img_file, dets, class_names, thresh)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Download Test Images\n", + "\n", + "Let's download some bird images that the model hasn't seen during training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import urllib.request\n", + "\n", + "# Download test images\n", + "test_image_urls = {\n", + " 'multi-goldfinch-1.jpg': 'https://t3.ftcdn.net/jpg/01/44/64/36/500_F_144643697_GJRUBtGc55KYSMpyg1Kucb9yJzvMQooW.jpg',\n", + " 'hummingbird-1.jpg': 'http://res.freestockphotos.biz/pictures/17/17875-hummingbird-close-up-pv.jpg'\n", + "}\n", + "\n", + "for filename, url in test_image_urls.items():\n", + " if not os.path.exists(filename):\n", + " print(f'Downloading {filename}...')\n", + " urllib.request.urlretrieve(url, filename)\n", + "\n", + "print('Downloaded 2 test images')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def test_model(endpoint_name):\n", + " \"\"\"Test model with downloaded bird images.\"\"\"\n", + " test_images = [\n", + " 'hummingbird-1.jpg',\n", + " 'multi-goldfinch-1.jpg',\n", + " ]\n", + " \n", + " for img in test_images:\n", + " if os.path.exists(img):\n", + " print(f'\\nTesting: {img}')\n", + " predict_and_visualize(img, endpoint_name, thresh=0.4)\n", + "\n", + "# Test the model\n", + "test_model(endpoint.endpoint_name)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 9. Cleanup\n", + "\n", + "To avoid ongoing charges, delete the endpoint when you're done testing.\n", + "\n", + "The endpoint runs on a GPU instance which incurs hourly charges even when not in use. Always clean up resources after experimentation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Delete the endpoint\n", + "endpoint.delete()\n", + "print(f'Deleted endpoint: {endpoint.endpoint_name}')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "py3.10.14", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/v3-examples/training-examples/neural-topic-model.ipynb b/v3-examples/training-examples/neural-topic-model.ipynb new file mode 100644 index 0000000000..464ab8e077 --- /dev/null +++ b/v3-examples/training-examples/neural-topic-model.ipynb @@ -0,0 +1,765 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# An Introduction to SageMaker Neural Topic Model (V3)\n", + "\n", + "***Unsupervised representation learning and topic extraction using Neural Topic Model***\n", + "\n", + "**This notebook has been migrated to SageMaker Python SDK V3**\n", + "\n", + "1. [Introduction](#Introduction)\n", + "1. [Data Preparation](#Data-Preparation)\n", + "1. [Model Training](#Model-Training)\n", + "1. [Model Hosting and Inference](#Model-Hosting-and-Inference)\n", + "1. [Model Exploration](#Model-Exploration)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "# Introduction\n", + "\n", + "Amazon SageMaker Neural Topic Model (NTM) is an unsupervised learning algorithm that attempts to describe a set of observations as a mixture of distinct categories. NTM is most commonly used to discover a user-specified number of topics shared by documents within a text corpus. Here each observation is a document, the features are the presence (or occurrence count) of each word, and the categories are the topics. Since the method is unsupervised, the topics are not specified upfront and are not guaranteed to align with how a human may naturally categorize documents. The topics are learned as a probability distribution over the words that occur in each document. Each document, in turn, is described as a mixture of topics. \n", + "\n", + "In this notebook, we will use the Amazon SageMaker NTM algorithm to train a model on the [20NewsGroups](https://archive.ics.uci.edu/ml/datasets/Twenty+Newsgroups) data set. This data set has been widely used as a topic modeling benchmark. \n", + "\n", + "The main goals of this notebook are as follows:\n", + "\n", + "1. learn how to obtain and store data for use in Amazon SageMaker,\n", + "2. create an AWS SageMaker training job on a data set to produce an NTM model,\n", + "3. use the model to perform inference with an Amazon SageMaker endpoint.\n", + "4. explore trained model and visualized learned topics\n", + "\n", + "If you would like to know more please check out the [SageMaker Neural Topic Model Documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/ntm.html)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "# Data Preparation\n", + "\n", + "The 20Newsgroups data set is a collection of approximately 20,000 newsgroup documents, partitioned (nearly) evenly across 20 different newsgroups. This collection has become a popular data set for experiments in text applications of machine learning techniques, such as text classification and text clustering. Here, we will see what topics we can learn from this set of documents with NTM." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Fetching Data Set\n", + "\n", + "First let's define the folder to hold the data and clean the content in it which might be from previous experiments." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import shutil\n", + "data_dir = '20_newsgroups'\n", + "if os.path.exists(data_dir):\n", + " shutil.rmtree(data_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!curl -O https://archive.ics.uci.edu/ml/machine-learning-databases/20newsgroups-mld/20_newsgroups.tar.gz" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!tar -xzf 20_newsgroups.tar.gz\n", + "!ls 20_newsgroups" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "folders = [os.path.join(data_dir,f) for f in sorted(os.listdir(data_dir)) if os.path.isdir(os.path.join(data_dir, f))]\n", + "file_list = [os.path.join(d,f) for d in folders for f in os.listdir(d)]\n", + "print('Number of documents:', len(file_list))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.datasets._twenty_newsgroups import strip_newsgroup_header, strip_newsgroup_quoting, strip_newsgroup_footer\n", + "data = []\n", + "for f in file_list:\n", + " with open(f, 'rb') as fin:\n", + " content = fin.read().decode('latin1') \n", + " content = strip_newsgroup_header(content)\n", + " content = strip_newsgroup_quoting(content)\n", + " content = strip_newsgroup_footer(content) \n", + " data.append(content)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## From Plain Text to Bag-of-Words (BOW)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install nltk\n", + "import nltk\n", + "nltk.download('punkt')\n", + "nltk.download('punkt_tab')\n", + "nltk.download('wordnet')\n", + "from nltk import word_tokenize \n", + "from nltk.stem import WordNetLemmatizer \n", + "import re\n", + "token_pattern = re.compile(r\"(?u)\\b\\w\\w+\\b\")\n", + "class LemmaTokenizer(object):\n", + " def __init__(self):\n", + " self.wnl = WordNetLemmatizer()\n", + " def __call__(self, doc):\n", + " return [self.wnl.lemmatize(t) for t in word_tokenize(doc) if len(t) >= 2 and re.match(\"[a-z].*\",t) \n", + " and re.match(token_pattern, t)]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "import numpy as np\n", + "from sklearn.feature_extraction.text import CountVectorizer\n", + "vocab_size = 2000\n", + "print('Tokenizing and counting, this may take a few minutes...')\n", + "start_time = time.time()\n", + "vectorizer = CountVectorizer(input='content', analyzer='word', stop_words='english',\n", + " tokenizer=LemmaTokenizer(), max_features=vocab_size, max_df=0.95, min_df=2)\n", + "vectors = vectorizer.fit_transform(data)\n", + "vocab_list = vectorizer.get_feature_names_out()\n", + "print('vocab size:', len(vocab_list))\n", + "print('vectors shape:', vectors.shape)\n", + "\n", + "idx = np.arange(vectors.shape[0])\n", + "np.random.shuffle(idx)\n", + "vectors = vectors[idx]\n", + "\n", + "print('Done. Time elapsed: {:.2f}s'.format(time.time() - start_time))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "threshold = 25\n", + "vectors = vectors[np.array(vectors.sum(axis=1)>threshold).reshape(-1,)]\n", + "print('removed short docs (<{} words)'.format(threshold)) \n", + "print(vectors.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import scipy.sparse as sparse\n", + "vectors = sparse.csr_matrix(vectors, dtype=np.float32)\n", + "print(type(vectors), vectors.dtype)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "n_train = int(0.8 * vectors.shape[0])\n", + "\n", + "train_vectors = vectors[:n_train, :]\n", + "test_vectors = vectors[n_train:, :]\n", + "\n", + "n_test = test_vectors.shape[0]\n", + "val_vectors = test_vectors[:n_test//2, :]\n", + "test_vectors = test_vectors[n_test//2:, :]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(train_vectors.shape, test_vectors.shape, val_vectors.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Store Data on S3\n", + "\n", + "**V3 Migration Note**: In V3, we use CSV format instead of RecordIO Protobuf format since `sagemaker.amazon.common` module is not available." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Setup AWS Credentials" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import boto3\n", + "from sagemaker.core.helper.session_helper import Session, get_execution_role\n", + "\n", + "sagemaker_session = Session()\n", + "role = get_execution_role()\n", + "region = sagemaker_session.boto_region_name\n", + "bucket = sagemaker_session.default_bucket()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prefix = '20newsgroups'\n", + "\n", + "train_prefix = os.path.join(prefix, 'train')\n", + "val_prefix = os.path.join(prefix, 'val')\n", + "output_prefix = os.path.join(prefix, 'output')\n", + "\n", + "s3_train_data = os.path.join('s3://', bucket, train_prefix)\n", + "s3_val_data = os.path.join('s3://', bucket, val_prefix)\n", + "output_path = os.path.join('s3://', bucket, output_prefix)\n", + "print('Training set location', s3_train_data)\n", + "print('Validation set location', s3_val_data)\n", + "print('Trained model will be saved at', output_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**V3 Migration**: Convert sparse matrices to CSV format and upload to S3" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Delete old files from S3\n", + "s3 = boto3.resource('s3')\n", + "bucket_obj = s3.Bucket(bucket)\n", + "\n", + "for obj in bucket_obj.objects.filter(Prefix='20newsgroups/train/'):\n", + " obj.delete()\n", + "for obj in bucket_obj.objects.filter(Prefix='20newsgroups/val/'):\n", + " obj.delete()\n", + "\n", + "print(\"Deleted old files. Now re-run the upload cells.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Upload CSV with extra column for NTM bug\n", + "def upload_csv_with_extra_col(sparray, bucket, prefix, template, n_parts):\n", + " chunk_size = sparray.shape[0] // n_parts\n", + " for i in range(n_parts):\n", + " start = i * chunk_size\n", + " end = (i + 1) * chunk_size if i + 1 < n_parts else sparray.shape[0]\n", + " \n", + " chunk = sparray[start:end].toarray().astype(int)\n", + " fname = template.format(i)\n", + " \n", + " with open(fname, 'w') as f:\n", + " for row in chunk:\n", + " # Add extra 0 column at the end for NTM bug\n", + " f.write(','.join(map(str, row)) + ',0\\n')\n", + " \n", + " s3_key = os.path.join(prefix, fname)\n", + " boto3.resource('s3').Bucket(bucket).upload_file(fname, s3_key)\n", + " print(f'Uploaded: s3://{bucket}/{s3_key}')\n", + " os.remove(fname)\n", + "\n", + "upload_csv_with_extra_col(train_vectors, bucket, train_prefix, 'train_part{}.csv', 8)\n", + "upload_csv_with_extra_col(val_vectors, bucket, val_prefix, 'val_part{}.csv', 1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sagemaker.core import image_uris\n", + "\n", + "container = image_uris.retrieve(\n", + " framework='ntm',\n", + " region=region\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sagemaker.train.model_trainer import ModelTrainer\n", + "from sagemaker.train.configs import InputData, Compute\n", + "\n", + "num_topics = 20\n", + "\n", + "print(f'Training with feature_dim={train_vectors.shape[1]}')\n", + "trainer = ModelTrainer(\n", + " training_image=container,\n", + " role=role,\n", + " compute=Compute(\n", + " instance_count=2,\n", + " instance_type='ml.c4.xlarge'\n", + " ),\n", + " hyperparameters={\n", + " 'num_topics': str(num_topics),\n", + " 'feature_dim': '2000',\n", + " 'mini_batch_size': '128',\n", + " 'epochs': '100',\n", + " 'num_patience_epochs': '5',\n", + " 'tolerance': '0.001'\n", + " },\n", + " sagemaker_session=sagemaker_session\n", + ")\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sagemaker.core.shapes.shapes import S3DataSource\n", + "\n", + "training_job = trainer.train(\n", + " input_data_config=[\n", + " InputData(\n", + " channel_name='train',\n", + " data_source=S3DataSource(\n", + " s3_data_type='S3Prefix',\n", + " s3_uri=s3_train_data,\n", + " s3_data_distribution_type='ShardedByS3Key'\n", + " ),\n", + " content_type='text/csv'\n", + " ),\n", + " InputData(\n", + " channel_name='validation',\n", + " data_source=s3_val_data,\n", + " content_type='text/csv'\n", + " )\n", + " ],\n", + " wait=True,\n", + " logs=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "training_job = trainer._latest_training_job\n", + "\n", + "print('Training job name: {}'.format(training_job.training_job_name))\n", + "print('Training job status: {}'.format(training_job.training_job_status))\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Model Hosting and Inference\n", + "\n", + "**V3 Migration**: Using resource classes (`Model`, `EndpointConfig`, `Endpoint`) instead of `deploy()`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sagemaker.core.resources import Model, EndpointConfig, Endpoint\n", + "import time\n", + "\n", + "model_name = f\"ntm-model-{int(time.time())}\"\n", + "endpoint_config_name = f\"ntm-endpoint-config-{int(time.time())}\"\n", + "endpoint_name = f\"ntm-endpoint-{int(time.time())}\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sagemaker.core.resources import Model\n", + "\n", + "# Create model from training job artifacts\n", + "model = Model.create(\n", + " model_name=f'ntm-model-{int(__import__(\"time\").time())}',\n", + " execution_role_arn=role,\n", + " primary_container={\n", + " 'image': container,\n", + " 'model_data_url': training_job.model_artifacts.s3_model_artifacts\n", + " }\n", + ")\n", + "\n", + "print(f'Model created: {model.model_name}')\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sagemaker.core.resources import EndpointConfig\n", + "\n", + "# Create endpoint configuration\n", + "endpoint_config = EndpointConfig.create(\n", + " endpoint_config_name=f'ntm-config-{int(__import__(\"time\").time())}',\n", + " production_variants=[{\n", + " 'variant_name': 'AllTraffic',\n", + " 'model_name': model.model_name,\n", + " 'initial_instance_count': 1,\n", + " 'instance_type': 'ml.m4.xlarge'\n", + " }]\n", + ")\n", + "\n", + "print(f'Endpoint config created: {endpoint_config.endpoint_config_name}')\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sagemaker.core.resources import Endpoint\n", + "\n", + "# Create endpoint\n", + "endpoint = Endpoint.create(\n", + " endpoint_name=f'ntm-endpoint-{int(__import__(\"time\").time())}',\n", + " endpoint_config_name=endpoint_config.endpoint_config_name\n", + ")\n", + "\n", + "print(f'Endpoint created: {endpoint.endpoint_name}')\n", + "endpoint.wait_for_status('InService')\n", + "print('Endpoint is ready!')\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print('Endpoint name: {}'.format(endpoint.endpoint_name))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Data Serialization/Deserialization\n", + "\n", + "**V3 Migration**: Using `endpoint.invoke()` method" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import numpy as np\n", + "\n", + "def np2csv(arr):\n", + " csv = '\\n'.join([','.join([str(x) for x in row]) for row in arr])\n", + " return csv\n", + "\n", + "test_data = np.array(test_vectors.todense())\n", + "payload = np2csv(test_data[:5])\n", + "\n", + "response = endpoint.invoke(\n", + " body=payload,\n", + " content_type='text/csv'\n", + ")\n", + "\n", + "results = json.loads(response.body.read().decode())\n", + "print(results)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "predictions = np.array([prediction['topic_weights'] for prediction in results['predictions']])\n", + "print(predictions)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "\n", + "fs = 12\n", + "df=pd.DataFrame(predictions.T)\n", + "df.plot(kind='bar', figsize=(16,4), fontsize=fs)\n", + "plt.ylabel('Topic assignment', fontsize=fs+2)\n", + "plt.xlabel('Topic ID', fontsize=fs+2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Stop / Close the Endpoint\n", + "\n", + "**V3 Migration**: Using `endpoint.delete()` method" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "endpoint.delete()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "# Model Exploration\n", + "\n", + "The trained NTM model contains learned topic representations. We can download and explore the model artifacts to understand the topics discovered in the 20 newsgroups dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get training job reference\n", + "training_job = trainer._latest_training_job\n", + "print(f\"Training job: {training_job.training_job_name}\")\n", + "print(f\"Status: {training_job.training_job_status}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get topic distributions from endpoint\n", + "import numpy as np\n", + "import json\n", + "\n", + "# Sample diverse documents from test set\n", + "test_data = np.array(test_vectors.todense())\n", + "sample_size = min(500, test_data.shape[0])\n", + "sample_indices = np.linspace(0, test_data.shape[0]-1, sample_size, dtype=int)\n", + "test_sample = test_data[sample_indices]\n", + "\n", + "print(f\"Using {sample_size} diverse samples\")\n", + "\n", + "# Get predictions\n", + "payload = np2csv(test_sample)\n", + "response = endpoint.invoke(body=payload, content_type=\"text/csv\")\n", + "results = json.loads(response.body.read().decode())\n", + "\n", + "# Extract topic distributions\n", + "topic_distributions = np.array([pred[\"topic_weights\"] for pred in results[\"predictions\"]])\n", + "print(f\"Topic distributions shape: {topic_distributions.shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Extract distinctive words per topic\n", + "from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS\n", + "\n", + "custom_stops = set([\n", + " \"don\", \"just\", \"think\", \"people\", \"like\", \"know\", \"time\", \"does\", \"said\",\n", + " \"did\", \"way\", \"say\", \"good\", \"right\", \"ve\", \"ll\", \"didn\", \"doesn\", \"isn\",\n", + " \"wasn\", \"aren\", \"god\", \"religion\", \"believe\", \"point\", \"things\", \"thing\",\n", + " \"make\", \"want\", \"going\", \"really\", \"question\", \"post\", \"better\", \"claim\"\n", + "])\n", + "all_stops = ENGLISH_STOP_WORDS.union(custom_stops)\n", + "\n", + "def get_distinctive_words(topic_idx, n_words=20):\n", + " topic_strengths = topic_distributions[:, topic_idx]\n", + " high_threshold = np.percentile(topic_strengths, 85)\n", + " high_mask = topic_strengths > high_threshold\n", + " low_threshold = np.percentile(topic_strengths, 50)\n", + " low_mask = topic_strengths < low_threshold\n", + " \n", + " if high_mask.sum() < 5 or low_mask.sum() < 5:\n", + " return []\n", + " \n", + " high_docs = test_sample[high_mask]\n", + " low_docs = test_sample[low_mask]\n", + " high_freq = (high_docs > 0).sum(axis=0) / high_mask.sum()\n", + " low_freq = (low_docs > 0).sum(axis=0) / low_mask.sum()\n", + " diff = high_freq - low_freq\n", + " \n", + " filtered_words = []\n", + " for idx in np.argsort(diff)[::-1]:\n", + " word = vocab_list[idx]\n", + " if word.lower() not in all_stops and len(word) > 2 and diff[idx] > 0:\n", + " filtered_words.append((word, diff[idx]))\n", + " if len(filtered_words) >= n_words:\n", + " break\n", + " return filtered_words\n", + "\n", + "print(\"Distinctive Topic Words:\")\n", + "print(\"=\"*60)\n", + "for topic_idx in range(min(5, topic_distributions.shape[1])):\n", + " words = get_distinctive_words(topic_idx, 10)\n", + " if words:\n", + " print(f\"\\nTopic {topic_idx}: \" + \", \".join([word for word, _ in words]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Visualize topic distributions\n", + "import matplotlib.pyplot as plt\n", + "\n", + "avg_topic_strength = topic_distributions.mean(axis=0)\n", + "plt.figure(figsize=(12, 6))\n", + "plt.bar(range(len(avg_topic_strength)), avg_topic_strength)\n", + "plt.xlabel(\"Topic Index\")\n", + "plt.ylabel(\"Average Strength\")\n", + "plt.title(\"Average Topic Strength Across Documents\")\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create word clouds for all topics\n", + "!pip install wordcloud -q\n", + "from wordcloud import WordCloud\n", + "\n", + "fig, axes = plt.subplots(5, 4, figsize=(20, 25))\n", + "for idx, ax in enumerate(axes.flat):\n", + " if idx < topic_distributions.shape[1]:\n", + " words = get_distinctive_words(idx, 50)\n", + " if words:\n", + " word_freq = {word: score for word, score in words}\n", + " wc = WordCloud(width=600, height=400,\n", + " background_color=\"white\",\n", + " colormap=\"tab20\").generate_from_frequencies(word_freq)\n", + " ax.imshow(wc, interpolation=\"bilinear\")\n", + " ax.set_title(f\"Topic {idx}\", fontsize=14, fontweight=\"bold\")\n", + " ax.axis(\"off\")\n", + " else:\n", + " ax.axis(\"off\")\n", + "plt.tight_layout()\n", + "plt.savefig(\"ntm_topics.png\", dpi=150, bbox_inches=\"tight\")\n", + "plt.show()\n", + "print(f\"Generated word clouds for {topic_distributions.shape[1]} topics\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}