waveterm/pkg/aiusechat/tools.go
Copilot 4449895424
Add Google Gemini backend for AI chat (#2602)
- [x] Add new API type constant for Google Gemini in uctypes.go
- [x] Create gemini directory under pkg/aiusechat/
- [x] Implement gemini-backend.go with streaming chat support
- [x] Implement gemini-convertmessage.go for message conversion
- [x] Implement gemini-types.go for Google-specific types
- [x] Add gemini backend to usechat-backend.go
- [x] Support tool calling with structured arguments
- [x] Support image upload (base64 inline data)
- [x] Support PDF upload (base64 inline data)
- [x] Support file upload (text files, directory listings)
- [x] Build verification passed
- [x] Add documentation for Gemini backend usage
- [x] Security scan passed (CodeQL found 0 issues)
- [x] Code review passed with no comments
- [x] Revert tsunami demo go.mod/go.sum files (per feedback - twice)
- [x] Add `--gemini` flag to main-testai.go for testing
- [x] Fix schema validation for tool calling (clean unsupported fields)
- [x] Preserve non-map property values in schema cleaning

## Summary

Successfully implemented a complete Google Gemini backend for WaveTerm's
AI chat system. The implementation:

- **Follows existing patterns**: Matches the structure of OpenAI and
Anthropic backends
- **Fully featured**: Supports all required capabilities including tool
calling, images, PDFs, and files
- **Properly tested**: Builds successfully with no errors or warnings
- **Secure**: Passed CodeQL security scanning with 0 issues
- **Well documented**: Includes comprehensive package documentation with
usage examples
- **Minimal changes**: Only affects backend code under pkg/aiusechat
(tsunami demo files reverted twice)
- **Testable**: Added `--gemini` flag to main-testai.go for easy testing
with SSE output
- **Schema compatible**: Cleans JSON schemas to remove fields
unsupported by Gemini API while preserving valid structure

## Testing

To test the Gemini backend using main-testai.go:
```bash
export GOOGLE_APIKEY="your-api-key"
cd cmd/testai
go run main-testai.go --gemini 'What is 2+2?'
go run main-testai.go --gemini --model gemini-1.5-pro 'Explain quantum computing'
go run main-testai.go --gemini --tools 'Help me configure GitHub Actions monitoring'
```


Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: sawka <2722291+sawka@users.noreply.github.com>
2025-12-05 12:43:42 -08:00

318 lines
9 KiB
Go

// Copyright 2025, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package aiusechat
import (
"context"
"fmt"
"os/user"
"strings"
"github.com/google/uuid"
"github.com/wavetermdev/waveterm/pkg/aiusechat/aiutil"
"github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes"
"github.com/wavetermdev/waveterm/pkg/blockcontroller"
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
"github.com/wavetermdev/waveterm/pkg/wavebase"
"github.com/wavetermdev/waveterm/pkg/waveobj"
"github.com/wavetermdev/waveterm/pkg/wstore"
)
func makeTerminalBlockDesc(block *waveobj.Block) string {
connection, hasConnection := block.Meta["connection"].(string)
cwd, hasCwd := block.Meta["cmd:cwd"].(string)
blockORef := waveobj.MakeORef(waveobj.OType_Block, block.OID)
rtInfo := wstore.GetRTInfo(blockORef)
hasCurCwd := rtInfo != nil && rtInfo.ShellHasCurCwd
var desc string
if hasConnection && connection != "" {
desc = fmt.Sprintf("CLI terminal connected to %q", connection)
} else {
desc = "local CLI terminal"
}
if rtInfo != nil && rtInfo.ShellType != "" {
desc += fmt.Sprintf(" (%s", rtInfo.ShellType)
if rtInfo.ShellVersion != "" {
desc += fmt.Sprintf(" %s", rtInfo.ShellVersion)
}
desc += ")"
}
if rtInfo != nil {
if rtInfo.ShellIntegration {
var stateStr string
switch rtInfo.ShellState {
case "ready":
stateStr = "waiting for input"
case "running-command":
stateStr = "running command"
if rtInfo.ShellLastCmd != "" {
cmdStr := rtInfo.ShellLastCmd
if len(cmdStr) > 30 {
cmdStr = cmdStr[:27] + "..."
}
cmdJSON := utilfn.MarshalJSONString(cmdStr)
stateStr = fmt.Sprintf("running command %s", cmdJSON)
}
default:
stateStr = "state unknown"
}
desc += fmt.Sprintf(", %s", stateStr)
} else {
desc += ", no shell integration"
}
}
if hasCurCwd && hasCwd && cwd != "" {
desc += fmt.Sprintf(", in directory %q", cwd)
}
return desc
}
func MakeBlockShortDesc(block *waveobj.Block) string {
if block.Meta == nil {
return ""
}
viewType, ok := block.Meta["view"].(string)
if !ok {
return ""
}
switch viewType {
case "term":
return makeTerminalBlockDesc(block)
case "preview":
file, hasFile := block.Meta["file"].(string)
connection, hasConnection := block.Meta["connection"].(string)
if hasConnection && connection != "" {
if hasFile && file != "" {
return fmt.Sprintf("preview widget viewing %q on %q", file, connection)
}
return fmt.Sprintf("preview widget viewing files on %q", connection)
}
if hasFile && file != "" {
return fmt.Sprintf("preview widget viewing %q", file)
}
return "file and directory preview widget"
case "web":
if url, hasUrl := block.Meta["url"].(string); hasUrl && url != "" {
return fmt.Sprintf("web browser widget pointing at %q", url)
}
return "web browser widget"
case "waveai":
return "AI chat widget"
case "cpuplot":
if connection, hasConnection := block.Meta["connection"].(string); hasConnection && connection != "" {
return fmt.Sprintf("cpu graph for %q", connection)
}
return "cpu graph"
case "tips":
return "Wave quick tips widget"
case "help":
return "Wave documentation widget"
case "launcher":
return "placeholder widget used to launch other widgets"
case "tsunami":
return handleTsunamiBlockDesc(block)
case "aifilediff":
return "" // AI doesn't need to see these
case "waveconfig":
if file, hasFile := block.Meta["file"].(string); hasFile && file != "" {
return fmt.Sprintf("wave config editor for %q", file)
}
return "wave config editor"
default:
return fmt.Sprintf("unknown widget with type %q", viewType)
}
}
func GenerateTabStateAndTools(ctx context.Context, tabid string, widgetAccess bool, chatOpts *uctypes.WaveChatOpts) (string, []uctypes.ToolDefinition, error) {
if tabid == "" {
return "", nil, nil
}
var blocks []*waveobj.Block
if widgetAccess {
if _, err := uuid.Parse(tabid); err != nil {
return "", nil, fmt.Errorf("tabid must be a valid UUID")
}
tabObj, err := wstore.DBMustGet[*waveobj.Tab](ctx, tabid)
if err != nil {
return "", nil, fmt.Errorf("error getting tab: %v", err)
}
for _, blockId := range tabObj.BlockIds {
block, err := wstore.DBGet[*waveobj.Block](ctx, blockId)
if err != nil {
continue
}
blocks = append(blocks, block)
}
}
tabState := GenerateCurrentTabStatePrompt(blocks, widgetAccess)
// for debugging
// log.Printf("TABPROMPT %s\n", tabState)
var tools []uctypes.ToolDefinition
if widgetAccess {
// Only add screenshot tool for:
// - openai-responses API type
// - google-gemini API type with Gemini 3+ models
if chatOpts.Config.APIType == uctypes.APIType_OpenAIResponses ||
(chatOpts.Config.APIType == uctypes.APIType_GoogleGemini && aiutil.GeminiSupportsImageToolResults(chatOpts.Config.Model)) {
tools = append(tools, GetCaptureScreenshotToolDefinition(tabid))
}
tools = append(tools, GetReadTextFileToolDefinition())
tools = append(tools, GetReadDirToolDefinition())
tools = append(tools, GetWriteTextFileToolDefinition())
tools = append(tools, GetEditTextFileToolDefinition())
tools = append(tools, GetDeleteTextFileToolDefinition())
viewTypes := make(map[string]bool)
for _, block := range blocks {
if block.Meta == nil {
continue
}
viewType, ok := block.Meta["view"].(string)
if !ok {
continue
}
viewTypes[viewType] = true
if viewType == "tsunami" {
blockTools := generateToolsForTsunamiBlock(block)
tools = append(tools, blockTools...)
}
}
if viewTypes["term"] {
tools = append(tools, GetTermGetScrollbackToolDefinition(tabid))
// tools = append(tools, GetTermCommandOutputToolDefinition(tabid))
}
if viewTypes["web"] {
tools = append(tools, GetWebNavigateToolDefinition(tabid))
}
}
return tabState, tools, nil
}
func GenerateCurrentTabStatePrompt(blocks []*waveobj.Block, widgetAccess bool) string {
if !widgetAccess {
return `<current_tab_state>The user has chosen not to share widget context with you</current_tab_state>`
}
var widgetDescriptions []string
for _, block := range blocks {
desc := MakeBlockShortDesc(block)
if desc == "" {
continue
}
blockIdPrefix := block.OID[:8]
fullDesc := fmt.Sprintf("(%s) %s", blockIdPrefix, desc)
widgetDescriptions = append(widgetDescriptions, fullDesc)
}
var prompt strings.Builder
prompt.WriteString("<current_tab_state>\n")
systemInfo := wavebase.GetSystemSummary()
if currentUser, err := user.Current(); err == nil && currentUser.Username != "" {
prompt.WriteString(fmt.Sprintf("Local Machine: %s, User: %s\n", systemInfo, currentUser.Username))
} else {
prompt.WriteString(fmt.Sprintf("Local Machine: %s\n", systemInfo))
}
if len(widgetDescriptions) == 0 {
prompt.WriteString("No widgets open\n")
} else {
prompt.WriteString("Open Widgets:\n")
for _, desc := range widgetDescriptions {
prompt.WriteString("* ")
prompt.WriteString(desc)
prompt.WriteString("\n")
}
}
prompt.WriteString("</current_tab_state>")
rtn := prompt.String()
return rtn
}
func generateToolsForTsunamiBlock(block *waveobj.Block) []uctypes.ToolDefinition {
var tools []uctypes.ToolDefinition
status := blockcontroller.GetBlockControllerRuntimeStatus(block.OID)
if status == nil || status.ShellProcStatus != blockcontroller.Status_Running || status.TsunamiPort <= 0 {
return nil
}
blockORef := waveobj.MakeORef(waveobj.OType_Block, block.OID)
rtInfo := wstore.GetRTInfo(blockORef)
if tool := GetTsunamiGetDataToolDefinition(block, rtInfo, status); tool != nil {
tools = append(tools, *tool)
}
if tool := GetTsunamiGetConfigToolDefinition(block, rtInfo, status); tool != nil {
tools = append(tools, *tool)
}
if tool := GetTsunamiSetConfigToolDefinition(block, rtInfo, status); tool != nil {
tools = append(tools, *tool)
}
return tools
}
// Used for internal testing of tool loops
func GetAdderToolDefinition() uctypes.ToolDefinition {
return uctypes.ToolDefinition{
Name: "adder",
DisplayName: "Adder",
Description: "Add an array of numbers together and return their sum",
ToolLogName: "gen:adder",
Strict: true,
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"values": map[string]any{
"type": "array",
"items": map[string]any{
"type": "integer",
},
"description": "Array of numbers to add together",
},
},
"required": []string{"values"},
"additionalProperties": false,
},
ToolAnyCallback: func(input any, toolUseData *uctypes.UIMessageDataToolUse) (any, error) {
inputMap, ok := input.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid input format")
}
valuesInterface, ok := inputMap["values"]
if !ok {
return nil, fmt.Errorf("missing values parameter")
}
valuesSlice, ok := valuesInterface.([]any)
if !ok {
return nil, fmt.Errorf("values must be an array")
}
if len(valuesSlice) == 0 {
return 0, nil
}
sum := 0
for i, val := range valuesSlice {
floatVal, ok := val.(float64)
if !ok {
return nil, fmt.Errorf("value at index %d is not a number", i)
}
sum += int(floatVal)
}
return sum, nil
},
}
}