mirror of
https://github.com/wavetermdev/waveterm
synced 2026-05-23 08:48:28 +00:00
openai native web search tool enabled (#2410)
This commit is contained in:
parent
fd0e75a984
commit
ef6366f6c6
5 changed files with 105 additions and 72 deletions
|
|
@ -122,10 +122,11 @@ func (m *OpenAIChatMessage) GetUsage() *uctypes.AIUsage {
|
|||
return nil
|
||||
}
|
||||
return &uctypes.AIUsage{
|
||||
APIType: "openai",
|
||||
Model: m.Usage.Model,
|
||||
InputTokens: m.Usage.InputTokens,
|
||||
OutputTokens: m.Usage.OutputTokens,
|
||||
APIType: "openai",
|
||||
Model: m.Usage.Model,
|
||||
InputTokens: m.Usage.InputTokens,
|
||||
OutputTokens: m.Usage.OutputTokens,
|
||||
NativeWebSearchCount: m.Usage.NativeWebSearchCount,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -281,12 +282,13 @@ type openaiTextFormat struct {
|
|||
}
|
||||
|
||||
type OpenAIUsage struct {
|
||||
InputTokens int `json:"input_tokens,omitempty"`
|
||||
OutputTokens int `json:"output_tokens,omitempty"`
|
||||
TotalTokens int `json:"total_tokens,omitempty"`
|
||||
InputTokensDetails *openaiInputTokensDetails `json:"input_tokens_details,omitempty"`
|
||||
OutputTokensDetails *openaiOutputTokensDetails `json:"output_tokens_details,omitempty"`
|
||||
Model string `json:"model,omitempty"` // internal field (not from OpenAI API)
|
||||
InputTokens int `json:"input_tokens,omitempty"`
|
||||
OutputTokens int `json:"output_tokens,omitempty"`
|
||||
TotalTokens int `json:"total_tokens,omitempty"`
|
||||
InputTokensDetails *openaiInputTokensDetails `json:"input_tokens_details,omitempty"`
|
||||
OutputTokensDetails *openaiOutputTokensDetails `json:"output_tokens_details,omitempty"`
|
||||
Model string `json:"model,omitempty"` // internal field (not from OpenAI API)
|
||||
NativeWebSearchCount int `json:"nativewebsearchcount,omitempty"` // internal field (not from OpenAI API)
|
||||
}
|
||||
|
||||
type openaiInputTokensDetails struct {
|
||||
|
|
@ -323,12 +325,13 @@ type openaiBlockState struct {
|
|||
}
|
||||
|
||||
type openaiStreamingState struct {
|
||||
blockMap map[string]*openaiBlockState // Use item_id as key for UI streaming
|
||||
toolUseData map[string]*uctypes.UIMessageDataToolUse // Use toolCallId as key
|
||||
msgID string
|
||||
model string
|
||||
stepStarted bool
|
||||
chatOpts uctypes.WaveChatOpts
|
||||
blockMap map[string]*openaiBlockState // Use item_id as key for UI streaming
|
||||
toolUseData map[string]*uctypes.UIMessageDataToolUse // Use toolCallId as key
|
||||
msgID string
|
||||
model string
|
||||
stepStarted bool
|
||||
chatOpts uctypes.WaveChatOpts
|
||||
webSearchCount int
|
||||
}
|
||||
|
||||
// ---------- Public entrypoint ----------
|
||||
|
|
@ -759,7 +762,7 @@ func handleOpenAIEvent(
|
|||
}
|
||||
|
||||
// Extract partial message if available
|
||||
finalMessages, _ := extractMessageAndToolsFromResponse(ev.Response, state.toolUseData)
|
||||
finalMessages, _ := extractMessageAndToolsFromResponse(ev.Response, state)
|
||||
|
||||
_ = sse.AiMsgError(errorMsg)
|
||||
return &uctypes.WaveStopReason{
|
||||
|
|
@ -772,7 +775,7 @@ func handleOpenAIEvent(
|
|||
}
|
||||
|
||||
// Extract the final message and tool calls from the response output
|
||||
finalMessages, toolCalls := extractMessageAndToolsFromResponse(ev.Response, state.toolUseData)
|
||||
finalMessages, toolCalls := extractMessageAndToolsFromResponse(ev.Response, state)
|
||||
|
||||
stopKind := uctypes.StopKindDone
|
||||
if len(toolCalls) > 0 {
|
||||
|
|
@ -820,6 +823,19 @@ func handleOpenAIEvent(
|
|||
}
|
||||
return nil, nil
|
||||
|
||||
case "response.web_search_call.in_progress":
|
||||
return nil, nil
|
||||
|
||||
case "response.web_search_call.searching":
|
||||
return nil, nil
|
||||
|
||||
case "response.web_search_call.completed":
|
||||
state.webSearchCount++
|
||||
return nil, nil
|
||||
|
||||
case "response.output_text.annotation.added":
|
||||
return nil, nil
|
||||
|
||||
default:
|
||||
// log unknown events for debugging
|
||||
log.Printf("OpenAI: unknown event: %s, data: %s", eventName, data)
|
||||
|
|
@ -857,9 +873,8 @@ func createToolUseData(toolCallID, toolName string, toolDef *uctypes.ToolDefinit
|
|||
return toolUseData
|
||||
}
|
||||
|
||||
|
||||
// extractMessageAndToolsFromResponse extracts the final OpenAI message and tool calls from the completed response
|
||||
func extractMessageAndToolsFromResponse(resp openaiResponse, toolUseData map[string]*uctypes.UIMessageDataToolUse) ([]*OpenAIChatMessage, []uctypes.WaveToolCall) {
|
||||
func extractMessageAndToolsFromResponse(resp openaiResponse, state *openaiStreamingState) ([]*OpenAIChatMessage, []uctypes.WaveToolCall) {
|
||||
var messageContent []OpenAIMessageContent
|
||||
var toolCalls []uctypes.WaveToolCall
|
||||
var messages []*OpenAIChatMessage
|
||||
|
|
@ -893,7 +908,7 @@ func extractMessageAndToolsFromResponse(resp openaiResponse, toolUseData map[str
|
|||
}
|
||||
|
||||
// Attach UIToolUseData if available
|
||||
if data, ok := toolUseData[outputItem.CallId]; ok {
|
||||
if data, ok := state.toolUseData[outputItem.CallId]; ok {
|
||||
toolCall.ToolUseData = data
|
||||
} else {
|
||||
log.Printf("AI no data-tooluse for %s (callid: %s)\n", outputItem.Id, outputItem.CallId)
|
||||
|
|
@ -907,7 +922,7 @@ func extractMessageAndToolsFromResponse(resp openaiResponse, toolUseData map[str
|
|||
argsStr = outputItem.Arguments
|
||||
}
|
||||
var toolUseDataPtr *uctypes.UIMessageDataToolUse
|
||||
if data, ok := toolUseData[outputItem.CallId]; ok {
|
||||
if data, ok := state.toolUseData[outputItem.CallId]; ok {
|
||||
toolUseDataPtr = data
|
||||
}
|
||||
functionCallMsg := &OpenAIChatMessage{
|
||||
|
|
@ -925,17 +940,20 @@ func extractMessageAndToolsFromResponse(resp openaiResponse, toolUseData map[str
|
|||
}
|
||||
|
||||
// Create OpenAIChatMessage with assistant message (first in slice)
|
||||
if resp.Usage != nil {
|
||||
usage := resp.Usage
|
||||
if usage != nil {
|
||||
resp.Usage.Model = resp.Model
|
||||
if state.webSearchCount > 0 {
|
||||
usage.NativeWebSearchCount = state.webSearchCount
|
||||
}
|
||||
}
|
||||
|
||||
assistantMessage := &OpenAIChatMessage{
|
||||
MessageId: uuid.New().String(),
|
||||
Message: &OpenAIMessage{
|
||||
Role: "assistant",
|
||||
Content: messageContent,
|
||||
},
|
||||
Usage: resp.Usage,
|
||||
Usage: usage,
|
||||
}
|
||||
|
||||
// Return assistant message first, followed by function call messages
|
||||
|
|
|
|||
|
|
@ -75,11 +75,11 @@ type OpenAIRequest struct {
|
|||
}
|
||||
|
||||
type OpenAIRequestTool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Parameters any `json:"parameters"`
|
||||
Strict bool `json:"strict"`
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Parameters any `json:"parameters,omitempty"`
|
||||
Strict bool `json:"strict,omitempty"`
|
||||
}
|
||||
|
||||
// ConvertToolDefinitionToOpenAI converts a generic ToolDefinition to OpenAI format
|
||||
|
|
@ -113,13 +113,13 @@ func debugPrintReq(req *OpenAIRequest, endpoint string) {
|
|||
// buildOpenAIHTTPRequest creates a complete HTTP request for the OpenAI API
|
||||
func buildOpenAIHTTPRequest(ctx context.Context, inputs []any, chatOpts uctypes.WaveChatOpts, cont *uctypes.WaveContinueResponse) (*http.Request, error) {
|
||||
opts := chatOpts.Config
|
||||
|
||||
|
||||
// If continuing from premium rate limit, downgrade to default model and low thinking
|
||||
if cont != nil && cont.ContinueFromKind == uctypes.StopKindPremiumRateLimit {
|
||||
opts.Model = uctypes.DefaultOpenAIModel
|
||||
opts.ThinkingLevel = uctypes.ThinkingLevelLow
|
||||
}
|
||||
|
||||
|
||||
if opts.Model == "" {
|
||||
return nil, errors.New("opts.model is required")
|
||||
}
|
||||
|
|
@ -183,6 +183,14 @@ func buildOpenAIHTTPRequest(ctx context.Context, inputs []any, chatOpts uctypes.
|
|||
reqBody.Tools = append(reqBody.Tools, convertedTool)
|
||||
}
|
||||
|
||||
// Add native web search tool if enabled
|
||||
if chatOpts.AllowNativeWebSearch {
|
||||
webSearchTool := OpenAIRequestTool{
|
||||
Type: "web_search",
|
||||
}
|
||||
reqBody.Tools = append(reqBody.Tools, webSearchTool)
|
||||
}
|
||||
|
||||
// Set reasoning based on thinking level
|
||||
if opts.ThinkingLevel != "" {
|
||||
reqBody.Reasoning = &ReasoningType{
|
||||
|
|
|
|||
|
|
@ -222,10 +222,11 @@ type AIChat struct {
|
|||
}
|
||||
|
||||
type AIUsage struct {
|
||||
APIType string `json:"apitype"`
|
||||
Model string `json:"model"`
|
||||
InputTokens int `json:"inputtokens,omitempty"`
|
||||
OutputTokens int `json:"outputtokens,omitempty"`
|
||||
APIType string `json:"apitype"`
|
||||
Model string `json:"model"`
|
||||
InputTokens int `json:"inputtokens,omitempty"`
|
||||
OutputTokens int `json:"outputtokens,omitempty"`
|
||||
NativeWebSearchCount int `json:"nativewebsearchcount,omitempty"`
|
||||
}
|
||||
|
||||
type AIMetrics struct {
|
||||
|
|
@ -424,6 +425,7 @@ type WaveChatOpts struct {
|
|||
TabStateGenerator func() (string, []ToolDefinition, error)
|
||||
WidgetAccess bool
|
||||
RegisterToolApproval func(string)
|
||||
AllowNativeWebSearch bool
|
||||
|
||||
// emphemeral to the step
|
||||
TabState string
|
||||
|
|
|
|||
|
|
@ -191,6 +191,7 @@ func getUsage(msgs []uctypes.GenAIMessage) uctypes.AIUsage {
|
|||
} else {
|
||||
rtn.InputTokens += usage.InputTokens
|
||||
rtn.OutputTokens += usage.OutputTokens
|
||||
rtn.NativeWebSearchCount += usage.NativeWebSearchCount
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -369,9 +370,10 @@ func RunAIChat(ctx context.Context, sseHandler *sse.SSEHandlerCh, chatOpts uctyp
|
|||
}
|
||||
if len(rtnMessage) > 0 {
|
||||
usage := getUsage(rtnMessage)
|
||||
log.Printf("usage: input=%d output=%d\n", usage.InputTokens, usage.OutputTokens)
|
||||
log.Printf("usage: input=%d output=%d websearch=%d\n", usage.InputTokens, usage.OutputTokens, usage.NativeWebSearchCount)
|
||||
metrics.Usage.InputTokens += usage.InputTokens
|
||||
metrics.Usage.OutputTokens += usage.OutputTokens
|
||||
metrics.Usage.NativeWebSearchCount += usage.NativeWebSearchCount
|
||||
if usage.Model != "" && metrics.Usage.Model != usage.Model {
|
||||
metrics.Usage.Model = "mixed"
|
||||
}
|
||||
|
|
@ -526,24 +528,25 @@ func WaveAIPostMessageWrap(ctx context.Context, sseHandler *sse.SSEHandlerCh, me
|
|||
|
||||
func sendAIMetricsTelemetry(ctx context.Context, metrics *uctypes.AIMetrics) {
|
||||
event := telemetrydata.MakeTEvent("waveai:post", telemetrydata.TEventProps{
|
||||
WaveAIAPIType: metrics.Usage.APIType,
|
||||
WaveAIModel: metrics.Usage.Model,
|
||||
WaveAIInputTokens: metrics.Usage.InputTokens,
|
||||
WaveAIOutputTokens: metrics.Usage.OutputTokens,
|
||||
WaveAIRequestCount: metrics.RequestCount,
|
||||
WaveAIToolUseCount: metrics.ToolUseCount,
|
||||
WaveAIToolUseErrorCount: metrics.ToolUseErrorCount,
|
||||
WaveAIToolDetail: metrics.ToolDetail,
|
||||
WaveAIPremiumReq: metrics.PremiumReqCount,
|
||||
WaveAIProxyReq: metrics.ProxyReqCount,
|
||||
WaveAIHadError: metrics.HadError,
|
||||
WaveAIImageCount: metrics.ImageCount,
|
||||
WaveAIPDFCount: metrics.PDFCount,
|
||||
WaveAITextDocCount: metrics.TextDocCount,
|
||||
WaveAITextLen: metrics.TextLen,
|
||||
WaveAIFirstByteMs: metrics.FirstByteLatency,
|
||||
WaveAIRequestDurMs: metrics.RequestDuration,
|
||||
WaveAIWidgetAccess: metrics.WidgetAccess,
|
||||
WaveAIAPIType: metrics.Usage.APIType,
|
||||
WaveAIModel: metrics.Usage.Model,
|
||||
WaveAIInputTokens: metrics.Usage.InputTokens,
|
||||
WaveAIOutputTokens: metrics.Usage.OutputTokens,
|
||||
WaveAINativeWebSearchCount: metrics.Usage.NativeWebSearchCount,
|
||||
WaveAIRequestCount: metrics.RequestCount,
|
||||
WaveAIToolUseCount: metrics.ToolUseCount,
|
||||
WaveAIToolUseErrorCount: metrics.ToolUseErrorCount,
|
||||
WaveAIToolDetail: metrics.ToolDetail,
|
||||
WaveAIPremiumReq: metrics.PremiumReqCount,
|
||||
WaveAIProxyReq: metrics.ProxyReqCount,
|
||||
WaveAIHadError: metrics.HadError,
|
||||
WaveAIImageCount: metrics.ImageCount,
|
||||
WaveAIPDFCount: metrics.PDFCount,
|
||||
WaveAITextDocCount: metrics.TextDocCount,
|
||||
WaveAITextLen: metrics.TextLen,
|
||||
WaveAIFirstByteMs: metrics.FirstByteLatency,
|
||||
WaveAIRequestDurMs: metrics.RequestDuration,
|
||||
WaveAIWidgetAccess: metrics.WidgetAccess,
|
||||
})
|
||||
_ = telemetry.RecordTEvent(ctx, event)
|
||||
}
|
||||
|
|
@ -602,6 +605,7 @@ func WaveAIPostMessageHandler(w http.ResponseWriter, r *http.Request) {
|
|||
Config: *aiOpts,
|
||||
WidgetAccess: req.WidgetAccess,
|
||||
RegisterToolApproval: RegisterToolApproval,
|
||||
AllowNativeWebSearch: true,
|
||||
}
|
||||
if chatOpts.Config.APIType == APIType_OpenAI {
|
||||
chatOpts.SystemPrompt = []string{SystemPromptText_OpenAI}
|
||||
|
|
|
|||
|
|
@ -101,24 +101,25 @@ type TEventProps struct {
|
|||
CountWSLConn int `json:"count:wslconn,omitempty"`
|
||||
CountViews map[string]int `json:"count:views,omitempty"`
|
||||
|
||||
WaveAIAPIType string `json:"waveai:apitype,omitempty"`
|
||||
WaveAIModel string `json:"waveai:model,omitempty"`
|
||||
WaveAIInputTokens int `json:"waveai:inputtokens,omitempty"`
|
||||
WaveAIOutputTokens int `json:"waveai:outputtokens,omitempty"`
|
||||
WaveAIRequestCount int `json:"waveai:requestcount,omitempty"`
|
||||
WaveAIToolUseCount int `json:"waveai:toolusecount,omitempty"`
|
||||
WaveAIToolUseErrorCount int `json:"waveai:tooluseerrorcount,omitempty"`
|
||||
WaveAIToolDetail map[string]int `json:"waveai:tooldetail,omitempty"`
|
||||
WaveAIPremiumReq int `json:"waveai:premiumreq,omitempty"`
|
||||
WaveAIProxyReq int `json:"waveai:proxyreq,omitempty"`
|
||||
WaveAIHadError bool `json:"waveai:haderror,omitempty"`
|
||||
WaveAIImageCount int `json:"waveai:imagecount,omitempty"`
|
||||
WaveAIPDFCount int `json:"waveai:pdfcount,omitempty"`
|
||||
WaveAITextDocCount int `json:"waveai:textdoccount,omitempty"`
|
||||
WaveAITextLen int `json:"waveai:textlen,omitempty"`
|
||||
WaveAIFirstByteMs int `json:"waveai:firstbytems,omitempty"` // ms
|
||||
WaveAIRequestDurMs int `json:"waveai:requestdurms,omitempty"` // ms
|
||||
WaveAIWidgetAccess bool `json:"waveai:widgetaccess,omitempty"`
|
||||
WaveAIAPIType string `json:"waveai:apitype,omitempty"`
|
||||
WaveAIModel string `json:"waveai:model,omitempty"`
|
||||
WaveAIInputTokens int `json:"waveai:inputtokens,omitempty"`
|
||||
WaveAIOutputTokens int `json:"waveai:outputtokens,omitempty"`
|
||||
WaveAINativeWebSearchCount int `json:"waveai:nativewebsearchcount,omitempty"`
|
||||
WaveAIRequestCount int `json:"waveai:requestcount,omitempty"`
|
||||
WaveAIToolUseCount int `json:"waveai:toolusecount,omitempty"`
|
||||
WaveAIToolUseErrorCount int `json:"waveai:tooluseerrorcount,omitempty"`
|
||||
WaveAIToolDetail map[string]int `json:"waveai:tooldetail,omitempty"`
|
||||
WaveAIPremiumReq int `json:"waveai:premiumreq,omitempty"`
|
||||
WaveAIProxyReq int `json:"waveai:proxyreq,omitempty"`
|
||||
WaveAIHadError bool `json:"waveai:haderror,omitempty"`
|
||||
WaveAIImageCount int `json:"waveai:imagecount,omitempty"`
|
||||
WaveAIPDFCount int `json:"waveai:pdfcount,omitempty"`
|
||||
WaveAITextDocCount int `json:"waveai:textdoccount,omitempty"`
|
||||
WaveAITextLen int `json:"waveai:textlen,omitempty"`
|
||||
WaveAIFirstByteMs int `json:"waveai:firstbytems,omitempty"` // ms
|
||||
WaveAIRequestDurMs int `json:"waveai:requestdurms,omitempty"` // ms
|
||||
WaveAIWidgetAccess bool `json:"waveai:widgetaccess,omitempty"`
|
||||
|
||||
UserSet *TEventUserProps `json:"$set,omitempty"`
|
||||
UserSetOnce *TEventUserProps `json:"$set_once,omitempty"`
|
||||
|
|
|
|||
Loading…
Reference in a new issue