Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,24 +142,30 @@ print(pgd.compute(generated)) # {'pgd': ..., 'pgd_descriptor': ..., 'subscores':

`pgd_descriptor` provides the best descriptor used to report the final score.

By default, PGD uses TabPFN v2 weights. To use TabPFN v2.5 weights instead, pass a custom classifier. The v2.5 weights are hosted on a gated Hugging Face repository ([Prior-Labs/tabpfn_2_5](https://huggingface.co/Prior-Labs/tabpfn_2_5)) and require authentication:
By default, PGD uses TabPFN v2.5 weights. The v2.5 weights are hosted on a gated Hugging Face repository ([Prior-Labs/tabpfn_2_5](https://huggingface.co/Prior-Labs/tabpfn_2_5)) and require authentication:

```bash
pip install huggingface_hub
huggingface-cli login
```

Then:
Alternatively, you can use TabPFN v2.0 weights, which are licensed under the Prior Labs License (Apache 2.0 with an additional attribution clause) and permit commercial use. The v2.5 weights, in contrast, use a non-commercial license that prohibits commercial and production use without a separate enterprise license from Prior Labs:

```python
from tabpfn import TabPFNClassifier
from polygraph.metrics import StandardPGD

classifier = TabPFNClassifier.create_default_for_version(
"v2.5", device="auto", n_estimators=4
)
classifier = TabPFNClassifier(device="auto", n_estimators=4)
pgd = StandardPGD(reference, classifier=classifier)
print(pgd.compute(generated))
```

A logistic regression classifier can also be used as a lightweight alternative, although it yields a looser bound in practice:

```python
from sklearn.linear_model import LogisticRegression
from polygraph.metrics import StandardPGD

pgd = StandardPGD(reference, classifier=LogisticRegression())
```

#### Validity, uniqueness and novelty
Expand Down
9 changes: 7 additions & 2 deletions polygraph/metrics/base/polygraphdiscrepancy.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from sklearn.preprocessing import StandardScaler
from packaging.version import Version
from tabpfn import TabPFNClassifier
from tabpfn.classifier import ModelVersion

from polygraph import GraphType
from polygraph.metrics.base.interface import GenerationMetric
Expand Down Expand Up @@ -107,7 +108,7 @@ def default_classifier() -> TabPFNClassifier:
"""Create the default TabPFN classifier used by PGD.

Returns:
A TabPFNClassifier with default settings (auto device, 4
A TabPFNClassifier with v2.5 weights (auto device, 4
estimators). Requires ``tabpfn >= 2.0.9``.
"""
tabpfn_ver = Version(version("tabpfn"))
Expand All @@ -116,7 +117,11 @@ def default_classifier() -> TabPFNClassifier:
"TabPFN >= 2.0.9 is required. "
"Install with `pip install 'tabpfn>=2.0.9'`."
)
return TabPFNClassifier(device="auto", n_estimators=4)
return TabPFNClassifier.create_default_for_version(
ModelVersion.V2_5,
device="auto",
n_estimators=4,
)


class PolyGraphDiscrepancyResult(TypedDict):
Expand Down
Loading