openai native web search tool enabled (#2410)

This commit is contained in:
Mike Sawka 2025-10-09 15:06:40 -07:00 committed by GitHub
parent fd0e75a984
commit ef6366f6c6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 105 additions and 72 deletions

View file

@ -122,10 +122,11 @@ func (m *OpenAIChatMessage) GetUsage() *uctypes.AIUsage {
return nil
}
return &uctypes.AIUsage{
APIType: "openai",
Model: m.Usage.Model,
InputTokens: m.Usage.InputTokens,
OutputTokens: m.Usage.OutputTokens,
APIType: "openai",
Model: m.Usage.Model,
InputTokens: m.Usage.InputTokens,
OutputTokens: m.Usage.OutputTokens,
NativeWebSearchCount: m.Usage.NativeWebSearchCount,
}
}
@ -281,12 +282,13 @@ type openaiTextFormat struct {
}
type OpenAIUsage struct {
InputTokens int `json:"input_tokens,omitempty"`
OutputTokens int `json:"output_tokens,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"`
InputTokensDetails *openaiInputTokensDetails `json:"input_tokens_details,omitempty"`
OutputTokensDetails *openaiOutputTokensDetails `json:"output_tokens_details,omitempty"`
Model string `json:"model,omitempty"` // internal field (not from OpenAI API)
InputTokens int `json:"input_tokens,omitempty"`
OutputTokens int `json:"output_tokens,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"`
InputTokensDetails *openaiInputTokensDetails `json:"input_tokens_details,omitempty"`
OutputTokensDetails *openaiOutputTokensDetails `json:"output_tokens_details,omitempty"`
Model string `json:"model,omitempty"` // internal field (not from OpenAI API)
NativeWebSearchCount int `json:"nativewebsearchcount,omitempty"` // internal field (not from OpenAI API)
}
type openaiInputTokensDetails struct {
@ -323,12 +325,13 @@ type openaiBlockState struct {
}
type openaiStreamingState struct {
blockMap map[string]*openaiBlockState // Use item_id as key for UI streaming
toolUseData map[string]*uctypes.UIMessageDataToolUse // Use toolCallId as key
msgID string
model string
stepStarted bool
chatOpts uctypes.WaveChatOpts
blockMap map[string]*openaiBlockState // Use item_id as key for UI streaming
toolUseData map[string]*uctypes.UIMessageDataToolUse // Use toolCallId as key
msgID string
model string
stepStarted bool
chatOpts uctypes.WaveChatOpts
webSearchCount int
}
// ---------- Public entrypoint ----------
@ -759,7 +762,7 @@ func handleOpenAIEvent(
}
// Extract partial message if available
finalMessages, _ := extractMessageAndToolsFromResponse(ev.Response, state.toolUseData)
finalMessages, _ := extractMessageAndToolsFromResponse(ev.Response, state)
_ = sse.AiMsgError(errorMsg)
return &uctypes.WaveStopReason{
@ -772,7 +775,7 @@ func handleOpenAIEvent(
}
// Extract the final message and tool calls from the response output
finalMessages, toolCalls := extractMessageAndToolsFromResponse(ev.Response, state.toolUseData)
finalMessages, toolCalls := extractMessageAndToolsFromResponse(ev.Response, state)
stopKind := uctypes.StopKindDone
if len(toolCalls) > 0 {
@ -820,6 +823,19 @@ func handleOpenAIEvent(
}
return nil, nil
case "response.web_search_call.in_progress":
return nil, nil
case "response.web_search_call.searching":
return nil, nil
case "response.web_search_call.completed":
state.webSearchCount++
return nil, nil
case "response.output_text.annotation.added":
return nil, nil
default:
// log unknown events for debugging
log.Printf("OpenAI: unknown event: %s, data: %s", eventName, data)
@ -857,9 +873,8 @@ func createToolUseData(toolCallID, toolName string, toolDef *uctypes.ToolDefinit
return toolUseData
}
// extractMessageAndToolsFromResponse extracts the final OpenAI message and tool calls from the completed response
func extractMessageAndToolsFromResponse(resp openaiResponse, toolUseData map[string]*uctypes.UIMessageDataToolUse) ([]*OpenAIChatMessage, []uctypes.WaveToolCall) {
func extractMessageAndToolsFromResponse(resp openaiResponse, state *openaiStreamingState) ([]*OpenAIChatMessage, []uctypes.WaveToolCall) {
var messageContent []OpenAIMessageContent
var toolCalls []uctypes.WaveToolCall
var messages []*OpenAIChatMessage
@ -893,7 +908,7 @@ func extractMessageAndToolsFromResponse(resp openaiResponse, toolUseData map[str
}
// Attach UIToolUseData if available
if data, ok := toolUseData[outputItem.CallId]; ok {
if data, ok := state.toolUseData[outputItem.CallId]; ok {
toolCall.ToolUseData = data
} else {
log.Printf("AI no data-tooluse for %s (callid: %s)\n", outputItem.Id, outputItem.CallId)
@ -907,7 +922,7 @@ func extractMessageAndToolsFromResponse(resp openaiResponse, toolUseData map[str
argsStr = outputItem.Arguments
}
var toolUseDataPtr *uctypes.UIMessageDataToolUse
if data, ok := toolUseData[outputItem.CallId]; ok {
if data, ok := state.toolUseData[outputItem.CallId]; ok {
toolUseDataPtr = data
}
functionCallMsg := &OpenAIChatMessage{
@ -925,17 +940,20 @@ func extractMessageAndToolsFromResponse(resp openaiResponse, toolUseData map[str
}
// Create OpenAIChatMessage with assistant message (first in slice)
if resp.Usage != nil {
usage := resp.Usage
if usage != nil {
resp.Usage.Model = resp.Model
if state.webSearchCount > 0 {
usage.NativeWebSearchCount = state.webSearchCount
}
}
assistantMessage := &OpenAIChatMessage{
MessageId: uuid.New().String(),
Message: &OpenAIMessage{
Role: "assistant",
Content: messageContent,
},
Usage: resp.Usage,
Usage: usage,
}
// Return assistant message first, followed by function call messages

View file

@ -75,11 +75,11 @@ type OpenAIRequest struct {
}
type OpenAIRequestTool struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Parameters any `json:"parameters"`
Strict bool `json:"strict"`
Type string `json:"type"`
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
Parameters any `json:"parameters,omitempty"`
Strict bool `json:"strict,omitempty"`
}
// ConvertToolDefinitionToOpenAI converts a generic ToolDefinition to OpenAI format
@ -113,13 +113,13 @@ func debugPrintReq(req *OpenAIRequest, endpoint string) {
// buildOpenAIHTTPRequest creates a complete HTTP request for the OpenAI API
func buildOpenAIHTTPRequest(ctx context.Context, inputs []any, chatOpts uctypes.WaveChatOpts, cont *uctypes.WaveContinueResponse) (*http.Request, error) {
opts := chatOpts.Config
// If continuing from premium rate limit, downgrade to default model and low thinking
if cont != nil && cont.ContinueFromKind == uctypes.StopKindPremiumRateLimit {
opts.Model = uctypes.DefaultOpenAIModel
opts.ThinkingLevel = uctypes.ThinkingLevelLow
}
if opts.Model == "" {
return nil, errors.New("opts.model is required")
}
@ -183,6 +183,14 @@ func buildOpenAIHTTPRequest(ctx context.Context, inputs []any, chatOpts uctypes.
reqBody.Tools = append(reqBody.Tools, convertedTool)
}
// Add native web search tool if enabled
if chatOpts.AllowNativeWebSearch {
webSearchTool := OpenAIRequestTool{
Type: "web_search",
}
reqBody.Tools = append(reqBody.Tools, webSearchTool)
}
// Set reasoning based on thinking level
if opts.ThinkingLevel != "" {
reqBody.Reasoning = &ReasoningType{

View file

@ -222,10 +222,11 @@ type AIChat struct {
}
type AIUsage struct {
APIType string `json:"apitype"`
Model string `json:"model"`
InputTokens int `json:"inputtokens,omitempty"`
OutputTokens int `json:"outputtokens,omitempty"`
APIType string `json:"apitype"`
Model string `json:"model"`
InputTokens int `json:"inputtokens,omitempty"`
OutputTokens int `json:"outputtokens,omitempty"`
NativeWebSearchCount int `json:"nativewebsearchcount,omitempty"`
}
type AIMetrics struct {
@ -424,6 +425,7 @@ type WaveChatOpts struct {
TabStateGenerator func() (string, []ToolDefinition, error)
WidgetAccess bool
RegisterToolApproval func(string)
AllowNativeWebSearch bool
// emphemeral to the step
TabState string

View file

@ -191,6 +191,7 @@ func getUsage(msgs []uctypes.GenAIMessage) uctypes.AIUsage {
} else {
rtn.InputTokens += usage.InputTokens
rtn.OutputTokens += usage.OutputTokens
rtn.NativeWebSearchCount += usage.NativeWebSearchCount
}
}
}
@ -369,9 +370,10 @@ func RunAIChat(ctx context.Context, sseHandler *sse.SSEHandlerCh, chatOpts uctyp
}
if len(rtnMessage) > 0 {
usage := getUsage(rtnMessage)
log.Printf("usage: input=%d output=%d\n", usage.InputTokens, usage.OutputTokens)
log.Printf("usage: input=%d output=%d websearch=%d\n", usage.InputTokens, usage.OutputTokens, usage.NativeWebSearchCount)
metrics.Usage.InputTokens += usage.InputTokens
metrics.Usage.OutputTokens += usage.OutputTokens
metrics.Usage.NativeWebSearchCount += usage.NativeWebSearchCount
if usage.Model != "" && metrics.Usage.Model != usage.Model {
metrics.Usage.Model = "mixed"
}
@ -526,24 +528,25 @@ func WaveAIPostMessageWrap(ctx context.Context, sseHandler *sse.SSEHandlerCh, me
func sendAIMetricsTelemetry(ctx context.Context, metrics *uctypes.AIMetrics) {
event := telemetrydata.MakeTEvent("waveai:post", telemetrydata.TEventProps{
WaveAIAPIType: metrics.Usage.APIType,
WaveAIModel: metrics.Usage.Model,
WaveAIInputTokens: metrics.Usage.InputTokens,
WaveAIOutputTokens: metrics.Usage.OutputTokens,
WaveAIRequestCount: metrics.RequestCount,
WaveAIToolUseCount: metrics.ToolUseCount,
WaveAIToolUseErrorCount: metrics.ToolUseErrorCount,
WaveAIToolDetail: metrics.ToolDetail,
WaveAIPremiumReq: metrics.PremiumReqCount,
WaveAIProxyReq: metrics.ProxyReqCount,
WaveAIHadError: metrics.HadError,
WaveAIImageCount: metrics.ImageCount,
WaveAIPDFCount: metrics.PDFCount,
WaveAITextDocCount: metrics.TextDocCount,
WaveAITextLen: metrics.TextLen,
WaveAIFirstByteMs: metrics.FirstByteLatency,
WaveAIRequestDurMs: metrics.RequestDuration,
WaveAIWidgetAccess: metrics.WidgetAccess,
WaveAIAPIType: metrics.Usage.APIType,
WaveAIModel: metrics.Usage.Model,
WaveAIInputTokens: metrics.Usage.InputTokens,
WaveAIOutputTokens: metrics.Usage.OutputTokens,
WaveAINativeWebSearchCount: metrics.Usage.NativeWebSearchCount,
WaveAIRequestCount: metrics.RequestCount,
WaveAIToolUseCount: metrics.ToolUseCount,
WaveAIToolUseErrorCount: metrics.ToolUseErrorCount,
WaveAIToolDetail: metrics.ToolDetail,
WaveAIPremiumReq: metrics.PremiumReqCount,
WaveAIProxyReq: metrics.ProxyReqCount,
WaveAIHadError: metrics.HadError,
WaveAIImageCount: metrics.ImageCount,
WaveAIPDFCount: metrics.PDFCount,
WaveAITextDocCount: metrics.TextDocCount,
WaveAITextLen: metrics.TextLen,
WaveAIFirstByteMs: metrics.FirstByteLatency,
WaveAIRequestDurMs: metrics.RequestDuration,
WaveAIWidgetAccess: metrics.WidgetAccess,
})
_ = telemetry.RecordTEvent(ctx, event)
}
@ -602,6 +605,7 @@ func WaveAIPostMessageHandler(w http.ResponseWriter, r *http.Request) {
Config: *aiOpts,
WidgetAccess: req.WidgetAccess,
RegisterToolApproval: RegisterToolApproval,
AllowNativeWebSearch: true,
}
if chatOpts.Config.APIType == APIType_OpenAI {
chatOpts.SystemPrompt = []string{SystemPromptText_OpenAI}

View file

@ -101,24 +101,25 @@ type TEventProps struct {
CountWSLConn int `json:"count:wslconn,omitempty"`
CountViews map[string]int `json:"count:views,omitempty"`
WaveAIAPIType string `json:"waveai:apitype,omitempty"`
WaveAIModel string `json:"waveai:model,omitempty"`
WaveAIInputTokens int `json:"waveai:inputtokens,omitempty"`
WaveAIOutputTokens int `json:"waveai:outputtokens,omitempty"`
WaveAIRequestCount int `json:"waveai:requestcount,omitempty"`
WaveAIToolUseCount int `json:"waveai:toolusecount,omitempty"`
WaveAIToolUseErrorCount int `json:"waveai:tooluseerrorcount,omitempty"`
WaveAIToolDetail map[string]int `json:"waveai:tooldetail,omitempty"`
WaveAIPremiumReq int `json:"waveai:premiumreq,omitempty"`
WaveAIProxyReq int `json:"waveai:proxyreq,omitempty"`
WaveAIHadError bool `json:"waveai:haderror,omitempty"`
WaveAIImageCount int `json:"waveai:imagecount,omitempty"`
WaveAIPDFCount int `json:"waveai:pdfcount,omitempty"`
WaveAITextDocCount int `json:"waveai:textdoccount,omitempty"`
WaveAITextLen int `json:"waveai:textlen,omitempty"`
WaveAIFirstByteMs int `json:"waveai:firstbytems,omitempty"` // ms
WaveAIRequestDurMs int `json:"waveai:requestdurms,omitempty"` // ms
WaveAIWidgetAccess bool `json:"waveai:widgetaccess,omitempty"`
WaveAIAPIType string `json:"waveai:apitype,omitempty"`
WaveAIModel string `json:"waveai:model,omitempty"`
WaveAIInputTokens int `json:"waveai:inputtokens,omitempty"`
WaveAIOutputTokens int `json:"waveai:outputtokens,omitempty"`
WaveAINativeWebSearchCount int `json:"waveai:nativewebsearchcount,omitempty"`
WaveAIRequestCount int `json:"waveai:requestcount,omitempty"`
WaveAIToolUseCount int `json:"waveai:toolusecount,omitempty"`
WaveAIToolUseErrorCount int `json:"waveai:tooluseerrorcount,omitempty"`
WaveAIToolDetail map[string]int `json:"waveai:tooldetail,omitempty"`
WaveAIPremiumReq int `json:"waveai:premiumreq,omitempty"`
WaveAIProxyReq int `json:"waveai:proxyreq,omitempty"`
WaveAIHadError bool `json:"waveai:haderror,omitempty"`
WaveAIImageCount int `json:"waveai:imagecount,omitempty"`
WaveAIPDFCount int `json:"waveai:pdfcount,omitempty"`
WaveAITextDocCount int `json:"waveai:textdoccount,omitempty"`
WaveAITextLen int `json:"waveai:textlen,omitempty"`
WaveAIFirstByteMs int `json:"waveai:firstbytems,omitempty"` // ms
WaveAIRequestDurMs int `json:"waveai:requestdurms,omitempty"` // ms
WaveAIWidgetAccess bool `json:"waveai:widgetaccess,omitempty"`
UserSet *TEventUserProps `json:"$set,omitempty"`
UserSetOnce *TEventUserProps `json:"$set_once,omitempty"`