Skip to content

Commit

Permalink
update LlmProcessor for AOT
Browse files Browse the repository at this point in the history
  • Loading branch information
AsakerMohd committed Jan 15, 2025
1 parent e49ddaf commit ddd68f3
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@ internal class AWSLlmModelProcessor
"Specify StringComparison for clarity",
"CA1307",
Justification = "Adding StringComparison only works for NET Core but not the framework.")]
internal static void ProcessGenAiAttributes<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties)]T>(Activity activity, T message, string modelName, bool isRequest, AWSSemanticConventions awsSemanticConventions)
#else
internal static void ProcessGenAiAttributes<T>(Activity activity, T message, string modelName, bool isRequest, AWSSemanticConventions awsSemanticConventions)
#endif
internal static void ProcessGenAiAttributes(Activity activity, MemoryStream body, string modelName, bool isRequest, AWSSemanticConventions awsSemanticConventions)
{
// message can be either a request or a response. isRequest is used by the model-specific methods to determine
// whether to extract the request or response attributes.
Expand All @@ -31,65 +29,52 @@ internal static void ProcessGenAiAttributes<T>(Activity activity, T message, str
// the response body. For the Claude, Command, and Mistral models, the input and output tokens are not provided
// in the response body, so we approximate their values by dividing the input and output lengths by 6, based on
// the Bedrock documentation here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-prepare.html

if (message is null)
{
return;
}

var messageBodyProperty = typeof(T).GetProperty("Body");
if (messageBodyProperty != null)
try
{
if (messageBodyProperty.GetValue(message) is MemoryStream body)
{
try
{
var jsonString = Encoding.UTF8.GetString(body.ToArray());
var jsonString = Encoding.UTF8.GetString(body.ToArray());
#if NET
var jsonObject = JsonSerializer.Deserialize(jsonString, SourceGenerationContext.Default.DictionaryStringJsonElement);
var jsonObject = JsonSerializer.Deserialize(jsonString, SourceGenerationContext.Default.DictionaryStringJsonElement);
#else
var jsonObject = JsonSerializer.Deserialize<Dictionary<string, JsonElement>>(jsonString);
var jsonObject = JsonSerializer.Deserialize<Dictionary<string, JsonElement>>(jsonString);
#endif
if (jsonObject == null)
{
return;
}
if (jsonObject == null)
{
return;
}

// extract model specific attributes based on model name
if (modelName.Contains("amazon.nova"))
{
ProcessNovaModelAttributes(activity, jsonObject, isRequest, awsSemanticConventions);
}
else if (modelName.Contains("amazon.titan"))
{
ProcessTitanModelAttributes(activity, jsonObject, isRequest, awsSemanticConventions);
}
else if (modelName.Contains("anthropic.claude"))
{
ProcessClaudeModelAttributes(activity, jsonObject, isRequest, awsSemanticConventions);
}
else if (modelName.Contains("meta.llama3"))
{
ProcessLlamaModelAttributes(activity, jsonObject, isRequest, awsSemanticConventions);
}
else if (modelName.Contains("cohere.command"))
{
ProcessCommandModelAttributes(activity, jsonObject, isRequest, awsSemanticConventions);
}
else if (modelName.Contains("ai21.jamba"))
{
ProcessJambaModelAttributes(activity, jsonObject, isRequest, awsSemanticConventions);
}
else if (modelName.Contains("mistral.mistral"))
{
ProcessMistralModelAttributes(activity, jsonObject, isRequest, awsSemanticConventions);
}
}
catch (Exception ex)
{
AWSInstrumentationEventSource.Log.JsonParserException(nameof(AWSLlmModelProcessor), ex);
}
// extract model specific attributes based on model name
if (modelName.Contains("amazon.nova"))
{
ProcessNovaModelAttributes(activity, jsonObject, isRequest, awsSemanticConventions);
}
else if (modelName.Contains("amazon.titan"))
{
ProcessTitanModelAttributes(activity, jsonObject, isRequest, awsSemanticConventions);
}
else if (modelName.Contains("anthropic.claude"))
{
ProcessClaudeModelAttributes(activity, jsonObject, isRequest, awsSemanticConventions);
}
else if (modelName.Contains("meta.llama3"))
{
ProcessLlamaModelAttributes(activity, jsonObject, isRequest, awsSemanticConventions);
}
else if (modelName.Contains("cohere.command"))
{
ProcessCommandModelAttributes(activity, jsonObject, isRequest, awsSemanticConventions);
}
else if (modelName.Contains("ai21.jamba"))
{
ProcessJambaModelAttributes(activity, jsonObject, isRequest, awsSemanticConventions);
}
else if (modelName.Contains("mistral.mistral"))
{
ProcessMistralModelAttributes(activity, jsonObject, isRequest, awsSemanticConventions);
}
}
catch (Exception ex)
{
AWSInstrumentationEventSource.Log.JsonParserException(nameof(AWSLlmModelProcessor), ex);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0

using System.Diagnostics;
using Amazon.BedrockRuntime.Model;
using Amazon.Runtime;
using Amazon.Runtime.Internal;
using Amazon.Runtime.Telemetry;
Expand Down Expand Up @@ -138,7 +139,8 @@ private void AddResponseSpecificInformation(Activity activity, IExecutionContext
var modelString = model.ToString();
if (modelString != null)
{
AWSLlmModelProcessor.ProcessGenAiAttributes(activity, responseContext.Response, modelString, false, this.awsSemanticConventions);
var response = (InvokeModelResponse)responseContext.Response;
AWSLlmModelProcessor.ProcessGenAiAttributes(activity, response.Body, modelString, false, this.awsSemanticConventions);
}
}
}
Expand Down Expand Up @@ -183,7 +185,8 @@ private void AddRequestSpecificInformation(Activity activity, IRequestContext re
var modelString = model.ToString();
if (modelString != null)
{
AWSLlmModelProcessor.ProcessGenAiAttributes(activity, request, modelString, true, this.awsSemanticConventions);
var invokeModelRequest = (InvokeModelRequest)request;
AWSLlmModelProcessor.ProcessGenAiAttributes(activity, invokeModelRequest.Body, modelString, true, this.awsSemanticConventions);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
<ItemGroup>
<PackageReference Include="AWSSDK.Core" Version="3.7.400" />
<PackageReference Include="AWSSDK.SimpleNotificationService" Version="3.7.400" />
<PackageReference Include="AWSSDK.BedrockRuntime" Version="3.7.400" />
<PackageReference Include="AWSSDK.SQS" Version="3.7.400" />
</ItemGroup>

Expand Down

0 comments on commit ddd68f3

Please sign in to comment.