diff --git a/packages/cli/test/ts-schema-gen.test.ts b/packages/cli/test/ts-schema-gen.test.ts index 2673567d..1056478f 100644 --- a/packages/cli/test/ts-schema-gen.test.ts +++ b/packages/cli/test/ts-schema-gen.test.ts @@ -184,6 +184,30 @@ model Post { }); }); + it('generates correct procedures with array params and returns', async () => { + const { schema } = await generateTsSchema(` +model User { + id Int @id +} + +procedure findByIds(ids: Int[]): User[] +procedure getIds(): Int[] + `); + + expect(schema.procedures).toMatchObject({ + findByIds: { + params: { ids: { name: 'ids', type: 'Int', array: true } }, + returnType: 'User', + returnArray: true, + }, + getIds: { + params: {}, + returnType: 'Int', + returnArray: true, + }, + }); + }); + it('merges fields and attributes from mixins', async () => { const { schema } = await generateTsSchema(` type Timestamped { diff --git a/packages/clients/tanstack-query/src/common/types.ts b/packages/clients/tanstack-query/src/common/types.ts index 8993445e..2e61d286 100644 --- a/packages/clients/tanstack-query/src/common/types.ts +++ b/packages/clients/tanstack-query/src/common/types.ts @@ -1,6 +1,6 @@ import type { Logger, OptimisticDataProvider } from '@zenstackhq/client-helpers'; import type { FetchFn } from '@zenstackhq/client-helpers/fetch'; -import type { OperationsIneligibleForDelegateModels } from '@zenstackhq/orm'; +import type { GetProcedureNames, OperationsIneligibleForDelegateModels, ProcedureFunc } from '@zenstackhq/orm'; import type { GetModels, IsDelegateModel, SchemaDef } from '@zenstackhq/schema'; /** @@ -76,3 +76,7 @@ type WithOptimisticFlag = T extends object : T; export type WithOptimistic = T extends Array ? Array> : WithOptimisticFlag; + +export type ProcedureReturn> = Awaited< + ReturnType> +>; diff --git a/packages/clients/tanstack-query/src/react.ts b/packages/clients/tanstack-query/src/react.ts index 4449e29f..819a5421 100644 --- a/packages/clients/tanstack-query/src/react.ts +++ b/packages/clients/tanstack-query/src/react.ts @@ -36,8 +36,11 @@ import type { FindFirstArgs, FindManyArgs, FindUniqueArgs, + GetProcedure, + GetProcedureNames, GroupByArgs, GroupByResult, + ProcedureEnvelope, QueryOptions, SelectSubset, SimplifiedPlainResult, @@ -55,12 +58,23 @@ import { getQueryKey } from './common/query-key'; import type { ExtraMutationOptions, ExtraQueryOptions, + ProcedureReturn, QueryContext, TrimDelegateModelOperations, WithOptimistic, } from './common/types'; export type { FetchFn } from '@zenstackhq/client-helpers/fetch'; +type ProcedureHookFn< + Schema extends SchemaDef, + ProcName extends GetProcedureNames, + Options, + Result, + Input = ProcedureEnvelope, +> = { args: undefined } extends Input + ? (input?: Input, options?: Options) => Result + : (input: Input, options?: Options) => Result; + /** * React context for query settings. */ @@ -133,8 +147,61 @@ export type ModelMutationModelResult< export type ClientHooks = QueryOptions> = { [Model in GetModels as `${Uncapitalize}`]: ModelQueryHooks; +} & ProcedureHooks; + +type ProcedureHookGroup = { + [Name in GetProcedureNames]: GetProcedure extends { mutation: true } + ? { + useMutation( + options?: Omit< + UseMutationOptions, DefaultError, ProcedureEnvelope>, + 'mutationFn' + > & + QueryContext, + ): UseMutationResult, DefaultError, ProcedureEnvelope>; + } + : { + useQuery: ProcedureHookFn< + Schema, + Name, + Omit>, 'optimisticUpdate'>, + UseQueryResult, DefaultError> & { queryKey: QueryKey } + >; + + useSuspenseQuery: ProcedureHookFn< + Schema, + Name, + Omit>, 'optimisticUpdate'>, + UseSuspenseQueryResult, DefaultError> & { queryKey: QueryKey } + >; + + // Infinite queries for procedures are currently disabled, will add back later if needed + // + // useInfiniteQuery: ProcedureHookFn< + // Schema, + // Name, + // ModelInfiniteQueryOptions>, + // ModelInfiniteQueryResult>> + // >; + + // useSuspenseInfiniteQuery: ProcedureHookFn< + // Schema, + // Name, + // ModelSuspenseInfiniteQueryOptions>, + // ModelSuspenseInfiniteQueryResult>> + // >; + }; }; +export type ProcedureHooks = Schema extends { procedures: Record } + ? { + /** + * Custom procedures. + */ + $procs: ProcedureHookGroup; + } + : {}; + // Note that we can potentially use TypeScript's mapped type to directly map from ORM contract, but that seems // to significantly slow down tsc performance ... export type ModelQueryHooks< @@ -263,7 +330,7 @@ export function useClientQueries { - return Object.keys(schema.models).reduce( + const result = Object.keys(schema.models).reduce( (acc, model) => { (acc as any)[lowerCaseFirst(model)] = useModelQueries, Options>( schema, @@ -274,6 +341,46 @@ export function useClientQueries, ); + + const procedures = (schema as any).procedures as Record | undefined; + if (procedures) { + const buildProcedureHooks = (endpointModel: '$procs') => { + return Object.keys(procedures).reduce((acc, name) => { + const procDef = procedures[name]; + if (procDef?.mutation) { + acc[name] = { + useMutation: (hookOptions?: any) => + useInternalMutation(schema, endpointModel, 'POST', name, { ...options, ...hookOptions }), + }; + } else { + acc[name] = { + useQuery: (args?: any, hookOptions?: any) => + useInternalQuery(schema, endpointModel, name, args, { ...options, ...hookOptions }), + useSuspenseQuery: (args?: any, hookOptions?: any) => + useInternalSuspenseQuery(schema, endpointModel, name, args, { + ...options, + ...hookOptions, + }), + useInfiniteQuery: (args?: any, hookOptions?: any) => + useInternalInfiniteQuery(schema, endpointModel, name, args, { + ...options, + ...hookOptions, + }), + useSuspenseInfiniteQuery: (args?: any, hookOptions?: any) => + useInternalSuspenseInfiniteQuery(schema, endpointModel, name, args, { + ...options, + ...hookOptions, + }), + }; + } + return acc; + }, {} as any); + }; + + (result as any).$procs = buildProcedureHooks('$procs'); + } + + return result; } /** diff --git a/packages/clients/tanstack-query/src/svelte/index.svelte.ts b/packages/clients/tanstack-query/src/svelte/index.svelte.ts index a94941c4..eaab402f 100644 --- a/packages/clients/tanstack-query/src/svelte/index.svelte.ts +++ b/packages/clients/tanstack-query/src/svelte/index.svelte.ts @@ -37,8 +37,11 @@ import type { FindFirstArgs, FindManyArgs, FindUniqueArgs, + GetProcedure, + GetProcedureNames, GroupByArgs, GroupByResult, + ProcedureEnvelope, QueryOptions, SelectSubset, SimplifiedPlainResult, @@ -56,12 +59,23 @@ import { getQueryKey } from '../common/query-key'; import type { ExtraMutationOptions, ExtraQueryOptions, + ProcedureReturn, QueryContext, TrimDelegateModelOperations, WithOptimistic, } from '../common/types'; export type { FetchFn } from '@zenstackhq/client-helpers/fetch'; +type ProcedureHookFn< + Schema extends SchemaDef, + ProcName extends GetProcedureNames, + Options, + Result, + Input = ProcedureEnvelope, +> = { args: undefined } extends Input + ? (args?: Accessor, options?: Accessor) => Result + : (args: Accessor, options?: Accessor) => Result; + /** * Key for setting and getting the global query context. */ @@ -88,6 +102,14 @@ function useQuerySettings() { return { endpoint: endpoint ?? DEFAULT_QUERY_ENDPOINT, ...rest }; } +function merge(rootOpt: unknown, opt: unknown): Accessor { + return () => { + const rootOptVal = typeof rootOpt === 'function' ? (rootOpt as any)() : rootOpt; + const optVal = typeof opt === 'function' ? (opt as any)() : opt; + return { ...rootOptVal, ...optVal }; + }; +} + export type ModelQueryOptions = Omit, 'queryKey'> & ExtraQueryOptions; export type ModelQueryResult = CreateQueryResult, DefaultError> & { queryKey: QueryKey }; @@ -122,8 +144,51 @@ export type ModelMutationModelResult< export type ClientHooks = QueryOptions> = { [Model in GetModels as `${Uncapitalize}`]: ModelQueryHooks; +} & ProcedureHooks; + +type ProcedureHookGroup = { + [Name in GetProcedureNames]: GetProcedure extends { mutation: true } + ? { + useMutation( + options?: Omit< + CreateMutationOptions< + ProcedureReturn, + DefaultError, + ProcedureEnvelope + >, + 'mutationFn' + > & + QueryContext, + ): CreateMutationResult, DefaultError, ProcedureEnvelope>; + } + : { + useQuery: ProcedureHookFn< + Schema, + Name, + Omit>, 'optimisticUpdate'>, + CreateQueryResult, DefaultError> & { queryKey: QueryKey } + >; + + // Infinite queries for procedures are currently disabled, will add back later if needed + // + // useInfiniteQuery: ProcedureHookFn< + // Schema, + // Name, + // ModelInfiniteQueryOptions>, + // ModelInfiniteQueryResult>> + // >; + }; }; +export type ProcedureHooks = Schema extends { procedures: Record } + ? { + /** + * Custom procedures. + */ + $procs: ProcedureHookGroup; + } + : {}; + // Note that we can potentially use TypeScript's mapped type to directly map from ORM contract, but that seems // to significantly slow down tsc performance ... export type ModelQueryHooks< @@ -212,7 +277,7 @@ export function useClientQueries, ): ClientHooks { - return Object.keys(schema.models).reduce( + const result = Object.keys(schema.models).reduce( (acc, model) => { (acc as any)[lowerCaseFirst(model)] = useModelQueries, Options>( schema, @@ -223,6 +288,33 @@ export function useClientQueries, ); + + const procedures = (schema as any).procedures as Record | undefined; + if (procedures) { + const buildProcedureHooks = (endpointModel: '$procs') => { + return Object.keys(procedures).reduce((acc, name) => { + const procDef = procedures[name]; + if (procDef?.mutation) { + acc[name] = { + useMutation: (hookOptions?: any) => + useInternalMutation(schema, endpointModel, 'POST', name, merge(options, hookOptions)), + }; + } else { + acc[name] = { + useQuery: (args?: any, hookOptions?: any) => + useInternalQuery(schema, endpointModel, name, args, merge(options, hookOptions)), + useInfiniteQuery: (args?: any, hookOptions?: any) => + useInternalInfiniteQuery(schema, endpointModel, name, args, merge(options, hookOptions)), + }; + } + return acc; + }, {} as any); + }; + + (result as any).$procs = buildProcedureHooks('$procs'); + } + + return result; } /** @@ -240,14 +332,6 @@ export function useModelQueries< const modelName = modelDef.name; - const merge = (rootOpt: unknown, opt: unknown): Accessor => { - return () => { - const rootOptVal = typeof rootOpt === 'function' ? rootOpt() : rootOpt; - const optVal = typeof opt === 'function' ? opt() : opt; - return { ...rootOptVal, ...optVal }; - }; - }; - return { useFindUnique: (args: any, options?: any) => { return useInternalQuery(schema, modelName, 'findUnique', args, merge(rootOptions, options)); diff --git a/packages/clients/tanstack-query/src/vue.ts b/packages/clients/tanstack-query/src/vue.ts index bd4dcf74..126fef96 100644 --- a/packages/clients/tanstack-query/src/vue.ts +++ b/packages/clients/tanstack-query/src/vue.ts @@ -35,8 +35,11 @@ import type { FindFirstArgs, FindManyArgs, FindUniqueArgs, + GetProcedure, + GetProcedureNames, GroupByArgs, GroupByResult, + ProcedureEnvelope, QueryOptions, SelectSubset, SimplifiedPlainResult, @@ -54,6 +57,7 @@ import { getQueryKey } from './common/query-key'; import type { ExtraMutationOptions, ExtraQueryOptions, + ProcedureReturn, QueryContext, TrimDelegateModelOperations, WithOptimistic, @@ -61,6 +65,16 @@ import type { export type { FetchFn } from '@zenstackhq/client-helpers/fetch'; export const VueQueryContextKey = 'zenstack-vue-query-context'; +type ProcedureHookFn< + Schema extends SchemaDef, + ProcName extends GetProcedureNames, + Options, + Result, + Input = ProcedureEnvelope, +> = { args: undefined } extends Input + ? (args?: MaybeRefOrGetter, options?: MaybeRefOrGetter) => Result + : (args: MaybeRefOrGetter, options?: MaybeRefOrGetter) => Result; + /** * Provide context for query settings. * @@ -123,8 +137,60 @@ export type ModelMutationModelResult< export type ClientHooks = QueryOptions> = { [Model in GetModels as `${Uncapitalize}`]: ModelQueryHooks; +} & ProcedureHooks; + +type ProcedureHookGroup = { + [Name in GetProcedureNames]: GetProcedure extends { mutation: true } + ? { + useMutation( + options?: MaybeRefOrGetter< + Omit< + UnwrapRef< + UseMutationOptions< + ProcedureReturn, + DefaultError, + ProcedureEnvelope + > + >, + 'mutationFn' + > & + QueryContext + >, + ): UseMutationReturnType< + ProcedureReturn, + DefaultError, + ProcedureEnvelope, + unknown + >; + } + : { + useQuery: ProcedureHookFn< + Schema, + Name, + Omit>, 'optimisticUpdate'>, + UseQueryReturnType, DefaultError> & { queryKey: Ref } + >; + + // Infinite queries for procedures are currently disabled, will add back later if needed + // + // useInfiniteQuery: ProcedureHookFn< + // Schema, + // Name, + // ModelInfiniteQueryOptions>, + // ModelInfiniteQueryResult>> + // >; + }; }; +export type ProcedureHooks = Schema extends { procedures: Record } + ? { + /** + * Custom procedures. + */ + $procs: ProcedureHookGroup; + } + : {}; + // Note that we can potentially use TypeScript's mapped type to directly map from ORM contract, but that seems // to significantly slow down tsc performance ... export type ModelQueryHooks< @@ -215,7 +281,15 @@ export function useClientQueries, ): ClientHooks { - return Object.keys(schema.models).reduce( + const merge = (rootOpt: MaybeRefOrGetter | undefined, opt: MaybeRefOrGetter | undefined): any => { + return computed(() => { + const rootVal = toValue(rootOpt) ?? {}; + const optVal = toValue(opt) ?? {}; + return { ...(rootVal as object), ...(optVal as object) }; + }); + }; + + const result = Object.keys(schema.models).reduce( (acc, model) => { (acc as any)[lowerCaseFirst(model)] = useModelQueries, Options>( schema, @@ -226,6 +300,33 @@ export function useClientQueries, ); + + const procedures = (schema as any).procedures as Record | undefined; + if (procedures) { + const buildProcedureHooks = (endpointModel: '$procs') => { + return Object.keys(procedures).reduce((acc, name) => { + const procDef = procedures[name]; + if (procDef?.mutation) { + acc[name] = { + useMutation: (hookOptions?: any) => + useInternalMutation(schema, endpointModel, 'POST', name, merge(options, hookOptions)), + }; + } else { + acc[name] = { + useQuery: (args?: any, hookOptions?: any) => + useInternalQuery(schema, endpointModel, name, args, merge(options, hookOptions)), + useInfiniteQuery: (args?: any, hookOptions?: any) => + useInternalInfiniteQuery(schema, endpointModel, name, args, merge(options, hookOptions)), + }; + } + return acc; + }, {} as any); + }; + + (result as any).$procs = buildProcedureHooks('$procs'); + } + + return result; } /** diff --git a/packages/clients/tanstack-query/test/react-typing-test.ts b/packages/clients/tanstack-query/test/react-typing-test.ts index 8f57ec67..4763bb41 100644 --- a/packages/clients/tanstack-query/test/react-typing-test.ts +++ b/packages/clients/tanstack-query/test/react-typing-test.ts @@ -1,7 +1,9 @@ import { useClientQueries } from '../src/react'; import { schema } from './schemas/basic/schema-lite'; +import { schema as proceduresSchema } from './schemas/procedures/schema-lite'; const client = useClientQueries(schema); +const proceduresClient = useClientQueries(proceduresSchema); // @ts-expect-error missing args client.user.useFindUnique(); @@ -111,3 +113,28 @@ client.foo.useCreate(); client.foo.useUpdate(); client.bar.useCreate(); + +// procedures (query) +check(proceduresClient.$procs.greet.useQuery().data?.toUpperCase()); +check(proceduresClient.$procs.greet.useQuery({ args: { name: 'bob' } }).data?.toUpperCase()); +check(proceduresClient.$procs.greet.useQuery({ args: { name: 'bob' } }, { enabled: true }).queryKey); +// @ts-expect-error wrong arg shape +proceduresClient.$procs.greet.useQuery({ args: { hello: 'world' } }); + +// Infinite queries for procedures are currently disabled, will add back later if needed +// check(proceduresClient.$procs.greetMany.useInfiniteQuery({ args: { name: 'bob' } }).data?.pages[0]?.[0]?.toUpperCase()); +// check(proceduresClient.$procs.greetMany.useInfiniteQuery({ args: { name: 'bob' } }).queryKey); + +// @ts-expect-error missing args +proceduresClient.$procs.greetMany.useQuery(); +// @ts-expect-error greet is not a mutation procedure +proceduresClient.$procs.greet.useMutation(); + +// procedures (mutation) +proceduresClient.$procs.sum.useMutation().mutate({ args: { a: 1, b: 2 } }); +// @ts-expect-error wrong arg shape for multi-param procedure +proceduresClient.$procs.sum.useMutation().mutate([1, 2]); +proceduresClient.$procs.sum + .useMutation() + .mutateAsync({ args: { a: 1, b: 2 } }) + .then((d) => check(d.toFixed(2))); diff --git a/packages/clients/tanstack-query/test/schemas/procedures/schema-lite.ts b/packages/clients/tanstack-query/test/schemas/procedures/schema-lite.ts new file mode 100644 index 00000000..630d3141 --- /dev/null +++ b/packages/clients/tanstack-query/test/schemas/procedures/schema-lite.ts @@ -0,0 +1,60 @@ +////////////////////////////////////////////////////////////////////////////////////////////// +// NOTE: Test fixture schema used for TanStack Query typing tests. // +////////////////////////////////////////////////////////////////////////////////////////////// + +import { type SchemaDef, ExpressionUtils } from '@zenstackhq/orm/schema'; + +export class SchemaType implements SchemaDef { + provider = { + type: 'sqlite', + } as const; + + models = { + User: { + name: 'User', + fields: { + id: { + name: 'id', + type: 'String', + id: true, + default: ExpressionUtils.call('cuid'), + }, + email: { + name: 'email', + type: 'String', + unique: true, + }, + }, + idFields: ['id'], + uniqueFields: { + id: { type: 'String' }, + email: { type: 'String' }, + }, + }, + } as const; + + procedures = { + greet: { + params: { name: { name: 'name', type: 'String', optional: true } }, + returnType: 'String', + }, + greetMany: { + params: { name: { name: 'name', type: 'String' } }, + returnType: 'String', + returnArray: true, + }, + sum: { + params: { + a: { name: 'a', type: 'Int' }, + b: { name: 'b', type: 'Int' }, + }, + returnType: 'Int', + mutation: true, + }, + } as const; + + authType = 'User' as const; + plugins = {}; +} + +export const schema = new SchemaType(); diff --git a/packages/clients/tanstack-query/test/svelte-typing-test.ts b/packages/clients/tanstack-query/test/svelte-typing-test.ts index 9c8788eb..0a309211 100644 --- a/packages/clients/tanstack-query/test/svelte-typing-test.ts +++ b/packages/clients/tanstack-query/test/svelte-typing-test.ts @@ -1,14 +1,21 @@ import { useClientQueries } from '../src/svelte/index.svelte'; import { schema } from './schemas/basic/schema-lite'; +import { schema as proceduresSchema } from './schemas/procedures/schema-lite'; const client = useClientQueries(schema); +const proceduresClient = useClientQueries(proceduresSchema); // @ts-expect-error missing args client.user.useFindUnique(); check(client.user.useFindUnique(() => ({ where: { id: '1' } })).data?.email); check(client.user.useFindUnique(() => ({ where: { id: '1' } })).queryKey); -check(client.user.useFindUnique(() => ({ where: { id: '1' } }), () => ({ optimisticUpdate: true, enabled: false }))); +check( + client.user.useFindUnique( + () => ({ where: { id: '1' } }), + () => ({ optimisticUpdate: true, enabled: false }), + ), +); // @ts-expect-error unselected field check(client.user.useFindUnique(() => ({ select: { email: true }, where: { id: '1' } })).data?.name); @@ -43,28 +50,34 @@ check(client.user.useGroupBy(() => ({ by: ['email'], _max: { name: true } })).da // @ts-expect-error missing args client.user.useCreate().mutate(); client.user.useCreate().mutate({ data: { email: 'test@example.com' } }); -client.user.useCreate(() => ({ optimisticUpdate: true, invalidateQueries: false, retry: 3 })).mutate({ - data: { email: 'test@example.com' }, -}); - -client.user.useCreate() +client.user + .useCreate(() => ({ optimisticUpdate: true, invalidateQueries: false, retry: 3 })) + .mutate({ + data: { email: 'test@example.com' }, + }); + +client.user + .useCreate() .mutateAsync({ data: { email: 'test@example.com' }, include: { posts: true } }) .then((d) => check(d.posts[0]?.title)); -client.user.useCreateMany() +client.user + .useCreateMany() .mutateAsync({ data: [{ email: 'test@example.com' }, { email: 'test2@example.com' }], skipDuplicates: true, }) .then((d) => d.count); -client.user.useCreateManyAndReturn() +client.user + .useCreateManyAndReturn() .mutateAsync({ data: [{ email: 'test@example.com' }], }) .then((d) => check(d[0]?.name)); -client.user.useCreateManyAndReturn() +client.user + .useCreateManyAndReturn() .mutateAsync({ data: [{ email: 'test@example.com' }], select: { email: true }, @@ -83,7 +96,8 @@ client.user.useUpdate().mutate( client.user.useUpdateMany().mutate({ data: { email: 'updated@example.com' } }); -client.user.useUpdateManyAndReturn() +client.user + .useUpdateManyAndReturn() .mutateAsync({ data: { email: 'updated@example.com' } }) .then((d) => check(d[0]?.email)); @@ -106,3 +120,28 @@ client.foo.useCreate(); client.foo.useUpdate(); client.bar.useCreate(); + +// procedures (query) +check(proceduresClient.$procs.greet.useQuery().data?.toUpperCase()); +check(proceduresClient.$procs.greet.useQuery(() => ({ args: { name: 'bob' } })).data?.toUpperCase()); +check(proceduresClient.$procs.greet.useQuery(() => ({ args: { name: 'bob' } })).queryKey); + +// Infinite queries for procedures are currently disabled, will add back later if needed +// check( +// proceduresClient.$procs.greetMany +// .useInfiniteQuery(() => ({ args: { name: 'bob' } })) +// .data?.pages[0]?.[0]?.toUpperCase(), +// ); +// check(proceduresClient.$procs.greetMany.useInfiniteQuery(() => ({ args: { name: 'bob' } })).queryKey); + +// @ts-expect-error greet is not a mutation procedure +proceduresClient.$procs.greet.useMutation(); + +// procedures (mutation) +proceduresClient.$procs.sum.useMutation().mutate({ args: { a: 1, b: 2 } }); +// @ts-expect-error wrong arg shape for multi-param procedure +proceduresClient.$procs.sum.useMutation().mutate([1, 2]); +proceduresClient.$procs.sum + .useMutation() + .mutateAsync({ args: { a: 1, b: 2 } }) + .then((d) => check(d.toFixed(2))); diff --git a/packages/clients/tanstack-query/test/vue-typing-test.ts b/packages/clients/tanstack-query/test/vue-typing-test.ts index f134378c..ee73c67f 100644 --- a/packages/clients/tanstack-query/test/vue-typing-test.ts +++ b/packages/clients/tanstack-query/test/vue-typing-test.ts @@ -1,7 +1,9 @@ import { useClientQueries } from '../src/vue'; import { schema } from './schemas/basic/schema-lite'; +import { schema as proceduresSchema } from './schemas/procedures/schema-lite'; const client = useClientQueries(schema); +const proceduresClient = useClientQueries(proceduresSchema); // @ts-expect-error missing args client.user.useFindUnique(); @@ -109,3 +111,26 @@ client.foo.useCreate(); client.foo.useUpdate(); client.bar.useCreate(); + +// procedures (query) +check(proceduresClient.$procs.greet.useQuery().data.value?.toUpperCase()); +check(proceduresClient.$procs.greet.useQuery({ args: { name: 'bob' } }).data.value?.toUpperCase()); +check(proceduresClient.$procs.greet.useQuery({ args: { name: 'bob' } }).queryKey.value); + +// Infinite queries for procedures are currently disabled, will add back later if needed +// check( +// proceduresClient.$procs.greetMany.useInfiniteQuery({ args: { name: 'bob' } }).data.value?.pages[0]?.[0]?.toUpperCase(), +// ); +// check(proceduresClient.$procs.greetMany.useInfiniteQuery({ args: { name: 'bob' } }).queryKey.value); + +// @ts-expect-error greet is not a mutation procedure +proceduresClient.$procs.greet.useMutation(); + +// procedures (mutation) +proceduresClient.$procs.sum.useMutation().mutate({ args: { a: 1, b: 2 } }); +// @ts-expect-error wrong arg shape for multi-param procedure +proceduresClient.$procs.sum.useMutation().mutate([1, 2]); +proceduresClient.$procs.sum + .useMutation() + .mutateAsync({ args: { a: 1, b: 2 } }) + .then((d) => check(d.toFixed(2))); diff --git a/packages/language/src/generated/ast.ts b/packages/language/src/generated/ast.ts index e759aa1f..7d6a589c 100644 --- a/packages/language/src/generated/ast.ts +++ b/packages/language/src/generated/ast.ts @@ -53,7 +53,9 @@ export type ZModelKeywordNames = | "Object" | "String" | "TransitiveFieldReference" + | "Undefined" | "Unsupported" + | "Void" | "[" | "]" | "^" @@ -120,10 +122,10 @@ export function isExpression(item: unknown): item is Expression { return reflection.isInstance(item, Expression); } -export type ExpressionType = 'Any' | 'Boolean' | 'DateTime' | 'Float' | 'Int' | 'Null' | 'Object' | 'String' | 'Unsupported'; +export type ExpressionType = 'Any' | 'BigInt' | 'Boolean' | 'Bytes' | 'DateTime' | 'Decimal' | 'Float' | 'Int' | 'Json' | 'Null' | 'Object' | 'String' | 'Undefined' | 'Unsupported' | 'Void'; export function isExpressionType(item: unknown): item is ExpressionType { - return item === 'String' || item === 'Int' || item === 'Float' || item === 'Boolean' || item === 'DateTime' || item === 'Null' || item === 'Object' || item === 'Any' || item === 'Unsupported'; + return item === 'String' || item === 'Int' || item === 'Float' || item === 'Boolean' || item === 'BigInt' || item === 'Decimal' || item === 'DateTime' || item === 'Json' || item === 'Bytes' || item === 'Null' || item === 'Object' || item === 'Any' || item === 'Void' || item === 'Undefined' || item === 'Unsupported'; } export type LiteralExpr = BooleanLiteral | NumberLiteral | StringLiteral; @@ -156,10 +158,10 @@ export function isRegularID(item: unknown): item is RegularID { return item === 'model' || item === 'enum' || item === 'attribute' || item === 'datasource' || item === 'plugin' || item === 'abstract' || item === 'in' || item === 'view' || item === 'import' || item === 'type' || (typeof item === 'string' && (/[_a-zA-Z][\w_]*/.test(item))); } -export type RegularIDWithTypeNames = 'Any' | 'BigInt' | 'Boolean' | 'Bytes' | 'DateTime' | 'Decimal' | 'Float' | 'Int' | 'Json' | 'Null' | 'Object' | 'String' | 'Unsupported' | RegularID; +export type RegularIDWithTypeNames = 'Any' | 'BigInt' | 'Boolean' | 'Bytes' | 'DateTime' | 'Decimal' | 'Float' | 'Int' | 'Json' | 'Null' | 'Object' | 'String' | 'Unsupported' | 'Void' | RegularID; export function isRegularIDWithTypeNames(item: unknown): item is RegularIDWithTypeNames { - return isRegularID(item) || item === 'String' || item === 'Boolean' || item === 'Int' || item === 'BigInt' || item === 'Float' || item === 'Decimal' || item === 'DateTime' || item === 'Json' || item === 'Bytes' || item === 'Null' || item === 'Object' || item === 'Any' || item === 'Unsupported'; + return isRegularID(item) || item === 'String' || item === 'Boolean' || item === 'Int' || item === 'BigInt' || item === 'Float' || item === 'Decimal' || item === 'DateTime' || item === 'Json' || item === 'Bytes' || item === 'Null' || item === 'Object' || item === 'Any' || item === 'Void' || item === 'Unsupported'; } export type TypeDeclaration = DataModel | Enum | TypeDef; @@ -477,7 +479,7 @@ export function isFunctionDecl(item: unknown): item is FunctionDecl { } export interface FunctionParam extends langium.AstNode { - readonly $container: FunctionDecl | Procedure; + readonly $container: FunctionDecl; readonly $type: 'FunctionParam'; name: RegularID; optional: boolean; @@ -648,7 +650,7 @@ export interface Procedure extends langium.AstNode { attributes: Array; mutation: boolean; name: RegularID; - params: Array; + params: Array; returnType: FunctionParamType; } diff --git a/packages/language/src/generated/grammar.ts b/packages/language/src/generated/grammar.ts index 02260ccd..53688b82 100644 --- a/packages/language/src/generated/grammar.ts +++ b/packages/language/src/generated/grammar.ts @@ -2927,7 +2927,8 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "Keyword", "value": "mutation" - } + }, + "cardinality": "?" }, { "$type": "Keyword", @@ -2978,7 +2979,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@47" + "$ref": "#/rules@49" }, "arguments": [] } @@ -3156,6 +3157,10 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "$type": "Keyword", "value": "Any" }, + { + "$type": "Keyword", + "value": "Void" + }, { "$type": "Keyword", "value": "Unsupported" @@ -3766,10 +3771,26 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "$type": "Keyword", "value": "Boolean" }, + { + "$type": "Keyword", + "value": "BigInt" + }, + { + "$type": "Keyword", + "value": "Decimal" + }, { "$type": "Keyword", "value": "DateTime" }, + { + "$type": "Keyword", + "value": "Json" + }, + { + "$type": "Keyword", + "value": "Bytes" + }, { "$type": "Keyword", "value": "Null" @@ -3782,6 +3803,14 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "$type": "Keyword", "value": "Any" }, + { + "$type": "Keyword", + "value": "Void" + }, + { + "$type": "Keyword", + "value": "Undefined" + }, { "$type": "Keyword", "value": "Unsupported" diff --git a/packages/language/src/utils.ts b/packages/language/src/utils.ts index 4b489fd5..1e963d9b 100644 --- a/packages/language/src/utils.ts +++ b/packages/language/src/utils.ts @@ -105,8 +105,8 @@ export function typeAssignable(destType: ExpressionType, sourceType: ExpressionT * Maps a ZModel builtin type to expression type */ export function mapBuiltinTypeToExpressionType( - type: BuiltinType | 'Any' | 'Object' | 'Null' | 'Unsupported', -): ExpressionType | 'Any' { + type: BuiltinType | ExpressionType, +): ExpressionType { switch (type) { case 'Any': case 'Boolean': @@ -115,6 +115,10 @@ export function mapBuiltinTypeToExpressionType( case 'Int': case 'Float': case 'Null': + case 'Object': + case 'Unsupported': + case 'Void': + case 'Undefined': return type; case 'BigInt': return 'Int'; @@ -123,10 +127,6 @@ export function mapBuiltinTypeToExpressionType( case 'Json': case 'Bytes': return 'Any'; - case 'Object': - return 'Object'; - case 'Unsupported': - return 'Unsupported'; } } diff --git a/packages/language/src/validator.ts b/packages/language/src/validator.ts index dcbf549f..8663e90f 100644 --- a/packages/language/src/validator.ts +++ b/packages/language/src/validator.ts @@ -9,6 +9,7 @@ import type { GeneratorDecl, InvocationExpr, Model, + Procedure, TypeDef, ZModelAstType, } from './generated/ast'; @@ -20,6 +21,7 @@ import EnumValidator from './validators/enum-validator'; import ExpressionValidator from './validators/expression-validator'; import FunctionDeclValidator from './validators/function-decl-validator'; import FunctionInvocationValidator from './validators/function-invocation-validator'; +import ProcedureValidator from './validators/procedure-validator'; import SchemaValidator from './validators/schema-validator'; import TypeDefValidator from './validators/typedef-validator'; @@ -40,6 +42,7 @@ export function registerValidationChecks(services: ZModelServices) { Expression: validator.checkExpression, InvocationExpr: validator.checkFunctionInvocation, FunctionDecl: validator.checkFunctionDecl, + Procedure: validator.checkProcedure, }; registry.register(checks, validator); } @@ -89,4 +92,9 @@ export class ZModelValidator { checkFunctionDecl(node: FunctionDecl, accept: ValidationAcceptor): void { new FunctionDeclValidator().validate(node, accept); } + + checkProcedure(node: Procedure, accept: ValidationAcceptor): void { + new ProcedureValidator().validate(node, accept); + } } + diff --git a/packages/language/src/validators/common.ts b/packages/language/src/validators/common.ts index 257efc79..01aad460 100644 --- a/packages/language/src/validators/common.ts +++ b/packages/language/src/validators/common.ts @@ -20,10 +20,11 @@ export function validateDuplicatedDeclarations( accept: ValidationAcceptor, ): void { const groupByName = decls.reduce>>((group, decl) => { + // Use a null-prototype map to avoid issues with names like "__proto__"/"constructor". group[decl.name] = group[decl.name] ?? []; group[decl.name]!.push(decl); return group; - }, {}); + }, Object.create(null) as Record>); for (const [name, decls] of Object.entries(groupByName)) { if (decls.length > 1) { diff --git a/packages/language/src/validators/function-invocation-validator.ts b/packages/language/src/validators/function-invocation-validator.ts index d3bb9e7c..cfc39192 100644 --- a/packages/language/src/validators/function-invocation-validator.ts +++ b/packages/language/src/validators/function-invocation-validator.ts @@ -20,6 +20,7 @@ import { getLiteral, isCheckInvocation, isDataFieldReference, + mapBuiltinTypeToExpressionType, typeAssignable, } from '../utils'; import type { AstValidator } from './common'; @@ -173,7 +174,9 @@ export default class FunctionInvocationValidator implements AstValidator { + validate(proc: Procedure, accept: ValidationAcceptor): void { + this.validateName(proc, accept); + proc.attributes.forEach((attr) => validateAttributeApplication(attr, accept)); + } + + private validateName(proc: Procedure, accept: ValidationAcceptor): void { + if (RESERVED_PROCEDURE_NAMES.has(proc.name)) { + accept('error', `Procedure name "${proc.name}" is reserved`, { + node: proc, + property: 'name', + }); + } + } +} diff --git a/packages/language/src/zmodel.langium b/packages/language/src/zmodel.langium index 8d279787..9f09dbf8 100644 --- a/packages/language/src/zmodel.langium +++ b/packages/language/src/zmodel.langium @@ -226,7 +226,7 @@ ProcedureParam: TRIPLE_SLASH_COMMENT* name=RegularID ':' type=FunctionParamType (optional?='?')?; Procedure: - TRIPLE_SLASH_COMMENT* (mutation?='mutation') 'procedure' name=RegularID '(' (params+=ProcedureParam (',' params+=FunctionParam)*)? ')' ':' returnType=FunctionParamType (attributes+=InternalAttribute)*; + TRIPLE_SLASH_COMMENT* (mutation?='mutation')? 'procedure' name=RegularID '(' (params+=ProcedureParam (',' params+=ProcedureParam)*)? ')' ':' returnType=FunctionParamType (attributes+=InternalAttribute)*; // https://github.com/langium/langium/discussions/1012 RegularID returns string: @@ -234,7 +234,7 @@ RegularID returns string: ID | 'model' | 'enum' | 'attribute' | 'datasource' | 'plugin' | 'abstract' | 'in' | 'view' | 'import' | 'type'; RegularIDWithTypeNames returns string: - RegularID | 'String' | 'Boolean' | 'Int' | 'BigInt' | 'Float' | 'Decimal' | 'DateTime' | 'Json' | 'Bytes' | 'Null' | 'Object' | 'Any' | 'Unsupported'; + RegularID | 'String' | 'Boolean' | 'Int' | 'BigInt' | 'Float' | 'Decimal' | 'DateTime' | 'Json' | 'Bytes' | 'Null' | 'Object' | 'Any' | 'Void' | 'Unsupported'; // attribute Attribute: @@ -266,7 +266,7 @@ AttributeArg: (name=RegularID ':')? value=Expression; ExpressionType returns string: - 'String' | 'Int' | 'Float' | 'Boolean' | 'DateTime' | 'Null' | 'Object' | 'Any' | 'Unsupported'; + 'String' | 'Int' | 'Float' | 'Boolean' | 'BigInt' | 'Decimal' | 'DateTime' | 'Json' | 'Bytes' | 'Null' | 'Object' | 'Any' | 'Void' | 'Undefined' | 'Unsupported'; BuiltinType returns string: 'String' | 'Boolean' | 'Int' | 'BigInt' | 'Float' | 'Decimal' | 'DateTime' | 'Json' | 'Bytes'; diff --git a/packages/language/test/procedure-validation.test.ts b/packages/language/test/procedure-validation.test.ts new file mode 100644 index 00000000..a830ac66 --- /dev/null +++ b/packages/language/test/procedure-validation.test.ts @@ -0,0 +1,43 @@ +import { describe, it } from 'vitest'; +import { loadSchemaWithError } from './utils'; + +describe('Procedure validation', () => { + it('rejects unknown parameter type', async () => { + await loadSchemaWithError( + ` +model User { + id Int @id +} + +procedure foo(a: NotAType): Int + `, + /unknown type|could not resolve reference/i, + ); + }); + + it('rejects unknown return type', async () => { + await loadSchemaWithError( + ` +model User { + id Int @id +} + +procedure foo(): NotAType + `, + /unknown type|could not resolve reference/i, + ); + }); + + it('rejects reserved procedure names', async () => { + await loadSchemaWithError( + ` +model User { + id Int @id +} + +procedure __proto__(): Int + `, + /reserved/i, + ); + }); +}); diff --git a/packages/orm/src/client/client-impl.ts b/packages/orm/src/client/client-impl.ts index 8e2c8382..a7605e32 100644 --- a/packages/orm/src/client/client-impl.ts +++ b/packages/orm/src/client/client-impl.ts @@ -214,28 +214,67 @@ export class ClientImpl { } } - get $procedures() { + get $procs() { return Object.keys(this.$schema.procedures ?? {}).reduce((acc, name) => { - acc[name] = (...args: unknown[]) => this.handleProc(name, args); + acc[name] = (input?: unknown) => this.handleProc(name, input); return acc; }, {} as any); } - private async handleProc(name: string, args: unknown[]) { + private async handleProc(name: string, input: unknown) { if (!('procedures' in this.$options) || !this.$options || typeof this.$options.procedures !== 'object') { throw createConfigError('Procedures are not configured for the client.'); } + const procDef = (this.$schema.procedures ?? {})[name]; + if (!procDef) { + throw createConfigError(`Procedure "${name}" is not defined in schema.`); + } + const procOptions = this.$options.procedures as ProceduresOptions< SchemaDef & { procedures: Record; } >; if (!procOptions[name] || typeof procOptions[name] !== 'function') { - throw new Error(`Procedure "${name}" does not have a handler configured.`); + throw createConfigError(`Procedure "${name}" does not have a handler configured.`); } - return (procOptions[name] as Function).apply(this, [this, ...args]); + // Validate inputs using the same validator infrastructure as CRUD operations. + const inputValidator = new InputValidator(this as any); + const validatedInput = inputValidator.validateProcedureInput(name, input); + + const handler = procOptions[name] as Function; + + const invokeWithClient = async (client: any, _input: unknown) => { + let proceed = async (nextInput: unknown) => { + const sanitizedNextInput = + nextInput && typeof nextInput === 'object' && !Array.isArray(nextInput) ? nextInput : {}; + + return handler({ client, ...sanitizedNextInput }); + }; + + // apply plugins + const plugins = [...(client.$options?.plugins ?? [])]; + for (const plugin of plugins) { + const onProcedure = plugin.onProcedure; + if (onProcedure) { + const _proceed = proceed; + proceed = (nextInput: unknown) => + onProcedure({ + client, + name, + mutation: !!procDef.mutation, + input: nextInput, + proceed: (finalInput: unknown) => _proceed(finalInput), + }) as Promise; + } + } + + return proceed(_input); + }; + + return invokeWithClient(this as any, validatedInput); } async $connect() { diff --git a/packages/orm/src/client/contract.ts b/packages/orm/src/client/contract.ts index 050381fd..04c483fc 100644 --- a/packages/orm/src/client/contract.ts +++ b/packages/orm/src/client/contract.ts @@ -1,4 +1,3 @@ -import type Decimal from 'decimal.js'; import { type FieldIsArray, type GetModels, @@ -10,7 +9,7 @@ import { type SchemaDef, } from '../schema'; import type { AnyKysely } from '../utils/kysely-utils'; -import type { OrUndefinedIf, Simplify, UnwrapTuplePromises } from '../utils/type-utils'; +import type { Simplify, UnwrapTuplePromises } from '../utils/type-utils'; import type { TRANSACTION_UNSUPPORTED_METHODS } from './constants'; import type { AggregateArgs, @@ -27,9 +26,10 @@ import type { FindFirstArgs, FindManyArgs, FindUniqueArgs, + GetProcedureNames, GroupByArgs, GroupByResult, - ModelResult, + ProcedureFunc, SelectSubset, SimplifiedPlainResult, Subset, @@ -194,7 +194,7 @@ export type ClientContract; } & { [Key in GetModels as Uncapitalize]: ModelOperations>; -} & Procedures; +} & ProcedureOperations; /** * The contract for a client in a transaction. @@ -204,41 +204,18 @@ export type TransactionClientContract; -type _TypeMap = { - String: string; - Int: number; - Float: number; - BigInt: bigint; - Decimal: Decimal; - Boolean: boolean; - DateTime: Date; -}; - -type MapType = T extends keyof _TypeMap - ? _TypeMap[T] - : T extends GetModels - ? ModelResult - : unknown; - -export type Procedures = +export type ProcedureOperations = Schema['procedures'] extends Record ? { - $procedures: { - [Key in keyof Schema['procedures']]: ProcedureFunc; + /** + * Custom procedures. + */ + $procs: { + [Key in GetProcedureNames]: ProcedureFunc; }; } : {}; -export type ProcedureFunc = ( - ...args: MapProcedureParams -) => Promise>; - -type MapProcedureParams = { - [P in keyof Params]: Params[P] extends { type: infer U } - ? OrUndefinedIf, Params[P] extends { optional: true } ? true : false> - : never; -}; - /** * Creates a new ZenStack client instance. */ diff --git a/packages/orm/src/client/crud-types.ts b/packages/orm/src/client/crud-types.ts index 1fefdc7d..566ec908 100644 --- a/packages/orm/src/client/crud-types.ts +++ b/packages/orm/src/client/crud-types.ts @@ -23,6 +23,7 @@ import type { GetTypeDefs, ModelFieldIsOptional, NonRelationFields, + ProcedureDef, RelationFields, RelationFieldType, RelationInfo, @@ -34,16 +35,20 @@ import type { import type { AtLeast, MapBaseType, + MaybePromise, NonEmptyArray, NullableIf, Optional, OrArray, + OrUndefinedIf, PartialIf, Simplify, + TypeMap, ValueOfPotentialTuple, WrapType, XOR, } from '../utils/type-utils'; +import type { ClientContract } from './contract'; import type { QueryOptions } from './options'; import type { ToKyselySchema } from './query-builder'; @@ -1976,6 +1981,91 @@ type NestedDeleteManyInput< // #endregion +// #region Procedures + +export type GetProcedureNames = Schema extends { procedures: Record } + ? keyof Schema['procedures'] + : never; + +export type GetProcedureParams> = Schema extends { + procedures: Record; +} + ? Schema['procedures'][ProcName]['params'] + : never; + +export type GetProcedure> = Schema extends { + procedures: Record; +} + ? Schema['procedures'][ProcName] + : never; + +type _OptionalProcedureParamNames = keyof { + [K in keyof Params as Params[K] extends { optional: true } ? K : never]: K; +}; + +type _RequiredProcedureParamNames = keyof { + [K in keyof Params as Params[K] extends { optional: true } ? never : K]: K; +}; + +type _HasRequiredProcedureParams = _RequiredProcedureParamNames extends never ? false : true; + +type MapProcedureArgsObject = Simplify< + Optional< + { + [K in keyof Params]: MapProcedureParam; + }, + _OptionalProcedureParamNames + > +>; + +export type ProcedureEnvelope< + Schema extends SchemaDef, + ProcName extends GetProcedureNames, + Params = GetProcedureParams, +> = keyof Params extends never + ? // no params + { args?: Record } + : _HasRequiredProcedureParams extends true + ? // has required params + { args: MapProcedureArgsObject } + : // no required params + { args?: MapProcedureArgsObject }; + +type ProcedureHandlerCtx> = { + client: ClientContract; +} & ProcedureEnvelope; + +/** + * Shape of a procedure's runtime function. + */ +export type ProcedureFunc> = ( + ...args: _HasRequiredProcedureParams> extends true + ? [input: ProcedureEnvelope] + : [input?: ProcedureEnvelope] +) => MaybePromise>>; + +/** + * Signature for procedure handlers configured via client options. + */ +export type ProcedureHandlerFunc> = ( + ctx: ProcedureHandlerCtx, +) => MaybePromise>>; + +type MapProcedureReturn = Proc extends { returnType: infer R } + ? Proc extends { returnArray: true } + ? Array> + : MapType + : never; + +type MapProcedureParam = P extends { type: infer U } + ? OrUndefinedIf< + P extends { array: true } ? Array> : MapType, + P extends { optional: true } ? true : false + > + : never; + +// #endregion + // #region Utilities type NonOwnedRelationFields> = keyof { @@ -1992,6 +2082,21 @@ type HasToManyRelations> = GetEnum[keyof GetEnum< + Schema, + Enum +>]; + +type MapType = T extends keyof TypeMap + ? TypeMap[T] + : T extends GetModels + ? ModelResult + : T extends GetTypeDefs + ? TypeDefResult + : T extends GetEnums + ? EnumValue + : unknown; + // type ProviderSupportsDistinct = Schema['provider']['type'] extends 'postgresql' // ? true // : false; diff --git a/packages/orm/src/client/crud/validator/index.ts b/packages/orm/src/client/crud/validator/index.ts index dc1a6f83..ef69d78f 100644 --- a/packages/orm/src/client/crud/validator/index.ts +++ b/packages/orm/src/client/crud/validator/index.ts @@ -9,6 +9,7 @@ import { type BuiltinType, type EnumDef, type FieldDef, + type ProcedureDef, type GetModels, type ModelDef, type SchemaDef, @@ -39,6 +40,8 @@ import { getEnum, getTypeDef, getUniqueFields, + isEnum, + isTypeDef, requireField, requireModel, } from '../../query-utils'; @@ -70,6 +73,120 @@ export class InputValidator { return this.client.$options.validateInput !== false; } + validateProcedureInput(proc: string, input: unknown): unknown { + const procDef = (this.schema.procedures ?? {})[proc] as ProcedureDef | undefined; + invariant(procDef, `Procedure "${proc}" not found in schema`); + + const params = Object.values(procDef.params ?? {}); + + // For procedures where every parameter is optional, allow omitting the input entirely. + if (typeof input === 'undefined') { + if (params.length === 0) { + return undefined; + } + if (params.every((p) => p.optional)) { + return undefined; + } + throw createInvalidInputError('Missing procedure arguments', `$procs.${proc}`); + } + + if (typeof input !== 'object') { + throw createInvalidInputError('Procedure input must be an object', `$procs.${proc}`); + } + + const envelope = input as Record; + const argsPayload = Object.prototype.hasOwnProperty.call(envelope, 'args') ? (envelope as any).args : undefined; + + if (params.length === 0) { + if (typeof argsPayload === 'undefined') { + return input; + } + if (!argsPayload || typeof argsPayload !== 'object' || Array.isArray(argsPayload)) { + throw createInvalidInputError('Procedure `args` must be an object', `$procs.${proc}`); + } + if (Object.keys(argsPayload as any).length === 0) { + return input; + } + throw createInvalidInputError('Procedure does not accept arguments', `$procs.${proc}`); + } + + if (typeof argsPayload === 'undefined') { + if (params.every((p) => p.optional)) { + return input; + } + throw createInvalidInputError('Missing procedure arguments', `$procs.${proc}`); + } + + if (!argsPayload || typeof argsPayload !== 'object' || Array.isArray(argsPayload)) { + throw createInvalidInputError('Procedure `args` must be an object', `$procs.${proc}`); + } + + const obj = argsPayload as Record; + + for (const param of params) { + const value = (obj as any)[param.name]; + + if (!Object.prototype.hasOwnProperty.call(obj, param.name)) { + if (param.optional) { + continue; + } + throw createInvalidInputError(`Missing procedure argument: ${param.name}`, `$procs.${proc}`); + } + + if (typeof value === 'undefined') { + if (param.optional) { + continue; + } + throw createInvalidInputError( + `Invalid procedure argument: ${param.name} is required`, + `$procs.${proc}`, + ); + } + + const schema = this.makeProcedureParamSchema(param); + const parsed = schema.safeParse(value); + if (!parsed.success) { + throw createInvalidInputError( + `Invalid procedure argument: ${param.name}: ${formatError(parsed.error)}`, + `$procs.${proc}`, + ); + } + } + + return input; + } + + private makeProcedureParamSchema(param: { type: string; array?: boolean; optional?: boolean }): z.ZodType { + let schema: z.ZodType; + + if (isTypeDef(this.schema, param.type)) { + schema = this.makeTypeDefSchema(param.type); + } else if (isEnum(this.schema, param.type)) { + schema = this.makeEnumSchema(param.type); + } else if (param.type in (this.schema.models ?? {})) { + // For model-typed values, accept any object (no deep shape validation). + schema = z.record(z.string(), z.unknown()); + } else { + // Builtin scalar types. + schema = this.makeScalarSchema(param.type as BuiltinType); + + // If a type isn't recognized by any of the above branches, `makeScalarSchema` returns `unknown`. + // Treat it as configuration/schema error. + if (schema instanceof z.ZodUnknown) { + throw createInternalError(`Unsupported procedure parameter type: ${param.type}`); + } + } + + if (param.array) { + schema = schema.array(); + } + if (param.optional) { + schema = schema.optional(); + } + + return schema; + } + validateFindArgs( model: GetModels, args: unknown, diff --git a/packages/orm/src/client/index.ts b/packages/orm/src/client/index.ts index 6a320300..e69e4180 100644 --- a/packages/orm/src/client/index.ts +++ b/packages/orm/src/client/index.ts @@ -3,6 +3,7 @@ export * from './contract'; export type * from './crud-types'; export { getCrudDialect } from './crud/dialects'; export { BaseCrudDialect } from './crud/dialects/base-dialect'; +export { InputValidator } from './crud/validator'; export { ORMError, ORMErrorReason, RejectedByPolicyReason } from './errors'; export * from './options'; export * from './plugin'; diff --git a/packages/orm/src/client/options.ts b/packages/orm/src/client/options.ts index 64f50b61..c5e6c94d 100644 --- a/packages/orm/src/client/options.ts +++ b/packages/orm/src/client/options.ts @@ -1,7 +1,8 @@ import type { Dialect, Expression, ExpressionBuilder, KyselyConfig } from 'kysely'; import type { GetModel, GetModelFields, GetModels, ProcedureDef, ScalarFields, SchemaDef } from '../schema'; import type { PrependParameter } from '../utils/type-utils'; -import type { ClientContract, CRUD_EXT, ProcedureFunc } from './contract'; +import type { ClientContract, CRUD_EXT } from './contract'; +import type { GetProcedureNames, ProcedureHandlerFunc } from './crud-types'; import type { BaseCrudDialect } from './crud/dialects/base-dialect'; import type { RuntimePlugin } from './plugin'; import type { ToKyselySchema } from './query-builder'; @@ -134,10 +135,7 @@ export type ProceduresOptions = Schema extends { procedures: Record; } ? { - [Key in keyof Schema['procedures']]: PrependParameter< - ClientContract, - ProcedureFunc - >; + [Key in GetProcedureNames]: ProcedureHandlerFunc; } : {}; diff --git a/packages/orm/src/client/plugin.ts b/packages/orm/src/client/plugin.ts index 5664abb8..cd092f4a 100644 --- a/packages/orm/src/client/plugin.ts +++ b/packages/orm/src/client/plugin.ts @@ -36,6 +36,11 @@ export interface RuntimePlugin { */ onQuery?: OnQueryCallback; + /** + * Intercepts a procedure invocation. + */ + onProcedure?: OnProcedureCallback; + /** * Intercepts an entity mutation. */ @@ -56,6 +61,42 @@ export function definePlugin(plugin: RuntimePlugin = (ctx: OnProcedureHookContext) => Promise; + +export type OnProcedureHookContext = { + /** + * The procedure name. + */ + name: string; + + /** + * Whether the procedure is a mutation. + */ + mutation: boolean; + + /** + * Procedure invocation input (envelope). + * + * The canonical shape is `{ args?: Record }`. + * When a procedure has required params, `args` is required. + */ + input: unknown; + + /** + * Continues the invocation. The input passed here is forwarded to the next handler. + */ + proceed: (input: unknown) => Promise; + + /** + * The ZenStack client that is invoking the procedure. + */ + client: ClientContract; +}; + +// #endregion + // #region OnQuery hooks type OnQueryCallback = (ctx: OnQueryHookContext) => Promise; diff --git a/packages/orm/src/utils/type-utils.ts b/packages/orm/src/utils/type-utils.ts index 4c671275..f1ad3d35 100644 --- a/packages/orm/src/utils/type-utils.ts +++ b/packages/orm/src/utils/type-utils.ts @@ -31,7 +31,7 @@ export type WrapType = Array extends true ? T | null : T; -type TypeMap = { +export type TypeMap = { String: string; Boolean: boolean; Int: number; @@ -41,6 +41,12 @@ type TypeMap = { DateTime: Date; Bytes: Uint8Array; Json: JsonValue; + Null: null; + Object: Record; + Any: unknown; + Unsupported: unknown; + Void: void; + Undefined: undefined; }; export type MapBaseType = T extends keyof TypeMap ? TypeMap[T] : unknown; diff --git a/packages/orm/test/procedures.test.ts b/packages/orm/test/procedures.test.ts new file mode 100644 index 00000000..6959c8bc --- /dev/null +++ b/packages/orm/test/procedures.test.ts @@ -0,0 +1,138 @@ +import SQLite from 'better-sqlite3'; +import { SqliteDialect } from 'kysely'; +import { describe, expect, it } from 'vitest'; + +import { ZenStackClient } from '../src/client/client-impl'; +import { definePlugin } from '../src/client/plugin'; + +const baseSchema = { + provider: { type: 'sqlite' }, + models: {}, + enums: {}, + typeDefs: {}, +} as const; + +describe('ORM procedures', () => { + it('exposes `$procs`', async () => { + const schema: any = { + ...baseSchema, + procedures: { + hello: { params: [], returnType: 'String' }, + }, + }; + + const client: any = new ZenStackClient(schema, { + dialect: new SqliteDialect({ database: new SQLite(':memory:') }), + procedures: { + hello: async () => 'ok', + }, + }); + + expect(typeof client.$procs.hello).toBe('function'); + expect(await client.$procs.hello()).toBe('ok'); + }); + + it('throws config error when procedures are not configured', async () => { + const schema: any = { + ...baseSchema, + procedures: { + hello: { params: [], returnType: 'String' }, + }, + }; + + const client: any = new ZenStackClient(schema, { + dialect: new SqliteDialect({ database: new SQLite(':memory:') }), + } as any); + + await expect(client.$procs.hello()).rejects.toThrow(/not configured/i); + }); + + it('throws error when a procedure handler is missing', async () => { + const schema: any = { + ...baseSchema, + procedures: { + hello: { params: [], returnType: 'String' }, + }, + }; + + const client: any = new ZenStackClient(schema, { + dialect: new SqliteDialect({ database: new SQLite(':memory:') }), + procedures: {}, + } as any); + + await expect(client.$procs.hello()).rejects.toThrow(/does not have a handler configured/i); + }); + + it('validates procedure args against schema', async () => { + const schema: any = { + ...baseSchema, + procedures: { + echoInt: { + params: [{ name: 'n', type: 'Int' }], + returnType: 'Int', + }, + }, + }; + + const client: any = new ZenStackClient(schema, { + dialect: new SqliteDialect({ database: new SQLite(':memory:') }), + procedures: { + echoInt: async ({ args }: any) => args.n, + }, + }); + + await expect(client.$procs.echoInt({ args: { n: '1' } })).rejects.toThrow(/invalid input/i); + }); + + it('runs procedure through onProcedure hooks', async () => { + const schema: any = { + ...baseSchema, + procedures: { + add: { + params: [ + { name: 'a', type: 'Int' }, + { name: 'b', type: 'Int' }, + ], + returnType: 'Int', + }, + }, + }; + + const calls: string[] = []; + + const p1 = definePlugin({ + id: 'p1', + onProcedure: async (ctx) => { + calls.push(`p1:${ctx.name}`); + return ctx.proceed(ctx.input); + }, + }); + + const p2 = definePlugin({ + id: 'p2', + onProcedure: async (ctx) => { + calls.push(`p2:${ctx.name}`); + // mutate args: add +1 to `a` + const nextInput: any = { + ...(ctx.input as any), + args: { + ...((ctx.input as any)?.args ?? {}), + a: Number((ctx.input as any)?.args?.a) + 1, + }, + }; + return ctx.proceed(nextInput); + }, + }); + + const client: any = new ZenStackClient(schema, { + dialect: new SqliteDialect({ database: new SQLite(':memory:') }), + plugins: [p1, p2], + procedures: { + add: async ({ args }: any) => args.a + args.b, + }, + }); + + await expect(client.$procs.add({ args: { a: 1, b: 2 } })).resolves.toBe(4); + expect(calls).toEqual(['p2:add', 'p1:add']); + }); +}); diff --git a/packages/schema/src/schema.ts b/packages/schema/src/schema.ts index 83640c35..40f7d8bd 100644 --- a/packages/schema/src/schema.ts +++ b/packages/schema/src/schema.ts @@ -77,11 +77,12 @@ export type FieldDef = { isDiscriminator?: boolean; }; -export type ProcedureParam = { name: string; type: string; optional?: boolean }; +export type ProcedureParam = { name: string; type: string; array?: boolean; optional?: boolean }; export type ProcedureDef = { - params: [...ProcedureParam[]]; + params: Record; returnType: string; + returnArray?: boolean; mutation?: boolean; }; diff --git a/packages/sdk/src/ts-schema-generator.ts b/packages/sdk/src/ts-schema-generator.ts index f68bb0bc..5df0b923 100644 --- a/packages/sdk/src/ts-schema-generator.ts +++ b/packages/sdk/src/ts-schema-generator.ts @@ -1124,65 +1124,38 @@ export class TsSchemaGenerator { } private createProcedureObject(proc: Procedure) { - const params = ts.factory.createArrayLiteralExpression( + const params = ts.factory.createObjectLiteralExpression( proc.params.map((param) => - ts.factory.createObjectLiteralExpression([ - ts.factory.createPropertyAssignment('name', ts.factory.createStringLiteral(param.name)), - ...(param.optional - ? [ts.factory.createPropertyAssignment('optional', ts.factory.createTrue())] - : []), - ts.factory.createPropertyAssignment( - 'type', - ts.factory.createStringLiteral(param.type.type ?? param.type.reference!.$refText), - ), - ]), - ), - true, - ); - - const paramsType = ts.factory.createTupleTypeNode([ - ...proc.params.map((param) => - ts.factory.createNamedTupleMember( - undefined, - ts.factory.createIdentifier(param.name), - undefined, - ts.factory.createTypeLiteralNode([ - ts.factory.createPropertySignature( - undefined, - ts.factory.createStringLiteral('name'), - undefined, - ts.factory.createLiteralTypeNode(ts.factory.createStringLiteral(param.name)), - ), - ts.factory.createPropertySignature( - undefined, - ts.factory.createStringLiteral('type'), - undefined, - ts.factory.createLiteralTypeNode( - ts.factory.createStringLiteral(param.type.type ?? param.type.reference!.$refText), - ), - ), + ts.factory.createPropertyAssignment( + param.name, + ts.factory.createObjectLiteralExpression([ + ts.factory.createPropertyAssignment('name', ts.factory.createStringLiteral(param.name)), ...(param.optional - ? [ - ts.factory.createPropertySignature( - undefined, - ts.factory.createStringLiteral('optional'), - undefined, - ts.factory.createLiteralTypeNode(ts.factory.createTrue()), - ), - ] + ? [ts.factory.createPropertyAssignment('optional', ts.factory.createTrue())] : []), + ...(param.type.array + ? [ts.factory.createPropertyAssignment('array', ts.factory.createTrue())] + : []), + ts.factory.createPropertyAssignment( + 'type', + ts.factory.createStringLiteral(param.type.type ?? param.type.reference!.$refText), + ), ]), ), ), - ]); + true, + ); return ts.factory.createObjectLiteralExpression( [ - ts.factory.createPropertyAssignment('params', ts.factory.createAsExpression(params, paramsType)), + ts.factory.createPropertyAssignment('params', params), ts.factory.createPropertyAssignment( 'returnType', ts.factory.createStringLiteral(proc.returnType.type ?? proc.returnType.reference!.$refText), ), + ...(proc.returnType.array + ? [ts.factory.createPropertyAssignment('returnArray', ts.factory.createTrue())] + : []), ...(proc.mutation ? [ts.factory.createPropertyAssignment('mutation', ts.factory.createTrue())] : []), ], true, diff --git a/packages/server/src/api/common/procedures.ts b/packages/server/src/api/common/procedures.ts new file mode 100644 index 00000000..60680158 --- /dev/null +++ b/packages/server/src/api/common/procedures.ts @@ -0,0 +1,137 @@ +import { ORMError } from '@zenstackhq/orm'; +import type { ProcedureDef, ProcedureParam, SchemaDef } from '@zenstackhq/orm/schema'; + +export const PROCEDURE_ROUTE_PREFIXES = ['$procs'] as const; + +export function getProcedureDef(schema: SchemaDef, proc: string): ProcedureDef | undefined { + const procs = schema.procedures ?? {}; + if (!Object.prototype.hasOwnProperty.call(procs, proc)) { + return undefined; + } + return procs[proc]; +} + +/** + * Maps and validates the incoming procedure payload for server-side routing. + * + * Supported payload formats: + * - **Envelope (preferred)**: `{ args: { ... } }` + * - **Direct object**: `{ ... }` (allowed only when *every* parameter is optional) + * + * The function returns the original `payload` unchanged; it only enforces payload + * *shape* and argument presence/keys so downstream code can safely assume a + * consistent contract. + * + * Validation / branching behavior (mirrors the code below): + * - **Zero-parameter procedures** (`params.length === 0`) + * - `undefined` payload is accepted. + * - Object payloads without an `args` key are treated as “no args” and accepted. + * - Envelope payloads with `args: {}` are accepted. + * - Any other payload (including `args` with keys) is rejected. + * - **All-optional parameter procedures** + * - Payload may be omitted (`undefined`). + * - If payload is an object and has no `args` key, it is treated as the direct + * object form. + * - **Missing payload** (required parameters exist) + * - `undefined` is rejected. + * - **Non-object or array payload** + * - Rejected. + * - **Undefined/invalid `args` (envelope form)** + * - If `args` is missing and not all params are optional: rejected. + * - If `args` exists but is not a non-array object: rejected. + * - **Unknown keys** + * - Any key in the `args` object that is not declared by the procedure is + * rejected (prevents silently ignoring typos). + * - **Missing required params** + * - Any declared non-optional param missing from `args` is rejected. + * + * Rationale for rejecting null/falsey payloads: + * - The checks `!payload` and `!argsPayload` intentionally reject values like + * `null`, `false`, `0`, or `''` instead of treating them as “no args”. This + * keeps the API strictly object-based and yields deterministic, descriptive + * errors rather than surprising coercion. + * + * @throws {Error} "procedure does not accept arguments" + * @throws {Error} "missing procedure arguments" + * @throws {Error} "procedure payload must be an object" + * @throws {Error} "procedure `args` must be an object" + * @throws {Error} "unknown procedure argument: " + * @throws {Error} "missing procedure argument: " + */ +export function mapProcedureArgs(procDef: { params: Record }, payload: unknown): unknown { + const params = Object.values(procDef.params ?? {}); + if (params.length === 0) { + if (typeof payload === 'undefined') { + return undefined; + } + if (payload && typeof payload === 'object' && !Array.isArray(payload)) { + const envelope = payload as Record; + const argsPayload = Object.prototype.hasOwnProperty.call(envelope, 'args') + ? (envelope as any).args + : undefined; + + if (typeof argsPayload === 'undefined') { + return payload; + } + + if (argsPayload && typeof argsPayload === 'object' && !Array.isArray(argsPayload)) { + if (Object.keys(argsPayload as any).length === 0) { + return payload; + } + } + } + throw new Error('procedure does not accept arguments'); + } + + // For procedures where every parameter is optional, allow omitting the payload entirely. + if (typeof payload === 'undefined' && params.every((p) => p.optional)) { + return undefined; + } + + if (typeof payload === 'undefined') { + throw new Error('missing procedure arguments'); + } + + if (!payload || typeof payload !== 'object' || Array.isArray(payload)) { + throw new Error('procedure payload must be an object'); + } + + const envelope = payload as Record; + const argsPayload = Object.prototype.hasOwnProperty.call(envelope, 'args') ? (envelope as any).args : undefined; + + if (typeof argsPayload === 'undefined') { + if (params.every((p) => p.optional)) { + return payload; + } + throw new Error('missing procedure arguments'); + } + + if (!argsPayload || typeof argsPayload !== 'object' || Array.isArray(argsPayload)) { + throw new Error('procedure `args` must be an object'); + } + + const obj = argsPayload as Record; + + // reject unknown keys to avoid silently ignoring user mistakes + for (const key of Object.keys(obj)) { + if (!params.some((p) => p.name === key)) { + throw new Error(`unknown procedure argument: ${key}`); + } + } + + // ensure required params are present + for (const p of params) { + if (!Object.prototype.hasOwnProperty.call(obj, p.name)) { + if (p.optional) { + continue; + } + throw new Error(`missing procedure argument: ${p.name}`); + } + } + + return payload; +} + +export function isOrmError(err: unknown): err is ORMError { + return err instanceof ORMError; +} diff --git a/packages/server/src/api/common/utils.ts b/packages/server/src/api/common/utils.ts new file mode 100644 index 00000000..ff42a3f9 --- /dev/null +++ b/packages/server/src/api/common/utils.ts @@ -0,0 +1,56 @@ +import SuperJSON from 'superjson'; + +/** + * Supports the SuperJSON request payload format used by api handlers + * `{ meta: { serialization }, ...json }`. + */ +export async function processSuperJsonRequestPayload(payload: unknown) : Promise<{ result: unknown; error: string | undefined; }> { + if (!payload || typeof payload !== 'object' || Array.isArray(payload) || !('meta' in (payload as any))) { + return { result: payload, error: undefined }; + } + + const { meta, ...rest } = payload as any; + if (meta?.serialization) { + try { + return { + result: SuperJSON.deserialize({ json: rest, meta: meta.serialization }), + error: undefined, + }; + } catch (err) { + return { + result: undefined, + error: `failed to deserialize request payload: ${(err as Error).message}`, + }; + } + } + + // drop meta when no serialization info is present + return { result: rest, error: undefined }; +} + +/** + * Supports the SuperJSON query format used by api handlers: + */ +export function unmarshalQ(value: string, meta: string | undefined) { + let parsedValue: any; + try { + parsedValue = JSON.parse(value); + } catch { + throw new Error('invalid "q" query parameter'); + } + + if (meta) { + let parsedMeta: any; + try { + parsedMeta = JSON.parse(meta); + } catch { + throw new Error('invalid "meta" query parameter'); + } + + if (parsedMeta.serialization) { + return SuperJSON.deserialize({ json: parsedValue, meta: parsedMeta.serialization }); + } + } + + return parsedValue; +} \ No newline at end of file diff --git a/packages/server/src/api/rest/index.ts b/packages/server/src/api/rest/index.ts index ed270cb1..1b5580af 100644 --- a/packages/server/src/api/rest/index.ts +++ b/packages/server/src/api/rest/index.ts @@ -9,6 +9,13 @@ import UrlPattern from 'url-pattern'; import z from 'zod'; import type { ApiHandler, LogConfig, RequestContext, Response } from '../../types'; import { getZodErrorMessage, log, registerCustomSerializers } from '../utils'; +import { + getProcedureDef, + mapProcedureArgs, +} from '../common/procedures'; +import { + processSuperJsonRequestPayload, +} from '../common/utils'; /** * Options for {@link RestApiHandler} @@ -336,6 +343,11 @@ export class RestApiHandler implements Api } try { + if (path.startsWith('/$procs/')) { + const proc = path.split('/')[2]; + return await this.processProcedureRequest({ client, method, proc, query, requestBody }); + } + switch (method) { case 'GET': { let match = this.matchUrlPattern(path, UrlPatterns.SINGLE); @@ -473,6 +485,89 @@ export class RestApiHandler implements Api return this.makeError('unknownError', err instanceof Error ? `${err.message}\n${err.stack}` : 'Unknown error'); } + private async processProcedureRequest({ + client, + method, + proc, + query, + requestBody, + }: { + client: ClientContract; + method: string; + proc?: string; + query?: Record; + requestBody?: unknown; + }): Promise { + if (!proc) { + return this.makeProcBadInputErrorResponse('missing procedure name'); + } + + const procDef = getProcedureDef(this.schema, proc); + if (!procDef) { + return this.makeProcBadInputErrorResponse(`unknown procedure: ${proc}`); + } + + const isMutation = !!procDef.mutation; + if (isMutation) { + if (method !== 'POST') { + return this.makeProcBadInputErrorResponse('invalid request method, only POST is supported'); + } + } else { + if (method !== 'GET') { + return this.makeProcBadInputErrorResponse('invalid request method, only GET is supported'); + } + } + + const argsPayload = method === 'POST' ? requestBody : query; + + // support SuperJSON request payload format + const { result: processedArgsPayload, error } = await processSuperJsonRequestPayload(argsPayload); + if (error) { + return this.makeProcBadInputErrorResponse(error); + } + + let procInput: unknown; + try { + procInput = mapProcedureArgs(procDef, processedArgsPayload); + } catch (err) { + return this.makeProcBadInputErrorResponse(err instanceof Error ? err.message : 'invalid procedure arguments'); + } + + try { + log(this.log, 'debug', () => `handling "$procs.${proc}" request`); + + const clientResult = await (client as any).$procs?.[proc](procInput); + const toSerialize = this.toPlainObject(clientResult); + + const { json, meta } = SuperJSON.serialize(toSerialize); + const responseBody: any = { data: json }; + if (meta) { + responseBody.meta = { serialization: meta }; + } + + return { status: 200, body: responseBody }; + } catch (err) { + log(this.log, 'error', `error occurred when handling "$procs.${proc}" request`, err); + if (err instanceof ORMError) { + throw err; // top-level handler will take care of it + } + return this.makeProcGenericErrorResponse(err); + } + } + + private makeProcBadInputErrorResponse(message: string): Response { + const resp = this.makeError('invalidPayload', message, 400); + log(this.log, 'debug', () => `sending error response: ${JSON.stringify(resp)}`); + return resp; + } + + private makeProcGenericErrorResponse(err: unknown): Response { + const message = err instanceof Error ? err.message : 'unknown error'; + const resp = this.makeError('unknownError', message, 500); + log(this.log, 'debug', () => `sending error response: ${JSON.stringify(resp)}`); + return resp; + } + private async processSingleRead( client: ClientContract, type: string, @@ -831,16 +926,16 @@ export class RestApiHandler implements Api prev: offset - limit >= 0 && offset - limit <= total - 1 ? this.replaceURLSearchParams(baseUrl, { - 'page[offset]': offset - limit, - 'page[limit]': limit, - }) + 'page[offset]': offset - limit, + 'page[limit]': limit, + }) : null, next: offset + limit <= total - 1 ? this.replaceURLSearchParams(baseUrl, { - 'page[offset]': offset + limit, - 'page[limit]': limit, - }) + 'page[offset]': offset + limit, + 'page[limit]': limit, + }) : null, })); } @@ -1906,8 +2001,8 @@ export class RestApiHandler implements Api } else { currPayload[relation] = select ? { - select: { ...select }, - } + select: { ...select }, + } : true; } } diff --git a/packages/server/src/api/rpc/index.ts b/packages/server/src/api/rpc/index.ts index e821366f..1511c3bd 100644 --- a/packages/server/src/api/rpc/index.ts +++ b/packages/server/src/api/rpc/index.ts @@ -5,6 +5,14 @@ import SuperJSON from 'superjson'; import { match } from 'ts-pattern'; import type { ApiHandler, LogConfig, RequestContext, Response } from '../../types'; import { log, registerCustomSerializers } from '../utils'; +import { + getProcedureDef, + mapProcedureArgs, +} from '../common/procedures'; +import { + processSuperJsonRequestPayload, + unmarshalQ, +} from '../common/utils'; registerCustomSerializers(); @@ -27,7 +35,7 @@ export type RPCApiHandlerOptions = { * RPC style API request handler that mirrors the ZenStackClient API */ export class RPCApiHandler implements ApiHandler { - constructor(private readonly options: RPCApiHandlerOptions) {} + constructor(private readonly options: RPCApiHandlerOptions) { } get schema(): Schema { return this.options.schema; @@ -46,6 +54,16 @@ export class RPCApiHandler implements ApiH return this.makeBadInputErrorResponse('invalid request path'); } + if (model === '$procs') { + return this.handleProcedureRequest({ + client, + method: method.toUpperCase(), + proc: op, + query, + requestBody, + }); + } + model = lowerCaseFirst(model); method = method.toUpperCase(); let args: unknown; @@ -78,7 +96,7 @@ export class RPCApiHandler implements ApiH } try { args = query?.['q'] - ? this.unmarshalQ(query['q'] as string, query['meta'] as string | undefined) + ? unmarshalQ(query['q'] as string, query['meta'] as string | undefined) : {}; } catch { return this.makeBadInputErrorResponse('invalid "q" query parameter'); @@ -105,7 +123,7 @@ export class RPCApiHandler implements ApiH } try { args = query?.['q'] - ? this.unmarshalQ(query['q'] as string, query['meta'] as string | undefined) + ? unmarshalQ(query['q'] as string, query['meta'] as string | undefined) : {}; } catch (err) { return this.makeBadInputErrorResponse( @@ -163,6 +181,86 @@ export class RPCApiHandler implements ApiH } } + private async handleProcedureRequest({ + client, + method, + proc, + query, + requestBody, + }: { + client: ClientContract; + method: string; + proc?: string; + query?: Record; + requestBody?: unknown; + }): Promise { + if (!proc) { + return this.makeBadInputErrorResponse('missing procedure name'); + } + + const procDef = getProcedureDef(this.options.schema, proc); + if (!procDef) { + return this.makeBadInputErrorResponse(`unknown procedure: ${proc}`); + } + + const isMutation = !!procDef.mutation; + + if (isMutation) { + if (method !== 'POST') { + return this.makeBadInputErrorResponse('invalid request method, only POST is supported'); + } + } else { + if (method !== 'GET') { + return this.makeBadInputErrorResponse('invalid request method, only GET is supported'); + } + } + + let argsPayload = method === 'POST' ? requestBody : undefined; + if (method === 'GET') { + try { + argsPayload = query?.['q'] + ? unmarshalQ(query['q'] as string, query['meta'] as string | undefined) + : undefined; + } catch (err) { + return this.makeBadInputErrorResponse(err instanceof Error ? err.message : 'invalid "q" query parameter'); + } + } + + const { result: processedArgsPayload, error } = await processSuperJsonRequestPayload(argsPayload); + if (error) { + return this.makeBadInputErrorResponse(error); + } + + let procInput: unknown; + try { + procInput = mapProcedureArgs(procDef, processedArgsPayload); + } catch (err) { + return this.makeBadInputErrorResponse(err instanceof Error ? err.message : 'invalid procedure arguments'); + } + + try { + log(this.options.log, 'debug', () => `handling "$procs.${proc}" request`); + + const clientResult = await (client as any).$procs?.[proc](procInput); + + const { json, meta } = SuperJSON.serialize(clientResult); + const responseBody: any = { data: json }; + if (meta) { + responseBody.meta = { serialization: meta }; + } + + const response = { status: 200, body: responseBody }; + log(this.options.log, 'debug', () => `sending response for "$procs.${proc}" request: ${safeJSONStringify(response)}`); + return response; + } catch (err) { + log(this.options.log, 'error', `error occurred when handling "$procs.${proc}" request`, err); + if (err instanceof ORMError) { + return this.makeORMErrorResponse(err); + } + return this.makeGenericErrorResponse(err); + } + } + private isValidModel(client: ClientContract, model: string) { return Object.keys(client.$schema.models).some((m) => lowerCaseFirst(m) === lowerCaseFirst(model)); } @@ -206,14 +304,14 @@ export class RPCApiHandler implements ApiH .with(ORMErrorReason.REJECTED_BY_POLICY, () => { status = 403; error.rejectedByPolicy = true; - error.rejectReason = err.rejectedByPolicyReason; error.model = err.model; + error.rejectReason = err.rejectedByPolicyReason; }) .with(ORMErrorReason.DB_QUERY_ERROR, () => { status = 400; error.dbErrorCode = err.dbErrorCode; }) - .otherwise(() => {}); + .otherwise(() => { }); const resp = { status, body: { error } }; log(this.options.log, 'debug', () => `sending error response: ${safeJSONStringify(resp)}`); @@ -235,28 +333,4 @@ export class RPCApiHandler implements ApiH } return { result: args, error: undefined }; } - - private unmarshalQ(value: string, meta: string | undefined) { - let parsedValue: any; - try { - parsedValue = JSON.parse(value); - } catch { - throw new Error('invalid "q" query parameter'); - } - - if (meta) { - let parsedMeta: any; - try { - parsedMeta = JSON.parse(meta); - } catch { - throw new Error('invalid "meta" query parameter'); - } - - if (parsedMeta.serialization) { - return SuperJSON.deserialize({ json: parsedValue, meta: parsedMeta.serialization }); - } - } - - return parsedValue; - } } diff --git a/packages/server/test/api/procedures.e2e.test.ts b/packages/server/test/api/procedures.e2e.test.ts new file mode 100644 index 00000000..60c6f06b --- /dev/null +++ b/packages/server/test/api/procedures.e2e.test.ts @@ -0,0 +1,86 @@ +import type { ClientContract } from '@zenstackhq/orm'; +import type { SchemaDef } from '@zenstackhq/orm/schema'; +import { createTestClient } from '@zenstackhq/testtools'; +import { afterEach, beforeEach, describe, expect, it } from 'vitest'; +import { RestApiHandler } from '../../src/api/rest'; + +describe('Procedures E2E', () => { + let client: ClientContract; + let api: RestApiHandler; + + const schema = ` +datasource db { + provider = 'sqlite' + url = 'file:./test.db' +} + +model User { + id Int @id @default(autoincrement()) + email String @unique +} + +procedure greet(name: String?): String +mutation procedure createTwoAndFail(email1: String, email2: String): Int +`; + + beforeEach(async () => { + client = await createTestClient( + schema, + { + procedures: { + greet: async ({ args }: any) => { + const name = args?.name as string | undefined; + return `hello ${name ?? 'world'}`; + }, + createTwoAndFail: async ({ client, args }: any) => { + const email1 = args.email1 as string; + const email2 = args.email2 as string; + await client.user.create({ data: { email: email1 } }); + await client.user.create({ data: { email: email2 } }); + throw new Error('boom'); + }, + }, + } as any + ); + + api = new RestApiHandler({ + schema: client.$schema, + endpoint: 'http://localhost/api', + pageSize: 5, + }); + }); + + afterEach(async () => { + await client?.$disconnect(); + }); + + it('supports $procs routes', async () => { + const r = await api.handleRequest({ + client, + method: 'get', + path: '/$procs/greet', + query: { args: { name: 'alice' } } as any, + }); + expect(r.status).toBe(200); + expect(r.body).toEqual({ data: 'hello alice' }); + }); + + it('returns 422 for invalid input', async () => { + const r = await api.handleRequest({ + client, + method: 'get', + path: '/$procs/greet', + query: { args: { name: 123 } } as any, + }); + + expect(r.status).toBe(422); + expect(r.body).toMatchObject({ + errors: [ + { + status: 422, + code: 'validation-error', + }, + ], + }); + }); +}); diff --git a/packages/server/test/api/rest.test.ts b/packages/server/test/api/rest.test.ts index b40ff604..124967ed 100644 --- a/packages/server/test/api/rest.test.ts +++ b/packages/server/test/api/rest.test.ts @@ -3163,4 +3163,301 @@ describe('REST server tests', () => { }); }); }); + + describe('REST server tests - procedures', () => { + const schema = ` +model User { + id String @id @default(cuid()) + email String @unique +} + +enum Role { + ADMIN + USER +} + +type Overview { + total Int +} + +procedure echoDecimal(x: Decimal): Decimal +procedure greet(name: String?): String +procedure echoInt(x: Int): Int +procedure opt2(a: Int?, b: Int?): Int +procedure sumIds(ids: Int[]): Int +procedure echoRole(r: Role): Role +procedure echoOverview(o: Overview): Overview + +mutation procedure sum(a: Int, b: Int): Int +`; + + beforeEach(async () => { + interface ProcCtx { + client: ClientContract; + args: TArgs; + } + + interface ProcCtxOptionalArgs { + client: ClientContract; + args?: TArgs; + } + + type Role = 'ADMIN' | 'USER'; + + interface Overview { + total: number; + } + + interface EchoDecimalArgs { + x: Decimal; + } + + interface GreetArgs { + name?: string | null; + } + + interface EchoIntArgs { + x: number; + } + + interface Opt2Args { + a?: number | null; + b?: number | null; + } + + interface SumIdsArgs { + ids: number[]; + } + + interface EchoRoleArgs { + r: Role; + } + + interface EchoOverviewArgs { + o: Overview; + } + + interface SumArgs { + a: number; + b: number; + } + + interface Procedures { + echoDecimal: (ctx: ProcCtx) => Promise; + greet: (ctx: ProcCtxOptionalArgs) => Promise; + echoInt: (ctx: ProcCtx) => Promise; + opt2: (ctx: ProcCtxOptionalArgs) => Promise; + sumIds: (ctx: ProcCtx) => Promise; + echoRole: (ctx: ProcCtx) => Promise; + echoOverview: (ctx: ProcCtx) => Promise; + sum: (ctx: ProcCtx) => Promise; + } + + client = await createTestClient(schema as unknown as SchemaDef, { + procedures: { + echoDecimal: async ({ args }: ProcCtx) => args.x, + greet: async ({ args }: ProcCtxOptionalArgs) => { + const name = args?.name as string | undefined; + return `hi ${name ?? 'anon'}`; + }, + echoInt: async ({ args }: ProcCtx) => args.x, + opt2: async ({ args }: ProcCtxOptionalArgs) => { + const a = args?.a as number | undefined; + const b = args?.b as number | undefined; + return (a ?? 0) + (b ?? 0); + }, + sumIds: async ({ args }: ProcCtx) => (args.ids as number[]).reduce((acc, x) => acc + x, 0), + echoRole: async ({ args }: ProcCtx) => args.r, + echoOverview: async ({ args }: ProcCtx) => args.o, + sum: async ({ args }: ProcCtx) => args.a + args.b, + } as Procedures, + }); + + const _handler = new RestApiHandler({ + schema: client.$schema, + endpoint: 'http://localhost/api', + pageSize: 5, + }); + + handler = (args) => _handler.handleRequest({ ...args, url: new URL(`http://localhost/${args.path}`) }); + }); + + it('supports GET query procedures with q/meta (SuperJSON)', async () => { + const { json, meta } = SuperJSON.serialize({ args: { x: new Decimal('1.23') } }); + const r = await handler({ + method: 'get', + path: '/$procs/echoDecimal', + query: { ...json as object, meta: { serialization: meta } } as any, + client, + }); + + expect(r.status).toBe(200); + expect(r.body).toMatchObject({ data: '1.23' }); + }); + + it('supports GET procedures without args when param is optional', async () => { + const r = await handler({ + method: 'get', + path: '/$procs/greet', + query: {}, + client, + }); + + expect(r.status).toBe(200); + expect(r.body).toMatchObject({ data: 'hi anon' }); + }); + + it('errors for missing required single-param arg', async () => { + const r = await handler({ + method: 'get', + path: '/$procs/echoInt', + query: {}, + client, + }); + + expect(r.status).toBe(400); + expect(r.body).toMatchObject({ + errors: [ + { + status: 400, + code: 'invalid-payload', + detail: 'missing procedure arguments', + }, + ], + }); + }); + + it('supports GET procedures without args when all params are optional', async () => { + const r = await handler({ + method: 'get', + path: '/$procs/opt2', + query: {}, + client, + }); + + expect(r.status).toBe(200); + expect(r.body).toMatchObject({ data: 0 }); + }); + + it('supports array-typed single param via envelope args', async () => { + const r = await handler({ + method: 'get', + path: '/$procs/sumIds', + query: { args: { ids: [1, 2, 3] } } as any, + client, + }); + + expect(r.status).toBe(200); + expect(r.body).toMatchObject({ data: 6 }); + }); + + it('supports enum-typed params with validation', async () => { + const r = await handler({ + method: 'get', + path: '/$procs/echoRole', + query: { args: { r: 'ADMIN' } } as any, + client, + }); + + expect(r.status).toBe(200); + expect(r.body).toMatchObject({ data: 'ADMIN' }); + }); + + it('supports typedef params (object payload)', async () => { + const r = await handler({ + method: 'get', + path: '/$procs/echoOverview', + query: { args: { o: { total: 123 } } } as any, + client, + }); + + expect(r.status).toBe(200); + expect(r.body).toMatchObject({ data: { total: 123 } }); + }); + + it('errors for wrong type input', async () => { + const r = await handler({ + method: 'get', + path: '/$procs/echoInt', + query: { args: { x: 'not-an-int' } } as any, + client, + }); + + expect(r.status).toBe(422); + expect(r.body).toMatchObject({ + errors: [ + { + status: 422, + code: 'validation-error', + }, + ], + }); + expect(r.body.errors?.[0]?.detail).toMatch(/invalid input/i); + }); + + it('supports POST mutation procedures with args passed via q/meta', async () => { + const { json, meta } = SuperJSON.serialize({ args: { a: 1, b: 2 } }); + const r = await handler({ + method: 'post', + path: '/$procs/sum', + requestBody: { ...json as object, meta: { serialization: meta } } as any, + client, + }); + + expect(r.status).toBe(200); + expect(r.body).toMatchObject({ data: 3 }); + }); + + it('errors for invalid `args` payload type', async () => { + const r = await handler({ + method: 'post', + path: '/$procs/sum', + requestBody: { args: [1, 2, 3] } as any, + client, + }); + + expect(r.status).toBe(400); + expect(r.body).toMatchObject({ + errors: [ + { + status: 400, + code: 'invalid-payload', + }, + ], + }); + expect(r.body.errors?.[0]?.detail).toMatch(/args/i); + }); + + it('errors for unknown argument keys (object mapping)', async () => { + const r = await handler({ + method: 'post', + path: '/$procs/sum', + requestBody: { args: { a: 1, b: 2, c: 3 } } as any, + client, + }); + + expect(r.status).toBe(400); + expect(r.body).toMatchObject({ + errors: [ + { + status: 400, + code: 'invalid-payload', + }, + ], + }); + expect(r.body.errors?.[0]?.detail).toMatch(/unknown procedure argument/i); + }); + + it('supports /$procs path', async () => { + const r = await handler({ + method: 'post', + path: '/$procs/sum', + requestBody: { args: { a: 1, b: 2 } } as any, + client, + }); + + expect(r.status).toBe(200); + expect(r.body).toMatchObject({ data: 3 }); + }); + }); }); diff --git a/packages/server/test/api/rpc.test.ts b/packages/server/test/api/rpc.test.ts index 4329e857..3842e5d3 100644 --- a/packages/server/test/api/rpc.test.ts +++ b/packages/server/test/api/rpc.test.ts @@ -125,6 +125,217 @@ describe('RPC API Handler Tests', () => { expect(r.data.count).toBe(1); }); + it('procedures', async () => { + const procSchema = ` +model User { + id String @id @default(cuid()) + email String @unique @email + + @@allow('all', true) +} + +procedure echo(input: String): String +mutation procedure createUser(email: String): User +procedure getFalse(): Boolean +procedure getUndefined(): Undefined +`; + + const procClient = await createPolicyTestClient(procSchema, { + procedures: { + echo: async ({ args }: any) => args.input, + createUser: async ({ client, args }: any) => { + return client.user.create({ data: { email: args.email } }); + }, + getFalse: async () => false, + getUndefined: async () => undefined, + }, + }); + + const handler = new RPCApiHandler({ schema: procClient.$schema }); + const handleProcRequest = async (args: any) => { + const r = await handler.handleRequest({ + ...args, + client: procClient, + url: new URL(`http://localhost/${args.path}`), + }); + return { + status: r.status, + body: r.body as any, + data: (r.body as any).data, + error: (r.body as any).error, + meta: (r.body as any).meta, + }; + }; + + // query procedure: GET only, args via q + let r = await handleProcRequest({ + method: 'get', + path: '/$procs/echo', + query: { q: JSON.stringify({ args: { input: 'hello' } }) }, + }); + expect(r.status).toBe(200); + expect(r.data).toBe('hello'); + + r = await handleProcRequest({ + method: 'post', + path: '/$procs/echo', + requestBody: { args: { input: 'hello' } }, + }); + expect(r.status).toBe(400); + expect(r.error?.message).toMatch(/only GET is supported/i); + + // mutation procedure: POST only, args via body + r = await handleProcRequest({ + method: 'post', + path: '/$procs/createUser', + requestBody: { args: { email: 'user1@abc.com' } }, + }); + expect(r.status).toBe(200); + expect(r.data).toEqual(expect.objectContaining({ email: 'user1@abc.com' })); + + r = await handleProcRequest({ + method: 'get', + path: '/$procs/createUser', + query: { q: JSON.stringify({ args: { email: 'user2@abc.com' } }) }, + }); + expect(r.status).toBe(400); + expect(r.error?.message).toMatch(/only POST is supported/i); + + // falsy/undefined return serialization + r = await handleProcRequest({ method: 'get', path: '/$procs/getFalse' }); + expect(r.status).toBe(200); + expect(r.data).toBe(false); + + r = await handleProcRequest({ method: 'get', path: '/$procs/getUndefined' }); + expect(r.status).toBe(200); + expect(r.data).toBeNull(); + expect(r.meta?.serialization).toBeTruthy(); + }); + + it('procedures - edge cases', async () => { + const procSchema = ` +model User { + id String @id @default(cuid()) + email String @unique @email +} + +enum Role { + ADMIN + USER +} + +type Overview { + total Int +} + +procedure echoInt(x: Int): Int +procedure opt2(a: Int?, b: Int?): Int +procedure sum3(a: Int, b: Int, c: Int): Int +procedure sumIds(ids: Int[]): Int +procedure echoRole(r: Role): Role +procedure echoOverview(o: Overview): Overview +`; + + const procClient = await createPolicyTestClient(procSchema, { + procedures: { + echoInt: async ({ args }: any) => args.x, + opt2: async ({ args }: any) => { + const a = args?.a as number | undefined; + const b = args?.b as number | undefined; + return (a ?? 0) + (b ?? 0); + }, + sum3: async ({ args }: any) => args.a + args.b + args.c, + sumIds: async ({ args }: any) => (args.ids as number[]).reduce((acc: number, x: number) => acc + x, 0), + echoRole: async ({ args }: any) => args.r, + echoOverview: async ({ args }: any) => args.o, + }, + }); + + const handler = new RPCApiHandler({ schema: procClient.$schema }); + const handleProcRequest = async (args: any) => { + const r = await handler.handleRequest({ + ...args, + client: procClient, + url: new URL(`http://localhost/${args.path}`), + }); + return { + status: r.status, + body: r.body as any, + data: (r.body as any).data, + error: (r.body as any).error, + meta: (r.body as any).meta, + }; + }; + + // > 2 params object mapping + let r = await handleProcRequest({ + method: 'get', + path: '/$procs/sum3', + query: { q: JSON.stringify({ args: { a: 1, b: 2, c: 3 } }) }, + }); + expect(r.status).toBe(200); + expect(r.data).toBe(6); + + // all optional params can omit payload + r = await handleProcRequest({ method: 'get', path: '/$procs/opt2' }); + expect(r.status).toBe(200); + expect(r.data).toBe(0); + + // array-typed single param via q JSON array + r = await handleProcRequest({ + method: 'get', + path: '/$procs/sumIds', + query: { q: JSON.stringify({ args: { ids: [1, 2, 3] } }) }, + }); + expect(r.status).toBe(200); + expect(r.data).toBe(6); + + // enum param validation + r = await handleProcRequest({ + method: 'get', + path: '/$procs/echoRole', + query: { q: JSON.stringify({ args: { r: 'ADMIN' } }) }, + }); + expect(r.status).toBe(200); + expect(r.data).toBe('ADMIN'); + + // typedef param (object payload) + r = await handleProcRequest({ + method: 'get', + path: '/$procs/echoOverview', + query: { q: JSON.stringify({ args: { o: { total: 123 } } }) }, + }); + expect(r.status).toBe(200); + expect(r.data).toMatchObject({ total: 123 }); + + // wrong type input + r = await handleProcRequest({ + method: 'get', + path: '/$procs/echoInt', + query: { q: JSON.stringify({ args: { x: 'x' } }) }, + }); + expect(r.status).toBe(422); + expect(r.error?.message).toMatch(/invalid input/i); + + // invalid args payload type + r = await handleProcRequest({ + method: 'get', + path: '/$procs/sum3', + query: { q: JSON.stringify({ args: [1, 2, 3, 4] }) }, + }); + expect(r.status).toBe(400); + expect(r.error?.message).toMatch(/args/i); + + // unknown keys + r = await handleProcRequest({ + method: 'get', + path: '/$procs/sum3', + query: { q: JSON.stringify({ args: { a: 1, b: 2, c: 3, d: 4 } }) }, + }); + expect(r.status).toBe(400); + expect(r.error?.message).toMatch(/unknown procedure argument/i); + }); + it('pagination and ordering', async () => { const handleRequest = makeHandler(); diff --git a/samples/orm/main.ts b/samples/orm/main.ts index 59e814c1..4b7f82ef 100644 --- a/samples/orm/main.ts +++ b/samples/orm/main.ts @@ -16,6 +16,18 @@ async function main() { .select(({ fn }) => fn.countAll().as('postCount')), }, }, + procedures: { + signUp: ({ client, args }) => + client.user.create({ + data: { ...args }, + }), + listPublicPosts: ({ client }) => + client.post.findMany({ + where: { + published: true, + }, + }), + }, }).$use({ id: 'cost-logger', onQuery: async ({ model, operation, args, proceed }) => { @@ -101,6 +113,12 @@ async function main() { console.log('Posts readable to', user2.email); console.table(await user2Db.post.findMany({ select: { title: true, published: true } })); + + const newUser = await authDb.$procs.signUp({ args: { email: 'new@zenstack.dev', name: 'New User' } }); + console.log('New user signed up via procedure:', newUser); + + const publicPosts = await authDb.$procs.listPublicPosts(); + console.log('Public posts via procedure:', publicPosts); } main(); diff --git a/samples/orm/zenstack/schema.ts b/samples/orm/zenstack/schema.ts index deb07cb9..e3c02e5c 100644 --- a/samples/orm/zenstack/schema.ts +++ b/samples/orm/zenstack/schema.ts @@ -239,6 +239,22 @@ export class SchemaType implements SchemaDef { } } as const; authType = "User" as const; + procedures = { + signUp: { + params: { + email: { name: "email", type: "String" }, + name: { name: "name", optional: true, type: "String" }, + role: { name: "role", optional: true, type: "Role" } + }, + returnType: "User", + mutation: true + }, + listPublicPosts: { + params: {}, + returnType: "Post", + returnArray: true + } + } as const; plugins = {}; } export const schema = new SchemaType(); diff --git a/samples/orm/zenstack/schema.zmodel b/samples/orm/zenstack/schema.zmodel index 3669d799..2567801a 100644 --- a/samples/orm/zenstack/schema.zmodel +++ b/samples/orm/zenstack/schema.zmodel @@ -47,10 +47,13 @@ model Profile with CommonFields { model Post with CommonFields { title String content String - published Boolean @default(false) - author User @relation(fields: [authorId], references: [id]) + published Boolean @default(false) + author User @relation(fields: [authorId], references: [id]) authorId String @@allow('read', published) @@allow('all', author == auth()) } + +mutation procedure signUp(email: String, name: String?, role: Role?): User +procedure listPublicPosts(): Post[] diff --git a/tests/e2e/orm/client-api/procedures.test.ts b/tests/e2e/orm/client-api/procedures.test.ts new file mode 100644 index 00000000..6bd7b5bf --- /dev/null +++ b/tests/e2e/orm/client-api/procedures.test.ts @@ -0,0 +1,217 @@ +import type { ClientContract } from '@zenstackhq/orm'; +import { createTestClient } from '@zenstackhq/testtools'; +import { afterEach, beforeEach, describe, expect, it } from 'vitest'; +import { schema } from '../schemas/procedures/schema'; +import type { User } from '../schemas/procedures/models'; + +describe('Procedures tests', () => { + let client: ClientContract; + + beforeEach(async () => { + client = await createTestClient(schema, { + procedures: { + // Query procedure that returns a single User + getUser: async ({ client, args: { id } }) => { + return await client.user.findUniqueOrThrow({ + where: { id }, + }); + }, + + // Query procedure that returns an array of Users + listUsers: async ({ client }) => { + return await client.user.findMany(); + }, + + // Mutation procedure that creates a User + signUp: async ({ client, args: { name, role } }) => { + return await client.user.create({ + data: { + name, + role, + }, + }); + }, + + // Query procedure that returns Void + setAdmin: async ({ client, args: { userId } }) => { + await client.user.update({ + where: { id: userId }, + data: { role: 'ADMIN' }, + }); + }, + + // Query procedure that returns a custom type + getOverview: async ({ client }) => { + const userIds = await client.user.findMany({ select: { id: true } }); + const total = await client.user.count(); + return { + userIds: userIds.map((u) => u.id), + total, + roles: ['ADMIN', 'USER'], + meta: { hello: 'world' }, + }; + }, + + createMultiple: async ({ client, args: { names } }) => { + return await client.$transaction(async (tx) => { + const createdUsers: User[] = []; + for (const name of names) { + const user = await tx.user.create({ + data: { name }, + }); + createdUsers.push(user); + } + return createdUsers; + }); + }, + }, + }); + }); + + afterEach(async () => { + await client?.$disconnect(); + }); + + it('works with query proc with parameters', async () => { + // Create a user first + const created = await client.user.create({ + data: { + name: 'Alice', + role: 'USER', + }, + }); + + // Call the procedure + const result = await client.$procs.getUser({ args: { id: created.id } }); + + expect(result).toMatchObject({ + id: created.id, + name: 'Alice', + role: 'USER', + }); + }); + + it('works with query proc without parameters', async () => { + // Create multiple users + await client.user.create({ + data: { name: 'Alice', role: 'USER' }, + }); + await client.user.create({ + data: { name: 'Bob', role: 'ADMIN' }, + }); + await client.user.create({ + data: { name: 'Charlie', role: 'USER' }, + }); + + const result = await client.$procs.listUsers(); + + expect(result).toHaveLength(3); + expect(result).toEqual( + expect.arrayContaining([ + expect.objectContaining({ name: 'Alice', role: 'USER' }), + expect.objectContaining({ name: 'Bob', role: 'ADMIN' }), + expect.objectContaining({ name: 'Charlie', role: 'USER' }), + ]), + ); + }); + + it('works with mutation with parameters', async () => { + const result = await client.$procs.signUp({ args: { name: 'Alice' } }); + + expect(result).toMatchObject({ + id: expect.any(Number), + name: 'Alice', + role: 'USER', + }); + + // Verify user was created in database + const users = await client.user.findMany(); + expect(users).toHaveLength(1); + expect(users[0]).toMatchObject({ + name: 'Alice', + role: 'USER', + }); + + // accepts optional role parameter + const result1 = await client.$procs.signUp({ + args: { + name: 'Bob', + role: 'ADMIN', + }, + }); + + expect(result1).toMatchObject({ + id: expect.any(Number), + name: 'Bob', + role: 'ADMIN', + }); + + // Verify user was created with correct role + const user1 = await client.user.findUnique({ + where: { id: result1.id }, + }); + expect(user1?.role).toBe('ADMIN'); + }); + + it('works with mutation proc that returns void', async () => { + // Create a regular user + const user = await client.user.create({ + data: { name: 'Alice', role: 'USER' }, + }); + + expect(user.role).toBe('USER'); + + // Call setAdmin procedure + const result = await client.$procs.setAdmin({ args: { userId: user.id } }); + + // Procedure returns void + expect(result).toBeUndefined(); + + // Verify user role was updated + const updated = await client.user.findUnique({ + where: { id: user.id }, + }); + expect(updated?.role).toBe('ADMIN'); + }); + + it('works with procedure returning custom type', async () => { + await client.user.create({ data: { name: 'Alice', role: 'USER' } }); + await client.user.create({ data: { name: 'Bob', role: 'ADMIN' } }); + + const result = await client.$procs.getOverview(); + expect(result.total).toBe(2); + expect(result.userIds).toHaveLength(2); + expect(result.roles).toEqual(expect.arrayContaining(['ADMIN', 'USER'])); + expect(result.meta).toEqual({ hello: 'world' }); + }); + + it('works with transactional mutation procs', async () => { + // unique constraint violation should rollback the transaction + await expect(client.$procs.createMultiple({ args: { names: ['Alice', 'Alice'] } })).rejects.toThrow(); + await expect(client.user.count()).resolves.toBe(0); + + // successful transaction + await expect(client.$procs.createMultiple({ args: { names: ['Alice', 'Bob'] } })).resolves.toEqual( + expect.arrayContaining([ + expect.objectContaining({ name: 'Alice' }), + expect.objectContaining({ name: 'Bob' }), + ]), + ); + }); + + it('respects the outer transaction context', async () => { + // outer client creates a transaction + await expect( + client.$transaction(async (tx) => { + await tx.$procs.signUp({ args: { name: 'Alice' } }); + await tx.$procs.signUp({ args: { name: 'Alice' } }); + }), + ).rejects.toThrow(); + await expect(client.user.count()).resolves.toBe(0); + + // without transaction + await client.$procs.signUp({ args: { name: 'Alice' } }); + await expect(client.$procs.signUp({ args: { name: 'Alice' } })).rejects.toThrow(); + await expect(client.user.count()).resolves.toBe(1); + }); +}); diff --git a/tests/e2e/orm/schemas/procedures/input.ts b/tests/e2e/orm/schemas/procedures/input.ts new file mode 100644 index 00000000..88a751d1 --- /dev/null +++ b/tests/e2e/orm/schemas/procedures/input.ts @@ -0,0 +1,30 @@ +////////////////////////////////////////////////////////////////////////////////////////////// +// DO NOT MODIFY THIS FILE // +// This file is automatically generated by ZenStack CLI and should not be manually updated. // +////////////////////////////////////////////////////////////////////////////////////////////// + +/* eslint-disable */ + +import { type SchemaType as $Schema } from "./schema"; +import type { FindManyArgs as $FindManyArgs, FindUniqueArgs as $FindUniqueArgs, FindFirstArgs as $FindFirstArgs, CreateArgs as $CreateArgs, CreateManyArgs as $CreateManyArgs, CreateManyAndReturnArgs as $CreateManyAndReturnArgs, UpdateArgs as $UpdateArgs, UpdateManyArgs as $UpdateManyArgs, UpdateManyAndReturnArgs as $UpdateManyAndReturnArgs, UpsertArgs as $UpsertArgs, DeleteArgs as $DeleteArgs, DeleteManyArgs as $DeleteManyArgs, CountArgs as $CountArgs, AggregateArgs as $AggregateArgs, GroupByArgs as $GroupByArgs, WhereInput as $WhereInput, SelectInput as $SelectInput, IncludeInput as $IncludeInput, OmitInput as $OmitInput, QueryOptions as $QueryOptions } from "@zenstackhq/orm"; +import type { SimplifiedPlainResult as $Result, SelectIncludeOmit as $SelectIncludeOmit } from "@zenstackhq/orm"; +export type UserFindManyArgs = $FindManyArgs<$Schema, "User">; +export type UserFindUniqueArgs = $FindUniqueArgs<$Schema, "User">; +export type UserFindFirstArgs = $FindFirstArgs<$Schema, "User">; +export type UserCreateArgs = $CreateArgs<$Schema, "User">; +export type UserCreateManyArgs = $CreateManyArgs<$Schema, "User">; +export type UserCreateManyAndReturnArgs = $CreateManyAndReturnArgs<$Schema, "User">; +export type UserUpdateArgs = $UpdateArgs<$Schema, "User">; +export type UserUpdateManyArgs = $UpdateManyArgs<$Schema, "User">; +export type UserUpdateManyAndReturnArgs = $UpdateManyAndReturnArgs<$Schema, "User">; +export type UserUpsertArgs = $UpsertArgs<$Schema, "User">; +export type UserDeleteArgs = $DeleteArgs<$Schema, "User">; +export type UserDeleteManyArgs = $DeleteManyArgs<$Schema, "User">; +export type UserCountArgs = $CountArgs<$Schema, "User">; +export type UserAggregateArgs = $AggregateArgs<$Schema, "User">; +export type UserGroupByArgs = $GroupByArgs<$Schema, "User">; +export type UserWhereInput = $WhereInput<$Schema, "User">; +export type UserSelect = $SelectInput<$Schema, "User">; +export type UserInclude = $IncludeInput<$Schema, "User">; +export type UserOmit = $OmitInput<$Schema, "User">; +export type UserGetPayload, Options extends $QueryOptions<$Schema> = $QueryOptions<$Schema>> = $Result<$Schema, "User", Args, Options>; diff --git a/tests/e2e/orm/schemas/procedures/models.ts b/tests/e2e/orm/schemas/procedures/models.ts new file mode 100644 index 00000000..9920c101 --- /dev/null +++ b/tests/e2e/orm/schemas/procedures/models.ts @@ -0,0 +1,13 @@ +////////////////////////////////////////////////////////////////////////////////////////////// +// DO NOT MODIFY THIS FILE // +// This file is automatically generated by ZenStack CLI and should not be manually updated. // +////////////////////////////////////////////////////////////////////////////////////////////// + +/* eslint-disable */ + +import { schema as $schema, type SchemaType as $Schema } from "./schema"; +import { type ModelResult as $ModelResult, type TypeDefResult as $TypeDefResult } from "@zenstackhq/orm"; +export type User = $ModelResult<$Schema, "User">; +export type Overview = $TypeDefResult<$Schema, "Overview">; +export const Role = $schema.enums.Role.values; +export type Role = (typeof Role)[keyof typeof Role]; diff --git a/tests/e2e/orm/schemas/procedures/schema.ts b/tests/e2e/orm/schemas/procedures/schema.ts new file mode 100644 index 00000000..f5a9044e --- /dev/null +++ b/tests/e2e/orm/schemas/procedures/schema.ts @@ -0,0 +1,121 @@ +////////////////////////////////////////////////////////////////////////////////////////////// +// DO NOT MODIFY THIS FILE // +// This file is automatically generated by ZenStack CLI and should not be manually updated. // +////////////////////////////////////////////////////////////////////////////////////////////// + +/* eslint-disable */ + +import { type SchemaDef, ExpressionUtils } from "@zenstackhq/orm/schema"; +export class SchemaType implements SchemaDef { + provider = { + type: "sqlite" + } as const; + models = { + User: { + name: "User", + fields: { + id: { + name: "id", + type: "Int", + id: true, + attributes: [{ name: "@id" }, { name: "@default", args: [{ name: "value", value: ExpressionUtils.call("autoincrement") }] }], + default: ExpressionUtils.call("autoincrement") + }, + name: { + name: "name", + type: "String", + unique: true, + attributes: [{ name: "@unique" }] + }, + role: { + name: "role", + type: "Role", + attributes: [{ name: "@default", args: [{ name: "value", value: ExpressionUtils.literal("USER") }] }], + default: "USER" + } + }, + idFields: ["id"], + uniqueFields: { + id: { type: "Int" }, + name: { type: "String" } + } + } + } as const; + typeDefs = { + Overview: { + name: "Overview", + fields: { + userIds: { + name: "userIds", + type: "Int", + array: true + }, + total: { + name: "total", + type: "Int" + }, + roles: { + name: "roles", + type: "Role", + array: true + }, + meta: { + name: "meta", + type: "Json", + optional: true + } + } + } + } as const; + enums = { + Role: { + values: { + ADMIN: "ADMIN", + USER: "USER" + } + } + } as const; + authType = "User" as const; + procedures = { + getUser: { + params: { + id: { name: "id", type: "Int" } + }, + returnType: "User" + }, + listUsers: { + params: {}, + returnType: "User", + returnArray: true + }, + signUp: { + params: { + name: { name: "name", type: "String" }, + role: { name: "role", optional: true, type: "Role" } + }, + returnType: "User", + mutation: true + }, + setAdmin: { + params: { + userId: { name: "userId", type: "Int" } + }, + returnType: "Void", + mutation: true + }, + getOverview: { + params: {}, + returnType: "Overview" + }, + createMultiple: { + params: { + names: { name: "names", array: true, type: "String" } + }, + returnType: "User", + returnArray: true, + mutation: true + } + } as const; + plugins = {}; +} +export const schema = new SchemaType(); diff --git a/tests/e2e/orm/schemas/procedures/schema.zmodel b/tests/e2e/orm/schemas/procedures/schema.zmodel new file mode 100644 index 00000000..25380dab --- /dev/null +++ b/tests/e2e/orm/schemas/procedures/schema.zmodel @@ -0,0 +1,29 @@ +datasource db { + provider = 'sqlite' + url = 'file:./test.db' +} + +enum Role { + ADMIN + USER +} + +type Overview { + userIds Int[] + total Int + roles Role[] + meta Json? +} + +model User { + id Int @id @default(autoincrement()) + name String @unique + role Role @default(USER) +} + +procedure getUser(id: Int): User +procedure listUsers(): User[] +mutation procedure signUp(name: String, role: Role?): User +mutation procedure setAdmin(userId: Int): Void +procedure getOverview(): Overview +mutation procedure createMultiple(names: String[]): User[]