diff --git a/packages/orm/src/client/crud/dialects/base-dialect.ts b/packages/orm/src/client/crud/dialects/base-dialect.ts index 5fae51ef5..169213e62 100644 --- a/packages/orm/src/client/crud/dialects/base-dialect.ts +++ b/packages/orm/src/client/crud/dialects/base-dialect.ts @@ -24,10 +24,10 @@ import { flattenCompoundUniqueFilters, getDelegateDescendantModels, getManyToManyRelation, + getModelFields, getRelationForeignKeyFieldPairs, isEnum, isTypeDef, - getModelFields, makeDefaultOrderBy, requireField, requireIdFields, @@ -1162,7 +1162,8 @@ export abstract class BaseCrudDialect { // client-level: check both uncapitalized (current) and original (backward compat) model name const uncapModel = lowerCaseFirst(model); - const omitConfig = (this.options.omit as Record | undefined)?.[uncapModel] ?? + const omitConfig = + (this.options.omit as Record | undefined)?.[uncapModel] ?? (this.options.omit as Record | undefined)?.[model]; if (omitConfig && typeof omitConfig === 'object' && typeof omitConfig[field] === 'boolean') { return omitConfig[field]; @@ -1357,7 +1358,9 @@ export abstract class BaseCrudDialect { const computedFields = this.options.computedFields as Record; // check both uncapitalized (current) and original (backward compat) model name const computedModel = fieldDef.originModel ?? model; - computer = computedFields?.[lowerCaseFirst(computedModel)]?.[field] ?? computedFields?.[computedModel]?.[field]; + computer = + computedFields?.[lowerCaseFirst(computedModel)]?.[field] ?? + computedFields?.[computedModel]?.[field]; } if (!computer) { throw createConfigError(`Computed field "${field}" implementation not provided for model "${model}"`); @@ -1489,6 +1492,19 @@ export abstract class BaseCrudDialect { */ abstract buildValuesTableSelect(fields: FieldDef[], rows: unknown[][]): SelectQueryBuilder; + /** + * Builds a binary comparison expression between two operands. + */ + buildComparison( + left: Expression, + _leftFieldDef: FieldDef | undefined, + op: string, + right: Expression, + _rightFieldDef: FieldDef | undefined, + ): Expression { + return this.eb(left, op as any, right) as Expression; + } + /** * Builds a JSON path selection expression. */ diff --git a/packages/orm/src/client/crud/dialects/postgresql.ts b/packages/orm/src/client/crud/dialects/postgresql.ts index 48d78b1d4..13c84adb9 100644 --- a/packages/orm/src/client/crud/dialects/postgresql.ts +++ b/packages/orm/src/client/crud/dialects/postgresql.ts @@ -32,6 +32,34 @@ export class PostgresCrudDialect extends LateralJoinDi Json: 'jsonb', }; + // Maps @db.* attribute names to PostgreSQL SQL types for use in VALUES table casts + private static readonly dbAttributeToSqlTypeMap: Record = { + '@db.Uuid': 'uuid', + '@db.Citext': 'citext', + '@db.Inet': 'inet', + '@db.Bit': 'bit', + '@db.VarBit': 'varbit', + '@db.Xml': 'xml', + '@db.Json': 'json', + '@db.JsonB': 'jsonb', + '@db.ByteA': 'bytea', + '@db.Text': 'text', + '@db.Char': 'bpchar', + '@db.VarChar': 'varchar', + '@db.Date': 'date', + '@db.Time': 'time', + '@db.Timetz': 'timetz', + '@db.Timestamp': 'timestamp', + '@db.Timestamptz': 'timestamptz', + '@db.SmallInt': 'smallint', + '@db.Integer': 'integer', + '@db.BigInt': 'bigint', + '@db.Real': 'real', + '@db.DoublePrecision': 'double precision', + '@db.Decimal': 'decimal', + '@db.Boolean': 'boolean', + }; + constructor(schema: Schema, options: ClientOptions) { super(schema, options); this.overrideTypeParsers(); @@ -406,7 +434,16 @@ export class PostgresCrudDialect extends LateralJoinDi ); } - private getSqlType(zmodelType: string) { + private getSqlType(zmodelType: string, attributes?: FieldDef['attributes']) { + // Check @db.* attributes first — they specify the exact native PostgreSQL type + if (attributes) { + for (const attr of attributes) { + const mapped = PostgresCrudDialect.dbAttributeToSqlTypeMap[attr.name]; + if (mapped) { + return mapped; + } + } + } if (isEnum(this.schema, zmodelType)) { // reduce enum to text for type compatibility return 'text'; @@ -415,6 +452,42 @@ export class PostgresCrudDialect extends LateralJoinDi } } + // Resolves the effective SQL type for a field: the native type from any @db.* attribute, + // or the base ZModel SQL type if no attribute is present, or undefined if the field is unknown. + private resolveFieldSqlType(fieldDef: FieldDef | undefined): { sqlType: string | undefined; hasDbOverride: boolean } { + if (!fieldDef) { + return { sqlType: undefined, hasDbOverride: false }; + } + const dbAttr = fieldDef.attributes?.find((a) => a.name.startsWith('@db.')); + if (dbAttr) { + return { sqlType: PostgresCrudDialect.dbAttributeToSqlTypeMap[dbAttr.name], hasDbOverride: true }; + } + return { sqlType: this.getSqlType(fieldDef.type), hasDbOverride: false }; + } + + override buildComparison( + left: Expression, + leftFieldDef: FieldDef | undefined, + op: string, + right: Expression, + rightFieldDef: FieldDef | undefined, + ) { + const leftResolved = this.resolveFieldSqlType(leftFieldDef); + const rightResolved = this.resolveFieldSqlType(rightFieldDef); + // If the resolved SQL types differ and at least one side carries a @db.* native type override, + // cast that side back to its base ZModel SQL type so PostgreSQL doesn't reject the comparison + // (e.g. "operator does not exist: uuid = text"). + if (leftResolved.sqlType !== rightResolved.sqlType && (leftResolved.hasDbOverride || rightResolved.hasDbOverride)) { + if (leftResolved.hasDbOverride) { + left = this.eb.cast(left, sql.raw(this.getSqlType(leftFieldDef!.type))); + } + if (rightResolved.hasDbOverride) { + right = this.eb.cast(right, sql.raw(this.getSqlType(rightFieldDef!.type))); + } + } + return super.buildComparison(left, leftFieldDef, op, right, rightFieldDef); + } + override getStringCasingBehavior() { // Postgres `LIKE` is case-sensitive, `ILIKE` is case-insensitive return { supportsILike: true, likeCaseSensitive: true }; @@ -449,7 +522,7 @@ export class PostgresCrudDialect extends LateralJoinDi ) .select( fields.map((f, i) => { - const mappedType = this.getSqlType(f.type); + const mappedType = this.getSqlType(f.type, f.attributes); const castType = f.array ? sql`${sql.raw(mappedType)}[]` : sql.raw(mappedType); return this.eb.cast(sql.ref(`$values.column${i + 1}`), castType).as(f.name); }), diff --git a/packages/plugins/policy/src/expression-transformer.ts b/packages/plugins/policy/src/expression-transformer.ts index 2a237c5cb..f85c35faf 100644 --- a/packages/plugins/policy/src/expression-transformer.ts +++ b/packages/plugins/policy/src/expression-transformer.ts @@ -265,7 +265,13 @@ export class ExpressionTransformer { } else if (this.isNullNode(left)) { return this.transformNullCheck(right, expr.op); } else { - return BinaryOperationNode.create(left, this.transformOperator(op), right); + const leftFieldDef = this.getFieldDefFromFieldRef(normalizedLeft, context); + const rightFieldDef = this.getFieldDefFromFieldRef(normalizedRight, context); + // Map ZModel operator to SQL operator string + const sqlOp = op === '==' ? '=' : op; + return this.dialect + .buildComparison(new ExpressionWrapper(left), leftFieldDef, sqlOp, new ExpressionWrapper(right), rightFieldDef) + .toOperationNode(); } } @@ -298,17 +304,17 @@ export class ExpressionTransformer { // if relation fields are used directly in comparison, it can only be compared with null, // so we normalize the args with the id field (use the first id field if multiple) let normalizedLeft: Expression = expr.left; - if (this.isRelationField(expr.left, context.modelOrType)) { + if (this.isRelationField(expr.left, context)) { invariant(ExpressionUtils.isNull(expr.right), 'only null comparison is supported for relation field'); - const leftRelDef = this.getFieldDefFromFieldRef(expr.left, context.modelOrType); + const leftRelDef = this.getFieldDefFromFieldRef(expr.left, context); invariant(leftRelDef, 'failed to get relation field definition'); const idFields = QueryUtils.requireIdFields(this.schema, leftRelDef.type); normalizedLeft = this.makeOrAppendMember(normalizedLeft, idFields[0]!); } let normalizedRight: Expression = expr.right; - if (this.isRelationField(expr.right, context.modelOrType)) { + if (this.isRelationField(expr.right, context)) { invariant(ExpressionUtils.isNull(expr.left), 'only null comparison is supported for relation field'); - const rightRelDef = this.getFieldDefFromFieldRef(expr.right, context.modelOrType); + const rightRelDef = this.getFieldDefFromFieldRef(expr.right, context); invariant(rightRelDef, 'failed to get relation field definition'); const idFields = QueryUtils.requireIdFields(this.schema, rightRelDef.type); normalizedRight = this.makeOrAppendMember(normalizedRight, idFields[0]!); @@ -349,7 +355,7 @@ export class ExpressionTransformer { ); let newContextModel: string; - const fieldDef = this.getFieldDefFromFieldRef(expr.left, context.modelOrType); + const fieldDef = this.getFieldDefFromFieldRef(expr.left, context); if (fieldDef) { invariant(fieldDef.relation, `field is not a relation: ${JSON.stringify(expr.left)}`); newContextModel = fieldDef.type; @@ -578,13 +584,6 @@ export class ExpressionTransformer { return logicalNot(this.dialect, this.transform(expr.operand, context)); } - private transformOperator(op: Exclude) { - const mappedOp = match(op) - .with('==', () => '=' as const) - .otherwise(() => op); - return OperatorNode.create(mappedOp); - } - @expr('call') // @ts-ignore private _call(expr: CallExpression, context: ExpressionTransformerContext) { @@ -979,12 +978,18 @@ export class ExpressionTransformer { } } - private isRelationField(expr: Expression, model: string) { - const fieldDef = this.getFieldDefFromFieldRef(expr, model); + private isRelationField(expr: Expression, context: ExpressionTransformerContext) { + const fieldDef = this.getFieldDefFromFieldRef(expr, context); return !!fieldDef?.relation; } - private getFieldDefFromFieldRef(expr: Expression, model: string): FieldDef | undefined { + private getFieldDefFromFieldRef(expr: Expression, context: ExpressionTransformerContext): FieldDef | undefined { + // `this.foo` references belong to `thisType` (the outer model in collection-predicate + // contexts); everything else uses `modelOrType`. + const model = + ExpressionUtils.isMember(expr) && ExpressionUtils.isThis(expr.receiver) + ? context.thisType + : context.modelOrType; if (ExpressionUtils.isField(expr)) { return QueryUtils.getField(this.schema, model, expr.field); } else if ( @@ -993,6 +998,19 @@ export class ExpressionTransformer { ExpressionUtils.isThis(expr.receiver) ) { return QueryUtils.getField(this.schema, model, expr.members[0]!); + } else if (ExpressionUtils.isMember(expr) && ExpressionUtils.isField(expr.receiver)) { + // relation chain access (e.g. `owner.id`, `user.profile.uuid_field`): walk the + // relation hops and return the terminal field's FieldDef so native-type info + // (@db.*) is available for casting in buildComparison + const receiverDef = QueryUtils.getField(this.schema, model, expr.receiver.field); + if (!receiverDef?.relation) return undefined; + let currModel = receiverDef.type; + for (let i = 0; i < expr.members.length - 1; i++) { + const hopDef = QueryUtils.getField(this.schema, currModel, expr.members[i]!); + if (!hopDef?.relation) return undefined; + currModel = hopDef.type; + } + return QueryUtils.getField(this.schema, currModel, expr.members[expr.members.length - 1]!); } else { return undefined; } diff --git a/tests/regression/test/issue-2394.test.ts b/tests/regression/test/issue-2394.test.ts new file mode 100644 index 000000000..5acdcb801 --- /dev/null +++ b/tests/regression/test/issue-2394.test.ts @@ -0,0 +1,156 @@ +import { createPolicyTestClient } from '@zenstackhq/testtools'; +import { randomUUID } from 'node:crypto'; +import { describe, expect, it } from 'vitest'; + +describe('Regression for issue #2394', () => { + it('should work with post-update rules when uuid fields are used', async () => { + const db = await createPolicyTestClient( + ` +model Item { + id String @id @default(uuid()) @db.Uuid + status String + + @@allow('create,read,update', true) + @@deny('post-update', before().status == status) +} + `, + { provider: 'postgresql', usePrismaPush: true }, + ); + + const item = await db.item.create({ data: { status: 'draft' } }); + + // updating with a different status should succeed (post-update policy: deny if status didn't change) + const updated = await db.item.update({ where: { id: item.id }, data: { status: 'published' } }); + expect(updated.status).toBe('published'); + + // updating with the same status should be denied + await expect( + db.item.update({ where: { id: updated.id }, data: { status: 'published' } }), + ).toBeRejectedByPolicy(); + }); + + it('should work with policies comparing string field with uuid field', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id String @id @default(uuid()) @db.Uuid + id1 String + value Int + @@allow('all', id == id1) +} + `, + { provider: 'postgresql', usePrismaPush: true }, + ); + + const newId = randomUUID(); + + await expect(db.foo.create({ data: { id: newId, id1: newId, value: 0 } })).toResolveTruthy(); + await expect(db.foo.update({ where: { id: newId }, data: { value: 1 } })).toResolveTruthy(); + }); + + it('should work with policies comparing related @db.Uuid field to plain string field (single-hop)', async () => { + // `owner.tag` (a non-FK @db.Uuid field) generates a correlated subquery whose result + // type is uuid, compared against tagRef which is plain text. FK fields are all kept + // type-compatible so the migration engine doesn't reject the schema. + const db = await createPolicyTestClient( + ` +model User { + id String @id @default(uuid()) @db.Uuid + tag String @db.Uuid + items Item[] + + @@allow('all', true) +} + +model Item { + id String @id @default(uuid()) @db.Uuid + ownerId String @db.Uuid + owner User @relation(fields: [ownerId], references: [id]) + tagRef String + + @@allow('all', owner.tag == tagRef) +} + `, + { provider: 'postgresql', usePrismaPush: true }, + ); + + const rawDb = db.$unuseAll(); + const tag = randomUUID(); + const user = await rawDb.user.create({ data: { tag } }); + await rawDb.item.create({ data: { ownerId: user.id, tagRef: tag } }); + + const items = await db.item.findMany(); + expect(items).toHaveLength(1); + }); + + it('should work with policies comparing related @db.Uuid field to plain string field (multi-hop)', async () => { + // `org.owner.token` is a two-hop chain through non-FK @db.Uuid fields; without + // multi-hop traversal the terminal FieldDef is invisible and the uuid/text mismatch + // is not caught. All FK fields are kept type-compatible with their PKs. + const db = await createPolicyTestClient( + ` +model User { + id String @id @default(uuid()) @db.Uuid + token String @db.Uuid + ownedOrgs Org[] + orgs OrgMember[] + + @@allow('all', true) +} + +model Org { + id String @id @default(uuid()) @db.Uuid + ownerId String @db.Uuid + owner User @relation(fields: [ownerId], references: [id]) + members OrgMember[] + + @@allow('all', true) +} + +model OrgMember { + id String @id @default(uuid()) @db.Uuid + orgId String @db.Uuid + org Org @relation(fields: [orgId], references: [id]) + userId String @db.Uuid + user User @relation(fields: [userId], references: [id]) + tokenRef String + + @@allow('all', org.owner.token == tokenRef) +} + `, + { provider: 'postgresql', usePrismaPush: true }, + ); + + const rawDb = db.$unuseAll(); + const token = randomUUID(); + const user = await rawDb.user.create({ data: { token } }); + const org = await rawDb.org.create({ data: { ownerId: user.id } }); + await rawDb.orgMember.create({ data: { orgId: org.id, userId: user.id, tokenRef: token } }); + + const members = await db.orgMember.findMany(); + expect(members).toHaveLength(1); + }); + + it('should work with policies comparing @db.Uuid field to auth()', async () => { + // Exercises transformAuthBinary: `id == auth()` expands to `id == auth().id`, where auth().id + // is emitted as a text parameter even though the auth model's id also has @db.Uuid. + const db = await createPolicyTestClient( + ` +model User { + id String @id @default(uuid()) @db.Uuid + value Int + + @@allow('all', id == auth().id) +} + `, + { provider: 'postgresql', usePrismaPush: true }, + ); + + const rawDb = db.$unuseAll(); + const user = await rawDb.user.create({ data: { value: 0 } }); + + const authedDb = db.$setAuth(user); + await expect(authedDb.user.findMany()).toResolveTruthy(); + await expect(authedDb.user.update({ where: { id: user.id }, data: { value: 1 } })).toResolveTruthy(); + }); +});