Skip to content

Commit

Permalink
warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel committed Dec 16, 2024
1 parent 0e72b9b commit 58317fc
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ exports[`streamText > multiple stream consumption > should support text stream,
"promptTokens": 3,
"totalTokens": 13,
},
"warnings": undefined,
},
{
"experimental_providerMetadata": undefined,
Expand Down Expand Up @@ -459,6 +460,7 @@ exports[`streamText > options.maxSteps > 2 steps: initial, tool-result > should
"promptTokens": 3,
"totalTokens": 13,
},
"warnings": undefined,
},
{
"textDelta": "Hello, ",
Expand Down Expand Up @@ -488,6 +490,7 @@ exports[`streamText > options.maxSteps > 2 steps: initial, tool-result > should
"promptTokens": 1,
"totalTokens": 6,
},
"warnings": undefined,
},
{
"experimental_providerMetadata": undefined,
Expand Down Expand Up @@ -1260,6 +1263,7 @@ exports[`streamText > options.maxSteps > 4 steps: initial, continue, continue, c
"promptTokens": 10,
"totalTokens": 30,
},
"warnings": undefined,
},
{
"textDelta": "no-whitespace",
Expand All @@ -1283,6 +1287,7 @@ exports[`streamText > options.maxSteps > 4 steps: initial, continue, continue, c
"promptTokens": 30,
"totalTokens": 35,
},
"warnings": undefined,
},
{
"textDelta": "immediatefollow ",
Expand All @@ -1308,6 +1313,7 @@ exports[`streamText > options.maxSteps > 4 steps: initial, continue, continue, c
"promptTokens": 3,
"totalTokens": 5,
},
"warnings": undefined,
},
{
"textDelta": "final ",
Expand Down Expand Up @@ -1346,6 +1352,7 @@ exports[`streamText > options.maxSteps > 4 steps: initial, continue, continue, c
"promptTokens": 3,
"totalTokens": 5,
},
"warnings": undefined,
},
{
"experimental_providerMetadata": undefined,
Expand Down Expand Up @@ -1943,6 +1950,7 @@ exports[`streamText > result.fullStream > should filter out empty text deltas 1`
"promptTokens": 3,
"totalTokens": 13,
},
"warnings": undefined,
},
{
"experimental_providerMetadata": undefined,
Expand Down Expand Up @@ -1992,6 +2000,7 @@ exports[`streamText > result.fullStream > should not send tool call deltas when
"promptTokens": 53,
"totalTokens": 70,
},
"warnings": undefined,
},
{
"experimental_providerMetadata": undefined,
Expand Down Expand Up @@ -2050,6 +2059,7 @@ exports[`streamText > result.fullStream > should send delayed asynchronous tool
"promptTokens": 3,
"totalTokens": 13,
},
"warnings": undefined,
},
{
"experimental_providerMetadata": undefined,
Expand Down Expand Up @@ -2103,6 +2113,7 @@ exports[`streamText > result.fullStream > should send text deltas 1`] = `
"promptTokens": 3,
"totalTokens": 13,
},
"warnings": undefined,
},
{
"experimental_providerMetadata": undefined,
Expand Down Expand Up @@ -2199,6 +2210,7 @@ exports[`streamText > result.fullStream > should send tool call deltas when tool
"promptTokens": 53,
"totalTokens": 70,
},
"warnings": undefined,
},
{
"experimental_providerMetadata": undefined,
Expand Down Expand Up @@ -2248,6 +2260,7 @@ exports[`streamText > result.fullStream > should send tool calls 1`] = `
"promptTokens": 3,
"totalTokens": 13,
},
"warnings": undefined,
},
{
"experimental_providerMetadata": undefined,
Expand Down Expand Up @@ -2306,6 +2319,7 @@ exports[`streamText > result.fullStream > should send tool results 1`] = `
"promptTokens": 3,
"totalTokens": 13,
},
"warnings": undefined,
},
{
"experimental_providerMetadata": undefined,
Expand Down Expand Up @@ -2359,6 +2373,7 @@ exports[`streamText > result.fullStream > should use fallback response metadata
"promptTokens": 3,
"totalTokens": 13,
},
"warnings": undefined,
},
{
"experimental_providerMetadata": undefined,
Expand Down Expand Up @@ -3016,6 +3031,7 @@ exports[`streamText > tools with custom schema > should send tool calls 1`] = `
"promptTokens": 3,
"totalTokens": 13,
},
"warnings": undefined,
},
{
"experimental_providerMetadata": undefined,
Expand Down
2 changes: 1 addition & 1 deletion packages/ai/core/generate-text/stream-text-result.ts
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ export type TextStreamPart<TOOLS extends Record<string, CoreTool>> =
usage: LanguageModelUsage;
request: LanguageModelRequestMetadata;
response: LanguageModelResponseMetadata;
// TODO warnings
warnings: CallWarning[] | undefined;
experimental_providerMetadata?: ProviderMetadata;
isContinued: boolean;
}
Expand Down
9 changes: 5 additions & 4 deletions packages/ai/core/generate-text/stream-text.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2950,22 +2950,23 @@ describe('streamText', () => {
await convertAsyncIterableToArray(result.fullStream),
).toStrictEqual([
{
type: 'tool-call',
args: {
value: 'value',
},
toolCallId: 'call-1',
toolName: 'tool1',
type: 'tool-call',
},
{
type: 'error',
error: new ToolExecutionError({
toolName: 'tool1',
toolArgs: { value: 'value' },
cause: new Error('test error'),
}),
type: 'error',
},
{
type: 'step-finish',
experimental_providerMetadata: undefined,
finishReason: 'stop',
isContinued: false,
Expand All @@ -2977,14 +2978,15 @@ describe('streamText', () => {
timestamp: new Date(0),
headers: undefined,
},
type: 'step-finish',
warnings: undefined,
usage: {
completionTokens: 10,
promptTokens: 3,
totalTokens: 13,
},
},
{
type: 'finish',
experimental_providerMetadata: undefined,
finishReason: 'stop',
logprobs: undefined,
Expand All @@ -2994,7 +2996,6 @@ describe('streamText', () => {
timestamp: new Date(0),
headers: undefined,
},
type: 'finish',
usage: {
completionTokens: 10,
promptTokens: 3,
Expand Down
61 changes: 33 additions & 28 deletions packages/ai/core/generate-text/stream-text.ts
Original file line number Diff line number Diff line change
Expand Up @@ -396,11 +396,6 @@ class DefaultStreamTextResult<TOOLS extends Record<string, CoreTool>>
});
}

// warnings from last invoked step. should ideally be part of stream chunks
// e.g. step-finish, finish
let recordedWarnings: Array<LanguageModelV1CallWarning> | undefined =
undefined;

// event processor for telemetry, invoking callbacks, etc.
// The event processor reads the transformed stream to enable correct
// recording of the final transformed outputs.
Expand Down Expand Up @@ -481,7 +476,7 @@ class DefaultStreamTextResult<TOOLS extends Record<string, CoreTool>>
toolResults: recordedToolResults,
finishReason: chunk.finishReason,
usage: chunk.usage,
warnings: recordedWarnings, // TODO buggy
warnings: chunk.warnings,
logprobs: chunk.logprobs,
request: chunk.request,
response: {
Expand Down Expand Up @@ -518,25 +513,37 @@ class DefaultStreamTextResult<TOOLS extends Record<string, CoreTool>>
}
},

flush() {
self.warningsPromise.resolve(recordedWarnings ?? []);
self.finishReasonPromise.resolve(recordedFinishReason ?? 'unknown');
self.textPromise.resolve(recordedFullText);
self.requestPromise.resolve(recordedRequest ?? {});
self.responsePromise.resolve(recordedResponse);
self.toolCallsPromise.resolve(recordedSteps.at(-1)?.toolCalls ?? []);
self.toolResultsPromise.resolve(
recordedSteps.at(-1)?.toolResults ?? [],
);
self.providerMetadataPromise.resolve(recordedProviderMetadata);
self.usagePromise.resolve(
recordedUsage ?? {
completionTokens: NaN,
promptTokens: NaN,
totalTokens: NaN,
},
);
self.stepsPromise.resolve(recordedSteps);
flush(controller) {
try {
// from last step (when there are errors there may be no last step)
const lastStep = recordedSteps[recordedSteps.length - 1];
if (lastStep) {
self.warningsPromise.resolve(lastStep.warnings);
self.requestPromise.resolve(lastStep.request);
self.responsePromise.resolve(lastStep.response);
self.toolCallsPromise.resolve(lastStep.toolCalls);
self.toolResultsPromise.resolve(lastStep.toolResults);
self.providerMetadataPromise.resolve(
lastStep.experimental_providerMetadata,
);
}

// from finish:
self.finishReasonPromise.resolve(recordedFinishReason ?? 'unknown');
self.usagePromise.resolve(
recordedUsage ?? {
completionTokens: NaN,
promptTokens: NaN,
totalTokens: NaN,
},
);

// aggregate results:
self.textPromise.resolve(recordedFullText);
self.stepsPromise.resolve(recordedSteps);
} catch (error) {
controller.error(error);
}
},
});

Expand Down Expand Up @@ -977,6 +984,7 @@ class DefaultStreamTextResult<TOOLS extends Record<string, CoreTool>>
...stepResponse,
headers: rawResponse?.headers,
},
warnings,
isContinued: nextStepType === 'continue',
});

Expand Down Expand Up @@ -1090,9 +1098,6 @@ class DefaultStreamTextResult<TOOLS extends Record<string, CoreTool>>
}),
);

// update warnings
recordedWarnings = warnings;

// call onFinish callback:
await onFinish?.({
finishReason: stepFinishReason,
Expand Down

0 comments on commit 58317fc

Please sign in to comment.