Skip to content

Commit f969002

Browse files
Apply PR #17227: refactor(provider): use AuthService in auth flows
2 parents 2e5b416 + 3cfdb07 commit f969002

3 files changed

Lines changed: 200 additions & 96 deletions

File tree

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
import { Effect, Layer, ServiceMap } from "effect"
2+
import { Instance } from "@/project/instance"
3+
import { Plugin } from "../plugin"
4+
import { filter, fromEntries, map, mapValues, pipe } from "remeda"
5+
import type { AuthOuathResult } from "@opencode-ai/plugin"
6+
import { NamedError } from "@opencode-ai/util/error"
7+
import * as Auth from "@/auth/service"
8+
import { ProviderID } from "./schema"
9+
import z from "zod"
10+
11+
const state = Instance.state(async () => {
12+
const methods = pipe(
13+
await Plugin.list(),
14+
filter((x) => x.auth?.provider !== undefined),
15+
map((x) => [x.auth!.provider, x.auth!] as const),
16+
fromEntries(),
17+
)
18+
return { methods, pending: {} as Record<string, AuthOuathResult> }
19+
})
20+
21+
export type Method = {
22+
type: "oauth" | "api"
23+
label: string
24+
}
25+
26+
export type Authorization = {
27+
url: string
28+
method: "auto" | "code"
29+
instructions: string
30+
}
31+
32+
export const OauthMissing = NamedError.create(
33+
"ProviderAuthOauthMissing",
34+
z.object({
35+
providerID: ProviderID.zod,
36+
}),
37+
)
38+
39+
export const OauthCodeMissing = NamedError.create(
40+
"ProviderAuthOauthCodeMissing",
41+
z.object({
42+
providerID: ProviderID.zod,
43+
}),
44+
)
45+
46+
export const OauthCallbackFailed = NamedError.create("ProviderAuthOauthCallbackFailed", z.object({}))
47+
48+
export type ProviderAuthError =
49+
| Auth.AuthServiceError
50+
| InstanceType<typeof OauthMissing>
51+
| InstanceType<typeof OauthCodeMissing>
52+
| InstanceType<typeof OauthCallbackFailed>
53+
54+
export namespace ProviderAuthService {
55+
export interface Service {
56+
readonly methods: () => Effect.Effect<Record<string, Method[]>>
57+
readonly authorize: (input: { providerID: ProviderID; method: number }) => Effect.Effect<Authorization | undefined>
58+
readonly callback: (input: {
59+
providerID: ProviderID
60+
method: number
61+
code?: string
62+
}) => Effect.Effect<void, ProviderAuthError>
63+
readonly api: (input: { providerID: ProviderID; key: string }) => Effect.Effect<void, Auth.AuthServiceError>
64+
}
65+
}
66+
67+
export class ProviderAuthService extends ServiceMap.Service<ProviderAuthService, ProviderAuthService.Service>()(
68+
"@opencode/ProviderAuth",
69+
) {
70+
static readonly layer = Layer.effect(
71+
ProviderAuthService,
72+
Effect.gen(function* () {
73+
const auth = yield* Auth.AuthService
74+
75+
const methods = Effect.fn("ProviderAuthService.methods")(() =>
76+
Effect.promise(() =>
77+
state().then((x) =>
78+
mapValues(x.methods, (y) =>
79+
y.methods.map(
80+
(z): Method => ({
81+
type: z.type,
82+
label: z.label,
83+
}),
84+
),
85+
),
86+
),
87+
),
88+
)
89+
90+
const authorize = Effect.fn("ProviderAuthService.authorize")(function* (input: {
91+
providerID: ProviderID
92+
method: number
93+
}) {
94+
const item = yield* Effect.promise(() => state().then((x) => x.methods[input.providerID]))
95+
const method = item.methods[input.method]
96+
if (method.type !== "oauth") return
97+
const result = yield* Effect.promise(() => method.authorize())
98+
yield* Effect.promise(() =>
99+
state().then((x) => {
100+
x.pending[input.providerID] = result
101+
}),
102+
)
103+
return {
104+
url: result.url,
105+
method: result.method,
106+
instructions: result.instructions,
107+
}
108+
})
109+
110+
const callback = Effect.fn("ProviderAuthService.callback")(function* (input: {
111+
providerID: ProviderID
112+
method: number
113+
code?: string
114+
}) {
115+
const match = yield* Effect.promise(() => state().then((x) => x.pending[input.providerID]))
116+
if (!match) return yield* Effect.fail(new OauthMissing({ providerID: input.providerID }))
117+
118+
const result =
119+
match.method === "code"
120+
? yield* Effect.gen(function* () {
121+
const code = input.code
122+
if (!code) return yield* Effect.fail(new OauthCodeMissing({ providerID: input.providerID }))
123+
return yield* Effect.promise(() => match.callback(code))
124+
})
125+
: yield* Effect.promise(() => match.callback())
126+
127+
if (!result || result.type !== "success") return yield* Effect.fail(new OauthCallbackFailed({}))
128+
129+
if ("key" in result) {
130+
yield* auth.set(input.providerID, {
131+
type: "api",
132+
key: result.key,
133+
})
134+
}
135+
136+
if ("refresh" in result) {
137+
yield* auth.set(input.providerID, {
138+
type: "oauth",
139+
access: result.access,
140+
refresh: result.refresh,
141+
expires: result.expires,
142+
...(result.accountId ? { accountId: result.accountId } : {}),
143+
})
144+
}
145+
})
146+
147+
const api = Effect.fn("ProviderAuthService.api")(function* (input: { providerID: ProviderID; key: string }) {
148+
yield* auth.set(input.providerID, {
149+
type: "api",
150+
key: input.key,
151+
})
152+
})
153+
154+
return ProviderAuthService.of({
155+
methods,
156+
authorize,
157+
callback,
158+
api,
159+
})
160+
}),
161+
)
162+
163+
static readonly defaultLayer = ProviderAuthService.layer.pipe(Layer.provide(Auth.AuthService.defaultLayer))
164+
}
Lines changed: 16 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,17 @@
1-
import { Instance } from "@/project/instance"
2-
import { Plugin } from "../plugin"
3-
import { map, filter, pipe, fromEntries, mapValues } from "remeda"
1+
import { Effect, ManagedRuntime } from "effect"
42
import z from "zod"
3+
54
import { fn } from "@/util/fn"
6-
import type { AuthOuathResult, Hooks } from "@opencode-ai/plugin"
7-
import { NamedError } from "@opencode-ai/util/error"
8-
import { Auth } from "@/auth"
5+
import * as S from "./auth-service"
96
import { ProviderID } from "./schema"
107

11-
export namespace ProviderAuth {
12-
const state = Instance.state(async () => {
13-
const methods = pipe(
14-
await Plugin.list(),
15-
filter((x) => x.auth?.provider !== undefined),
16-
map((x) => [x.auth!.provider, x.auth!] as const),
17-
fromEntries(),
18-
)
19-
return { methods, pending: {} as Record<string, AuthOuathResult> }
20-
})
8+
const rt = ManagedRuntime.make(S.ProviderAuthService.defaultLayer)
219

10+
function runPromise<A>(f: (service: S.ProviderAuthService.Service) => Effect.Effect<A, S.ProviderAuthError>) {
11+
return rt.runPromise(S.ProviderAuthService.use(f))
12+
}
13+
14+
export namespace ProviderAuth {
2215
export const Method = z
2316
.object({
2417
type: z.union([z.literal("oauth"), z.literal("api")]),
@@ -30,15 +23,7 @@ export namespace ProviderAuth {
3023
export type Method = z.infer<typeof Method>
3124

3225
export async function methods() {
33-
const s = await state().then((x) => x.methods)
34-
return mapValues(s, (x) =>
35-
x.methods.map(
36-
(y): Method => ({
37-
type: y.type,
38-
label: y.label,
39-
}),
40-
),
41-
)
26+
return runPromise((service) => service.methods())
4227
}
4328

4429
export const Authorization = z
@@ -57,19 +42,7 @@ export namespace ProviderAuth {
5742
providerID: ProviderID.zod,
5843
method: z.number(),
5944
}),
60-
async (input): Promise<Authorization | undefined> => {
61-
const auth = await state().then((s) => s.methods[input.providerID])
62-
const method = auth.methods[input.method]
63-
if (method.type === "oauth") {
64-
const result = await method.authorize()
65-
await state().then((s) => (s.pending[input.providerID] = result))
66-
return {
67-
url: result.url,
68-
method: result.method,
69-
instructions: result.instructions,
70-
}
71-
}
72-
},
45+
async (input): Promise<Authorization | undefined> => runPromise((service) => service.authorize(input)),
7346
)
7447

7548
export const callback = fn(
@@ -78,71 +51,18 @@ export namespace ProviderAuth {
7851
method: z.number(),
7952
code: z.string().optional(),
8053
}),
81-
async (input) => {
82-
const match = await state().then((s) => s.pending[input.providerID])
83-
if (!match) throw new OauthMissing({ providerID: input.providerID })
84-
let result
85-
86-
if (match.method === "code") {
87-
if (!input.code) throw new OauthCodeMissing({ providerID: input.providerID })
88-
result = await match.callback(input.code)
89-
}
90-
91-
if (match.method === "auto") {
92-
result = await match.callback()
93-
}
94-
95-
if (result?.type === "success") {
96-
if ("key" in result) {
97-
await Auth.set(input.providerID, {
98-
type: "api",
99-
key: result.key,
100-
})
101-
}
102-
if ("refresh" in result) {
103-
const info: Auth.Info = {
104-
type: "oauth",
105-
access: result.access,
106-
refresh: result.refresh,
107-
expires: result.expires,
108-
}
109-
if (result.accountId) {
110-
info.accountId = result.accountId
111-
}
112-
await Auth.set(input.providerID, info)
113-
}
114-
return
115-
}
116-
117-
throw new OauthCallbackFailed({})
118-
},
54+
async (input) => runPromise((service) => service.callback(input)),
11955
)
12056

12157
export const api = fn(
12258
z.object({
12359
providerID: ProviderID.zod,
12460
key: z.string(),
12561
}),
126-
async (input) => {
127-
await Auth.set(input.providerID, {
128-
type: "api",
129-
key: input.key,
130-
})
131-
},
132-
)
133-
134-
export const OauthMissing = NamedError.create(
135-
"ProviderAuthOauthMissing",
136-
z.object({
137-
providerID: ProviderID.zod,
138-
}),
139-
)
140-
export const OauthCodeMissing = NamedError.create(
141-
"ProviderAuthOauthCodeMissing",
142-
z.object({
143-
providerID: ProviderID.zod,
144-
}),
62+
async (input) => runPromise((service) => service.api(input)),
14563
)
14664

147-
export const OauthCallbackFailed = NamedError.create("ProviderAuthOauthCallbackFailed", z.object({}))
65+
export import OauthMissing = S.OauthMissing
66+
export import OauthCodeMissing = S.OauthCodeMissing
67+
export import OauthCallbackFailed = S.OauthCallbackFailed
14868
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import { afterEach, expect, test } from "bun:test"
2+
import { Auth } from "../../src/auth"
3+
import { ProviderAuth } from "../../src/provider/auth"
4+
import { ProviderID } from "../../src/provider/schema"
5+
6+
afterEach(async () => {
7+
await Auth.remove("test-provider-auth")
8+
})
9+
10+
test("ProviderAuth.api persists auth via AuthService", async () => {
11+
await ProviderAuth.api({
12+
providerID: ProviderID.make("test-provider-auth"),
13+
key: "sk-test",
14+
})
15+
16+
expect(await Auth.get("test-provider-auth")).toEqual({
17+
type: "api",
18+
key: "sk-test",
19+
})
20+
})

0 commit comments

Comments
 (0)