diff --git a/.gitignore b/.gitignore index 8451607..b508df9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,6 @@ .venv/ -.DS_Store \ No newline at end of file +.DS_Store + +__pycache__/ +*.pyc +.pytest_cache/ \ No newline at end of file diff --git a/mta_ridership_project .ipynb b/mta_ridership_project .ipynb index 2bdea6b..51891ec 100644 --- a/mta_ridership_project .ipynb +++ b/mta_ridership_project .ipynb @@ -11,17 +11,16 @@ "**Dataset:** MTA Daily Ridership Data: Beginning 2020\n", "\n", "**Research Questions:**\n", - "\n", "1. What's the difference between weekday and weekend travel patterns?\n", "2. Do holidays and big events show up in the ridership numbers?\n", - "3. Which parts of the MTA system bounced back fastest after 2020?\n" + "3. Which parts of the MTA system bounced back fastest after 2020?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## 1. Setup and Data Loading\n" + "## 1. Setup and Data Loading" ] }, { @@ -45,7 +44,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 2. Data Cleaning\n" + "## 2. Data Cleaning" ] }, { @@ -55,8 +54,9 @@ "outputs": [], "source": [ "# Convert date column to datetime and sort\n", - "df[\"date\"] = pd.to_datetime(df[\"date\"])\n", - "df = df.sort_values(\"date\")\n", + "from utils import clean_mta_df\n", + "\n", + "df = clean_mta_df(df)\n", "\n", "print(f\"Date range: {df['date'].min()} to {df['date'].max()}\")" ] @@ -65,7 +65,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 3. Visualization: MTA Ridership Recovery by Transit Mode\n" + "## 3. Visualization: MTA Ridership Recovery by Transit Mode" ] }, { @@ -76,51 +76,21 @@ "source": [ "plt.figure(figsize=(14, 7))\n", "\n", - "plt.plot(\n", - " df[\"date\"],\n", - " df[\"subways_of_comparable_pre_pandemic_day\"],\n", - " label=\"Subway\",\n", - " alpha=0.8,\n", - " linewidth=1.2,\n", - ")\n", - "plt.plot(\n", - " df[\"date\"],\n", - " df[\"buses_of_comparable_pre_pandemic_day\"],\n", - " label=\"Bus\",\n", - " alpha=0.8,\n", - " linewidth=1.2,\n", - ")\n", - "plt.plot(\n", - " df[\"date\"],\n", - " df[\"lirr_of_comparable_pre_pandemic_day\"],\n", - " label=\"LIRR\",\n", - " alpha=0.8,\n", - " linewidth=1.2,\n", - ")\n", - "plt.plot(\n", - " df[\"date\"],\n", - " df[\"metro_north_of_comparable_pre_pandemic_day\"],\n", - " label=\"Metro-North\",\n", - " alpha=0.8,\n", - " linewidth=1.2,\n", - ")\n", + "plt.plot(df['date'], df['subways_of_comparable_pre_pandemic_day'], \n", + " label='Subway', alpha=0.8, linewidth=1.2)\n", + "plt.plot(df['date'], df['buses_of_comparable_pre_pandemic_day'], \n", + " label='Bus', alpha=0.8, linewidth=1.2)\n", + "plt.plot(df['date'], df['lirr_of_comparable_pre_pandemic_day'], \n", + " label='LIRR', alpha=0.8, linewidth=1.2)\n", + "plt.plot(df['date'], df['metro_north_of_comparable_pre_pandemic_day'], \n", + " label='Metro-North', alpha=0.8, linewidth=1.2)\n", "\n", - "plt.axhline(\n", - " y=1.0,\n", - " color=\"gray\",\n", - " linestyle=\"--\",\n", - " linewidth=1.5,\n", - " label=\"Pre-pandemic baseline (100%)\",\n", - ")\n", + "plt.axhline(y=1.0, color='gray', linestyle='--', linewidth=1.5, label='Pre-pandemic baseline (100%)')\n", "\n", - "plt.xlabel(\"Date\", fontsize=12)\n", - "plt.ylabel(\"% of Pre-Pandemic Ridership\", fontsize=12)\n", - "plt.title(\n", - " \"MTA Ridership Recovery: Subway vs Bus vs Commuter Rail (2020-Present)\",\n", - " fontsize=14,\n", - " fontweight=\"bold\",\n", - ")\n", - "plt.legend(loc=\"lower right\", fontsize=10)\n", + "plt.xlabel('Date', fontsize=12)\n", + "plt.ylabel('% of Pre-Pandemic Ridership', fontsize=12)\n", + "plt.title('MTA Ridership Recovery: Subway vs Bus vs Commuter Rail (2020-Present)', fontsize=14, fontweight='bold')\n", + "plt.legend(loc='lower right', fontsize=10)\n", "plt.grid(True, alpha=0.3)\n", "plt.ylim(0, 1.5)\n", "plt.tight_layout()\n", diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..6ffadac --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,81 @@ +import pandas as pd +import pytest +import matplotlib +matplotlib.use("Agg") # non-interactive backend for testing +from utils import clean_mta_df, plot_ridership_recovery + +def test_clean_mta_df_converts_date_and_sorts(): + df = pd.DataFrame({ + "date": ["2020-01-02", "2020-01-01"], + "x": [2, 1], + }) + + out = clean_mta_df(df) + + assert str(out["date"].dtype).startswith("datetime64") + + assert out["date"].is_monotonic_increasing + + assert list(out["x"]) == [1, 2] + +def test_clean_mta_df_missing_date_raises(): + df = pd.DataFrame({"x": [1, 2]}) + with pytest.raises(KeyError): + clean_mta_df(df) + +def test_clean_mta_df_does_not_modify_original(): + """Test that the original DataFrame is not mutated.""" + df = pd.DataFrame({ + "date": ["2020-01-02", "2020-01-01"], + "x": [2, 1], + }) + original_dates = list(df["date"]) + + clean_mta_df(df) + + # original df should remain unchanged + assert list(df["date"]) == original_dates + + +def test_clean_mta_df_already_sorted(): + """Test that already-sorted data passes through correctly.""" + df = pd.DataFrame({ + "date": ["2020-01-01", "2020-01-02", "2020-01-03"], + "x": [1, 2, 3], + }) + + out = clean_mta_df(df) + + assert out["date"].is_monotonic_increasing + assert list(out["x"]) == [1, 2, 3] + + +# ---------- Tests for plot_ridership_recovery ---------- + +def _make_ridership_df(): + """Helper: create a small valid ridership DataFrame for testing.""" + return pd.DataFrame({ + "date": pd.to_datetime(["2020-03-01", "2020-03-02", "2020-03-03"]), + "subways_of_comparable_pre_pandemic_day": [0.9, 0.5, 0.6], + "buses_of_comparable_pre_pandemic_day": [0.95, 0.6, 0.7], + "lirr_of_comparable_pre_pandemic_day": [0.85, 0.4, 0.5], + "metro_north_of_comparable_pre_pandemic_day": [0.88, 0.45, 0.55], + }) + + +def test_plot_ridership_recovery_returns_figure(): + """Test that the function returns a matplotlib Figure without error.""" + df = _make_ridership_df() + fig = plot_ridership_recovery(df) + assert isinstance(fig, matplotlib.figure.Figure) + matplotlib.pyplot.close(fig) + + +def test_plot_ridership_recovery_missing_column_raises(): + """Test that KeyError is raised when a required column is missing.""" + df = pd.DataFrame({ + "date": pd.to_datetime(["2020-03-01"]), + "subways_of_comparable_pre_pandemic_day": [0.9], + }) + with pytest.raises(KeyError): + plot_ridership_recovery(df) \ No newline at end of file diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..b7c698d --- /dev/null +++ b/utils.py @@ -0,0 +1,53 @@ +import pandas as pd +import matplotlib.pyplot as plt + +def clean_mta_df(df: pd.DataFrame) -> pd.DataFrame: + out = df.copy() + + if "date" not in out.columns: + raise KeyError("Missing 'date' column") + + out["date"] = pd.to_datetime(out["date"]) + out = out.sort_values("date").reset_index(drop=True) + + return out + +def plot_ridership_recovery(df: pd.DataFrame) -> plt.Figure: + """Plot MTA ridership recovery by transit mode as % of pre-pandemic levels.""" + required_cols = [ + "date", + "subways_of_comparable_pre_pandemic_day", + "buses_of_comparable_pre_pandemic_day", + "lirr_of_comparable_pre_pandemic_day", + "metro_north_of_comparable_pre_pandemic_day", + ] + missing = [c for c in required_cols if c not in df.columns] + if missing: + raise KeyError(f"Missing required columns: {missing}") + + fig, ax = plt.subplots(figsize=(14, 7)) + + ax.plot(df["date"], df["subways_of_comparable_pre_pandemic_day"], + label="Subway", alpha=0.8, linewidth=1.2) + ax.plot(df["date"], df["buses_of_comparable_pre_pandemic_day"], + label="Bus", alpha=0.8, linewidth=1.2) + ax.plot(df["date"], df["lirr_of_comparable_pre_pandemic_day"], + label="LIRR", alpha=0.8, linewidth=1.2) + ax.plot(df["date"], df["metro_north_of_comparable_pre_pandemic_day"], + label="Metro-North", alpha=0.8, linewidth=1.2) + + ax.axhline(y=1.0, color="gray", linestyle="--", linewidth=1.5, + label="Pre-pandemic baseline (100%)") + + ax.set_xlabel("Date", fontsize=12) + ax.set_ylabel("% of Pre-Pandemic Ridership", fontsize=12) + ax.set_title( + "MTA Ridership Recovery: Subway vs Bus vs Commuter Rail (2020-Present)", + fontsize=14, fontweight="bold", + ) + ax.legend(loc="lower right", fontsize=10) + ax.grid(True, alpha=0.3) + ax.set_ylim(0, 1.5) + fig.tight_layout() + + return fig