Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions packages/orm/src/client/crud/dialects/base-dialect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ import {
flattenCompoundUniqueFilters,
getDelegateDescendantModels,
getManyToManyRelation,
getModelFields,
getRelationForeignKeyFieldPairs,
isEnum,
isTypeDef,
getModelFields,
makeDefaultOrderBy,
requireField,
requireIdFields,
Expand Down Expand Up @@ -1162,7 +1162,8 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {

// client-level: check both uncapitalized (current) and original (backward compat) model name
const uncapModel = lowerCaseFirst(model);
const omitConfig = (this.options.omit as Record<string, any> | undefined)?.[uncapModel] ??
const omitConfig =
(this.options.omit as Record<string, any> | undefined)?.[uncapModel] ??
(this.options.omit as Record<string, any> | undefined)?.[model];
if (omitConfig && typeof omitConfig === 'object' && typeof omitConfig[field] === 'boolean') {
return omitConfig[field];
Expand Down Expand Up @@ -1357,7 +1358,9 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
const computedFields = this.options.computedFields as Record<string, any>;
// check both uncapitalized (current) and original (backward compat) model name
const computedModel = fieldDef.originModel ?? model;
computer = computedFields?.[lowerCaseFirst(computedModel)]?.[field] ?? computedFields?.[computedModel]?.[field];
computer =
computedFields?.[lowerCaseFirst(computedModel)]?.[field] ??
computedFields?.[computedModel]?.[field];
}
if (!computer) {
throw createConfigError(`Computed field "${field}" implementation not provided for model "${model}"`);
Expand Down Expand Up @@ -1489,6 +1492,19 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
*/
abstract buildValuesTableSelect(fields: FieldDef[], rows: unknown[][]): SelectQueryBuilder<any, any, any>;

/**
* Builds a binary comparison expression between two operands.
*/
buildComparison(
left: Expression<unknown>,
_leftFieldDef: FieldDef | undefined,
op: string,
right: Expression<unknown>,
_rightFieldDef: FieldDef | undefined,
): Expression<SqlBool> {
return this.eb(left, op as any, right) as Expression<SqlBool>;
}

/**
* Builds a JSON path selection expression.
*/
Expand Down
77 changes: 75 additions & 2 deletions packages/orm/src/client/crud/dialects/postgresql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,34 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends LateralJoinDi
Json: 'jsonb',
};

// Maps @db.* attribute names to PostgreSQL SQL types for use in VALUES table casts
private static readonly dbAttributeToSqlTypeMap: Record<string, string> = {
'@db.Uuid': 'uuid',
'@db.Citext': 'citext',
'@db.Inet': 'inet',
'@db.Bit': 'bit',
'@db.VarBit': 'varbit',
'@db.Xml': 'xml',
'@db.Json': 'json',
'@db.JsonB': 'jsonb',
'@db.ByteA': 'bytea',
'@db.Text': 'text',
'@db.Char': 'bpchar',
'@db.VarChar': 'varchar',
'@db.Date': 'date',
'@db.Time': 'time',
'@db.Timetz': 'timetz',
'@db.Timestamp': 'timestamp',
'@db.Timestamptz': 'timestamptz',
'@db.SmallInt': 'smallint',
'@db.Integer': 'integer',
'@db.BigInt': 'bigint',
'@db.Real': 'real',
'@db.DoublePrecision': 'double precision',
'@db.Decimal': 'decimal',
'@db.Boolean': 'boolean',
};

constructor(schema: Schema, options: ClientOptions<Schema>) {
super(schema, options);
this.overrideTypeParsers();
Expand Down Expand Up @@ -406,7 +434,16 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends LateralJoinDi
);
}

private getSqlType(zmodelType: string) {
private getSqlType(zmodelType: string, attributes?: FieldDef['attributes']) {
// Check @db.* attributes first — they specify the exact native PostgreSQL type
if (attributes) {
for (const attr of attributes) {
const mapped = PostgresCrudDialect.dbAttributeToSqlTypeMap[attr.name];
if (mapped) {
return mapped;
}
}
}
if (isEnum(this.schema, zmodelType)) {
// reduce enum to text for type compatibility
return 'text';
Expand All @@ -415,6 +452,42 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends LateralJoinDi
}
}

// Resolves the effective SQL type for a field: the native type from any @db.* attribute,
// or the base ZModel SQL type if no attribute is present, or undefined if the field is unknown.
private resolveFieldSqlType(fieldDef: FieldDef | undefined): { sqlType: string | undefined; hasDbOverride: boolean } {
if (!fieldDef) {
return { sqlType: undefined, hasDbOverride: false };
}
const dbAttr = fieldDef.attributes?.find((a) => a.name.startsWith('@db.'));
if (dbAttr) {
return { sqlType: PostgresCrudDialect.dbAttributeToSqlTypeMap[dbAttr.name], hasDbOverride: true };
}
return { sqlType: this.getSqlType(fieldDef.type), hasDbOverride: false };
}

override buildComparison(
left: Expression<unknown>,
leftFieldDef: FieldDef | undefined,
op: string,
right: Expression<unknown>,
rightFieldDef: FieldDef | undefined,
) {
const leftResolved = this.resolveFieldSqlType(leftFieldDef);
const rightResolved = this.resolveFieldSqlType(rightFieldDef);
// If the resolved SQL types differ and at least one side carries a @db.* native type override,
// cast that side back to its base ZModel SQL type so PostgreSQL doesn't reject the comparison
// (e.g. "operator does not exist: uuid = text").
if (leftResolved.sqlType !== rightResolved.sqlType && (leftResolved.hasDbOverride || rightResolved.hasDbOverride)) {
if (leftResolved.hasDbOverride) {
left = this.eb.cast(left, sql.raw(this.getSqlType(leftFieldDef!.type)));
}
if (rightResolved.hasDbOverride) {
right = this.eb.cast(right, sql.raw(this.getSqlType(rightFieldDef!.type)));
}
}
return super.buildComparison(left, leftFieldDef, op, right, rightFieldDef);
}
Comment thread
ymc9 marked this conversation as resolved.

override getStringCasingBehavior() {
// Postgres `LIKE` is case-sensitive, `ILIKE` is case-insensitive
return { supportsILike: true, likeCaseSensitive: true };
Expand Down Expand Up @@ -449,7 +522,7 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends LateralJoinDi
)
.select(
fields.map((f, i) => {
const mappedType = this.getSqlType(f.type);
const mappedType = this.getSqlType(f.type, f.attributes);
const castType = f.array ? sql`${sql.raw(mappedType)}[]` : sql.raw(mappedType);
return this.eb.cast(sql.ref(`$values.column${i + 1}`), castType).as(f.name);
}),
Expand Down
38 changes: 30 additions & 8 deletions packages/plugins/policy/src/expression-transformer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,23 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
} else if (this.isNullNode(left)) {
return this.transformNullCheck(right, expr.op);
} else {
return BinaryOperationNode.create(left, this.transformOperator(op), right);
// `this.foo` references belong to `thisType`, not `modelOrType` (which is the
// collection element type in predicate contexts); everything else uses `modelOrType`.
const leftModel =
ExpressionUtils.isMember(normalizedLeft) && ExpressionUtils.isThis(normalizedLeft.receiver)
? context.thisType
: context.modelOrType;
const rightModel =
ExpressionUtils.isMember(normalizedRight) && ExpressionUtils.isThis(normalizedRight.receiver)
? context.thisType
: context.modelOrType;
const leftFieldDef = this.getFieldDefFromFieldRef(normalizedLeft, leftModel);
const rightFieldDef = this.getFieldDefFromFieldRef(normalizedRight, rightModel);
// Map ZModel operator to SQL operator string
const sqlOp = op === '==' ? '=' : op;
return this.dialect
.buildComparison(new ExpressionWrapper(left), leftFieldDef, sqlOp, new ExpressionWrapper(right), rightFieldDef)
.toOperationNode();
}
}

Expand Down Expand Up @@ -578,13 +594,6 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
return logicalNot(this.dialect, this.transform(expr.operand, context));
}

private transformOperator(op: Exclude<BinaryOperator, '?' | '!' | '^'>) {
const mappedOp = match(op)
.with('==', () => '=' as const)
.otherwise(() => op);
return OperatorNode.create(mappedOp);
}

@expr('call')
// @ts-ignore
private _call(expr: CallExpression, context: ExpressionTransformerContext) {
Expand Down Expand Up @@ -993,6 +1002,19 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
ExpressionUtils.isThis(expr.receiver)
) {
return QueryUtils.getField(this.schema, model, expr.members[0]!);
} else if (ExpressionUtils.isMember(expr) && ExpressionUtils.isField(expr.receiver)) {
// relation chain access (e.g. `owner.id`, `user.profile.uuid_field`): walk the
// relation hops and return the terminal field's FieldDef so native-type info
// (@db.*) is available for casting in buildComparison
const receiverDef = QueryUtils.getField(this.schema, model, expr.receiver.field);
if (!receiverDef?.relation) return undefined;
let currModel = receiverDef.type;
for (let i = 0; i < expr.members.length - 1; i++) {
const hopDef = QueryUtils.getField(this.schema, currModel, expr.members[i]!);
if (!hopDef?.relation) return undefined;
currModel = hopDef.type;
}
return QueryUtils.getField(this.schema, currModel, expr.members[expr.members.length - 1]!);
} else {
return undefined;
}
Expand Down
156 changes: 156 additions & 0 deletions tests/regression/test/issue-2394.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import { createPolicyTestClient } from '@zenstackhq/testtools';
import { randomUUID } from 'node:crypto';
import { describe, expect, it } from 'vitest';

describe('Regression for issue #2394', () => {
it('should work with post-update rules when uuid fields are used', async () => {
const db = await createPolicyTestClient(
`
model Item {
id String @id @default(uuid()) @db.Uuid
status String

@@allow('create,read,update', true)
@@deny('post-update', before().status == status)
}
`,
{ provider: 'postgresql', usePrismaPush: true },
);

const item = await db.item.create({ data: { status: 'draft' } });

// updating with a different status should succeed (post-update policy: deny if status didn't change)
const updated = await db.item.update({ where: { id: item.id }, data: { status: 'published' } });
expect(updated.status).toBe('published');

// updating with the same status should be denied
await expect(
db.item.update({ where: { id: updated.id }, data: { status: 'published' } }),
).toBeRejectedByPolicy();
});

it('should work with policies comparing string field with uuid field', async () => {
const db = await createPolicyTestClient(
`
model Foo {
id String @id @default(uuid()) @db.Uuid
id1 String
value Int
@@allow('all', id == id1)
}
`,
{ provider: 'postgresql', usePrismaPush: true },
);

const newId = randomUUID();

await expect(db.foo.create({ data: { id: newId, id1: newId, value: 0 } })).toResolveTruthy();
await expect(db.foo.update({ where: { id: newId }, data: { value: 1 } })).toResolveTruthy();
Comment thread
coderabbitai[bot] marked this conversation as resolved.
});

it('should work with policies comparing related @db.Uuid field to plain string field (single-hop)', async () => {
// `owner.tag` (a non-FK @db.Uuid field) generates a correlated subquery whose result
// type is uuid, compared against tagRef which is plain text. FK fields are all kept
// type-compatible so the migration engine doesn't reject the schema.
const db = await createPolicyTestClient(
`
model User {
id String @id @default(uuid()) @db.Uuid
tag String @db.Uuid
items Item[]

@@allow('all', true)
}

model Item {
id String @id @default(uuid()) @db.Uuid
ownerId String @db.Uuid
owner User @relation(fields: [ownerId], references: [id])
tagRef String

@@allow('all', owner.tag == tagRef)
}
`,
{ provider: 'postgresql', usePrismaPush: true },
);

const rawDb = db.$unuseAll();
const tag = randomUUID();
const user = await rawDb.user.create({ data: { tag } });
await rawDb.item.create({ data: { ownerId: user.id, tagRef: tag } });

const items = await db.item.findMany();
expect(items).toHaveLength(1);
});

it('should work with policies comparing related @db.Uuid field to plain string field (multi-hop)', async () => {
// `org.owner.token` is a two-hop chain through non-FK @db.Uuid fields; without
// multi-hop traversal the terminal FieldDef is invisible and the uuid/text mismatch
// is not caught. All FK fields are kept type-compatible with their PKs.
const db = await createPolicyTestClient(
`
model User {
id String @id @default(uuid()) @db.Uuid
token String @db.Uuid
ownedOrgs Org[]
orgs OrgMember[]

@@allow('all', true)
}

model Org {
id String @id @default(uuid()) @db.Uuid
ownerId String @db.Uuid
owner User @relation(fields: [ownerId], references: [id])
members OrgMember[]

@@allow('all', true)
}

model OrgMember {
id String @id @default(uuid()) @db.Uuid
orgId String @db.Uuid
org Org @relation(fields: [orgId], references: [id])
userId String @db.Uuid
user User @relation(fields: [userId], references: [id])
tokenRef String

@@allow('all', org.owner.token == tokenRef)
}
`,
{ provider: 'postgresql', usePrismaPush: true },
);

const rawDb = db.$unuseAll();
const token = randomUUID();
const user = await rawDb.user.create({ data: { token } });
const org = await rawDb.org.create({ data: { ownerId: user.id } });
await rawDb.orgMember.create({ data: { orgId: org.id, userId: user.id, tokenRef: token } });

const members = await db.orgMember.findMany();
expect(members).toHaveLength(1);
});

it('should work with policies comparing @db.Uuid field to auth()', async () => {
// Exercises transformAuthBinary: `id == auth()` expands to `id == auth().id`, where auth().id
// is emitted as a text parameter even though the auth model's id also has @db.Uuid.
const db = await createPolicyTestClient(
`
model User {
id String @id @default(uuid()) @db.Uuid
value Int

@@allow('all', id == auth().id)
}
`,
{ provider: 'postgresql', usePrismaPush: true },
);

const rawDb = db.$unuseAll();
const user = await rawDb.user.create({ data: { value: 0 } });

const authedDb = db.$setAuth(user);
await expect(authedDb.user.findMany()).toResolveTruthy();
await expect(authedDb.user.update({ where: { id: user.id }, data: { value: 1 } })).toResolveTruthy();
});
});
Loading