Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/cli/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@
"@prisma/config": "workspace:*",
"@prisma/dev": "0.20.0",
"@prisma/engines": "workspace:*",
"@prisma/studio-core": "0.13.1",
"@prisma/studio-core": "0.16.3",
"mysql2": "3.15.3",
"postgres": "3.4.7"
},
Expand Down
112 changes: 58 additions & 54 deletions packages/cli/src/Studio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ import { digest } from 'ohash'
import open from 'open'
import { dirname, extname, join, resolve } from 'pathe'
import { runtime } from 'std-env'
import { z } from 'zod'

import packageJson from '../package.json' assert { type: 'json' }
import { UserFacingError } from './utils/errors'
import { getPpgInfo } from './utils/ppgInfo'

/**
Expand Down Expand Up @@ -56,12 +56,8 @@ const DEFAULT_CONTENT_TYPE = 'application/octet-stream'
const ADAPTER_FILE_NAME = 'adapter.js'
const ADAPTER_FACTORY_FUNCTION_NAME = 'createAdapter'

const ACCELERATE_API_KEY_QUERY_PARAMETER = 'api_key'

const AccelerateAPIKeyPayloadSchema = z.object({
secure_key: z.string(),
tenant_id: z.string(),
})
const ACCELERATE_UNSUPPORTED_MESSAGE =
'Prisma Studio no longer supports Accelerate URLs (`prisma://` or `prisma+postgres://`). Use a direct database connection string instead.'

interface StudioStuff {
createExecutor(connectionString: string, relativeTo: string): Promise<Executor>
Expand All @@ -87,6 +83,14 @@ const PRISMA_ORM_SPECIFIC_QUERY_PARAMETERS = [
'statement_cache_size',
] as const

const PRISMA_ORM_SPECIFIC_MYSQL_QUERY_PARAMETERS = [
'connection_limit',
'pool_timeout',
'socket_timeout',
'sslaccept',
'sslidentity',
] as const

const POSTGRES_STUDIO_STUFF: StudioStuff = {
async createExecutor(connectionString) {
const postgresModule = await import('postgres')
Expand Down Expand Up @@ -184,54 +188,11 @@ Please use Node.js >=22.5, Deno >=2.2 or Bun >=1.0 or ensure you have the \`bett
},
postgres: POSTGRES_STUDIO_STUFF,
postgresql: POSTGRES_STUDIO_STUFF,
'prisma+postgres': {
async createExecutor(connectionString, relativeTo) {
const connectionURL = new URL(connectionString)

if (['localhost', '127.0.0.1', '[::1]'].includes(connectionURL.hostname)) {
// TODO: support `prisma dev` accelerate URLs.

throw new Error('The "prisma+postgres" protocol with localhost is not supported in Prisma Studio yet.')
}

const apiKey = connectionURL.searchParams.get(ACCELERATE_API_KEY_QUERY_PARAMETER)

if (!apiKey) {
throw new Error(
`\`${ACCELERATE_API_KEY_QUERY_PARAMETER}\` query parameter is missing in the provided "prisma+postgres" connection string.`,
)
}

const [, payload] = apiKey.split('.')

try {
const decodedPayload = AccelerateAPIKeyPayloadSchema.parse(
JSON.parse(Buffer.from(payload, 'base64').toString('utf-8')),
)

connectionURL.password = decodedPayload.secure_key
connectionURL.username = decodedPayload.tenant_id
} catch {
throw new Error(
`Invalid/outdated \`${ACCELERATE_API_KEY_QUERY_PARAMETER}\` query parameter in the provided "prisma+postgres" connection string. Please create a new API key and use the new connection string OR use a direct TCP connection string instead.`,
)
}

connectionURL.host = 'db.prisma.io:5432'
connectionURL.pathname = '/postgres'
connectionURL.protocol = 'postgres:'
connectionURL.searchParams.delete(ACCELERATE_API_KEY_QUERY_PARAMETER)
connectionURL.searchParams.set('sslmode', 'require')

return await POSTGRES_STUDIO_STUFF.createExecutor(connectionURL.toString(), relativeTo)
},
reExportAdapterScript: POSTGRES_STUDIO_STUFF.reExportAdapterScript,
},
mysql: {
async createExecutor(connectionString) {
const { createPool } = await import('mysql2/promise')

const pool = createPool(connectionString)
const pool = createPool(normalizeMySQLConnectionString(connectionString))

process.once('SIGINT', () => pool.end())
process.once('SIGTERM', () => pool.end())
Expand Down Expand Up @@ -323,21 +284,25 @@ ${bold('Examples')}
const connectionString = args['--url'] || config.datasource?.url

if (!connectionString) {
return new Error(
return new UserFacingError(
'No database URL found. Provide it via the `--url <url>` argument or define it in your Prisma config file as `datasource.url`.',
)
}

if (!URL.canParse(connectionString)) {
return new Error('The provided database URL is not valid.')
return new UserFacingError('The provided database URL is not valid.')
}

const protocol = new URL(connectionString).protocol.replace(':', '')

if (isAccelerateProtocol(protocol)) {
return new UserFacingError(ACCELERATE_UNSUPPORTED_MESSAGE)
}

const studioStuff = CONNECTION_STRING_PROTOCOL_TO_STUDIO_STUFF[protocol]

if (!studioStuff) {
return new Error(`Prisma Studio is not supported for the "${protocol}" protocol.`)
return new UserFacingError(`Prisma Studio is not supported for the "${protocol}" protocol.`)
}

const executor = await studioStuff.createExecutor(
Expand Down Expand Up @@ -476,6 +441,45 @@ function getUrlBasePath(url: string | undefined, configPath: string | null): str
return url ? process.cwd() : configPath ? dirname(configPath) : process.cwd()
}

function isAccelerateProtocol(protocol: string): boolean {
return protocol === 'prisma' || protocol === 'prisma+postgres'
}

function normalizeMySQLConnectionString(connectionString: string): string {
const connectionURL = new URL(connectionString)

const connectionLimit = connectionURL.searchParams.get('connection_limit')

if (connectionLimit && !connectionURL.searchParams.has('connectionLimit')) {
connectionURL.searchParams.set('connectionLimit', connectionLimit)
}

const sslAccept = connectionURL.searchParams.get('sslaccept')

if (sslAccept && !connectionURL.searchParams.has('ssl')) {
connectionURL.searchParams.set('ssl', JSON.stringify(prismaSslAcceptToMySQL2Ssl(sslAccept)))
}

for (const queryParameter of PRISMA_ORM_SPECIFIC_MYSQL_QUERY_PARAMETERS) {
connectionURL.searchParams.delete(queryParameter)
}

return connectionURL.toString()
}

function prismaSslAcceptToMySQL2Ssl(sslAccept: string): { rejectUnauthorized: boolean } {
switch (sslAccept) {
case 'strict':
return { rejectUnauthorized: true }
case 'accept_invalid_certs':
return { rejectUnauthorized: false }
default:
throw new Error(
`Unknown Prisma MySQL sslaccept value "${sslAccept}". Supported values are "strict" and "accept_invalid_certs".`,
)
}
}

// prettier-ignore
const INDEX_HTML =
`<!doctype html>
Expand Down
122 changes: 122 additions & 0 deletions packages/cli/src/__tests__/Studio.vitest.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import { defaultTestConfig } from '@prisma/config'
import { beforeEach, describe, expect, test, vi } from 'vitest'

const createPoolMock = vi.fn(() => ({ end: vi.fn() }))

vi.mock('mysql2/promise', () => {
return {
createPool: createPoolMock,
}
})

vi.mock('@hono/node-server', () => {
return {
serve: vi.fn(() => ({ close: vi.fn() })),
}
})

vi.mock('@prisma/studio-core/data/mysql2', () => {
return {
createMySQL2Executor: vi.fn(() => ({
execute: vi.fn(),
})),
}
})

vi.mock('@prisma/studio-core/data/bff', () => {
return {
serializeError: vi.fn(() => ({ message: 'mock-error' })),
}
})

vi.mock('@prisma/studio-core/data/node-sqlite', () => {
return {
createNodeSQLiteExecutor: vi.fn(() => ({
execute: vi.fn(),
})),
}
})

vi.mock('@prisma/studio-core/data/postgresjs', () => {
return {
createPostgresJSExecutor: vi.fn(() => ({
execute: vi.fn(),
})),
}
})

describe('Studio MySQL URL compatibility', () => {
beforeEach(() => {
vi.resetModules()
createPoolMock.mockClear()
})

test('converts sslaccept=strict to mysql2 ssl JSON', async () => {
const { Studio } = await import('../Studio')

await Studio.new().parse(
[
'--browser',
'none',
'--port',
'5555',
'--url',
'mysql://user:password@aws.connect.psdb.cloud/db?sslaccept=strict',
],
defaultTestConfig(),
)

expect(createPoolMock).toHaveBeenCalledTimes(1)

const passedUrl = new URL(createPoolMock.mock.calls[0][0])

expect(passedUrl.searchParams.get('sslaccept')).toBeNull()
expect(passedUrl.searchParams.get('ssl')).toBe('{"rejectUnauthorized":true}')
})

test('maps connection_limit to mysql2 connectionLimit', async () => {
const { Studio } = await import('../Studio')

await Studio.new().parse(
[
'--browser',
'none',
'--port',
'5555',
'--url',
'mysql://user:password@aws.connect.psdb.cloud/db?connection_limit=7',
],
defaultTestConfig(),
)

expect(createPoolMock).toHaveBeenCalledTimes(1)

const passedUrl = new URL(createPoolMock.mock.calls[0][0])

expect(passedUrl.searchParams.get('connection_limit')).toBeNull()
expect(passedUrl.searchParams.get('connectionLimit')).toBe('7')
})

test('converts sslaccept=accept_invalid_certs to mysql2 ssl JSON', async () => {
const { Studio } = await import('../Studio')

await Studio.new().parse(
[
'--browser',
'none',
'--port',
'5555',
'--url',
'mysql://user:password@aws.connect.psdb.cloud/db?sslaccept=accept_invalid_certs',
],
defaultTestConfig(),
)

expect(createPoolMock).toHaveBeenCalledTimes(1)

const passedUrl = new URL(createPoolMock.mock.calls[0][0])

expect(passedUrl.searchParams.get('sslaccept')).toBeNull()
expect(passedUrl.searchParams.get('ssl')).toBe('{"rejectUnauthorized":false}')
})
})
7 changes: 6 additions & 1 deletion packages/cli/src/bin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import { SubCommand } from './SubCommand'
import { Telemetry } from './Telemetry'
import { redactCommandArray } from './utils/checkpoint'
import { loadOrInitializeCommandState } from './utils/commandState'
import { UserFacingError } from './utils/errors'
import { loadConfig } from './utils/loadConfig'
import { Validate } from './Validate'
import { Version } from './Version'
Expand Down Expand Up @@ -167,7 +168,11 @@ async function main(): Promise<number> {
debug(`Execution time for executing "await cli.parse(commandArray)": ${cliExecElapsedTime} ms`)

if (result instanceof Error) {
console.error(result instanceof HelpError ? result.message : result)
if (result instanceof HelpError || result instanceof UserFacingError) {
console.error(result.message)
} else {
console.error(result)
}
return 1
}

Expand Down
12 changes: 11 additions & 1 deletion packages/cli/src/utils/errors.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { green } from 'kleur/colors'
import { bold, green, red } from 'kleur/colors'

export class EarlyAccessFlagError extends Error {
constructor() {
Expand All @@ -8,3 +8,13 @@ Please provide the ${green('--early-access')} flag to use this command.`,
)
}
}

/**
* Error intended to be rendered directly to the terminal without stack traces.
*/
export class UserFacingError extends Error {
constructor(message: string) {
super(`\n${bold(red('!'))} ${message}`)
this.name = 'UserFacingError'
}
}
Loading
Loading