mirror of
https://github.com/zenstackhq/zenstack
synced 2026-05-24 10:08:55 +00:00
WIP: create completion
This commit is contained in:
parent
cedda3cc7d
commit
2eb35ec8cd
14 changed files with 476 additions and 386 deletions
|
|
@ -1,30 +1,33 @@
|
|||
import SQLite from 'better-sqlite3';
|
||||
import { Kysely, SqliteDialect } from 'kysely';
|
||||
import { Kysely, ParseJSONResultsPlugin, type KyselyConfig } from 'kysely';
|
||||
import { type GetModels, type SchemaDef } from '../schema/schema';
|
||||
import { NotFoundError } from './errors';
|
||||
import { runCreate } from './operations/create';
|
||||
import { runFind } from './operations/find';
|
||||
import type { toKysely } from './query-builder';
|
||||
import type { DBClient, ModelOperations } from './types';
|
||||
import { runFind } from './operations/find';
|
||||
import { NotFoundError } from './errors';
|
||||
|
||||
export function makeClient<Schema extends SchemaDef>(schema: Schema) {
|
||||
return new Client<Schema>(schema) as unknown as DBClient<Schema>;
|
||||
export type ClientOptions = {
|
||||
dialect: KyselyConfig['dialect'];
|
||||
plugins?: KyselyConfig['plugins'];
|
||||
log?: KyselyConfig['log'];
|
||||
};
|
||||
|
||||
export function makeClient<Schema extends SchemaDef>(
|
||||
schema: Schema,
|
||||
options: ClientOptions
|
||||
) {
|
||||
return new Client<Schema>(schema, options) as unknown as DBClient<Schema>;
|
||||
}
|
||||
|
||||
class Client<Schema extends SchemaDef> {
|
||||
export class Client<Schema extends SchemaDef> {
|
||||
public readonly $db: Kysely<toKysely<Schema>>;
|
||||
|
||||
constructor(schema: Schema) {
|
||||
this.$db = this.createKysely(schema);
|
||||
return createClientProxy(this, schema);
|
||||
}
|
||||
|
||||
private createKysely<Schema extends SchemaDef>(
|
||||
_schema: Schema
|
||||
): Kysely<toKysely<Schema>> {
|
||||
return new Kysely({
|
||||
dialect: new SqliteDialect({ database: new SQLite(':memory:') }),
|
||||
constructor(schema: Schema, options: ClientOptions) {
|
||||
this.$db = new Kysely({
|
||||
...options,
|
||||
plugins: [...(options.plugins ?? []), new ParseJSONResultsPlugin()],
|
||||
});
|
||||
return createClientProxy(this, schema);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -63,16 +66,26 @@ function createModelProxy<
|
|||
): ModelOperations<Schema, Model> {
|
||||
return {
|
||||
create: async (args) => {
|
||||
return runCreate(db, schema, model, args);
|
||||
const r = await runCreate(
|
||||
{ db, schema, model, operation: 'create' },
|
||||
args
|
||||
);
|
||||
return r;
|
||||
},
|
||||
|
||||
findUnique: async (args) => {
|
||||
const r = await runFind(db, schema, model, 'findUnique', args);
|
||||
const r = await runFind(
|
||||
{ db, schema, model, operation: 'findUnique' },
|
||||
args
|
||||
);
|
||||
return r ?? null;
|
||||
},
|
||||
|
||||
findUniqueOrThrow: async (args) => {
|
||||
const r = await runFind(db, schema, model, 'findUnique', args);
|
||||
const r = await runFind(
|
||||
{ db, schema, model, operation: 'findUnique' },
|
||||
args
|
||||
);
|
||||
if (!r) {
|
||||
throw new NotFoundError(`No "${model}" found`);
|
||||
} else {
|
||||
|
|
@ -81,11 +94,14 @@ function createModelProxy<
|
|||
},
|
||||
|
||||
findFirst: async (args) => {
|
||||
return runFind(db, schema, model, 'findFirst', args);
|
||||
return runFind({ db, schema, model, operation: 'findFirst' }, args);
|
||||
},
|
||||
|
||||
findFirstOrThrow: async (args) => {
|
||||
const r = await runFind(db, schema, model, 'findFirst', args);
|
||||
const r = await runFind(
|
||||
{ db, schema, model, operation: 'findFirst' },
|
||||
args
|
||||
);
|
||||
if (!r) {
|
||||
throw new NotFoundError(`No "${model}" found`);
|
||||
} else {
|
||||
|
|
@ -94,7 +110,7 @@ function createModelProxy<
|
|||
},
|
||||
|
||||
findMany: async (args) => {
|
||||
return runFind(db, schema, model, 'findMany', args);
|
||||
return runFind({ db, schema, model, operation: 'findMany' }, args);
|
||||
},
|
||||
};
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,108 +0,0 @@
|
|||
import { Array } from 'effect';
|
||||
import type { SchemaDef } from '../../schema/schema';
|
||||
import { requireField, requireIdFields } from '../query-utils';
|
||||
import type { FindArgs } from '../types';
|
||||
|
||||
export function assembleResult(
|
||||
schema: SchemaDef,
|
||||
model: string,
|
||||
data: any,
|
||||
args: FindArgs<SchemaDef, string> | undefined
|
||||
) {
|
||||
if (!data) {
|
||||
return data;
|
||||
}
|
||||
|
||||
const arrayData = Array.isArray(data) ? data : [data];
|
||||
return doAssembleResult(schema, model, '$', arrayData, args);
|
||||
}
|
||||
|
||||
function doAssembleResult(
|
||||
schema: SchemaDef,
|
||||
model: string,
|
||||
path: string,
|
||||
data: any[],
|
||||
args: FindArgs<SchemaDef, string> | undefined
|
||||
) {
|
||||
const grouped = Array.groupBy(data, (item) =>
|
||||
getEntityKey(schema, model, path, item)
|
||||
);
|
||||
return Object.values(grouped).map((rows) => {
|
||||
const entity = constructEntity(schema, model, path, rows, args);
|
||||
return entity;
|
||||
});
|
||||
}
|
||||
|
||||
function getEntityKey(
|
||||
schema: SchemaDef,
|
||||
model: string,
|
||||
path: string,
|
||||
data: any
|
||||
) {
|
||||
const idFields = requireIdFields(schema, model);
|
||||
return JSON.stringify(
|
||||
idFields.reduce((acc, f) => ({ ...acc, [f]: data[`${path}>${f}`] }), {})
|
||||
);
|
||||
}
|
||||
|
||||
function constructEntity(
|
||||
schema: SchemaDef,
|
||||
model: string,
|
||||
path: string,
|
||||
rows: any[],
|
||||
args: FindArgs<SchemaDef, string> | undefined
|
||||
) {
|
||||
const result: any = {};
|
||||
|
||||
// scalar fields
|
||||
for (const [k, v] of Object.entries(rows[0])) {
|
||||
if (!k.startsWith(`${path}>`)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const field = k.substring(`${path}>`.length).split('>')[0];
|
||||
if (!field) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const fieldDef = requireField(schema, model, field);
|
||||
if (!fieldDef.relation) {
|
||||
if (!args?.select || args?.select[field]) {
|
||||
result[field] = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// relation fields
|
||||
const selectInclude = args?.select ?? args?.include;
|
||||
if (selectInclude) {
|
||||
for (const [field, payload] of Object.entries(selectInclude)) {
|
||||
if (!payload) {
|
||||
continue;
|
||||
}
|
||||
const fieldDef = requireField(schema, model, field);
|
||||
if (!fieldDef.relation) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const childSelectInclude =
|
||||
typeof payload === 'object'
|
||||
? (payload as any)[field] ?? (payload as any)[field]
|
||||
: undefined;
|
||||
const child = doAssembleResult(
|
||||
schema,
|
||||
fieldDef.type,
|
||||
`${path}>${field}`,
|
||||
rows,
|
||||
childSelectInclude
|
||||
);
|
||||
if (fieldDef.array) {
|
||||
result[field] = child;
|
||||
} else {
|
||||
result[field] = child[0] ?? null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
12
packages/runtime/src/client/operations/context.ts
Normal file
12
packages/runtime/src/client/operations/context.ts
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
import type { Kysely } from 'kysely';
|
||||
import type { SchemaDef } from '../../schema/schema';
|
||||
import type { toKysely } from '../query-builder';
|
||||
|
||||
export type Operations = 'findMany' | 'findUnique' | 'findFirst' | 'create';
|
||||
|
||||
export type OperationContext = {
|
||||
db: Kysely<toKysely<any>>;
|
||||
schema: SchemaDef;
|
||||
model: string;
|
||||
operation: Operations;
|
||||
};
|
||||
|
|
@ -11,12 +11,15 @@ import {
|
|||
import { clone } from '../../utils/clone';
|
||||
import { InternalError, QueryError } from '../errors';
|
||||
import {
|
||||
getIdValues,
|
||||
getRelationForeignKeyFieldPairs,
|
||||
isForeignKeyField,
|
||||
isScalarField,
|
||||
requireField,
|
||||
requireModelEffect,
|
||||
} from '../query-utils';
|
||||
import type { OperationContext } from './context';
|
||||
import { runQuery as runFindQuery } from './find';
|
||||
|
||||
const CreateArgsSchema = z.object({
|
||||
data: z.record(z.string(), z.any()),
|
||||
|
|
@ -34,9 +37,7 @@ const RelationPayloadSchema = z.union([
|
|||
]);
|
||||
|
||||
export function runCreate(
|
||||
db: Kysely<any>,
|
||||
schema: SchemaDef,
|
||||
model: string,
|
||||
{ db, schema, model }: OperationContext,
|
||||
args: unknown
|
||||
) {
|
||||
return Effect.runPromise(
|
||||
|
|
@ -44,7 +45,6 @@ export function runCreate(
|
|||
// parse args
|
||||
const parsedArgs = yield* parseCreateArgs(args);
|
||||
|
||||
// run query
|
||||
const result = yield* runQuery(db, schema, model, parsedArgs);
|
||||
yield* Console.log('create result:', result);
|
||||
return result;
|
||||
|
|
@ -66,12 +66,14 @@ function runQuery(
|
|||
args: CreateArgs
|
||||
) {
|
||||
return Effect.gen(function* () {
|
||||
const hasRelation = Object.keys(args.data).some(
|
||||
const hasRelationCreate = Object.keys(args.data).some(
|
||||
(f) => !!requireField(schema, model, f).relation
|
||||
);
|
||||
|
||||
const returnRelations = needReturnRelations(schema, model, args);
|
||||
|
||||
let result: any;
|
||||
if (hasRelation) {
|
||||
if (hasRelationCreate || returnRelations) {
|
||||
// employ a transaction
|
||||
result = yield* Effect.tryPromise({
|
||||
try: () =>
|
||||
|
|
@ -80,17 +82,32 @@ function runQuery(
|
|||
.setIsolationLevel('repeatable read')
|
||||
.execute(async (trx) =>
|
||||
Effect.runPromise(
|
||||
doCreate(trx, schema, model, args.data)
|
||||
Effect.gen(function* () {
|
||||
const createResult = yield* doCreate(
|
||||
trx,
|
||||
schema,
|
||||
model,
|
||||
args.data
|
||||
);
|
||||
return yield* readBackResult(
|
||||
trx,
|
||||
schema,
|
||||
model,
|
||||
createResult,
|
||||
args
|
||||
);
|
||||
})
|
||||
)
|
||||
),
|
||||
catch: (e) => new QueryError(`Error during create: ${e}`),
|
||||
});
|
||||
} else {
|
||||
// simple create
|
||||
result = yield* doCreate(db, schema, model, args.data);
|
||||
const createResult = yield* doCreate(db, schema, model, args.data);
|
||||
result = trimResult(createResult, args);
|
||||
}
|
||||
|
||||
return assembleResult(result, args);
|
||||
return result;
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -140,7 +157,10 @@ function doCreate(
|
|||
);
|
||||
}
|
||||
|
||||
const subPayload = parseRelationPayload(payload, field);
|
||||
const subPayload = yield* parseRelationPayload(
|
||||
payload,
|
||||
field
|
||||
);
|
||||
|
||||
const r = yield* Match.value(subPayload).pipe(
|
||||
Match.when(
|
||||
|
|
@ -233,6 +253,48 @@ function evalGenerator(generator: FieldGenerators) {
|
|||
);
|
||||
}
|
||||
|
||||
function assembleResult(primaryData: any, _args: CreateArgs) {
|
||||
return primaryData;
|
||||
function trimResult(data: any, args: CreateArgs) {
|
||||
if (!args.select) {
|
||||
return data;
|
||||
}
|
||||
return Object.keys(args.select).reduce((acc, field) => {
|
||||
acc[field] = data[field];
|
||||
return acc;
|
||||
}, {} as any);
|
||||
}
|
||||
|
||||
function readBackResult(
|
||||
db: Kysely<any>,
|
||||
schema: SchemaDef,
|
||||
model: string,
|
||||
primaryData: any,
|
||||
args: CreateArgs
|
||||
) {
|
||||
return Effect.gen(function* () {
|
||||
// fetch relations based on include or select
|
||||
const read = yield* runFindQuery(db, schema, model, 'findUnique', {
|
||||
where: getIdValues(schema, model, primaryData),
|
||||
select: args.select,
|
||||
include: args.include,
|
||||
});
|
||||
return read[0] ?? null;
|
||||
});
|
||||
}
|
||||
|
||||
function needReturnRelations(
|
||||
schema: SchemaDef,
|
||||
model: string,
|
||||
args: CreateArgs
|
||||
) {
|
||||
let returnRelation = false;
|
||||
|
||||
if (args.include) {
|
||||
returnRelation = Object.keys(args.include).length > 0;
|
||||
} else if (args.select) {
|
||||
returnRelation = Object.entries(args.select).some(([K, v]) => {
|
||||
const fieldDef = requireField(schema, model, K);
|
||||
return fieldDef.relation && v;
|
||||
});
|
||||
}
|
||||
return returnRelation;
|
||||
}
|
||||
|
|
|
|||
24
packages/runtime/src/client/operations/dialect/index.ts
Normal file
24
packages/runtime/src/client/operations/dialect/index.ts
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
import type { SelectQueryBuilder } from 'kysely';
|
||||
import type { SchemaDef, SupportedProviders } from '../../../schema/schema';
|
||||
import { PostgresQueryDialect } from './postgres';
|
||||
import { SqliteQueryDialect } from './sqlite';
|
||||
|
||||
export interface QueryDialect {
|
||||
buildRelationSelection(
|
||||
query: SelectQueryBuilder<any, any, {}>,
|
||||
schema: SchemaDef,
|
||||
model: string,
|
||||
relationField: string,
|
||||
parentName: string,
|
||||
_payload: any
|
||||
): SelectQueryBuilder<any, any, {}>;
|
||||
}
|
||||
|
||||
const dialects: Record<SupportedProviders, QueryDialect> = {
|
||||
postgresql: new PostgresQueryDialect(),
|
||||
sqlite: new SqliteQueryDialect(),
|
||||
};
|
||||
|
||||
export function getQueryDialect(provider: SupportedProviders) {
|
||||
return dialects[provider];
|
||||
}
|
||||
63
packages/runtime/src/client/operations/dialect/postgres.ts
Normal file
63
packages/runtime/src/client/operations/dialect/postgres.ts
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
import type { SelectQueryBuilder } from 'kysely';
|
||||
import type { QueryDialect } from '.';
|
||||
import type { SchemaDef } from '../../../schema/schema';
|
||||
import {
|
||||
getRelationForeignKeyFieldPairs,
|
||||
requireField,
|
||||
requireModel,
|
||||
} from '../../query-utils';
|
||||
|
||||
export class PostgresQueryDialect implements QueryDialect {
|
||||
buildRelationSelection(
|
||||
query: SelectQueryBuilder<any, any, {}>,
|
||||
schema: SchemaDef,
|
||||
model: string,
|
||||
relationField: string,
|
||||
parentName: string,
|
||||
_payload: any
|
||||
): SelectQueryBuilder<any, any, {}> {
|
||||
const relationFieldDef = requireField(schema, model, relationField);
|
||||
const relationModel = requireModel(schema, relationFieldDef.type);
|
||||
const keyPairs = getRelationForeignKeyFieldPairs(
|
||||
schema,
|
||||
model,
|
||||
relationField
|
||||
);
|
||||
|
||||
let result = query;
|
||||
|
||||
result = result.leftJoinLateral(
|
||||
(eb) => {
|
||||
let tbl = eb.selectFrom(relationModel.dbTable);
|
||||
|
||||
keyPairs.forEach(({ fk, pk }) => {
|
||||
tbl = tbl.whereRef(
|
||||
`${relationModel.dbTable}.${fk}`,
|
||||
'=',
|
||||
`${parentName}.${pk}`
|
||||
);
|
||||
});
|
||||
|
||||
return tbl
|
||||
.select((eb1) =>
|
||||
eb1.fn
|
||||
.coalesce(
|
||||
eb1.fn.jsonAgg(
|
||||
eb1.fn('jsonb_build_object', [])
|
||||
),
|
||||
'[]'
|
||||
)
|
||||
.as('data')
|
||||
)
|
||||
.as(`${model}$${relationField}`);
|
||||
},
|
||||
(join) => join.onTrue()
|
||||
);
|
||||
|
||||
result = result.select(
|
||||
`${model}$${relationField}.data as ${relationField}`
|
||||
);
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
||||
61
packages/runtime/src/client/operations/dialect/sqlite.ts
Normal file
61
packages/runtime/src/client/operations/dialect/sqlite.ts
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
import { sql, type SelectQueryBuilder } from 'kysely';
|
||||
import type { QueryDialect } from '.';
|
||||
import type { SchemaDef } from '../../../schema/schema';
|
||||
import {
|
||||
getRelationForeignKeyFieldPairs,
|
||||
requireField,
|
||||
requireModel,
|
||||
} from '../../query-utils';
|
||||
|
||||
export class SqliteQueryDialect implements QueryDialect {
|
||||
buildRelationSelection(
|
||||
query: SelectQueryBuilder<any, any, {}>,
|
||||
schema: SchemaDef,
|
||||
model: string,
|
||||
relationField: string,
|
||||
parentName: string,
|
||||
_payload: any
|
||||
): SelectQueryBuilder<any, any, {}> {
|
||||
const relationFieldDef = requireField(schema, model, relationField);
|
||||
const relationModel = requireModel(schema, relationFieldDef.type);
|
||||
const keyPairs = getRelationForeignKeyFieldPairs(
|
||||
schema,
|
||||
model,
|
||||
relationField
|
||||
);
|
||||
|
||||
let result = query;
|
||||
|
||||
result = result.select((eb) => {
|
||||
let tbl = eb
|
||||
.selectFrom(
|
||||
`${relationModel.dbTable} as ${parentName}$${relationField}`
|
||||
)
|
||||
.select((eb1) => {
|
||||
const objArgs = Object.keys(relationModel.fields)
|
||||
.filter((f) => !relationModel.fields[f]?.relation)
|
||||
.map((field) => [field, eb1.ref(field)])
|
||||
.flatMap((v) => v);
|
||||
|
||||
return eb1.fn
|
||||
.coalesce(
|
||||
sql`json_group_array(json_object(${sql.join(
|
||||
objArgs
|
||||
)}))`,
|
||||
sql`json_array()`
|
||||
)
|
||||
.as('data');
|
||||
});
|
||||
keyPairs.forEach(({ fk, pk }) => {
|
||||
tbl = tbl.whereRef(
|
||||
`${parentName}$${relationField}.${fk}`,
|
||||
'=',
|
||||
`${parentName}.${pk}`
|
||||
);
|
||||
});
|
||||
return tbl.as(relationField);
|
||||
});
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
|
@ -1,29 +1,20 @@
|
|||
import { Console, Effect } from 'effect';
|
||||
import type { Kysely, SelectQueryBuilder } from 'kysely';
|
||||
import { z, ZodSchema } from 'zod';
|
||||
import type { SchemaDef } from '../../schema/schema';
|
||||
import { InternalError, QueryError } from '../errors';
|
||||
import { QueryError } from '../errors';
|
||||
import {
|
||||
getRelationForeignKeyFieldPairs,
|
||||
getUniqueFields,
|
||||
isScalarField,
|
||||
requireField,
|
||||
requireModel,
|
||||
requireModelEffect,
|
||||
} from '../query-utils';
|
||||
import type { FindArgs } from '../types';
|
||||
import { assembleResult } from './common';
|
||||
import { makeIncludeSchema, makeSelectSchema, makeWhereSchema } from './parse';
|
||||
|
||||
type FindOperation = 'findMany' | 'findUnique' | 'findFirst';
|
||||
|
||||
const ROOT_ALIAS = '$';
|
||||
import type { OperationContext, Operations } from './context';
|
||||
import { getQueryDialect } from './dialect';
|
||||
import { makeFindSchema } from './parse';
|
||||
|
||||
export function runFind(
|
||||
db: Kysely<any>,
|
||||
schema: SchemaDef,
|
||||
model: string,
|
||||
operation: FindOperation,
|
||||
{ db, schema, model, operation }: OperationContext,
|
||||
args: unknown
|
||||
) {
|
||||
return Effect.runPromise(
|
||||
|
|
@ -45,7 +36,8 @@ export function runFind(
|
|||
parsedArgs
|
||||
);
|
||||
|
||||
const finalResult = operation === 'findMany' ? result : result[0];
|
||||
const finalResult =
|
||||
operation === 'findMany' ? result : result[0] ?? null;
|
||||
yield* Console.log(`${operation} result:`, finalResult);
|
||||
return finalResult;
|
||||
})
|
||||
|
|
@ -55,67 +47,14 @@ export function runFind(
|
|||
function parseFindArgs(
|
||||
schema: SchemaDef,
|
||||
model: string,
|
||||
operation: FindOperation,
|
||||
operation: Operations,
|
||||
args: unknown
|
||||
) {
|
||||
if (!args || typeof args !== 'object') {
|
||||
if (operation === 'findUnique') {
|
||||
// args is required for findUnique
|
||||
return Effect.fail(new QueryError(`Missing query args`));
|
||||
} else {
|
||||
return Effect.succeed(undefined);
|
||||
}
|
||||
}
|
||||
|
||||
const baseWhere = makeWhereSchema(schema, model);
|
||||
let where: ZodSchema = baseWhere;
|
||||
|
||||
if (operation === 'findUnique') {
|
||||
// findUnique requires at least one unique field (field set) is required
|
||||
|
||||
const uniqueFields = getUniqueFields(schema, model);
|
||||
if (uniqueFields.length === 0) {
|
||||
return Effect.fail(
|
||||
new InternalError(`Model "${model}" has no unique fields`)
|
||||
);
|
||||
}
|
||||
|
||||
if (uniqueFields.length === 1) {
|
||||
// only one unique field (set), mark the field(s) required
|
||||
where = baseWhere.required(
|
||||
uniqueFields[0]!.reduce(
|
||||
(acc, k) => ({
|
||||
...acc,
|
||||
[k.name]: true,
|
||||
}),
|
||||
{}
|
||||
)
|
||||
);
|
||||
} else {
|
||||
where = baseWhere.refine((value) => {
|
||||
// check that at least one unique field is set
|
||||
return uniqueFields.some((fields) =>
|
||||
fields.every(({ name }) => value[name] !== undefined)
|
||||
);
|
||||
}, `At least one unique field or field set must be set`);
|
||||
}
|
||||
} else {
|
||||
// where clause is optional
|
||||
where = where.optional();
|
||||
}
|
||||
|
||||
const select = makeSelectSchema(schema, model).optional();
|
||||
const include = makeIncludeSchema(schema, model).optional();
|
||||
|
||||
if ('select' in args && 'include' in args) {
|
||||
return Effect.fail(
|
||||
new QueryError(
|
||||
'Cannot use both "select" and "include" in find args'
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const findSchema = z.object({ where, select, include });
|
||||
const findSchema = makeFindSchema(
|
||||
schema,
|
||||
model,
|
||||
operation === 'findUnique'
|
||||
);
|
||||
|
||||
return Effect.try({
|
||||
try: () => findSchema.parse(args),
|
||||
|
|
@ -123,18 +62,18 @@ function parseFindArgs(
|
|||
});
|
||||
}
|
||||
|
||||
function runQuery(
|
||||
export function runQuery(
|
||||
db: Kysely<any>,
|
||||
schema: SchemaDef,
|
||||
model: string,
|
||||
operation: string,
|
||||
args: FindArgs<SchemaDef, string> | undefined
|
||||
): Effect.Effect<any, QueryError, never> {
|
||||
): Effect.Effect<any[], QueryError, never> {
|
||||
return Effect.gen(function* () {
|
||||
const modelDef = yield* requireModelEffect(schema, model);
|
||||
|
||||
// table
|
||||
let query = db.selectFrom(`${modelDef.dbTable} as ${ROOT_ALIAS}`);
|
||||
let query = db.selectFrom(`${modelDef.dbTable}`);
|
||||
|
||||
if (operation !== 'findMany') {
|
||||
query = query.limit(1);
|
||||
|
|
@ -145,16 +84,38 @@ function runQuery(
|
|||
query = buildWhere(query, args.where);
|
||||
}
|
||||
|
||||
// skip
|
||||
if (args?.skip) {
|
||||
query = query.offset(args.skip);
|
||||
}
|
||||
|
||||
// take
|
||||
if (args?.take) {
|
||||
query = query.limit(args.take);
|
||||
}
|
||||
|
||||
// select
|
||||
if (args?.select) {
|
||||
query = buildFieldSelection(schema, model, query, args?.select);
|
||||
query = buildFieldSelection(
|
||||
schema,
|
||||
model,
|
||||
query,
|
||||
args?.select,
|
||||
modelDef.dbTable
|
||||
);
|
||||
} else {
|
||||
query = buildSelectAllFields(schema, model, query);
|
||||
}
|
||||
|
||||
// include
|
||||
if (args?.include) {
|
||||
query = buildFieldSelection(schema, model, query, args?.include);
|
||||
query = buildFieldSelection(
|
||||
schema,
|
||||
model,
|
||||
query,
|
||||
args?.include,
|
||||
modelDef.dbTable
|
||||
);
|
||||
}
|
||||
|
||||
const compiled = query.compile();
|
||||
|
|
@ -170,15 +131,13 @@ function runQuery(
|
|||
});
|
||||
|
||||
yield* Console.log(`Raw results:`, rows);
|
||||
const assembled = assembleResult(schema, model, rows, args);
|
||||
return assembled;
|
||||
return rows;
|
||||
});
|
||||
}
|
||||
|
||||
function buildWhere(
|
||||
query: SelectQueryBuilder<any, string, {}>,
|
||||
where: Record<string, any> | undefined,
|
||||
tableAlias = ROOT_ALIAS
|
||||
query: SelectQueryBuilder<any, any, {}>,
|
||||
where: Record<string, any> | undefined
|
||||
) {
|
||||
let result = query;
|
||||
if (!where) {
|
||||
|
|
@ -186,12 +145,7 @@ function buildWhere(
|
|||
}
|
||||
|
||||
result = Object.entries(where).reduce(
|
||||
(acc, [field, value]) =>
|
||||
acc.where(
|
||||
tableAlias ? `${tableAlias}.${field}` : field,
|
||||
'=',
|
||||
value
|
||||
),
|
||||
(acc, [field, value]) => acc.where(field, '=', value),
|
||||
result
|
||||
);
|
||||
|
||||
|
|
@ -201,9 +155,9 @@ function buildWhere(
|
|||
function buildFieldSelection(
|
||||
schema: SchemaDef,
|
||||
model: string,
|
||||
query: SelectQueryBuilder<any, string, {}>,
|
||||
query: SelectQueryBuilder<any, any, {}>,
|
||||
selectOrInclude: Record<string, any>,
|
||||
tableAlias = ROOT_ALIAS
|
||||
parentName: string
|
||||
) {
|
||||
let result = query;
|
||||
|
||||
|
|
@ -213,15 +167,15 @@ function buildFieldSelection(
|
|||
}
|
||||
const fieldDef = requireField(schema, model, field);
|
||||
if (!fieldDef.relation) {
|
||||
result = result.select(selectField(tableAlias, field));
|
||||
result = result.select(field);
|
||||
} else {
|
||||
result = buildRelationSelection(
|
||||
result,
|
||||
schema,
|
||||
model,
|
||||
field,
|
||||
result,
|
||||
payload,
|
||||
tableAlias
|
||||
parentName,
|
||||
payload
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -230,77 +184,36 @@ function buildFieldSelection(
|
|||
}
|
||||
|
||||
function buildRelationSelection(
|
||||
query: SelectQueryBuilder<any, any, {}>,
|
||||
schema: SchemaDef,
|
||||
model: string,
|
||||
relationField: string,
|
||||
query: SelectQueryBuilder<any, string, {}>,
|
||||
payload: any,
|
||||
tableAlias = ROOT_ALIAS
|
||||
parentName: string,
|
||||
payload: any
|
||||
) {
|
||||
const relationFieldDef = requireField(schema, model, relationField);
|
||||
const relationModel = requireModel(schema, relationFieldDef.type);
|
||||
const keyPairs = getRelationForeignKeyFieldPairs(
|
||||
schema,
|
||||
model,
|
||||
relationField
|
||||
);
|
||||
|
||||
let result = query;
|
||||
|
||||
const nextAlias = joinAlias(tableAlias, relationField);
|
||||
result = result.leftJoin(
|
||||
`${relationModel.dbTable} as ${nextAlias}`,
|
||||
(join) =>
|
||||
keyPairs.reduce(
|
||||
(acc, { fk, pk }) =>
|
||||
acc.onRef(
|
||||
`${nextAlias}.${fk}`,
|
||||
'=',
|
||||
tableAlias ? `${tableAlias}.${pk}` : pk
|
||||
),
|
||||
join
|
||||
)
|
||||
);
|
||||
|
||||
if (payload === true) {
|
||||
result = buildSelectAllFields(
|
||||
schema,
|
||||
relationFieldDef.type,
|
||||
result,
|
||||
nextAlias
|
||||
);
|
||||
} else {
|
||||
result = buildFieldSelection(
|
||||
schema,
|
||||
relationFieldDef.type,
|
||||
result,
|
||||
payload,
|
||||
nextAlias
|
||||
);
|
||||
const queryDialect = getQueryDialect(schema.provider);
|
||||
if (!queryDialect) {
|
||||
throw new QueryError(`Unsupported provider: ${schema.provider}`);
|
||||
}
|
||||
|
||||
return result;
|
||||
return queryDialect.buildRelationSelection(
|
||||
query,
|
||||
schema,
|
||||
model,
|
||||
relationField,
|
||||
parentName,
|
||||
payload
|
||||
);
|
||||
}
|
||||
|
||||
function buildSelectAllFields(
|
||||
schema: SchemaDef,
|
||||
model: string,
|
||||
query: SelectQueryBuilder<any, string, {}>,
|
||||
tableAlias = ROOT_ALIAS
|
||||
query: SelectQueryBuilder<any, any, {}>
|
||||
) {
|
||||
let result = query;
|
||||
const modelDef = requireModel(schema, model);
|
||||
return Object.keys(modelDef.fields)
|
||||
.filter((f) => isScalarField(schema, model, f))
|
||||
.reduce((acc, f) => acc.select(selectField(tableAlias, f)), result);
|
||||
}
|
||||
|
||||
function joinAlias(tableAlias: string, field: string) {
|
||||
return `${tableAlias}>${field}`;
|
||||
}
|
||||
|
||||
function selectField(tableAlias: string, field: string) {
|
||||
return tableAlias
|
||||
? `${tableAlias}.${field} as ${joinAlias(tableAlias, field)}`
|
||||
: `${field} as ${field}`;
|
||||
.reduce((acc, f) => acc.select(f), result);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,11 +1,18 @@
|
|||
import { Match } from 'effect';
|
||||
import { z, ZodObject, ZodSchema } from 'zod';
|
||||
import { z, ZodSchema } from 'zod';
|
||||
import type { SchemaDef } from '../../schema/schema';
|
||||
import { requireField, requireModel } from '../query-utils';
|
||||
import { InternalError } from '../errors';
|
||||
import { getUniqueFields, requireField, requireModel } from '../query-utils';
|
||||
|
||||
const schemas = new Map<string, ZodSchema>();
|
||||
|
||||
type SchemaKinds = 'where' | 'select' | 'include';
|
||||
type SchemaKinds =
|
||||
| 'where'
|
||||
| 'whereUnique'
|
||||
| 'select'
|
||||
| 'include'
|
||||
| 'find'
|
||||
| 'findUnique';
|
||||
|
||||
function getCache(model: string, kind: SchemaKinds) {
|
||||
return schemas.get(`${model}:${kind}`);
|
||||
|
|
@ -17,9 +24,11 @@ function putCache(model: string, kind: SchemaKinds, schema: ZodSchema) {
|
|||
|
||||
export function makeWhereSchema(
|
||||
schema: SchemaDef,
|
||||
model: string
|
||||
): ZodObject<any> {
|
||||
let result = getCache(model, 'where') as ZodObject<any> | undefined;
|
||||
model: string,
|
||||
unique: boolean
|
||||
): ZodSchema {
|
||||
const cacheKey = unique ? 'whereUnique' : 'where';
|
||||
let result = getCache(model, cacheKey);
|
||||
if (result) {
|
||||
return result;
|
||||
}
|
||||
|
|
@ -30,24 +39,51 @@ export function makeWhereSchema(
|
|||
const fieldDef = requireField(schema, model, field);
|
||||
if (fieldDef.relation) {
|
||||
fields[field] = z.lazy(() =>
|
||||
makeWhereSchema(schema, fieldDef.type).optional()
|
||||
makeWhereSchema(schema, fieldDef.type, false).optional()
|
||||
);
|
||||
} else {
|
||||
fields[field] = makePrimitiveSchema(fieldDef.type).optional();
|
||||
}
|
||||
}
|
||||
|
||||
result = z.object(fields);
|
||||
putCache(model, 'where', result);
|
||||
const baseWhere = z.object(fields);
|
||||
result = baseWhere;
|
||||
|
||||
if (unique) {
|
||||
// requires at least one unique field (field set) is required
|
||||
const uniqueFields = getUniqueFields(schema, model);
|
||||
if (uniqueFields.length === 0) {
|
||||
throw new InternalError(`Model "${model}" has no unique fields`);
|
||||
}
|
||||
|
||||
if (uniqueFields.length === 1) {
|
||||
// only one unique field (set), mark the field(s) required
|
||||
result = baseWhere.required(
|
||||
uniqueFields[0]!.reduce(
|
||||
(acc, k) => ({
|
||||
...acc,
|
||||
[k.name]: true,
|
||||
}),
|
||||
{}
|
||||
)
|
||||
);
|
||||
} else {
|
||||
result = baseWhere.refine((value) => {
|
||||
// check that at least one unique field is set
|
||||
return uniqueFields.some((fields) =>
|
||||
fields.every(({ name }) => value[name] !== undefined)
|
||||
);
|
||||
}, `At least one unique field or field set must be set`);
|
||||
}
|
||||
}
|
||||
|
||||
putCache(model, cacheKey, result);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
export function makeSelectSchema(
|
||||
schema: SchemaDef,
|
||||
model: string
|
||||
): ZodObject<any> {
|
||||
let result = getCache(model, 'select') as ZodObject<any> | undefined;
|
||||
export function makeSelectSchema(schema: SchemaDef, model: string) {
|
||||
let result = getCache(model, 'select');
|
||||
if (result) {
|
||||
return result;
|
||||
}
|
||||
|
|
@ -82,11 +118,8 @@ export function makeSelectSchema(
|
|||
return result;
|
||||
}
|
||||
|
||||
export function makeIncludeSchema(
|
||||
schema: SchemaDef,
|
||||
model: string
|
||||
): ZodObject<any> {
|
||||
let result = getCache(model, 'include') as ZodObject<any> | undefined;
|
||||
export function makeIncludeSchema(schema: SchemaDef, model: string) {
|
||||
let result = getCache(model, 'include');
|
||||
if (result) {
|
||||
return result;
|
||||
}
|
||||
|
|
@ -119,6 +152,41 @@ export function makeIncludeSchema(
|
|||
return result;
|
||||
}
|
||||
|
||||
export function makeFindSchema(
|
||||
schema: SchemaDef,
|
||||
model: string,
|
||||
unique: boolean
|
||||
) {
|
||||
const cacheKey = unique ? 'findUnique' : 'find';
|
||||
let result = getCache(model, cacheKey);
|
||||
if (result) {
|
||||
return result;
|
||||
}
|
||||
|
||||
const where = makeWhereSchema(schema, model, unique);
|
||||
const select = makeSelectSchema(schema, model);
|
||||
const include = makeIncludeSchema(schema, model);
|
||||
result = z
|
||||
.object({
|
||||
where: unique ? where : where.optional(),
|
||||
select: select.optional(),
|
||||
include: include.optional(),
|
||||
skip: z.number().int().nonnegative().optional(),
|
||||
take: z.number().int().nonnegative().optional(),
|
||||
})
|
||||
.refine(
|
||||
(value) => !value.select || !value.include,
|
||||
'"select" and "include" cannot be used together'
|
||||
);
|
||||
|
||||
if (!unique) {
|
||||
result = result.optional();
|
||||
}
|
||||
|
||||
putCache(model, cacheKey, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
function makePrimitiveSchema(type: string) {
|
||||
return Match.value(type).pipe(
|
||||
Match.when('String', () => z.string()),
|
||||
|
|
@ -131,5 +199,3 @@ function makePrimitiveSchema(type: string) {
|
|||
Match.orElse(() => z.unknown())
|
||||
);
|
||||
}
|
||||
|
||||
//#endregion
|
||||
|
|
|
|||
|
|
@ -157,10 +157,13 @@ export function getIdValues(
|
|||
schema: SchemaDef,
|
||||
model: string,
|
||||
data: any
|
||||
): Array<{ field: string; value: any }> {
|
||||
): Record<string, any> {
|
||||
const idFields = getIdFields(schema, model);
|
||||
if (!idFields) {
|
||||
throw new InternalError(`ID fields not defined for model "${model}"`);
|
||||
}
|
||||
return idFields.map((field) => ({ field, value: data[field] }));
|
||||
return idFields.reduce(
|
||||
(acc, field) => ({ ...acc, [field]: data[field] }),
|
||||
{}
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -204,14 +204,6 @@ export type MapFieldType<
|
|||
Field extends GetFields<Schema, Model>
|
||||
> = MapFieldDefType<Schema, GetField<Schema, Model, Field>>;
|
||||
|
||||
// WrapType<
|
||||
// GetFieldType<Schema, Model, Field> extends GetEnums<Schema>
|
||||
// ? 'foo'
|
||||
// : MapBaseType<GetField<Schema, Model, Field>>,
|
||||
// FieldIsOptional<Schema, Model, Field>,
|
||||
// FieldIsArray<Schema, Model, Field>
|
||||
// >;
|
||||
|
||||
type MapFieldDefType<
|
||||
Schema extends SchemaDef,
|
||||
T extends Pick<FieldDef, 'type' | 'optional' | 'array'>
|
||||
|
|
@ -303,6 +295,8 @@ export type FindArgs<
|
|||
Model extends GetModels<Schema>
|
||||
> = {
|
||||
where?: Where<Schema, Model>;
|
||||
skip?: number;
|
||||
take?: number;
|
||||
} & SelectInclude<Schema, Model>;
|
||||
|
||||
export type FindUniqueArgs<
|
||||
|
|
|
|||
|
|
@ -22,8 +22,21 @@ export type MapBaseType<T> = T extends 'String'
|
|||
? Decimal
|
||||
: T extends 'DateTime'
|
||||
? Date
|
||||
: T extends 'Json'
|
||||
? JsonValue
|
||||
: unknown;
|
||||
|
||||
export type JsonValue =
|
||||
| string
|
||||
| number
|
||||
| boolean
|
||||
| null
|
||||
| JsonObject
|
||||
| JsonArray;
|
||||
|
||||
export type JsonObject = { [key: string]: JsonValue };
|
||||
export type JsonArray = Array<JsonValue>;
|
||||
|
||||
export type Simplify<T> = { [Key in keyof T]: T[Key] } & {};
|
||||
|
||||
export function call(code: string) {
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import Sqlite from 'better-sqlite3';
|
||||
import { SqliteDialect } from 'kysely';
|
||||
import { beforeEach, describe, expect, it } from 'vitest';
|
||||
import { makeClient } from '../../src/client';
|
||||
import type { DBClient } from '../../src/client/types';
|
||||
|
|
@ -7,7 +9,11 @@ describe('Client API create tests', () => {
|
|||
let client: DBClient<typeof Schema>;
|
||||
|
||||
beforeEach(async () => {
|
||||
client = makeClient(Schema);
|
||||
client = makeClient(Schema, {
|
||||
dialect: new SqliteDialect({
|
||||
database: new Sqlite(':memory:'),
|
||||
}),
|
||||
});
|
||||
await pushSchema(client.$db);
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -1,14 +1,20 @@
|
|||
import Sqlite from 'better-sqlite3';
|
||||
import { SqliteDialect } from 'kysely';
|
||||
import { beforeEach, describe, expect, it } from 'vitest';
|
||||
import { makeClient } from '../../src/client';
|
||||
import { NotFoundError } from '../../src/client/errors';
|
||||
import type { DBClient } from '../../src/client/types';
|
||||
import { pushSchema, Schema } from '../test-schema';
|
||||
import { NotFoundError } from '../../src/client/errors';
|
||||
|
||||
describe('Client API find tests', () => {
|
||||
let client: DBClient<typeof Schema>;
|
||||
|
||||
beforeEach(async () => {
|
||||
client = makeClient(Schema);
|
||||
client = makeClient(Schema, {
|
||||
dialect: new SqliteDialect({
|
||||
database: new Sqlite(':memory:'),
|
||||
}),
|
||||
});
|
||||
await pushSchema(client.$db);
|
||||
});
|
||||
|
||||
|
|
@ -111,62 +117,21 @@ describe('Client API find tests', () => {
|
|||
const user = await createUser();
|
||||
await createPosts(user.id);
|
||||
|
||||
const q = client.$db
|
||||
.selectFrom('user')
|
||||
.where('user.id', 'in', (qb) =>
|
||||
qb
|
||||
.selectFrom('user')
|
||||
.select('id')
|
||||
.orderBy('createdAt desc')
|
||||
.limit(1)
|
||||
)
|
||||
.leftJoin(
|
||||
(eb) =>
|
||||
eb
|
||||
.selectFrom('post')
|
||||
.select([
|
||||
'post.id',
|
||||
'post.title',
|
||||
'post.authorId',
|
||||
'author.email as authorEmail',
|
||||
])
|
||||
.where('post.published', '!=', 1 as any)
|
||||
.leftJoin(
|
||||
(eb1) =>
|
||||
eb1.selectFrom('user').selectAll().as('author'),
|
||||
(join) =>
|
||||
join.onRef('author.id', '=', 'post.authorId')
|
||||
)
|
||||
.as('post'),
|
||||
(join) => join.onRef('post.authorId', '=', 'user.id')
|
||||
)
|
||||
.select([
|
||||
'user.id as user.id',
|
||||
'post.id as post.id',
|
||||
'post.title as post.title',
|
||||
'post.authorEmail',
|
||||
]);
|
||||
const { sql, parameters } = q.compile();
|
||||
console.log('SQL:', sql, 'PARAMS', parameters);
|
||||
console.log(await q.execute());
|
||||
let r = await client.user.findUnique({
|
||||
where: { id: '1' },
|
||||
select: { id: true, email: true, posts: true },
|
||||
});
|
||||
expect(r?.id).toBeTruthy();
|
||||
expect(r?.email).toBeTruthy();
|
||||
expect('name' in r!).toBeFalsy();
|
||||
expect(r?.posts).toHaveLength(2);
|
||||
|
||||
// let r = await client.user.findUnique({
|
||||
// where: { id: '1' },
|
||||
// select: { id: true, email: true, posts: true },
|
||||
// });
|
||||
// expect(r?.id).toBeTruthy();
|
||||
// expect(r?.email).toBeTruthy();
|
||||
// expect('name' in r!).toBeFalsy();
|
||||
// expect(r?.posts).toHaveLength(2);
|
||||
|
||||
// await expect(
|
||||
// client.user.findUnique({
|
||||
// where: { id: '1' },
|
||||
// select: { id: true, email: true },
|
||||
// include: { posts: true },
|
||||
// } as any)
|
||||
// ).rejects.toThrow(
|
||||
// 'Cannot use both "select" and "include" in find args'
|
||||
// );
|
||||
await expect(
|
||||
client.user.findUnique({
|
||||
where: { id: '1' },
|
||||
select: { id: true, email: true },
|
||||
include: { posts: true },
|
||||
} as any)
|
||||
).rejects.toThrow('cannot be used together');
|
||||
});
|
||||
});
|
||||
|
|
|
|||
Loading…
Reference in a new issue