Skip to content

Commit

Permalink
Support Azure OpenAI API endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
debanjum committed Jan 9, 2025
1 parent 266d274 commit f6948a2
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 12 deletions.
22 changes: 10 additions & 12 deletions src/khoj/processor/conversation/openai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
ThreadedGenerator,
commit_conversation_trace,
)
from khoj.utils.helpers import get_chat_usage_metrics, is_promptrace_enabled
from khoj.utils.helpers import (
get_chat_usage_metrics,
get_openai_client,
is_promptrace_enabled,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -51,10 +55,7 @@ def completion_with_backoff(
client_key = f"{openai_api_key}--{api_base_url}"
client: openai.OpenAI | None = openai_clients.get(client_key)
if not client:
client = openai.OpenAI(
api_key=openai_api_key,
base_url=api_base_url,
)
client = get_openai_client(openai_api_key, api_base_url)
openai_clients[client_key] = client

formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
Expand Down Expand Up @@ -158,14 +159,11 @@ def llm_thread(
):
try:
client_key = f"{openai_api_key}--{api_base_url}"
if client_key not in openai_clients:
client = openai.OpenAI(
api_key=openai_api_key,
base_url=api_base_url,
)
openai_clients[client_key] = client
else:
if client_key in openai_clients:
client = openai_clients[client_key]
else:
client = get_openai_client(openai_api_key, api_base_url)
openai_clients[client_key] = client

formatted_messages = [{"role": message.role, "content": message.content} for message in messages]

Expand Down
18 changes: 18 additions & 0 deletions src/khoj/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import TYPE_CHECKING, Any, Optional, Union
from urllib.parse import urlparse

import openai
import psutil
import requests
import torch
Expand Down Expand Up @@ -596,3 +597,20 @@ def get_chat_usage_metrics(
"output_tokens": prev_usage["output_tokens"] + output_tokens,
"cost": cost or get_cost_of_chat_message(model_name, input_tokens, output_tokens, prev_cost=prev_usage["cost"]),
}


def get_openai_client(api_key: str, api_base_url: str) -> Union[openai.OpenAI, openai.AzureOpenAI]:
"""Get OpenAI or AzureOpenAI client based on the API Base URL"""
parsed_url = urlparse(api_base_url)
if parsed_url.hostname and parsed_url.hostname.endswith(".openai.azure.com"):
client = openai.AzureOpenAI(
api_key=api_key,
azure_endpoint=api_base_url,
api_version="2024-10-21",
)
else:
client = openai.OpenAI(
api_key=api_key,
base_url=api_base_url,
)
return client

0 comments on commit f6948a2

Please sign in to comment.