diff --git a/AGENTS.md b/AGENTS.md index 509d36d2434b..b46901e7859d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -31,11 +31,13 @@ - **Docs & references**: `ARCHITECTURE.md` contains dependency graphs (requires GraphViz to regenerate), `docker/README.md` explains local DB setup, `docs/benchmarking.md` covers performance benchmarking, `examples/` provides sample apps, and `sandbox/` hosts debugging helpers like the DMMF explorer. - **Client architecture (Prisma 7)**: - - `ClientEngine` in `packages/client/src/runtime/core/engines/client/` orchestrates query execution using WASM query compiler. + - `ClientEngine` in `packages/client/src/runtime/core/engines/client/` orchestrates query execution using Wasm query compiler. - Two executor implementations: `LocalExecutor` (driver adapters, direct DB) and `RemoteExecutor` (Accelerate/Data Proxy). - `QueryInterpreter` class in `packages/client-engine-runtime/src/interpreter/query-interpreter.ts` executes query plans against `SqlQueryable` (driver adapter interface). - Query flow: `PrismaClient` → `ClientEngine.request()` → query compiler → `executor.execute()` → `QueryInterpreter.run()` → driver adapter. - `ExecutePlanParams` interface in `packages/client/src/runtime/core/engines/client/Executor.ts` defines what's passed through the execution chain. + - `TransactionManager` in `packages/client-engine-runtime/src/transaction-manager/transaction-manager.ts` owns interactive transaction IDs and implements nested transactions using savepoints. Savepoint SQL is provider-specific (e.g. PostgreSQL uses `ROLLBACK TO SAVEPOINT `, MySQL/SQLite use `ROLLBACK TO `, SQL Server uses `SAVE TRANSACTION ` / `ROLLBACK TRANSACTION ` and has no release statement). + - `Transaction` in `packages/driver-adapter-utils` models savepoint behavior as async methods (`createSavepoint`, `rollbackToSavepoint`, optional `releaseSavepoint`) instead of returning SQL via `savepoint(action, name)`. `TransactionManager` expects adapter methods for savepoints and does not synthesize provider fallback SQL. - Fluent API `dataPath` is built in `packages/client/src/runtime/core/model/applyFluent.ts` by appending `['select', ]` on each hop; runtime unpacking in `packages/client/src/runtime/RequestHandler.ts` currently strips `'select'`/`'include'` segments before `deepGet`. - In extension context resolution, `dataPath` should be interpreted as selector/field pairs (`select|include`, relation field). Do not strip by raw string value or relation fields literally named `select`/`include` get dropped. diff --git a/eslint.config.cjs b/eslint.config.cjs index 601b580837be..ac154f9199a6 100644 --- a/eslint.config.cjs +++ b/eslint.config.cjs @@ -36,6 +36,7 @@ module.exports = [ '**/fixtures/**', '**/__fixtures__/**', '**/generated/**', + '**/.generated/**', '**/prism.ts', '**/charm.ts', '**/pnpm-lock.yaml', diff --git a/packages/adapter-better-sqlite3/src/better-sqlite3.ts b/packages/adapter-better-sqlite3/src/better-sqlite3.ts index b2a06b8acbda..16b239e0e640 100644 --- a/packages/adapter-better-sqlite3/src/better-sqlite3.ts +++ b/packages/adapter-better-sqlite3/src/better-sqlite3.ts @@ -160,6 +160,18 @@ class BetterSQLite3Transaction extends BetterSQLite3Queryable impleme this.#unlockParent() return Promise.resolve() } + + async createSavepoint(name: string): Promise { + await this.executeRaw({ sql: `SAVEPOINT ${name}`, args: [], argTypes: [] }) + } + + async rollbackToSavepoint(name: string): Promise { + await this.executeRaw({ sql: `ROLLBACK TO ${name}`, args: [], argTypes: [] }) + } + + async releaseSavepoint(name: string): Promise { + await this.executeRaw({ sql: `RELEASE SAVEPOINT ${name}`, args: [], argTypes: [] }) + } } export class PrismaBetterSqlite3Adapter extends BetterSQLite3Queryable implements SqlDriverAdapter { diff --git a/packages/adapter-d1/src/d1-http.ts b/packages/adapter-d1/src/d1-http.ts index 59f7d5cd2bd1..7a86b4c2e18c 100644 --- a/packages/adapter-d1/src/d1-http.ts +++ b/packages/adapter-d1/src/d1-http.ts @@ -199,6 +199,18 @@ class D1HttpTransaction extends D1HttpQueryable implements Transaction { async rollback(): Promise { debug(`[js::rollback]`) } + + async createSavepoint(name: string): Promise { + await this.executeRaw({ sql: `SAVEPOINT ${name}`, args: [], argTypes: [] }) + } + + async rollbackToSavepoint(name: string): Promise { + await this.executeRaw({ sql: `ROLLBACK TO ${name}`, args: [], argTypes: [] }) + } + + async releaseSavepoint(name: string): Promise { + await this.executeRaw({ sql: `RELEASE SAVEPOINT ${name}`, args: [], argTypes: [] }) + } } export class PrismaD1HttpAdapter extends D1HttpQueryable implements SqlDriverAdapter { diff --git a/packages/adapter-d1/src/d1-worker.ts b/packages/adapter-d1/src/d1-worker.ts index 95c5dfd9f08c..cab5bc9fa7e7 100644 --- a/packages/adapter-d1/src/d1-worker.ts +++ b/packages/adapter-d1/src/d1-worker.ts @@ -119,6 +119,18 @@ class D1WorkerTransaction extends D1WorkerQueryable implements Transa async rollback(): Promise { debug(`[js::rollback]`) } + + async createSavepoint(name: string): Promise { + await this.executeRaw({ sql: `SAVEPOINT ${name}`, args: [], argTypes: [] }) + } + + async rollbackToSavepoint(name: string): Promise { + await this.executeRaw({ sql: `ROLLBACK TO ${name}`, args: [], argTypes: [] }) + } + + async releaseSavepoint(name: string): Promise { + await this.executeRaw({ sql: `RELEASE SAVEPOINT ${name}`, args: [], argTypes: [] }) + } } export class PrismaD1WorkerAdapter extends D1WorkerQueryable implements SqlDriverAdapter { diff --git a/packages/adapter-libsql/src/libsql.ts b/packages/adapter-libsql/src/libsql.ts index 42dc88d4787b..8eab92839e86 100644 --- a/packages/adapter-libsql/src/libsql.ts +++ b/packages/adapter-libsql/src/libsql.ts @@ -129,6 +129,18 @@ class LibSqlTransaction extends LibSqlQueryable implements Tr this.#unlockParent() } } + + async createSavepoint(name: string): Promise { + await this.executeRaw({ sql: `SAVEPOINT ${name}`, args: [], argTypes: [] }) + } + + async rollbackToSavepoint(name: string): Promise { + await this.executeRaw({ sql: `ROLLBACK TO ${name}`, args: [], argTypes: [] }) + } + + async releaseSavepoint(name: string): Promise { + await this.executeRaw({ sql: `RELEASE SAVEPOINT ${name}`, args: [], argTypes: [] }) + } } export class PrismaLibSqlAdapter extends LibSqlQueryable implements SqlDriverAdapter { diff --git a/packages/adapter-mariadb/src/mariadb.ts b/packages/adapter-mariadb/src/mariadb.ts index 5bd797216fa0..3bacca68c47a 100644 --- a/packages/adapter-mariadb/src/mariadb.ts +++ b/packages/adapter-mariadb/src/mariadb.ts @@ -98,6 +98,18 @@ class MariaDbTransaction extends MariaDbQueryable implements this.cleanup?.() await this.client.end() } + + async createSavepoint(name: string): Promise { + await this.executeRaw({ sql: `SAVEPOINT ${name}`, args: [], argTypes: [] }) + } + + async rollbackToSavepoint(name: string): Promise { + await this.executeRaw({ sql: `ROLLBACK TO ${name}`, args: [], argTypes: [] }) + } + + async releaseSavepoint(name: string): Promise { + await this.executeRaw({ sql: `RELEASE SAVEPOINT ${name}`, args: [], argTypes: [] }) + } } export type PrismaMariadbOptions = { diff --git a/packages/adapter-mssql/src/mssql.ts b/packages/adapter-mssql/src/mssql.ts index 2e4d828f7727..bb84c4dbe92c 100644 --- a/packages/adapter-mssql/src/mssql.ts +++ b/packages/adapter-mssql/src/mssql.ts @@ -117,6 +117,14 @@ class MssqlTransaction extends MssqlQueryable implements Transaction { release() } } + + async createSavepoint(name: string): Promise { + await this.executeRaw({ sql: `SAVE TRANSACTION ${name}`, args: [], argTypes: [] }) + } + + async rollbackToSavepoint(name: string): Promise { + await this.executeRaw({ sql: `ROLLBACK TRANSACTION ${name}`, args: [], argTypes: [] }) + } } export type PrismaMssqlOptions = { diff --git a/packages/adapter-neon/src/neon.ts b/packages/adapter-neon/src/neon.ts index 73e08cf87d96..d420de44a9dc 100644 --- a/packages/adapter-neon/src/neon.ts +++ b/packages/adapter-neon/src/neon.ts @@ -158,6 +158,18 @@ class NeonTransaction extends NeonWsQueryable implements Transa this.cleanup?.() this.client.release() } + + async createSavepoint(name: string): Promise { + await this.executeRaw({ sql: `SAVEPOINT ${name}`, args: [], argTypes: [] }) + } + + async rollbackToSavepoint(name: string): Promise { + await this.executeRaw({ sql: `ROLLBACK TO SAVEPOINT ${name}`, args: [], argTypes: [] }) + } + + async releaseSavepoint(name: string): Promise { + await this.executeRaw({ sql: `RELEASE SAVEPOINT ${name}`, args: [], argTypes: [] }) + } } export type PrismaNeonOptions = { diff --git a/packages/adapter-pg/src/pg.ts b/packages/adapter-pg/src/pg.ts index 702e4b3d9925..8d01ee593875 100644 --- a/packages/adapter-pg/src/pg.ts +++ b/packages/adapter-pg/src/pg.ts @@ -168,6 +168,18 @@ class PgTransaction extends PgQueryable implements Transactio this.cleanup?.() this.client.release() } + + async createSavepoint(name: string): Promise { + await this.executeRaw({ sql: `SAVEPOINT ${name}`, args: [], argTypes: [] }) + } + + async rollbackToSavepoint(name: string): Promise { + await this.executeRaw({ sql: `ROLLBACK TO SAVEPOINT ${name}`, args: [], argTypes: [] }) + } + + async releaseSavepoint(name: string): Promise { + await this.executeRaw({ sql: `RELEASE SAVEPOINT ${name}`, args: [], argTypes: [] }) + } } export type PrismaPgOptions = { diff --git a/packages/adapter-planetscale/src/planetscale.ts b/packages/adapter-planetscale/src/planetscale.ts index fe25dc9cc4f1..5e62d7498fea 100644 --- a/packages/adapter-planetscale/src/planetscale.ts +++ b/packages/adapter-planetscale/src/planetscale.ts @@ -177,6 +177,18 @@ class PlanetScaleTransaction extends PlanetScaleQueryable { + await this.executeRaw({ sql: `SAVEPOINT ${name}`, args: [], argTypes: [] }) + } + + async rollbackToSavepoint(name: string): Promise { + await this.executeRaw({ sql: `ROLLBACK TO ${name}`, args: [], argTypes: [] }) + } + + async releaseSavepoint(name: string): Promise { + await this.executeRaw({ sql: `RELEASE SAVEPOINT ${name}`, args: [], argTypes: [] }) + } } export class PrismaPlanetScaleAdapter extends PlanetScaleQueryable implements SqlDriverAdapter { diff --git a/packages/adapter-ppg/src/ppg.ts b/packages/adapter-ppg/src/ppg.ts index 115adee8acb0..00057b52353b 100644 --- a/packages/adapter-ppg/src/ppg.ts +++ b/packages/adapter-ppg/src/ppg.ts @@ -188,6 +188,18 @@ class PrismaPostgresTransaction implements Transaction { } } + async createSavepoint(name: string): Promise { + await this.executeRaw({ sql: `SAVEPOINT ${name}`, args: [], argTypes: [] }) + } + + async rollbackToSavepoint(name: string): Promise { + await this.executeRaw({ sql: `ROLLBACK TO SAVEPOINT ${name}`, args: [], argTypes: [] }) + } + + async releaseSavepoint(name: string): Promise { + await this.executeRaw({ sql: `RELEASE SAVEPOINT ${name}`, args: [], argTypes: [] }) + } + async executeRaw(params: SqlQuery): Promise { await this.#ensureBegun() return executeRawStatement(this.#session, params) diff --git a/packages/client-engine-runtime/src/transaction-manager/transaction-manager.test.ts b/packages/client-engine-runtime/src/transaction-manager/transaction-manager.test.ts index 90dac6936689..e38862330c33 100644 --- a/packages/client-engine-runtime/src/transaction-manager/transaction-manager.test.ts +++ b/packages/client-engine-runtime/src/transaction-manager/transaction-manager.test.ts @@ -1,7 +1,6 @@ import timers from 'node:timers/promises' import type { SqlDriverAdapter, SqlQuery, SqlResultSet, Transaction } from '@prisma/driver-adapter-utils' -import { ok } from '@prisma/driver-adapter-utils' import { noopTracingHelper } from '../tracing' import { Options } from './transaction' @@ -31,14 +30,51 @@ class MockDriverAdapter implements SqlDriverAdapter { adapterName = 'mock-adapter' provider: SqlDriverAdapter['provider'] private readonly usePhantomQuery: boolean - - executeRawMock: jest.MockedFn<(params: SqlQuery) => Promise> = jest.fn().mockResolvedValue(ok(1)) - commitMock: jest.MockedFn<() => Promise> = jest.fn().mockResolvedValue(ok(undefined)) - rollbackMock: jest.MockedFn<() => Promise> = jest.fn().mockResolvedValue(ok(undefined)) - - constructor({ provider = 'postgres' as SqlDriverAdapter['provider'], usePhantomQuery = false } = {}) { + private readonly createSavepoint: Transaction['createSavepoint'] + private readonly rollbackToSavepoint: Transaction['rollbackToSavepoint'] + private readonly releaseSavepoint: Transaction['releaseSavepoint'] + + executeRawMock: jest.MockedFn<(params: SqlQuery) => Promise> = jest.fn().mockResolvedValue(1) + commitMock: jest.MockedFn<() => Promise> = jest.fn().mockResolvedValue(undefined) + rollbackMock: jest.MockedFn<() => Promise> = jest.fn().mockResolvedValue(undefined) + + constructor( + options: { + provider?: SqlDriverAdapter['provider'] + usePhantomQuery?: boolean + createSavepoint?: Transaction['createSavepoint'] + rollbackToSavepoint?: Transaction['rollbackToSavepoint'] + releaseSavepoint?: Transaction['releaseSavepoint'] + } = {}, + ) { + const { provider = 'postgres' as SqlDriverAdapter['provider'], usePhantomQuery = false } = options this.usePhantomQuery = usePhantomQuery this.provider = provider + + this.createSavepoint = Object.hasOwn(options, 'createSavepoint') + ? options.createSavepoint + : async (name) => { + const sql = this.provider === 'sqlserver' ? `SAVE TRANSACTION ${name}` : `SAVEPOINT ${name}` + await this.executeRawMock({ sql, args: [], argTypes: [] }) + } + this.rollbackToSavepoint = Object.hasOwn(options, 'rollbackToSavepoint') + ? options.rollbackToSavepoint + : async (name) => { + const sql = + this.provider === 'sqlserver' + ? `ROLLBACK TRANSACTION ${name}` + : this.provider === 'postgres' + ? `ROLLBACK TO SAVEPOINT ${name}` + : `ROLLBACK TO ${name}` + await this.executeRawMock({ sql, args: [], argTypes: [] }) + } + this.releaseSavepoint = Object.hasOwn(options, 'releaseSavepoint') + ? options.releaseSavepoint + : this.provider === 'sqlserver' + ? undefined + : async (name) => { + await this.executeRawMock({ sql: `RELEASE SAVEPOINT ${name}`, args: [], argTypes: [] }) + } } executeRaw(params: SqlQuery): Promise { @@ -62,15 +98,19 @@ class MockDriverAdapter implements SqlDriverAdapter { const commitMock = this.commitMock const rollbackMock = this.rollbackMock const usePhantomQuery = this.usePhantomQuery + const provider = this.provider const mockTransaction: Transaction = { adapterName: 'mock-adapter', - provider: 'postgres', + provider, options: { usePhantomQuery }, queryRaw: jest.fn().mockRejectedValue('Not implemented for test'), executeRaw: executeRawMock, commit: commitMock, rollback: rollbackMock, + createSavepoint: this.createSavepoint, + rollbackToSavepoint: this.rollbackToSavepoint, + releaseSavepoint: this.releaseSavepoint, } return new Promise((resolve) => @@ -113,6 +153,281 @@ test('transaction executes normally', async () => { await expect(transactionManager.rollbackTransaction(id)).rejects.toBeInstanceOf(TransactionClosedError) }) +test('nested commit only closes at the outermost level', async () => { + const driverAdapter = new MockDriverAdapter() + const transactionManager = new TransactionManager({ + driverAdapter, + transactionOptions: TRANSACTION_OPTIONS, + tracingHelper: noopTracingHelper, + }) + + const id = await startTransaction(transactionManager) + + const nested = await transactionManager.startTransaction({ + ...TRANSACTION_OPTIONS, + newTxId: id, + }) + + expect(nested.id).toBe(id) + + expect(driverAdapter.executeRawMock.mock.calls[0][0].sql).toMatch(/^SAVEPOINT prisma_sp_\d+$/) + + // Inner commit should not close the underlying transaction. + await transactionManager.commitTransaction(id) + expect(driverAdapter.commitMock).not.toHaveBeenCalled() + expect(driverAdapter.executeRawMock.mock.calls[1][0].sql).toMatch(/^RELEASE SAVEPOINT prisma_sp_\d+$/) + + // Outer commit closes. + await transactionManager.commitTransaction(id) + expect(driverAdapter.commitMock).toHaveBeenCalledTimes(1) + expect(driverAdapter.executeRawMock.mock.calls[2][0].sql).toEqual('COMMIT') +}) + +test('nested rollback uses a savepoint and keeps the outer transaction open', async () => { + const driverAdapter = new MockDriverAdapter() + const transactionManager = new TransactionManager({ + driverAdapter, + transactionOptions: TRANSACTION_OPTIONS, + tracingHelper: noopTracingHelper, + }) + + const id = await startTransaction(transactionManager) + + await transactionManager.startTransaction({ + ...TRANSACTION_OPTIONS, + newTxId: id, + }) + + expect(driverAdapter.executeRawMock.mock.calls[0][0].sql).toMatch(/^SAVEPOINT prisma_sp_\d+$/) + + // Inner rollback should not close the underlying transaction. + await transactionManager.rollbackTransaction(id) + expect(driverAdapter.rollbackMock).not.toHaveBeenCalled() + expect(driverAdapter.executeRawMock.mock.calls[1][0].sql).toMatch(/^ROLLBACK TO SAVEPOINT prisma_sp_\d+$/) + expect(driverAdapter.executeRawMock.mock.calls[2][0].sql).toMatch(/^RELEASE SAVEPOINT prisma_sp_\d+$/) + + // Outer commit still closes. + await transactionManager.commitTransaction(id) + expect(driverAdapter.commitMock).toHaveBeenCalledTimes(1) + expect(driverAdapter.executeRawMock.mock.calls[3][0].sql).toEqual('COMMIT') +}) + +test('nested starts are serialized when the first savepoint fails', async () => { + const driverAdapter = new MockDriverAdapter() + const transactionManager = new TransactionManager({ + driverAdapter, + transactionOptions: TRANSACTION_OPTIONS, + tracingHelper: noopTracingHelper, + }) + + const id = await startTransaction(transactionManager) + + let savepointCallCount = 0 + let rejectFirstSavepoint: ((error: Error) => void) | undefined + let resolveFirstSavepointStarted: () => void + const firstSavepointStarted = new Promise((resolve) => { + resolveFirstSavepointStarted = resolve + }) + + driverAdapter.executeRawMock.mockImplementation((query) => { + if (query.sql.startsWith('SAVEPOINT ')) { + savepointCallCount += 1 + if (savepointCallCount === 1) { + resolveFirstSavepointStarted() + return new Promise((_, reject) => { + rejectFirstSavepoint = reject + }) + } + } + + return Promise.resolve(1) + }) + + const nested1 = transactionManager.startTransaction({ + ...TRANSACTION_OPTIONS, + newTxId: id, + }) + const nested2 = transactionManager.startTransaction({ + ...TRANSACTION_OPTIONS, + newTxId: id, + }) + + await firstSavepointStarted + await Promise.resolve() + expect(savepointCallCount).toBe(1) + + rejectFirstSavepoint?.(new Error('savepoint failed')) + await expect(nested1).rejects.toThrow('savepoint failed') + + await expect(nested2).resolves.toEqual({ id }) + + const savepointQueries = driverAdapter.executeRawMock.mock.calls + .map((call) => call[0].sql) + .filter((sql) => sql.startsWith('SAVEPOINT ')) + const successfulSavepointName = savepointQueries[1].slice('SAVEPOINT '.length) + + await transactionManager.rollbackTransaction(id) + + const rollbackToSavepointQuery = driverAdapter.executeRawMock.mock.calls + .map((call) => call[0].sql) + .find((sql) => sql.startsWith('ROLLBACK TO SAVEPOINT ')) + + expect(rollbackToSavepointQuery).toEqual(`ROLLBACK TO SAVEPOINT ${successfulSavepointName}`) + + // Close top-level transaction. + await transactionManager.rollbackTransaction(id) +}) + +test('nested savepoints use sqlserver syntax', async () => { + const driverAdapter = new MockDriverAdapter({ provider: 'sqlserver' }) + const transactionManager = new TransactionManager({ + driverAdapter, + transactionOptions: TRANSACTION_OPTIONS, + tracingHelper: noopTracingHelper, + }) + + const id = await startTransaction(transactionManager) + await transactionManager.startTransaction({ ...TRANSACTION_OPTIONS, newTxId: id }) + + expect(driverAdapter.executeRawMock.mock.calls[0][0].sql).toMatch(/^SAVE TRANSACTION prisma_sp_\d+$/) + + await transactionManager.rollbackTransaction(id) + expect(driverAdapter.executeRawMock.mock.calls[1][0].sql).toMatch(/^ROLLBACK TRANSACTION prisma_sp_\d+$/) + + // No release savepoint query on SQL Server. + expect(driverAdapter.executeRawMock.mock.calls).toHaveLength(2) + await transactionManager.commitTransaction(id) + expect(driverAdapter.executeRawMock.mock.calls[2][0].sql).toEqual('COMMIT') +}) + +test('nested savepoints use mysql syntax', async () => { + const driverAdapter = new MockDriverAdapter({ provider: 'mysql' }) + const transactionManager = new TransactionManager({ + driverAdapter, + transactionOptions: TRANSACTION_OPTIONS, + tracingHelper: noopTracingHelper, + }) + + const id = await startTransaction(transactionManager) + await transactionManager.startTransaction({ ...TRANSACTION_OPTIONS, newTxId: id }) + + expect(driverAdapter.executeRawMock.mock.calls[0][0].sql).toMatch(/^SAVEPOINT prisma_sp_\d+$/) + + await transactionManager.rollbackTransaction(id) + expect(driverAdapter.executeRawMock.mock.calls[1][0].sql).toMatch(/^ROLLBACK TO prisma_sp_\d+$/) + expect(driverAdapter.executeRawMock.mock.calls[2][0].sql).toMatch(/^RELEASE SAVEPOINT prisma_sp_\d+$/) + + await transactionManager.commitTransaction(id) + expect(driverAdapter.executeRawMock.mock.calls[3][0].sql).toEqual('COMMIT') +}) + +test('nested savepoints use sqlite syntax', async () => { + const driverAdapter = new MockDriverAdapter({ provider: 'sqlite' }) + const transactionManager = new TransactionManager({ + driverAdapter, + transactionOptions: TRANSACTION_OPTIONS, + tracingHelper: noopTracingHelper, + }) + + const id = await startTransaction(transactionManager) + await transactionManager.startTransaction({ ...TRANSACTION_OPTIONS, newTxId: id }) + + expect(driverAdapter.executeRawMock.mock.calls[0][0].sql).toMatch(/^SAVEPOINT prisma_sp_\d+$/) + + await transactionManager.rollbackTransaction(id) + expect(driverAdapter.executeRawMock.mock.calls[1][0].sql).toMatch(/^ROLLBACK TO prisma_sp_\d+$/) + expect(driverAdapter.executeRawMock.mock.calls[2][0].sql).toMatch(/^RELEASE SAVEPOINT prisma_sp_\d+$/) + + await transactionManager.commitTransaction(id) + expect(driverAdapter.executeRawMock.mock.calls[3][0].sql).toEqual('COMMIT') +}) + +test('nested savepoints use adapter-provided methods when available', async () => { + const createSavepoint = jest.fn>, [string]>(async () => {}) + const rollbackToSavepoint = jest.fn>, [string]>( + async () => {}, + ) + const releaseSavepoint = jest.fn>, [string]>(async () => {}) + + const driverAdapter = new MockDriverAdapter({ + provider: 'postgres', + createSavepoint, + rollbackToSavepoint, + releaseSavepoint, + }) + const transactionManager = new TransactionManager({ + driverAdapter, + transactionOptions: TRANSACTION_OPTIONS, + tracingHelper: noopTracingHelper, + }) + + const id = await startTransaction(transactionManager) + await transactionManager.startTransaction({ ...TRANSACTION_OPTIONS, newTxId: id }) + await transactionManager.rollbackTransaction(id) + await transactionManager.commitTransaction(id) + + expect(createSavepoint).toHaveBeenCalledTimes(1) + expect(rollbackToSavepoint).toHaveBeenCalledTimes(1) + expect(releaseSavepoint).toHaveBeenCalledTimes(1) + + const savepointName = createSavepoint.mock.calls[0][0] + expect(rollbackToSavepoint).toHaveBeenCalledWith(savepointName) + expect(releaseSavepoint).toHaveBeenCalledWith(savepointName) + expect(driverAdapter.executeRawMock.mock.calls[0][0].sql).toEqual('COMMIT') +}) + +test('nested savepoint release can be omitted by adapter', async () => { + const createSavepoint = jest.fn>, [string]>(async () => {}) + + const driverAdapter = new MockDriverAdapter({ + provider: 'postgres', + createSavepoint, + releaseSavepoint: undefined, + }) + const transactionManager = new TransactionManager({ + driverAdapter, + transactionOptions: TRANSACTION_OPTIONS, + tracingHelper: noopTracingHelper, + }) + + const id = await startTransaction(transactionManager) + await transactionManager.startTransaction({ ...TRANSACTION_OPTIONS, newTxId: id }) + await transactionManager.commitTransaction(id) + await transactionManager.commitTransaction(id) + + expect(createSavepoint).toHaveBeenCalledTimes(1) + expect(driverAdapter.executeRawMock.mock.calls[0][0].sql).toEqual('COMMIT') +}) + +test('missing createSavepoint fails nested transaction start', async () => { + const driverAdapter = new MockDriverAdapter({ createSavepoint: undefined }) + const transactionManager = new TransactionManager({ + driverAdapter, + transactionOptions: TRANSACTION_OPTIONS, + tracingHelper: noopTracingHelper, + }) + + const id = await startTransaction(transactionManager) + + await expect(transactionManager.startTransaction({ ...TRANSACTION_OPTIONS, newTxId: id })).rejects.toThrow( + 'createSavepoint is not implemented', + ) +}) + +test('missing rollbackToSavepoint fails nested rollback', async () => { + const driverAdapter = new MockDriverAdapter({ rollbackToSavepoint: undefined }) + const transactionManager = new TransactionManager({ + driverAdapter, + transactionOptions: TRANSACTION_OPTIONS, + tracingHelper: noopTracingHelper, + }) + + const id = await startTransaction(transactionManager) + await transactionManager.startTransaction({ ...TRANSACTION_OPTIONS, newTxId: id }) + + await expect(transactionManager.rollbackTransaction(id)).rejects.toThrow('rollbackToSavepoint is not implemented') +}) + test('transaction is rolled back', async () => { const driverAdapter = new MockDriverAdapter() const transactionManager = new TransactionManager({ diff --git a/packages/client-engine-runtime/src/transaction-manager/transaction-manager.ts b/packages/client-engine-runtime/src/transaction-manager/transaction-manager.ts index f0722567d5fd..e1cbe7fdba93 100644 --- a/packages/client-engine-runtime/src/transaction-manager/transaction-manager.ts +++ b/packages/client-engine-runtime/src/transaction-manager/transaction-manager.ts @@ -28,6 +28,10 @@ type TransactionWrapper = { timeout: number | undefined startedAt: number transaction?: Transaction + operationQueue: Promise + depth: number + savepoints: string[] + savepointCounter: number } & TransactionState type TransactionState = @@ -102,6 +106,36 @@ export class TransactionManager { } async #startTransactionImpl(options: Options): Promise { + if (options.newTxId) { + return await this.#withActiveTransactionLock(options.newTxId, 'start', async (existing) => { + if (existing.status !== 'running') { + throw new TransactionInternalConsistencyError( + `Transaction in invalid state ${existing.status} when starting a nested transaction.`, + ) + } + if (!existing.transaction) { + throw new TransactionInternalConsistencyError( + `Transaction missing underlying driver transaction when starting a nested transaction.`, + ) + } + + existing.depth += 1 + + const savepointName = this.#nextSavepointName(existing) + existing.savepoints.push(savepointName) + try { + await this.#requiredCreateSavepoint(existing.transaction)(savepointName) + } catch (e) { + // Keep state consistent if creating the savepoint fails. + existing.depth -= 1 + existing.savepoints.pop() + throw e + } + + return { id: existing.id } + }) + } + const transaction: TransactionWrapper = { id: await randomUUID(), status: 'waiting', @@ -109,6 +143,10 @@ export class TransactionManager { timeout: options.timeout, startedAt: Date.now(), transaction: undefined, + operationQueue: Promise.resolve(), + depth: 1, + savepoints: [], + savepointCounter: 0, } // Start timeout to wait for transaction to be started. @@ -167,15 +205,53 @@ export class TransactionManager { async commitTransaction(transactionId: string): Promise { return await this.tracingHelper.runInChildSpan('commit_transaction', async () => { - const txw = this.#getActiveOrClosingTransaction(transactionId, 'commit') - await this.#closeTransaction(txw, 'committed') + await this.#withActiveTransactionLock(transactionId, 'commit', async (txw) => { + if (txw.depth > 1) { + if (!txw.transaction) throw new TransactionNotFoundError() + const savepointName = txw.savepoints.at(-1) + if (!savepointName) { + throw new TransactionInternalConsistencyError( + `Missing savepoint for nested commit. Depth: ${txw.depth}, transactionId: ${txw.id}`, + ) + } + try { + await this.#releaseSavepoint(txw.transaction, savepointName) + } finally { + // Keep internal state consistent even if releasing the savepoint fails. + txw.savepoints.pop() + txw.depth -= 1 + } + return + } + await this.#closeTransaction(txw, 'committed') + }) }) } async rollbackTransaction(transactionId: string): Promise { return await this.tracingHelper.runInChildSpan('rollback_transaction', async () => { - const txw = this.#getActiveOrClosingTransaction(transactionId, 'rollback') - await this.#closeTransaction(txw, 'rolled_back') + await this.#withActiveTransactionLock(transactionId, 'rollback', async (txw) => { + if (txw.depth > 1) { + if (!txw.transaction) throw new TransactionNotFoundError() + const savepointName = txw.savepoints.at(-1) + if (!savepointName) { + throw new TransactionInternalConsistencyError( + `Missing savepoint for nested rollback. Depth: ${txw.depth}, transactionId: ${txw.id}`, + ) + } + + try { + await this.#requiredRollbackToSavepoint(txw.transaction)(savepointName) + await this.#releaseSavepoint(txw.transaction, savepointName) + } finally { + // Keep internal state consistent even if rollback/release fails. + txw.savepoints.pop() + txw.depth -= 1 + } + return + } + await this.#closeTransaction(txw, 'rolled_back') + }) }) } @@ -228,7 +304,53 @@ export class TransactionManager { async cancelAllTransactions(): Promise { // TODO: call `map` on the iterator directly without collecting it into an array first // once we drop support for Node.js 18 and 20. - await Promise.allSettled([...this.transactions.values()].map((tx) => this.#closeTransaction(tx, 'rolled_back'))) + await Promise.allSettled( + [...this.transactions.values()].map((tx) => + this.#runSerialized(tx, async () => { + const current = this.transactions.get(tx.id) + if (current) { + await this.#closeTransaction(current, 'rolled_back') + } + }), + ), + ) + } + + #nextSavepointName(transaction: TransactionWrapper): string { + return `prisma_sp_${transaction.savepointCounter++}` + } + + #requiredCreateSavepoint(transaction: Transaction): (name: string) => Promise { + if (transaction.createSavepoint) { + return transaction.createSavepoint.bind(transaction) + } + + throw new TransactionManagerError( + `Nested transactions are not supported by adapter "${transaction.adapterName}" (${transaction.provider}): createSavepoint is not implemented.`, + ) + } + + #requiredRollbackToSavepoint(transaction: Transaction): (name: string) => Promise { + if (transaction.rollbackToSavepoint) { + return transaction.rollbackToSavepoint.bind(transaction) + } + + throw new TransactionManagerError( + `Nested transactions are not supported by adapter "${transaction.adapterName}" (${transaction.provider}): rollbackToSavepoint is not implemented.`, + ) + } + + async #releaseSavepoint(transaction: Transaction, name: string): Promise { + if (transaction.releaseSavepoint) { + await transaction.releaseSavepoint(name) + } + } + + #debugTransactionAlreadyClosedOnTimeout(transactionId: string): void { + // Transaction was already committed or rolled back when timeout happened. + // Should normally not happen as timeout is cancelled when transaction is committed or rolled back. + // No further action needed though. + debug('Transaction already committed or rolled back when timeout happened.', transactionId) } #startTransactionTimeout(transactionId: string, timeout: number | undefined): NodeJS.Timeout | undefined { @@ -237,20 +359,57 @@ export class TransactionManager { debug('Transaction timed out.', { transactionId, timeoutStartedAt, timeout }) const tx = this.transactions.get(transactionId) - if (tx && ['running', 'waiting'].includes(tx.status)) { - await this.#closeTransaction(tx, 'timed_out') - } else { - // Transaction was already committed or rolled back when timeout happened. - // Should normally not happen as timeout is cancelled when transaction is committed or rolled back. - // No further action needed though. - debug('Transaction already committed or rolled back when timeout happened.', transactionId) + if (!tx) { + this.#debugTransactionAlreadyClosedOnTimeout(transactionId) + return } + + await this.#runSerialized(tx, async () => { + const current = this.transactions.get(transactionId) + if (current && ['running', 'waiting'].includes(current.status)) { + await this.#closeTransaction(current, 'timed_out') + } else { + this.#debugTransactionAlreadyClosedOnTimeout(transactionId) + } + }) }, timeout) timer?.unref?.() return timer } + // Any operation that mutates or closes a transaction must run through this lock so + // status/savepoint/depth checks and updates happen against a stable view of state. + async #withActiveTransactionLock( + transactionId: string, + operation: string, + callback: (tx: TransactionWrapper) => Promise, + ): Promise { + const tx = this.#getActiveOrClosingTransaction(transactionId, operation) + return await this.#runSerialized(tx, async () => { + const current = this.#getActiveOrClosingTransaction(transactionId, operation) + return await callback(current) + }) + } + + // Serializes operations per transaction id to prevent interleaving across awaits. + // This avoids races where one operation mutates savepoint/depth state while another + // operation is suspended, which could otherwise corrupt cleanup logic. + async #runSerialized(tx: TransactionWrapper, callback: () => Promise): Promise { + const previousOperation = tx.operationQueue + let releaseOperationLock!: () => void + tx.operationQueue = new Promise((resolve) => { + releaseOperationLock = resolve + }) + + await previousOperation + try { + return await callback() + } finally { + releaseOperationLock() + } + } + async #closeTransaction(tx: TransactionWrapper, status: 'committed' | 'rolled_back' | 'timed_out'): Promise { const createClosingPromise = async () => { debug('Closing transaction.', { transactionId: tx.id, status }) diff --git a/packages/client-engine-runtime/src/transaction-manager/transaction.ts b/packages/client-engine-runtime/src/transaction-manager/transaction.ts index 45c33b489f12..f9e763785c20 100644 --- a/packages/client-engine-runtime/src/transaction-manager/transaction.ts +++ b/packages/client-engine-runtime/src/transaction-manager/transaction.ts @@ -9,6 +9,7 @@ export type Options = { /// Transaction isolation level isolationLevel?: IsolationLevel + newTxId?: string } export type TransactionInfo = { diff --git a/packages/client-generator-js/src/TSClient/PrismaClient.ts b/packages/client-generator-js/src/TSClient/PrismaClient.ts index f03b904ab10b..1d834dda459c 100644 --- a/packages/client-generator-js/src/TSClient/PrismaClient.ts +++ b/packages/client-generator-js/src/TSClient/PrismaClient.ts @@ -233,9 +233,7 @@ function interactiveTransactionDefinition(context: GenerateContext) { const callbackType = ts .functionType() - .addParameter( - ts.parameter('prisma', ts.omit(ts.namedType('PrismaClient'), ts.namedType('runtime.ITXClientDenyList'))), - ) + .addParameter(ts.parameter('prisma', ts.omit(ts.namedType('PrismaClient'), itxTransactionClientDenyList(context)))) .setReturnType(returnType) const method = ts @@ -248,6 +246,14 @@ function interactiveTransactionDefinition(context: GenerateContext) { return ts.stringify(method, { indentLevel: 1, newLine: 'leading' }) } +function itxTransactionClientDenyList(context: GenerateContext) { + if (context.provider === 'mongodb') { + return ts.unionType([ts.namedType('runtime.ITXClientDenyList'), ts.stringLiteral('$transaction')]) + } + + return ts.namedType('runtime.ITXClientDenyList') +} + function queryRawDefinition(context: GenerateContext) { // we do not generate `$queryRaw...` definitions if not supported if (!context.dmmf.mappings.otherOperations.write.includes('queryRaw')) { @@ -476,6 +482,8 @@ get ${methodName}(): Prisma.${m.model}Delegate<${generics.join(', ')}>;` } public toTS(): string { const clientOptions = this.buildClientOptions() + const transactionClientDenyList = + this.context.provider === 'mongodb' ? "runtime.ITXClientDenyList | '$transaction'" : 'runtime.ITXClientDenyList' return `${clientExtensionsDefinitions(this.context)} export type DefaultPrismaClient = PrismaClient @@ -545,7 +553,7 @@ export function getLogLevel(log: Array): LogLevel | un /** * \`PrismaClient\` proxy available in interactive transactions. */ -export type TransactionClient = Omit +export type TransactionClient = Omit ` } diff --git a/packages/client-generator-ts/src/TSClient/PrismaClient.ts b/packages/client-generator-ts/src/TSClient/PrismaClient.ts index 06f6723d8198..9e1afef9567e 100644 --- a/packages/client-generator-ts/src/TSClient/PrismaClient.ts +++ b/packages/client-generator-ts/src/TSClient/PrismaClient.ts @@ -73,9 +73,7 @@ function interactiveTransactionDefinition(context: GenerateContext) { const callbackType = ts .functionType() - .addParameter( - ts.parameter('prisma', tsx.omit(ts.namedType('PrismaClient'), ts.namedType('runtime.ITXClientDenyList'))), - ) + .addParameter(ts.parameter('prisma', tsx.omit(ts.namedType('PrismaClient'), itxTransactionClientDenyList(context)))) .setReturnType(returnType) const method = ts @@ -88,6 +86,14 @@ function interactiveTransactionDefinition(context: GenerateContext) { return ts.stringify(method, { indentLevel: 1, newLine: 'leading' }) } +function itxTransactionClientDenyList(context: GenerateContext) { + if (!context.isSqlProvider()) { + return ts.unionType([ts.namedType('runtime.ITXClientDenyList'), ts.stringLiteral('$transaction')]) + } + + return ts.namedType('runtime.ITXClientDenyList') +} + function queryRawDefinition(context: GenerateContext) { // we do not generate `$queryRaw...` definitions if not supported if (!context.dmmf.mappings.otherOperations.write.includes('queryRaw')) { diff --git a/packages/client-generator-ts/src/TSClient/file-generators/PrismaNamespaceFile.ts b/packages/client-generator-ts/src/TSClient/file-generators/PrismaNamespaceFile.ts index 6d39ae97bae6..78458b4cbe68 100644 --- a/packages/client-generator-ts/src/TSClient/file-generators/PrismaNamespaceFile.ts +++ b/packages/client-generator-ts/src/TSClient/file-generators/PrismaNamespaceFile.ts @@ -31,6 +31,10 @@ export function createPrismaNamespaceFile(context: GenerateContext, options: TSC const fieldRefs = context.dmmf.schema.fieldRefTypes.prisma?.map((type) => new FieldRefInput(type).toTS()) ?? [] + const transactionClientDenyList = context.isSqlProvider() + ? 'runtime.ITXClientDenyList' + : "runtime.ITXClientDenyList | '$transaction'" + return `${jsDocHeader} ${imports.join('\n')} @@ -136,7 +140,7 @@ export type PrismaAction = /** * \`PrismaClient\` proxy available in interactive transactions. */ -export type TransactionClient = Omit +export type TransactionClient = Omit ` } diff --git a/packages/client/src/runtime/core/engines/common/types/Transaction.ts b/packages/client/src/runtime/core/engines/common/types/Transaction.ts index e338bf683ec4..f33b25ce42d0 100644 --- a/packages/client/src/runtime/core/engines/common/types/Transaction.ts +++ b/packages/client/src/runtime/core/engines/common/types/Transaction.ts @@ -11,6 +11,12 @@ export type Options = { /** Transaction isolation level */ isolationLevel?: IsolationLevel + + /** + * Used for nested interactive transactions. When provided, the engine may + * re-use an existing open transaction instead of opening a new one. + */ + newTxId?: string } export type InteractiveTransactionInfo = { diff --git a/packages/client/src/runtime/core/types/exported/itxClientDenyList.ts b/packages/client/src/runtime/core/types/exported/itxClientDenyList.ts index 1cd43aaf6ac6..046f4634d05e 100644 --- a/packages/client/src/runtime/core/types/exported/itxClientDenyList.ts +++ b/packages/client/src/runtime/core/types/exported/itxClientDenyList.ts @@ -1,4 +1,4 @@ -const denylist = ['$connect', '$disconnect', '$on', '$transaction', '$extends'] as const +const denylist = ['$connect', '$disconnect', '$on', '$use', '$extends'] as const export const itxClientDenyList = denylist as ReadonlyArray diff --git a/packages/client/src/runtime/getPrismaClient.ts b/packages/client/src/runtime/getPrismaClient.ts index 11f5d3df3687..b9f5c35f7f83 100644 --- a/packages/client/src/runtime/getPrismaClient.ts +++ b/packages/client/src/runtime/getPrismaClient.ts @@ -216,7 +216,67 @@ type EventCallback = [E] extends ['beforeExit'] ? (event: EngineEvent) => void : never -const TX_ID = Symbol.for('prisma.client.transaction.id') +const TX_SCOPE_CONTEXT = Symbol.for('prisma.client.transaction.scope_context') + +type ItxScopeState = { + stack: string[] +} + +type TopLevelItxScopeContext = { + kind: 'top-level' +} + +type NestedItxScopeContext = { + kind: 'nested' + txId: string + scopeId: string + scopeState: ItxScopeState +} + +type ItxScopeContext = TopLevelItxScopeContext | NestedItxScopeContext + +function getItxScopeContext(client: object): ItxScopeContext { + const symbolStorage = client as Record + const context = symbolStorage[TX_SCOPE_CONTEXT] + + if (context === undefined) { + return { kind: 'top-level' } + } + + if (isNestedItxScopeContext(context)) { + return context + } + + throw new Error('Internal error: inconsistent transaction scope context.') +} + +function isNestedItxScopeContext(value: unknown): value is NestedItxScopeContext { + if (typeof value !== 'object' || value === null) { + return false + } + + const objectValue = value as Record + return ( + objectValue['kind'] === 'nested' && + typeof objectValue['txId'] === 'string' && + typeof objectValue['scopeId'] === 'string' && + isItxScopeState(objectValue['scopeState']) + ) +} + +function isItxScopeState(value: unknown): value is ItxScopeState { + if (typeof value !== 'object' || value === null) { + return false + } + + return Array.isArray(value['stack']) +} + +function createItxScopeId(): string { + return typeof globalThis.crypto?.randomUUID === 'function' + ? globalThis.crypto.randomUUID() + : `${Date.now()}-${Math.random().toString(16).slice(2)}` +} const BatchTxIdCounter = { id: 0, @@ -675,46 +735,109 @@ Or read our docs at https://www.prisma.io/docs/concepts/components/prisma-client */ async _transactionWithCallback({ callback, - options, + options = {}, }: { callback: (client: Client) => Promise options?: Options }) { + const itxContext = getItxScopeContext(this) + const isNested = itxContext.kind === 'nested' + const scopeState: ItxScopeState = isNested ? itxContext.scopeState : { stack: [] } + const scopeStack = scopeState.stack + + const scopeId = createItxScopeId() + + if (isNested) { + // Only the currently-active (innermost) scope can start another nested scope. + // This prevents sibling nested transactions from running concurrently. + const activeScope = scopeStack.at(-1) + if (activeScope !== itxContext.scopeId) { + throw new Error('Concurrent nested transactions are not supported') + } + + // Re-use the underlying transaction in the engine by reusing the same transaction id. + options.newTxId = itxContext.txId + } + scopeStack.push(scopeId) + const headers = { traceparent: this._tracingHelper.getTraceParent() } const optionsWithDefaults: Options = { maxWait: options?.maxWait ?? this._engineConfig.transactionOptions.maxWait, timeout: options?.timeout ?? this._engineConfig.transactionOptions.timeout, isolationLevel: options?.isolationLevel ?? this._engineConfig.transactionOptions.isolationLevel, + newTxId: options.newTxId, + } + let info: Transaction.InteractiveTransactionInfo + try { + info = await this._engine.transaction('start', headers, optionsWithDefaults) + } catch (e) { + // Ensure we don't leave the scope stack dirty if starting the transaction fails. + if (scopeStack.at(-1) === scopeId) scopeStack.pop() + throw e } - const info = await this._engine.transaction('start', headers, optionsWithDefaults) let result: unknown try { // execute user logic with a proxied the client const transaction = { kind: 'itx', ...info } as const - result = await callback(this._createItxClient(transaction)) + result = await callback(this._createItxClient(transaction, scopeId, scopeState)) + + if (isNested) { + // Don't allow closing a transaction if we are not the active scope. + if (scopeStack.at(-1) !== scopeId) { + throw new Error('Nested transactions must be closed in reverse order of creation.') + } + } else if (scopeStack.length !== 1) { + throw new Error('Cannot close transaction while a nested transaction is still active.') + } - // it went well, then we commit the transaction await this._engine.transaction('commit', headers, info) } catch (e: any) { - // it went bad, then we rollback the transaction - await this._engine.transaction('rollback', headers, info).catch(() => {}) + // If we try to close out-of-order (e.g. un-awaited nested transaction), + // the TransactionManager depth will be > 1 and a single rollback would + // only roll back to the latest savepoint, leaving the top-level transaction + // open and leaking it. Force rollback to depth 0 in that case. + const isOrderViolation = scopeStack.at(-1) !== scopeId + const rollbackScopeCount = isOrderViolation ? Math.max(1, scopeStack.length) : 1 + for (let i = 0; i < rollbackScopeCount; i++) { + await this._engine.transaction('rollback', headers, info).catch((rollbackError) => { + debug('rollback attempt %d/%d failed: %O', i + 1, rollbackScopeCount, rollbackError) + }) + } throw e // silent rollback, throw original error + } finally { + if (scopeStack.at(-1) === scopeId) { + scopeStack.pop() + } else { + // Reset the scope stack to avoid poisoning this transaction context after an ordering violation. + scopeStack.length = 0 + } } return result } - _createItxClient(transaction: PrismaPromiseInteractiveTransaction): Client { + _createItxClient( + transaction: PrismaPromiseInteractiveTransaction, + scopeId: string, + scopeState: ItxScopeState, + ): Client { + const itxScopeContext: NestedItxScopeContext = { + kind: 'nested', + txId: transaction.id, + scopeId, + scopeState, + } + return createCompositeProxy( applyModelsAndClientExtensions( createCompositeProxy(unApplyModelsAndClientExtensions(this), [ - addProperty('_appliedParent', () => this._appliedParent._createItxClient(transaction)), + addProperty('_appliedParent', () => this._appliedParent._createItxClient(transaction, scopeId, scopeState)), addProperty('_createPrismaPromise', () => createPrismaPromiseFactory(transaction)), - addProperty(TX_ID, () => transaction.id), + addProperty(TX_SCOPE_CONTEXT, () => itxScopeContext), ]), ), [removeProperties(itxClientDenyList)], @@ -738,6 +861,13 @@ Or read our docs at https://www.prisma.io/docs/concepts/components/prisma-client 'Cloudflare D1 does not support interactive transactions. We recommend you to refactor your queries with that limitation in mind, and use batch transactions with `prisma.$transactions([])` where applicable.', ) } + } else if (config.activeProvider === 'mongodb' && getItxScopeContext(this).kind === 'nested') { + callback = () => { + throw new PrismaClientValidationError( + `The ${config.activeProvider} provider does not support nested transactions`, + { clientVersion: this._clientVersion }, + ) + } } else { callback = () => this._transactionWithCallback({ callback: input, options }) } diff --git a/packages/client/tests/functional/extensions/defineExtension.ts b/packages/client/tests/functional/extensions/defineExtension.ts index 3b77de252b3a..47fc5cbe6fd9 100644 --- a/packages/client/tests/functional/extensions/defineExtension.ts +++ b/packages/client/tests/functional/extensions/defineExtension.ts @@ -220,7 +220,6 @@ function itxWithinGenericExtension() { void xclient.$transaction((tx) => { expectTypeOf(tx).toHaveProperty('helperMethod') - expectTypeOf(tx).not.toHaveProperty('$transaction') expectTypeOf(tx).not.toHaveProperty('$extends') return Promise.resolve() }) diff --git a/packages/client/tests/functional/extensions/itx.ts b/packages/client/tests/functional/extensions/itx.ts index d925942519f8..aed17b556edc 100644 --- a/packages/client/tests/functional/extensions/itx.ts +++ b/packages/client/tests/functional/extensions/itx.ts @@ -267,7 +267,7 @@ testMatrix.setupTestSuite( if (isTransaction) { expect(ctx.$connect).toBeUndefined() expect(ctx.$disconnect).toBeUndefined() - expect(ctx.$transaction).toBeUndefined() + expect(ctx.$transaction).toBeDefined() expect(ctx.$extends).toBeUndefined() } else { expect(ctx.$connect).toBeDefined() diff --git a/packages/client/tests/functional/interactive-transactions/tests.ts b/packages/client/tests/functional/interactive-transactions/tests.ts index 2de3b13e4d65..83cbcd7983ec 100644 --- a/packages/client/tests/functional/interactive-transactions/tests.ts +++ b/packages/client/tests/functional/interactive-transactions/tests.ts @@ -220,11 +220,424 @@ testMatrix.setupTestSuite( await expect(result).resolves.toHaveLength(2) }) + testIf(provider === Providers.MONGODB)('mongodb: nested transactions are not available in types', async () => { + await prisma.$transaction((tx) => { + // For MongoDB, the transaction-bound client type should not expose `$transaction`. + // We keep this as a type-only assertion: at runtime, this is just a property access. + // @ts-test-if: provider !== Providers.MONGODB + void tx.$transaction + return Promise.resolve() + }) + }) + + /** + * If a parent transaction is rolled back, the child transaction should also rollback. + * This is only supported on SQL providers. + */ + testIf(provider !== Providers.MONGODB)('sql: nested rollback', async () => { + const email1 = `user_${copycat.uuid(101)}@website.com` + const email2 = `user_${copycat.uuid(102)}@website.com` + + await expect( + prisma.$transaction(async (tx) => { + await tx.user.create({ + data: { + email: email1, + }, + }) + + await tx.$transaction(async (tx2) => { + await tx2.user.create({ + data: { + email: email2, + }, + }) + }) + + throw new Error('Rollback') + }), + ).rejects.toThrow(/Rollback/) + + const users = await prisma.user.findMany({ + where: { + email: { + in: [email1, email2], + }, + }, + }) + + expect(users).toHaveLength(0) + }) + + testIf(provider !== Providers.MONGODB)( + 'sql: nested rollback restores parent state (savepoints, 3 levels)', + async () => { + const emailA = `user_${copycat.uuid(151)}@website.com` + const emailB = `user_${copycat.uuid(152)}@website.com` + const emailC = `user_${copycat.uuid(153)}@website.com` + + const outerPromise = prisma.$transaction(async (tx) => { + await tx.user.create({ data: { email: emailA } }) + + const afterOuterInsert = await tx.user.findMany({ where: { email: { in: [emailA, emailB, emailC] } } }) + expect(afterOuterInsert).toHaveLength(1) + + try { + await tx.$transaction(async (tx2) => { + await tx2.user.create({ data: { email: emailB } }) + + const afterInnerInsert = await tx2.user.findMany({ where: { email: { in: [emailA, emailB, emailC] } } }) + expect(afterInnerInsert).toHaveLength(2) + + try { + await tx2.$transaction(async (tx3) => { + await tx3.user.create({ data: { email: emailC } }) + const afterGrandchildInsert = await tx3.user.findMany({ + where: { email: { in: [emailA, emailB, emailC] } }, + }) + expect(afterGrandchildInsert).toHaveLength(3) + throw new Error('grandchild rollback') + }) + } catch (e) { + expect(e).toMatchObject({ message: 'grandchild rollback' }) + } + + const afterGrandchildRollback = await tx2.user.findMany({ + where: { email: { in: [emailA, emailB, emailC] } }, + }) + expect(afterGrandchildRollback).toHaveLength(2) + + throw new Error('child rollback') + }) + } catch (e) { + expect(e).toMatchObject({ message: 'child rollback' }) + } + + const afterChildRollback = await tx.user.findMany({ where: { email: { in: [emailA, emailB, emailC] } } }) + expect(afterChildRollback).toHaveLength(1) + + throw new Error('outer rollback') + }) + + await expect(outerPromise).rejects.toThrow('outer rollback') + + const users = await prisma.user.findMany({ where: { email: { in: [emailA, emailB, emailC] } } }) + expect(users).toHaveLength(0) + }, + ) + + testIf(provider !== Providers.MONGODB)('sql: nested commit keeps state (savepoints, 3 levels)', async () => { + const emailA = `user_${copycat.uuid(161)}@website.com` + const emailB = `user_${copycat.uuid(162)}@website.com` + const emailC = `user_${copycat.uuid(163)}@website.com` + + await prisma.$transaction(async (tx) => { + await tx.user.create({ data: { email: emailA } }) + + await tx.$transaction(async (tx2) => { + await tx2.user.create({ data: { email: emailB } }) + + await tx2.$transaction(async (tx3) => { + await tx3.user.create({ data: { email: emailC } }) + + const inside = await tx3.user.findMany({ where: { email: { in: [emailA, emailB, emailC] } } }) + expect(inside).toHaveLength(3) + }) + }) + + const insideOuter = await tx.user.findMany({ where: { email: { in: [emailA, emailB, emailC] } } }) + expect(insideOuter).toHaveLength(3) + }) + + const users = await prisma.user.findMany({ + where: { email: { in: [emailA, emailB, emailC] } }, + orderBy: { email: 'asc' }, + }) + expect(users).toHaveLength(3) + }) + + testIf(provider !== Providers.MONGODB)('sql: disallow concurrent nested transactions', async () => { + const result = prisma.$transaction(async (tx) => { + const email1 = `user_${copycat.uuid(201)}@website.com` + const email2 = `user_${copycat.uuid(202)}@website.com` + const email3 = `user_${copycat.uuid(203)}@website.com` + const email4 = `user_${copycat.uuid(204)}@website.com` + const email5 = `user_${copycat.uuid(205)}@website.com` + const email6 = `user_${copycat.uuid(206)}@website.com` + const email7 = `user_${copycat.uuid(207)}@website.com` + + await Promise.all([ + tx.$transaction(async (tx2) => { + await tx2.user.create({ data: { email: email1 } }) + await tx2.user.create({ data: { email: email2 } }) + await tx2.user.create({ data: { email: email3 } }) + }), + tx.$transaction(async (tx3) => { + await tx3.user.create({ data: { email: email4 } }) + await tx3.user.create({ data: { email: email5 } }) + }), + tx.$transaction(async (tx4) => { + await tx4.user.create({ data: { email: email6 } }) + await tx4.user.create({ data: { email: email7 } }) + }), + ]) + }) + + await expect(result).rejects.toThrow('Concurrent nested transactions are not supported') + + // No partial writes should be visible after the outer transaction fails. + const users = await prisma.user.findMany() + expect(users).toHaveLength(0) + }) + + testIf(provider !== Providers.MONGODB)( + 'sql: allow nested transactions in concurrent top-level transactions', + async () => { + const a1 = `user_${copycat.uuid(401)}@website.com` + const a2 = `user_${copycat.uuid(402)}@website.com` + const b1 = `user_${copycat.uuid(403)}@website.com` + const b2 = `user_${copycat.uuid(404)}@website.com` + + await Promise.all([ + prisma.$transaction(async (tx) => { + await tx.user.create({ data: { email: a1 } }) + await delay(25) + await tx.$transaction(async (tx2) => { + await delay(25) + await tx2.user.create({ data: { email: a2 } }) + }) + }), + prisma.$transaction(async (tx) => { + await tx.user.create({ data: { email: b1 } }) + await delay(25) + await tx.$transaction(async (tx2) => { + await delay(25) + await tx2.user.create({ data: { email: b2 } }) + }) + }), + ]) + + const users = await prisma.user.findMany({ where: { email: { in: [a1, a2, b1, b2] } } }) + expect(users).toHaveLength(4) + }, + ) + + testIf(provider !== Providers.MONGODB)('sql: nested commit keeps outer transaction open', async () => { + const email1 = `user_${copycat.uuid(211)}@website.com` + const email2 = `user_${copycat.uuid(212)}@website.com` + const email3 = `user_${copycat.uuid(213)}@website.com` + + const users = await prisma.$transaction(async (tx) => { + await tx.user.create({ data: { email: email1 } }) + + await tx.$transaction(async (tx2) => { + await tx2.user.create({ data: { email: email2 } }) + }) + + // If nested commit incorrectly closes the underlying transaction, + // this query or the final commit will fail. + await tx.user.create({ data: { email: email3 } }) + + return tx.user.findMany({ + where: { email: { in: [email1, email2, email3] } }, + orderBy: { email: 'asc' }, + }) + }) + + expect(users).toHaveLength(3) + }) + + testIf(provider !== Providers.MONGODB)('sql: sequential nested transactions work', async () => { + const email1 = `user_${copycat.uuid(221)}@website.com` + const email2 = `user_${copycat.uuid(222)}@website.com` + const email3 = `user_${copycat.uuid(223)}@website.com` + + await prisma.$transaction(async (tx) => { + await tx.$transaction(async (tx2) => { + await tx2.user.create({ data: { email: email1 } }) + }) + + await tx.$transaction(async (tx3) => { + await tx3.user.create({ data: { email: email2 } }) + }) + + await tx.user.create({ data: { email: email3 } }) + }) + + const users = await prisma.user.findMany({ + where: { email: { in: [email1, email2, email3] } }, + }) + expect(users).toHaveLength(3) + }) + + testIf(provider !== Providers.MONGODB)('sql: deep nesting (3 levels) works', async () => { + const email1 = `user_${copycat.uuid(231)}@website.com` + const email2 = `user_${copycat.uuid(232)}@website.com` + const email3 = `user_${copycat.uuid(233)}@website.com` + + await prisma.$transaction(async (tx) => { + await tx.user.create({ data: { email: email1 } }) + + await tx.$transaction(async (tx2) => { + await tx2.user.create({ data: { email: email2 } }) + + await tx2.$transaction(async (tx3) => { + await tx3.user.create({ data: { email: email3 } }) + }) + }) + }) + + const users = await prisma.user.findMany({ + where: { email: { in: [email1, email2, email3] } }, + }) + expect(users).toHaveLength(3) + }) + + testIf(provider !== Providers.MONGODB)('sql: nested rollback can be caught and outer can continue', async () => { + const outerEmail1 = `user_${copycat.uuid(241)}@website.com` + const innerEmail = `user_${copycat.uuid(242)}@website.com` + const outerEmail2 = `user_${copycat.uuid(243)}@website.com` + + await prisma.$transaction(async (tx) => { + await tx.user.create({ data: { email: outerEmail1 } }) + + try { + await tx.$transaction(async (tx2) => { + await tx2.user.create({ data: { email: innerEmail } }) + throw new Error('inner rollback') + }) + } catch (e) { + expect(e).toMatchObject({ message: 'inner rollback' }) + } + + await tx.user.create({ data: { email: outerEmail2 } }) + }) + + const users = await prisma.user.findMany({ + where: { email: { in: [outerEmail1, innerEmail, outerEmail2] } }, + orderBy: { email: 'asc' }, + }) + + expect(users.map((u) => u.email)).toEqual([outerEmail1, outerEmail2].sort()) + }) + + testIf(provider !== Providers.MONGODB)('sql: enforce order for nested transactions', async () => { + const result = prisma.$transaction(async (tx) => { + const nested = tx.$transaction(async (tx2) => { + await tx2.user.create({ data: { email: `user_${copycat.uuid(301)}@website.com` } }) + await delay(50) + }) + nested.catch(() => {}) // avoid unhandled rejection in this test + + await tx.user.create({ data: { email: `user_${copycat.uuid(302)}@website.com` } }) + }) + + await expect(result).rejects.toThrow('Cannot close transaction while a nested transaction is still active.') + + const users = await prisma.user.findMany() + expect(users).toHaveLength(0) + }) + + testIf(provider !== Providers.MONGODB)( + 'sql: child fails if parent tries to commit before child finishes', + async () => { + const email = `user_${copycat.uuid(521)}@website.com` + let child: Promise | undefined + + const parent = prisma.$transaction((tx) => { + child = tx.$transaction(async (tx2) => { + await delay(50) + await tx2.user.create({ data: { email } }) + }) + child.catch(() => {}) // prevent unhandled rejection if parent fails fast + + // Parent returns immediately (commit attempt) without awaiting the child. + return Promise.resolve() + }) + + await expect(parent).rejects.toThrow('Cannot close transaction while a nested transaction is still active.') + await expect(child).rejects.toMatchObject({ + code: 'P2028', + }) + + const users = await prisma.user.findMany({ where: { email } }) + expect(users).toHaveLength(0) + }, + ) + + testIf(provider !== Providers.MONGODB)('sql: child fails if parent rolls back before child finishes', async () => { + const email = `user_${copycat.uuid(501)}@website.com` + let child: Promise | undefined + + const parent = prisma.$transaction((tx) => { + child = tx.$transaction(async (tx2) => { + await delay(50) + await tx2.user.create({ data: { email } }) + }) + child.catch(() => {}) // prevent unhandled rejection if parent fails fast + + return Promise.reject(new Error('parent rollback')) + }) + + await expect(parent).rejects.toThrow('parent rollback') + await expect(child).rejects.toMatchObject({ + code: 'P2028', + }) + + const users = await prisma.user.findMany({ where: { email } }) + expect(users).toHaveLength(0) + }) + + testIf(provider !== Providers.MONGODB)( + 'sql: child fails if nested parent closes before grandchild finishes', + async () => { + const email = `user_${copycat.uuid(511)}@website.com` + let grandchild: Promise | undefined + + const parent = prisma.$transaction(async (tx) => { + await tx.$transaction((tx2) => { + grandchild = tx2.$transaction(async (tx3) => { + await delay(50) + await tx3.user.create({ data: { email } }) + }) + grandchild.catch(() => {}) // prevent unhandled rejection if parent fails fast + + // Intentionally don't await `grandchild` to simulate incorrect ordering. + // The parent nested transaction should fail to close and the whole top-level tx should rollback. + return Promise.resolve() + }) + }) + + await expect(parent).rejects.toThrow('Nested transactions must be closed in reverse order of creation.') + await expect(grandchild).rejects.toMatchObject({ + code: 'P2028', + }) + + const users = await prisma.user.findMany({ where: { email } }) + expect(users).toHaveLength(0) + }, + ) + + testIf(provider === Providers.MONGODB)('mongodb: disallow nested transactions at runtime', async () => { + const result = prisma.$transaction(async (tx) => { + await tx.user.create({ data: { email: 'user_1@website.com' } }) + // Nested transactions are intentionally not available in types for MongoDB. + // Bypass the type system to assert runtime behavior. + await (tx as any).$transaction(async (tx2: any) => { + await tx2.user.create({ data: { email: 'user_2@website.com' } }) + }) + }) + + await expect(result).rejects.toThrow('The mongodb provider does not support nested transactions') + const users = await prisma.user.findMany() + expect(users).toHaveLength(0) + }) + /** * We don't allow certain methods to be called in a transaction */ test('forbidden', async () => { - const forbidden = ['$connect', '$disconnect', '$on', '$transaction'] + const forbidden = ['$connect', '$disconnect', '$on', '$use'] expect.assertions(forbidden.length + 1) const result = prisma.$transaction((prisma) => { @@ -241,25 +654,32 @@ testMatrix.setupTestSuite( * If one of the query fails, all queries should cancel */ test('rollback query', async () => { + const email1 = 'user_1@website.com' const result = prisma.$transaction(async (prisma) => { await prisma.user.create({ data: { id: copycat.uuid(1).replaceAll('-', '').slice(-24), - email: 'user_1@website.com', + email: email1, }, }) await prisma.user.create({ data: { id: copycat.uuid(2).replaceAll('-', '').slice(-24), - email: 'user_1@website.com', + email: email1, }, }) }) await expect(result).rejects.toMatchPrismaErrorSnapshot() - const users = await prisma.user.findMany() + const users = await prisma.user.findMany({ + where: { + email: { + equals: email1, + }, + }, + }) expect(users.length).toBe(0) }) diff --git a/packages/driver-adapter-utils/src/binder.ts b/packages/driver-adapter-utils/src/binder.ts index bfb6f3fb6051..c17ce6072ded 100644 --- a/packages/driver-adapter-utils/src/binder.ts +++ b/packages/driver-adapter-utils/src/binder.ts @@ -113,7 +113,7 @@ export const bindAdapter = ( // *.bind(transaction) is required to preserve the `this` context of functions whose // execution is delegated to napi.rs. const bindTransaction = (errorRegistry: ErrorRegistryInternal, transaction: Transaction): ErrorCapturingTransaction => { - return { + const boundTransaction: ErrorCapturingTransaction = { adapterName: transaction.adapterName, provider: transaction.provider, options: transaction.options, @@ -122,6 +122,20 @@ const bindTransaction = (errorRegistry: ErrorRegistryInternal, transaction: Tran commit: wrapAsync(errorRegistry, transaction.commit.bind(transaction)), rollback: wrapAsync(errorRegistry, transaction.rollback.bind(transaction)), } + + if (transaction.createSavepoint) { + boundTransaction.createSavepoint = wrapAsync(errorRegistry, transaction.createSavepoint.bind(transaction)) + } + + if (transaction.rollbackToSavepoint) { + boundTransaction.rollbackToSavepoint = wrapAsync(errorRegistry, transaction.rollbackToSavepoint.bind(transaction)) + } + + if (transaction.releaseSavepoint) { + boundTransaction.releaseSavepoint = wrapAsync(errorRegistry, transaction.releaseSavepoint.bind(transaction)) + } + + return boundTransaction } function wrapAsync( diff --git a/packages/driver-adapter-utils/src/types.ts b/packages/driver-adapter-utils/src/types.ts index 5943031f8940..744291caa2b7 100644 --- a/packages/driver-adapter-utils/src/types.ts +++ b/packages/driver-adapter-utils/src/types.ts @@ -289,6 +289,18 @@ export interface Transaction extends AdapterInfo, SqlQueryable { * Roll back the transaction. */ rollback(): Promise + /** + * Creates a savepoint within the currently running transaction. + */ + createSavepoint?(name: string): Promise + /** + * Rolls back transaction state to a previously created savepoint. + */ + rollbackToSavepoint?(name: string): Promise + /** + * Releases a previously created savepoint. Optional because not every connector supports this operation. + */ + releaseSavepoint?(name: string): Promise } /** diff --git a/packages/query-plan-executor/src/logic/adapter.test.ts b/packages/query-plan-executor/src/logic/adapter.test.ts index 90efa20738ad..800ffedf6b71 100644 --- a/packages/query-plan-executor/src/logic/adapter.test.ts +++ b/packages/query-plan-executor/src/logic/adapter.test.ts @@ -424,6 +424,9 @@ describe('createAdapter wrapper error handling', () => { rollback: vi.fn(), executeRaw: vi.fn(), queryRaw: vi.fn(), + createSavepoint: vi.fn(), + rollbackToSavepoint: vi.fn(), + releaseSavepoint: vi.fn(), } mockAdapter = { @@ -466,6 +469,18 @@ describe('createAdapter wrapper error handling', () => { return tx.queryRaw(query) }, }, + { + method: 'createSavepoint' as const, + execute: async (tx: Transaction) => tx.createSavepoint?.('sp1'), + }, + { + method: 'rollbackToSavepoint' as const, + execute: async (tx: Transaction) => tx.rollbackToSavepoint?.('sp1'), + }, + { + method: 'releaseSavepoint' as const, + execute: async (tx: Transaction) => tx.releaseSavepoint?.('sp1'), + }, ])('wraps transaction.$method() to sanitize errors', async ({ method, execute }) => { const error = new Error(`${method} failed for postgresql://user:pass@host:5432/db`) ;(mockTransaction as any)[method] = vi.fn().mockRejectedValue(error) @@ -493,6 +508,9 @@ describe('createAdapter wrapper error handling', () => { mockTransaction.rollback = vi.fn().mockResolvedValue(undefined) mockTransaction.executeRaw = vi.fn().mockResolvedValue(1) mockTransaction.queryRaw = vi.fn().mockResolvedValue(mockResult) + mockTransaction.createSavepoint = vi.fn().mockResolvedValue(undefined) + mockTransaction.rollbackToSavepoint = vi.fn().mockResolvedValue(undefined) + mockTransaction.releaseSavepoint = vi.fn().mockResolvedValue(undefined) const wrappedFactory = createAdapter('postgresql://user:pass@host:5432/db', [ { @@ -509,12 +527,18 @@ describe('createAdapter wrapper error handling', () => { await expect(tx.rollback()).resolves.toBeUndefined() expect(await tx.executeRaw(query)).toBe(1) expect(await tx.queryRaw(query)).toEqual(mockResult) + await expect(tx.createSavepoint?.('sp1')).resolves.toBeUndefined() + await expect(tx.rollbackToSavepoint?.('sp1')).resolves.toBeUndefined() + await expect(tx.releaseSavepoint?.('sp1')).resolves.toBeUndefined() /* eslint-disable @typescript-eslint/unbound-method */ expect(mockTransaction.commit).toHaveBeenCalledOnce() expect(mockTransaction.rollback).toHaveBeenCalledOnce() expect(mockTransaction.executeRaw).toHaveBeenCalledWith(query) expect(mockTransaction.queryRaw).toHaveBeenCalledWith(query) + expect(mockTransaction.createSavepoint).toHaveBeenCalledWith('sp1') + expect(mockTransaction.rollbackToSavepoint).toHaveBeenCalledWith('sp1') + expect(mockTransaction.releaseSavepoint).toHaveBeenCalledWith('sp1') /* eslint-enable @typescript-eslint/unbound-method */ }) diff --git a/packages/query-plan-executor/src/logic/adapter.ts b/packages/query-plan-executor/src/logic/adapter.ts index 3f28334dd2cd..e2cc8d93214c 100644 --- a/packages/query-plan-executor/src/logic/adapter.ts +++ b/packages/query-plan-executor/src/logic/adapter.ts @@ -118,7 +118,8 @@ function createConnectionStringRegex(protocols: string[]) { function wrapFactory(protocols: string[], factory: SqlDriverAdapterFactory): SqlDriverAdapterFactory { return { - ...factory, + adapterName: factory.adapterName, + provider: factory.provider, connect: () => factory.connect().then(wrapAdapter.bind(null, protocols), rethrowSanitizedError.bind(null, protocols)), } @@ -126,7 +127,8 @@ function wrapFactory(protocols: string[], factory: SqlDriverAdapterFactory): Sql function wrapAdapter(protocols: string[], adapter: SqlDriverAdapter): SqlDriverAdapter { return { - ...adapter, + adapterName: adapter.adapterName, + provider: adapter.provider, dispose: () => adapter.dispose().catch(rethrowSanitizedError.bind(null, protocols)), executeRaw: (query) => adapter.executeRaw(query).catch(rethrowSanitizedError.bind(null, protocols)), queryRaw: (query) => adapter.queryRaw(query).catch(rethrowSanitizedError.bind(null, protocols)), @@ -141,10 +143,21 @@ function wrapAdapter(protocols: string[], adapter: SqlDriverAdapter): SqlDriverA function wrapTransaction(protocols: string[], tx: Transaction): Transaction { return { - ...tx, + adapterName: tx.adapterName, + provider: tx.provider, + options: tx.options, commit: () => tx.commit().catch(rethrowSanitizedError.bind(null, protocols)), rollback: () => tx.rollback().catch(rethrowSanitizedError.bind(null, protocols)), executeRaw: (query) => tx.executeRaw(query).catch(rethrowSanitizedError.bind(null, protocols)), queryRaw: (query) => tx.queryRaw(query).catch(rethrowSanitizedError.bind(null, protocols)), + createSavepoint: tx.createSavepoint + ? (name) => tx.createSavepoint!(name).catch(rethrowSanitizedError.bind(null, protocols)) + : undefined, + rollbackToSavepoint: tx.rollbackToSavepoint + ? (name) => tx.rollbackToSavepoint!(name).catch(rethrowSanitizedError.bind(null, protocols)) + : undefined, + releaseSavepoint: tx.releaseSavepoint + ? (name) => tx.releaseSavepoint!(name).catch(rethrowSanitizedError.bind(null, protocols)) + : undefined, } } diff --git a/packages/query-plan-executor/src/server/schemas.ts b/packages/query-plan-executor/src/server/schemas.ts index 145f6719b058..6124f3c1230b 100644 --- a/packages/query-plan-executor/src/server/schemas.ts +++ b/packages/query-plan-executor/src/server/schemas.ts @@ -50,6 +50,7 @@ export const TransactionStartRequestBody = z.object({ isolationLevel: z .enum(['READ UNCOMMITTED', 'READ COMMITTED', 'REPEATABLE READ', 'SNAPSHOT', 'SERIALIZABLE']) .optional(), + newTxId: z.string().optional(), }) export type TransactionStartRequestBody = z.infer