-
-
Notifications
You must be signed in to change notification settings - Fork 80
Description
As part of ongoing work to revamp the IO of the library (#785), we would like to transition from sourcing Stan config info out of comment lines in the Stan CSV to the structured configuration json files (#714). Since #838, the config files are now generated as part of running any Stan method and saved to the output dir along with other files. This includes storing the location of the config in the RunSet object.
In line with work in #844 which introduced pydantic as a dependency for parsing and validating data sourced from json files, we'd like to do the same here. A underlying challenge here is that the config information is used across a number of different methods and not always in a consistent way. There is also a desire to maintain a flexible interface that can accommodate new or in-development outputs/methods. An ideal solution would standardize the IO configuration across methods while not constraining future development.
It's not obvious to me the right path, so I'd like to use this thread to get some thoughts down and discuss ideas before moving forward with a particular implementation.
Where do we actually use the configuration in the library?
To my understanding, the primary use of parsing the Stan configuration is the from_csv() function to reconstruct Fit object from the Stan CSV files. When constructing a Fit object directly from a method like model.sample(), the configuration stored is sourced from the arguments passed by the user. (I believe the only time there is an explicit comparison between the output configuration and the user args is in CmdStanMCMC._validate_csv_files()).
So, under the current approach, we really only need to parse and validate configuration insofar as it is used to create Fit objects from CmdStan output files (this would need a change from a from_csv function to a from_output_files function). The good news is that the required fields to do this is relatively small.
As a side note, in theory I think there may be some value in validating user-passed arguments against the arguments parsed out of config files. Could be the kind of thing that would catch bugs that are otherwise hiding.
What should a configuration object look like? How should validation be done?
When it comes to using the configuration, one way I see that this can be done is a single pydantic model for all the different methods, such as:
from typing import Literal
from pydantic import BaseModel, ConfigDict
class ConfigInfo(BaseModel):
"""Structured representation of a config JSON file"""
model_config = ConfigDict(extra="allow")
model_name: str
stan_major_version: int
stan_minor_version: int
stan_patch_version: int
method: Literal["sample", "optimize", "variational", "laplace", "pathfinder"]
# Sample
algorithm: str | None = None
num_samples: int | None = None
num_warmup: int | None = None
save_warmup: bool = False
thin: int = 1
max_depth: int | None = None
# Optimize
save_iterations: bool = False
jacobian: bool = False
# Variational
iter: int | None = None
grad_samples: int | None = None
elbo_samples: int | None = None
eta: float | None = None
tol_rel_obj: float | None = None
eval_elbo: int | None = None
output_samples: int | None = None
# Laplace
mode: str | None = None
draws: int | None = None
# Pathfinder
num_draws: int | None = None
num_paths: int | None = None
psis_resample: bool | None = None
calculate_lp: bool | None = None This keeps everything together, but co-mingles the various config properties and any validation process would need to do so in a method-aware manner. The model_config = ConfigDict(extra="allow") allows the pydantic model to allow fields that are not explicitly specified and handles there being many more fields in the output than are needed by the library (but still gives access to them if desired).
Something like this would probably be the easiest to integrate within the current library in the sense that it would require the fewest changes.
We could alternatively create separate models for each method, e.g.:
class BaseConfig(BaseModel):
"""Common fields for all methods."""
model_config = ConfigDict(extra="allow")
model_name: str
stan_major_version: int
stan_minor_version: int
stan_patch_version: int
class SampleConfig(BaseConfig):
method: Literal["sample"] = "sample"
algorithm: str
num_samples: int
num_warmup: int
save_warmup: bool = False
thin: int = 1
max_depth: int | None = None
class OptimizeConfig(BaseConfig):
method: Literal["optimize"] = "optimize"
algorithm: str
save_iterations: bool = False
jacobian: bool = False
... which, in combination with pydantic's discriminated union feature would give us method discriminated config objects that I think would be nicer to work with. This would allow us to keep validation method specific.
This would probably require more changes throughout the library, but I think it would be easier to reason about within the library.
Thinking about further changes
Since these IO changes are aligned with trying to do work for a 2.0 release, I wonder if there are other organizational changes in the library that could be made that would make it a bit easier to maintain/develop against. I find that not having a consistent interface/protocol for the different methods to output (like adhering to a StanFit protocol/ABC) is a pain point. Primarily, I think this could clean up the IO pipeline by making things more consistent among the other methods.
In any case, I would appreciate some thoughts (@WardBrian ) on what a reasonable approach for this would be? Or if you think a better approach is something that I had not suggested, I'm all ears.
In some sense I'm trying to close the gap between the state that this library is in today and an ideal version that would be created from scratch today. Not so easy a task.