fix: support using aggregations inside orderBy and having of groupBy (#152)

* fix: support using aggregations inside `orderBy` and `having` of `groupBy`

* update

* update
This commit is contained in:
Yiming Cao 2025-08-12 18:28:52 +08:00 committed by GitHub
parent 8833aa75b5
commit 2e95aa5eef
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 445 additions and 153 deletions

View file

@ -107,7 +107,7 @@ export default class DataModelValidator implements AstValidator<DataModel> {
if (field.type.array && !isDataModel(field.type.reference?.ref)) {
const provider = this.getDataSourceProvider(AstUtils.getContainerOfType(field, isModel)!);
if (provider === 'sqlite') {
accept('error', `Array type is not supported for "${provider}" provider.`, { node: field.type });
accept('error', `List type is not supported for "${provider}" provider.`, { node: field.type });
}
}

View file

@ -556,7 +556,12 @@ function createModelCrudHandler<Schema extends SchemaDef, Model extends GetModel
},
groupBy: (args: unknown) => {
return createPromise('groupBy', args, new GroupByOperationHandler<Schema>(client, model, inputValidator));
return createPromise(
'groupBy',
args,
new GroupByOperationHandler<Schema>(client, model, inputValidator),
true,
);
},
} as ModelOperations<Schema, Model>;
}

View file

@ -17,3 +17,14 @@ export const TRANSACTION_UNSUPPORTED_METHODS = ['$transaction', '$disconnect', '
* Prefix for JSON field used to store joined delegate rows.
*/
export const DELEGATE_JOINED_FIELD_PREFIX = '$delegate$';
/**
* Logical combinators used in filters.
*/
export const LOGICAL_COMBINATORS = ['AND', 'OR', 'NOT'] as const;
/**
* Aggregation operators.
*/
export const AGGREGATE_OPERATORS = ['_count', '_sum', '_avg', '_min', '_max'] as const;
export type AGGREGATE_OPERATORS = (typeof AGGREGATE_OPERATORS)[number];

View file

@ -752,17 +752,18 @@ export type ModelOperations<Schema extends SchemaDef, Model extends GetModels<Sc
* _count: true
* }); // result: `Array<{ country: string, city: string, _count: number }>`
*
* // group by with sorting, the `orderBy` fields must be in the `by` list
* // group by with sorting, the `orderBy` fields must be either an aggregation
* // or a field used in the `by` list
* await db.profile.groupBy({
* by: 'country',
* orderBy: { country: 'desc' }
* });
*
* // group by with having (post-aggregation filter), the `having` fields must
* // be in the `by` list
* // group by with having (post-aggregation filter), the fields used in `having` must
* // be either an aggregation, or a field used in the `by` list
* await db.profile.groupBy({
* by: 'country',
* having: { country: 'US' }
* having: { country: 'US', age: { _avg: { gte: 18 } } }
* });
*/
groupBy<T extends GroupByArgs<Schema, Model>>(

View file

@ -209,6 +209,7 @@ export type WhereInput<
Schema extends SchemaDef,
Model extends GetModels<Schema>,
ScalarOnly extends boolean = false,
WithAggregations extends boolean = false,
> = {
[Key in GetModelFields<Schema, Model> as ScalarOnly extends true
? Key extends RelationFields<Schema, Model>
@ -223,7 +224,12 @@ export type WhereInput<
: FieldIsArray<Schema, Model, Key> extends true
? ArrayFilter<GetModelFieldType<Schema, Model, Key>>
: // primitive
PrimitiveFilter<Schema, GetModelFieldType<Schema, Model, Key>, ModelFieldIsOptional<Schema, Model, Key>>;
PrimitiveFilter<
Schema,
GetModelFieldType<Schema, Model, Key>,
ModelFieldIsOptional<Schema, Model, Key>,
WithAggregations
>;
} & {
$expr?: (eb: ExpressionBuilder<ToKyselySchema<Schema>, Model>) => OperandExpression<SqlBool>;
} & {
@ -249,21 +255,32 @@ type ArrayFilter<T extends string> = {
isEmpty?: boolean;
};
type PrimitiveFilter<Schema extends SchemaDef, T extends string, Nullable extends boolean> = T extends 'String'
? StringFilter<Schema, Nullable>
type PrimitiveFilter<
Schema extends SchemaDef,
T extends string,
Nullable extends boolean,
WithAggregations extends boolean,
> = T extends 'String'
? StringFilter<Schema, Nullable, WithAggregations>
: T extends 'Int' | 'Float' | 'Decimal' | 'BigInt'
? NumberFilter<Schema, T, Nullable>
? NumberFilter<Schema, T, Nullable, WithAggregations>
: T extends 'Boolean'
? BooleanFilter<Nullable>
? BooleanFilter<Schema, Nullable, WithAggregations>
: T extends 'DateTime'
? DateTimeFilter<Schema, Nullable>
? DateTimeFilter<Schema, Nullable, WithAggregations>
: T extends 'Bytes'
? BytesFilter<Nullable>
? BytesFilter<Schema, Nullable, WithAggregations>
: T extends 'Json'
? 'Not implemented yet' // TODO: Json filter
: never;
type CommonPrimitiveFilter<Schema extends SchemaDef, DataType, T extends BuiltinType, Nullable extends boolean> = {
type CommonPrimitiveFilter<
Schema extends SchemaDef,
DataType,
T extends BuiltinType,
Nullable extends boolean,
WithAggregations extends boolean,
> = {
equals?: NullableIf<DataType, Nullable>;
in?: DataType[];
notIn?: DataType[];
@ -271,16 +288,23 @@ type CommonPrimitiveFilter<Schema extends SchemaDef, DataType, T extends Builtin
lte?: DataType;
gt?: DataType;
gte?: DataType;
not?: PrimitiveFilter<Schema, T, Nullable>;
not?: PrimitiveFilter<Schema, T, Nullable, WithAggregations>;
};
export type StringFilter<Schema extends SchemaDef, Nullable extends boolean> =
export type StringFilter<Schema extends SchemaDef, Nullable extends boolean, WithAggregations extends boolean> =
| NullableIf<string, Nullable>
| (CommonPrimitiveFilter<Schema, string, 'String', Nullable> & {
| (CommonPrimitiveFilter<Schema, string, 'String', Nullable, WithAggregations> & {
contains?: string;
startsWith?: string;
endsWith?: string;
} & (ProviderSupportsCaseSensitivity<Schema> extends true
} & (WithAggregations extends true
? {
_count?: NumberFilter<Schema, 'Int', false, false>;
_min?: StringFilter<Schema, false, false>;
_max?: StringFilter<Schema, false, false>;
}
: {}) &
(ProviderSupportsCaseSensitivity<Schema> extends true
? {
mode?: 'default' | 'insensitive';
}
@ -290,27 +314,58 @@ export type NumberFilter<
Schema extends SchemaDef,
T extends 'Int' | 'Float' | 'Decimal' | 'BigInt',
Nullable extends boolean,
> = NullableIf<number | bigint, Nullable> | CommonPrimitiveFilter<Schema, number, T, Nullable>;
WithAggregations extends boolean,
> =
| NullableIf<number | bigint, Nullable>
| (CommonPrimitiveFilter<Schema, number, T, Nullable, WithAggregations> &
(WithAggregations extends true
? {
_count?: NumberFilter<Schema, 'Int', false, false>;
_avg?: NumberFilter<Schema, T, false, false>;
_sum?: NumberFilter<Schema, T, false, false>;
_min?: NumberFilter<Schema, T, false, false>;
_max?: NumberFilter<Schema, T, false, false>;
}
: {}));
export type DateTimeFilter<Schema extends SchemaDef, Nullable extends boolean> =
export type DateTimeFilter<Schema extends SchemaDef, Nullable extends boolean, WithAggregations extends boolean> =
| NullableIf<Date | string, Nullable>
| CommonPrimitiveFilter<Schema, Date | string, 'DateTime', Nullable>;
| (CommonPrimitiveFilter<Schema, Date | string, 'DateTime', Nullable, WithAggregations> &
(WithAggregations extends true
? {
_count?: NumberFilter<Schema, 'Int', false, false>;
_min?: DateTimeFilter<Schema, false, false>;
_max?: DateTimeFilter<Schema, false, false>;
}
: {}));
export type BytesFilter<Nullable extends boolean> =
export type BytesFilter<Schema extends SchemaDef, Nullable extends boolean, WithAggregations extends boolean> =
| NullableIf<Uint8Array | Buffer, Nullable>
| {
| ({
equals?: NullableIf<Uint8Array, Nullable>;
in?: Uint8Array[];
notIn?: Uint8Array[];
not?: BytesFilter<Nullable>;
};
not?: BytesFilter<Schema, Nullable, WithAggregations>;
} & (WithAggregations extends true
? {
_count?: NumberFilter<Schema, 'Int', false, false>;
_min?: BytesFilter<Schema, false, false>;
_max?: BytesFilter<Schema, false, false>;
}
: {}));
export type BooleanFilter<Nullable extends boolean> =
export type BooleanFilter<Schema extends SchemaDef, Nullable extends boolean, WithAggregations extends boolean> =
| NullableIf<boolean, Nullable>
| {
| ({
equals?: NullableIf<boolean, Nullable>;
not?: BooleanFilter<Nullable>;
};
not?: BooleanFilter<Schema, Nullable, WithAggregations>;
} & (WithAggregations extends true
? {
_count?: NumberFilter<Schema, 'Int', false, false>;
_min?: BooleanFilter<Schema, false, false>;
_max?: BooleanFilter<Schema, false, false>;
}
: {}));
export type SortOrder = 'asc' | 'desc';
export type NullsOrder = 'first' | 'last';
@ -340,14 +395,15 @@ export type OrderBy<
: {}) &
(WithAggregation extends true
? {
_count?: OrderBy<Schema, Model, WithRelation, false>;
_count?: OrderBy<Schema, Model, false, false>;
_min?: MinMaxInput<Schema, Model, SortOrder>;
_max?: MinMaxInput<Schema, Model, SortOrder>;
} & (NumericFields<Schema, Model> extends never
? {}
: {
_avg?: SumAvgInput<Schema, Model>;
_sum?: SumAvgInput<Schema, Model>;
_min?: MinMaxInput<Schema, Model>;
_max?: MinMaxInput<Schema, Model>;
// aggregations specific to numeric fields
_avg?: SumAvgInput<Schema, Model, SortOrder>;
_sum?: SumAvgInput<Schema, Model, SortOrder>;
})
: {});
@ -931,13 +987,13 @@ export type AggregateArgs<Schema extends SchemaDef, Model extends GetModels<Sche
orderBy?: OrArray<OrderBy<Schema, Model, true, false>>;
} & {
_count?: true | CountAggregateInput<Schema, Model>;
_min?: MinMaxInput<Schema, Model, true>;
_max?: MinMaxInput<Schema, Model, true>;
} & (NumericFields<Schema, Model> extends never
? {}
: {
_avg?: SumAvgInput<Schema, Model>;
_sum?: SumAvgInput<Schema, Model>;
_min?: MinMaxInput<Schema, Model>;
_max?: MinMaxInput<Schema, Model>;
_avg?: SumAvgInput<Schema, Model, true>;
_sum?: SumAvgInput<Schema, Model, true>;
});
type NumericFields<Schema extends SchemaDef, Model extends GetModels<Schema>> = keyof {
@ -952,16 +1008,16 @@ type NumericFields<Schema extends SchemaDef, Model extends GetModels<Schema>> =
: never]: GetModelField<Schema, Model, Key>;
};
type SumAvgInput<Schema extends SchemaDef, Model extends GetModels<Schema>> = {
[Key in NumericFields<Schema, Model>]?: true;
type SumAvgInput<Schema extends SchemaDef, Model extends GetModels<Schema>, ValueType> = {
[Key in NumericFields<Schema, Model>]?: ValueType;
};
type MinMaxInput<Schema extends SchemaDef, Model extends GetModels<Schema>> = {
type MinMaxInput<Schema extends SchemaDef, Model extends GetModels<Schema>, ValueType> = {
[Key in GetModelFields<Schema, Model> as FieldIsArray<Schema, Model, Key> extends true
? never
: FieldIsRelation<Schema, Model, Key> extends true
? never
: Key]?: true;
: Key]?: ValueType;
};
export type AggregateResult<
@ -1006,21 +1062,28 @@ type AggCommonOutput<Input> = Input extends true
// #region GroupBy
type GroupByHaving<Schema extends SchemaDef, Model extends GetModels<Schema>> = Omit<
WhereInput<Schema, Model, true, true>,
'$expr'
>;
export type GroupByArgs<Schema extends SchemaDef, Model extends GetModels<Schema>> = {
where?: WhereInput<Schema, Model>;
orderBy?: OrArray<OrderBy<Schema, Model, false, true>>;
by: NonRelationFields<Schema, Model> | NonEmptyArray<NonRelationFields<Schema, Model>>;
having?: WhereInput<Schema, Model, true>;
having?: GroupByHaving<Schema, Model>;
take?: number;
skip?: number;
// aggregations
_count?: true | CountAggregateInput<Schema, Model>;
_min?: MinMaxInput<Schema, Model, true>;
_max?: MinMaxInput<Schema, Model, true>;
} & (NumericFields<Schema, Model> extends never
? {}
: {
_avg?: SumAvgInput<Schema, Model>;
_sum?: SumAvgInput<Schema, Model>;
_min?: MinMaxInput<Schema, Model>;
_max?: MinMaxInput<Schema, Model>;
// aggregations specific to numeric fields
_avg?: SumAvgInput<Schema, Model, true>;
_sum?: SumAvgInput<Schema, Model, true>;
});
export type GroupByResult<

View file

@ -5,7 +5,7 @@ import { match, P } from 'ts-pattern';
import type { BuiltinType, DataSourceProviderType, FieldDef, GetModels, SchemaDef } from '../../../schema';
import { enumerate } from '../../../utils/enumerate';
import type { OrArray } from '../../../utils/type-utils';
import { DELEGATE_JOINED_FIELD_PREFIX } from '../../constants';
import { AGGREGATE_OPERATORS, DELEGATE_JOINED_FIELD_PREFIX, LOGICAL_COMBINATORS } from '../../constants';
import type {
BooleanFilter,
BytesFilter,
@ -18,6 +18,7 @@ import type {
import { InternalError, QueryError } from '../../errors';
import type { ClientOptions } from '../../options';
import {
aggregate,
buildFieldRef,
buildJoinPairs,
flattenCompoundUniqueFilters,
@ -83,7 +84,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
continue;
}
if (key === 'AND' || key === 'OR' || key === 'NOT') {
if (this.isLogicalCombinator(key)) {
result = this.and(eb, result, this.buildCompositeFilter(eb, model, modelAlias, key, payload));
continue;
}
@ -118,11 +119,15 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
return result;
}
private isLogicalCombinator(key: string): key is (typeof LOGICAL_COMBINATORS)[number] {
return LOGICAL_COMBINATORS.includes(key as any);
}
protected buildCompositeFilter(
eb: ExpressionBuilder<any, any>,
model: string,
modelAlias: string,
key: 'AND' | 'OR' | 'NOT',
key: (typeof LOGICAL_COMBINATORS)[number],
payload: any,
): Expression<SqlBool> {
return match(key)
@ -500,6 +505,20 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
.with('gt', () => eb(lhs, '>', rhs))
.with('gte', () => eb(lhs, '>=', rhs))
.with('not', () => eb.not(recurse(value)))
// aggregations
.with(P.union(...AGGREGATE_OPERATORS), (op) => {
const innerResult = this.buildStandardFilter(
eb,
type,
value,
aggregate(eb, lhs, op),
getRhs,
recurse,
throwIfInvalid,
);
consumedKeys.push(...innerResult.consumedKeys);
return this.and(eb, ...innerResult.conditions);
})
.otherwise(() => {
if (throwIfInvalid) {
throw new QueryError(`Invalid filter key: ${op}`);
@ -520,7 +539,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
private buildStringFilter(
eb: ExpressionBuilder<any, any>,
fieldRef: Expression<any>,
payload: StringFilter<Schema, true>,
payload: StringFilter<Schema, true, boolean>,
) {
let mode: 'default' | 'insensitive' | undefined;
if (payload && typeof payload === 'object' && 'mode' in payload) {
@ -533,7 +552,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
payload,
mode === 'insensitive' ? eb.fn('lower', [fieldRef]) : fieldRef,
(value) => this.prepStringCasing(eb, value, mode),
(value) => this.buildStringFilter(eb, fieldRef, value as StringFilter<Schema, true>),
(value) => this.buildStringFilter(eb, fieldRef, value as StringFilter<Schema, true, boolean>),
);
if (payload && typeof payload === 'object') {
@ -610,7 +629,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
private buildBooleanFilter(
eb: ExpressionBuilder<any, any>,
fieldRef: Expression<any>,
payload: BooleanFilter<true>,
payload: BooleanFilter<Schema, boolean, boolean>,
) {
const { conditions } = this.buildStandardFilter(
eb,
@ -618,7 +637,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
payload,
fieldRef,
(value) => this.transformPrimitive(value, 'Boolean', false),
(value) => this.buildBooleanFilter(eb, fieldRef, value as BooleanFilter<true>),
(value) => this.buildBooleanFilter(eb, fieldRef, value as BooleanFilter<Schema, boolean, boolean>),
true,
['equals', 'not'],
);
@ -628,7 +647,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
private buildDateTimeFilter(
eb: ExpressionBuilder<any, any>,
fieldRef: Expression<any>,
payload: DateTimeFilter<Schema, true>,
payload: DateTimeFilter<Schema, boolean, boolean>,
) {
const { conditions } = this.buildStandardFilter(
eb,
@ -636,20 +655,24 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
payload,
fieldRef,
(value) => this.transformPrimitive(value, 'DateTime', false),
(value) => this.buildDateTimeFilter(eb, fieldRef, value as DateTimeFilter<Schema, true>),
(value) => this.buildDateTimeFilter(eb, fieldRef, value as DateTimeFilter<Schema, boolean, boolean>),
true,
);
return this.and(eb, ...conditions);
}
private buildBytesFilter(eb: ExpressionBuilder<any, any>, fieldRef: Expression<any>, payload: BytesFilter<true>) {
private buildBytesFilter(
eb: ExpressionBuilder<any, any>,
fieldRef: Expression<any>,
payload: BytesFilter<Schema, boolean, boolean>,
) {
const conditions = this.buildStandardFilter(
eb,
'Bytes',
payload,
fieldRef,
(value) => this.transformPrimitive(value, 'Bytes', false),
(value) => this.buildBytesFilter(eb, fieldRef, value as BytesFilter<true>),
(value) => this.buildBytesFilter(eb, fieldRef, value as BytesFilter<Schema, boolean, boolean>),
true,
['equals', 'in', 'notIn', 'not'],
);
@ -704,7 +727,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
for (const [k, v] of Object.entries<string>(value)) {
invariant(v === 'asc' || v === 'desc', `invalid orderBy value for field "${field}"`);
result = result.orderBy(
(eb) => eb.fn(field.slice(1), [sql.ref(k)]),
(eb) => aggregate(eb, sql.ref(`${modelAlias}.${k}`), field as AGGREGATE_OPERATORS),
sql.raw(this.negateSort(v, negated)),
);
}

View file

@ -44,9 +44,9 @@ export class GroupByOperationHandler<Schema extends SchemaDef> extends BaseOpera
return subQuery.as('$sub');
});
// groupBy
const bys = typeof parsedArgs.by === 'string' ? [parsedArgs.by] : (parsedArgs.by as string[]);
query = query.groupBy(bys as any);
query = query.groupBy(bys.map((by) => sql.ref(`$sub.${by}`)));
// orderBy
if (parsedArgs.orderBy) {
@ -54,7 +54,7 @@ export class GroupByOperationHandler<Schema extends SchemaDef> extends BaseOpera
}
if (parsedArgs.having) {
query = query.having((eb1) => this.dialect.buildFilter(eb1, this.model, '$sub', parsedArgs.having));
query = query.having((eb) => this.dialect.buildFilter(eb, this.model, '$sub', parsedArgs.having));
}
// select all by fields

View file

@ -4,7 +4,9 @@ import stableStringify from 'json-stable-stringify';
import { match, P } from 'ts-pattern';
import { z, ZodType } from 'zod';
import { type BuiltinType, type EnumDef, type FieldDef, type GetModels, type SchemaDef } from '../../schema';
import { NUMERIC_FIELD_TYPES } from '../constants';
import { enumerate } from '../../utils/enumerate';
import { extractFields } from '../../utils/object-utils';
import { AGGREGATE_OPERATORS, LOGICAL_COMBINATORS, NUMERIC_FIELD_TYPES } from '../constants';
import {
type AggregateArgs,
type CountArgs,
@ -231,10 +233,10 @@ export class InputValidator<Schema extends SchemaDef> {
} else {
return match(type)
.with('String', () => z.string())
.with('Int', () => z.number())
.with('Int', () => z.int())
.with('Float', () => z.number())
.with('Boolean', () => z.boolean())
.with('BigInt', () => z.union([z.number(), z.bigint()]))
.with('BigInt', () => z.union([z.int(), z.bigint()]))
.with('Decimal', () => z.union([z.number(), z.instanceof(Decimal), z.string()]))
.with('DateTime', () => z.union([z.date(), z.string().datetime()]))
.with('Bytes', () => z.instanceof(Uint8Array))
@ -268,7 +270,12 @@ export class InputValidator<Schema extends SchemaDef> {
return schema;
}
private makeWhereSchema(model: string, unique: boolean, withoutRelationFields = false): ZodType {
private makeWhereSchema(
model: string,
unique: boolean,
withoutRelationFields = false,
withAggregations = false,
): ZodType {
const modelDef = getModel(this.schema, model);
if (!modelDef) {
throw new QueryError(`Model "${model}" not found in schema`);
@ -313,14 +320,18 @@ export class InputValidator<Schema extends SchemaDef> {
if (enumDef) {
// enum
if (Object.keys(enumDef).length > 0) {
fieldSchema = this.makeEnumFilterSchema(enumDef, !!fieldDef.optional);
fieldSchema = this.makeEnumFilterSchema(enumDef, !!fieldDef.optional, withAggregations);
}
} else if (fieldDef.array) {
// array field
fieldSchema = this.makeArrayFilterSchema(fieldDef.type as BuiltinType);
} else {
// primitive field
fieldSchema = this.makePrimitiveFilterSchema(fieldDef.type as BuiltinType, !!fieldDef.optional);
fieldSchema = this.makePrimitiveFilterSchema(
fieldDef.type as BuiltinType,
!!fieldDef.optional,
withAggregations,
);
}
}
@ -344,7 +355,7 @@ export class InputValidator<Schema extends SchemaDef> {
if (enumDef) {
// enum
if (Object.keys(enumDef).length > 0) {
fieldSchema = this.makeEnumFilterSchema(enumDef, !!def.optional);
fieldSchema = this.makeEnumFilterSchema(enumDef, !!def.optional, false);
} else {
fieldSchema = z.never();
}
@ -353,6 +364,7 @@ export class InputValidator<Schema extends SchemaDef> {
fieldSchema = this.makePrimitiveFilterSchema(
def.type as BuiltinType,
!!def.optional,
false,
);
}
return [key, fieldSchema];
@ -407,20 +419,16 @@ export class InputValidator<Schema extends SchemaDef> {
return result;
}
private makeEnumFilterSchema(enumDef: EnumDef, optional: boolean) {
private makeEnumFilterSchema(enumDef: EnumDef, optional: boolean, withAggregations: boolean) {
const baseSchema = z.enum(Object.keys(enumDef) as [string, ...string[]]);
const components = this.makeCommonPrimitiveFilterComponents(baseSchema, optional, () =>
z.lazy(() => this.makeEnumFilterSchema(enumDef, optional)),
const components = this.makeCommonPrimitiveFilterComponents(
baseSchema,
optional,
() => z.lazy(() => this.makeEnumFilterSchema(enumDef, optional, withAggregations)),
['equals', 'in', 'notIn', 'not'],
withAggregations ? ['_count', '_min', '_max'] : undefined,
);
return z.union([
this.nullableIf(baseSchema, optional),
z.strictObject({
equals: components.equals,
in: components.in,
notIn: components.notIn,
not: components.not,
}),
]);
return z.union([this.nullableIf(baseSchema, optional), z.strictObject(components)]);
}
private makeArrayFilterSchema(type: BuiltinType) {
@ -433,20 +441,20 @@ export class InputValidator<Schema extends SchemaDef> {
});
}
private makePrimitiveFilterSchema(type: BuiltinType, optional: boolean) {
private makePrimitiveFilterSchema(type: BuiltinType, optional: boolean, withAggregations: boolean) {
if (this.schema.typeDefs && type in this.schema.typeDefs) {
// typed JSON field
return this.makeTypeDefFilterSchema(type, optional);
}
return (
match(type)
.with('String', () => this.makeStringFilterSchema(optional))
.with('String', () => this.makeStringFilterSchema(optional, withAggregations))
.with(P.union('Int', 'Float', 'Decimal', 'BigInt'), (type) =>
this.makeNumberFilterSchema(this.makePrimitiveSchema(type), optional),
this.makeNumberFilterSchema(this.makePrimitiveSchema(type), optional, withAggregations),
)
.with('Boolean', () => this.makeBooleanFilterSchema(optional))
.with('DateTime', () => this.makeDateTimeFilterSchema(optional))
.with('Bytes', () => this.makeBytesFilterSchema(optional))
.with('Boolean', () => this.makeBooleanFilterSchema(optional, withAggregations))
.with('DateTime', () => this.makeDateTimeFilterSchema(optional, withAggregations))
.with('Bytes', () => this.makeBytesFilterSchema(optional, withAggregations))
// TODO: JSON filters
.with('Json', () => z.any())
.with('Unsupported', () => z.never())
@ -459,40 +467,48 @@ export class InputValidator<Schema extends SchemaDef> {
return z.never();
}
private makeDateTimeFilterSchema(optional: boolean): ZodType {
return this.makeCommonPrimitiveFilterSchema(z.union([z.string().datetime(), z.date()]), optional, () =>
z.lazy(() => this.makeDateTimeFilterSchema(optional)),
private makeDateTimeFilterSchema(optional: boolean, withAggregations: boolean): ZodType {
return this.makeCommonPrimitiveFilterSchema(
z.union([z.iso.datetime(), z.date()]),
optional,
() => z.lazy(() => this.makeDateTimeFilterSchema(optional, withAggregations)),
withAggregations ? ['_count', '_min', '_max'] : undefined,
);
}
private makeBooleanFilterSchema(optional: boolean): ZodType {
return z.union([
this.nullableIf(z.boolean(), optional),
z.strictObject({
equals: this.nullableIf(z.boolean(), optional).optional(),
not: z.lazy(() => this.makeBooleanFilterSchema(optional)).optional(),
}),
]);
private makeBooleanFilterSchema(optional: boolean, withAggregations: boolean): ZodType {
const components = this.makeCommonPrimitiveFilterComponents(
z.boolean(),
optional,
() => z.lazy(() => this.makeBooleanFilterSchema(optional, withAggregations)),
['equals', 'not'],
withAggregations ? ['_count', '_min', '_max'] : undefined,
);
return z.union([this.nullableIf(z.boolean(), optional), z.strictObject(components)]);
}
private makeBytesFilterSchema(optional: boolean): ZodType {
private makeBytesFilterSchema(optional: boolean, withAggregations: boolean): ZodType {
const baseSchema = z.instanceof(Uint8Array);
const components = this.makeCommonPrimitiveFilterComponents(baseSchema, optional, () =>
z.instanceof(Uint8Array),
const components = this.makeCommonPrimitiveFilterComponents(
baseSchema,
optional,
() => z.instanceof(Uint8Array),
['equals', 'in', 'notIn', 'not'],
withAggregations ? ['_count', '_min', '_max'] : undefined,
);
return z.union([
this.nullableIf(baseSchema, optional),
z.strictObject({
equals: components.equals,
in: components.in,
notIn: components.notIn,
not: components.not,
}),
]);
return z.union([this.nullableIf(baseSchema, optional), z.strictObject(components)]);
}
private makeCommonPrimitiveFilterComponents(baseSchema: ZodType, optional: boolean, makeThis: () => ZodType) {
return {
private makeCommonPrimitiveFilterComponents(
baseSchema: ZodType,
optional: boolean,
makeThis: () => ZodType,
supportedOperators: string[] | undefined = undefined,
withAggregations: Array<'_count' | '_avg' | '_sum' | '_min' | '_max'> | undefined = undefined,
) {
const commonAggSchema = () =>
this.makeCommonPrimitiveFilterSchema(baseSchema, false, makeThis, undefined).optional();
let result = {
equals: this.nullableIf(baseSchema.optional(), optional),
notEquals: this.nullableIf(baseSchema.optional(), optional),
in: baseSchema.array().optional(),
@ -502,28 +518,54 @@ export class InputValidator<Schema extends SchemaDef> {
gt: baseSchema.optional(),
gte: baseSchema.optional(),
not: makeThis().optional(),
...(withAggregations?.includes('_count')
? { _count: this.makeNumberFilterSchema(z.int(), false, false).optional() }
: {}),
...(withAggregations?.includes('_avg') ? { _avg: commonAggSchema() } : {}),
...(withAggregations?.includes('_sum') ? { _sum: commonAggSchema() } : {}),
...(withAggregations?.includes('_min') ? { _min: commonAggSchema() } : {}),
...(withAggregations?.includes('_max') ? { _max: commonAggSchema() } : {}),
};
if (supportedOperators) {
const keys = [...supportedOperators, ...(withAggregations ?? [])];
result = extractFields(result, keys) as typeof result;
}
return result;
}
private makeCommonPrimitiveFilterSchema(baseSchema: ZodType, optional: boolean, makeThis: () => ZodType) {
private makeCommonPrimitiveFilterSchema(
baseSchema: ZodType,
optional: boolean,
makeThis: () => ZodType,
withAggregations: Array<AGGREGATE_OPERATORS> | undefined = undefined,
): z.ZodType {
return z.union([
this.nullableIf(baseSchema, optional),
z.strictObject(this.makeCommonPrimitiveFilterComponents(baseSchema, optional, makeThis)),
z.strictObject(
this.makeCommonPrimitiveFilterComponents(baseSchema, optional, makeThis, undefined, withAggregations),
),
]);
}
private makeNumberFilterSchema(baseSchema: ZodType, optional: boolean): ZodType {
return this.makeCommonPrimitiveFilterSchema(baseSchema, optional, () =>
z.lazy(() => this.makeNumberFilterSchema(baseSchema, optional)),
private makeNumberFilterSchema(baseSchema: ZodType, optional: boolean, withAggregations: boolean): ZodType {
return this.makeCommonPrimitiveFilterSchema(
baseSchema,
optional,
() => z.lazy(() => this.makeNumberFilterSchema(baseSchema, optional, withAggregations)),
withAggregations ? ['_count', '_avg', '_sum', '_min', '_max'] : undefined,
);
}
private makeStringFilterSchema(optional: boolean): ZodType {
private makeStringFilterSchema(optional: boolean, withAggregations: boolean): ZodType {
return z.union([
this.nullableIf(z.string(), optional),
z.strictObject({
...this.makeCommonPrimitiveFilterComponents(z.string(), optional, () =>
z.lazy(() => this.makeStringFilterSchema(optional)),
...this.makeCommonPrimitiveFilterComponents(
z.string(),
optional,
() => z.lazy(() => this.makeStringFilterSchema(optional, withAggregations)),
undefined,
withAggregations ? ['_count', '_min', '_max'] : undefined,
),
startsWith: z.string().optional(),
endsWith: z.string().optional(),
@ -973,7 +1015,7 @@ export class InputValidator<Schema extends SchemaDef> {
return z.object({
where: this.makeWhereSchema(model, false).optional(),
data: this.makeUpdateDataSchema(model, [], true),
limit: z.number().int().nonnegative().optional(),
limit: z.int().nonnegative().optional(),
});
}
@ -1113,7 +1155,7 @@ export class InputValidator<Schema extends SchemaDef> {
return z
.object({
where: this.makeWhereSchema(model, false).optional(),
limit: z.number().int().nonnegative().optional(),
limit: z.int().nonnegative().optional(),
})
.optional();
@ -1214,7 +1256,7 @@ export class InputValidator<Schema extends SchemaDef> {
where: this.makeWhereSchema(model, false).optional(),
orderBy: this.orArray(this.makeOrderBySchema(model, false, true), true).optional(),
by: this.orArray(z.enum(nonRelationFields), true),
having: this.makeWhereSchema(model, false, true).optional(),
having: this.makeHavingSchema(model).optional(),
skip: this.makeSkipSchema().optional(),
take: this.makeTakeSchema().optional(),
_count: this.makeCountAggregateInputSchema(model).optional(),
@ -1223,26 +1265,41 @@ export class InputValidator<Schema extends SchemaDef> {
_min: this.makeMinMaxInputSchema(model).optional(),
_max: this.makeMinMaxInputSchema(model).optional(),
});
// fields used in `having` must be either in the `by` list, or aggregations
schema = schema.refine((value) => {
const bys = typeof value.by === 'string' ? [value.by] : value.by;
if (
value.having &&
Object.keys(value.having)
.filter((f) => !f.startsWith('_'))
.some((key) => !bys.includes(key))
) {
return false;
} else {
return true;
if (value.having && typeof value.having === 'object') {
for (const [key, val] of Object.entries(value.having)) {
if (AGGREGATE_OPERATORS.includes(key as any)) {
continue;
}
if (bys.includes(key)) {
continue;
}
// we have a key not mentioned in `by`, in this case it must only use
// aggregations in the condition
// 1. payload must be an object
if (!val || typeof val !== 'object') {
return false;
}
// 2. payload must only contain aggregations
if (!this.onlyAggregationFields(val)) {
return false;
}
}
}
return true;
}, 'fields in "having" must be in "by"');
// fields used in `orderBy` must be either in the `by` list, or aggregations
schema = schema.refine((value) => {
const bys = typeof value.by === 'string' ? [value.by] : value.by;
if (
value.orderBy &&
Object.keys(value.orderBy)
.filter((f) => !f.startsWith('_'))
.filter((f) => !AGGREGATE_OPERATORS.includes(f as AGGREGATE_OPERATORS))
.some((key) => !bys.includes(key))
) {
return false;
@ -1254,16 +1311,37 @@ export class InputValidator<Schema extends SchemaDef> {
return schema;
}
private onlyAggregationFields(val: object) {
for (const [key, value] of Object.entries(val)) {
if (AGGREGATE_OPERATORS.includes(key as any)) {
// aggregation field
continue;
}
if (LOGICAL_COMBINATORS.includes(key as any)) {
// logical operators
if (enumerate(value).every((v) => this.onlyAggregationFields(v))) {
continue;
}
}
return false;
}
return true;
}
private makeHavingSchema(model: GetModels<Schema>) {
return this.makeWhereSchema(model, false, true, true);
}
// #endregion
// #region Helpers
private makeSkipSchema() {
return z.number().int().nonnegative();
return z.int().nonnegative();
}
private makeTakeSchema() {
return z.number().int();
return z.int();
}
private refineForSelectIncludeMutuallyExclusive(schema: ZodType) {

View file

@ -1,5 +1,8 @@
import type { ExpressionBuilder, ExpressionWrapper } from 'kysely';
import type { Expression, ExpressionBuilder, ExpressionWrapper } from 'kysely';
import { match } from 'ts-pattern';
import { ExpressionUtils, type FieldDef, type GetModels, type ModelDef, type SchemaDef } from '../schema';
import { extractFields } from '../utils/object-utils';
import type { AGGREGATE_OPERATORS } from './constants';
import type { OrderBy } from './crud-types';
import { InternalError, QueryError } from './errors';
import type { ClientOptions } from './options';
@ -282,15 +285,6 @@ export function safeJSONStringify(value: unknown) {
});
}
export function extractFields(object: any, fields: string[]) {
return fields.reduce((acc: any, field) => {
if (field in object) {
acc[field] = object[field];
}
return acc;
}, {});
}
export function extractIdFields(entity: any, schema: SchemaDef, model: string) {
const idFields = getIdFields(schema, model);
return extractFields(entity, idFields);
@ -323,3 +317,17 @@ export function getDelegateDescendantModels(
});
return [...collected];
}
export function aggregate(
eb: ExpressionBuilder<any, any>,
expr: Expression<any>,
op: AGGREGATE_OPERATORS,
): Expression<any> {
return match(op)
.with('_count', () => eb.fn.count(expr))
.with('_sum', () => eb.fn.sum(expr))
.with('_avg', () => eb.fn.avg(expr))
.with('_min', () => eb.fn.min(expr))
.with('_max', () => eb.fn.max(expr))
.exhaustive();
}

View file

@ -2,7 +2,7 @@ import { afterEach, beforeEach, describe, expect, it } from 'vitest';
import type { ClientContract } from '../../src/client';
import { schema } from '../schemas/basic';
import { createClientSpecs } from './client-specs';
import { createUser } from './utils';
import { createPosts, createUser } from './utils';
const PG_DB_NAME = 'client-api-group-by-tests';
@ -33,6 +33,7 @@ describe.each(createClientSpecs(PG_DB_NAME))('Client groupBy tests', ({ createCl
name: 'User',
role: 'USER',
});
await createPosts(client, '1');
await expect(
client.user.groupBy({
@ -67,7 +68,7 @@ describe.each(createClientSpecs(PG_DB_NAME))('Client groupBy tests', ({ createCl
take: -2,
orderBy: { email: 'desc' },
}),
).resolves.toEqual([{ email: 'u2@test.com' }, { email: 'u1@test.com' }]);
).resolves.toEqual(expect.arrayContaining([{ email: 'u2@test.com' }, { email: 'u1@test.com' }]));
await expect(
client.user.groupBy({
@ -93,6 +94,18 @@ describe.each(createClientSpecs(PG_DB_NAME))('Client groupBy tests', ({ createCl
{ name: 'User', role: 'USER', _count: 2 },
{ name: 'Admin', role: 'ADMIN', _count: 1 },
]);
await expect(
client.post.groupBy({
by: ['published'],
_count: true,
}),
).resolves.toEqual(
expect.arrayContaining([
{ published: true, _count: 1 },
{ published: false, _count: 1 },
]),
);
});
it('works with multiple bys', async () => {
@ -130,12 +143,14 @@ describe.each(createClientSpecs(PG_DB_NAME))('Client groupBy tests', ({ createCl
it('works with different types of aggregation', async () => {
await client.profile.create({
data: {
id: '1',
age: 10,
bio: 'bio',
},
});
await client.profile.create({
data: {
id: '2',
age: 20,
bio: 'bio',
},
@ -144,26 +159,108 @@ describe.each(createClientSpecs(PG_DB_NAME))('Client groupBy tests', ({ createCl
await expect(
client.profile.groupBy({
by: ['bio'],
_count: { age: true },
_count: { age: true, id: true },
_avg: { age: true },
_sum: { age: true },
_min: { age: true },
_max: { age: true },
_min: { age: true, id: true },
_max: { age: true, id: true },
}),
).resolves.toEqual(
expect.arrayContaining([
{
bio: 'bio',
_count: { age: 2 },
_count: { age: 2, id: 2 },
_avg: { age: 15 },
_sum: { age: 30 },
_min: { age: 10 },
_max: { age: 20 },
_min: { age: 10, id: '1' },
_max: { age: 20, id: '2' },
},
]),
);
});
it('works with using aggregations in having', async () => {
await client.profile.create({
data: {
id: '1',
age: 10,
bio: 'bio1',
},
});
await client.profile.create({
data: {
id: '2',
age: 20,
bio: 'bio1',
},
});
await client.profile.create({
data: {
id: '3',
age: 30,
bio: 'bio2',
},
});
await client.profile.create({
data: {
id: '4',
age: 40,
bio: 'bio2',
},
});
await expect(
client.profile.groupBy({
by: ['bio'],
having: {
age: { _avg: { gt: 15, lt: 50 }, _sum: { equals: 70 } },
},
}),
).resolves.toEqual(expect.arrayContaining([{ bio: 'bio2' }]));
});
it('works with using aggregations in orderBy', async () => {
await client.profile.create({
data: {
id: '1',
age: 10,
bio: 'bio1',
},
});
await client.profile.create({
data: {
id: '2',
age: 20,
bio: 'bio1',
},
});
await client.profile.create({
data: {
id: '3',
age: 30,
bio: 'bio2',
},
});
await client.profile.create({
data: {
id: '4',
age: 40,
bio: 'bio2',
},
});
await expect(
client.profile.groupBy({
by: ['bio'],
orderBy: {
_avg: {
age: 'desc',
},
},
}),
).resolves.toEqual(expect.arrayContaining([{ bio: 'bio2' }]));
});
it('complains about fields in having that are not in by', async () => {
await expect(
client.profile.groupBy({

View file

@ -61,6 +61,11 @@ export const schema = {
type: "Profile",
optional: true,
relation: { opposite: "user" }
},
meta: {
name: "meta",
type: "Json",
optional: true
}
},
attributes: [

View file

@ -24,6 +24,7 @@ model User with CommonFields {
role Role @default(USER)
posts Post[]
profile Profile?
meta Json?
// Access policies
@@allow('all', auth().id == id)