diff --git a/packages/runtime/src/client/index.ts b/packages/runtime/src/client/index.ts index 6baba91b..255e5faf 100644 --- a/packages/runtime/src/client/index.ts +++ b/packages/runtime/src/client/index.ts @@ -1,30 +1,33 @@ -import SQLite from 'better-sqlite3'; -import { Kysely, SqliteDialect } from 'kysely'; +import { Kysely, ParseJSONResultsPlugin, type KyselyConfig } from 'kysely'; import { type GetModels, type SchemaDef } from '../schema/schema'; +import { NotFoundError } from './errors'; import { runCreate } from './operations/create'; +import { runFind } from './operations/find'; import type { toKysely } from './query-builder'; import type { DBClient, ModelOperations } from './types'; -import { runFind } from './operations/find'; -import { NotFoundError } from './errors'; -export function makeClient(schema: Schema) { - return new Client(schema) as unknown as DBClient; +export type ClientOptions = { + dialect: KyselyConfig['dialect']; + plugins?: KyselyConfig['plugins']; + log?: KyselyConfig['log']; +}; + +export function makeClient( + schema: Schema, + options: ClientOptions +) { + return new Client(schema, options) as unknown as DBClient; } -class Client { +export class Client { public readonly $db: Kysely>; - constructor(schema: Schema) { - this.$db = this.createKysely(schema); - return createClientProxy(this, schema); - } - - private createKysely( - _schema: Schema - ): Kysely> { - return new Kysely({ - dialect: new SqliteDialect({ database: new SQLite(':memory:') }), + constructor(schema: Schema, options: ClientOptions) { + this.$db = new Kysely({ + ...options, + plugins: [...(options.plugins ?? []), new ParseJSONResultsPlugin()], }); + return createClientProxy(this, schema); } } @@ -63,16 +66,26 @@ function createModelProxy< ): ModelOperations { return { create: async (args) => { - return runCreate(db, schema, model, args); + const r = await runCreate( + { db, schema, model, operation: 'create' }, + args + ); + return r; }, findUnique: async (args) => { - const r = await runFind(db, schema, model, 'findUnique', args); + const r = await runFind( + { db, schema, model, operation: 'findUnique' }, + args + ); return r ?? null; }, findUniqueOrThrow: async (args) => { - const r = await runFind(db, schema, model, 'findUnique', args); + const r = await runFind( + { db, schema, model, operation: 'findUnique' }, + args + ); if (!r) { throw new NotFoundError(`No "${model}" found`); } else { @@ -81,11 +94,14 @@ function createModelProxy< }, findFirst: async (args) => { - return runFind(db, schema, model, 'findFirst', args); + return runFind({ db, schema, model, operation: 'findFirst' }, args); }, findFirstOrThrow: async (args) => { - const r = await runFind(db, schema, model, 'findFirst', args); + const r = await runFind( + { db, schema, model, operation: 'findFirst' }, + args + ); if (!r) { throw new NotFoundError(`No "${model}" found`); } else { @@ -94,7 +110,7 @@ function createModelProxy< }, findMany: async (args) => { - return runFind(db, schema, model, 'findMany', args); + return runFind({ db, schema, model, operation: 'findMany' }, args); }, }; } diff --git a/packages/runtime/src/client/operations/common.ts b/packages/runtime/src/client/operations/common.ts deleted file mode 100644 index 13c59e01..00000000 --- a/packages/runtime/src/client/operations/common.ts +++ /dev/null @@ -1,108 +0,0 @@ -import { Array } from 'effect'; -import type { SchemaDef } from '../../schema/schema'; -import { requireField, requireIdFields } from '../query-utils'; -import type { FindArgs } from '../types'; - -export function assembleResult( - schema: SchemaDef, - model: string, - data: any, - args: FindArgs | undefined -) { - if (!data) { - return data; - } - - const arrayData = Array.isArray(data) ? data : [data]; - return doAssembleResult(schema, model, '$', arrayData, args); -} - -function doAssembleResult( - schema: SchemaDef, - model: string, - path: string, - data: any[], - args: FindArgs | undefined -) { - const grouped = Array.groupBy(data, (item) => - getEntityKey(schema, model, path, item) - ); - return Object.values(grouped).map((rows) => { - const entity = constructEntity(schema, model, path, rows, args); - return entity; - }); -} - -function getEntityKey( - schema: SchemaDef, - model: string, - path: string, - data: any -) { - const idFields = requireIdFields(schema, model); - return JSON.stringify( - idFields.reduce((acc, f) => ({ ...acc, [f]: data[`${path}>${f}`] }), {}) - ); -} - -function constructEntity( - schema: SchemaDef, - model: string, - path: string, - rows: any[], - args: FindArgs | undefined -) { - const result: any = {}; - - // scalar fields - for (const [k, v] of Object.entries(rows[0])) { - if (!k.startsWith(`${path}>`)) { - continue; - } - - const field = k.substring(`${path}>`.length).split('>')[0]; - if (!field) { - continue; - } - - const fieldDef = requireField(schema, model, field); - if (!fieldDef.relation) { - if (!args?.select || args?.select[field]) { - result[field] = v; - } - } - } - - // relation fields - const selectInclude = args?.select ?? args?.include; - if (selectInclude) { - for (const [field, payload] of Object.entries(selectInclude)) { - if (!payload) { - continue; - } - const fieldDef = requireField(schema, model, field); - if (!fieldDef.relation) { - continue; - } - - const childSelectInclude = - typeof payload === 'object' - ? (payload as any)[field] ?? (payload as any)[field] - : undefined; - const child = doAssembleResult( - schema, - fieldDef.type, - `${path}>${field}`, - rows, - childSelectInclude - ); - if (fieldDef.array) { - result[field] = child; - } else { - result[field] = child[0] ?? null; - } - } - } - - return result; -} diff --git a/packages/runtime/src/client/operations/context.ts b/packages/runtime/src/client/operations/context.ts new file mode 100644 index 00000000..d9b2bd1e --- /dev/null +++ b/packages/runtime/src/client/operations/context.ts @@ -0,0 +1,12 @@ +import type { Kysely } from 'kysely'; +import type { SchemaDef } from '../../schema/schema'; +import type { toKysely } from '../query-builder'; + +export type Operations = 'findMany' | 'findUnique' | 'findFirst' | 'create'; + +export type OperationContext = { + db: Kysely>; + schema: SchemaDef; + model: string; + operation: Operations; +}; diff --git a/packages/runtime/src/client/operations/create.ts b/packages/runtime/src/client/operations/create.ts index d5021701..664028b5 100644 --- a/packages/runtime/src/client/operations/create.ts +++ b/packages/runtime/src/client/operations/create.ts @@ -11,12 +11,15 @@ import { import { clone } from '../../utils/clone'; import { InternalError, QueryError } from '../errors'; import { + getIdValues, getRelationForeignKeyFieldPairs, isForeignKeyField, isScalarField, requireField, requireModelEffect, } from '../query-utils'; +import type { OperationContext } from './context'; +import { runQuery as runFindQuery } from './find'; const CreateArgsSchema = z.object({ data: z.record(z.string(), z.any()), @@ -34,9 +37,7 @@ const RelationPayloadSchema = z.union([ ]); export function runCreate( - db: Kysely, - schema: SchemaDef, - model: string, + { db, schema, model }: OperationContext, args: unknown ) { return Effect.runPromise( @@ -44,7 +45,6 @@ export function runCreate( // parse args const parsedArgs = yield* parseCreateArgs(args); - // run query const result = yield* runQuery(db, schema, model, parsedArgs); yield* Console.log('create result:', result); return result; @@ -66,12 +66,14 @@ function runQuery( args: CreateArgs ) { return Effect.gen(function* () { - const hasRelation = Object.keys(args.data).some( + const hasRelationCreate = Object.keys(args.data).some( (f) => !!requireField(schema, model, f).relation ); + const returnRelations = needReturnRelations(schema, model, args); + let result: any; - if (hasRelation) { + if (hasRelationCreate || returnRelations) { // employ a transaction result = yield* Effect.tryPromise({ try: () => @@ -80,17 +82,32 @@ function runQuery( .setIsolationLevel('repeatable read') .execute(async (trx) => Effect.runPromise( - doCreate(trx, schema, model, args.data) + Effect.gen(function* () { + const createResult = yield* doCreate( + trx, + schema, + model, + args.data + ); + return yield* readBackResult( + trx, + schema, + model, + createResult, + args + ); + }) ) ), catch: (e) => new QueryError(`Error during create: ${e}`), }); } else { // simple create - result = yield* doCreate(db, schema, model, args.data); + const createResult = yield* doCreate(db, schema, model, args.data); + result = trimResult(createResult, args); } - return assembleResult(result, args); + return result; }); } @@ -140,7 +157,10 @@ function doCreate( ); } - const subPayload = parseRelationPayload(payload, field); + const subPayload = yield* parseRelationPayload( + payload, + field + ); const r = yield* Match.value(subPayload).pipe( Match.when( @@ -233,6 +253,48 @@ function evalGenerator(generator: FieldGenerators) { ); } -function assembleResult(primaryData: any, _args: CreateArgs) { - return primaryData; +function trimResult(data: any, args: CreateArgs) { + if (!args.select) { + return data; + } + return Object.keys(args.select).reduce((acc, field) => { + acc[field] = data[field]; + return acc; + }, {} as any); +} + +function readBackResult( + db: Kysely, + schema: SchemaDef, + model: string, + primaryData: any, + args: CreateArgs +) { + return Effect.gen(function* () { + // fetch relations based on include or select + const read = yield* runFindQuery(db, schema, model, 'findUnique', { + where: getIdValues(schema, model, primaryData), + select: args.select, + include: args.include, + }); + return read[0] ?? null; + }); +} + +function needReturnRelations( + schema: SchemaDef, + model: string, + args: CreateArgs +) { + let returnRelation = false; + + if (args.include) { + returnRelation = Object.keys(args.include).length > 0; + } else if (args.select) { + returnRelation = Object.entries(args.select).some(([K, v]) => { + const fieldDef = requireField(schema, model, K); + return fieldDef.relation && v; + }); + } + return returnRelation; } diff --git a/packages/runtime/src/client/operations/dialect/index.ts b/packages/runtime/src/client/operations/dialect/index.ts new file mode 100644 index 00000000..a5116427 --- /dev/null +++ b/packages/runtime/src/client/operations/dialect/index.ts @@ -0,0 +1,24 @@ +import type { SelectQueryBuilder } from 'kysely'; +import type { SchemaDef, SupportedProviders } from '../../../schema/schema'; +import { PostgresQueryDialect } from './postgres'; +import { SqliteQueryDialect } from './sqlite'; + +export interface QueryDialect { + buildRelationSelection( + query: SelectQueryBuilder, + schema: SchemaDef, + model: string, + relationField: string, + parentName: string, + _payload: any + ): SelectQueryBuilder; +} + +const dialects: Record = { + postgresql: new PostgresQueryDialect(), + sqlite: new SqliteQueryDialect(), +}; + +export function getQueryDialect(provider: SupportedProviders) { + return dialects[provider]; +} diff --git a/packages/runtime/src/client/operations/dialect/postgres.ts b/packages/runtime/src/client/operations/dialect/postgres.ts new file mode 100644 index 00000000..e36e1fe8 --- /dev/null +++ b/packages/runtime/src/client/operations/dialect/postgres.ts @@ -0,0 +1,63 @@ +import type { SelectQueryBuilder } from 'kysely'; +import type { QueryDialect } from '.'; +import type { SchemaDef } from '../../../schema/schema'; +import { + getRelationForeignKeyFieldPairs, + requireField, + requireModel, +} from '../../query-utils'; + +export class PostgresQueryDialect implements QueryDialect { + buildRelationSelection( + query: SelectQueryBuilder, + schema: SchemaDef, + model: string, + relationField: string, + parentName: string, + _payload: any + ): SelectQueryBuilder { + const relationFieldDef = requireField(schema, model, relationField); + const relationModel = requireModel(schema, relationFieldDef.type); + const keyPairs = getRelationForeignKeyFieldPairs( + schema, + model, + relationField + ); + + let result = query; + + result = result.leftJoinLateral( + (eb) => { + let tbl = eb.selectFrom(relationModel.dbTable); + + keyPairs.forEach(({ fk, pk }) => { + tbl = tbl.whereRef( + `${relationModel.dbTable}.${fk}`, + '=', + `${parentName}.${pk}` + ); + }); + + return tbl + .select((eb1) => + eb1.fn + .coalesce( + eb1.fn.jsonAgg( + eb1.fn('jsonb_build_object', []) + ), + '[]' + ) + .as('data') + ) + .as(`${model}$${relationField}`); + }, + (join) => join.onTrue() + ); + + result = result.select( + `${model}$${relationField}.data as ${relationField}` + ); + + return result; + } +} diff --git a/packages/runtime/src/client/operations/dialect/sqlite.ts b/packages/runtime/src/client/operations/dialect/sqlite.ts new file mode 100644 index 00000000..4bdd5490 --- /dev/null +++ b/packages/runtime/src/client/operations/dialect/sqlite.ts @@ -0,0 +1,61 @@ +import { sql, type SelectQueryBuilder } from 'kysely'; +import type { QueryDialect } from '.'; +import type { SchemaDef } from '../../../schema/schema'; +import { + getRelationForeignKeyFieldPairs, + requireField, + requireModel, +} from '../../query-utils'; + +export class SqliteQueryDialect implements QueryDialect { + buildRelationSelection( + query: SelectQueryBuilder, + schema: SchemaDef, + model: string, + relationField: string, + parentName: string, + _payload: any + ): SelectQueryBuilder { + const relationFieldDef = requireField(schema, model, relationField); + const relationModel = requireModel(schema, relationFieldDef.type); + const keyPairs = getRelationForeignKeyFieldPairs( + schema, + model, + relationField + ); + + let result = query; + + result = result.select((eb) => { + let tbl = eb + .selectFrom( + `${relationModel.dbTable} as ${parentName}$${relationField}` + ) + .select((eb1) => { + const objArgs = Object.keys(relationModel.fields) + .filter((f) => !relationModel.fields[f]?.relation) + .map((field) => [field, eb1.ref(field)]) + .flatMap((v) => v); + + return eb1.fn + .coalesce( + sql`json_group_array(json_object(${sql.join( + objArgs + )}))`, + sql`json_array()` + ) + .as('data'); + }); + keyPairs.forEach(({ fk, pk }) => { + tbl = tbl.whereRef( + `${parentName}$${relationField}.${fk}`, + '=', + `${parentName}.${pk}` + ); + }); + return tbl.as(relationField); + }); + + return result; + } +} diff --git a/packages/runtime/src/client/operations/find.ts b/packages/runtime/src/client/operations/find.ts index c65d168f..019be969 100644 --- a/packages/runtime/src/client/operations/find.ts +++ b/packages/runtime/src/client/operations/find.ts @@ -1,29 +1,20 @@ import { Console, Effect } from 'effect'; import type { Kysely, SelectQueryBuilder } from 'kysely'; -import { z, ZodSchema } from 'zod'; import type { SchemaDef } from '../../schema/schema'; -import { InternalError, QueryError } from '../errors'; +import { QueryError } from '../errors'; import { - getRelationForeignKeyFieldPairs, - getUniqueFields, isScalarField, requireField, requireModel, requireModelEffect, } from '../query-utils'; import type { FindArgs } from '../types'; -import { assembleResult } from './common'; -import { makeIncludeSchema, makeSelectSchema, makeWhereSchema } from './parse'; - -type FindOperation = 'findMany' | 'findUnique' | 'findFirst'; - -const ROOT_ALIAS = '$'; +import type { OperationContext, Operations } from './context'; +import { getQueryDialect } from './dialect'; +import { makeFindSchema } from './parse'; export function runFind( - db: Kysely, - schema: SchemaDef, - model: string, - operation: FindOperation, + { db, schema, model, operation }: OperationContext, args: unknown ) { return Effect.runPromise( @@ -45,7 +36,8 @@ export function runFind( parsedArgs ); - const finalResult = operation === 'findMany' ? result : result[0]; + const finalResult = + operation === 'findMany' ? result : result[0] ?? null; yield* Console.log(`${operation} result:`, finalResult); return finalResult; }) @@ -55,67 +47,14 @@ export function runFind( function parseFindArgs( schema: SchemaDef, model: string, - operation: FindOperation, + operation: Operations, args: unknown ) { - if (!args || typeof args !== 'object') { - if (operation === 'findUnique') { - // args is required for findUnique - return Effect.fail(new QueryError(`Missing query args`)); - } else { - return Effect.succeed(undefined); - } - } - - const baseWhere = makeWhereSchema(schema, model); - let where: ZodSchema = baseWhere; - - if (operation === 'findUnique') { - // findUnique requires at least one unique field (field set) is required - - const uniqueFields = getUniqueFields(schema, model); - if (uniqueFields.length === 0) { - return Effect.fail( - new InternalError(`Model "${model}" has no unique fields`) - ); - } - - if (uniqueFields.length === 1) { - // only one unique field (set), mark the field(s) required - where = baseWhere.required( - uniqueFields[0]!.reduce( - (acc, k) => ({ - ...acc, - [k.name]: true, - }), - {} - ) - ); - } else { - where = baseWhere.refine((value) => { - // check that at least one unique field is set - return uniqueFields.some((fields) => - fields.every(({ name }) => value[name] !== undefined) - ); - }, `At least one unique field or field set must be set`); - } - } else { - // where clause is optional - where = where.optional(); - } - - const select = makeSelectSchema(schema, model).optional(); - const include = makeIncludeSchema(schema, model).optional(); - - if ('select' in args && 'include' in args) { - return Effect.fail( - new QueryError( - 'Cannot use both "select" and "include" in find args' - ) - ); - } - - const findSchema = z.object({ where, select, include }); + const findSchema = makeFindSchema( + schema, + model, + operation === 'findUnique' + ); return Effect.try({ try: () => findSchema.parse(args), @@ -123,18 +62,18 @@ function parseFindArgs( }); } -function runQuery( +export function runQuery( db: Kysely, schema: SchemaDef, model: string, operation: string, args: FindArgs | undefined -): Effect.Effect { +): Effect.Effect { return Effect.gen(function* () { const modelDef = yield* requireModelEffect(schema, model); // table - let query = db.selectFrom(`${modelDef.dbTable} as ${ROOT_ALIAS}`); + let query = db.selectFrom(`${modelDef.dbTable}`); if (operation !== 'findMany') { query = query.limit(1); @@ -145,16 +84,38 @@ function runQuery( query = buildWhere(query, args.where); } + // skip + if (args?.skip) { + query = query.offset(args.skip); + } + + // take + if (args?.take) { + query = query.limit(args.take); + } + // select if (args?.select) { - query = buildFieldSelection(schema, model, query, args?.select); + query = buildFieldSelection( + schema, + model, + query, + args?.select, + modelDef.dbTable + ); } else { query = buildSelectAllFields(schema, model, query); } // include if (args?.include) { - query = buildFieldSelection(schema, model, query, args?.include); + query = buildFieldSelection( + schema, + model, + query, + args?.include, + modelDef.dbTable + ); } const compiled = query.compile(); @@ -170,15 +131,13 @@ function runQuery( }); yield* Console.log(`Raw results:`, rows); - const assembled = assembleResult(schema, model, rows, args); - return assembled; + return rows; }); } function buildWhere( - query: SelectQueryBuilder, - where: Record | undefined, - tableAlias = ROOT_ALIAS + query: SelectQueryBuilder, + where: Record | undefined ) { let result = query; if (!where) { @@ -186,12 +145,7 @@ function buildWhere( } result = Object.entries(where).reduce( - (acc, [field, value]) => - acc.where( - tableAlias ? `${tableAlias}.${field}` : field, - '=', - value - ), + (acc, [field, value]) => acc.where(field, '=', value), result ); @@ -201,9 +155,9 @@ function buildWhere( function buildFieldSelection( schema: SchemaDef, model: string, - query: SelectQueryBuilder, + query: SelectQueryBuilder, selectOrInclude: Record, - tableAlias = ROOT_ALIAS + parentName: string ) { let result = query; @@ -213,15 +167,15 @@ function buildFieldSelection( } const fieldDef = requireField(schema, model, field); if (!fieldDef.relation) { - result = result.select(selectField(tableAlias, field)); + result = result.select(field); } else { result = buildRelationSelection( + result, schema, model, field, - result, - payload, - tableAlias + parentName, + payload ); } } @@ -230,77 +184,36 @@ function buildFieldSelection( } function buildRelationSelection( + query: SelectQueryBuilder, schema: SchemaDef, model: string, relationField: string, - query: SelectQueryBuilder, - payload: any, - tableAlias = ROOT_ALIAS + parentName: string, + payload: any ) { - const relationFieldDef = requireField(schema, model, relationField); - const relationModel = requireModel(schema, relationFieldDef.type); - const keyPairs = getRelationForeignKeyFieldPairs( - schema, - model, - relationField - ); - - let result = query; - - const nextAlias = joinAlias(tableAlias, relationField); - result = result.leftJoin( - `${relationModel.dbTable} as ${nextAlias}`, - (join) => - keyPairs.reduce( - (acc, { fk, pk }) => - acc.onRef( - `${nextAlias}.${fk}`, - '=', - tableAlias ? `${tableAlias}.${pk}` : pk - ), - join - ) - ); - - if (payload === true) { - result = buildSelectAllFields( - schema, - relationFieldDef.type, - result, - nextAlias - ); - } else { - result = buildFieldSelection( - schema, - relationFieldDef.type, - result, - payload, - nextAlias - ); + const queryDialect = getQueryDialect(schema.provider); + if (!queryDialect) { + throw new QueryError(`Unsupported provider: ${schema.provider}`); } - return result; + return queryDialect.buildRelationSelection( + query, + schema, + model, + relationField, + parentName, + payload + ); } function buildSelectAllFields( schema: SchemaDef, model: string, - query: SelectQueryBuilder, - tableAlias = ROOT_ALIAS + query: SelectQueryBuilder ) { let result = query; const modelDef = requireModel(schema, model); return Object.keys(modelDef.fields) .filter((f) => isScalarField(schema, model, f)) - .reduce((acc, f) => acc.select(selectField(tableAlias, f)), result); -} - -function joinAlias(tableAlias: string, field: string) { - return `${tableAlias}>${field}`; -} - -function selectField(tableAlias: string, field: string) { - return tableAlias - ? `${tableAlias}.${field} as ${joinAlias(tableAlias, field)}` - : `${field} as ${field}`; + .reduce((acc, f) => acc.select(f), result); } diff --git a/packages/runtime/src/client/operations/parse.ts b/packages/runtime/src/client/operations/parse.ts index c1a26b61..c817dfcb 100644 --- a/packages/runtime/src/client/operations/parse.ts +++ b/packages/runtime/src/client/operations/parse.ts @@ -1,11 +1,18 @@ import { Match } from 'effect'; -import { z, ZodObject, ZodSchema } from 'zod'; +import { z, ZodSchema } from 'zod'; import type { SchemaDef } from '../../schema/schema'; -import { requireField, requireModel } from '../query-utils'; +import { InternalError } from '../errors'; +import { getUniqueFields, requireField, requireModel } from '../query-utils'; const schemas = new Map(); -type SchemaKinds = 'where' | 'select' | 'include'; +type SchemaKinds = + | 'where' + | 'whereUnique' + | 'select' + | 'include' + | 'find' + | 'findUnique'; function getCache(model: string, kind: SchemaKinds) { return schemas.get(`${model}:${kind}`); @@ -17,9 +24,11 @@ function putCache(model: string, kind: SchemaKinds, schema: ZodSchema) { export function makeWhereSchema( schema: SchemaDef, - model: string -): ZodObject { - let result = getCache(model, 'where') as ZodObject | undefined; + model: string, + unique: boolean +): ZodSchema { + const cacheKey = unique ? 'whereUnique' : 'where'; + let result = getCache(model, cacheKey); if (result) { return result; } @@ -30,24 +39,51 @@ export function makeWhereSchema( const fieldDef = requireField(schema, model, field); if (fieldDef.relation) { fields[field] = z.lazy(() => - makeWhereSchema(schema, fieldDef.type).optional() + makeWhereSchema(schema, fieldDef.type, false).optional() ); } else { fields[field] = makePrimitiveSchema(fieldDef.type).optional(); } } - result = z.object(fields); - putCache(model, 'where', result); + const baseWhere = z.object(fields); + result = baseWhere; + + if (unique) { + // requires at least one unique field (field set) is required + const uniqueFields = getUniqueFields(schema, model); + if (uniqueFields.length === 0) { + throw new InternalError(`Model "${model}" has no unique fields`); + } + + if (uniqueFields.length === 1) { + // only one unique field (set), mark the field(s) required + result = baseWhere.required( + uniqueFields[0]!.reduce( + (acc, k) => ({ + ...acc, + [k.name]: true, + }), + {} + ) + ); + } else { + result = baseWhere.refine((value) => { + // check that at least one unique field is set + return uniqueFields.some((fields) => + fields.every(({ name }) => value[name] !== undefined) + ); + }, `At least one unique field or field set must be set`); + } + } + + putCache(model, cacheKey, result); return result; } -export function makeSelectSchema( - schema: SchemaDef, - model: string -): ZodObject { - let result = getCache(model, 'select') as ZodObject | undefined; +export function makeSelectSchema(schema: SchemaDef, model: string) { + let result = getCache(model, 'select'); if (result) { return result; } @@ -82,11 +118,8 @@ export function makeSelectSchema( return result; } -export function makeIncludeSchema( - schema: SchemaDef, - model: string -): ZodObject { - let result = getCache(model, 'include') as ZodObject | undefined; +export function makeIncludeSchema(schema: SchemaDef, model: string) { + let result = getCache(model, 'include'); if (result) { return result; } @@ -119,6 +152,41 @@ export function makeIncludeSchema( return result; } +export function makeFindSchema( + schema: SchemaDef, + model: string, + unique: boolean +) { + const cacheKey = unique ? 'findUnique' : 'find'; + let result = getCache(model, cacheKey); + if (result) { + return result; + } + + const where = makeWhereSchema(schema, model, unique); + const select = makeSelectSchema(schema, model); + const include = makeIncludeSchema(schema, model); + result = z + .object({ + where: unique ? where : where.optional(), + select: select.optional(), + include: include.optional(), + skip: z.number().int().nonnegative().optional(), + take: z.number().int().nonnegative().optional(), + }) + .refine( + (value) => !value.select || !value.include, + '"select" and "include" cannot be used together' + ); + + if (!unique) { + result = result.optional(); + } + + putCache(model, cacheKey, result); + return result; +} + function makePrimitiveSchema(type: string) { return Match.value(type).pipe( Match.when('String', () => z.string()), @@ -131,5 +199,3 @@ function makePrimitiveSchema(type: string) { Match.orElse(() => z.unknown()) ); } - -//#endregion diff --git a/packages/runtime/src/client/query-utils.ts b/packages/runtime/src/client/query-utils.ts index ef20991d..be5d0138 100644 --- a/packages/runtime/src/client/query-utils.ts +++ b/packages/runtime/src/client/query-utils.ts @@ -157,10 +157,13 @@ export function getIdValues( schema: SchemaDef, model: string, data: any -): Array<{ field: string; value: any }> { +): Record { const idFields = getIdFields(schema, model); if (!idFields) { throw new InternalError(`ID fields not defined for model "${model}"`); } - return idFields.map((field) => ({ field, value: data[field] })); + return idFields.reduce( + (acc, field) => ({ ...acc, [field]: data[field] }), + {} + ); } diff --git a/packages/runtime/src/client/types.ts b/packages/runtime/src/client/types.ts index 1ff6ea74..1913592b 100644 --- a/packages/runtime/src/client/types.ts +++ b/packages/runtime/src/client/types.ts @@ -204,14 +204,6 @@ export type MapFieldType< Field extends GetFields > = MapFieldDefType>; -// WrapType< -// GetFieldType extends GetEnums -// ? 'foo' -// : MapBaseType>, -// FieldIsOptional, -// FieldIsArray -// >; - type MapFieldDefType< Schema extends SchemaDef, T extends Pick @@ -303,6 +295,8 @@ export type FindArgs< Model extends GetModels > = { where?: Where; + skip?: number; + take?: number; } & SelectInclude; export type FindUniqueArgs< diff --git a/packages/runtime/src/schema/type-utils.ts b/packages/runtime/src/schema/type-utils.ts index 62467fe7..ca60cab4 100644 --- a/packages/runtime/src/schema/type-utils.ts +++ b/packages/runtime/src/schema/type-utils.ts @@ -22,8 +22,21 @@ export type MapBaseType = T extends 'String' ? Decimal : T extends 'DateTime' ? Date + : T extends 'Json' + ? JsonValue : unknown; +export type JsonValue = + | string + | number + | boolean + | null + | JsonObject + | JsonArray; + +export type JsonObject = { [key: string]: JsonValue }; +export type JsonArray = Array; + export type Simplify = { [Key in keyof T]: T[Key] } & {}; export function call(code: string) { diff --git a/packages/runtime/test/client-api/create.test.ts b/packages/runtime/test/client-api/create.test.ts index 9fcee31e..4da97f96 100644 --- a/packages/runtime/test/client-api/create.test.ts +++ b/packages/runtime/test/client-api/create.test.ts @@ -1,3 +1,5 @@ +import Sqlite from 'better-sqlite3'; +import { SqliteDialect } from 'kysely'; import { beforeEach, describe, expect, it } from 'vitest'; import { makeClient } from '../../src/client'; import type { DBClient } from '../../src/client/types'; @@ -7,7 +9,11 @@ describe('Client API create tests', () => { let client: DBClient; beforeEach(async () => { - client = makeClient(Schema); + client = makeClient(Schema, { + dialect: new SqliteDialect({ + database: new Sqlite(':memory:'), + }), + }); await pushSchema(client.$db); }); diff --git a/packages/runtime/test/client-api/find.test.ts b/packages/runtime/test/client-api/find.test.ts index 83164524..6c82ef8e 100644 --- a/packages/runtime/test/client-api/find.test.ts +++ b/packages/runtime/test/client-api/find.test.ts @@ -1,14 +1,20 @@ +import Sqlite from 'better-sqlite3'; +import { SqliteDialect } from 'kysely'; import { beforeEach, describe, expect, it } from 'vitest'; import { makeClient } from '../../src/client'; +import { NotFoundError } from '../../src/client/errors'; import type { DBClient } from '../../src/client/types'; import { pushSchema, Schema } from '../test-schema'; -import { NotFoundError } from '../../src/client/errors'; describe('Client API find tests', () => { let client: DBClient; beforeEach(async () => { - client = makeClient(Schema); + client = makeClient(Schema, { + dialect: new SqliteDialect({ + database: new Sqlite(':memory:'), + }), + }); await pushSchema(client.$db); }); @@ -111,62 +117,21 @@ describe('Client API find tests', () => { const user = await createUser(); await createPosts(user.id); - const q = client.$db - .selectFrom('user') - .where('user.id', 'in', (qb) => - qb - .selectFrom('user') - .select('id') - .orderBy('createdAt desc') - .limit(1) - ) - .leftJoin( - (eb) => - eb - .selectFrom('post') - .select([ - 'post.id', - 'post.title', - 'post.authorId', - 'author.email as authorEmail', - ]) - .where('post.published', '!=', 1 as any) - .leftJoin( - (eb1) => - eb1.selectFrom('user').selectAll().as('author'), - (join) => - join.onRef('author.id', '=', 'post.authorId') - ) - .as('post'), - (join) => join.onRef('post.authorId', '=', 'user.id') - ) - .select([ - 'user.id as user.id', - 'post.id as post.id', - 'post.title as post.title', - 'post.authorEmail', - ]); - const { sql, parameters } = q.compile(); - console.log('SQL:', sql, 'PARAMS', parameters); - console.log(await q.execute()); + let r = await client.user.findUnique({ + where: { id: '1' }, + select: { id: true, email: true, posts: true }, + }); + expect(r?.id).toBeTruthy(); + expect(r?.email).toBeTruthy(); + expect('name' in r!).toBeFalsy(); + expect(r?.posts).toHaveLength(2); - // let r = await client.user.findUnique({ - // where: { id: '1' }, - // select: { id: true, email: true, posts: true }, - // }); - // expect(r?.id).toBeTruthy(); - // expect(r?.email).toBeTruthy(); - // expect('name' in r!).toBeFalsy(); - // expect(r?.posts).toHaveLength(2); - - // await expect( - // client.user.findUnique({ - // where: { id: '1' }, - // select: { id: true, email: true }, - // include: { posts: true }, - // } as any) - // ).rejects.toThrow( - // 'Cannot use both "select" and "include" in find args' - // ); + await expect( + client.user.findUnique({ + where: { id: '1' }, + select: { id: true, email: true }, + include: { posts: true }, + } as any) + ).rejects.toThrow('cannot be used together'); }); });