Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion drizzle-orm/src/bun-sql/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ export class BunSQLTransaction<
override transaction<T>(
transaction: (tx: BunSQLTransaction<TFullSchema, TSchema>) => Promise<T>,
): Promise<T> {
return (this.session.client as TransactionSQL).savepoint((client) => {
return (this.session.client as TransactionSQL).savepoint((client: SavepointSQL) => {
const session = new BunSQLSession<SavepointSQL, TFullSchema, TSchema>(
client,
this.dialect,
Expand Down
104 changes: 53 additions & 51 deletions drizzle-orm/src/sql/expressions/conditions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@ export function bindIfParam(value: unknown, column: SQLWrapper): SQLChunk {
return value as SQLChunk;
}

export interface BinaryOperator {
export interface BinaryOperator<Return = unknown> {
<TColumn extends Column>(
left: TColumn,
right: GetColumnData<TColumn, 'raw'> | SQLWrapper,
): SQL;
<T>(left: SQL.Aliased<T>, right: T | SQLWrapper): SQL;
): SQL<Return>;
<T>(left: SQL.Aliased<T>, right: T | SQLWrapper): SQL<Return>;
<T extends SQLWrapper>(
left: Exclude<T, SQL.Aliased | Column>,
right: unknown,
): SQL;
): SQL<Return>;
}

/**
Expand All @@ -59,7 +59,7 @@ export interface BinaryOperator {
*
* @see isNull for a way to test equality to NULL.
*/
export const eq: BinaryOperator = (left: SQLWrapper, right: unknown): SQL => {
export const eq: BinaryOperator<boolean | null> = (left: SQLWrapper, right: unknown): SQL<boolean | null> => {
return sql`${left} = ${bindIfParam(right, left)}`;
};

Expand All @@ -81,7 +81,7 @@ export const eq: BinaryOperator = (left: SQLWrapper, right: unknown): SQL => {
*
* @see isNotNull for a way to test whether a value is not null.
*/
export const ne: BinaryOperator = (left: SQLWrapper, right: unknown): SQL => {
export const ne: BinaryOperator<boolean | null> = (left: SQLWrapper, right: unknown): SQL<boolean | null> => {
return sql`${left} <> ${bindIfParam(right, left)}`;
};

Expand All @@ -101,10 +101,11 @@ export const ne: BinaryOperator = (left: SQLWrapper, right: unknown): SQL => {
* )
* ```
*/
export function and(...conditions: (SQLWrapper | undefined)[]): SQL | undefined;
export function and(...conditions: [SQLWrapper, ...(SQLWrapper | undefined)[]]): SQL<boolean | null>;
export function and(...conditions: (SQLWrapper | undefined)[]): SQL<boolean | null> | undefined;
export function and(
...unfilteredConditions: (SQLWrapper | undefined)[]
): SQL | undefined {
): SQL<boolean | null> | undefined {
const conditions = unfilteredConditions.filter(
(c): c is Exclude<typeof c, undefined> => c !== undefined,
);
Expand Down Expand Up @@ -140,10 +141,11 @@ export function and(
* )
* ```
*/
export function or(...conditions: (SQLWrapper | undefined)[]): SQL | undefined;
export function or(...conditions: [SQLWrapper, ...(SQLWrapper | undefined)[]]): SQL<boolean | null>;
export function or(...conditions: (SQLWrapper | undefined)[]): SQL<boolean | null> | undefined;
export function or(
...unfilteredConditions: (SQLWrapper | undefined)[]
): SQL | undefined {
): SQL<boolean | null> | undefined {
const conditions = unfilteredConditions.filter(
(c): c is Exclude<typeof c, undefined> => c !== undefined,
);
Expand Down Expand Up @@ -174,7 +176,7 @@ export function or(
* .where(not(inArray(cars.make, ['GM', 'Ford'])))
* ```
*/
export function not(condition: SQLWrapper): SQL {
export function not(condition: SQLWrapper): SQL<boolean | null> {
return sql`not ${condition}`;
}

Expand All @@ -192,7 +194,7 @@ export function not(condition: SQLWrapper): SQL {
*
* @see gte for greater-than-or-equal
*/
export const gt: BinaryOperator = (left: SQLWrapper, right: unknown): SQL => {
export const gt: BinaryOperator<boolean | null> = (left: SQLWrapper, right: unknown): SQL<boolean | null> => {
return sql`${left} > ${bindIfParam(right, left)}`;
};

Expand All @@ -212,7 +214,7 @@ export const gt: BinaryOperator = (left: SQLWrapper, right: unknown): SQL => {
*
* @see gt for a strictly greater-than condition
*/
export const gte: BinaryOperator = (left: SQLWrapper, right: unknown): SQL => {
export const gte: BinaryOperator<boolean | null> = (left: SQLWrapper, right: unknown): SQL<boolean | null> => {
return sql`${left} >= ${bindIfParam(right, left)}`;
};

Expand All @@ -230,7 +232,7 @@ export const gte: BinaryOperator = (left: SQLWrapper, right: unknown): SQL => {
*
* @see lte for less-than-or-equal
*/
export const lt: BinaryOperator = (left: SQLWrapper, right: unknown): SQL => {
export const lt: BinaryOperator<boolean | null> = (left: SQLWrapper, right: unknown): SQL<boolean | null> => {
return sql`${left} < ${bindIfParam(right, left)}`;
};

Expand All @@ -248,7 +250,7 @@ export const lt: BinaryOperator = (left: SQLWrapper, right: unknown): SQL => {
*
* @see lt for a strictly less-than condition
*/
export const lte: BinaryOperator = (left: SQLWrapper, right: unknown): SQL => {
export const lte: BinaryOperator<boolean | null> = (left: SQLWrapper, right: unknown): SQL<boolean | null> => {
return sql`${left} <= ${bindIfParam(right, left)}`;
};

Expand All @@ -269,19 +271,19 @@ export const lte: BinaryOperator = (left: SQLWrapper, right: unknown): SQL => {
export function inArray<T>(
column: SQL.Aliased<T>,
values: (T | Placeholder)[] | SQLWrapper,
): SQL;
): SQL<boolean | null>;
export function inArray<TColumn extends Column>(
column: TColumn,
values: ReadonlyArray<GetColumnData<TColumn, 'raw'> | Placeholder> | SQLWrapper,
): SQL;
): SQL<boolean | null>;
export function inArray<T extends SQLWrapper>(
column: Exclude<T, SQL.Aliased | Column>,
values: ReadonlyArray<unknown | Placeholder> | SQLWrapper,
): SQL;
): SQL<boolean | null>;
export function inArray(
column: SQLWrapper,
values: ReadonlyArray<unknown | Placeholder> | SQLWrapper,
): SQL {
): SQL<boolean | null> {
if (Array.isArray(values)) {
if (values.length === 0) {
return sql`false`;
Expand Down Expand Up @@ -310,19 +312,19 @@ export function inArray(
export function notInArray<T>(
column: SQL.Aliased<T>,
values: (T | Placeholder)[] | SQLWrapper,
): SQL;
): SQL<boolean | null>;
export function notInArray<TColumn extends Column>(
column: TColumn,
values: (GetColumnData<TColumn, 'raw'> | Placeholder)[] | SQLWrapper,
): SQL;
): SQL<boolean | null>;
export function notInArray<T extends SQLWrapper>(
column: Exclude<T, SQL.Aliased | Column>,
values: (unknown | Placeholder)[] | SQLWrapper,
): SQL;
): SQL<boolean | null>;
export function notInArray(
column: SQLWrapper,
values: (unknown | Placeholder)[] | SQLWrapper,
): SQL {
): SQL<boolean | null> {
if (Array.isArray(values)) {
if (values.length === 0) {
return sql`true`;
Expand All @@ -349,7 +351,7 @@ export function notInArray(
*
* @see isNotNull for the inverse of this test
*/
export function isNull(value: SQLWrapper): SQL {
export function isNull(value: SQLWrapper): SQL<boolean> {
return sql`${value} is null`;
}

Expand All @@ -369,7 +371,7 @@ export function isNull(value: SQLWrapper): SQL {
*
* @see isNull for the inverse of this test
*/
export function isNotNull(value: SQLWrapper): SQL {
export function isNotNull(value: SQLWrapper): SQL<boolean> {
return sql`${value} is not null`;
}

Expand All @@ -393,7 +395,7 @@ export function isNotNull(value: SQLWrapper): SQL {
*
* @see notExists for the inverse of this test
*/
export function exists(subquery: SQLWrapper): SQL {
export function exists(subquery: SQLWrapper): SQL<boolean> {
return sql`exists ${subquery}`;
}

Expand All @@ -418,7 +420,7 @@ export function exists(subquery: SQLWrapper): SQL {
*
* @see exists for the inverse of this test
*/
export function notExists(subquery: SQLWrapper): SQL {
export function notExists(subquery: SQLWrapper): SQL<boolean> {
return sql`not exists ${subquery}`;
}

Expand All @@ -445,18 +447,18 @@ export function between<T>(
column: SQL.Aliased,
min: T | SQLWrapper,
max: T | SQLWrapper,
): SQL;
): SQL<boolean | null>;
export function between<TColumn extends AnyColumn>(
column: TColumn,
min: GetColumnData<TColumn, 'raw'> | SQLWrapper,
max: GetColumnData<TColumn, 'raw'> | SQLWrapper,
): SQL;
): SQL<boolean | null>;
export function between<T extends SQLWrapper>(
column: Exclude<T, SQL.Aliased | Column>,
min: unknown,
max: unknown,
): SQL;
export function between(column: SQLWrapper, min: unknown, max: unknown): SQL {
): SQL<boolean | null>;
export function between(column: SQLWrapper, min: unknown, max: unknown): SQL<boolean | null> {
return sql`${column} between ${bindIfParam(min, column)} and ${
bindIfParam(
max,
Expand Down Expand Up @@ -486,22 +488,22 @@ export function notBetween<T>(
column: SQL.Aliased,
min: T | SQLWrapper,
max: T | SQLWrapper,
): SQL;
): SQL<boolean | null>;
export function notBetween<TColumn extends AnyColumn>(
column: TColumn,
min: GetColumnData<TColumn, 'raw'> | SQLWrapper,
max: GetColumnData<TColumn, 'raw'> | SQLWrapper,
): SQL;
): SQL<boolean | null>;
export function notBetween<T extends SQLWrapper>(
column: Exclude<T, SQL.Aliased | Column>,
min: unknown,
max: unknown,
): SQL;
): SQL<boolean | null>;
export function notBetween(
column: SQLWrapper,
min: unknown,
max: unknown,
): SQL {
): SQL<boolean | null> {
return sql`${column} not between ${
bindIfParam(
min,
Expand All @@ -526,7 +528,7 @@ export function notBetween(
*
* @see ilike for a case-insensitive version of this condition
*/
export function like(column: Column | SQL.Aliased | SQL, value: string | SQLWrapper): SQL {
export function like(column: Column | SQL.Aliased | SQL, value: string | SQLWrapper): SQL<boolean | null> {
return sql`${column} like ${value}`;
}

Expand All @@ -548,7 +550,7 @@ export function like(column: Column | SQL.Aliased | SQL, value: string | SQLWrap
* @see like for the inverse condition
* @see notIlike for a case-insensitive version of this condition
*/
export function notLike(column: Column | SQL.Aliased | SQL, value: string | SQLWrapper): SQL {
export function notLike(column: Column | SQL.Aliased | SQL, value: string | SQLWrapper): SQL<boolean | null> {
return sql`${column} not like ${value}`;
}

Expand All @@ -571,7 +573,7 @@ export function notLike(column: Column | SQL.Aliased | SQL, value: string | SQLW
*
* @see like for a case-sensitive version of this condition
*/
export function ilike(column: Column | SQL.Aliased | SQL, value: string | SQLWrapper): SQL {
export function ilike(column: Column | SQL.Aliased | SQL, value: string | SQLWrapper): SQL<boolean | null> {
return sql`${column} ilike ${value}`;
}

Expand All @@ -593,7 +595,7 @@ export function ilike(column: Column | SQL.Aliased | SQL, value: string | SQLWra
* @see ilike for the inverse condition
* @see notLike for a case-sensitive version of this condition
*/
export function notIlike(column: Column | SQL.Aliased | SQL, value: string | SQLWrapper): SQL {
export function notIlike(column: Column | SQL.Aliased | SQL, value: string | SQLWrapper): SQL<boolean | null> {
return sql`${column} not ilike ${value}`;
}

Expand All @@ -620,19 +622,19 @@ export function notIlike(column: Column | SQL.Aliased | SQL, value: string | SQL
export function arrayContains<T>(
column: SQL.Aliased<T>,
values: (T | Placeholder) | SQLWrapper,
): SQL;
): SQL<boolean | null>;
export function arrayContains<TColumn extends Column>(
column: TColumn,
values: (GetColumnData<TColumn, 'raw'> | Placeholder) | SQLWrapper,
): SQL;
): SQL<boolean | null>;
export function arrayContains<T extends SQLWrapper>(
column: Exclude<T, SQL.Aliased | Column>,
values: (unknown | Placeholder)[] | SQLWrapper,
): SQL;
): SQL<boolean | null>;
export function arrayContains(
column: SQLWrapper,
values: (unknown | Placeholder)[] | SQLWrapper,
): SQL {
): SQL<boolean | null> {
if (Array.isArray(values)) {
if (values.length === 0) {
throw new Error('arrayContains requires at least one value');
Expand Down Expand Up @@ -668,19 +670,19 @@ export function arrayContains(
export function arrayContained<T>(
column: SQL.Aliased<T>,
values: (T | Placeholder) | SQLWrapper,
): SQL;
): SQL<boolean | null>;
export function arrayContained<TColumn extends Column>(
column: TColumn,
values: (GetColumnData<TColumn, 'raw'> | Placeholder) | SQLWrapper,
): SQL;
): SQL<boolean | null>;
export function arrayContained<T extends SQLWrapper>(
column: Exclude<T, SQL.Aliased | Column>,
values: (unknown | Placeholder)[] | SQLWrapper,
): SQL;
): SQL<boolean | null>;
export function arrayContained(
column: SQLWrapper,
values: (unknown | Placeholder)[] | SQLWrapper,
): SQL {
): SQL<boolean | null> {
if (Array.isArray(values)) {
if (values.length === 0) {
throw new Error('arrayContained requires at least one value');
Expand Down Expand Up @@ -715,19 +717,19 @@ export function arrayContained(
export function arrayOverlaps<T>(
column: SQL.Aliased<T>,
values: (T | Placeholder) | SQLWrapper,
): SQL;
): SQL<boolean | null>;
export function arrayOverlaps<TColumn extends Column>(
column: TColumn,
values: (GetColumnData<TColumn, 'raw'> | Placeholder) | SQLWrapper,
): SQL;
): SQL<boolean | null>;
export function arrayOverlaps<T extends SQLWrapper>(
column: Exclude<T, SQL.Aliased | Column>,
values: (unknown | Placeholder)[] | SQLWrapper,
): SQL;
): SQL<boolean | null>;
export function arrayOverlaps(
column: SQLWrapper,
values: (unknown | Placeholder)[] | SQLWrapper,
): SQL {
): SQL<boolean | null> {
if (Array.isArray(values)) {
if (values.length === 0) {
throw new Error('arrayOverlaps requires at least one value');
Expand Down
2 changes: 1 addition & 1 deletion integration-tests/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"description": "",
"type": "module",
"scripts": {
"test:types": "tsc && cd type-tests/join-nodenext && tsc",
"test:types": "tsc",
"test": "pnpm test:vitest",
"test:vitest": "vitest run --pass-with-no-tests",
"test:esm": "node tests/imports.test.mjs && node tests/imports.test.cjs",
Expand Down
26 changes: 26 additions & 0 deletions integration-tests/type-tests/conditions/conditions.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import { and, or, type SQL } from 'drizzle-orm';

declare function returnsMaybeSql(): SQL | undefined;
declare function returnsSql(): SQL;

void function testConditionTypes() {
// @ts-expect-error it will return undefined
and() satisfies SQL;
// @ts-expect-error it could return undefined
and(returnsMaybeSql()) satisfies SQL;
// @ts-expect-error the SQL could return null
and(returnsSql()) satisfies SQL<boolean>;
// this should be ok
and(returnsSql()) satisfies SQL<boolean | null>;
and(returnsSql(), undefined) satisfies SQL<boolean | null>;

// @ts-expect-error it will return undefined
or() satisfies SQL;
// @ts-expect-error it could return undefined
or(returnsMaybeSql()) satisfies SQL;
// @ts-expect-error the SQL could return null
or(returnsSql()) satisfies SQL<boolean>;
// this should be ok
or(returnsSql()) satisfies SQL<boolean | null>;
or(returnsSql(), undefined) satisfies SQL<boolean | null>;
};
Loading