From f6948a2402e1c720c29e476488e0119048db36e9 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Wed, 8 Jan 2025 20:48:39 +0700 Subject: [PATCH] Support Azure OpenAI API endpoint --- .../processor/conversation/openai/utils.py | 22 +++++++++---------- src/khoj/utils/helpers.py | 18 +++++++++++++++ 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 8af836f10..f2dc42d23 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -19,7 +19,11 @@ ThreadedGenerator, commit_conversation_trace, ) -from khoj.utils.helpers import get_chat_usage_metrics, is_promptrace_enabled +from khoj.utils.helpers import ( + get_chat_usage_metrics, + get_openai_client, + is_promptrace_enabled, +) logger = logging.getLogger(__name__) @@ -51,10 +55,7 @@ def completion_with_backoff( client_key = f"{openai_api_key}--{api_base_url}" client: openai.OpenAI | None = openai_clients.get(client_key) if not client: - client = openai.OpenAI( - api_key=openai_api_key, - base_url=api_base_url, - ) + client = get_openai_client(openai_api_key, api_base_url) openai_clients[client_key] = client formatted_messages = [{"role": message.role, "content": message.content} for message in messages] @@ -158,14 +159,11 @@ def llm_thread( ): try: client_key = f"{openai_api_key}--{api_base_url}" - if client_key not in openai_clients: - client = openai.OpenAI( - api_key=openai_api_key, - base_url=api_base_url, - ) - openai_clients[client_key] = client - else: + if client_key in openai_clients: client = openai_clients[client_key] + else: + client = get_openai_client(openai_api_key, api_base_url) + openai_clients[client_key] = client formatted_messages = [{"role": message.role, "content": message.content} for message in messages] diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 6214e5f50..b78dc9d7f 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union from urllib.parse import urlparse +import openai import psutil import requests import torch @@ -596,3 +597,20 @@ def get_chat_usage_metrics( "output_tokens": prev_usage["output_tokens"] + output_tokens, "cost": cost or get_cost_of_chat_message(model_name, input_tokens, output_tokens, prev_cost=prev_usage["cost"]), } + + +def get_openai_client(api_key: str, api_base_url: str) -> Union[openai.OpenAI, openai.AzureOpenAI]: + """Get OpenAI or AzureOpenAI client based on the API Base URL""" + parsed_url = urlparse(api_base_url) + if parsed_url.hostname and parsed_url.hostname.endswith(".openai.azure.com"): + client = openai.AzureOpenAI( + api_key=api_key, + azure_endpoint=api_base_url, + api_version="2024-10-21", + ) + else: + client = openai.OpenAI( + api_key=api_key, + base_url=api_base_url, + ) + return client