Skip to content

Implemention discussion: using Stan config JSON throughout library and other changes #848

@amas0

Description

@amas0

As part of ongoing work to revamp the IO of the library (#785), we would like to transition from sourcing Stan config info out of comment lines in the Stan CSV to the structured configuration json files (#714). Since #838, the config files are now generated as part of running any Stan method and saved to the output dir along with other files. This includes storing the location of the config in the RunSet object.

In line with work in #844 which introduced pydantic as a dependency for parsing and validating data sourced from json files, we'd like to do the same here. A underlying challenge here is that the config information is used across a number of different methods and not always in a consistent way. There is also a desire to maintain a flexible interface that can accommodate new or in-development outputs/methods. An ideal solution would standardize the IO configuration across methods while not constraining future development.

It's not obvious to me the right path, so I'd like to use this thread to get some thoughts down and discuss ideas before moving forward with a particular implementation.

Where do we actually use the configuration in the library?

To my understanding, the primary use of parsing the Stan configuration is the from_csv() function to reconstruct Fit object from the Stan CSV files. When constructing a Fit object directly from a method like model.sample(), the configuration stored is sourced from the arguments passed by the user. (I believe the only time there is an explicit comparison between the output configuration and the user args is in CmdStanMCMC._validate_csv_files()).

So, under the current approach, we really only need to parse and validate configuration insofar as it is used to create Fit objects from CmdStan output files (this would need a change from a from_csv function to a from_output_files function). The good news is that the required fields to do this is relatively small.

As a side note, in theory I think there may be some value in validating user-passed arguments against the arguments parsed out of config files. Could be the kind of thing that would catch bugs that are otherwise hiding.

What should a configuration object look like? How should validation be done?

When it comes to using the configuration, one way I see that this can be done is a single pydantic model for all the different methods, such as:

  from typing import Literal                                                                                                                                                                        
  from pydantic import BaseModel, ConfigDict                                                                                                                      
                                                                                                                                                                                                    
  class ConfigInfo(BaseModel):                                                                                                                                                                      
      """Structured representation of a config JSON file"""                                                                                                                                        
      model_config = ConfigDict(extra="allow")                                                                                                                                                      
                                                                                                                                                                                                    
      model_name: str                                                                                                                                                                               
      stan_major_version: int                                                                                                                                           
      stan_minor_version: int                                                                                                                                                                       
      stan_patch_version: int                                                                                                                                                                       
      method: Literal["sample", "optimize", "variational", "laplace", "pathfinder"]                                                                                                                 
                                                                                                                                                                                                    
      # Sample                                                                                                                                                                                      
      algorithm: str | None = None                                                                                                                                                                  
      num_samples: int | None = None                                                                                                                                                                
      num_warmup: int | None = None                                                                                                                                                                 
      save_warmup: bool = False                                                                                                                                                                     
      thin: int = 1                                                                                                                                                                                 
      max_depth: int | None = None                                                                                                                                                                  
                                                                                                                                                                                                    
      # Optimize                                                                                                                                                                                    
      save_iterations: bool = False                                                                                                                                                                 
      jacobian: bool = False                                                                                                                                                                        
                                                                                                                                                                                                    
      # Variational                                                                                                                                                                                 
      iter: int | None = None                                                                                                                                                                       
      grad_samples: int | None = None                                                                                                                                                               
      elbo_samples: int | None = None                                                                                                                                                               
      eta: float | None = None                                                                                                                                                                      
      tol_rel_obj: float | None = None                                                                                                                                                              
      eval_elbo: int | None = None                                                                                                                                                                  
      output_samples: int | None = None                                                                                                                                                             
                                                                                                                                                                                                    
      # Laplace                                                                                                                                                                                     
      mode: str | None = None                                                                                                                                                                       
      draws: int | None = None                                                                                                                                                                      
                                                                                                                                                                                                    
      # Pathfinder                                                                                                                                                                                  
      num_draws: int | None = None                                                                                                                                                                  
      num_paths: int | None = None                                                                                                                                                                  
      psis_resample: bool | None = None                                                                                                                                                             
      calculate_lp: bool | None = None                                                                                                                                                              

This keeps everything together, but co-mingles the various config properties and any validation process would need to do so in a method-aware manner. The model_config = ConfigDict(extra="allow") allows the pydantic model to allow fields that are not explicitly specified and handles there being many more fields in the output than are needed by the library (but still gives access to them if desired).

Something like this would probably be the easiest to integrate within the current library in the sense that it would require the fewest changes.

We could alternatively create separate models for each method, e.g.:

  class BaseConfig(BaseModel):                                                                                                                                                                      
      """Common fields for all methods."""                                                                                                                                                          
      model_config = ConfigDict(extra="allow")                                                                                                                                                      
                                                                                                                                                                                                    
      model_name: str                                                                                                                                                                               
      stan_major_version: int                                                                                                                                                                       
      stan_minor_version: int                                                                                                                                                                       
      stan_patch_version: int                                                                                                                                                                       
                                                                                                                                                                                                    
  class SampleConfig(BaseConfig):                                                                                                                                                                   
      method: Literal["sample"] = "sample"                                                                                                                                                          
      algorithm: str                                                                                                                                                                                
      num_samples: int                                                                                                                                                                              
      num_warmup: int                                                                                                                                                                               
      save_warmup: bool = False                                                                                                                                                                     
      thin: int = 1                                                                                                                                                                                 
      max_depth: int | None = None                                                                                                                     
                                                                                                                                                                                                    
  class OptimizeConfig(BaseConfig):                                                                                                                                                                 
      method: Literal["optimize"] = "optimize"                                                                                                                                                      
      algorithm: str                                                                                                                                                                                
      save_iterations: bool = False                                                                                                                                                                 
      jacobian: bool = False

   ...                                                                                                                                                                     

which, in combination with pydantic's discriminated union feature would give us method discriminated config objects that I think would be nicer to work with. This would allow us to keep validation method specific.

This would probably require more changes throughout the library, but I think it would be easier to reason about within the library.

Thinking about further changes

Since these IO changes are aligned with trying to do work for a 2.0 release, I wonder if there are other organizational changes in the library that could be made that would make it a bit easier to maintain/develop against. I find that not having a consistent interface/protocol for the different methods to output (like adhering to a StanFit protocol/ABC) is a pain point. Primarily, I think this could clean up the IO pipeline by making things more consistent among the other methods.

In any case, I would appreciate some thoughts (@WardBrian ) on what a reasonable approach for this would be? Or if you think a better approach is something that I had not suggested, I'm all ears.

In some sense I'm trying to close the gap between the state that this library is in today and an ideal version that would be created from scratch today. Not so easy a task.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions