diff --git a/backend/danswer/chat/models.py b/backend/danswer/chat/models.py index e04efd92a2f..d23f2229cc1 100644 --- a/backend/danswer/chat/models.py +++ b/backend/danswer/chat/models.py @@ -74,12 +74,24 @@ class DanswerQuotes(BaseModel): quotes: list[DanswerQuote] +class DanswerContext(BaseModel): + content: str + document_id: str + semantic_identifier: str + blurb: str + + +class DanswerContexts(BaseModel): + contexts: list[DanswerContext] + + class DanswerAnswer(BaseModel): answer: str | None class QAResponse(SearchResponse, DanswerAnswer): quotes: list[DanswerQuote] | None + contexts: list[DanswerContexts] | None predicted_flow: QueryFlow predicted_search: SearchType eval_res_valid: bool | None = None @@ -87,11 +99,8 @@ class QAResponse(SearchResponse, DanswerAnswer): error_msg: str | None = None -AnswerQuestionReturn = tuple[DanswerAnswer, DanswerQuotes] - - AnswerQuestionStreamReturn = Iterator[ - DanswerAnswerPiece | DanswerQuotes | StreamingError + DanswerAnswerPiece | DanswerQuotes | DanswerContexts | StreamingError ] diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index 04f95fbdbb2..4a2cb493cdb 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -1,3 +1,4 @@ +import itertools from collections.abc import Callable from collections.abc import Iterator from typing import cast @@ -6,6 +7,8 @@ from danswer.chat.chat_utils import get_chunks_for_qa from danswer.chat.models import DanswerAnswerPiece +from danswer.chat.models import DanswerContext +from danswer.chat.models import DanswerContexts from danswer.chat.models import DanswerQuotes from danswer.chat.models import LLMMetricsContainer from danswer.chat.models import LLMRelevanceFilterResponse @@ -67,6 +70,7 @@ def stream_answer_objects( | LLMRelevanceFilterResponse | DanswerAnswerPiece | DanswerQuotes + | DanswerContexts | StreamingError | ChatMessageDetail ]: @@ -229,6 +233,21 @@ def stream_answer_objects( else no_gen_ai_response() ) + if qa_model is not None and query_req.return_contexts: + contexts = DanswerContexts( + contexts=[ + DanswerContext( + content=context_doc.content, + document_id=context_doc.document_id, + semantic_identifier=context_doc.semantic_identifier, + blurb=context_doc.semantic_identifier, + ) + for context_doc in llm_chunks + ] + ) + + response_packets = itertools.chain(response_packets, [contexts]) + # Capture outputs and errors llm_output = "" error: str | None = None @@ -316,6 +335,8 @@ def get_search_answer( qa_response.llm_chunks_indices = packet.relevant_chunk_indices elif isinstance(packet, DanswerQuotes): qa_response.quotes = packet + elif isinstance(packet, DanswerContexts): + qa_response.contexts = packet elif isinstance(packet, StreamingError): qa_response.error_msg = packet.error elif isinstance(packet, ChatMessageDetail): diff --git a/backend/danswer/one_shot_answer/models.py b/backend/danswer/one_shot_answer/models.py index 1e5d94d27c7..6401b34404e 100644 --- a/backend/danswer/one_shot_answer/models.py +++ b/backend/danswer/one_shot_answer/models.py @@ -3,6 +3,7 @@ from pydantic import BaseModel from pydantic import root_validator +from danswer.chat.models import DanswerContexts from danswer.chat.models import DanswerQuotes from danswer.chat.models import QADocsResponse from danswer.configs.constants import MessageType @@ -25,6 +26,7 @@ class DirectQARequest(BaseModel): persona_id: int retrieval_options: RetrievalDetails chain_of_thought: bool = False + return_contexts: bool = False @root_validator def check_chain_of_thought_and_prompt_id( @@ -53,3 +55,4 @@ class OneShotQAResponse(BaseModel): error_msg: str | None = None answer_valid: bool = True # Reflexion result, default True if Reflexion not run chat_message_id: int | None = None + contexts: DanswerContexts | None = None