11package cliext
22
33import (
4- "bytes"
54 "context"
5+ "encoding/json"
66 "errors"
77 "fmt"
88 "os"
99 "path/filepath"
1010 "time"
1111
12- "github.com/BurntSushi/toml"
1312 "go.temporal.io/sdk/contrib/envconfig"
1413 "golang.org/x/oauth2"
1514)
1615
16+ // oauthConfigJSON is an intermediate struct for JSON serialization of OAuth config.
17+ type oauthConfigJSON struct {
18+ ClientID string `json:"client_id,omitempty"`
19+ ClientSecret string `json:"client_secret,omitempty"`
20+ TokenURL string `json:"token_url,omitempty"`
21+ AuthURL string `json:"auth_url,omitempty"`
22+ RedirectURL string `json:"redirect_url,omitempty"`
23+ AccessToken string `json:"access_token,omitempty"`
24+ RefreshToken string `json:"refresh_token,omitempty"`
25+ TokenType string `json:"token_type,omitempty"`
26+ ExpiresAt string `json:"expires_at,omitempty"`
27+ Scopes []string `json:"scopes,omitempty"`
28+ }
29+
1730// OAuthConfig combines OAuth client configuration with token information.
1831type OAuthConfig struct {
1932 // ClientConfig is the OAuth 2.0 client configuration.
@@ -84,52 +97,72 @@ func loadOAuthConfigFromFile(path string) (map[string]*OAuthConfig, error) {
8497 return nil , fmt .Errorf ("failed to read config file: %w" , err )
8598 }
8699
87- var raw rawConfigWithOAuth
88- if _ , err := toml .Decode (string (data ), & raw ); err != nil {
100+ // Use envconfig's FromTOML with AdditionalProfileFields to capture OAuth fields
101+ var conf envconfig.ClientConfig
102+ additional := make (map [string ]map [string ]any )
103+ if err := conf .FromTOML (data , envconfig.ClientConfigFromTOMLOptions {
104+ AdditionalProfileFields : additional ,
105+ }); err != nil {
89106 return nil , fmt .Errorf ("failed to parse config file: %w" , err )
90107 }
91108
92109 oauthByProfile := make (map [string ]* OAuthConfig )
93- for profileName , profile := range raw . Profile {
94- if profile == nil || profile . OAuth == nil {
95- oauthByProfile [ profileName ] = nil
110+ for profileName , fields := range additional {
111+ oauthRaw , ok := fields [ "oauth" ].( map [ string ] any )
112+ if ! ok {
96113 continue
97114 }
98- cfg := profile .OAuth
99-
100- // Parse expiry time if present
101- var expiry time.Time
102- if cfg .ExpiresAt != "" {
103- t , err := time .Parse (time .RFC3339 , cfg .ExpiresAt )
104- if err != nil {
105- return nil , fmt .Errorf ("failed to parse expires_at for profile %q: %w" , profileName , err )
106- }
107- expiry = t
108- }
109-
110- oauth := & OAuthConfig {
111- ClientConfig : & oauth2.Config {
112- ClientID : cfg .ClientID ,
113- ClientSecret : cfg .ClientSecret ,
114- RedirectURL : cfg .RedirectURL ,
115- Scopes : cfg .Scopes ,
116- Endpoint : oauth2.Endpoint {
117- AuthURL : cfg .AuthURL ,
118- TokenURL : cfg .TokenURL ,
119- },
120- },
121- Token : & oauth2.Token {
122- AccessToken : cfg .AccessToken ,
123- RefreshToken : cfg .RefreshToken ,
124- TokenType : cfg .TokenType ,
125- Expiry : expiry ,
126- },
115+ oauth , err := oauthConfigFromMap (oauthRaw )
116+ if err != nil {
117+ return nil , fmt .Errorf ("failed to parse oauth for profile %q: %w" , profileName , err )
127118 }
128119 oauthByProfile [profileName ] = oauth
129120 }
130121 return oauthByProfile , nil
131122}
132123
124+ // oauthConfigFromMap converts a map[string]any to OAuthConfig using JSON as intermediary.
125+ func oauthConfigFromMap (m map [string ]any ) (* OAuthConfig , error ) {
126+ data , err := json .Marshal (m )
127+ if err != nil {
128+ return nil , fmt .Errorf ("failed to marshal oauth config: %w" , err )
129+ }
130+
131+ var cfg oauthConfigJSON
132+ if err := json .Unmarshal (data , & cfg ); err != nil {
133+ return nil , fmt .Errorf ("failed to unmarshal oauth config: %w" , err )
134+ }
135+
136+ // Parse expiry time if present
137+ var expiry time.Time
138+ if cfg .ExpiresAt != "" {
139+ t , err := time .Parse (time .RFC3339 , cfg .ExpiresAt )
140+ if err != nil {
141+ return nil , fmt .Errorf ("failed to parse expires_at: %w" , err )
142+ }
143+ expiry = t
144+ }
145+
146+ return & OAuthConfig {
147+ ClientConfig : & oauth2.Config {
148+ ClientID : cfg .ClientID ,
149+ ClientSecret : cfg .ClientSecret ,
150+ RedirectURL : cfg .RedirectURL ,
151+ Scopes : cfg .Scopes ,
152+ Endpoint : oauth2.Endpoint {
153+ AuthURL : cfg .AuthURL ,
154+ TokenURL : cfg .TokenURL ,
155+ },
156+ },
157+ Token : & oauth2.Token {
158+ AccessToken : cfg .AccessToken ,
159+ RefreshToken : cfg .RefreshToken ,
160+ TokenType : cfg .TokenType ,
161+ Expiry : expiry ,
162+ },
163+ }, nil
164+ }
165+
133166// resolveConfigAndProfile resolves the config file path and profile name.
134167func resolveConfigAndProfile (configFilePath , profileName string , envLookup envconfig.EnvLookup ) (string , string , error ) {
135168 if envLookup == nil {
@@ -182,43 +215,45 @@ func StoreClientOAuth(opts StoreClientOAuthOptions) error {
182215 return err
183216 }
184217
185- // Read and parse existing file content.
218+ // Read and parse existing file content using envconfig .
186219 existingContent , err := os .ReadFile (configFilePath )
187220 if err != nil && ! errors .Is (err , os .ErrNotExist ) {
188221 return fmt .Errorf ("failed to read config file: %w" , err )
189222 }
190223
191- var existingRaw map [string ]any
224+ var conf envconfig.ClientConfig
225+ additional := make (map [string ]map [string ]any )
192226 if len (existingContent ) > 0 {
193- if _ , err := toml .Decode (string (existingContent ), & existingRaw ); err != nil {
227+ if err := conf .FromTOML (existingContent , envconfig.ClientConfigFromTOMLOptions {
228+ AdditionalProfileFields : additional ,
229+ }); err != nil {
194230 return fmt .Errorf ("failed to parse existing config: %w" , err )
195231 }
196232 }
197- if existingRaw == nil {
198- existingRaw = make (map [string ]any )
199- }
200233
201- // Load existing OAuth configs from the parsed content.
202- oauthByProfile , err := parseOAuthFromRaw (existingRaw )
203- if err != nil {
204- return fmt .Errorf ("failed to parse existing OAuth config: %w" , err )
234+ // Ensure the profile exists in the config.
235+ if conf .Profiles == nil {
236+ conf .Profiles = make (map [string ]* envconfig.ClientConfigProfile )
205237 }
206- if oauthByProfile == nil {
207- oauthByProfile = make ( map [ string ] * OAuthConfig )
238+ if conf . Profiles [ profileName ] == nil {
239+ conf . Profiles [ profileName ] = & envconfig. ClientConfigProfile {}
208240 }
209241
210- // Update the OAuth config for this profile.
211- oauthByProfile [profileName ] = opts .OAuth
212-
213- // Merge OAuth configs back into the raw structure.
214- if err := mergeOAuthIntoRaw (existingRaw , oauthByProfile ); err != nil {
215- return err
242+ // Update the OAuth config for this profile in additional fields.
243+ if additional [profileName ] == nil {
244+ additional [profileName ] = make (map [string ]any )
245+ }
246+ if opts .OAuth == nil {
247+ delete (additional [profileName ], "oauth" )
248+ } else {
249+ additional [profileName ]["oauth" ] = oauthConfigToMap (opts .OAuth )
216250 }
217251
218- // Marshal back to TOML.
219- var buf bytes.Buffer
220- enc := toml .NewEncoder (& buf )
221- if err := enc .Encode (existingRaw ); err != nil {
252+ // Marshal back to TOML using envconfig.
253+ data , err := conf .ToTOML (envconfig.ClientConfigToTOMLOptions {
254+ AdditionalProfileFields : additional ,
255+ })
256+ if err != nil {
222257 return fmt .Errorf ("failed to encode config: %w" , err )
223258 }
224259
@@ -227,122 +262,20 @@ func StoreClientOAuth(opts StoreClientOAuthOptions) error {
227262 return fmt .Errorf ("failed to create config directory: %w" , err )
228263 }
229264
230- if err := os .WriteFile (configFilePath , buf . Bytes () , 0600 ); err != nil {
265+ if err := os .WriteFile (configFilePath , data , 0600 ); err != nil {
231266 return fmt .Errorf ("failed to write config file: %w" , err )
232267 }
233268
234269 return nil
235270}
236271
237- func parseOAuthFromRaw (raw map [string ]any ) (map [string ]* OAuthConfig , error ) {
238- var parsed rawConfigWithOAuth
239-
240- // Re-encode and decode to convert map[string]any to our struct.
241- // This is simpler than manual type assertions for nested structures.
242- var buf bytes.Buffer
243- enc := toml .NewEncoder (& buf )
244- if err := enc .Encode (raw ); err != nil {
245- return nil , err
246- }
247- if _ , err := toml .Decode (buf .String (), & parsed ); err != nil {
248- return nil , err
249- }
250-
251- oauthByProfile := make (map [string ]* OAuthConfig )
252- for profileName , profile := range parsed .Profile {
253- if profile == nil || profile .OAuth == nil {
254- continue
255- }
256- cfg := profile .OAuth
257-
258- // Parse expiry time if present
259- var expiry time.Time
260- if cfg .ExpiresAt != "" {
261- t , err := time .Parse (time .RFC3339 , cfg .ExpiresAt )
262- if err != nil {
263- return nil , fmt .Errorf ("failed to parse expires_at for profile %q: %w" , profileName , err )
264- }
265- expiry = t
266- }
267-
268- oauth := & OAuthConfig {
269- ClientConfig : & oauth2.Config {
270- ClientID : cfg .ClientID ,
271- ClientSecret : cfg .ClientSecret ,
272- RedirectURL : cfg .RedirectURL ,
273- Scopes : cfg .Scopes ,
274- Endpoint : oauth2.Endpoint {
275- AuthURL : cfg .AuthURL ,
276- TokenURL : cfg .TokenURL ,
277- },
278- },
279- Token : & oauth2.Token {
280- AccessToken : cfg .AccessToken ,
281- RefreshToken : cfg .RefreshToken ,
282- TokenType : cfg .TokenType ,
283- Expiry : expiry ,
284- },
285- }
286- oauthByProfile [profileName ] = oauth
287- }
288- return oauthByProfile , nil
289- }
290-
291- // mergeOAuthIntoRaw merges OAuth configurations into a raw TOML structure.
292- func mergeOAuthIntoRaw (raw map [string ]any , oauthByProfile map [string ]* OAuthConfig ) error {
293- // Get or create the profile section.
294- profileSection , ok := raw ["profile" ].(map [string ]any )
295- if ! ok {
296- profileSection = make (map [string ]any )
297- raw ["profile" ] = profileSection
298- }
299-
300- // Update OAuth for each profile.
301- for profileName , oauth := range oauthByProfile {
302- profile , ok := profileSection [profileName ].(map [string ]any )
303- if ! ok {
304- profile = make (map [string ]any )
305- profileSection [profileName ] = profile
306- }
307-
308- if oauth == nil {
309- delete (profile , "oauth" )
310- } else {
311- profile ["oauth" ] = oauthConfigToTOML (oauth )
312- }
313- }
314-
315- return nil
316- }
317-
318- // oauthConfigTOML is the TOML representation of OAuthConfig.
319- type oauthConfigTOML struct {
320- ClientID string `toml:"client_id,omitempty"`
321- ClientSecret string `toml:"client_secret,omitempty"`
322- TokenURL string `toml:"token_url,omitempty"`
323- AuthURL string `toml:"auth_url,omitempty"`
324- RedirectURL string `toml:"redirect_url,omitempty"`
325- AccessToken string `toml:"access_token,omitempty"`
326- RefreshToken string `toml:"refresh_token,omitempty"`
327- TokenType string `toml:"token_type,omitempty"`
328- ExpiresAt string `toml:"expires_at,omitempty"`
329- Scopes []string `toml:"scopes,omitempty"`
330- }
331-
332- type rawProfileWithOAuth struct {
333- OAuth * oauthConfigTOML `toml:"oauth"`
334- }
335-
336- type rawConfigWithOAuth struct {
337- Profile map [string ]* rawProfileWithOAuth `toml:"profile"`
338- }
339-
340- // oauthConfigToTOML converts OAuthConfig to its TOML representation.
341- func oauthConfigToTOML (oauth * OAuthConfig ) * oauthConfigTOML {
272+ // oauthConfigToMap converts OAuthConfig to map[string]any using JSON as intermediary.
273+ func oauthConfigToMap (oauth * OAuthConfig ) map [string ]any {
342274 if oauth == nil || oauth .ClientConfig == nil || oauth .Token == nil {
343275 return nil
344276 }
345- result := & oauthConfigTOML {
277+
278+ cfg := oauthConfigJSON {
346279 ClientID : oauth .ClientConfig .ClientID ,
347280 ClientSecret : oauth .ClientConfig .ClientSecret ,
348281 TokenURL : oauth .ClientConfig .Endpoint .TokenURL ,
@@ -354,7 +287,18 @@ func oauthConfigToTOML(oauth *OAuthConfig) *oauthConfigTOML {
354287 Scopes : oauth .ClientConfig .Scopes ,
355288 }
356289 if ! oauth .Token .Expiry .IsZero () {
357- result .ExpiresAt = oauth .Token .Expiry .Format (time .RFC3339 )
290+ cfg .ExpiresAt = oauth .Token .Expiry .Format (time .RFC3339 )
291+ }
292+
293+ data , err := json .Marshal (cfg )
294+ if err != nil {
295+ return nil
296+ }
297+
298+ var result map [string ]any
299+ if err := json .Unmarshal (data , & result ); err != nil {
300+ return nil
358301 }
359302 return result
360303}
304+
0 commit comments