From 6091e02500a5458c87b51e12d436e717a14ad911 Mon Sep 17 00:00:00 2001 From: Lucian Buzzo Date: Thu, 19 Feb 2026 16:06:44 +0000 Subject: [PATCH] feat: add support for nested transaction rollbacks via savepoints in sql (#21678) ~~This adds a high level test that should pass if/when https://github.com/prisma/prisma-engines/pull/4375 is merged.~~ Transactions are now handled inside this repository, so the update to prisma-engines is no longer required. This change adds support for handling rollbacks in nested transactions in SQL databases. Specifically, the inner transaction should be rolled back if the outer transaction fails. To do this we keep track of the transaction ID and transaction depth so we can re-use an existing open transaction in the underlying engine. This change also allows the use of the `$transaction` method on an interactive transaction client. ## Summary by CodeRabbit * **New Features** * Nested interactive transactions for SQL databases (Postgres, MySQL, SQLite, SQL Server) via savepoints; MongoDB remains without nested-transaction support. * **Tests** * Added extensive nested-transaction tests covering multi-level nesting, savepoint behavior, rollback/commit propagation, concurrency and provider-specific syntax. * **Documentation** * Updated runtime documentation to describe nested transaction behavior and provider-specific savepoint semantics. --------- Co-authored-by: Jacek Malec --- AGENTS.md | 4 +- eslint.config.cjs | 1 + .../src/better-sqlite3.ts | 12 + packages/adapter-d1/src/d1-http.ts | 12 + packages/adapter-d1/src/d1-worker.ts | 12 + packages/adapter-libsql/src/libsql.ts | 12 + packages/adapter-mariadb/src/mariadb.ts | 12 + packages/adapter-mssql/src/mssql.ts | 8 + packages/adapter-neon/src/neon.ts | 12 + packages/adapter-pg/src/pg.ts | 12 + .../adapter-planetscale/src/planetscale.ts | 12 + packages/adapter-ppg/src/ppg.ts | 12 + .../transaction-manager.test.ts | 331 +++++++++++++- .../transaction-manager.ts | 183 +++++++- .../src/transaction-manager/transaction.ts | 1 + .../src/TSClient/PrismaClient.ts | 16 +- .../src/TSClient/PrismaClient.ts | 12 +- .../file-generators/PrismaNamespaceFile.ts | 6 +- .../core/engines/common/types/Transaction.ts | 6 + .../core/types/exported/itxClientDenyList.ts | 2 +- .../client/src/runtime/getPrismaClient.ts | 150 +++++- .../functional/extensions/defineExtension.ts | 1 - .../client/tests/functional/extensions/itx.ts | 2 +- .../interactive-transactions/tests.ts | 428 +++++++++++++++++- packages/driver-adapter-utils/src/binder.ts | 16 +- packages/driver-adapter-utils/src/types.ts | 12 + .../src/logic/adapter.test.ts | 24 + .../query-plan-executor/src/logic/adapter.ts | 19 +- .../query-plan-executor/src/server/schemas.ts | 1 + 29 files changed, 1281 insertions(+), 50 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 509d36d2434b..b46901e7859d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -31,11 +31,13 @@ - **Docs & references**: `ARCHITECTURE.md` contains dependency graphs (requires GraphViz to regenerate), `docker/README.md` explains local DB setup, `docs/benchmarking.md` covers performance benchmarking, `examples/` provides sample apps, and `sandbox/` hosts debugging helpers like the DMMF explorer. - **Client architecture (Prisma 7)**: - - `ClientEngine` in `packages/client/src/runtime/core/engines/client/` orchestrates query execution using WASM query compiler. + - `ClientEngine` in `packages/client/src/runtime/core/engines/client/` orchestrates query execution using Wasm query compiler. - Two executor implementations: `LocalExecutor` (driver adapters, direct DB) and `RemoteExecutor` (Accelerate/Data Proxy). - `QueryInterpreter` class in `packages/client-engine-runtime/src/interpreter/query-interpreter.ts` executes query plans against `SqlQueryable` (driver adapter interface). - Query flow: `PrismaClient` → `ClientEngine.request()` → query compiler → `executor.execute()` → `QueryInterpreter.run()` → driver adapter. - `ExecutePlanParams` interface in `packages/client/src/runtime/core/engines/client/Executor.ts` defines what's passed through the execution chain. + - `TransactionManager` in `packages/client-engine-runtime/src/transaction-manager/transaction-manager.ts` owns interactive transaction IDs and implements nested transactions using savepoints. Savepoint SQL is provider-specific (e.g. PostgreSQL uses `ROLLBACK TO SAVEPOINT `, MySQL/SQLite use `ROLLBACK TO `, SQL Server uses `SAVE TRANSACTION ` / `ROLLBACK TRANSACTION ` and has no release statement). + - `Transaction` in `packages/driver-adapter-utils` models savepoint behavior as async methods (`createSavepoint`, `rollbackToSavepoint`, optional `releaseSavepoint`) instead of returning SQL via `savepoint(action, name)`. `TransactionManager` expects adapter methods for savepoints and does not synthesize provider fallback SQL. - Fluent API `dataPath` is built in `packages/client/src/runtime/core/model/applyFluent.ts` by appending `['select', ]` on each hop; runtime unpacking in `packages/client/src/runtime/RequestHandler.ts` currently strips `'select'`/`'include'` segments before `deepGet`. - In extension context resolution, `dataPath` should be interpreted as selector/field pairs (`select|include`, relation field). Do not strip by raw string value or relation fields literally named `select`/`include` get dropped. diff --git a/eslint.config.cjs b/eslint.config.cjs index 601b580837be..ac154f9199a6 100644 --- a/eslint.config.cjs +++ b/eslint.config.cjs @@ -36,6 +36,7 @@ module.exports = [ '**/fixtures/**', '**/__fixtures__/**', '**/generated/**', + '**/.generated/**', '**/prism.ts', '**/charm.ts', '**/pnpm-lock.yaml', diff --git a/packages/adapter-better-sqlite3/src/better-sqlite3.ts b/packages/adapter-better-sqlite3/src/better-sqlite3.ts index b2a06b8acbda..16b239e0e640 100644 --- a/packages/adapter-better-sqlite3/src/better-sqlite3.ts +++ b/packages/adapter-better-sqlite3/src/better-sqlite3.ts @@ -160,6 +160,18 @@ class BetterSQLite3Transaction extends BetterSQLite3Queryable impleme this.#unlockParent() return Promise.resolve() } + + async createSavepoint(name: string): Promise { + await this.executeRaw({ sql: `SAVEPOINT ${name}`, args: [], argTypes: [] }) + } + + async rollbackToSavepoint(name: string): Promise { + await this.executeRaw({ sql: `ROLLBACK TO ${name}`, args: [], argTypes: [] }) + } + + async releaseSavepoint(name: string): Promise { + await this.executeRaw({ sql: `RELEASE SAVEPOINT ${name}`, args: [], argTypes: [] }) + } } export class PrismaBetterSqlite3Adapter extends BetterSQLite3Queryable implements SqlDriverAdapter { diff --git a/packages/adapter-d1/src/d1-http.ts b/packages/adapter-d1/src/d1-http.ts index 59f7d5cd2bd1..7a86b4c2e18c 100644 --- a/packages/adapter-d1/src/d1-http.ts +++ b/packages/adapter-d1/src/d1-http.ts @@ -199,6 +199,18 @@ class D1HttpTransaction extends D1HttpQueryable implements Transaction { async rollback(): Promise { debug(`[js::rollback]`) } + + async createSavepoint(name: string): Promise { + await this.executeRaw({ sql: `SAVEPOINT ${name}`, args: [], argTypes: [] }) + } + + async rollbackToSavepoint(name: string): Promise { + await this.executeRaw({ sql: `ROLLBACK TO ${name}`, args: [], argTypes: [] }) + } + + async releaseSavepoint(name: string): Promise { + await this.executeRaw({ sql: `RELEASE SAVEPOINT ${name}`, args: [], argTypes: [] }) + } } export class PrismaD1HttpAdapter extends D1HttpQueryable implements SqlDriverAdapter { diff --git a/packages/adapter-d1/src/d1-worker.ts b/packages/adapter-d1/src/d1-worker.ts index 95c5dfd9f08c..cab5bc9fa7e7 100644 --- a/packages/adapter-d1/src/d1-worker.ts +++ b/packages/adapter-d1/src/d1-worker.ts @@ -119,6 +119,18 @@ class D1WorkerTransaction extends D1WorkerQueryable implements Transa async rollback(): Promise { debug(`[js::rollback]`) } + + async createSavepoint(name: string): Promise { + await this.executeRaw({ sql: `SAVEPOINT ${name}`, args: [], argTypes: [] }) + } + + async rollbackToSavepoint(name: string): Promise { + await this.executeRaw({ sql: `ROLLBACK TO ${name}`, args: [], argTypes: [] }) + } + + async releaseSavepoint(name: string): Promise { + await this.executeRaw({ sql: `RELEASE SAVEPOINT ${name}`, args: [], argTypes: [] }) + } } export class PrismaD1WorkerAdapter extends D1WorkerQueryable implements SqlDriverAdapter { diff --git a/packages/adapter-libsql/src/libsql.ts b/packages/adapter-libsql/src/libsql.ts index 42dc88d4787b..8eab92839e86 100644 --- a/packages/adapter-libsql/src/libsql.ts +++ b/packages/adapter-libsql/src/libsql.ts @@ -129,6 +129,18 @@ class LibSqlTransaction extends LibSqlQueryable implements Tr this.#unlockParent() } } + + async createSavepoint(name: string): Promise { + await this.executeRaw({ sql: `SAVEPOINT ${name}`, args: [], argTypes: [] }) + } + + async rollbackToSavepoint(name: string): Promise { + await this.executeRaw({ sql: `ROLLBACK TO ${name}`, args: [], argTypes: [] }) + } + + async releaseSavepoint(name: string): Promise { + await this.executeRaw({ sql: `RELEASE SAVEPOINT ${name}`, args: [], argTypes: [] }) + } } export class PrismaLibSqlAdapter extends LibSqlQueryable implements SqlDriverAdapter { diff --git a/packages/adapter-mariadb/src/mariadb.ts b/packages/adapter-mariadb/src/mariadb.ts index 5bd797216fa0..3bacca68c47a 100644 --- a/packages/adapter-mariadb/src/mariadb.ts +++ b/packages/adapter-mariadb/src/mariadb.ts @@ -98,6 +98,18 @@ class MariaDbTransaction extends MariaDbQueryable implements this.cleanup?.() await this.client.end() } + + async createSavepoint(name: string): Promise { + await this.executeRaw({ sql: `SAVEPOINT ${name}`, args: [], argTypes: [] }) + } + + async rollbackToSavepoint(name: string): Promise { + await this.executeRaw({ sql: `ROLLBACK TO ${name}`, args: [], argTypes: [] }) + } + + async releaseSavepoint(name: string): Promise { + await this.executeRaw({ sql: `RELEASE SAVEPOINT ${name}`, args: [], argTypes: [] }) + } } export type PrismaMariadbOptions = { diff --git a/packages/adapter-mssql/src/mssql.ts b/packages/adapter-mssql/src/mssql.ts index 2e4d828f7727..bb84c4dbe92c 100644 --- a/packages/adapter-mssql/src/mssql.ts +++ b/packages/adapter-mssql/src/mssql.ts @@ -117,6 +117,14 @@ class MssqlTransaction extends MssqlQueryable implements Transaction { release() } } + + async createSavepoint(name: string): Promise { + await this.executeRaw({ sql: `SAVE TRANSACTION ${name}`, args: [], argTypes: [] }) + } + + async rollbackToSavepoint(name: string): Promise { + await this.executeRaw({ sql: `ROLLBACK TRANSACTION ${name}`, args: [], argTypes: [] }) + } } export type PrismaMssqlOptions = { diff --git a/packages/adapter-neon/src/neon.ts b/packages/adapter-neon/src/neon.ts index 73e08cf87d96..d420de44a9dc 100644 --- a/packages/adapter-neon/src/neon.ts +++ b/packages/adapter-neon/src/neon.ts @@ -158,6 +158,18 @@ class NeonTransaction extends NeonWsQueryable implements Transa this.cleanup?.() this.client.release() } + + async createSavepoint(name: string): Promise { + await this.executeRaw({ sql: `SAVEPOINT ${name}`, args: [], argTypes: [] }) + } + + async rollbackToSavepoint(name: string): Promise { + await this.executeRaw({ sql: `ROLLBACK TO SAVEPOINT ${name}`, args: [], argTypes: [] }) + } + + async releaseSavepoint(name: string): Promise { + await this.executeRaw({ sql: `RELEASE SAVEPOINT ${name}`, args: [], argTypes: [] }) + } } export type PrismaNeonOptions = { diff --git a/packages/adapter-pg/src/pg.ts b/packages/adapter-pg/src/pg.ts index 702e4b3d9925..8d01ee593875 100644 --- a/packages/adapter-pg/src/pg.ts +++ b/packages/adapter-pg/src/pg.ts @@ -168,6 +168,18 @@ class PgTransaction extends PgQueryable implements Transactio this.cleanup?.() this.client.release() } + + async createSavepoint(name: string): Promise { + await this.executeRaw({ sql: `SAVEPOINT ${name}`, args: [], argTypes: [] }) + } + + async rollbackToSavepoint(name: string): Promise { + await this.executeRaw({ sql: `ROLLBACK TO SAVEPOINT ${name}`, args: [], argTypes: [] }) + } + + async releaseSavepoint(name: string): Promise { + await this.executeRaw({ sql: `RELEASE SAVEPOINT ${name}`, args: [], argTypes: [] }) + } } export type PrismaPgOptions = { diff --git a/packages/adapter-planetscale/src/planetscale.ts b/packages/adapter-planetscale/src/planetscale.ts index fe25dc9cc4f1..5e62d7498fea 100644 --- a/packages/adapter-planetscale/src/planetscale.ts +++ b/packages/adapter-planetscale/src/planetscale.ts @@ -177,6 +177,18 @@ class PlanetScaleTransaction extends PlanetScaleQueryable { + await this.executeRaw({ sql: `SAVEPOINT ${name}`, args: [], argTypes: [] }) + } + + async rollbackToSavepoint(name: string): Promise { + await this.executeRaw({ sql: `ROLLBACK TO ${name}`, args: [], argTypes: [] }) + } + + async releaseSavepoint(name: string): Promise { + await this.executeRaw({ sql: `RELEASE SAVEPOINT ${name}`, args: [], argTypes: [] }) + } } export class PrismaPlanetScaleAdapter extends PlanetScaleQueryable implements SqlDriverAdapter { diff --git a/packages/adapter-ppg/src/ppg.ts b/packages/adapter-ppg/src/ppg.ts index 115adee8acb0..00057b52353b 100644 --- a/packages/adapter-ppg/src/ppg.ts +++ b/packages/adapter-ppg/src/ppg.ts @@ -188,6 +188,18 @@ class PrismaPostgresTransaction implements Transaction { } } + async createSavepoint(name: string): Promise { + await this.executeRaw({ sql: `SAVEPOINT ${name}`, args: [], argTypes: [] }) + } + + async rollbackToSavepoint(name: string): Promise { + await this.executeRaw({ sql: `ROLLBACK TO SAVEPOINT ${name}`, args: [], argTypes: [] }) + } + + async releaseSavepoint(name: string): Promise { + await this.executeRaw({ sql: `RELEASE SAVEPOINT ${name}`, args: [], argTypes: [] }) + } + async executeRaw(params: SqlQuery): Promise { await this.#ensureBegun() return executeRawStatement(this.#session, params) diff --git a/packages/client-engine-runtime/src/transaction-manager/transaction-manager.test.ts b/packages/client-engine-runtime/src/transaction-manager/transaction-manager.test.ts index 90dac6936689..e38862330c33 100644 --- a/packages/client-engine-runtime/src/transaction-manager/transaction-manager.test.ts +++ b/packages/client-engine-runtime/src/transaction-manager/transaction-manager.test.ts @@ -1,7 +1,6 @@ import timers from 'node:timers/promises' import type { SqlDriverAdapter, SqlQuery, SqlResultSet, Transaction } from '@prisma/driver-adapter-utils' -import { ok } from '@prisma/driver-adapter-utils' import { noopTracingHelper } from '../tracing' import { Options } from './transaction' @@ -31,14 +30,51 @@ class MockDriverAdapter implements SqlDriverAdapter { adapterName = 'mock-adapter' provider: SqlDriverAdapter['provider'] private readonly usePhantomQuery: boolean - - executeRawMock: jest.MockedFn<(params: SqlQuery) => Promise> = jest.fn().mockResolvedValue(ok(1)) - commitMock: jest.MockedFn<() => Promise> = jest.fn().mockResolvedValue(ok(undefined)) - rollbackMock: jest.MockedFn<() => Promise> = jest.fn().mockResolvedValue(ok(undefined)) - - constructor({ provider = 'postgres' as SqlDriverAdapter['provider'], usePhantomQuery = false } = {}) { + private readonly createSavepoint: Transaction['createSavepoint'] + private readonly rollbackToSavepoint: Transaction['rollbackToSavepoint'] + private readonly releaseSavepoint: Transaction['releaseSavepoint'] + + executeRawMock: jest.MockedFn<(params: SqlQuery) => Promise> = jest.fn().mockResolvedValue(1) + commitMock: jest.MockedFn<() => Promise> = jest.fn().mockResolvedValue(undefined) + rollbackMock: jest.MockedFn<() => Promise> = jest.fn().mockResolvedValue(undefined) + + constructor( + options: { + provider?: SqlDriverAdapter['provider'] + usePhantomQuery?: boolean + createSavepoint?: Transaction['createSavepoint'] + rollbackToSavepoint?: Transaction['rollbackToSavepoint'] + releaseSavepoint?: Transaction['releaseSavepoint'] + } = {}, + ) { + const { provider = 'postgres' as SqlDriverAdapter['provider'], usePhantomQuery = false } = options this.usePhantomQuery = usePhantomQuery this.provider = provider + + this.createSavepoint = Object.hasOwn(options, 'createSavepoint') + ? options.createSavepoint + : async (name) => { + const sql = this.provider === 'sqlserver' ? `SAVE TRANSACTION ${name}` : `SAVEPOINT ${name}` + await this.executeRawMock({ sql, args: [], argTypes: [] }) + } + this.rollbackToSavepoint = Object.hasOwn(options, 'rollbackToSavepoint') + ? options.rollbackToSavepoint + : async (name) => { + const sql = + this.provider === 'sqlserver' + ? `ROLLBACK TRANSACTION ${name}` + : this.provider === 'postgres' + ? `ROLLBACK TO SAVEPOINT ${name}` + : `ROLLBACK TO ${name}` + await this.executeRawMock({ sql, args: [], argTypes: [] }) + } + this.releaseSavepoint = Object.hasOwn(options, 'releaseSavepoint') + ? options.releaseSavepoint + : this.provider === 'sqlserver' + ? undefined + : async (name) => { + await this.executeRawMock({ sql: `RELEASE SAVEPOINT ${name}`, args: [], argTypes: [] }) + } } executeRaw(params: SqlQuery): Promise { @@ -62,15 +98,19 @@ class MockDriverAdapter implements SqlDriverAdapter { const commitMock = this.commitMock const rollbackMock = this.rollbackMock const usePhantomQuery = this.usePhantomQuery + const provider = this.provider const mockTransaction: Transaction = { adapterName: 'mock-adapter', - provider: 'postgres', + provider, options: { usePhantomQuery }, queryRaw: jest.fn().mockRejectedValue('Not implemented for test'), executeRaw: executeRawMock, commit: commitMock, rollback: rollbackMock, + createSavepoint: this.createSavepoint, + rollbackToSavepoint: this.rollbackToSavepoint, + releaseSavepoint: this.releaseSavepoint, } return new Promise((resolve) => @@ -113,6 +153,281 @@ test('transaction executes normally', async () => { await expect(transactionManager.rollbackTransaction(id)).rejects.toBeInstanceOf(TransactionClosedError) }) +test('nested commit only closes at the outermost level', async () => { + const driverAdapter = new MockDriverAdapter() + const transactionManager = new TransactionManager({ + driverAdapter, + transactionOptions: TRANSACTION_OPTIONS, + tracingHelper: noopTracingHelper, + }) + + const id = await startTransaction(transactionManager) + + const nested = await transactionManager.startTransaction({ + ...TRANSACTION_OPTIONS, + newTxId: id, + }) + + expect(nested.id).toBe(id) + + expect(driverAdapter.executeRawMock.mock.calls[0][0].sql).toMatch(/^SAVEPOINT prisma_sp_\d+$/) + + // Inner commit should not close the underlying transaction. + await transactionManager.commitTransaction(id) + expect(driverAdapter.commitMock).not.toHaveBeenCalled() + expect(driverAdapter.executeRawMock.mock.calls[1][0].sql).toMatch(/^RELEASE SAVEPOINT prisma_sp_\d+$/) + + // Outer commit closes. + await transactionManager.commitTransaction(id) + expect(driverAdapter.commitMock).toHaveBeenCalledTimes(1) + expect(driverAdapter.executeRawMock.mock.calls[2][0].sql).toEqual('COMMIT') +}) + +test('nested rollback uses a savepoint and keeps the outer transaction open', async () => { + const driverAdapter = new MockDriverAdapter() + const transactionManager = new TransactionManager({ + driverAdapter, + transactionOptions: TRANSACTION_OPTIONS, + tracingHelper: noopTracingHelper, + }) + + const id = await startTransaction(transactionManager) + + await transactionManager.startTransaction({ + ...TRANSACTION_OPTIONS, + newTxId: id, + }) + + expect(driverAdapter.executeRawMock.mock.calls[0][0].sql).toMatch(/^SAVEPOINT prisma_sp_\d+$/) + + // Inner rollback should not close the underlying transaction. + await transactionManager.rollbackTransaction(id) + expect(driverAdapter.rollbackMock).not.toHaveBeenCalled() + expect(driverAdapter.executeRawMock.mock.calls[1][0].sql).toMatch(/^ROLLBACK TO SAVEPOINT prisma_sp_\d+$/) + expect(driverAdapter.executeRawMock.mock.calls[2][0].sql).toMatch(/^RELEASE SAVEPOINT prisma_sp_\d+$/) + + // Outer commit still closes. + await transactionManager.commitTransaction(id) + expect(driverAdapter.commitMock).toHaveBeenCalledTimes(1) + expect(driverAdapter.executeRawMock.mock.calls[3][0].sql).toEqual('COMMIT') +}) + +test('nested starts are serialized when the first savepoint fails', async () => { + const driverAdapter = new MockDriverAdapter() + const transactionManager = new TransactionManager({ + driverAdapter, + transactionOptions: TRANSACTION_OPTIONS, + tracingHelper: noopTracingHelper, + }) + + const id = await startTransaction(transactionManager) + + let savepointCallCount = 0 + let rejectFirstSavepoint: ((error: Error) => void) | undefined + let resolveFirstSavepointStarted: () => void + const firstSavepointStarted = new Promise((resolve) => { + resolveFirstSavepointStarted = resolve + }) + + driverAdapter.executeRawMock.mockImplementation((query) => { + if (query.sql.startsWith('SAVEPOINT ')) { + savepointCallCount += 1 + if (savepointCallCount === 1) { + resolveFirstSavepointStarted() + return new Promise((_, reject) => { + rejectFirstSavepoint = reject + }) + } + } + + return Promise.resolve(1) + }) + + const nested1 = transactionManager.startTransaction({ + ...TRANSACTION_OPTIONS, + newTxId: id, + }) + const nested2 = transactionManager.startTransaction({ + ...TRANSACTION_OPTIONS, + newTxId: id, + }) + + await firstSavepointStarted + await Promise.resolve() + expect(savepointCallCount).toBe(1) + + rejectFirstSavepoint?.(new Error('savepoint failed')) + await expect(nested1).rejects.toThrow('savepoint failed') + + await expect(nested2).resolves.toEqual({ id }) + + const savepointQueries = driverAdapter.executeRawMock.mock.calls + .map((call) => call[0].sql) + .filter((sql) => sql.startsWith('SAVEPOINT ')) + const successfulSavepointName = savepointQueries[1].slice('SAVEPOINT '.length) + + await transactionManager.rollbackTransaction(id) + + const rollbackToSavepointQuery = driverAdapter.executeRawMock.mock.calls + .map((call) => call[0].sql) + .find((sql) => sql.startsWith('ROLLBACK TO SAVEPOINT ')) + + expect(rollbackToSavepointQuery).toEqual(`ROLLBACK TO SAVEPOINT ${successfulSavepointName}`) + + // Close top-level transaction. + await transactionManager.rollbackTransaction(id) +}) + +test('nested savepoints use sqlserver syntax', async () => { + const driverAdapter = new MockDriverAdapter({ provider: 'sqlserver' }) + const transactionManager = new TransactionManager({ + driverAdapter, + transactionOptions: TRANSACTION_OPTIONS, + tracingHelper: noopTracingHelper, + }) + + const id = await startTransaction(transactionManager) + await transactionManager.startTransaction({ ...TRANSACTION_OPTIONS, newTxId: id }) + + expect(driverAdapter.executeRawMock.mock.calls[0][0].sql).toMatch(/^SAVE TRANSACTION prisma_sp_\d+$/) + + await transactionManager.rollbackTransaction(id) + expect(driverAdapter.executeRawMock.mock.calls[1][0].sql).toMatch(/^ROLLBACK TRANSACTION prisma_sp_\d+$/) + + // No release savepoint query on SQL Server. + expect(driverAdapter.executeRawMock.mock.calls).toHaveLength(2) + await transactionManager.commitTransaction(id) + expect(driverAdapter.executeRawMock.mock.calls[2][0].sql).toEqual('COMMIT') +}) + +test('nested savepoints use mysql syntax', async () => { + const driverAdapter = new MockDriverAdapter({ provider: 'mysql' }) + const transactionManager = new TransactionManager({ + driverAdapter, + transactionOptions: TRANSACTION_OPTIONS, + tracingHelper: noopTracingHelper, + }) + + const id = await startTransaction(transactionManager) + await transactionManager.startTransaction({ ...TRANSACTION_OPTIONS, newTxId: id }) + + expect(driverAdapter.executeRawMock.mock.calls[0][0].sql).toMatch(/^SAVEPOINT prisma_sp_\d+$/) + + await transactionManager.rollbackTransaction(id) + expect(driverAdapter.executeRawMock.mock.calls[1][0].sql).toMatch(/^ROLLBACK TO prisma_sp_\d+$/) + expect(driverAdapter.executeRawMock.mock.calls[2][0].sql).toMatch(/^RELEASE SAVEPOINT prisma_sp_\d+$/) + + await transactionManager.commitTransaction(id) + expect(driverAdapter.executeRawMock.mock.calls[3][0].sql).toEqual('COMMIT') +}) + +test('nested savepoints use sqlite syntax', async () => { + const driverAdapter = new MockDriverAdapter({ provider: 'sqlite' }) + const transactionManager = new TransactionManager({ + driverAdapter, + transactionOptions: TRANSACTION_OPTIONS, + tracingHelper: noopTracingHelper, + }) + + const id = await startTransaction(transactionManager) + await transactionManager.startTransaction({ ...TRANSACTION_OPTIONS, newTxId: id }) + + expect(driverAdapter.executeRawMock.mock.calls[0][0].sql).toMatch(/^SAVEPOINT prisma_sp_\d+$/) + + await transactionManager.rollbackTransaction(id) + expect(driverAdapter.executeRawMock.mock.calls[1][0].sql).toMatch(/^ROLLBACK TO prisma_sp_\d+$/) + expect(driverAdapter.executeRawMock.mock.calls[2][0].sql).toMatch(/^RELEASE SAVEPOINT prisma_sp_\d+$/) + + await transactionManager.commitTransaction(id) + expect(driverAdapter.executeRawMock.mock.calls[3][0].sql).toEqual('COMMIT') +}) + +test('nested savepoints use adapter-provided methods when available', async () => { + const createSavepoint = jest.fn>, [string]>(async () => {}) + const rollbackToSavepoint = jest.fn>, [string]>( + async () => {}, + ) + const releaseSavepoint = jest.fn>, [string]>(async () => {}) + + const driverAdapter = new MockDriverAdapter({ + provider: 'postgres', + createSavepoint, + rollbackToSavepoint, + releaseSavepoint, + }) + const transactionManager = new TransactionManager({ + driverAdapter, + transactionOptions: TRANSACTION_OPTIONS, + tracingHelper: noopTracingHelper, + }) + + const id = await startTransaction(transactionManager) + await transactionManager.startTransaction({ ...TRANSACTION_OPTIONS, newTxId: id }) + await transactionManager.rollbackTransaction(id) + await transactionManager.commitTransaction(id) + + expect(createSavepoint).toHaveBeenCalledTimes(1) + expect(rollbackToSavepoint).toHaveBeenCalledTimes(1) + expect(releaseSavepoint).toHaveBeenCalledTimes(1) + + const savepointName = createSavepoint.mock.calls[0][0] + expect(rollbackToSavepoint).toHaveBeenCalledWith(savepointName) + expect(releaseSavepoint).toHaveBeenCalledWith(savepointName) + expect(driverAdapter.executeRawMock.mock.calls[0][0].sql).toEqual('COMMIT') +}) + +test('nested savepoint release can be omitted by adapter', async () => { + const createSavepoint = jest.fn>, [string]>(async () => {}) + + const driverAdapter = new MockDriverAdapter({ + provider: 'postgres', + createSavepoint, + releaseSavepoint: undefined, + }) + const transactionManager = new TransactionManager({ + driverAdapter, + transactionOptions: TRANSACTION_OPTIONS, + tracingHelper: noopTracingHelper, + }) + + const id = await startTransaction(transactionManager) + await transactionManager.startTransaction({ ...TRANSACTION_OPTIONS, newTxId: id }) + await transactionManager.commitTransaction(id) + await transactionManager.commitTransaction(id) + + expect(createSavepoint).toHaveBeenCalledTimes(1) + expect(driverAdapter.executeRawMock.mock.calls[0][0].sql).toEqual('COMMIT') +}) + +test('missing createSavepoint fails nested transaction start', async () => { + const driverAdapter = new MockDriverAdapter({ createSavepoint: undefined }) + const transactionManager = new TransactionManager({ + driverAdapter, + transactionOptions: TRANSACTION_OPTIONS, + tracingHelper: noopTracingHelper, + }) + + const id = await startTransaction(transactionManager) + + await expect(transactionManager.startTransaction({ ...TRANSACTION_OPTIONS, newTxId: id })).rejects.toThrow( + 'createSavepoint is not implemented', + ) +}) + +test('missing rollbackToSavepoint fails nested rollback', async () => { + const driverAdapter = new MockDriverAdapter({ rollbackToSavepoint: undefined }) + const transactionManager = new TransactionManager({ + driverAdapter, + transactionOptions: TRANSACTION_OPTIONS, + tracingHelper: noopTracingHelper, + }) + + const id = await startTransaction(transactionManager) + await transactionManager.startTransaction({ ...TRANSACTION_OPTIONS, newTxId: id }) + + await expect(transactionManager.rollbackTransaction(id)).rejects.toThrow('rollbackToSavepoint is not implemented') +}) + test('transaction is rolled back', async () => { const driverAdapter = new MockDriverAdapter() const transactionManager = new TransactionManager({ diff --git a/packages/client-engine-runtime/src/transaction-manager/transaction-manager.ts b/packages/client-engine-runtime/src/transaction-manager/transaction-manager.ts index f0722567d5fd..e1cbe7fdba93 100644 --- a/packages/client-engine-runtime/src/transaction-manager/transaction-manager.ts +++ b/packages/client-engine-runtime/src/transaction-manager/transaction-manager.ts @@ -28,6 +28,10 @@ type TransactionWrapper = { timeout: number | undefined startedAt: number transaction?: Transaction + operationQueue: Promise + depth: number + savepoints: string[] + savepointCounter: number } & TransactionState type TransactionState = @@ -102,6 +106,36 @@ export class TransactionManager { } async #startTransactionImpl(options: Options): Promise { + if (options.newTxId) { + return await this.#withActiveTransactionLock(options.newTxId, 'start', async (existing) => { + if (existing.status !== 'running') { + throw new TransactionInternalConsistencyError( + `Transaction in invalid state ${existing.status} when starting a nested transaction.`, + ) + } + if (!existing.transaction) { + throw new TransactionInternalConsistencyError( + `Transaction missing underlying driver transaction when starting a nested transaction.`, + ) + } + + existing.depth += 1 + + const savepointName = this.#nextSavepointName(existing) + existing.savepoints.push(savepointName) + try { + await this.#requiredCreateSavepoint(existing.transaction)(savepointName) + } catch (e) { + // Keep state consistent if creating the savepoint fails. + existing.depth -= 1 + existing.savepoints.pop() + throw e + } + + return { id: existing.id } + }) + } + const transaction: TransactionWrapper = { id: await randomUUID(), status: 'waiting', @@ -109,6 +143,10 @@ export class TransactionManager { timeout: options.timeout, startedAt: Date.now(), transaction: undefined, + operationQueue: Promise.resolve(), + depth: 1, + savepoints: [], + savepointCounter: 0, } // Start timeout to wait for transaction to be started. @@ -167,15 +205,53 @@ export class TransactionManager { async commitTransaction(transactionId: string): Promise { return await this.tracingHelper.runInChildSpan('commit_transaction', async () => { - const txw = this.#getActiveOrClosingTransaction(transactionId, 'commit') - await this.#closeTransaction(txw, 'committed') + await this.#withActiveTransactionLock(transactionId, 'commit', async (txw) => { + if (txw.depth > 1) { + if (!txw.transaction) throw new TransactionNotFoundError() + const savepointName = txw.savepoints.at(-1) + if (!savepointName) { + throw new TransactionInternalConsistencyError( + `Missing savepoint for nested commit. Depth: ${txw.depth}, transactionId: ${txw.id}`, + ) + } + try { + await this.#releaseSavepoint(txw.transaction, savepointName) + } finally { + // Keep internal state consistent even if releasing the savepoint fails. + txw.savepoints.pop() + txw.depth -= 1 + } + return + } + await this.#closeTransaction(txw, 'committed') + }) }) } async rollbackTransaction(transactionId: string): Promise { return await this.tracingHelper.runInChildSpan('rollback_transaction', async () => { - const txw = this.#getActiveOrClosingTransaction(transactionId, 'rollback') - await this.#closeTransaction(txw, 'rolled_back') + await this.#withActiveTransactionLock(transactionId, 'rollback', async (txw) => { + if (txw.depth > 1) { + if (!txw.transaction) throw new TransactionNotFoundError() + const savepointName = txw.savepoints.at(-1) + if (!savepointName) { + throw new TransactionInternalConsistencyError( + `Missing savepoint for nested rollback. Depth: ${txw.depth}, transactionId: ${txw.id}`, + ) + } + + try { + await this.#requiredRollbackToSavepoint(txw.transaction)(savepointName) + await this.#releaseSavepoint(txw.transaction, savepointName) + } finally { + // Keep internal state consistent even if rollback/release fails. + txw.savepoints.pop() + txw.depth -= 1 + } + return + } + await this.#closeTransaction(txw, 'rolled_back') + }) }) } @@ -228,7 +304,53 @@ export class TransactionManager { async cancelAllTransactions(): Promise { // TODO: call `map` on the iterator directly without collecting it into an array first // once we drop support for Node.js 18 and 20. - await Promise.allSettled([...this.transactions.values()].map((tx) => this.#closeTransaction(tx, 'rolled_back'))) + await Promise.allSettled( + [...this.transactions.values()].map((tx) => + this.#runSerialized(tx, async () => { + const current = this.transactions.get(tx.id) + if (current) { + await this.#closeTransaction(current, 'rolled_back') + } + }), + ), + ) + } + + #nextSavepointName(transaction: TransactionWrapper): string { + return `prisma_sp_${transaction.savepointCounter++}` + } + + #requiredCreateSavepoint(transaction: Transaction): (name: string) => Promise { + if (transaction.createSavepoint) { + return transaction.createSavepoint.bind(transaction) + } + + throw new TransactionManagerError( + `Nested transactions are not supported by adapter "${transaction.adapterName}" (${transaction.provider}): createSavepoint is not implemented.`, + ) + } + + #requiredRollbackToSavepoint(transaction: Transaction): (name: string) => Promise { + if (transaction.rollbackToSavepoint) { + return transaction.rollbackToSavepoint.bind(transaction) + } + + throw new TransactionManagerError( + `Nested transactions are not supported by adapter "${transaction.adapterName}" (${transaction.provider}): rollbackToSavepoint is not implemented.`, + ) + } + + async #releaseSavepoint(transaction: Transaction, name: string): Promise { + if (transaction.releaseSavepoint) { + await transaction.releaseSavepoint(name) + } + } + + #debugTransactionAlreadyClosedOnTimeout(transactionId: string): void { + // Transaction was already committed or rolled back when timeout happened. + // Should normally not happen as timeout is cancelled when transaction is committed or rolled back. + // No further action needed though. + debug('Transaction already committed or rolled back when timeout happened.', transactionId) } #startTransactionTimeout(transactionId: string, timeout: number | undefined): NodeJS.Timeout | undefined { @@ -237,20 +359,57 @@ export class TransactionManager { debug('Transaction timed out.', { transactionId, timeoutStartedAt, timeout }) const tx = this.transactions.get(transactionId) - if (tx && ['running', 'waiting'].includes(tx.status)) { - await this.#closeTransaction(tx, 'timed_out') - } else { - // Transaction was already committed or rolled back when timeout happened. - // Should normally not happen as timeout is cancelled when transaction is committed or rolled back. - // No further action needed though. - debug('Transaction already committed or rolled back when timeout happened.', transactionId) + if (!tx) { + this.#debugTransactionAlreadyClosedOnTimeout(transactionId) + return } + + await this.#runSerialized(tx, async () => { + const current = this.transactions.get(transactionId) + if (current && ['running', 'waiting'].includes(current.status)) { + await this.#closeTransaction(current, 'timed_out') + } else { + this.#debugTransactionAlreadyClosedOnTimeout(transactionId) + } + }) }, timeout) timer?.unref?.() return timer } + // Any operation that mutates or closes a transaction must run through this lock so + // status/savepoint/depth checks and updates happen against a stable view of state. + async #withActiveTransactionLock( + transactionId: string, + operation: string, + callback: (tx: TransactionWrapper) => Promise, + ): Promise { + const tx = this.#getActiveOrClosingTransaction(transactionId, operation) + return await this.#runSerialized(tx, async () => { + const current = this.#getActiveOrClosingTransaction(transactionId, operation) + return await callback(current) + }) + } + + // Serializes operations per transaction id to prevent interleaving across awaits. + // This avoids races where one operation mutates savepoint/depth state while another + // operation is suspended, which could otherwise corrupt cleanup logic. + async #runSerialized(tx: TransactionWrapper, callback: () => Promise): Promise { + const previousOperation = tx.operationQueue + let releaseOperationLock!: () => void + tx.operationQueue = new Promise((resolve) => { + releaseOperationLock = resolve + }) + + await previousOperation + try { + return await callback() + } finally { + releaseOperationLock() + } + } + async #closeTransaction(tx: TransactionWrapper, status: 'committed' | 'rolled_back' | 'timed_out'): Promise { const createClosingPromise = async () => { debug('Closing transaction.', { transactionId: tx.id, status }) diff --git a/packages/client-engine-runtime/src/transaction-manager/transaction.ts b/packages/client-engine-runtime/src/transaction-manager/transaction.ts index 45c33b489f12..f9e763785c20 100644 --- a/packages/client-engine-runtime/src/transaction-manager/transaction.ts +++ b/packages/client-engine-runtime/src/transaction-manager/transaction.ts @@ -9,6 +9,7 @@ export type Options = { /// Transaction isolation level isolationLevel?: IsolationLevel + newTxId?: string } export type TransactionInfo = { diff --git a/packages/client-generator-js/src/TSClient/PrismaClient.ts b/packages/client-generator-js/src/TSClient/PrismaClient.ts index f03b904ab10b..1d834dda459c 100644 --- a/packages/client-generator-js/src/TSClient/PrismaClient.ts +++ b/packages/client-generator-js/src/TSClient/PrismaClient.ts @@ -233,9 +233,7 @@ function interactiveTransactionDefinition(context: GenerateContext) { const callbackType = ts .functionType() - .addParameter( - ts.parameter('prisma', ts.omit(ts.namedType('PrismaClient'), ts.namedType('runtime.ITXClientDenyList'))), - ) + .addParameter(ts.parameter('prisma', ts.omit(ts.namedType('PrismaClient'), itxTransactionClientDenyList(context)))) .setReturnType(returnType) const method = ts @@ -248,6 +246,14 @@ function interactiveTransactionDefinition(context: GenerateContext) { return ts.stringify(method, { indentLevel: 1, newLine: 'leading' }) } +function itxTransactionClientDenyList(context: GenerateContext) { + if (context.provider === 'mongodb') { + return ts.unionType([ts.namedType('runtime.ITXClientDenyList'), ts.stringLiteral('$transaction')]) + } + + return ts.namedType('runtime.ITXClientDenyList') +} + function queryRawDefinition(context: GenerateContext) { // we do not generate `$queryRaw...` definitions if not supported if (!context.dmmf.mappings.otherOperations.write.includes('queryRaw')) { @@ -476,6 +482,8 @@ get ${methodName}(): Prisma.${m.model}Delegate<${generics.join(', ')}>;` } public toTS(): string { const clientOptions = this.buildClientOptions() + const transactionClientDenyList = + this.context.provider === 'mongodb' ? "runtime.ITXClientDenyList | '$transaction'" : 'runtime.ITXClientDenyList' return `${clientExtensionsDefinitions(this.context)} export type DefaultPrismaClient = PrismaClient @@ -545,7 +553,7 @@ export function getLogLevel(log: Array): LogLevel | un /** * \`PrismaClient\` proxy available in interactive transactions. */ -export type TransactionClient = Omit +export type TransactionClient = Omit ` } diff --git a/packages/client-generator-ts/src/TSClient/PrismaClient.ts b/packages/client-generator-ts/src/TSClient/PrismaClient.ts index 06f6723d8198..9e1afef9567e 100644 --- a/packages/client-generator-ts/src/TSClient/PrismaClient.ts +++ b/packages/client-generator-ts/src/TSClient/PrismaClient.ts @@ -73,9 +73,7 @@ function interactiveTransactionDefinition(context: GenerateContext) { const callbackType = ts .functionType() - .addParameter( - ts.parameter('prisma', tsx.omit(ts.namedType('PrismaClient'), ts.namedType('runtime.ITXClientDenyList'))), - ) + .addParameter(ts.parameter('prisma', tsx.omit(ts.namedType('PrismaClient'), itxTransactionClientDenyList(context)))) .setReturnType(returnType) const method = ts @@ -88,6 +86,14 @@ function interactiveTransactionDefinition(context: GenerateContext) { return ts.stringify(method, { indentLevel: 1, newLine: 'leading' }) } +function itxTransactionClientDenyList(context: GenerateContext) { + if (!context.isSqlProvider()) { + return ts.unionType([ts.namedType('runtime.ITXClientDenyList'), ts.stringLiteral('$transaction')]) + } + + return ts.namedType('runtime.ITXClientDenyList') +} + function queryRawDefinition(context: GenerateContext) { // we do not generate `$queryRaw...` definitions if not supported if (!context.dmmf.mappings.otherOperations.write.includes('queryRaw')) { diff --git a/packages/client-generator-ts/src/TSClient/file-generators/PrismaNamespaceFile.ts b/packages/client-generator-ts/src/TSClient/file-generators/PrismaNamespaceFile.ts index 6d39ae97bae6..78458b4cbe68 100644 --- a/packages/client-generator-ts/src/TSClient/file-generators/PrismaNamespaceFile.ts +++ b/packages/client-generator-ts/src/TSClient/file-generators/PrismaNamespaceFile.ts @@ -31,6 +31,10 @@ export function createPrismaNamespaceFile(context: GenerateContext, options: TSC const fieldRefs = context.dmmf.schema.fieldRefTypes.prisma?.map((type) => new FieldRefInput(type).toTS()) ?? [] + const transactionClientDenyList = context.isSqlProvider() + ? 'runtime.ITXClientDenyList' + : "runtime.ITXClientDenyList | '$transaction'" + return `${jsDocHeader} ${imports.join('\n')} @@ -136,7 +140,7 @@ export type PrismaAction = /** * \`PrismaClient\` proxy available in interactive transactions. */ -export type TransactionClient = Omit +export type TransactionClient = Omit ` } diff --git a/packages/client/src/runtime/core/engines/common/types/Transaction.ts b/packages/client/src/runtime/core/engines/common/types/Transaction.ts index e338bf683ec4..f33b25ce42d0 100644 --- a/packages/client/src/runtime/core/engines/common/types/Transaction.ts +++ b/packages/client/src/runtime/core/engines/common/types/Transaction.ts @@ -11,6 +11,12 @@ export type Options = { /** Transaction isolation level */ isolationLevel?: IsolationLevel + + /** + * Used for nested interactive transactions. When provided, the engine may + * re-use an existing open transaction instead of opening a new one. + */ + newTxId?: string } export type InteractiveTransactionInfo = { diff --git a/packages/client/src/runtime/core/types/exported/itxClientDenyList.ts b/packages/client/src/runtime/core/types/exported/itxClientDenyList.ts index 1cd43aaf6ac6..046f4634d05e 100644 --- a/packages/client/src/runtime/core/types/exported/itxClientDenyList.ts +++ b/packages/client/src/runtime/core/types/exported/itxClientDenyList.ts @@ -1,4 +1,4 @@ -const denylist = ['$connect', '$disconnect', '$on', '$transaction', '$extends'] as const +const denylist = ['$connect', '$disconnect', '$on', '$use', '$extends'] as const export const itxClientDenyList = denylist as ReadonlyArray diff --git a/packages/client/src/runtime/getPrismaClient.ts b/packages/client/src/runtime/getPrismaClient.ts index 11f5d3df3687..b9f5c35f7f83 100644 --- a/packages/client/src/runtime/getPrismaClient.ts +++ b/packages/client/src/runtime/getPrismaClient.ts @@ -216,7 +216,67 @@ type EventCallback = [E] extends ['beforeExit'] ? (event: EngineEvent) => void : never -const TX_ID = Symbol.for('prisma.client.transaction.id') +const TX_SCOPE_CONTEXT = Symbol.for('prisma.client.transaction.scope_context') + +type ItxScopeState = { + stack: string[] +} + +type TopLevelItxScopeContext = { + kind: 'top-level' +} + +type NestedItxScopeContext = { + kind: 'nested' + txId: string + scopeId: string + scopeState: ItxScopeState +} + +type ItxScopeContext = TopLevelItxScopeContext | NestedItxScopeContext + +function getItxScopeContext(client: object): ItxScopeContext { + const symbolStorage = client as Record + const context = symbolStorage[TX_SCOPE_CONTEXT] + + if (context === undefined) { + return { kind: 'top-level' } + } + + if (isNestedItxScopeContext(context)) { + return context + } + + throw new Error('Internal error: inconsistent transaction scope context.') +} + +function isNestedItxScopeContext(value: unknown): value is NestedItxScopeContext { + if (typeof value !== 'object' || value === null) { + return false + } + + const objectValue = value as Record + return ( + objectValue['kind'] === 'nested' && + typeof objectValue['txId'] === 'string' && + typeof objectValue['scopeId'] === 'string' && + isItxScopeState(objectValue['scopeState']) + ) +} + +function isItxScopeState(value: unknown): value is ItxScopeState { + if (typeof value !== 'object' || value === null) { + return false + } + + return Array.isArray(value['stack']) +} + +function createItxScopeId(): string { + return typeof globalThis.crypto?.randomUUID === 'function' + ? globalThis.crypto.randomUUID() + : `${Date.now()}-${Math.random().toString(16).slice(2)}` +} const BatchTxIdCounter = { id: 0, @@ -675,46 +735,109 @@ Or read our docs at https://www.prisma.io/docs/concepts/components/prisma-client */ async _transactionWithCallback({ callback, - options, + options = {}, }: { callback: (client: Client) => Promise options?: Options }) { + const itxContext = getItxScopeContext(this) + const isNested = itxContext.kind === 'nested' + const scopeState: ItxScopeState = isNested ? itxContext.scopeState : { stack: [] } + const scopeStack = scopeState.stack + + const scopeId = createItxScopeId() + + if (isNested) { + // Only the currently-active (innermost) scope can start another nested scope. + // This prevents sibling nested transactions from running concurrently. + const activeScope = scopeStack.at(-1) + if (activeScope !== itxContext.scopeId) { + throw new Error('Concurrent nested transactions are not supported') + } + + // Re-use the underlying transaction in the engine by reusing the same transaction id. + options.newTxId = itxContext.txId + } + scopeStack.push(scopeId) + const headers = { traceparent: this._tracingHelper.getTraceParent() } const optionsWithDefaults: Options = { maxWait: options?.maxWait ?? this._engineConfig.transactionOptions.maxWait, timeout: options?.timeout ?? this._engineConfig.transactionOptions.timeout, isolationLevel: options?.isolationLevel ?? this._engineConfig.transactionOptions.isolationLevel, + newTxId: options.newTxId, + } + let info: Transaction.InteractiveTransactionInfo + try { + info = await this._engine.transaction('start', headers, optionsWithDefaults) + } catch (e) { + // Ensure we don't leave the scope stack dirty if starting the transaction fails. + if (scopeStack.at(-1) === scopeId) scopeStack.pop() + throw e } - const info = await this._engine.transaction('start', headers, optionsWithDefaults) let result: unknown try { // execute user logic with a proxied the client const transaction = { kind: 'itx', ...info } as const - result = await callback(this._createItxClient(transaction)) + result = await callback(this._createItxClient(transaction, scopeId, scopeState)) + + if (isNested) { + // Don't allow closing a transaction if we are not the active scope. + if (scopeStack.at(-1) !== scopeId) { + throw new Error('Nested transactions must be closed in reverse order of creation.') + } + } else if (scopeStack.length !== 1) { + throw new Error('Cannot close transaction while a nested transaction is still active.') + } - // it went well, then we commit the transaction await this._engine.transaction('commit', headers, info) } catch (e: any) { - // it went bad, then we rollback the transaction - await this._engine.transaction('rollback', headers, info).catch(() => {}) + // If we try to close out-of-order (e.g. un-awaited nested transaction), + // the TransactionManager depth will be > 1 and a single rollback would + // only roll back to the latest savepoint, leaving the top-level transaction + // open and leaking it. Force rollback to depth 0 in that case. + const isOrderViolation = scopeStack.at(-1) !== scopeId + const rollbackScopeCount = isOrderViolation ? Math.max(1, scopeStack.length) : 1 + for (let i = 0; i < rollbackScopeCount; i++) { + await this._engine.transaction('rollback', headers, info).catch((rollbackError) => { + debug('rollback attempt %d/%d failed: %O', i + 1, rollbackScopeCount, rollbackError) + }) + } throw e // silent rollback, throw original error + } finally { + if (scopeStack.at(-1) === scopeId) { + scopeStack.pop() + } else { + // Reset the scope stack to avoid poisoning this transaction context after an ordering violation. + scopeStack.length = 0 + } } return result } - _createItxClient(transaction: PrismaPromiseInteractiveTransaction): Client { + _createItxClient( + transaction: PrismaPromiseInteractiveTransaction, + scopeId: string, + scopeState: ItxScopeState, + ): Client { + const itxScopeContext: NestedItxScopeContext = { + kind: 'nested', + txId: transaction.id, + scopeId, + scopeState, + } + return createCompositeProxy( applyModelsAndClientExtensions( createCompositeProxy(unApplyModelsAndClientExtensions(this), [ - addProperty('_appliedParent', () => this._appliedParent._createItxClient(transaction)), + addProperty('_appliedParent', () => this._appliedParent._createItxClient(transaction, scopeId, scopeState)), addProperty('_createPrismaPromise', () => createPrismaPromiseFactory(transaction)), - addProperty(TX_ID, () => transaction.id), + addProperty(TX_SCOPE_CONTEXT, () => itxScopeContext), ]), ), [removeProperties(itxClientDenyList)], @@ -738,6 +861,13 @@ Or read our docs at https://www.prisma.io/docs/concepts/components/prisma-client 'Cloudflare D1 does not support interactive transactions. We recommend you to refactor your queries with that limitation in mind, and use batch transactions with `prisma.$transactions([])` where applicable.', ) } + } else if (config.activeProvider === 'mongodb' && getItxScopeContext(this).kind === 'nested') { + callback = () => { + throw new PrismaClientValidationError( + `The ${config.activeProvider} provider does not support nested transactions`, + { clientVersion: this._clientVersion }, + ) + } } else { callback = () => this._transactionWithCallback({ callback: input, options }) } diff --git a/packages/client/tests/functional/extensions/defineExtension.ts b/packages/client/tests/functional/extensions/defineExtension.ts index 3b77de252b3a..47fc5cbe6fd9 100644 --- a/packages/client/tests/functional/extensions/defineExtension.ts +++ b/packages/client/tests/functional/extensions/defineExtension.ts @@ -220,7 +220,6 @@ function itxWithinGenericExtension() { void xclient.$transaction((tx) => { expectTypeOf(tx).toHaveProperty('helperMethod') - expectTypeOf(tx).not.toHaveProperty('$transaction') expectTypeOf(tx).not.toHaveProperty('$extends') return Promise.resolve() }) diff --git a/packages/client/tests/functional/extensions/itx.ts b/packages/client/tests/functional/extensions/itx.ts index d925942519f8..aed17b556edc 100644 --- a/packages/client/tests/functional/extensions/itx.ts +++ b/packages/client/tests/functional/extensions/itx.ts @@ -267,7 +267,7 @@ testMatrix.setupTestSuite( if (isTransaction) { expect(ctx.$connect).toBeUndefined() expect(ctx.$disconnect).toBeUndefined() - expect(ctx.$transaction).toBeUndefined() + expect(ctx.$transaction).toBeDefined() expect(ctx.$extends).toBeUndefined() } else { expect(ctx.$connect).toBeDefined() diff --git a/packages/client/tests/functional/interactive-transactions/tests.ts b/packages/client/tests/functional/interactive-transactions/tests.ts index 2de3b13e4d65..83cbcd7983ec 100644 --- a/packages/client/tests/functional/interactive-transactions/tests.ts +++ b/packages/client/tests/functional/interactive-transactions/tests.ts @@ -220,11 +220,424 @@ testMatrix.setupTestSuite( await expect(result).resolves.toHaveLength(2) }) + testIf(provider === Providers.MONGODB)('mongodb: nested transactions are not available in types', async () => { + await prisma.$transaction((tx) => { + // For MongoDB, the transaction-bound client type should not expose `$transaction`. + // We keep this as a type-only assertion: at runtime, this is just a property access. + // @ts-test-if: provider !== Providers.MONGODB + void tx.$transaction + return Promise.resolve() + }) + }) + + /** + * If a parent transaction is rolled back, the child transaction should also rollback. + * This is only supported on SQL providers. + */ + testIf(provider !== Providers.MONGODB)('sql: nested rollback', async () => { + const email1 = `user_${copycat.uuid(101)}@website.com` + const email2 = `user_${copycat.uuid(102)}@website.com` + + await expect( + prisma.$transaction(async (tx) => { + await tx.user.create({ + data: { + email: email1, + }, + }) + + await tx.$transaction(async (tx2) => { + await tx2.user.create({ + data: { + email: email2, + }, + }) + }) + + throw new Error('Rollback') + }), + ).rejects.toThrow(/Rollback/) + + const users = await prisma.user.findMany({ + where: { + email: { + in: [email1, email2], + }, + }, + }) + + expect(users).toHaveLength(0) + }) + + testIf(provider !== Providers.MONGODB)( + 'sql: nested rollback restores parent state (savepoints, 3 levels)', + async () => { + const emailA = `user_${copycat.uuid(151)}@website.com` + const emailB = `user_${copycat.uuid(152)}@website.com` + const emailC = `user_${copycat.uuid(153)}@website.com` + + const outerPromise = prisma.$transaction(async (tx) => { + await tx.user.create({ data: { email: emailA } }) + + const afterOuterInsert = await tx.user.findMany({ where: { email: { in: [emailA, emailB, emailC] } } }) + expect(afterOuterInsert).toHaveLength(1) + + try { + await tx.$transaction(async (tx2) => { + await tx2.user.create({ data: { email: emailB } }) + + const afterInnerInsert = await tx2.user.findMany({ where: { email: { in: [emailA, emailB, emailC] } } }) + expect(afterInnerInsert).toHaveLength(2) + + try { + await tx2.$transaction(async (tx3) => { + await tx3.user.create({ data: { email: emailC } }) + const afterGrandchildInsert = await tx3.user.findMany({ + where: { email: { in: [emailA, emailB, emailC] } }, + }) + expect(afterGrandchildInsert).toHaveLength(3) + throw new Error('grandchild rollback') + }) + } catch (e) { + expect(e).toMatchObject({ message: 'grandchild rollback' }) + } + + const afterGrandchildRollback = await tx2.user.findMany({ + where: { email: { in: [emailA, emailB, emailC] } }, + }) + expect(afterGrandchildRollback).toHaveLength(2) + + throw new Error('child rollback') + }) + } catch (e) { + expect(e).toMatchObject({ message: 'child rollback' }) + } + + const afterChildRollback = await tx.user.findMany({ where: { email: { in: [emailA, emailB, emailC] } } }) + expect(afterChildRollback).toHaveLength(1) + + throw new Error('outer rollback') + }) + + await expect(outerPromise).rejects.toThrow('outer rollback') + + const users = await prisma.user.findMany({ where: { email: { in: [emailA, emailB, emailC] } } }) + expect(users).toHaveLength(0) + }, + ) + + testIf(provider !== Providers.MONGODB)('sql: nested commit keeps state (savepoints, 3 levels)', async () => { + const emailA = `user_${copycat.uuid(161)}@website.com` + const emailB = `user_${copycat.uuid(162)}@website.com` + const emailC = `user_${copycat.uuid(163)}@website.com` + + await prisma.$transaction(async (tx) => { + await tx.user.create({ data: { email: emailA } }) + + await tx.$transaction(async (tx2) => { + await tx2.user.create({ data: { email: emailB } }) + + await tx2.$transaction(async (tx3) => { + await tx3.user.create({ data: { email: emailC } }) + + const inside = await tx3.user.findMany({ where: { email: { in: [emailA, emailB, emailC] } } }) + expect(inside).toHaveLength(3) + }) + }) + + const insideOuter = await tx.user.findMany({ where: { email: { in: [emailA, emailB, emailC] } } }) + expect(insideOuter).toHaveLength(3) + }) + + const users = await prisma.user.findMany({ + where: { email: { in: [emailA, emailB, emailC] } }, + orderBy: { email: 'asc' }, + }) + expect(users).toHaveLength(3) + }) + + testIf(provider !== Providers.MONGODB)('sql: disallow concurrent nested transactions', async () => { + const result = prisma.$transaction(async (tx) => { + const email1 = `user_${copycat.uuid(201)}@website.com` + const email2 = `user_${copycat.uuid(202)}@website.com` + const email3 = `user_${copycat.uuid(203)}@website.com` + const email4 = `user_${copycat.uuid(204)}@website.com` + const email5 = `user_${copycat.uuid(205)}@website.com` + const email6 = `user_${copycat.uuid(206)}@website.com` + const email7 = `user_${copycat.uuid(207)}@website.com` + + await Promise.all([ + tx.$transaction(async (tx2) => { + await tx2.user.create({ data: { email: email1 } }) + await tx2.user.create({ data: { email: email2 } }) + await tx2.user.create({ data: { email: email3 } }) + }), + tx.$transaction(async (tx3) => { + await tx3.user.create({ data: { email: email4 } }) + await tx3.user.create({ data: { email: email5 } }) + }), + tx.$transaction(async (tx4) => { + await tx4.user.create({ data: { email: email6 } }) + await tx4.user.create({ data: { email: email7 } }) + }), + ]) + }) + + await expect(result).rejects.toThrow('Concurrent nested transactions are not supported') + + // No partial writes should be visible after the outer transaction fails. + const users = await prisma.user.findMany() + expect(users).toHaveLength(0) + }) + + testIf(provider !== Providers.MONGODB)( + 'sql: allow nested transactions in concurrent top-level transactions', + async () => { + const a1 = `user_${copycat.uuid(401)}@website.com` + const a2 = `user_${copycat.uuid(402)}@website.com` + const b1 = `user_${copycat.uuid(403)}@website.com` + const b2 = `user_${copycat.uuid(404)}@website.com` + + await Promise.all([ + prisma.$transaction(async (tx) => { + await tx.user.create({ data: { email: a1 } }) + await delay(25) + await tx.$transaction(async (tx2) => { + await delay(25) + await tx2.user.create({ data: { email: a2 } }) + }) + }), + prisma.$transaction(async (tx) => { + await tx.user.create({ data: { email: b1 } }) + await delay(25) + await tx.$transaction(async (tx2) => { + await delay(25) + await tx2.user.create({ data: { email: b2 } }) + }) + }), + ]) + + const users = await prisma.user.findMany({ where: { email: { in: [a1, a2, b1, b2] } } }) + expect(users).toHaveLength(4) + }, + ) + + testIf(provider !== Providers.MONGODB)('sql: nested commit keeps outer transaction open', async () => { + const email1 = `user_${copycat.uuid(211)}@website.com` + const email2 = `user_${copycat.uuid(212)}@website.com` + const email3 = `user_${copycat.uuid(213)}@website.com` + + const users = await prisma.$transaction(async (tx) => { + await tx.user.create({ data: { email: email1 } }) + + await tx.$transaction(async (tx2) => { + await tx2.user.create({ data: { email: email2 } }) + }) + + // If nested commit incorrectly closes the underlying transaction, + // this query or the final commit will fail. + await tx.user.create({ data: { email: email3 } }) + + return tx.user.findMany({ + where: { email: { in: [email1, email2, email3] } }, + orderBy: { email: 'asc' }, + }) + }) + + expect(users).toHaveLength(3) + }) + + testIf(provider !== Providers.MONGODB)('sql: sequential nested transactions work', async () => { + const email1 = `user_${copycat.uuid(221)}@website.com` + const email2 = `user_${copycat.uuid(222)}@website.com` + const email3 = `user_${copycat.uuid(223)}@website.com` + + await prisma.$transaction(async (tx) => { + await tx.$transaction(async (tx2) => { + await tx2.user.create({ data: { email: email1 } }) + }) + + await tx.$transaction(async (tx3) => { + await tx3.user.create({ data: { email: email2 } }) + }) + + await tx.user.create({ data: { email: email3 } }) + }) + + const users = await prisma.user.findMany({ + where: { email: { in: [email1, email2, email3] } }, + }) + expect(users).toHaveLength(3) + }) + + testIf(provider !== Providers.MONGODB)('sql: deep nesting (3 levels) works', async () => { + const email1 = `user_${copycat.uuid(231)}@website.com` + const email2 = `user_${copycat.uuid(232)}@website.com` + const email3 = `user_${copycat.uuid(233)}@website.com` + + await prisma.$transaction(async (tx) => { + await tx.user.create({ data: { email: email1 } }) + + await tx.$transaction(async (tx2) => { + await tx2.user.create({ data: { email: email2 } }) + + await tx2.$transaction(async (tx3) => { + await tx3.user.create({ data: { email: email3 } }) + }) + }) + }) + + const users = await prisma.user.findMany({ + where: { email: { in: [email1, email2, email3] } }, + }) + expect(users).toHaveLength(3) + }) + + testIf(provider !== Providers.MONGODB)('sql: nested rollback can be caught and outer can continue', async () => { + const outerEmail1 = `user_${copycat.uuid(241)}@website.com` + const innerEmail = `user_${copycat.uuid(242)}@website.com` + const outerEmail2 = `user_${copycat.uuid(243)}@website.com` + + await prisma.$transaction(async (tx) => { + await tx.user.create({ data: { email: outerEmail1 } }) + + try { + await tx.$transaction(async (tx2) => { + await tx2.user.create({ data: { email: innerEmail } }) + throw new Error('inner rollback') + }) + } catch (e) { + expect(e).toMatchObject({ message: 'inner rollback' }) + } + + await tx.user.create({ data: { email: outerEmail2 } }) + }) + + const users = await prisma.user.findMany({ + where: { email: { in: [outerEmail1, innerEmail, outerEmail2] } }, + orderBy: { email: 'asc' }, + }) + + expect(users.map((u) => u.email)).toEqual([outerEmail1, outerEmail2].sort()) + }) + + testIf(provider !== Providers.MONGODB)('sql: enforce order for nested transactions', async () => { + const result = prisma.$transaction(async (tx) => { + const nested = tx.$transaction(async (tx2) => { + await tx2.user.create({ data: { email: `user_${copycat.uuid(301)}@website.com` } }) + await delay(50) + }) + nested.catch(() => {}) // avoid unhandled rejection in this test + + await tx.user.create({ data: { email: `user_${copycat.uuid(302)}@website.com` } }) + }) + + await expect(result).rejects.toThrow('Cannot close transaction while a nested transaction is still active.') + + const users = await prisma.user.findMany() + expect(users).toHaveLength(0) + }) + + testIf(provider !== Providers.MONGODB)( + 'sql: child fails if parent tries to commit before child finishes', + async () => { + const email = `user_${copycat.uuid(521)}@website.com` + let child: Promise | undefined + + const parent = prisma.$transaction((tx) => { + child = tx.$transaction(async (tx2) => { + await delay(50) + await tx2.user.create({ data: { email } }) + }) + child.catch(() => {}) // prevent unhandled rejection if parent fails fast + + // Parent returns immediately (commit attempt) without awaiting the child. + return Promise.resolve() + }) + + await expect(parent).rejects.toThrow('Cannot close transaction while a nested transaction is still active.') + await expect(child).rejects.toMatchObject({ + code: 'P2028', + }) + + const users = await prisma.user.findMany({ where: { email } }) + expect(users).toHaveLength(0) + }, + ) + + testIf(provider !== Providers.MONGODB)('sql: child fails if parent rolls back before child finishes', async () => { + const email = `user_${copycat.uuid(501)}@website.com` + let child: Promise | undefined + + const parent = prisma.$transaction((tx) => { + child = tx.$transaction(async (tx2) => { + await delay(50) + await tx2.user.create({ data: { email } }) + }) + child.catch(() => {}) // prevent unhandled rejection if parent fails fast + + return Promise.reject(new Error('parent rollback')) + }) + + await expect(parent).rejects.toThrow('parent rollback') + await expect(child).rejects.toMatchObject({ + code: 'P2028', + }) + + const users = await prisma.user.findMany({ where: { email } }) + expect(users).toHaveLength(0) + }) + + testIf(provider !== Providers.MONGODB)( + 'sql: child fails if nested parent closes before grandchild finishes', + async () => { + const email = `user_${copycat.uuid(511)}@website.com` + let grandchild: Promise | undefined + + const parent = prisma.$transaction(async (tx) => { + await tx.$transaction((tx2) => { + grandchild = tx2.$transaction(async (tx3) => { + await delay(50) + await tx3.user.create({ data: { email } }) + }) + grandchild.catch(() => {}) // prevent unhandled rejection if parent fails fast + + // Intentionally don't await `grandchild` to simulate incorrect ordering. + // The parent nested transaction should fail to close and the whole top-level tx should rollback. + return Promise.resolve() + }) + }) + + await expect(parent).rejects.toThrow('Nested transactions must be closed in reverse order of creation.') + await expect(grandchild).rejects.toMatchObject({ + code: 'P2028', + }) + + const users = await prisma.user.findMany({ where: { email } }) + expect(users).toHaveLength(0) + }, + ) + + testIf(provider === Providers.MONGODB)('mongodb: disallow nested transactions at runtime', async () => { + const result = prisma.$transaction(async (tx) => { + await tx.user.create({ data: { email: 'user_1@website.com' } }) + // Nested transactions are intentionally not available in types for MongoDB. + // Bypass the type system to assert runtime behavior. + await (tx as any).$transaction(async (tx2: any) => { + await tx2.user.create({ data: { email: 'user_2@website.com' } }) + }) + }) + + await expect(result).rejects.toThrow('The mongodb provider does not support nested transactions') + const users = await prisma.user.findMany() + expect(users).toHaveLength(0) + }) + /** * We don't allow certain methods to be called in a transaction */ test('forbidden', async () => { - const forbidden = ['$connect', '$disconnect', '$on', '$transaction'] + const forbidden = ['$connect', '$disconnect', '$on', '$use'] expect.assertions(forbidden.length + 1) const result = prisma.$transaction((prisma) => { @@ -241,25 +654,32 @@ testMatrix.setupTestSuite( * If one of the query fails, all queries should cancel */ test('rollback query', async () => { + const email1 = 'user_1@website.com' const result = prisma.$transaction(async (prisma) => { await prisma.user.create({ data: { id: copycat.uuid(1).replaceAll('-', '').slice(-24), - email: 'user_1@website.com', + email: email1, }, }) await prisma.user.create({ data: { id: copycat.uuid(2).replaceAll('-', '').slice(-24), - email: 'user_1@website.com', + email: email1, }, }) }) await expect(result).rejects.toMatchPrismaErrorSnapshot() - const users = await prisma.user.findMany() + const users = await prisma.user.findMany({ + where: { + email: { + equals: email1, + }, + }, + }) expect(users.length).toBe(0) }) diff --git a/packages/driver-adapter-utils/src/binder.ts b/packages/driver-adapter-utils/src/binder.ts index bfb6f3fb6051..c17ce6072ded 100644 --- a/packages/driver-adapter-utils/src/binder.ts +++ b/packages/driver-adapter-utils/src/binder.ts @@ -113,7 +113,7 @@ export const bindAdapter = ( // *.bind(transaction) is required to preserve the `this` context of functions whose // execution is delegated to napi.rs. const bindTransaction = (errorRegistry: ErrorRegistryInternal, transaction: Transaction): ErrorCapturingTransaction => { - return { + const boundTransaction: ErrorCapturingTransaction = { adapterName: transaction.adapterName, provider: transaction.provider, options: transaction.options, @@ -122,6 +122,20 @@ const bindTransaction = (errorRegistry: ErrorRegistryInternal, transaction: Tran commit: wrapAsync(errorRegistry, transaction.commit.bind(transaction)), rollback: wrapAsync(errorRegistry, transaction.rollback.bind(transaction)), } + + if (transaction.createSavepoint) { + boundTransaction.createSavepoint = wrapAsync(errorRegistry, transaction.createSavepoint.bind(transaction)) + } + + if (transaction.rollbackToSavepoint) { + boundTransaction.rollbackToSavepoint = wrapAsync(errorRegistry, transaction.rollbackToSavepoint.bind(transaction)) + } + + if (transaction.releaseSavepoint) { + boundTransaction.releaseSavepoint = wrapAsync(errorRegistry, transaction.releaseSavepoint.bind(transaction)) + } + + return boundTransaction } function wrapAsync( diff --git a/packages/driver-adapter-utils/src/types.ts b/packages/driver-adapter-utils/src/types.ts index 5943031f8940..744291caa2b7 100644 --- a/packages/driver-adapter-utils/src/types.ts +++ b/packages/driver-adapter-utils/src/types.ts @@ -289,6 +289,18 @@ export interface Transaction extends AdapterInfo, SqlQueryable { * Roll back the transaction. */ rollback(): Promise + /** + * Creates a savepoint within the currently running transaction. + */ + createSavepoint?(name: string): Promise + /** + * Rolls back transaction state to a previously created savepoint. + */ + rollbackToSavepoint?(name: string): Promise + /** + * Releases a previously created savepoint. Optional because not every connector supports this operation. + */ + releaseSavepoint?(name: string): Promise } /** diff --git a/packages/query-plan-executor/src/logic/adapter.test.ts b/packages/query-plan-executor/src/logic/adapter.test.ts index 90efa20738ad..800ffedf6b71 100644 --- a/packages/query-plan-executor/src/logic/adapter.test.ts +++ b/packages/query-plan-executor/src/logic/adapter.test.ts @@ -424,6 +424,9 @@ describe('createAdapter wrapper error handling', () => { rollback: vi.fn(), executeRaw: vi.fn(), queryRaw: vi.fn(), + createSavepoint: vi.fn(), + rollbackToSavepoint: vi.fn(), + releaseSavepoint: vi.fn(), } mockAdapter = { @@ -466,6 +469,18 @@ describe('createAdapter wrapper error handling', () => { return tx.queryRaw(query) }, }, + { + method: 'createSavepoint' as const, + execute: async (tx: Transaction) => tx.createSavepoint?.('sp1'), + }, + { + method: 'rollbackToSavepoint' as const, + execute: async (tx: Transaction) => tx.rollbackToSavepoint?.('sp1'), + }, + { + method: 'releaseSavepoint' as const, + execute: async (tx: Transaction) => tx.releaseSavepoint?.('sp1'), + }, ])('wraps transaction.$method() to sanitize errors', async ({ method, execute }) => { const error = new Error(`${method} failed for postgresql://user:pass@host:5432/db`) ;(mockTransaction as any)[method] = vi.fn().mockRejectedValue(error) @@ -493,6 +508,9 @@ describe('createAdapter wrapper error handling', () => { mockTransaction.rollback = vi.fn().mockResolvedValue(undefined) mockTransaction.executeRaw = vi.fn().mockResolvedValue(1) mockTransaction.queryRaw = vi.fn().mockResolvedValue(mockResult) + mockTransaction.createSavepoint = vi.fn().mockResolvedValue(undefined) + mockTransaction.rollbackToSavepoint = vi.fn().mockResolvedValue(undefined) + mockTransaction.releaseSavepoint = vi.fn().mockResolvedValue(undefined) const wrappedFactory = createAdapter('postgresql://user:pass@host:5432/db', [ { @@ -509,12 +527,18 @@ describe('createAdapter wrapper error handling', () => { await expect(tx.rollback()).resolves.toBeUndefined() expect(await tx.executeRaw(query)).toBe(1) expect(await tx.queryRaw(query)).toEqual(mockResult) + await expect(tx.createSavepoint?.('sp1')).resolves.toBeUndefined() + await expect(tx.rollbackToSavepoint?.('sp1')).resolves.toBeUndefined() + await expect(tx.releaseSavepoint?.('sp1')).resolves.toBeUndefined() /* eslint-disable @typescript-eslint/unbound-method */ expect(mockTransaction.commit).toHaveBeenCalledOnce() expect(mockTransaction.rollback).toHaveBeenCalledOnce() expect(mockTransaction.executeRaw).toHaveBeenCalledWith(query) expect(mockTransaction.queryRaw).toHaveBeenCalledWith(query) + expect(mockTransaction.createSavepoint).toHaveBeenCalledWith('sp1') + expect(mockTransaction.rollbackToSavepoint).toHaveBeenCalledWith('sp1') + expect(mockTransaction.releaseSavepoint).toHaveBeenCalledWith('sp1') /* eslint-enable @typescript-eslint/unbound-method */ }) diff --git a/packages/query-plan-executor/src/logic/adapter.ts b/packages/query-plan-executor/src/logic/adapter.ts index 3f28334dd2cd..e2cc8d93214c 100644 --- a/packages/query-plan-executor/src/logic/adapter.ts +++ b/packages/query-plan-executor/src/logic/adapter.ts @@ -118,7 +118,8 @@ function createConnectionStringRegex(protocols: string[]) { function wrapFactory(protocols: string[], factory: SqlDriverAdapterFactory): SqlDriverAdapterFactory { return { - ...factory, + adapterName: factory.adapterName, + provider: factory.provider, connect: () => factory.connect().then(wrapAdapter.bind(null, protocols), rethrowSanitizedError.bind(null, protocols)), } @@ -126,7 +127,8 @@ function wrapFactory(protocols: string[], factory: SqlDriverAdapterFactory): Sql function wrapAdapter(protocols: string[], adapter: SqlDriverAdapter): SqlDriverAdapter { return { - ...adapter, + adapterName: adapter.adapterName, + provider: adapter.provider, dispose: () => adapter.dispose().catch(rethrowSanitizedError.bind(null, protocols)), executeRaw: (query) => adapter.executeRaw(query).catch(rethrowSanitizedError.bind(null, protocols)), queryRaw: (query) => adapter.queryRaw(query).catch(rethrowSanitizedError.bind(null, protocols)), @@ -141,10 +143,21 @@ function wrapAdapter(protocols: string[], adapter: SqlDriverAdapter): SqlDriverA function wrapTransaction(protocols: string[], tx: Transaction): Transaction { return { - ...tx, + adapterName: tx.adapterName, + provider: tx.provider, + options: tx.options, commit: () => tx.commit().catch(rethrowSanitizedError.bind(null, protocols)), rollback: () => tx.rollback().catch(rethrowSanitizedError.bind(null, protocols)), executeRaw: (query) => tx.executeRaw(query).catch(rethrowSanitizedError.bind(null, protocols)), queryRaw: (query) => tx.queryRaw(query).catch(rethrowSanitizedError.bind(null, protocols)), + createSavepoint: tx.createSavepoint + ? (name) => tx.createSavepoint!(name).catch(rethrowSanitizedError.bind(null, protocols)) + : undefined, + rollbackToSavepoint: tx.rollbackToSavepoint + ? (name) => tx.rollbackToSavepoint!(name).catch(rethrowSanitizedError.bind(null, protocols)) + : undefined, + releaseSavepoint: tx.releaseSavepoint + ? (name) => tx.releaseSavepoint!(name).catch(rethrowSanitizedError.bind(null, protocols)) + : undefined, } } diff --git a/packages/query-plan-executor/src/server/schemas.ts b/packages/query-plan-executor/src/server/schemas.ts index 145f6719b058..6124f3c1230b 100644 --- a/packages/query-plan-executor/src/server/schemas.ts +++ b/packages/query-plan-executor/src/server/schemas.ts @@ -50,6 +50,7 @@ export const TransactionStartRequestBody = z.object({ isolationLevel: z .enum(['READ UNCOMMITTED', 'READ COMMITTED', 'REPEATABLE READ', 'SNAPSHOT', 'SERIALIZABLE']) .optional(), + newTxId: z.string().optional(), }) export type TransactionStartRequestBody = z.infer