diff --git a/packages/auth/src/account/provision-account.ts b/packages/auth/src/account/provision-account.ts index 8c3d7ca..2686dc4 100644 --- a/packages/auth/src/account/provision-account.ts +++ b/packages/auth/src/account/provision-account.ts @@ -5,6 +5,7 @@ import type { SupabaseToken, } from "@listee/db"; import { + and, categories, createRlsClient, DEFAULT_CATEGORY_KIND, @@ -79,9 +80,14 @@ export function createAccountProvisioner( id: params.userId, email, }) - .onConflictDoNothing(); + .onConflictDoUpdate({ + target: profiles.id, + set: { + email, + }, + }); - await tx + const insertedCategories = await tx .insert(categories) .values({ name: defaultCategoryName, @@ -92,7 +98,43 @@ export function createAccountProvisioner( .onConflictDoNothing({ target: [categories.createdBy, categories.name], where: eq(categories.kind, defaultCategoryKind), - }); + }) + .returning({ categoryId: categories.id }); + + const categoryRecord = + insertedCategories[0] ?? + ( + await tx + .select({ categoryId: categories.id }) + .from(categories) + .where( + and( + eq(categories.createdBy, params.userId), + eq(categories.name, defaultCategoryName), + eq(categories.kind, defaultCategoryKind), + ), + ) + .limit(1) + )[0]; + + if (categoryRecord === undefined) { + throw new Error("Failed to resolve default category for profile"); + } + + const defaultCategoryId = categoryRecord.categoryId; + + const shouldUpdateProfile = await tx + .update(profiles) + .set({ + defaultCategoryId, + email, + }) + .where(eq(profiles.id, params.userId)) + .returning({ id: profiles.id }); + + if (shouldUpdateProfile.length === 0) { + throw new Error("Failed to update profile with default category"); + } }); } diff --git a/packages/db/src/schema/index.ts b/packages/db/src/schema/index.ts index 8c76996..6bcda48 100644 --- a/packages/db/src/schema/index.ts +++ b/packages/db/src/schema/index.ts @@ -28,7 +28,10 @@ export const profiles = pgTable( id: uuid("id").primaryKey(), email: text("email").notNull().unique(), name: text("name"), - defaultCategoryId: uuid("default_category_id"), + defaultCategoryId: uuid("default_category_id").references( + (): AnyPgColumn => categories.id, + { onDelete: "restrict" }, + ), ...timestamps, }, (table) => {