diff --git a/graphrag/query/llm/oai/chat_openai.py b/graphrag/query/llm/oai/chat_openai.py index ad626a3974..60cf105b79 100644 --- a/graphrag/query/llm/oai/chat_openai.py +++ b/graphrag/query/llm/oai/chat_openai.py @@ -22,6 +22,7 @@ OPENAI_RETRY_ERROR_TYPES, OpenaiApiType, ) +import litellm _MODEL_REQUIRED_MSG = "model is required" @@ -193,7 +194,7 @@ def _generate( model = self.model if not model: raise ValueError(_MODEL_REQUIRED_MSG) - response = self.sync_client.chat.completions.create( # type: ignore + response = litellm.completion( # type: ignore model=model, messages=messages, # type: ignore stream=streaming, @@ -233,7 +234,7 @@ def _stream_generate( model = self.model if not model: raise ValueError(_MODEL_REQUIRED_MSG) - response = self.sync_client.chat.completions.create( # type: ignore + response = litellm.completion( # type: ignore model=model, messages=messages, # type: ignore stream=True, @@ -265,7 +266,7 @@ async def _agenerate( model = self.model if not model: raise ValueError(_MODEL_REQUIRED_MSG) - response = await self.async_client.chat.completions.create( # type: ignore + response = await litellm.acompletion( # type: ignore model=model, messages=messages, # type: ignore stream=streaming, @@ -306,7 +307,7 @@ async def _astream_generate( model = self.model if not model: raise ValueError(_MODEL_REQUIRED_MSG) - response = await self.async_client.chat.completions.create( # type: ignore + response = await litellm.acompletion( # type: ignore model=model, messages=messages, # type: ignore stream=True, diff --git a/graphrag/query/llm/oai/openai.py b/graphrag/query/llm/oai/openai.py index 76bb5fe52c..588f4185d5 100644 --- a/graphrag/query/llm/oai/openai.py +++ b/graphrag/query/llm/oai/openai.py @@ -6,6 +6,7 @@ import logging from typing import Any +import litellm from tenacity import ( AsyncRetrying, RetryError, @@ -117,7 +118,7 @@ def _generate( callbacks: list[BaseLLMCallback] | None = None, **kwargs: Any, ) -> str: - response = self.sync_client.chat.completions.create( # type: ignore + response = litellm.completion( # type: ignore model=self.model, messages=messages, # type: ignore stream=streaming, @@ -155,7 +156,7 @@ async def _agenerate( callbacks: list[BaseLLMCallback] | None = None, **kwargs: Any, ) -> str: - response = await self.async_client.chat.completions.create( # type: ignore + response = await litellm.acompletion( # type: ignore model=self.model, messages=messages, # type: ignore stream=streaming,