Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions packages/client-engine-runtime/bench/sample-query-plans.ts
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ export const JOIN_PLAN: QueryPlanNode = {
isRelationUnique: false,
},
],
canAssumeStrictEquality: true,
},
},
structure: {
Expand Down Expand Up @@ -231,6 +232,7 @@ export const DEEP_JOIN_PLAN: QueryPlanNode = {
isRelationUnique: true,
},
],
canAssumeStrictEquality: true,
},
},
children: [
Expand Down Expand Up @@ -278,13 +280,15 @@ export const DEEP_JOIN_PLAN: QueryPlanNode = {
isRelationUnique: false,
},
],
canAssumeStrictEquality: true,
},
},
on: [['id', 'authorId']],
parentField: 'posts',
isRelationUnique: false,
},
],
canAssumeStrictEquality: true,
},
},
structure: {
Expand Down
2 changes: 1 addition & 1 deletion packages/client-engine-runtime/src/events.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
export type QueryEvent = {
timestamp: Date
query: string
params: unknown[]
params: readonly unknown[]
duration: number
}
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ function paginateSingleList(list: {}[], { cursor, skip, take }: Pagination): {}[
/*
* Generate a key string for a record based on the values of the specified fields.
*/
export function getRecordKey(record: {}, fields: string[]): string {
return JSON.stringify(fields.map((field) => record[field]))
export function getRecordKey(record: {}, fields: readonly string[], mappers?: ((value: unknown) => unknown)[]): string {
const array = fields.map((field, index) =>
mappers?.[index] ? (record[field] !== null ? mappers[index](record[field]) : null) : record[field],
)
return JSON.stringify(array)
}
95 changes: 78 additions & 17 deletions packages/client-engine-runtime/src/interpreter/query-interpreter.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { ConnectionInfo, SqlQuery, SqlQueryable, SqlResultSet } from '@prisma/driver-adapter-utils'
import type { SqlCommenterPlugin, SqlCommenterQueryInfo } from '@prisma/sqlcommenter'
import { klona } from 'klona'

import { QueryEvent } from '../events'
import { FieldInitializer, FieldOperation, InMemoryOps, JoinExpression, QueryPlanNode } from '../query-plan'
Expand All @@ -8,7 +9,7 @@ import { appendSqlComment, buildSqlComment } from '../sql-commenter'
import { type TracingHelper, withQuerySpanAndEvent } from '../tracing'
import { type TransactionManager } from '../transaction-manager/transaction-manager'
import { rethrowAsUserFacing, rethrowAsUserFacingRawError } from '../user-facing-error'
import { assertNever } from '../utils'
import { assertNever, DeepReadonly, DeepUnreadonly } from '../utils'
import { applyDataMap } from './data-mapper'
import { GeneratorRegistry, GeneratorRegistrySnapshot } from './generators'
import { getRecordKey, processRecords } from './in-memory-processing'
Expand Down Expand Up @@ -89,7 +90,7 @@ export class QueryInterpreter {
})
}

async run(queryPlan: QueryPlanNode, options: QueryRuntimeOptions): Promise<unknown> {
async run(queryPlan: DeepReadonly<QueryPlanNode>, options: QueryRuntimeOptions): Promise<unknown> {
const { value } = await this.interpretNode(queryPlan, {
...options,
generators: this.#generators.snapshot(),
Expand All @@ -98,7 +99,10 @@ export class QueryInterpreter {
return value
}

private async interpretNode(node: QueryPlanNode, context: QueryRuntimeContext): Promise<IntermediateValue> {
private async interpretNode(
node: DeepReadonly<QueryPlanNode>,
context: QueryRuntimeContext,
): Promise<IntermediateValue> {
switch (node.type) {
case 'value': {
return {
Expand Down Expand Up @@ -163,7 +167,7 @@ export class QueryInterpreter {
const commentedQuery = applyComments(query, context.sqlCommenter)
sum += await this.#withQuerySpanAndEvent(commentedQuery, context.queryable, () =>
context.queryable
.executeRaw(commentedQuery)
.executeRaw(cloneObject(commentedQuery))
.catch((err) =>
node.args.type === 'rawSql' ? rethrowAsUserFacingRawError(err) : rethrowAsUserFacing(err),
),
Expand All @@ -181,7 +185,7 @@ export class QueryInterpreter {
const commentedQuery = applyComments(query, context.sqlCommenter)
const result = await this.#withQuerySpanAndEvent(commentedQuery, context.queryable, () =>
context.queryable
.queryRaw(commentedQuery)
.queryRaw(cloneObject(commentedQuery))
.catch((err) =>
node.args.type === 'rawSql' ? rethrowAsUserFacingRawError(err) : rethrowAsUserFacing(err),
),
Expand Down Expand Up @@ -236,14 +240,14 @@ export class QueryInterpreter {
return { value: null, lastInsertId }
}

const children = (await Promise.all(
const children = await Promise.all(
node.args.children.map(async (joinExpr) => ({
joinExpr,
childRecords: (await this.interpretNode(joinExpr.child, context)).value,
})),
)) satisfies JoinExpressionWithRecords[]
)

return { value: attachChildrenToParents(parent, children), lastInsertId }
return { value: attachChildrenToParents(parent, children, node.args.canAssumeStrictEquality), lastInsertId }
}

case 'transaction': {
Expand Down Expand Up @@ -301,8 +305,9 @@ export class QueryInterpreter {

case 'process': {
const { value, lastInsertId } = await this.interpretNode(node.args.expr, context)
evaluateProcessingParameters(node.args.operations, context.scope, context.generators)
return { value: processRecords(value, node.args.operations), lastInsertId }
const ops = cloneObject(node.args.operations)
evaluateProcessingParameters(ops, context.scope, context.generators)
return { value: processRecords(value, ops), lastInsertId }
}

case 'initializeRecord': {
Expand Down Expand Up @@ -360,7 +365,11 @@ export class QueryInterpreter {
}
}

#withQuerySpanAndEvent<T>(query: SqlQuery, queryable: SqlQueryable, execute: () => Promise<T>): Promise<T> {
#withQuerySpanAndEvent<T>(
query: DeepReadonly<SqlQuery>,
queryable: SqlQueryable,
execute: () => Promise<T>,
): Promise<T> {
return withQuerySpanAndEvent({
query,
execute,
Expand Down Expand Up @@ -420,13 +429,20 @@ type JoinExpressionWithRecords = {
childRecords: Value
}

function attachChildrenToParents(parentRecords: unknown, children: JoinExpressionWithRecords[]) {
type KeyCast = (value: Value) => Value

function attachChildrenToParents(
parentRecords: unknown,
children: DeepReadonly<JoinExpressionWithRecords[]>,
canAssumeStrictEquality: boolean,
) {
for (const { joinExpr, childRecords } of children) {
const parentKeys = joinExpr.on.map(([k]) => k)
const childKeys = joinExpr.on.map(([, k]) => k)
const parentMap = {}

for (const parent of Array.isArray(parentRecords) ? parentRecords : [parentRecords]) {
const parentArray = Array.isArray(parentRecords) ? parentRecords : [parentRecords]
for (const parent of parentArray) {
const parentRecord = asRecord(parent)
const key = getRecordKey(parentRecord, parentKeys)
if (!parentMap[key]) {
Expand All @@ -441,12 +457,13 @@ function attachChildrenToParents(parentRecords: unknown, children: JoinExpressio
}
}

const mappers = canAssumeStrictEquality ? undefined : inferKeyCasts(parentArray, parentKeys)
for (const childRecord of Array.isArray(childRecords) ? childRecords : [childRecords]) {
if (childRecord === null) {
continue
}

const key = getRecordKey(asRecord(childRecord), childKeys)
const key = getRecordKey(asRecord(childRecord), childKeys, mappers)
for (const parentRecord of parentMap[key] ?? []) {
if (joinExpr.isRelationUnique) {
parentRecord[joinExpr.parentField] = childRecord
Expand All @@ -460,8 +477,45 @@ function attachChildrenToParents(parentRecords: unknown, children: JoinExpressio
return parentRecords
}

function inferKeyCasts(rows: unknown[], keys: string[]): KeyCast[] {
function getKeyCast(type: string): KeyCast | undefined {
switch (type) {
case 'number':
return Number
case 'string':
return String
case 'boolean':
return Boolean
case 'bigint':
return BigInt as KeyCast
default:
return
}
}

const keyCasts: KeyCast[] = Array.from({ length: keys.length })
let keysFound = 0
for (const parent of rows) {
const parentRecord = asRecord(parent)
for (const [i, key] of keys.entries()) {
if (parentRecord[key] !== null && keyCasts[i] === undefined) {
const keyCast = getKeyCast(typeof parentRecord[key])
if (keyCast !== undefined) {
keyCasts[i] = keyCast
}
keysFound++
}
}
if (keysFound === keys.length) {
break
}
}

return keyCasts
}

function evalFieldInitializer(
initializer: FieldInitializer,
initializer: DeepReadonly<FieldInitializer>,
lastInsertId: string | undefined,
scope: ScopeBindings,
generators: GeneratorRegistrySnapshot,
Expand All @@ -477,7 +531,7 @@ function evalFieldInitializer(
}

function evalFieldOperation(
op: FieldOperation,
op: DeepReadonly<FieldOperation>,
value: Value,
scope: ScopeBindings,
generators: GeneratorRegistrySnapshot,
Expand Down Expand Up @@ -508,7 +562,10 @@ function evalFieldOperation(
}
}

function applyComments(query: SqlQuery, sqlCommenter?: QueryInterpreterSqlCommenter): SqlQuery {
function applyComments(
query: DeepReadonly<SqlQuery>,
sqlCommenter?: QueryInterpreterSqlCommenter,
): DeepReadonly<SqlQuery> {
if (!sqlCommenter || sqlCommenter.plugins.length === 0) {
return query
}
Expand Down Expand Up @@ -543,3 +600,7 @@ function evaluateProcessingParameters(
evaluateProcessingParameters(nested, scope, generators)
}
}

function cloneObject<T>(value: T): DeepUnreadonly<T> {
return klona(value) as DeepUnreadonly<T>
}
32 changes: 19 additions & 13 deletions packages/client-engine-runtime/src/interpreter/render-query.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@ import {
type QueryPlanDbQuery,
} from '../query-plan'
import { UserFacingError } from '../user-facing-error'
import { assertNever } from '../utils'
import { assertNever, DeepReadonly } from '../utils'
import { GeneratorRegistrySnapshot } from './generators'
import { ScopeBindings } from './scope'

export function renderQuery(
dbQuery: QueryPlanDbQuery,
dbQuery: DeepReadonly<QueryPlanDbQuery>,
scope: ScopeBindings,
generators: GeneratorRegistrySnapshot,
maxChunkSize?: number,
): SqlQuery[] {
): DeepReadonly<SqlQuery>[] {
const args = dbQuery.args.map((arg) => evaluateArg(arg, scope, generators))

switch (dbQuery.type) {
Expand Down Expand Up @@ -69,10 +69,10 @@ export function evaluateArg(arg: unknown, scope: ScopeBindings, generators: Gene
}

function renderTemplateSql(
fragments: Fragment[],
fragments: DeepReadonly<Fragment[]>,
placeholderFormat: PlaceholderFormat,
params: unknown[],
argTypes: DynamicArgType[],
argTypes: DeepReadonly<DynamicArgType[]>,
): SqlQuery {
let sql = ''
const ctx = { placeholderNumber: 1 }
Expand Down Expand Up @@ -112,7 +112,7 @@ function renderTemplateSql(
}
}

function renderFragment<Type extends DynamicArgType | undefined>(
function renderFragment<Type extends DeepReadonly<DynamicArgType> | undefined>(
fragment: FragmentWithParams<Type>,
placeholderFormat: PlaceholderFormat,
ctx: { placeholderNumber: number },
Expand Down Expand Up @@ -158,10 +158,14 @@ function formatPlaceholder(placeholderFormat: PlaceholderFormat, placeholderNumb
return placeholderFormat.hasNumbering ? `${placeholderFormat.prefix}${placeholderNumber}` : placeholderFormat.prefix
}

function renderRawSql(sql: string, args: unknown[], argTypes: ArgType[]): SqlQuery {
function renderRawSql(
sql: string,
args: readonly unknown[],
argTypes: DeepReadonly<ArgType[]>,
): DeepReadonly<SqlQuery> {
return {
sql,
args: args,
args,
argTypes,
}
}
Expand All @@ -170,7 +174,7 @@ function doesRequireEvaluation(param: unknown): param is PrismaValuePlaceholder
return isPrismaValuePlaceholder(param) || isPrismaValueGenerator(param)
}

type FragmentWithParams<Type extends DynamicArgType | undefined = undefined> = Fragment &
type FragmentWithParams<Type extends DeepReadonly<DynamicArgType> | undefined = undefined> = Fragment &
(
| { type: 'stringChunk' }
| { type: 'parameter'; value: unknown; argType: Type }
Expand All @@ -179,10 +183,12 @@ type FragmentWithParams<Type extends DynamicArgType | undefined = undefined> = F
)

function* pairFragmentsWithParams<Types>(
fragments: Fragment[],
fragments: DeepReadonly<Fragment[]>,
params: unknown[],
argTypes: Types,
): Generator<FragmentWithParams<Types extends DynamicArgType[] ? DynamicArgType : undefined>> {
): Generator<
FragmentWithParams<Types extends DeepReadonly<DynamicArgType[]> ? DeepReadonly<DynamicArgType> : undefined>
> {
let index = 0

for (const fragment of fragments) {
Expand Down Expand Up @@ -239,7 +245,7 @@ function* pairFragmentsWithParams<Types>(
}
}

function* flattenedFragmentParams<Type extends DynamicArgType | undefined>(
function* flattenedFragmentParams<Type extends DeepReadonly<DynamicArgType> | undefined>(
fragment: FragmentWithParams<Type>,
): Generator<unknown, undefined, undefined> {
switch (fragment.type) {
Expand All @@ -259,7 +265,7 @@ function* flattenedFragmentParams<Type extends DynamicArgType | undefined>(
}
}

function chunkParams(fragments: Fragment[], params: unknown[], maxChunkSize?: number): unknown[][] {
function chunkParams(fragments: DeepReadonly<Fragment[]>, params: unknown[], maxChunkSize?: number): unknown[][] {
// Find out the total number of parameters once flattened and what the maximum number of
// parameters in a single fragment is.
let totalParamCount = 0
Expand Down
8 changes: 4 additions & 4 deletions packages/client-engine-runtime/src/interpreter/validation.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { DataRule, ValidationError } from '../query-plan'
import { UserFacingError } from '../user-facing-error'
import { assertNever } from '../utils'
import { assertNever, DeepReadonly } from '../utils'

export function performValidation(data: unknown, rules: DataRule[], error: ValidationError) {
export function performValidation(data: unknown, rules: DeepReadonly<DataRule[]>, error: ValidationError) {
if (!rules.every((rule) => doesSatisfyRule(data, rule))) {
const message = renderMessage(data, error)
const code = getErrorCode(error)
Expand Down Expand Up @@ -42,7 +42,7 @@ export function doesSatisfyRule(data: unknown, rule: DataRule): boolean {
}

function renderMessage(data: unknown, error: ValidationError): string {
switch (error.error_identifier) {
switch (error.errorIdentifier) {
case 'RELATION_VIOLATION':
return `The change you are trying to make would violate the required relation '${error.context.relation}' between the \`${error.context.modelA}\` and \`${error.context.modelB}\` models.`
case 'MISSING_RECORD':
Expand Down Expand Up @@ -70,7 +70,7 @@ function renderMessage(data: unknown, error: ValidationError): string {
}

function getErrorCode(error: ValidationError): string {
switch (error.error_identifier) {
switch (error.errorIdentifier) {
case 'RELATION_VIOLATION':
return 'P2014'
case 'RECORDS_NOT_CONNECTED':
Expand Down
Loading
Loading