Skip to content

Unpickling RandomForestClassifier with cuml.accel #7627

@betatim

Description

@betatim

Another discussion topic. Unpickling a random forest with cuml.accel enabled takes quite a bit more time than with it turned off.

I instructed cursor to create a script that trains forests with trees of different "shapes". Trying to get trees that are deep or shallow. Ideally having trees that are deep and narrow vs trees that are shallow and wide. I'm not sure how achievable that is without lots of doctoring of the training dataset. However, I think it doesn't matter too much, as the different tree shapes are just a guess of mine for what could influence pickling time. The script is in the "details" at the end. I've read through it to check it is not totally bananas and it seems alright. I've not (yet) experimented with changing the hyper-parameters to see what happens (eg do we really need the min_samples_split setting or is it all about max_depth?)

The average depth for the "deep" config is quite different in the two cases (~20 vs ~30). My guess would have been that less depth/less nodes means faster (un)pickling, so I don't think this difference matters wrt the slower (un)pickling we see for with cuml.accel.

scikit-learn with cuml.accel on:

====================================================================================================
BENCHMARK RESULTS
====================================================================================================
Config        Trees  AvgDepth  TotalNodes  AvgLeaves     Pickle(ms)   Unpickle(ms)   Size(MB)
----------------------------------------------------------------------------------------------------
shallow          50       5.0        1550       16.0    0.68±0.25     2.89±0.73        0.17
shallow         100       5.0        3100       16.0    0.88±0.17     3.92±0.48        0.33
shallow         200       5.0        6200       16.0    1.53±0.33     7.14±0.87        0.66
shallow         300       5.0        9300       16.0    2.31±0.48    21.96±22.78       0.99
deep             50      19.8     1056182    10562.3   48.23±22.05  505.60±40.33     104.77
deep            100      19.8     2113350    10567.2  135.02±7.35   889.58±6.10      209.63
deep            200      19.7     4228960    10572.9  215.88±88.97 2042.09±60.36     419.49
deep            300      19.7     6342006    10570.5  325.77±122.84 3044.53±39.58     629.09
medium           50      10.9       84526      845.8    2.34±0.25    20.78±0.56        8.40
medium          100      10.8      168594      843.5   11.47±13.38   41.65±0.80       16.75
medium          200      10.7      338528      846.8    9.39±1.20    86.75±0.72       33.63
medium          300      10.8      507654      846.6   14.92±2.07   126.44±1.47       50.43
====================================================================================================

Just scikit-learn:

====================================================================================================
BENCHMARK RESULTS
====================================================================================================
Config        Trees  AvgDepth  TotalNodes  AvgLeaves     Pickle(ms)   Unpickle(ms)   Size(MB)
----------------------------------------------------------------------------------------------------
shallow          50       4.0        1550       16.0    0.52±0.10     0.42±0.03        0.17
shallow         100       4.0        3100       16.0    0.99±0.17     1.00±0.37        0.34
shallow         200       4.0        6200       16.0    2.13±0.26     1.62±0.12        0.68
shallow         300       4.0        9300       16.0    3.38±0.54     2.76±0.73        1.02
deep             50      29.6      982974     9830.2   82.56±0.86    65.01±3.83       97.51
deep            100      29.6     1964450     9822.8  165.29±0.85   137.92±2.09      194.87
deep            200      29.5     3931726     9829.8  329.58±8.15   290.88±2.51      390.03
deep            300      29.6     5898986     9832.1  491.37±10.54  433.18±0.64      585.18
medium           50      10.0       79518      795.7    3.32±2.26     1.57±0.74        7.90
medium          100      10.0      157966      790.3    5.65±2.10     2.94±1.04       15.70
medium          200      10.0      316650      792.1   23.33±4.01    13.85±10.37      31.47
medium          300      10.0      475614      793.2   39.48±2.13    35.01±7.52       47.27
====================================================================================================
Cursor generated benchmark script

I've not made any edits to this, except remove a very small `n_estimators` entry and add the 300 one. I've read through it and it looks super good enough.

#!/usr/bin/env python3
"""Benchmark pickle/unpickle times for sklearn RandomForestClassifier

This script benchmarks forests with different tree shapes:
- Shallow/wide trees: low max_depth, high min_samples_split/leaf
- Deep/narrow trees: high max_depth, low min_samples_split/leaf
- Medium trees: balanced parameters

Each shape is tested with varying numbers of trees (n_estimators).
"""

import pickle
import time
import numpy as np
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier


# Forest configurations controlling tree shape
FOREST_CONFIGS = {
    "shallow": {
        "max_depth": 4,
        "min_samples_split": 50,
        "min_samples_leaf": 20,
    },
    "deep": {
        "max_depth": 30,
        "min_samples_split": 2,
        "min_samples_leaf": 1,
    },
    "medium": {
        "max_depth": 10,
        "min_samples_split": 10,
        "min_samples_leaf": 5,
    },
}

# Number of trees to test
N_ESTIMATORS_LIST = [50, 100, 200, 300]

# Benchmark parameters
N_RUNS = 5
RANDOM_STATE = 42


def generate_dataset(n_samples=50_000, n_features=100, n_informative=50, n_classes=5):
    """Generate a classification dataset suitable for tree benchmarking."""
    print(f"Generating dataset: {n_samples} samples, {n_features} features, "
          f"{n_informative} informative, {n_classes} classes...")
    
    X, y = make_classification(
        n_samples=n_samples,
        n_features=n_features,
        n_informative=n_informative,
        n_redundant=n_features - n_informative - n_classes,
        n_classes=n_classes,
        random_state=RANDOM_STATE,
    )
    return X.astype(np.float32), y.astype(np.int32)


def get_tree_stats(model):
    """Extract statistics from a fitted random forest.
    
    Returns:
        tuple: (average_depth, total_nodes, avg_leaves)
    """
    depths = []
    total_nodes = 0
    total_leaves = 0
    
    for tree in model.estimators_:
        depths.append(tree.get_depth())
        total_nodes += tree.tree_.node_count
        total_leaves += tree.tree_.n_leaves
    
    return np.mean(depths), total_nodes, total_leaves / len(model.estimators_)


def benchmark_pickle(model, n_runs=N_RUNS):
    """Measure pickle and unpickle times using in-memory serialization.
    
    Returns:
        tuple: (mean_pickle_time, std_pickle_time, mean_unpickle_time, 
                std_unpickle_time, serialized_size_bytes)
    """
    pickle_times = []
    unpickle_times = []
    data = None
    
    for _ in range(n_runs):
        # Pickle to bytes
        start = time.perf_counter()
        data = pickle.dumps(model)
        pickle_times.append(time.perf_counter() - start)
        
        # Unpickle from bytes
        start = time.perf_counter()
        _ = pickle.loads(data)
        unpickle_times.append(time.perf_counter() - start)
    
    return (
        np.mean(pickle_times),
        np.std(pickle_times),
        np.mean(unpickle_times),
        np.std(unpickle_times),
        len(data),
    )


def train_and_benchmark(X, y, config_name, config_params, n_estimators):
    """Train a RandomForest and benchmark its pickle/unpickle times.
    
    Returns:
        dict: Results including timing and tree statistics
    """
    # Create and train model
    model = RandomForestClassifier(
        n_estimators=n_estimators,
        n_jobs=-1,
        random_state=RANDOM_STATE,
        **config_params,
    )
    
    train_start = time.perf_counter()
    model.fit(X, y)
    train_time = time.perf_counter() - train_start
    
    # Get tree statistics
    avg_depth, total_nodes, avg_leaves = get_tree_stats(model)
    
    # Benchmark pickle/unpickle
    pickle_mean, pickle_std, unpickle_mean, unpickle_std, size_bytes = benchmark_pickle(model)
    
    return {
        "config": config_name,
        "n_estimators": n_estimators,
        "train_time_s": train_time,
        "avg_depth": avg_depth,
        "total_nodes": total_nodes,
        "avg_leaves": avg_leaves,
        "pickle_mean_ms": pickle_mean * 1000,
        "pickle_std_ms": pickle_std * 1000,
        "unpickle_mean_ms": unpickle_mean * 1000,
        "unpickle_std_ms": unpickle_std * 1000,
        "size_mb": size_bytes / (1024 * 1024),
    }


def print_results_table(results):
    """Print results in a formatted table."""
    print("\n" + "=" * 100)
    print("BENCHMARK RESULTS")
    print("=" * 100)
    
    header = (
        f"{'Config':<12} {'Trees':>6} {'AvgDepth':>9} {'TotalNodes':>11} "
        f"{'AvgLeaves':>10} {'Pickle(ms)':>14} {'Unpickle(ms)':>14} {'Size(MB)':>10}"
    )
    print(header)
    print("-" * 100)
    
    for r in results:
        row = (
            f"{r['config']:<12} {r['n_estimators']:>6} {r['avg_depth']:>9.1f} "
            f"{r['total_nodes']:>11} {r['avg_leaves']:>10.1f} "
            f"{r['pickle_mean_ms']:>7.2f}±{r['pickle_std_ms']:<5.2f} "
            f"{r['unpickle_mean_ms']:>7.2f}±{r['unpickle_std_ms']:<5.2f} "
            f"{r['size_mb']:>10.2f}"
        )
        print(row)
    
    print("=" * 100)


def main():
    print("=" * 60)
    print("RandomForest Pickle/Unpickle Benchmark")
    print("=" * 60)
    print(f"Runs per config: {N_RUNS}")
    print()
    
    # Generate dataset
    X, y = generate_dataset()
    print(f"Dataset shape: X={X.shape}, y={y.shape}")
    print(f"Dataset memory: {X.nbytes / (1024**2):.2f} MB")
    print()
    
    # Run benchmarks
    results = []
    total_configs = len(FOREST_CONFIGS) * len(N_ESTIMATORS_LIST)
    current = 0
    
    for config_name, config_params in FOREST_CONFIGS.items():
        print(f"\n--- Testing '{config_name}' tree configuration ---")
        print(f"    max_depth={config_params['max_depth']}, "
              f"min_samples_split={config_params['min_samples_split']}, "
              f"min_samples_leaf={config_params['min_samples_leaf']}")
        
        for n_estimators in N_ESTIMATORS_LIST:
            current += 1
            print(f"\n[{current}/{total_configs}] {config_name} with {n_estimators} trees...")
            
            result = train_and_benchmark(X, y, config_name, config_params, n_estimators)
            results.append(result)
            
            print(f"    Train: {result['train_time_s']:.2f}s | "
                  f"Pickle: {result['pickle_mean_ms']:.2f}ms | "
                  f"Unpickle: {result['unpickle_mean_ms']:.2f}ms | "
                  f"Size: {result['size_mb']:.2f}MB")
    
    # Print final results table
    print_results_table(results)
    
    # Summary comparison
    print("\nSUMMARY: Shallow vs Deep comparison (at 100 trees)")
    shallow_100 = next(r for r in results if r["config"] == "shallow" and r["n_estimators"] == 100)
    deep_100 = next(r for r in results if r["config"] == "deep" and r["n_estimators"] == 100)
    
    print(f"  Shallow: {shallow_100['total_nodes']:,} nodes, "
          f"{shallow_100['pickle_mean_ms']:.2f}ms pickle, "
          f"{shallow_100['size_mb']:.2f}MB")
    print(f"  Deep:    {deep_100['total_nodes']:,} nodes, "
          f"{deep_100['pickle_mean_ms']:.2f}ms pickle, "
          f"{deep_100['size_mb']:.2f}MB")
    print(f"  Ratio:   {deep_100['total_nodes']/shallow_100['total_nodes']:.1f}x nodes, "
          f"{deep_100['pickle_mean_ms']/shallow_100['pickle_mean_ms']:.1f}x pickle time, "
          f"{deep_100['size_mb']/shallow_100['size_mb']:.1f}x size")


if __name__ == "__main__":
    main()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions