Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions packages/orm/src/client/crud/dialects/base-dialect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ import {
flattenCompoundUniqueFilters,
getDelegateDescendantModels,
getManyToManyRelation,
getModelFields,
getRelationForeignKeyFieldPairs,
isEnum,
isTypeDef,
getModelFields,
makeDefaultOrderBy,
requireField,
requireIdFields,
Expand Down Expand Up @@ -1162,7 +1162,8 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {

// client-level: check both uncapitalized (current) and original (backward compat) model name
const uncapModel = lowerCaseFirst(model);
const omitConfig = (this.options.omit as Record<string, any> | undefined)?.[uncapModel] ??
const omitConfig =
(this.options.omit as Record<string, any> | undefined)?.[uncapModel] ??
(this.options.omit as Record<string, any> | undefined)?.[model];
if (omitConfig && typeof omitConfig === 'object' && typeof omitConfig[field] === 'boolean') {
return omitConfig[field];
Expand Down Expand Up @@ -1357,7 +1358,9 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
const computedFields = this.options.computedFields as Record<string, any>;
// check both uncapitalized (current) and original (backward compat) model name
const computedModel = fieldDef.originModel ?? model;
computer = computedFields?.[lowerCaseFirst(computedModel)]?.[field] ?? computedFields?.[computedModel]?.[field];
computer =
computedFields?.[lowerCaseFirst(computedModel)]?.[field] ??
computedFields?.[computedModel]?.[field];
}
if (!computer) {
throw createConfigError(`Computed field "${field}" implementation not provided for model "${model}"`);
Expand Down Expand Up @@ -1489,6 +1492,19 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
*/
abstract buildValuesTableSelect(fields: FieldDef[], rows: unknown[][]): SelectQueryBuilder<any, any, any>;

/**
* Builds a binary comparison expression between two operands.
*/
buildComparison(
left: Expression<unknown>,
_leftFieldDef: FieldDef | undefined,
op: string,
right: Expression<unknown>,
_rightFieldDef: FieldDef | undefined,
): Expression<SqlBool> {
return this.eb(left, op as any, right) as Expression<SqlBool>;
}

/**
* Builds a JSON path selection expression.
*/
Expand Down
77 changes: 75 additions & 2 deletions packages/orm/src/client/crud/dialects/postgresql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,34 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends LateralJoinDi
Json: 'jsonb',
};

// Maps @db.* attribute names to PostgreSQL SQL types for use in VALUES table casts
private static readonly dbAttributeToSqlTypeMap: Record<string, string> = {
'@db.Uuid': 'uuid',
'@db.Citext': 'citext',
'@db.Inet': 'inet',
'@db.Bit': 'bit',
'@db.VarBit': 'varbit',
'@db.Xml': 'xml',
'@db.Json': 'json',
'@db.JsonB': 'jsonb',
'@db.ByteA': 'bytea',
'@db.Text': 'text',
'@db.Char': 'bpchar',
'@db.VarChar': 'varchar',
'@db.Date': 'date',
'@db.Time': 'time',
'@db.Timetz': 'timetz',
'@db.Timestamp': 'timestamp',
'@db.Timestamptz': 'timestamptz',
'@db.SmallInt': 'smallint',
'@db.Integer': 'integer',
'@db.BigInt': 'bigint',
'@db.Real': 'real',
'@db.DoublePrecision': 'double precision',
'@db.Decimal': 'decimal',
'@db.Boolean': 'boolean',
};

constructor(schema: Schema, options: ClientOptions<Schema>) {
super(schema, options);
this.overrideTypeParsers();
Expand Down Expand Up @@ -406,7 +434,16 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends LateralJoinDi
);
}

private getSqlType(zmodelType: string) {
private getSqlType(zmodelType: string, attributes?: FieldDef['attributes']) {
// Check @db.* attributes first — they specify the exact native PostgreSQL type
if (attributes) {
for (const attr of attributes) {
const mapped = PostgresCrudDialect.dbAttributeToSqlTypeMap[attr.name];
if (mapped) {
return mapped;
}
}
}
if (isEnum(this.schema, zmodelType)) {
// reduce enum to text for type compatibility
return 'text';
Expand All @@ -415,6 +452,42 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends LateralJoinDi
}
}

// Resolves the effective SQL type for a field: the native type from any @db.* attribute,
// or the base ZModel SQL type if no attribute is present, or undefined if the field is unknown.
private resolveFieldSqlType(fieldDef: FieldDef | undefined): { sqlType: string | undefined; hasDbOverride: boolean } {
if (!fieldDef) {
return { sqlType: undefined, hasDbOverride: false };
}
const dbAttr = fieldDef.attributes?.find((a) => a.name.startsWith('@db.'));
if (dbAttr) {
return { sqlType: PostgresCrudDialect.dbAttributeToSqlTypeMap[dbAttr.name], hasDbOverride: true };
}
return { sqlType: this.getSqlType(fieldDef.type), hasDbOverride: false };
}

override buildComparison(
left: Expression<unknown>,
leftFieldDef: FieldDef | undefined,
op: string,
right: Expression<unknown>,
rightFieldDef: FieldDef | undefined,
) {
const leftResolved = this.resolveFieldSqlType(leftFieldDef);
const rightResolved = this.resolveFieldSqlType(rightFieldDef);
// If the resolved SQL types differ and at least one side carries a @db.* native type override,
// cast that side back to its base ZModel SQL type so PostgreSQL doesn't reject the comparison
// (e.g. "operator does not exist: uuid = text").
if (leftResolved.sqlType !== rightResolved.sqlType && (leftResolved.hasDbOverride || rightResolved.hasDbOverride)) {
if (leftResolved.hasDbOverride) {
left = this.eb.cast(left, sql.raw(this.getSqlType(leftFieldDef!.type)));
}
if (rightResolved.hasDbOverride) {
right = this.eb.cast(right, sql.raw(this.getSqlType(rightFieldDef!.type)));
}
}
return super.buildComparison(left, leftFieldDef, op, right, rightFieldDef);
}
Comment thread
ymc9 marked this conversation as resolved.

override getStringCasingBehavior() {
// Postgres `LIKE` is case-sensitive, `ILIKE` is case-insensitive
return { supportsILike: true, likeCaseSensitive: true };
Expand Down Expand Up @@ -449,7 +522,7 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends LateralJoinDi
)
.select(
fields.map((f, i) => {
const mappedType = this.getSqlType(f.type);
const mappedType = this.getSqlType(f.type, f.attributes);
const castType = f.array ? sql`${sql.raw(mappedType)}[]` : sql.raw(mappedType);
return this.eb.cast(sql.ref(`$values.column${i + 1}`), castType).as(f.name);
}),
Expand Down
15 changes: 7 additions & 8 deletions packages/plugins/policy/src/expression-transformer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,13 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
} else if (this.isNullNode(left)) {
return this.transformNullCheck(right, expr.op);
} else {
return BinaryOperationNode.create(left, this.transformOperator(op), right);
const leftFieldDef = this.getFieldDefFromFieldRef(normalizedLeft, context.modelOrType);
const rightFieldDef = this.getFieldDefFromFieldRef(normalizedRight, context.modelOrType);
// Map ZModel operator to SQL operator string
const sqlOp = op === '==' ? '=' : op;
return this.dialect
.buildComparison(new ExpressionWrapper(left), leftFieldDef, sqlOp, new ExpressionWrapper(right), rightFieldDef)
.toOperationNode();
Comment thread
ymc9 marked this conversation as resolved.
Outdated
}
}

Expand Down Expand Up @@ -578,13 +584,6 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
return logicalNot(this.dialect, this.transform(expr.operand, context));
}

private transformOperator(op: Exclude<BinaryOperator, '?' | '!' | '^'>) {
const mappedOp = match(op)
.with('==', () => '=' as const)
.otherwise(() => op);
return OperatorNode.create(mappedOp);
}

@expr('call')
// @ts-ignore
private _call(expr: CallExpression, context: ExpressionTransformerContext) {
Expand Down
48 changes: 48 additions & 0 deletions tests/regression/test/issue-2394.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import { createPolicyTestClient } from '@zenstackhq/testtools';
import { randomUUID } from 'node:crypto';
import { describe, expect, it } from 'vitest';

describe('Regression for issue #2394', () => {
it('should work with post-update rules when uuid fields are used', async () => {
const db = await createPolicyTestClient(
`
model Item {
id String @id @default(uuid()) @db.Uuid
status String

@@allow('create,read,update', true)
@@deny('post-update', before().status == status)
}
`,
{ provider: 'postgresql', usePrismaPush: true },
);

const item = await db.item.create({ data: { status: 'draft' } });

// updating with a different status should succeed (post-update policy: deny if status didn't change)
const updated = await db.item.update({ where: { id: item.id }, data: { status: 'published' } });
expect(updated.status).toBe('published');

// // updating with the same status should be denied
// await expect(db.item.update({ where: { id: updated.id }, data: { status: 'published' } })).rejects.toThrow();
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
});

it('should work with policies comparing string field with uuid field', async () => {
const db = await createPolicyTestClient(
`
model Foo {
id String @id @default(uuid()) @db.Uuid
id1 String
value Int
@@allow('all', id == id1)
}
`,
{ provider: 'postgresql', usePrismaPush: true },
);

const newId = randomUUID();

await expect(db.foo.create({ data: { id: newId, id1: newId, value: 0 } })).toResolveTruthy();
await expect(db.foo.update({ where: { id: newId }, data: { value: 1 } })).toResolveTruthy();
Comment thread
coderabbitai[bot] marked this conversation as resolved.
});
});
Loading