Skip to content

Commit

Permalink
Refactor to remove QueryRequest entity (#799)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza authored Jan 10, 2025
1 parent 06b0a4c commit ec4e486
Show file tree
Hide file tree
Showing 9 changed files with 166 additions and 212 deletions.
28 changes: 10 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -321,13 +321,11 @@ answer_response = ask(
`ask` is just a convenience wrapper around the real entrypoint, which can be accessed if you'd like to run concurrent asynchronous workloads:

```python
from paperqa import Settings, agent_query, QueryRequest
from paperqa import Settings, agent_query

answer_response = await agent_query(
QueryRequest(
query="What manufacturing challenges are unique to bispecific antibodies?",
settings=Settings(temperature=0.5, paper_directory="my_papers"),
)
query="What manufacturing challenges are unique to bispecific antibodies?",
settings=Settings(temperature=0.5, paper_directory="my_papers"),
)
```

Expand Down Expand Up @@ -682,7 +680,6 @@ import os

from paperqa import Settings
from paperqa.agents.main import agent_query
from paperqa.agents.models import QueryRequest
from paperqa.agents.search import get_directory_index


Expand All @@ -696,15 +693,12 @@ async def amain(folder_of_papers: str | os.PathLike) -> None:

# 2. Use the settings as many times as you want with ask
answer_response_1 = await agent_query(
query=QueryRequest(
query="What is the best way to make a vaccine?", settings=settings
)
query="What is the best way to make a vaccine?",
settings=settings,
)
answer_response_2 = await agent_query(
query=QueryRequest(
query="What manufacturing challenges are unique to bispecific antibodies?",
settings=settings,
)
query="What manufacturing challenges are unique to bispecific antibodies?",
settings=settings,
)
```

Expand All @@ -726,15 +720,13 @@ from ldp.agent import SimpleAgent
from ldp.alg.callbacks import MeanMetricsCallback
from ldp.alg.runners import Evaluator, EvaluatorConfig

from paperqa import QueryRequest, Settings
from paperqa import Settings
from paperqa.agents.task import TASK_DATASET_NAME


async def evaluate(folder_of_litqa_v2_papers: str | os.PathLike) -> None:
base_query = QueryRequest(
settings=Settings(paper_directory=folder_of_litqa_v2_papers)
)
dataset = TaskDataset.from_name(TASK_DATASET_NAME, base_query=base_query)
settings = Settings(paper_directory=folder_of_litqa_v2_papers)
dataset = TaskDataset.from_name(TASK_DATASET_NAME, settings=settings)
metrics_callback = MeanMetricsCallback(eval_dataset=dataset)

evaluator = Evaluator(
Expand Down
2 changes: 0 additions & 2 deletions paperqa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from paperqa.agents import ask
from paperqa.agents.main import agent_query
from paperqa.agents.models import QueryRequest
from paperqa.docs import Docs, PQASession, print_callback
from paperqa.llms import (
NumpyVectorStore,
Expand Down Expand Up @@ -46,7 +45,6 @@
"NumpyVectorStore",
"PQASession",
"QdrantVectorStore",
"QueryRequest",
"SentenceTransformerEmbeddingModel",
"Settings",
"SparseEmbeddingModel",
Expand Down
7 changes: 2 additions & 5 deletions paperqa/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from paperqa.version import __version__

from .main import agent_query, index_search
from .models import AnswerResponse, QueryRequest
from .models import AnswerResponse
from .search import SearchIndex, get_directory_index

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -102,10 +102,7 @@ def ask(query: str | MultipleChoiceQuestion, settings: Settings) -> AnswerRespon
"""Query PaperQA via an agent."""
configure_cli_logging(settings)
return get_loop().run_until_complete(
agent_query(
QueryRequest(query=query, settings=settings),
agent_type=settings.agent.agent_type,
)
agent_query(query, settings, agent_type=settings.agent.agent_type)
)


Expand Down
37 changes: 23 additions & 14 deletions paperqa/agents/env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from copy import deepcopy
from typing import Any, ClassVar, Self, cast
from uuid import UUID

from aviary.core import (
Environment,
Expand All @@ -23,7 +24,6 @@
from paperqa.types import PQASession
from paperqa.utils import get_year

from .models import QueryRequest
from .tools import (
AVAILABLE_TOOL_NAME_TO_CLASS,
DEFAULT_TOOL_NAMES,
Expand Down Expand Up @@ -207,25 +207,27 @@ class PaperQAEnvironment(Environment[EnvironmentState]):

def __init__(
self,
query: QueryRequest,
query: str | MultipleChoiceQuestion,
settings: Settings,
docs: Docs,
llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS,
summary_llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS,
embedding_model: EmbeddingModel | None = POPULATE_FROM_SETTINGS,
session_id: UUID | None = None,
**env_kwargs,
):
super().__init__(**env_kwargs)
# Hold onto QueryRequest to create fresh tools and answer during each reset
self._query = query
# Hold onto Docs to clear and reuse in state during each reset
self._settings = settings
self._docs = docs
self._llm_model = llm_model
self._summary_llm_model = summary_llm_model
self._embedding_model = embedding_model
self._session_id = session_id

def make_tools(self) -> list[Tool]:
return settings_to_tools(
settings=self._query.settings,
settings=self._settings,
llm_model=self._llm_model,
summary_llm_model=self._summary_llm_model,
embedding_model=self._embedding_model,
Expand All @@ -235,17 +237,23 @@ def make_initial_state(self) -> EnvironmentState:
status_fn = None

if ClinicalTrialsSearch.TOOL_FN_NAME in (
self._query.settings.agent.tool_names or DEFAULT_TOOL_NAMES
self._settings.agent.tool_names or DEFAULT_TOOL_NAMES
):
status_fn = clinical_trial_status

query: str | MultipleChoiceQuestion = self._query.query
session_kwargs: dict[str, Any] = {}
if self._session_id:
session_kwargs["id"] = self._session_id
return EnvironmentState(
docs=self._docs,
session=PQASession(
question=query if isinstance(query, str) else query.question_prompt,
config_md5=self._query.settings.md5,
id=self._query.id,
question=(
self._query
if isinstance(self._query, str)
else self._query.question_prompt
),
config_md5=self._settings.md5,
**session_kwargs,
),
status_fn=status_fn,
)
Expand All @@ -259,7 +267,7 @@ async def reset(self) -> tuple[list[Message], list[Tool]]:
return (
[
Message(
content=self._query.settings.agent.agent_prompt.format(
content=self._settings.agent.agent_prompt.format(
question=self.state.session.question,
status=self.state.status,
complete_tool_name=Complete.TOOL_FN_NAME,
Expand All @@ -273,15 +281,15 @@ def export_frame(self) -> Frame:
return Frame(state=self.state, info={"query": self._query})

def _has_excess_answer_failures(self) -> bool:
if self._query.settings.answer.max_answer_attempts is None:
if self._settings.answer.max_answer_attempts is None:
return False
return (
sum(
tn == GenerateAnswer.gen_answer.__name__
for s in self.state.session.tool_history
for tn in s
)
> self._query.settings.answer.max_answer_attempts
> self._settings.answer.max_answer_attempts
)

USE_POST_PROCESSED_REWARD: ClassVar[float] = 0.0
Expand Down Expand Up @@ -331,7 +339,8 @@ def __deepcopy__(self, memo) -> Self:
)
}
copy_self = type(self)(
query=deepcopy(self._query, memo), # deepcopy for _docs_name
query=self._query, # No need to copy since we read only
settings=deepcopy(self._settings, memo), # Deepcopy just to be safe
docs=copy_state.docs,
**env_model_kwargs,
)
Expand Down
Loading

0 comments on commit ec4e486

Please sign in to comment.