Skip to content

Commit

Permalink
Removed QueryRequest from the stack, in favor of separate pieces of i…
Browse files Browse the repository at this point in the history
…nformation
  • Loading branch information
jamesbraza committed Jan 9, 2025
1 parent 8b0ed16 commit 618e857
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 193 deletions.
2 changes: 0 additions & 2 deletions paperqa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from paperqa.agents import ask
from paperqa.agents.main import agent_query
from paperqa.agents.models import QueryRequest
from paperqa.docs import Docs, PQASession, print_callback
from paperqa.llms import (
NumpyVectorStore,
Expand Down Expand Up @@ -46,7 +45,6 @@
"NumpyVectorStore",
"PQASession",
"QdrantVectorStore",
"QueryRequest",
"SentenceTransformerEmbeddingModel",
"Settings",
"SparseEmbeddingModel",
Expand Down
7 changes: 2 additions & 5 deletions paperqa/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from paperqa.version import __version__

from .main import agent_query, index_search
from .models import AnswerResponse, QueryRequest
from .models import AnswerResponse
from .search import SearchIndex, get_directory_index

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -102,10 +102,7 @@ def ask(query: str | MultipleChoiceQuestion, settings: Settings) -> AnswerRespon
"""Query PaperQA via an agent."""
configure_cli_logging(settings)
return get_loop().run_until_complete(
agent_query(
QueryRequest(query=query, settings=settings),
agent_type=settings.agent.agent_type,
)
agent_query(query, settings, agent_type=settings.agent.agent_type)
)


Expand Down
37 changes: 23 additions & 14 deletions paperqa/agents/env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from copy import deepcopy
from typing import Any, ClassVar, Self, cast
from uuid import UUID

from aviary.core import (
Environment,
Expand All @@ -23,7 +24,6 @@
from paperqa.types import PQASession
from paperqa.utils import get_year

from .models import QueryRequest
from .tools import (
AVAILABLE_TOOL_NAME_TO_CLASS,
DEFAULT_TOOL_NAMES,
Expand Down Expand Up @@ -207,25 +207,27 @@ class PaperQAEnvironment(Environment[EnvironmentState]):

def __init__(
self,
query: QueryRequest,
query: str | MultipleChoiceQuestion,
settings: Settings,
docs: Docs,
llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS,
summary_llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS,
embedding_model: EmbeddingModel | None = POPULATE_FROM_SETTINGS,
session_id: UUID | None = None,
**env_kwargs,
):
super().__init__(**env_kwargs)
# Hold onto QueryRequest to create fresh tools and answer during each reset
self._query = query
# Hold onto Docs to clear and reuse in state during each reset
self._settings = settings
self._docs = docs
self._llm_model = llm_model
self._summary_llm_model = summary_llm_model
self._embedding_model = embedding_model
self._session_id = session_id

def make_tools(self) -> list[Tool]:
return settings_to_tools(
settings=self._query.settings,
settings=self._settings,
llm_model=self._llm_model,
summary_llm_model=self._summary_llm_model,
embedding_model=self._embedding_model,
Expand All @@ -235,17 +237,23 @@ def make_initial_state(self) -> EnvironmentState:
status_fn = None

if ClinicalTrialsSearch.TOOL_FN_NAME in (
self._query.settings.agent.tool_names or DEFAULT_TOOL_NAMES
self._settings.agent.tool_names or DEFAULT_TOOL_NAMES
):
status_fn = clinical_trial_status

query: str | MultipleChoiceQuestion = self._query.query
session_kwargs: dict[str, Any] = {}
if self._session_id:
session_kwargs["id"] = self._session_id
return EnvironmentState(
docs=self._docs,
session=PQASession(
question=query if isinstance(query, str) else query.question_prompt,
config_md5=self._query.settings.md5,
id=self._query.id,
question=(
self._query
if isinstance(self._query, str)
else self._query.question_prompt
),
config_md5=self._settings.md5,
**session_kwargs,
),
status_fn=status_fn,
)
Expand All @@ -259,7 +267,7 @@ async def reset(self) -> tuple[list[Message], list[Tool]]:
return (
[
Message(
content=self._query.settings.agent.agent_prompt.format(
content=self._settings.agent.agent_prompt.format(
question=self.state.session.question,
status=self.state.status,
complete_tool_name=Complete.TOOL_FN_NAME,
Expand All @@ -273,15 +281,15 @@ def export_frame(self) -> Frame:
return Frame(state=self.state, info={"query": self._query})

def _has_excess_answer_failures(self) -> bool:
if self._query.settings.answer.max_answer_attempts is None:
if self._settings.answer.max_answer_attempts is None:
return False
return (
sum(
tn == GenerateAnswer.gen_answer.__name__
for s in self.state.session.tool_history
for tn in s
)
> self._query.settings.answer.max_answer_attempts
> self._settings.answer.max_answer_attempts
)

USE_POST_PROCESSED_REWARD: ClassVar[float] = 0.0
Expand Down Expand Up @@ -331,7 +339,8 @@ def __deepcopy__(self, memo) -> Self:
)
}
copy_self = type(self)(
query=deepcopy(self._query, memo), # deepcopy for _docs_name
query=self._query, # No need to copy since we read only
settings=deepcopy(self._settings, memo), # Deepcopy just to be safe
docs=copy_state.docs,
**env_model_kwargs,
)
Expand Down
Loading

0 comments on commit 618e857

Please sign in to comment.