diff --git a/src/utils/entity-resolver.test.ts b/src/utils/entity-resolver.test.ts index c32d721..12707c5 100644 --- a/src/utils/entity-resolver.test.ts +++ b/src/utils/entity-resolver.test.ts @@ -411,4 +411,123 @@ describe("EntityResolver", () => { expect(built.referenceImages).toHaveLength(3); }); }); + + describe("should handle compound words correctly", () => { + it("should match 'Darien' in prompt with 'superdarien' compound word", async () => { + const mockSupabase = createMockSupabaseClient( + ["darien", "cody"], + { + darien: [{ name: "darien1.jpg", id: "1" }], + cody: [{ name: "cody1.jpg", id: "2" }], + }, + ); + const resolver = new EntityResolver(mockSupabase, mockLogger); + + const prompt = + "samebot draw Darien standing next to a superdarien which is 2x wider in each dimension, standing next to a long Darien who is 2x taller but same width and depth"; + + const result = await resolver.resolve(prompt); + + expect(result).not.toBeNull(); + expect(result?.entities).toHaveLength(1); + expect(result?.entities[0]?.name).toBe("darien"); + const entityNames = result?.entities.map((e) => e.name); + expect(entityNames).not.toContain("cody"); + }); + + it("should not match 'cody' when prompt contains 'darien'", async () => { + const mockSupabase = createMockSupabaseClient( + ["darien", "cody"], + { + darien: [{ name: "darien1.jpg", id: "1" }], + cody: [{ name: "cody1.jpg", id: "2" }], + }, + ); + const resolver = new EntityResolver(mockSupabase, mockLogger); + + const prompt = "draw Darien standing next to a superdarien"; + + const result = await resolver.resolve(prompt); + + if (result) { + const entityNames = result.entities.map((e) => e.name); + expect(entityNames).not.toContain("cody"); + expect(entityNames).toContain("darien"); + } + }); + + it("should not match 'superdarien' to 'cody' via fuzzy matching", async () => { + const mockSupabase = createMockSupabaseClient( + ["darien", "cody"], + { + darien: [{ name: "darien1.jpg", id: "1" }], + cody: [{ name: "cody1.jpg", id: "2" }], + }, + ); + const resolver = new EntityResolver(mockSupabase, mockLogger); + + const prompt = "draw a superdarien"; + + const result = await resolver.resolve(prompt); + + if (result) { + const entityNames = result.entities.map((e) => e.name); + expect(entityNames).not.toContain("cody"); + } + }); + + it("should match 'darien' as substring in 'superdarien' compound word", async () => { + const mockSupabase = createMockSupabaseClient( + ["darien", "cody"], + { + darien: [{ name: "darien1.jpg", id: "1" }], + cody: [{ name: "cody1.jpg", id: "2" }], + }, + ); + const resolver = new EntityResolver(mockSupabase, mockLogger); + + const prompt = "draw a superdarien"; + + const result = await resolver.resolve(prompt); + + expect(result).not.toBeNull(); + const entityNames = result?.entities.map((e) => e.name); + expect(entityNames).toContain("darien"); + expect(entityNames).not.toContain("cody"); + }); + + it("should prioritize exact 'darien' match over fuzzy 'cody' match", async () => { + const mockSupabase = createMockSupabaseClient( + ["darien", "cody"], + { + darien: [{ name: "darien1.jpg", id: "1" }], + cody: [{ name: "cody1.jpg", id: "2" }], + }, + ); + const resolver = new EntityResolver(mockSupabase, mockLogger); + + const prompt = "draw Darien and a superdarien"; + + const result = await resolver.resolve(prompt); + + expect(result).not.toBeNull(); + const entityNames = result?.entities.map((e) => e.name).sort(); + expect(entityNames).toEqual(["darien"]); + expect(entityNames).not.toContain("cody"); + }); + + it("should not match entity name at start of word (to avoid typo matches)", async () => { + const mockSupabase = createMockSupabaseClient( + ["tyrus"], + { tyrus: [{ name: "tyrus1.jpg", id: "1" }] }, + ); + const resolver = new EntityResolver(mockSupabase, mockLogger); + + const prompt = "draw a tyrusss"; + + const result = await resolver.resolve(prompt); + + expect(result).toBeNull(); + }); + }); }); diff --git a/src/utils/entity-resolver.ts b/src/utils/entity-resolver.ts index 5d5278c..489a636 100644 --- a/src/utils/entity-resolver.ts +++ b/src/utils/entity-resolver.ts @@ -69,6 +69,47 @@ export class EntityResolver { continue; } + const substringMatch = searchableEntities.find((entity) => { + const normalizedEntityName = this.normalizeWord(entity.searchTerm); + const entityLength = normalizedEntityName.length; + const wordLength = normalizedWord.length; + + if (entityLength < 3) { + return false; + } + + if (wordLength <= entityLength) { + return false; + } + + if (wordLength - entityLength < 2) { + return false; + } + + if (!normalizedWord.includes(normalizedEntityName)) { + return false; + } + + const index = normalizedWord.indexOf(normalizedEntityName); + if (index === 0) { + return false; + } + + return true; + }); + + if (substringMatch) { + const existingMatch = matchedEntities.get(substringMatch.folderName); + if (!existingMatch || existingMatch.score < 0.95) { + matchedEntities.set(substringMatch.folderName, { + word, + folder: substringMatch.folderName, + score: 0.95, + }); + } + continue; + } + const results = fuse.search(normalizedWord); if (results.length > 0) { const topResult = results[0]!;