WIP: create completion

This commit is contained in:
ymc9 2025-02-03 19:59:44 -08:00
parent cedda3cc7d
commit 2eb35ec8cd
14 changed files with 476 additions and 386 deletions

View file

@ -1,30 +1,33 @@
import SQLite from 'better-sqlite3';
import { Kysely, SqliteDialect } from 'kysely';
import { Kysely, ParseJSONResultsPlugin, type KyselyConfig } from 'kysely';
import { type GetModels, type SchemaDef } from '../schema/schema';
import { NotFoundError } from './errors';
import { runCreate } from './operations/create';
import { runFind } from './operations/find';
import type { toKysely } from './query-builder';
import type { DBClient, ModelOperations } from './types';
import { runFind } from './operations/find';
import { NotFoundError } from './errors';
export function makeClient<Schema extends SchemaDef>(schema: Schema) {
return new Client<Schema>(schema) as unknown as DBClient<Schema>;
export type ClientOptions = {
dialect: KyselyConfig['dialect'];
plugins?: KyselyConfig['plugins'];
log?: KyselyConfig['log'];
};
export function makeClient<Schema extends SchemaDef>(
schema: Schema,
options: ClientOptions
) {
return new Client<Schema>(schema, options) as unknown as DBClient<Schema>;
}
class Client<Schema extends SchemaDef> {
export class Client<Schema extends SchemaDef> {
public readonly $db: Kysely<toKysely<Schema>>;
constructor(schema: Schema) {
this.$db = this.createKysely(schema);
return createClientProxy(this, schema);
}
private createKysely<Schema extends SchemaDef>(
_schema: Schema
): Kysely<toKysely<Schema>> {
return new Kysely({
dialect: new SqliteDialect({ database: new SQLite(':memory:') }),
constructor(schema: Schema, options: ClientOptions) {
this.$db = new Kysely({
...options,
plugins: [...(options.plugins ?? []), new ParseJSONResultsPlugin()],
});
return createClientProxy(this, schema);
}
}
@ -63,16 +66,26 @@ function createModelProxy<
): ModelOperations<Schema, Model> {
return {
create: async (args) => {
return runCreate(db, schema, model, args);
const r = await runCreate(
{ db, schema, model, operation: 'create' },
args
);
return r;
},
findUnique: async (args) => {
const r = await runFind(db, schema, model, 'findUnique', args);
const r = await runFind(
{ db, schema, model, operation: 'findUnique' },
args
);
return r ?? null;
},
findUniqueOrThrow: async (args) => {
const r = await runFind(db, schema, model, 'findUnique', args);
const r = await runFind(
{ db, schema, model, operation: 'findUnique' },
args
);
if (!r) {
throw new NotFoundError(`No "${model}" found`);
} else {
@ -81,11 +94,14 @@ function createModelProxy<
},
findFirst: async (args) => {
return runFind(db, schema, model, 'findFirst', args);
return runFind({ db, schema, model, operation: 'findFirst' }, args);
},
findFirstOrThrow: async (args) => {
const r = await runFind(db, schema, model, 'findFirst', args);
const r = await runFind(
{ db, schema, model, operation: 'findFirst' },
args
);
if (!r) {
throw new NotFoundError(`No "${model}" found`);
} else {
@ -94,7 +110,7 @@ function createModelProxy<
},
findMany: async (args) => {
return runFind(db, schema, model, 'findMany', args);
return runFind({ db, schema, model, operation: 'findMany' }, args);
},
};
}

View file

@ -1,108 +0,0 @@
import { Array } from 'effect';
import type { SchemaDef } from '../../schema/schema';
import { requireField, requireIdFields } from '../query-utils';
import type { FindArgs } from '../types';
export function assembleResult(
schema: SchemaDef,
model: string,
data: any,
args: FindArgs<SchemaDef, string> | undefined
) {
if (!data) {
return data;
}
const arrayData = Array.isArray(data) ? data : [data];
return doAssembleResult(schema, model, '$', arrayData, args);
}
function doAssembleResult(
schema: SchemaDef,
model: string,
path: string,
data: any[],
args: FindArgs<SchemaDef, string> | undefined
) {
const grouped = Array.groupBy(data, (item) =>
getEntityKey(schema, model, path, item)
);
return Object.values(grouped).map((rows) => {
const entity = constructEntity(schema, model, path, rows, args);
return entity;
});
}
function getEntityKey(
schema: SchemaDef,
model: string,
path: string,
data: any
) {
const idFields = requireIdFields(schema, model);
return JSON.stringify(
idFields.reduce((acc, f) => ({ ...acc, [f]: data[`${path}>${f}`] }), {})
);
}
function constructEntity(
schema: SchemaDef,
model: string,
path: string,
rows: any[],
args: FindArgs<SchemaDef, string> | undefined
) {
const result: any = {};
// scalar fields
for (const [k, v] of Object.entries(rows[0])) {
if (!k.startsWith(`${path}>`)) {
continue;
}
const field = k.substring(`${path}>`.length).split('>')[0];
if (!field) {
continue;
}
const fieldDef = requireField(schema, model, field);
if (!fieldDef.relation) {
if (!args?.select || args?.select[field]) {
result[field] = v;
}
}
}
// relation fields
const selectInclude = args?.select ?? args?.include;
if (selectInclude) {
for (const [field, payload] of Object.entries(selectInclude)) {
if (!payload) {
continue;
}
const fieldDef = requireField(schema, model, field);
if (!fieldDef.relation) {
continue;
}
const childSelectInclude =
typeof payload === 'object'
? (payload as any)[field] ?? (payload as any)[field]
: undefined;
const child = doAssembleResult(
schema,
fieldDef.type,
`${path}>${field}`,
rows,
childSelectInclude
);
if (fieldDef.array) {
result[field] = child;
} else {
result[field] = child[0] ?? null;
}
}
}
return result;
}

View file

@ -0,0 +1,12 @@
import type { Kysely } from 'kysely';
import type { SchemaDef } from '../../schema/schema';
import type { toKysely } from '../query-builder';
export type Operations = 'findMany' | 'findUnique' | 'findFirst' | 'create';
export type OperationContext = {
db: Kysely<toKysely<any>>;
schema: SchemaDef;
model: string;
operation: Operations;
};

View file

@ -11,12 +11,15 @@ import {
import { clone } from '../../utils/clone';
import { InternalError, QueryError } from '../errors';
import {
getIdValues,
getRelationForeignKeyFieldPairs,
isForeignKeyField,
isScalarField,
requireField,
requireModelEffect,
} from '../query-utils';
import type { OperationContext } from './context';
import { runQuery as runFindQuery } from './find';
const CreateArgsSchema = z.object({
data: z.record(z.string(), z.any()),
@ -34,9 +37,7 @@ const RelationPayloadSchema = z.union([
]);
export function runCreate(
db: Kysely<any>,
schema: SchemaDef,
model: string,
{ db, schema, model }: OperationContext,
args: unknown
) {
return Effect.runPromise(
@ -44,7 +45,6 @@ export function runCreate(
// parse args
const parsedArgs = yield* parseCreateArgs(args);
// run query
const result = yield* runQuery(db, schema, model, parsedArgs);
yield* Console.log('create result:', result);
return result;
@ -66,12 +66,14 @@ function runQuery(
args: CreateArgs
) {
return Effect.gen(function* () {
const hasRelation = Object.keys(args.data).some(
const hasRelationCreate = Object.keys(args.data).some(
(f) => !!requireField(schema, model, f).relation
);
const returnRelations = needReturnRelations(schema, model, args);
let result: any;
if (hasRelation) {
if (hasRelationCreate || returnRelations) {
// employ a transaction
result = yield* Effect.tryPromise({
try: () =>
@ -80,17 +82,32 @@ function runQuery(
.setIsolationLevel('repeatable read')
.execute(async (trx) =>
Effect.runPromise(
doCreate(trx, schema, model, args.data)
Effect.gen(function* () {
const createResult = yield* doCreate(
trx,
schema,
model,
args.data
);
return yield* readBackResult(
trx,
schema,
model,
createResult,
args
);
})
)
),
catch: (e) => new QueryError(`Error during create: ${e}`),
});
} else {
// simple create
result = yield* doCreate(db, schema, model, args.data);
const createResult = yield* doCreate(db, schema, model, args.data);
result = trimResult(createResult, args);
}
return assembleResult(result, args);
return result;
});
}
@ -140,7 +157,10 @@ function doCreate(
);
}
const subPayload = parseRelationPayload(payload, field);
const subPayload = yield* parseRelationPayload(
payload,
field
);
const r = yield* Match.value(subPayload).pipe(
Match.when(
@ -233,6 +253,48 @@ function evalGenerator(generator: FieldGenerators) {
);
}
function assembleResult(primaryData: any, _args: CreateArgs) {
return primaryData;
function trimResult(data: any, args: CreateArgs) {
if (!args.select) {
return data;
}
return Object.keys(args.select).reduce((acc, field) => {
acc[field] = data[field];
return acc;
}, {} as any);
}
function readBackResult(
db: Kysely<any>,
schema: SchemaDef,
model: string,
primaryData: any,
args: CreateArgs
) {
return Effect.gen(function* () {
// fetch relations based on include or select
const read = yield* runFindQuery(db, schema, model, 'findUnique', {
where: getIdValues(schema, model, primaryData),
select: args.select,
include: args.include,
});
return read[0] ?? null;
});
}
function needReturnRelations(
schema: SchemaDef,
model: string,
args: CreateArgs
) {
let returnRelation = false;
if (args.include) {
returnRelation = Object.keys(args.include).length > 0;
} else if (args.select) {
returnRelation = Object.entries(args.select).some(([K, v]) => {
const fieldDef = requireField(schema, model, K);
return fieldDef.relation && v;
});
}
return returnRelation;
}

View file

@ -0,0 +1,24 @@
import type { SelectQueryBuilder } from 'kysely';
import type { SchemaDef, SupportedProviders } from '../../../schema/schema';
import { PostgresQueryDialect } from './postgres';
import { SqliteQueryDialect } from './sqlite';
export interface QueryDialect {
buildRelationSelection(
query: SelectQueryBuilder<any, any, {}>,
schema: SchemaDef,
model: string,
relationField: string,
parentName: string,
_payload: any
): SelectQueryBuilder<any, any, {}>;
}
const dialects: Record<SupportedProviders, QueryDialect> = {
postgresql: new PostgresQueryDialect(),
sqlite: new SqliteQueryDialect(),
};
export function getQueryDialect(provider: SupportedProviders) {
return dialects[provider];
}

View file

@ -0,0 +1,63 @@
import type { SelectQueryBuilder } from 'kysely';
import type { QueryDialect } from '.';
import type { SchemaDef } from '../../../schema/schema';
import {
getRelationForeignKeyFieldPairs,
requireField,
requireModel,
} from '../../query-utils';
export class PostgresQueryDialect implements QueryDialect {
buildRelationSelection(
query: SelectQueryBuilder<any, any, {}>,
schema: SchemaDef,
model: string,
relationField: string,
parentName: string,
_payload: any
): SelectQueryBuilder<any, any, {}> {
const relationFieldDef = requireField(schema, model, relationField);
const relationModel = requireModel(schema, relationFieldDef.type);
const keyPairs = getRelationForeignKeyFieldPairs(
schema,
model,
relationField
);
let result = query;
result = result.leftJoinLateral(
(eb) => {
let tbl = eb.selectFrom(relationModel.dbTable);
keyPairs.forEach(({ fk, pk }) => {
tbl = tbl.whereRef(
`${relationModel.dbTable}.${fk}`,
'=',
`${parentName}.${pk}`
);
});
return tbl
.select((eb1) =>
eb1.fn
.coalesce(
eb1.fn.jsonAgg(
eb1.fn('jsonb_build_object', [])
),
'[]'
)
.as('data')
)
.as(`${model}$${relationField}`);
},
(join) => join.onTrue()
);
result = result.select(
`${model}$${relationField}.data as ${relationField}`
);
return result;
}
}

View file

@ -0,0 +1,61 @@
import { sql, type SelectQueryBuilder } from 'kysely';
import type { QueryDialect } from '.';
import type { SchemaDef } from '../../../schema/schema';
import {
getRelationForeignKeyFieldPairs,
requireField,
requireModel,
} from '../../query-utils';
export class SqliteQueryDialect implements QueryDialect {
buildRelationSelection(
query: SelectQueryBuilder<any, any, {}>,
schema: SchemaDef,
model: string,
relationField: string,
parentName: string,
_payload: any
): SelectQueryBuilder<any, any, {}> {
const relationFieldDef = requireField(schema, model, relationField);
const relationModel = requireModel(schema, relationFieldDef.type);
const keyPairs = getRelationForeignKeyFieldPairs(
schema,
model,
relationField
);
let result = query;
result = result.select((eb) => {
let tbl = eb
.selectFrom(
`${relationModel.dbTable} as ${parentName}$${relationField}`
)
.select((eb1) => {
const objArgs = Object.keys(relationModel.fields)
.filter((f) => !relationModel.fields[f]?.relation)
.map((field) => [field, eb1.ref(field)])
.flatMap((v) => v);
return eb1.fn
.coalesce(
sql`json_group_array(json_object(${sql.join(
objArgs
)}))`,
sql`json_array()`
)
.as('data');
});
keyPairs.forEach(({ fk, pk }) => {
tbl = tbl.whereRef(
`${parentName}$${relationField}.${fk}`,
'=',
`${parentName}.${pk}`
);
});
return tbl.as(relationField);
});
return result;
}
}

View file

@ -1,29 +1,20 @@
import { Console, Effect } from 'effect';
import type { Kysely, SelectQueryBuilder } from 'kysely';
import { z, ZodSchema } from 'zod';
import type { SchemaDef } from '../../schema/schema';
import { InternalError, QueryError } from '../errors';
import { QueryError } from '../errors';
import {
getRelationForeignKeyFieldPairs,
getUniqueFields,
isScalarField,
requireField,
requireModel,
requireModelEffect,
} from '../query-utils';
import type { FindArgs } from '../types';
import { assembleResult } from './common';
import { makeIncludeSchema, makeSelectSchema, makeWhereSchema } from './parse';
type FindOperation = 'findMany' | 'findUnique' | 'findFirst';
const ROOT_ALIAS = '$';
import type { OperationContext, Operations } from './context';
import { getQueryDialect } from './dialect';
import { makeFindSchema } from './parse';
export function runFind(
db: Kysely<any>,
schema: SchemaDef,
model: string,
operation: FindOperation,
{ db, schema, model, operation }: OperationContext,
args: unknown
) {
return Effect.runPromise(
@ -45,7 +36,8 @@ export function runFind(
parsedArgs
);
const finalResult = operation === 'findMany' ? result : result[0];
const finalResult =
operation === 'findMany' ? result : result[0] ?? null;
yield* Console.log(`${operation} result:`, finalResult);
return finalResult;
})
@ -55,67 +47,14 @@ export function runFind(
function parseFindArgs(
schema: SchemaDef,
model: string,
operation: FindOperation,
operation: Operations,
args: unknown
) {
if (!args || typeof args !== 'object') {
if (operation === 'findUnique') {
// args is required for findUnique
return Effect.fail(new QueryError(`Missing query args`));
} else {
return Effect.succeed(undefined);
}
}
const baseWhere = makeWhereSchema(schema, model);
let where: ZodSchema = baseWhere;
if (operation === 'findUnique') {
// findUnique requires at least one unique field (field set) is required
const uniqueFields = getUniqueFields(schema, model);
if (uniqueFields.length === 0) {
return Effect.fail(
new InternalError(`Model "${model}" has no unique fields`)
);
}
if (uniqueFields.length === 1) {
// only one unique field (set), mark the field(s) required
where = baseWhere.required(
uniqueFields[0]!.reduce(
(acc, k) => ({
...acc,
[k.name]: true,
}),
{}
)
);
} else {
where = baseWhere.refine((value) => {
// check that at least one unique field is set
return uniqueFields.some((fields) =>
fields.every(({ name }) => value[name] !== undefined)
);
}, `At least one unique field or field set must be set`);
}
} else {
// where clause is optional
where = where.optional();
}
const select = makeSelectSchema(schema, model).optional();
const include = makeIncludeSchema(schema, model).optional();
if ('select' in args && 'include' in args) {
return Effect.fail(
new QueryError(
'Cannot use both "select" and "include" in find args'
)
);
}
const findSchema = z.object({ where, select, include });
const findSchema = makeFindSchema(
schema,
model,
operation === 'findUnique'
);
return Effect.try({
try: () => findSchema.parse(args),
@ -123,18 +62,18 @@ function parseFindArgs(
});
}
function runQuery(
export function runQuery(
db: Kysely<any>,
schema: SchemaDef,
model: string,
operation: string,
args: FindArgs<SchemaDef, string> | undefined
): Effect.Effect<any, QueryError, never> {
): Effect.Effect<any[], QueryError, never> {
return Effect.gen(function* () {
const modelDef = yield* requireModelEffect(schema, model);
// table
let query = db.selectFrom(`${modelDef.dbTable} as ${ROOT_ALIAS}`);
let query = db.selectFrom(`${modelDef.dbTable}`);
if (operation !== 'findMany') {
query = query.limit(1);
@ -145,16 +84,38 @@ function runQuery(
query = buildWhere(query, args.where);
}
// skip
if (args?.skip) {
query = query.offset(args.skip);
}
// take
if (args?.take) {
query = query.limit(args.take);
}
// select
if (args?.select) {
query = buildFieldSelection(schema, model, query, args?.select);
query = buildFieldSelection(
schema,
model,
query,
args?.select,
modelDef.dbTable
);
} else {
query = buildSelectAllFields(schema, model, query);
}
// include
if (args?.include) {
query = buildFieldSelection(schema, model, query, args?.include);
query = buildFieldSelection(
schema,
model,
query,
args?.include,
modelDef.dbTable
);
}
const compiled = query.compile();
@ -170,15 +131,13 @@ function runQuery(
});
yield* Console.log(`Raw results:`, rows);
const assembled = assembleResult(schema, model, rows, args);
return assembled;
return rows;
});
}
function buildWhere(
query: SelectQueryBuilder<any, string, {}>,
where: Record<string, any> | undefined,
tableAlias = ROOT_ALIAS
query: SelectQueryBuilder<any, any, {}>,
where: Record<string, any> | undefined
) {
let result = query;
if (!where) {
@ -186,12 +145,7 @@ function buildWhere(
}
result = Object.entries(where).reduce(
(acc, [field, value]) =>
acc.where(
tableAlias ? `${tableAlias}.${field}` : field,
'=',
value
),
(acc, [field, value]) => acc.where(field, '=', value),
result
);
@ -201,9 +155,9 @@ function buildWhere(
function buildFieldSelection(
schema: SchemaDef,
model: string,
query: SelectQueryBuilder<any, string, {}>,
query: SelectQueryBuilder<any, any, {}>,
selectOrInclude: Record<string, any>,
tableAlias = ROOT_ALIAS
parentName: string
) {
let result = query;
@ -213,15 +167,15 @@ function buildFieldSelection(
}
const fieldDef = requireField(schema, model, field);
if (!fieldDef.relation) {
result = result.select(selectField(tableAlias, field));
result = result.select(field);
} else {
result = buildRelationSelection(
result,
schema,
model,
field,
result,
payload,
tableAlias
parentName,
payload
);
}
}
@ -230,77 +184,36 @@ function buildFieldSelection(
}
function buildRelationSelection(
query: SelectQueryBuilder<any, any, {}>,
schema: SchemaDef,
model: string,
relationField: string,
query: SelectQueryBuilder<any, string, {}>,
payload: any,
tableAlias = ROOT_ALIAS
parentName: string,
payload: any
) {
const relationFieldDef = requireField(schema, model, relationField);
const relationModel = requireModel(schema, relationFieldDef.type);
const keyPairs = getRelationForeignKeyFieldPairs(
schema,
model,
relationField
);
let result = query;
const nextAlias = joinAlias(tableAlias, relationField);
result = result.leftJoin(
`${relationModel.dbTable} as ${nextAlias}`,
(join) =>
keyPairs.reduce(
(acc, { fk, pk }) =>
acc.onRef(
`${nextAlias}.${fk}`,
'=',
tableAlias ? `${tableAlias}.${pk}` : pk
),
join
)
);
if (payload === true) {
result = buildSelectAllFields(
schema,
relationFieldDef.type,
result,
nextAlias
);
} else {
result = buildFieldSelection(
schema,
relationFieldDef.type,
result,
payload,
nextAlias
);
const queryDialect = getQueryDialect(schema.provider);
if (!queryDialect) {
throw new QueryError(`Unsupported provider: ${schema.provider}`);
}
return result;
return queryDialect.buildRelationSelection(
query,
schema,
model,
relationField,
parentName,
payload
);
}
function buildSelectAllFields(
schema: SchemaDef,
model: string,
query: SelectQueryBuilder<any, string, {}>,
tableAlias = ROOT_ALIAS
query: SelectQueryBuilder<any, any, {}>
) {
let result = query;
const modelDef = requireModel(schema, model);
return Object.keys(modelDef.fields)
.filter((f) => isScalarField(schema, model, f))
.reduce((acc, f) => acc.select(selectField(tableAlias, f)), result);
}
function joinAlias(tableAlias: string, field: string) {
return `${tableAlias}>${field}`;
}
function selectField(tableAlias: string, field: string) {
return tableAlias
? `${tableAlias}.${field} as ${joinAlias(tableAlias, field)}`
: `${field} as ${field}`;
.reduce((acc, f) => acc.select(f), result);
}

View file

@ -1,11 +1,18 @@
import { Match } from 'effect';
import { z, ZodObject, ZodSchema } from 'zod';
import { z, ZodSchema } from 'zod';
import type { SchemaDef } from '../../schema/schema';
import { requireField, requireModel } from '../query-utils';
import { InternalError } from '../errors';
import { getUniqueFields, requireField, requireModel } from '../query-utils';
const schemas = new Map<string, ZodSchema>();
type SchemaKinds = 'where' | 'select' | 'include';
type SchemaKinds =
| 'where'
| 'whereUnique'
| 'select'
| 'include'
| 'find'
| 'findUnique';
function getCache(model: string, kind: SchemaKinds) {
return schemas.get(`${model}:${kind}`);
@ -17,9 +24,11 @@ function putCache(model: string, kind: SchemaKinds, schema: ZodSchema) {
export function makeWhereSchema(
schema: SchemaDef,
model: string
): ZodObject<any> {
let result = getCache(model, 'where') as ZodObject<any> | undefined;
model: string,
unique: boolean
): ZodSchema {
const cacheKey = unique ? 'whereUnique' : 'where';
let result = getCache(model, cacheKey);
if (result) {
return result;
}
@ -30,24 +39,51 @@ export function makeWhereSchema(
const fieldDef = requireField(schema, model, field);
if (fieldDef.relation) {
fields[field] = z.lazy(() =>
makeWhereSchema(schema, fieldDef.type).optional()
makeWhereSchema(schema, fieldDef.type, false).optional()
);
} else {
fields[field] = makePrimitiveSchema(fieldDef.type).optional();
}
}
result = z.object(fields);
putCache(model, 'where', result);
const baseWhere = z.object(fields);
result = baseWhere;
if (unique) {
// requires at least one unique field (field set) is required
const uniqueFields = getUniqueFields(schema, model);
if (uniqueFields.length === 0) {
throw new InternalError(`Model "${model}" has no unique fields`);
}
if (uniqueFields.length === 1) {
// only one unique field (set), mark the field(s) required
result = baseWhere.required(
uniqueFields[0]!.reduce(
(acc, k) => ({
...acc,
[k.name]: true,
}),
{}
)
);
} else {
result = baseWhere.refine((value) => {
// check that at least one unique field is set
return uniqueFields.some((fields) =>
fields.every(({ name }) => value[name] !== undefined)
);
}, `At least one unique field or field set must be set`);
}
}
putCache(model, cacheKey, result);
return result;
}
export function makeSelectSchema(
schema: SchemaDef,
model: string
): ZodObject<any> {
let result = getCache(model, 'select') as ZodObject<any> | undefined;
export function makeSelectSchema(schema: SchemaDef, model: string) {
let result = getCache(model, 'select');
if (result) {
return result;
}
@ -82,11 +118,8 @@ export function makeSelectSchema(
return result;
}
export function makeIncludeSchema(
schema: SchemaDef,
model: string
): ZodObject<any> {
let result = getCache(model, 'include') as ZodObject<any> | undefined;
export function makeIncludeSchema(schema: SchemaDef, model: string) {
let result = getCache(model, 'include');
if (result) {
return result;
}
@ -119,6 +152,41 @@ export function makeIncludeSchema(
return result;
}
export function makeFindSchema(
schema: SchemaDef,
model: string,
unique: boolean
) {
const cacheKey = unique ? 'findUnique' : 'find';
let result = getCache(model, cacheKey);
if (result) {
return result;
}
const where = makeWhereSchema(schema, model, unique);
const select = makeSelectSchema(schema, model);
const include = makeIncludeSchema(schema, model);
result = z
.object({
where: unique ? where : where.optional(),
select: select.optional(),
include: include.optional(),
skip: z.number().int().nonnegative().optional(),
take: z.number().int().nonnegative().optional(),
})
.refine(
(value) => !value.select || !value.include,
'"select" and "include" cannot be used together'
);
if (!unique) {
result = result.optional();
}
putCache(model, cacheKey, result);
return result;
}
function makePrimitiveSchema(type: string) {
return Match.value(type).pipe(
Match.when('String', () => z.string()),
@ -131,5 +199,3 @@ function makePrimitiveSchema(type: string) {
Match.orElse(() => z.unknown())
);
}
//#endregion

View file

@ -157,10 +157,13 @@ export function getIdValues(
schema: SchemaDef,
model: string,
data: any
): Array<{ field: string; value: any }> {
): Record<string, any> {
const idFields = getIdFields(schema, model);
if (!idFields) {
throw new InternalError(`ID fields not defined for model "${model}"`);
}
return idFields.map((field) => ({ field, value: data[field] }));
return idFields.reduce(
(acc, field) => ({ ...acc, [field]: data[field] }),
{}
);
}

View file

@ -204,14 +204,6 @@ export type MapFieldType<
Field extends GetFields<Schema, Model>
> = MapFieldDefType<Schema, GetField<Schema, Model, Field>>;
// WrapType<
// GetFieldType<Schema, Model, Field> extends GetEnums<Schema>
// ? 'foo'
// : MapBaseType<GetField<Schema, Model, Field>>,
// FieldIsOptional<Schema, Model, Field>,
// FieldIsArray<Schema, Model, Field>
// >;
type MapFieldDefType<
Schema extends SchemaDef,
T extends Pick<FieldDef, 'type' | 'optional' | 'array'>
@ -303,6 +295,8 @@ export type FindArgs<
Model extends GetModels<Schema>
> = {
where?: Where<Schema, Model>;
skip?: number;
take?: number;
} & SelectInclude<Schema, Model>;
export type FindUniqueArgs<

View file

@ -22,8 +22,21 @@ export type MapBaseType<T> = T extends 'String'
? Decimal
: T extends 'DateTime'
? Date
: T extends 'Json'
? JsonValue
: unknown;
export type JsonValue =
| string
| number
| boolean
| null
| JsonObject
| JsonArray;
export type JsonObject = { [key: string]: JsonValue };
export type JsonArray = Array<JsonValue>;
export type Simplify<T> = { [Key in keyof T]: T[Key] } & {};
export function call(code: string) {

View file

@ -1,3 +1,5 @@
import Sqlite from 'better-sqlite3';
import { SqliteDialect } from 'kysely';
import { beforeEach, describe, expect, it } from 'vitest';
import { makeClient } from '../../src/client';
import type { DBClient } from '../../src/client/types';
@ -7,7 +9,11 @@ describe('Client API create tests', () => {
let client: DBClient<typeof Schema>;
beforeEach(async () => {
client = makeClient(Schema);
client = makeClient(Schema, {
dialect: new SqliteDialect({
database: new Sqlite(':memory:'),
}),
});
await pushSchema(client.$db);
});

View file

@ -1,14 +1,20 @@
import Sqlite from 'better-sqlite3';
import { SqliteDialect } from 'kysely';
import { beforeEach, describe, expect, it } from 'vitest';
import { makeClient } from '../../src/client';
import { NotFoundError } from '../../src/client/errors';
import type { DBClient } from '../../src/client/types';
import { pushSchema, Schema } from '../test-schema';
import { NotFoundError } from '../../src/client/errors';
describe('Client API find tests', () => {
let client: DBClient<typeof Schema>;
beforeEach(async () => {
client = makeClient(Schema);
client = makeClient(Schema, {
dialect: new SqliteDialect({
database: new Sqlite(':memory:'),
}),
});
await pushSchema(client.$db);
});
@ -111,62 +117,21 @@ describe('Client API find tests', () => {
const user = await createUser();
await createPosts(user.id);
const q = client.$db
.selectFrom('user')
.where('user.id', 'in', (qb) =>
qb
.selectFrom('user')
.select('id')
.orderBy('createdAt desc')
.limit(1)
)
.leftJoin(
(eb) =>
eb
.selectFrom('post')
.select([
'post.id',
'post.title',
'post.authorId',
'author.email as authorEmail',
])
.where('post.published', '!=', 1 as any)
.leftJoin(
(eb1) =>
eb1.selectFrom('user').selectAll().as('author'),
(join) =>
join.onRef('author.id', '=', 'post.authorId')
)
.as('post'),
(join) => join.onRef('post.authorId', '=', 'user.id')
)
.select([
'user.id as user.id',
'post.id as post.id',
'post.title as post.title',
'post.authorEmail',
]);
const { sql, parameters } = q.compile();
console.log('SQL:', sql, 'PARAMS', parameters);
console.log(await q.execute());
let r = await client.user.findUnique({
where: { id: '1' },
select: { id: true, email: true, posts: true },
});
expect(r?.id).toBeTruthy();
expect(r?.email).toBeTruthy();
expect('name' in r!).toBeFalsy();
expect(r?.posts).toHaveLength(2);
// let r = await client.user.findUnique({
// where: { id: '1' },
// select: { id: true, email: true, posts: true },
// });
// expect(r?.id).toBeTruthy();
// expect(r?.email).toBeTruthy();
// expect('name' in r!).toBeFalsy();
// expect(r?.posts).toHaveLength(2);
// await expect(
// client.user.findUnique({
// where: { id: '1' },
// select: { id: true, email: true },
// include: { posts: true },
// } as any)
// ).rejects.toThrow(
// 'Cannot use both "select" and "include" in find args'
// );
await expect(
client.user.findUnique({
where: { id: '1' },
select: { id: true, email: true },
include: { posts: true },
} as any)
).rejects.toThrow('cannot be used together');
});
});