diff --git a/package-lock.json b/package-lock.json index faf9e73..343217b 100644 --- a/package-lock.json +++ b/package-lock.json @@ -16,6 +16,7 @@ "@langchain/community": "^0.3.53", "@langchain/core": "^0.3.72", "ai": "^4.3.19", + "dedent": "^1.7.0", "react-basic-contenteditable": "^1.0.6", "react-markdown": "^10.1.0" }, @@ -6972,6 +6973,20 @@ "node": ">=0.10" } }, + "node_modules/dedent": { + "version": "1.7.0", + "resolved": "https://registry.npmjs.org/dedent/-/dedent-1.7.0.tgz", + "integrity": "sha512-HGFtf8yhuhGhqO07SV79tRp+br4MnbdjeVxotpn1QBl30pcLLCQjX5b2295ll0fv8RKDKsmWYrl05usHM9CewQ==", + "license": "MIT", + "peerDependencies": { + "babel-plugin-macros": "^3.1.0" + }, + "peerDependenciesMeta": { + "babel-plugin-macros": { + "optional": true + } + } + }, "node_modules/deep-eql": { "version": "5.0.2", "resolved": "https://registry.npmjs.org/deep-eql/-/deep-eql-5.0.2.tgz", diff --git a/package.json b/package.json index 1c37de1..0aef9d5 100644 --- a/package.json +++ b/package.json @@ -67,6 +67,7 @@ "@langchain/community": "^0.3.53", "@langchain/core": "^0.3.72", "ai": "^4.3.19", + "dedent": "^1.7.0", "react-basic-contenteditable": "^1.0.6", "react-markdown": "^10.1.0" }, diff --git a/src/plugin/Panel/ChatInput/index.tsx b/src/plugin/Panel/ChatInput/index.tsx index afbe03b..297e072 100644 --- a/src/plugin/Panel/ChatInput/index.tsx +++ b/src/plugin/Panel/ChatInput/index.tsx @@ -38,6 +38,9 @@ export const ChatInput: FC = () => { content: input, }, ], + context: { + canvas: state.activeCanvas, + }, }; if (state.selectedMedia.length) { diff --git a/src/plugin/Panel/index.tsx b/src/plugin/Panel/index.tsx index 4d9517a..6182baa 100644 --- a/src/plugin/Panel/index.tsx +++ b/src/plugin/Panel/index.tsx @@ -22,15 +22,10 @@ export function PluginPanelComponent(props: CloverPlugin & PluginProps) { useEffect(() => { if (provider) { provider.update_dispatch(dispatch); + provider.update_plugin_state(state); + provider.set_system_prompt(); dispatch({ type: "updateProvider", provider }); } - - if (!state.systemPrompt) { - dispatch({ - type: "setSystemPrompt", - systemPrompt: `You are a helpful assistant that can answer questions about the item in the viewer`, - }); - } }, []); // eslint-disable-line react-hooks/exhaustive-deps useEffect(() => { @@ -51,16 +46,11 @@ export function PluginPanelComponent(props: CloverPlugin & PluginProps) { useEffect(() => { if (state.manifest) { - // Update system prompt with manifest metadata - dispatch({ - type: "setSystemPrompt", - systemPrompt: `You are a helpful assistant that can answer questions about the item in the viewer. Here is the manifest data for the item:\n\n${JSON.stringify(state.manifest["metadata"], null, 2)}`, - }); const label = state.manifest?.label ?? undefined; const title = getLabelByUserLanguage(label); setItemTitle(title.length > 0 ? title[0] : "this item"); } - }, [state.manifest, dispatch]); + }, [state.manifest]); useEffect(() => { dispatch({ diff --git a/src/plugin/base_provider.tsx b/src/plugin/base_provider.tsx index 47371ef..8658953 100644 --- a/src/plugin/base_provider.tsx +++ b/src/plugin/base_provider.tsx @@ -1,5 +1,8 @@ -import type { PluginContextActions } from "@context"; +import type { PluginContextActions, PluginContextStore } from "@context"; +import { ManifestNormalized } from "@iiif/presentation-3-normalized"; import type { ConversationState, Message } from "@types"; +import { getLabelByUserLanguage } from "@utils"; +import dedent from "dedent"; import type { Dispatch } from "react"; type ProviderStatus = "initializing" | "ready" | "error"; @@ -9,42 +12,74 @@ type ProviderStatus = "initializing" | "ready" | "error"; * */ export abstract class BaseProvider { - #dispatch: Dispatch | undefined; + #plugin_dispatch: Dispatch | undefined; + #plugin_state: PluginContextStore | undefined; #status: ProviderStatus; constructor() { this.#status = "ready"; } - private get dispatch(): Dispatch { - if (!this.#dispatch) { + private get plugin_dispatch(): Dispatch { + if (!this.#plugin_dispatch) { throw new Error("Provider dispatch not initialized."); } - return this.#dispatch; + return this.#plugin_dispatch; } /** * Sets the dispatch function to allow the provider to update Plugin state */ - private set dispatch(dispatch: Dispatch) { - this.#dispatch = dispatch; + private set plugin_dispatch(dispatch: Dispatch) { + this.#plugin_dispatch = dispatch; + } + + protected get plugin_state(): PluginContextStore { + if (!this.#plugin_state) { + throw new Error("Provider plugin_state not initialized."); + } + return this.#plugin_state; + } + + protected set plugin_state(state: PluginContextStore) { + this.#plugin_state = state; } /** * Add messages to the Plugin state */ protected add_messages(messages: Message[]) { - this.dispatch({ + this.plugin_dispatch({ type: "addMessages", messages, }); } + /** + * Generate a system prompt based on the provided manifest + * + * @param manifest the IIIF manifest + * @returns a system prompt string based on the manifest data + */ + protected generate_system_prompt(manifest: ManifestNormalized) { + const title = getLabelByUserLanguage(manifest.label ?? undefined)?.[0] ?? "N/A"; + const summary = getLabelByUserLanguage(manifest.summary ?? undefined)?.[0] ?? "N/A"; + return dedent` + You are a helpful assistant that can answer questions about the item in the image viewer. + + Here is the manifest data for the item: + + ## Title: ${title} + ## Summary: ${summary} + ## Raw Metadata: ${JSON.stringify(manifest.metadata)} + `; + } + /** * Update the Plugin's conversation state. */ protected set_conversation_state(state: ConversationState) { - this.dispatch({ + this.plugin_dispatch({ type: "setConversationState", conversationState: state, }); @@ -54,7 +89,7 @@ export abstract class BaseProvider { * Update the last message in the Plugin state. */ protected update_last_message(message: Message) { - this.dispatch({ + this.plugin_dispatch({ type: "updateLastMessage", message, }); @@ -64,7 +99,7 @@ export abstract class BaseProvider { * Update the Plugin state with the current provider. */ protected update_plugin_provider() { - this.dispatch({ + this.plugin_dispatch({ type: "updateProvider", provider: this, }); @@ -83,6 +118,18 @@ export abstract class BaseProvider { this.#status = value; } + /** + * Set the system prompt in the Plugin state based on the current manifest. + */ + set_system_prompt() { + const systemPrompt = this.generate_system_prompt(this.plugin_state.manifest); + + this.plugin_dispatch({ + type: "setSystemPrompt", + systemPrompt, + }); + } + /** * A component that providers can implement to set up their UI. */ @@ -91,6 +138,10 @@ export abstract class BaseProvider { } update_dispatch(dispatch: Dispatch) { - this.dispatch = dispatch; + this.plugin_dispatch = dispatch; + } + + update_plugin_state(context: PluginContextStore) { + this.plugin_state = context; } } diff --git a/src/plugin/context/index.ts b/src/plugin/context/index.ts index 8561ab5..4f86a2c 100644 --- a/src/plugin/context/index.ts +++ b/src/plugin/context/index.ts @@ -1,6 +1,8 @@ // using a barrel file helps tsc-alias resolve the path correctly -import type { PluginContextActions } from "./plugin-context"; -import { PluginContextProvider, usePlugin } from "./plugin-context"; -export { PluginContextProvider, usePlugin }; -export type { PluginContextActions }; +export { + PluginContextProvider, + usePlugin, + type PluginContextActions, + type PluginContextStore, +} from "./plugin-context"; diff --git a/src/providers/userTokenProvider/index.tsx b/src/providers/userTokenProvider/index.tsx index bb04002..73b5a13 100644 --- a/src/providers/userTokenProvider/index.tsx +++ b/src/providers/userTokenProvider/index.tsx @@ -2,9 +2,13 @@ import { createAnthropic, type AnthropicProvider } from "@ai-sdk/anthropic"; import { createGoogleGenerativeAI, type google } from "@ai-sdk/google"; import { createOpenAI, type OpenAIProvider } from "@ai-sdk/openai"; import { Button, Heading, Input } from "@components"; +import { serializeConfigPresentation3, Traverse } from "@iiif/parser"; +import type { Canvas } from "@iiif/presentation-3"; import { Tool } from "@langchain/core/tools"; import type { AssistantMessage, Message } from "@types"; -import { streamText, tool } from "ai"; +import { getLabelByUserLanguage } from "@utils"; +import { CoreMessage, streamText, tool } from "ai"; +import dedent from "dedent"; import React from "react"; import { BaseProvider } from "../../plugin/base_provider"; import { ModelSelection } from "./components/ModelSelection"; @@ -44,8 +48,12 @@ export class UserTokenProvider extends BaseProvider { * * @param message * @returns a formatted message + * + * @privateRemarks + * + * Use an arrow function so `this` references the `UserTokenProvider` class */ - #format_message(message: Message) { + #format_message = (message: Message, index: number, messages: Message[]): CoreMessage => { switch (message.role) { case "user": return { @@ -54,7 +62,56 @@ export class UserTokenProvider extends BaseProvider { if (c.type === "media") { return { type: "image", image: c.content.src }; } - return { type: "text", text: c.content }; + + const prevMessages = messages.slice(0, index); + const lastUserMessage = prevMessages.findLast((m) => m.role === "user"); + let context = ""; + + // only add new context to the messages when it changes to save on tokens + if ( + !lastUserMessage || + lastUserMessage.context.canvas.id !== message.context.canvas.id + ) { + const canvas = this.plugin_state.vault.serialize( + { + type: "Canvas", + id: message.context.canvas.id, + }, + serializeConfigPresentation3, + ); + + const annotationTexts: string[] = []; + const traverse = new Traverse({ + annotation: [ + (a) => { + if ( + a.body && + typeof a.body === "object" && + "type" in a.body && + a.body.type === "TextualBody" && + a.body.value + ) { + annotationTexts.push(a.body.value); + } + }, + ], + }); + + traverse.traverseCanvas(canvas); + + // prettier-ignore + context = dedent.withOptions({ alignValues: true })` + ## Context + The following context is about the latest Canvas in the image viewer. + Use this information if possible to inform your answer. + + ### Canvas${canvas.label ? ` + - Label: ${getLabelByUserLanguage(canvas.label)[0]}` : ""}${annotationTexts.length ? ` + - Annotations: ${annotationTexts.join(", ")}` : ""} + `; + } + + return { type: "text", text: c.content + `${context ? "\n" + context : ""}` }; }), }; case "assistant": @@ -65,7 +122,7 @@ export class UserTokenProvider extends BaseProvider { // @ts-expect-error - this is a catch-all for unsupported roles throw new Error(`Unsupported message role: ${message.role}`); } - } + }; #is_valid_model_provider_model(provider: Provider, model: string): boolean { return this.models_by_provider[provider].includes(model); @@ -181,7 +238,6 @@ export class UserTokenProvider extends BaseProvider { model, tools: this.#transform_tools(), maxSteps: this.max_steps, - // @ts-expect-error - there is a type mismatch here, but it works messages: all_messages.map(this.#format_message), }); diff --git a/src/types.d.ts b/src/types.d.ts index 53002de..24cc98f 100644 --- a/src/types.d.ts +++ b/src/types.d.ts @@ -1,3 +1,4 @@ +import type { CanvasNormalized } from "@iiif/presentation-3-normalized"; export type ConversationState = "idle" | "assistant_responding" | "error"; export type Role = "assistant" | "system" | "user"; @@ -39,6 +40,10 @@ export type AssistantMessage = { export interface UserMessage { content: (TextContent | MediaContent)[]; + /** Context that can be added to user messages when generating a response */ + context: { + canvas: CanvasNormalized; + }; role: Extract; }