Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions src/utils/entity-resolver.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -411,4 +411,123 @@ describe("EntityResolver", () => {
expect(built.referenceImages).toHaveLength(3);
});
});

describe("should handle compound words correctly", () => {
it("should match 'Darien' in prompt with 'superdarien' compound word", async () => {
const mockSupabase = createMockSupabaseClient(
["darien", "cody"],
{
darien: [{ name: "darien1.jpg", id: "1" }],
cody: [{ name: "cody1.jpg", id: "2" }],
},
);
const resolver = new EntityResolver(mockSupabase, mockLogger);

const prompt =
"samebot draw Darien standing next to a superdarien which is 2x wider in each dimension, standing next to a long Darien who is 2x taller but same width and depth";

const result = await resolver.resolve(prompt);

expect(result).not.toBeNull();
expect(result?.entities).toHaveLength(1);
expect(result?.entities[0]?.name).toBe("darien");
const entityNames = result?.entities.map((e) => e.name);
expect(entityNames).not.toContain("cody");
});

it("should not match 'cody' when prompt contains 'darien'", async () => {
const mockSupabase = createMockSupabaseClient(
["darien", "cody"],
{
darien: [{ name: "darien1.jpg", id: "1" }],
cody: [{ name: "cody1.jpg", id: "2" }],
},
);
const resolver = new EntityResolver(mockSupabase, mockLogger);

const prompt = "draw Darien standing next to a superdarien";

const result = await resolver.resolve(prompt);

if (result) {
const entityNames = result.entities.map((e) => e.name);
expect(entityNames).not.toContain("cody");
expect(entityNames).toContain("darien");
}
});

it("should not match 'superdarien' to 'cody' via fuzzy matching", async () => {
const mockSupabase = createMockSupabaseClient(
["darien", "cody"],
{
darien: [{ name: "darien1.jpg", id: "1" }],
cody: [{ name: "cody1.jpg", id: "2" }],
},
);
const resolver = new EntityResolver(mockSupabase, mockLogger);

const prompt = "draw a superdarien";

const result = await resolver.resolve(prompt);

if (result) {
const entityNames = result.entities.map((e) => e.name);
expect(entityNames).not.toContain("cody");
}
});

it("should match 'darien' as substring in 'superdarien' compound word", async () => {
const mockSupabase = createMockSupabaseClient(
["darien", "cody"],
{
darien: [{ name: "darien1.jpg", id: "1" }],
cody: [{ name: "cody1.jpg", id: "2" }],
},
);
const resolver = new EntityResolver(mockSupabase, mockLogger);

const prompt = "draw a superdarien";

const result = await resolver.resolve(prompt);

expect(result).not.toBeNull();
const entityNames = result?.entities.map((e) => e.name);
expect(entityNames).toContain("darien");
expect(entityNames).not.toContain("cody");
});

it("should prioritize exact 'darien' match over fuzzy 'cody' match", async () => {
const mockSupabase = createMockSupabaseClient(
["darien", "cody"],
{
darien: [{ name: "darien1.jpg", id: "1" }],
cody: [{ name: "cody1.jpg", id: "2" }],
},
);
const resolver = new EntityResolver(mockSupabase, mockLogger);

const prompt = "draw Darien and a superdarien";

const result = await resolver.resolve(prompt);

expect(result).not.toBeNull();
const entityNames = result?.entities.map((e) => e.name).sort();
expect(entityNames).toEqual(["darien"]);
expect(entityNames).not.toContain("cody");
});

it("should not match entity name at start of word (to avoid typo matches)", async () => {
const mockSupabase = createMockSupabaseClient(
["tyrus"],
{ tyrus: [{ name: "tyrus1.jpg", id: "1" }] },
);
const resolver = new EntityResolver(mockSupabase, mockLogger);

const prompt = "draw a tyrusss";

const result = await resolver.resolve(prompt);

expect(result).toBeNull();
});
});
});
41 changes: 41 additions & 0 deletions src/utils/entity-resolver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,47 @@ export class EntityResolver {
continue;
}

const substringMatch = searchableEntities.find((entity) => {
const normalizedEntityName = this.normalizeWord(entity.searchTerm);
const entityLength = normalizedEntityName.length;
const wordLength = normalizedWord.length;

if (entityLength < 3) {
return false;
}

if (wordLength <= entityLength) {
return false;
}

if (wordLength - entityLength < 2) {
return false;
}

if (!normalizedWord.includes(normalizedEntityName)) {
return false;
}

const index = normalizedWord.indexOf(normalizedEntityName);
if (index === 0) {
return false;
}

return true;
});

if (substringMatch) {
const existingMatch = matchedEntities.get(substringMatch.folderName);
if (!existingMatch || existingMatch.score < 0.95) {
matchedEntities.set(substringMatch.folderName, {
word,
folder: substringMatch.folderName,
score: 0.95,
});
}
continue;
}

const results = fuse.search(normalizedWord);
if (results.length > 0) {
const topResult = results[0]!;
Expand Down