Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions packages/orm/src/client/crud/dialects/base-dialect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ import {
flattenCompoundUniqueFilters,
getDelegateDescendantModels,
getManyToManyRelation,
getModelFields,
getRelationForeignKeyFieldPairs,
isEnum,
isTypeDef,
getModelFields,
makeDefaultOrderBy,
requireField,
requireIdFields,
Expand Down Expand Up @@ -1162,7 +1162,8 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {

// client-level: check both uncapitalized (current) and original (backward compat) model name
const uncapModel = lowerCaseFirst(model);
const omitConfig = (this.options.omit as Record<string, any> | undefined)?.[uncapModel] ??
const omitConfig =
(this.options.omit as Record<string, any> | undefined)?.[uncapModel] ??
(this.options.omit as Record<string, any> | undefined)?.[model];
if (omitConfig && typeof omitConfig === 'object' && typeof omitConfig[field] === 'boolean') {
return omitConfig[field];
Expand Down Expand Up @@ -1357,7 +1358,9 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
const computedFields = this.options.computedFields as Record<string, any>;
// check both uncapitalized (current) and original (backward compat) model name
const computedModel = fieldDef.originModel ?? model;
computer = computedFields?.[lowerCaseFirst(computedModel)]?.[field] ?? computedFields?.[computedModel]?.[field];
computer =
computedFields?.[lowerCaseFirst(computedModel)]?.[field] ??
computedFields?.[computedModel]?.[field];
}
if (!computer) {
throw createConfigError(`Computed field "${field}" implementation not provided for model "${model}"`);
Expand Down Expand Up @@ -1489,6 +1492,19 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
*/
abstract buildValuesTableSelect(fields: FieldDef[], rows: unknown[][]): SelectQueryBuilder<any, any, any>;

/**
* Builds a binary comparison expression between two operands.
*/
buildComparison(
left: Expression<unknown>,
_leftFieldDef: FieldDef | undefined,
op: string,
right: Expression<unknown>,
_rightFieldDef: FieldDef | undefined,
): Expression<SqlBool> {
return this.eb(left, op as any, right) as Expression<SqlBool>;
}

/**
* Builds a JSON path selection expression.
*/
Expand Down
62 changes: 60 additions & 2 deletions packages/orm/src/client/crud/dialects/postgresql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,34 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends LateralJoinDi
Json: 'jsonb',
};

// Maps @db.* attribute names to PostgreSQL SQL types for use in VALUES table casts
private static readonly dbAttributeToSqlTypeMap: Record<string, string> = {
'@db.Uuid': 'uuid',
'@db.Citext': 'citext',
'@db.Inet': 'inet',
'@db.Bit': 'bit',
'@db.VarBit': 'varbit',
'@db.Xml': 'xml',
'@db.Json': 'json',
'@db.JsonB': 'jsonb',
'@db.ByteA': 'bytea',
'@db.Text': 'text',
'@db.Char': 'bpchar',
'@db.VarChar': 'varchar',
'@db.Date': 'date',
'@db.Time': 'time',
'@db.Timetz': 'timetz',
'@db.Timestamp': 'timestamp',
'@db.Timestamptz': 'timestamptz',
'@db.SmallInt': 'smallint',
'@db.Integer': 'integer',
'@db.BigInt': 'bigint',
'@db.Real': 'real',
'@db.DoublePrecision': 'double precision',
'@db.Decimal': 'decimal',
'@db.Boolean': 'boolean',
};

constructor(schema: Schema, options: ClientOptions<Schema>) {
super(schema, options);
this.overrideTypeParsers();
Expand Down Expand Up @@ -406,7 +434,16 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends LateralJoinDi
);
}

private getSqlType(zmodelType: string) {
private getSqlType(zmodelType: string, attributes?: FieldDef['attributes']) {
// Check @db.* attributes first — they specify the exact native PostgreSQL type
if (attributes) {
for (const attr of attributes) {
const mapped = PostgresCrudDialect.dbAttributeToSqlTypeMap[attr.name];
if (mapped) {
return mapped;
}
}
}
if (isEnum(this.schema, zmodelType)) {
// reduce enum to text for type compatibility
return 'text';
Expand All @@ -415,6 +452,27 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends LateralJoinDi
}
}

override buildComparison(
left: Expression<unknown>,
leftFieldDef: FieldDef | undefined,
op: string,
right: Expression<unknown>,
rightFieldDef: FieldDef | undefined,
) {
const leftHasNativeType = leftFieldDef?.attributes?.some((a) => a.name.startsWith('@db.')) ?? false;
const rightHasNativeType = rightFieldDef?.attributes?.some((a) => a.name.startsWith('@db.')) ?? false;
// When one side has a @db.* native type override and the other doesn't (or its type can't be
// determined, e.g. auth() values arrive as untyped params), cast the @db.* side back to the
// base SQL type for its ZModel type so PostgreSQL doesn't reject the comparison
// (e.g. "operator does not exist: uuid = text").
if (leftHasNativeType && !rightHasNativeType) {
left = this.eb.cast(left, sql.raw(this.getSqlType(leftFieldDef!.type)));
} else if (rightHasNativeType && !leftHasNativeType) {
right = this.eb.cast(right, sql.raw(this.getSqlType(rightFieldDef!.type)));
}
return super.buildComparison(left, leftFieldDef, op, right, rightFieldDef);
}
Comment thread
ymc9 marked this conversation as resolved.

override getStringCasingBehavior() {
// Postgres `LIKE` is case-sensitive, `ILIKE` is case-insensitive
return { supportsILike: true, likeCaseSensitive: true };
Expand Down Expand Up @@ -449,7 +507,7 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends LateralJoinDi
)
.select(
fields.map((f, i) => {
const mappedType = this.getSqlType(f.type);
const mappedType = this.getSqlType(f.type, f.attributes);
const castType = f.array ? sql`${sql.raw(mappedType)}[]` : sql.raw(mappedType);
return this.eb.cast(sql.ref(`$values.column${i + 1}`), castType).as(f.name);
}),
Expand Down
15 changes: 7 additions & 8 deletions packages/plugins/policy/src/expression-transformer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,13 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
} else if (this.isNullNode(left)) {
return this.transformNullCheck(right, expr.op);
} else {
return BinaryOperationNode.create(left, this.transformOperator(op), right);
const leftFieldDef = this.getFieldDefFromFieldRef(normalizedLeft, context.modelOrType);
const rightFieldDef = this.getFieldDefFromFieldRef(normalizedRight, context.modelOrType);
// Map ZModel operator to SQL operator string
const sqlOp = op === '==' ? '=' : op;
return this.dialect
.buildComparison(new ExpressionWrapper(left), leftFieldDef, sqlOp, new ExpressionWrapper(right), rightFieldDef)
.toOperationNode();
Comment thread
ymc9 marked this conversation as resolved.
Outdated
}
}

Expand Down Expand Up @@ -578,13 +584,6 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
return logicalNot(this.dialect, this.transform(expr.operand, context));
}

private transformOperator(op: Exclude<BinaryOperator, '?' | '!' | '^'>) {
const mappedOp = match(op)
.with('==', () => '=' as const)
.otherwise(() => op);
return OperatorNode.create(mappedOp);
}

@expr('call')
// @ts-ignore
private _call(expr: CallExpression, context: ExpressionTransformerContext) {
Expand Down
48 changes: 48 additions & 0 deletions tests/regression/test/issue-2394.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import { createPolicyTestClient } from '@zenstackhq/testtools';
import { randomUUID } from 'node:crypto';
import { describe, expect, it } from 'vitest';

describe('Regression for issue #2394', () => {
it('should work with post-update rules when uuid fields are used', async () => {
const db = await createPolicyTestClient(
`
model Item {
id String @id @default(uuid()) @db.Uuid
status String

@@allow('create,read,update', true)
@@deny('post-update', before().status == status)
}
`,
{ provider: 'postgresql', usePrismaPush: true },
);

const item = await db.item.create({ data: { status: 'draft' } });

// updating with a different status should succeed (post-update policy: deny if status didn't change)
const updated = await db.item.update({ where: { id: item.id }, data: { status: 'published' } });
expect(updated.status).toBe('published');

// // updating with the same status should be denied
// await expect(db.item.update({ where: { id: updated.id }, data: { status: 'published' } })).rejects.toThrow();
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
});

it('should work with policies comparing string field with uuid field', async () => {
const db = await createPolicyTestClient(
`
model Foo {
id String @id @default(uuid()) @db.Uuid
id1 String
value Int
@@allow('all', id == id1)
}
`,
{ provider: 'postgresql', usePrismaPush: true },
);

const newId = randomUUID();

await expect(db.foo.create({ data: { id: newId, id1: newId, value: 0 } })).toResolveTruthy();
await expect(db.foo.update({ where: { id: newId }, data: { value: 1 } })).toResolveTruthy();
Comment thread
coderabbitai[bot] marked this conversation as resolved.
});
});
Loading