From 1795833c6538a72c16680626b0eaa80854b7f0cf Mon Sep 17 00:00:00 2001 From: Hanghai Li Date: Wed, 25 Feb 2026 04:07:47 -0500 Subject: [PATCH 1/3] Add unit tests and refactor date cleaning into function --- mta_ridership_project .ipynb | 5 +++-- tests/test_utils.py | 22 ++++++++++++++++++++++ utils.py | 8 ++++++++ 3 files changed, 33 insertions(+), 2 deletions(-) create mode 100644 tests/test_utils.py create mode 100644 utils.py diff --git a/mta_ridership_project .ipynb b/mta_ridership_project .ipynb index d826da5..51891ec 100644 --- a/mta_ridership_project .ipynb +++ b/mta_ridership_project .ipynb @@ -54,8 +54,9 @@ "outputs": [], "source": [ "# Convert date column to datetime and sort\n", - "df['date'] = pd.to_datetime(df['date'])\n", - "df = df.sort_values('date')\n", + "from utils import clean_mta_df\n", + "\n", + "df = clean_mta_df(df)\n", "\n", "print(f\"Date range: {df['date'].min()} to {df['date'].max()}\")" ] diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..387e1e6 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,22 @@ +import pandas as pd +import pytest +from utils import clean_mta_df + +def test_clean_mta_df_converts_date_and_sorts(): + df = pd.DataFrame({ + "date": ["2020-01-02", "2020-01-01"], + "x": [2, 1], + }) + + out = clean_mta_df(df) + + assert str(out["date"].dtype).startswith("datetime64") + + assert out["date"].is_monotonic_increasing + + assert list(out["x"]) == [1, 2] + +def test_clean_mta_df_missing_date_raises(): + df = pd.DataFrame({"x": [1, 2]}) + with pytest.raises(KeyError): + clean_mta_df(df) \ No newline at end of file diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..cccad3f --- /dev/null +++ b/utils.py @@ -0,0 +1,8 @@ +# utils.py +import pandas as pd + +def clean_mta_df(df: pd.DataFrame) -> pd.DataFrame: + out = df.copy() + out["date"] = pd.to_datetime(out["date"]) + out = out.sort_values("date") + return out \ No newline at end of file From d9aa239b775eb96c3382ae8f8caa8afba33fb8ec Mon Sep 17 00:00:00 2001 From: Hanghai Li Date: Wed, 25 Feb 2026 04:12:16 -0500 Subject: [PATCH 2/3] Update .gitignore and improve clean_mta_df --- .gitignore | 6 +++++- utils.py | 8 ++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 8451607..b508df9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,6 @@ .venv/ -.DS_Store \ No newline at end of file +.DS_Store + +__pycache__/ +*.pyc +.pytest_cache/ \ No newline at end of file diff --git a/utils.py b/utils.py index cccad3f..2fecde5 100644 --- a/utils.py +++ b/utils.py @@ -1,8 +1,12 @@ -# utils.py import pandas as pd def clean_mta_df(df: pd.DataFrame) -> pd.DataFrame: out = df.copy() + + if "date" not in out.columns: + raise KeyError("Missing 'date' column") + out["date"] = pd.to_datetime(out["date"]) - out = out.sort_values("date") + out = out.sort_values("date").reset_index(drop=True) + return out \ No newline at end of file From eeb173fd576a355047469f6d857e1f6958d39723 Mon Sep 17 00:00:00 2001 From: Hanghai Li Date: Wed, 25 Feb 2026 15:54:02 -0500 Subject: [PATCH 3/3] Add plotting function and expand tests --- tests/test_utils.py | 63 +++++++++++++++++++++++++++++++++++++++++++-- utils.py | 43 ++++++++++++++++++++++++++++++- 2 files changed, 103 insertions(+), 3 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 387e1e6..6ffadac 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,8 @@ import pandas as pd import pytest -from utils import clean_mta_df +import matplotlib +matplotlib.use("Agg") # non-interactive backend for testing +from utils import clean_mta_df, plot_ridership_recovery def test_clean_mta_df_converts_date_and_sorts(): df = pd.DataFrame({ @@ -19,4 +21,61 @@ def test_clean_mta_df_converts_date_and_sorts(): def test_clean_mta_df_missing_date_raises(): df = pd.DataFrame({"x": [1, 2]}) with pytest.raises(KeyError): - clean_mta_df(df) \ No newline at end of file + clean_mta_df(df) + +def test_clean_mta_df_does_not_modify_original(): + """Test that the original DataFrame is not mutated.""" + df = pd.DataFrame({ + "date": ["2020-01-02", "2020-01-01"], + "x": [2, 1], + }) + original_dates = list(df["date"]) + + clean_mta_df(df) + + # original df should remain unchanged + assert list(df["date"]) == original_dates + + +def test_clean_mta_df_already_sorted(): + """Test that already-sorted data passes through correctly.""" + df = pd.DataFrame({ + "date": ["2020-01-01", "2020-01-02", "2020-01-03"], + "x": [1, 2, 3], + }) + + out = clean_mta_df(df) + + assert out["date"].is_monotonic_increasing + assert list(out["x"]) == [1, 2, 3] + + +# ---------- Tests for plot_ridership_recovery ---------- + +def _make_ridership_df(): + """Helper: create a small valid ridership DataFrame for testing.""" + return pd.DataFrame({ + "date": pd.to_datetime(["2020-03-01", "2020-03-02", "2020-03-03"]), + "subways_of_comparable_pre_pandemic_day": [0.9, 0.5, 0.6], + "buses_of_comparable_pre_pandemic_day": [0.95, 0.6, 0.7], + "lirr_of_comparable_pre_pandemic_day": [0.85, 0.4, 0.5], + "metro_north_of_comparable_pre_pandemic_day": [0.88, 0.45, 0.55], + }) + + +def test_plot_ridership_recovery_returns_figure(): + """Test that the function returns a matplotlib Figure without error.""" + df = _make_ridership_df() + fig = plot_ridership_recovery(df) + assert isinstance(fig, matplotlib.figure.Figure) + matplotlib.pyplot.close(fig) + + +def test_plot_ridership_recovery_missing_column_raises(): + """Test that KeyError is raised when a required column is missing.""" + df = pd.DataFrame({ + "date": pd.to_datetime(["2020-03-01"]), + "subways_of_comparable_pre_pandemic_day": [0.9], + }) + with pytest.raises(KeyError): + plot_ridership_recovery(df) \ No newline at end of file diff --git a/utils.py b/utils.py index 2fecde5..b7c698d 100644 --- a/utils.py +++ b/utils.py @@ -1,4 +1,5 @@ import pandas as pd +import matplotlib.pyplot as plt def clean_mta_df(df: pd.DataFrame) -> pd.DataFrame: out = df.copy() @@ -9,4 +10,44 @@ def clean_mta_df(df: pd.DataFrame) -> pd.DataFrame: out["date"] = pd.to_datetime(out["date"]) out = out.sort_values("date").reset_index(drop=True) - return out \ No newline at end of file + return out + +def plot_ridership_recovery(df: pd.DataFrame) -> plt.Figure: + """Plot MTA ridership recovery by transit mode as % of pre-pandemic levels.""" + required_cols = [ + "date", + "subways_of_comparable_pre_pandemic_day", + "buses_of_comparable_pre_pandemic_day", + "lirr_of_comparable_pre_pandemic_day", + "metro_north_of_comparable_pre_pandemic_day", + ] + missing = [c for c in required_cols if c not in df.columns] + if missing: + raise KeyError(f"Missing required columns: {missing}") + + fig, ax = plt.subplots(figsize=(14, 7)) + + ax.plot(df["date"], df["subways_of_comparable_pre_pandemic_day"], + label="Subway", alpha=0.8, linewidth=1.2) + ax.plot(df["date"], df["buses_of_comparable_pre_pandemic_day"], + label="Bus", alpha=0.8, linewidth=1.2) + ax.plot(df["date"], df["lirr_of_comparable_pre_pandemic_day"], + label="LIRR", alpha=0.8, linewidth=1.2) + ax.plot(df["date"], df["metro_north_of_comparable_pre_pandemic_day"], + label="Metro-North", alpha=0.8, linewidth=1.2) + + ax.axhline(y=1.0, color="gray", linestyle="--", linewidth=1.5, + label="Pre-pandemic baseline (100%)") + + ax.set_xlabel("Date", fontsize=12) + ax.set_ylabel("% of Pre-Pandemic Ridership", fontsize=12) + ax.set_title( + "MTA Ridership Recovery: Subway vs Bus vs Commuter Rail (2020-Present)", + fontsize=14, fontweight="bold", + ) + ax.legend(loc="lower right", fontsize=10) + ax.grid(True, alpha=0.3) + ax.set_ylim(0, 1.5) + fig.tight_layout() + + return fig