Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions packages/orm/src/client/crud/dialects/base-dialect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ import {
flattenCompoundUniqueFilters,
getDelegateDescendantModels,
getManyToManyRelation,
getModelFields,
getRelationForeignKeyFieldPairs,
isEnum,
isTypeDef,
getModelFields,
makeDefaultOrderBy,
requireField,
requireIdFields,
Expand Down Expand Up @@ -1162,7 +1162,8 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {

// client-level: check both uncapitalized (current) and original (backward compat) model name
const uncapModel = lowerCaseFirst(model);
const omitConfig = (this.options.omit as Record<string, any> | undefined)?.[uncapModel] ??
const omitConfig =
(this.options.omit as Record<string, any> | undefined)?.[uncapModel] ??
(this.options.omit as Record<string, any> | undefined)?.[model];
if (omitConfig && typeof omitConfig === 'object' && typeof omitConfig[field] === 'boolean') {
return omitConfig[field];
Expand Down Expand Up @@ -1357,7 +1358,9 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
const computedFields = this.options.computedFields as Record<string, any>;
// check both uncapitalized (current) and original (backward compat) model name
const computedModel = fieldDef.originModel ?? model;
computer = computedFields?.[lowerCaseFirst(computedModel)]?.[field] ?? computedFields?.[computedModel]?.[field];
computer =
computedFields?.[lowerCaseFirst(computedModel)]?.[field] ??
computedFields?.[computedModel]?.[field];
}
if (!computer) {
throw createConfigError(`Computed field "${field}" implementation not provided for model "${model}"`);
Expand Down Expand Up @@ -1489,6 +1492,19 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
*/
abstract buildValuesTableSelect(fields: FieldDef[], rows: unknown[][]): SelectQueryBuilder<any, any, any>;

/**
* Builds a binary comparison expression between two operands.
*/
buildComparison(
left: Expression<unknown>,
_leftFieldDef: FieldDef | undefined,
op: string,
right: Expression<unknown>,
_rightFieldDef: FieldDef | undefined,
): Expression<SqlBool> {
return this.eb(left, op as any, right) as Expression<SqlBool>;
}

/**
* Builds a JSON path selection expression.
*/
Expand Down
77 changes: 75 additions & 2 deletions packages/orm/src/client/crud/dialects/postgresql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,34 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends LateralJoinDi
Json: 'jsonb',
};

// Maps @db.* attribute names to PostgreSQL SQL types for use in VALUES table casts
private static readonly dbAttributeToSqlTypeMap: Record<string, string> = {
'@db.Uuid': 'uuid',
'@db.Citext': 'citext',
'@db.Inet': 'inet',
'@db.Bit': 'bit',
'@db.VarBit': 'varbit',
'@db.Xml': 'xml',
'@db.Json': 'json',
'@db.JsonB': 'jsonb',
'@db.ByteA': 'bytea',
'@db.Text': 'text',
'@db.Char': 'bpchar',
'@db.VarChar': 'varchar',
'@db.Date': 'date',
'@db.Time': 'time',
'@db.Timetz': 'timetz',
'@db.Timestamp': 'timestamp',
'@db.Timestamptz': 'timestamptz',
'@db.SmallInt': 'smallint',
'@db.Integer': 'integer',
'@db.BigInt': 'bigint',
'@db.Real': 'real',
'@db.DoublePrecision': 'double precision',
'@db.Decimal': 'decimal',
'@db.Boolean': 'boolean',
};

constructor(schema: Schema, options: ClientOptions<Schema>) {
super(schema, options);
this.overrideTypeParsers();
Expand Down Expand Up @@ -406,7 +434,16 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends LateralJoinDi
);
}

private getSqlType(zmodelType: string) {
private getSqlType(zmodelType: string, attributes?: FieldDef['attributes']) {
// Check @db.* attributes first — they specify the exact native PostgreSQL type
if (attributes) {
for (const attr of attributes) {
const mapped = PostgresCrudDialect.dbAttributeToSqlTypeMap[attr.name];
if (mapped) {
return mapped;
}
}
}
if (isEnum(this.schema, zmodelType)) {
// reduce enum to text for type compatibility
return 'text';
Expand All @@ -415,6 +452,42 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends LateralJoinDi
}
}

// Resolves the effective SQL type for a field: the native type from any @db.* attribute,
// or the base ZModel SQL type if no attribute is present, or undefined if the field is unknown.
private resolveFieldSqlType(fieldDef: FieldDef | undefined): { sqlType: string | undefined; hasDbOverride: boolean } {
if (!fieldDef) {
return { sqlType: undefined, hasDbOverride: false };
}
const dbAttr = fieldDef.attributes?.find((a) => a.name.startsWith('@db.'));
if (dbAttr) {
return { sqlType: PostgresCrudDialect.dbAttributeToSqlTypeMap[dbAttr.name], hasDbOverride: true };
}
return { sqlType: this.getSqlType(fieldDef.type), hasDbOverride: false };
}

override buildComparison(
left: Expression<unknown>,
leftFieldDef: FieldDef | undefined,
op: string,
right: Expression<unknown>,
rightFieldDef: FieldDef | undefined,
) {
const leftResolved = this.resolveFieldSqlType(leftFieldDef);
const rightResolved = this.resolveFieldSqlType(rightFieldDef);
// If the resolved SQL types differ and at least one side carries a @db.* native type override,
// cast that side back to its base ZModel SQL type so PostgreSQL doesn't reject the comparison
// (e.g. "operator does not exist: uuid = text").
if (leftResolved.sqlType !== rightResolved.sqlType && (leftResolved.hasDbOverride || rightResolved.hasDbOverride)) {
if (leftResolved.hasDbOverride) {
left = this.eb.cast(left, sql.raw(this.getSqlType(leftFieldDef!.type)));
}
if (rightResolved.hasDbOverride) {
right = this.eb.cast(right, sql.raw(this.getSqlType(rightFieldDef!.type)));
}
}
return super.buildComparison(left, leftFieldDef, op, right, rightFieldDef);
}

override getStringCasingBehavior() {
// Postgres `LIKE` is case-sensitive, `ILIKE` is case-insensitive
return { supportsILike: true, likeCaseSensitive: true };
Expand Down Expand Up @@ -449,7 +522,7 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends LateralJoinDi
)
.select(
fields.map((f, i) => {
const mappedType = this.getSqlType(f.type);
const mappedType = this.getSqlType(f.type, f.attributes);
const castType = f.array ? sql`${sql.raw(mappedType)}[]` : sql.raw(mappedType);
return this.eb.cast(sql.ref(`$values.column${i + 1}`), castType).as(f.name);
}),
Expand Down
50 changes: 34 additions & 16 deletions packages/plugins/policy/src/expression-transformer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,13 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
} else if (this.isNullNode(left)) {
return this.transformNullCheck(right, expr.op);
} else {
return BinaryOperationNode.create(left, this.transformOperator(op), right);
const leftFieldDef = this.getFieldDefFromFieldRef(normalizedLeft, context);
const rightFieldDef = this.getFieldDefFromFieldRef(normalizedRight, context);
// Map ZModel operator to SQL operator string
const sqlOp = op === '==' ? '=' : op;
return this.dialect
.buildComparison(new ExpressionWrapper(left), leftFieldDef, sqlOp, new ExpressionWrapper(right), rightFieldDef)
.toOperationNode();
}
}

Expand Down Expand Up @@ -298,17 +304,17 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
// if relation fields are used directly in comparison, it can only be compared with null,
// so we normalize the args with the id field (use the first id field if multiple)
let normalizedLeft: Expression = expr.left;
if (this.isRelationField(expr.left, context.modelOrType)) {
if (this.isRelationField(expr.left, context)) {
invariant(ExpressionUtils.isNull(expr.right), 'only null comparison is supported for relation field');
const leftRelDef = this.getFieldDefFromFieldRef(expr.left, context.modelOrType);
const leftRelDef = this.getFieldDefFromFieldRef(expr.left, context);
invariant(leftRelDef, 'failed to get relation field definition');
const idFields = QueryUtils.requireIdFields(this.schema, leftRelDef.type);
normalizedLeft = this.makeOrAppendMember(normalizedLeft, idFields[0]!);
}
let normalizedRight: Expression = expr.right;
if (this.isRelationField(expr.right, context.modelOrType)) {
if (this.isRelationField(expr.right, context)) {
invariant(ExpressionUtils.isNull(expr.left), 'only null comparison is supported for relation field');
const rightRelDef = this.getFieldDefFromFieldRef(expr.right, context.modelOrType);
const rightRelDef = this.getFieldDefFromFieldRef(expr.right, context);
invariant(rightRelDef, 'failed to get relation field definition');
const idFields = QueryUtils.requireIdFields(this.schema, rightRelDef.type);
normalizedRight = this.makeOrAppendMember(normalizedRight, idFields[0]!);
Expand Down Expand Up @@ -349,7 +355,7 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
);

let newContextModel: string;
const fieldDef = this.getFieldDefFromFieldRef(expr.left, context.modelOrType);
const fieldDef = this.getFieldDefFromFieldRef(expr.left, context);
if (fieldDef) {
invariant(fieldDef.relation, `field is not a relation: ${JSON.stringify(expr.left)}`);
newContextModel = fieldDef.type;
Expand Down Expand Up @@ -578,13 +584,6 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
return logicalNot(this.dialect, this.transform(expr.operand, context));
}

private transformOperator(op: Exclude<BinaryOperator, '?' | '!' | '^'>) {
const mappedOp = match(op)
.with('==', () => '=' as const)
.otherwise(() => op);
return OperatorNode.create(mappedOp);
}

@expr('call')
// @ts-ignore
private _call(expr: CallExpression, context: ExpressionTransformerContext) {
Expand Down Expand Up @@ -979,12 +978,18 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
}
}

private isRelationField(expr: Expression, model: string) {
const fieldDef = this.getFieldDefFromFieldRef(expr, model);
private isRelationField(expr: Expression, context: ExpressionTransformerContext) {
const fieldDef = this.getFieldDefFromFieldRef(expr, context);
return !!fieldDef?.relation;
}

private getFieldDefFromFieldRef(expr: Expression, model: string): FieldDef | undefined {
private getFieldDefFromFieldRef(expr: Expression, context: ExpressionTransformerContext): FieldDef | undefined {
// `this.foo` references belong to `thisType` (the outer model in collection-predicate
// contexts); everything else uses `modelOrType`.
const model =
ExpressionUtils.isMember(expr) && ExpressionUtils.isThis(expr.receiver)
? context.thisType
: context.modelOrType;
if (ExpressionUtils.isField(expr)) {
return QueryUtils.getField(this.schema, model, expr.field);
} else if (
Expand All @@ -993,6 +998,19 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
ExpressionUtils.isThis(expr.receiver)
) {
return QueryUtils.getField(this.schema, model, expr.members[0]!);
} else if (ExpressionUtils.isMember(expr) && ExpressionUtils.isField(expr.receiver)) {
// relation chain access (e.g. `owner.id`, `user.profile.uuid_field`): walk the
// relation hops and return the terminal field's FieldDef so native-type info
// (@db.*) is available for casting in buildComparison
const receiverDef = QueryUtils.getField(this.schema, model, expr.receiver.field);
if (!receiverDef?.relation) return undefined;
let currModel = receiverDef.type;
for (let i = 0; i < expr.members.length - 1; i++) {
const hopDef = QueryUtils.getField(this.schema, currModel, expr.members[i]!);
if (!hopDef?.relation) return undefined;
currModel = hopDef.type;
}
return QueryUtils.getField(this.schema, currModel, expr.members[expr.members.length - 1]!);
} else {
return undefined;
}
Expand Down
Loading
Loading