feat: add distributed mode (#9124)

* feat: add distributed mode (experimental)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix data races, mutexes, transactions

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactorings

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fixups

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix events and tool stream in agent chat

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* use ginkgo

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix(cron): compute correctly time boundaries avoiding re-triggering

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* enhancements, refactorings

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* do not flood of healthy checks

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* do not list obvious backends as text backends

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* tests fixups

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Drop redundant healthcheck

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* enhancements, refactorings

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2026-03-30 00:47:27 +02:00 committed by GitHub
parent 4c870288d9
commit 59108fbe32
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
389 changed files with 276305 additions and 246521 deletions

View file

@ -406,7 +406,7 @@ func getHuggingFaceAvatarURL(author string) string {
} }
// Parse the response to get avatar URL // Parse the response to get avatar URL
var userInfo map[string]interface{} var userInfo map[string]any
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return "" return ""

View file

@ -3,7 +3,7 @@ package main
import ( import (
"context" "context"
"fmt" "fmt"
"math/rand" "math/rand/v2"
"strings" "strings"
"time" "time"
) )
@ -13,11 +13,11 @@ func runSyntheticMode() error {
generator := NewSyntheticDataGenerator() generator := NewSyntheticDataGenerator()
// Generate a random number of synthetic models (1-3) // Generate a random number of synthetic models (1-3)
numModels := generator.rand.Intn(3) + 1 numModels := generator.rand.IntN(3) + 1
fmt.Printf("Generating %d synthetic models for testing...\n", numModels) fmt.Printf("Generating %d synthetic models for testing...\n", numModels)
var models []ProcessedModel var models []ProcessedModel
for i := 0; i < numModels; i++ { for i := range numModels {
model := generator.GenerateProcessedModel() model := generator.GenerateProcessedModel()
models = append(models, model) models = append(models, model)
fmt.Printf("Generated synthetic model: %s\n", model.ModelID) fmt.Printf("Generated synthetic model: %s\n", model.ModelID)
@ -42,14 +42,14 @@ type SyntheticDataGenerator struct {
// NewSyntheticDataGenerator creates a new synthetic data generator // NewSyntheticDataGenerator creates a new synthetic data generator
func NewSyntheticDataGenerator() *SyntheticDataGenerator { func NewSyntheticDataGenerator() *SyntheticDataGenerator {
return &SyntheticDataGenerator{ return &SyntheticDataGenerator{
rand: rand.New(rand.NewSource(time.Now().UnixNano())), rand: rand.New(rand.NewPCG(uint64(time.Now().UnixNano()), 0)),
} }
} }
// GenerateProcessedModelFile creates a synthetic ProcessedModelFile // GenerateProcessedModelFile creates a synthetic ProcessedModelFile
func (g *SyntheticDataGenerator) GenerateProcessedModelFile() ProcessedModelFile { func (g *SyntheticDataGenerator) GenerateProcessedModelFile() ProcessedModelFile {
fileTypes := []string{"model", "readme", "other"} fileTypes := []string{"model", "readme", "other"}
fileType := fileTypes[g.rand.Intn(len(fileTypes))] fileType := fileTypes[g.rand.IntN(len(fileTypes))]
var path string var path string
var isReadme bool var isReadme bool
@ -68,7 +68,7 @@ func (g *SyntheticDataGenerator) GenerateProcessedModelFile() ProcessedModelFile
return ProcessedModelFile{ return ProcessedModelFile{
Path: path, Path: path,
Size: int64(g.rand.Intn(1000000000) + 1000000), // 1MB to 1GB Size: int64(g.rand.IntN(1000000000) + 1000000), // 1MB to 1GB
SHA256: g.randomSHA256(), SHA256: g.randomSHA256(),
IsReadme: isReadme, IsReadme: isReadme,
FileType: fileType, FileType: fileType,
@ -80,19 +80,19 @@ func (g *SyntheticDataGenerator) GenerateProcessedModel() ProcessedModel {
authors := []string{"microsoft", "meta", "google", "openai", "anthropic", "mistralai", "huggingface"} authors := []string{"microsoft", "meta", "google", "openai", "anthropic", "mistralai", "huggingface"}
modelNames := []string{"llama", "gpt", "claude", "mistral", "gemma", "phi", "qwen", "codellama"} modelNames := []string{"llama", "gpt", "claude", "mistral", "gemma", "phi", "qwen", "codellama"}
author := authors[g.rand.Intn(len(authors))] author := authors[g.rand.IntN(len(authors))]
modelName := modelNames[g.rand.Intn(len(modelNames))] modelName := modelNames[g.rand.IntN(len(modelNames))]
modelID := fmt.Sprintf("%s/%s-%s", author, modelName, g.randomString(6)) modelID := fmt.Sprintf("%s/%s-%s", author, modelName, g.randomString(6))
// Generate files // Generate files
numFiles := g.rand.Intn(5) + 2 // 2-6 files numFiles := g.rand.IntN(5) + 2 // 2-6 files
files := make([]ProcessedModelFile, numFiles) files := make([]ProcessedModelFile, numFiles)
// Ensure at least one model file and one readme // Ensure at least one model file and one readme
hasModelFile := false hasModelFile := false
hasReadme := false hasReadme := false
for i := 0; i < numFiles; i++ { for i := range numFiles {
files[i] = g.GenerateProcessedModelFile() files[i] = g.GenerateProcessedModelFile()
if files[i].FileType == "model" { if files[i].FileType == "model" {
hasModelFile = true hasModelFile = true
@ -140,27 +140,27 @@ func (g *SyntheticDataGenerator) GenerateProcessedModel() ProcessedModel {
// Generate sample metadata // Generate sample metadata
licenses := []string{"apache-2.0", "mit", "llama2", "gpl-3.0", "bsd", ""} licenses := []string{"apache-2.0", "mit", "llama2", "gpl-3.0", "bsd", ""}
license := licenses[g.rand.Intn(len(licenses))] license := licenses[g.rand.IntN(len(licenses))]
sampleTags := []string{"llm", "gguf", "gpu", "cpu", "text-to-text", "chat", "instruction-tuned"} sampleTags := []string{"llm", "gguf", "gpu", "cpu", "text-to-text", "chat", "instruction-tuned"}
numTags := g.rand.Intn(4) + 3 // 3-6 tags numTags := g.rand.IntN(4) + 3 // 3-6 tags
tags := make([]string, numTags) tags := make([]string, numTags)
for i := 0; i < numTags; i++ { for i := range numTags {
tags[i] = sampleTags[g.rand.Intn(len(sampleTags))] tags[i] = sampleTags[g.rand.IntN(len(sampleTags))]
} }
// Remove duplicates // Remove duplicates
tags = g.removeDuplicates(tags) tags = g.removeDuplicates(tags)
// Optionally include icon (50% chance) // Optionally include icon (50% chance)
icon := "" icon := ""
if g.rand.Intn(2) == 0 { if g.rand.IntN(2) == 0 {
icon = fmt.Sprintf("https://cdn-avatars.huggingface.co/v1/production/uploads/%s.png", g.randomString(24)) icon = fmt.Sprintf("https://cdn-avatars.huggingface.co/v1/production/uploads/%s.png", g.randomString(24))
} }
return ProcessedModel{ return ProcessedModel{
ModelID: modelID, ModelID: modelID,
Author: author, Author: author,
Downloads: g.rand.Intn(1000000) + 1000, Downloads: g.rand.IntN(1000000) + 1000,
LastModified: g.randomDate(), LastModified: g.randomDate(),
Files: files, Files: files,
PreferredModelFile: preferredModelFile, PreferredModelFile: preferredModelFile,
@ -180,7 +180,7 @@ func (g *SyntheticDataGenerator) randomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyz0123456789" const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
b := make([]byte, length) b := make([]byte, length)
for i := range b { for i := range b {
b[i] = charset[g.rand.Intn(len(charset))] b[i] = charset[g.rand.IntN(len(charset))]
} }
return string(b) return string(b)
} }
@ -189,14 +189,14 @@ func (g *SyntheticDataGenerator) randomSHA256() string {
const charset = "0123456789abcdef" const charset = "0123456789abcdef"
b := make([]byte, 64) b := make([]byte, 64)
for i := range b { for i := range b {
b[i] = charset[g.rand.Intn(len(charset))] b[i] = charset[g.rand.IntN(len(charset))]
} }
return string(b) return string(b)
} }
func (g *SyntheticDataGenerator) randomDate() string { func (g *SyntheticDataGenerator) randomDate() string {
now := time.Now() now := time.Now()
daysAgo := g.rand.Intn(365) // Random date within last year daysAgo := g.rand.IntN(365) // Random date within last year
pastDate := now.AddDate(0, 0, -daysAgo) pastDate := now.AddDate(0, 0, -daysAgo)
return pastDate.Format("2006-01-02T15:04:05.000Z") return pastDate.Format("2006-01-02T15:04:05.000Z")
} }
@ -220,5 +220,5 @@ func (g *SyntheticDataGenerator) generateReadmeContent(modelName, author string)
fmt.Sprintf("# %s Language Model\n\nDeveloped by %s, this model represents state-of-the-art performance in natural language understanding and generation.\n\n## Key Features\n\n- Multilingual support\n- Context-aware responses\n- Efficient memory usage\n- Fast inference speed\n\n## Applications\n\n- Chatbots and virtual assistants\n- Content generation\n- Code completion\n- Educational tools", strings.Title(modelName), author), fmt.Sprintf("# %s Language Model\n\nDeveloped by %s, this model represents state-of-the-art performance in natural language understanding and generation.\n\n## Key Features\n\n- Multilingual support\n- Context-aware responses\n- Efficient memory usage\n- Fast inference speed\n\n## Applications\n\n- Chatbots and virtual assistants\n- Content generation\n- Code completion\n- Educational tools", strings.Title(modelName), author),
} }
return templates[g.rand.Intn(len(templates))] return templates[g.rand.IntN(len(templates))]
} }

View file

@ -21,7 +21,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
go-version: ['1.25.x'] go-version: ['1.26.x']
steps: steps:
- name: Free Disk Space (Ubuntu) - name: Free Disk Space (Ubuntu)
uses: jlumbroso/free-disk-space@main uses: jlumbroso/free-disk-space@main
@ -179,7 +179,7 @@ jobs:
runs-on: macos-latest runs-on: macos-latest
strategy: strategy:
matrix: matrix:
go-version: ['1.25.x'] go-version: ['1.26.x']
steps: steps:
- name: Clone - name: Clone
uses: actions/checkout@v6 uses: actions/checkout@v6

View file

@ -176,7 +176,7 @@ ENV PATH=/opt/rocm/bin:${PATH}
# The requirements-core target is common to all images. It should not be placed in requirements-core unless every single build will use it. # The requirements-core target is common to all images. It should not be placed in requirements-core unless every single build will use it.
FROM requirements-drivers AS build-requirements FROM requirements-drivers AS build-requirements
ARG GO_VERSION=1.25.4 ARG GO_VERSION=1.26.0
ARG CMAKE_VERSION=3.31.10 ARG CMAKE_VERSION=3.31.10
ARG CMAKE_FROM_SOURCE=false ARG CMAKE_FROM_SOURCE=false
ARG TARGETARCH ARG TARGETARCH
@ -319,7 +319,6 @@ COPY ./.git ./.git
# Some of the Go backends use libs from the main src, we could further optimize the caching by building the CPP backends before here # Some of the Go backends use libs from the main src, we could further optimize the caching by building the CPP backends before here
COPY ./pkg/grpc ./pkg/grpc COPY ./pkg/grpc ./pkg/grpc
COPY ./pkg/utils ./pkg/utils COPY ./pkg/utils ./pkg/utils
COPY ./pkg/langchain ./pkg/langchain
RUN ls -l ./ RUN ls -l ./
RUN make protogen-go RUN make protogen-go

View file

@ -154,6 +154,7 @@ For older news and full release notes, see [GitHub Releases](https://github.com/
- [Object Detection](https://localai.io/features/object-detection/) - [Object Detection](https://localai.io/features/object-detection/)
- [Reranker API](https://localai.io/features/reranker/) - [Reranker API](https://localai.io/features/reranker/)
- [P2P Inferencing](https://localai.io/features/distribute/) - [P2P Inferencing](https://localai.io/features/distribute/)
- [Distributed Mode](https://localai.io/features/distributed-mode/) — Horizontal scaling with PostgreSQL + NATS
- [Model Context Protocol (MCP)](https://localai.io/docs/features/mcp/) - [Model Context Protocol (MCP)](https://localai.io/docs/features/mcp/)
- [Built-in Agents](https://localai.io/features/agents/) — Autonomous AI agents with tool use, RAG, skills, SSE streaming, and [Agent Hub](https://agenthub.localai.io) - [Built-in Agents](https://localai.io/features/agents/) — Autonomous AI agents with tool use, RAG, skills, SSE streaming, and [Agent Hub](https://agenthub.localai.io)
- [Backend Gallery](https://localai.io/backends/) — Install/remove backends on the fly via OCI images - [Backend Gallery](https://localai.io/backends/) — Install/remove backends on the fly via OCI images

View file

@ -51,6 +51,7 @@ service Backend {
rpc StartQuantization(QuantizationRequest) returns (QuantizationJobResult) {} rpc StartQuantization(QuantizationRequest) returns (QuantizationJobResult) {}
rpc QuantizationProgress(QuantizationProgressRequest) returns (stream QuantizationProgressUpdate) {} rpc QuantizationProgress(QuantizationProgressRequest) returns (stream QuantizationProgressUpdate) {}
rpc StopQuantization(QuantizationStopRequest) returns (Result) {} rpc StopQuantization(QuantizationStopRequest) returns (Result) {}
} }
// Define the empty request // Define the empty request
@ -676,3 +677,4 @@ message QuantizationProgressUpdate {
message QuantizationStopRequest { message QuantizationStopRequest {
string job_id = 1; string job_id = 1;
} }

View file

@ -22,8 +22,10 @@
#include <grpcpp/ext/proto_server_reflection_plugin.h> #include <grpcpp/ext/proto_server_reflection_plugin.h>
#include <grpcpp/grpcpp.h> #include <grpcpp/grpcpp.h>
#include <grpcpp/health_check_service_interface.h> #include <grpcpp/health_check_service_interface.h>
#include <grpcpp/security/server_credentials.h>
#include <regex> #include <regex>
#include <atomic> #include <atomic>
#include <cstdlib>
#include <mutex> #include <mutex>
#include <signal.h> #include <signal.h>
#include <thread> #include <thread>
@ -37,6 +39,47 @@ using grpc::Server;
using grpc::ServerBuilder; using grpc::ServerBuilder;
using grpc::ServerContext; using grpc::ServerContext;
using grpc::Status; using grpc::Status;
// gRPC bearer token auth via AuthMetadataProcessor for distributed mode.
// Reads LOCALAI_GRPC_AUTH_TOKEN from the environment. When set, rejects
// requests without a matching "authorization: Bearer <token>" metadata header.
class TokenAuthMetadataProcessor : public grpc::AuthMetadataProcessor {
public:
explicit TokenAuthMetadataProcessor(const std::string& token) : token_(token) {}
bool IsBlocking() const override { return false; }
grpc::Status Process(const InputMetadata& auth_metadata,
grpc::AuthContext* /*context*/,
OutputMetadata* /*consumed_auth_metadata*/,
OutputMetadata* /*response_metadata*/) override {
auto it = auth_metadata.find("authorization");
if (it != auth_metadata.end()) {
std::string expected = "Bearer " + token_;
std::string got(it->second.data(), it->second.size());
// Constant-time comparison
if (expected.size() == got.size() && ct_memcmp(expected.data(), got.data(), expected.size()) == 0) {
return grpc::Status::OK;
}
}
return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "invalid token");
}
private:
std::string token_;
// Minimal constant-time comparison (avoids OpenSSL dependency)
static int ct_memcmp(const void* a, const void* b, size_t n) {
const unsigned char* pa = static_cast<const unsigned char*>(a);
const unsigned char* pb = static_cast<const unsigned char*>(b);
unsigned char result = 0;
for (size_t i = 0; i < n; i++) {
result |= pa[i] ^ pb[i];
}
return result;
}
};
// END LocalAI // END LocalAI
@ -2760,11 +2803,24 @@ int main(int argc, char** argv) {
BackendServiceImpl service(ctx_server); BackendServiceImpl service(ctx_server);
ServerBuilder builder; ServerBuilder builder;
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); // Add bearer token auth via AuthMetadataProcessor if LOCALAI_GRPC_AUTH_TOKEN is set
const char* auth_token = std::getenv("LOCALAI_GRPC_AUTH_TOKEN");
std::shared_ptr<grpc::ServerCredentials> creds;
if (auth_token != nullptr && auth_token[0] != '\0') {
creds = grpc::InsecureServerCredentials();
creds->SetAuthMetadataProcessor(
std::make_shared<TokenAuthMetadataProcessor>(auth_token));
std::cout << "gRPC auth enabled via LOCALAI_GRPC_AUTH_TOKEN" << std::endl;
} else {
creds = grpc::InsecureServerCredentials();
}
builder.AddListeningPort(server_address, creds);
builder.RegisterService(&service); builder.RegisterService(&service);
builder.SetMaxMessageSize(50 * 1024 * 1024); // 50MB builder.SetMaxMessageSize(50 * 1024 * 1024); // 50MB
builder.SetMaxSendMessageSize(50 * 1024 * 1024); // 50MB builder.SetMaxSendMessageSize(50 * 1024 * 1024); // 50MB
builder.SetMaxReceiveMessageSize(50 * 1024 * 1024); // 50MB builder.SetMaxReceiveMessageSize(50 * 1024 * 1024); // 50MB
std::unique_ptr<Server> server(builder.BuildAndStart()); std::unique_ptr<Server> server(builder.BuildAndStart());
// run the HTTP server in a thread - see comment below // run the HTTP server in a thread - see comment below
std::thread t([&]() std::thread t([&]()

View file

@ -134,7 +134,7 @@ func TestSoundGeneration(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer os.RemoveAll(tmpDir) t.Cleanup(func() { os.RemoveAll(tmpDir) })
outputFile := filepath.Join(tmpDir, "output.wav") outputFile := filepath.Join(tmpDir, "output.wav")

View file

@ -11,7 +11,7 @@ import (
) )
var ( var (
CppLoadModel func(lmModelPath, textEncoderPath, ditModelPath, vaeModelPath string) int CppLoadModel func(lmModelPath, textEncoderPath, ditModelPath, vaeModelPath string) int
CppGenerateMusic func(caption, lyrics string, bpm int, keyscale, timesignature string, duration, temperature float32, instrumental bool, seed int, dst string, threads int) int CppGenerateMusic func(caption, lyrics string, bpm int, keyscale, timesignature string, duration, temperature float32, instrumental bool, seed int, dst string, threads int) int
) )
@ -29,18 +29,18 @@ func (a *AceStepCpp) Load(opts *pb.ModelOptions) error {
var textEncoderModel, ditModel, vaeModel string var textEncoderModel, ditModel, vaeModel string
for _, oo := range opts.Options { for _, oo := range opts.Options {
parts := strings.SplitN(oo, ":", 2) key, value, found := strings.Cut(oo, ":")
if len(parts) != 2 { if !found {
fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo) fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo)
continue continue
} }
switch parts[0] { switch key {
case "text_encoder_model": case "text_encoder_model":
textEncoderModel = parts[1] textEncoderModel = value
case "dit_model": case "dit_model":
ditModel = parts[1] ditModel = value
case "vae_model": case "vae_model":
vaeModel = parts[1] vaeModel = value
default: default:
fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo) fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo)
} }

View file

@ -18,7 +18,6 @@ type LLM struct {
draftModel *llama.LLama draftModel *llama.LLama
} }
// Free releases GPU resources and frees the llama model // Free releases GPU resources and frees the llama model
// This should be called when the model is being unloaded to properly release VRAM // This should be called when the model is being unloaded to properly release VRAM
func (llm *LLM) Free() error { func (llm *LLM) Free() error {

View file

@ -1,5 +1,4 @@
//go:build debug //go:build debug
// +build debug
package main package main

View file

@ -1,5 +1,4 @@
//go:build !debug //go:build !debug
// +build !debug
package main package main

View file

@ -332,7 +332,7 @@ func normalizedCosineSimilarity(k1, k2 []float32) float32 {
assert(len(k1) == len(k2), fmt.Sprintf("normalizedCosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2))) assert(len(k1) == len(k2), fmt.Sprintf("normalizedCosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
var dot float32 var dot float32
for i := 0; i < len(k1); i++ { for i := range len(k1) {
dot += k1[i] * k2[i] dot += k1[i] * k2[i]
} }
@ -419,7 +419,7 @@ func cosineSimilarity(k1, k2 []float32, mag1 float64) float32 {
assert(len(k1) == len(k2), fmt.Sprintf("cosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2))) assert(len(k1) == len(k2), fmt.Sprintf("cosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
var dot, mag2 float64 var dot, mag2 float64
for i := 0; i < len(k1); i++ { for i := range len(k1) {
dot += float64(k1[i] * k2[i]) dot += float64(k1[i] * k2[i])
mag2 += float64(k2[i] * k2[i]) mag2 += float64(k2[i] * k2[i])
} }

View file

@ -701,7 +701,7 @@ var _ = Describe("Opus", func() {
// to one-shot (only difference is resampler batch boundaries). // to one-shot (only difference is resampler batch boundaries).
var maxDiff float64 var maxDiff float64
var sumDiffSq float64 var sumDiffSq float64
for i := 0; i < minLen; i++ { for i := range minLen {
diff := math.Abs(float64(oneShotTail[i]) - float64(batchedTail[i])) diff := math.Abs(float64(oneShotTail[i]) - float64(batchedTail[i]))
if diff > maxDiff { if diff > maxDiff {
maxDiff = diff maxDiff = diff
@ -774,7 +774,7 @@ var _ = Describe("Opus", func() {
minLen := min(len(refTail), min(len(persistentTail), len(freshTail))) minLen := min(len(refTail), min(len(persistentTail), len(freshTail)))
var persistentMaxDiff, freshMaxDiff float64 var persistentMaxDiff, freshMaxDiff float64
for i := 0; i < minLen; i++ { for i := range minLen {
pd := math.Abs(float64(refTail[i]) - float64(persistentTail[i])) pd := math.Abs(float64(refTail[i]) - float64(persistentTail[i]))
fd := math.Abs(float64(refTail[i]) - float64(freshTail[i])) fd := math.Abs(float64(refTail[i]) - float64(freshTail[i]))
if pd > persistentMaxDiff { if pd > persistentMaxDiff {
@ -932,7 +932,7 @@ var _ = Describe("Opus", func() {
GinkgoWriter.Printf("Zero-crossing intervals: mean=%.2f stddev=%.2f CV=%.3f (expected period ~%.1f)\n", GinkgoWriter.Printf("Zero-crossing intervals: mean=%.2f stddev=%.2f CV=%.3f (expected period ~%.1f)\n",
mean, stddev, stddev/mean, 16000.0/440.0/2.0) mean, stddev, stddev/mean, 16000.0/440.0/2.0)
Expect(stddev / mean).To(BeNumerically("<", 0.15), Expect(stddev/mean).To(BeNumerically("<", 0.15),
fmt.Sprintf("irregular zero crossings suggest discontinuity: CV=%.3f", stddev/mean)) fmt.Sprintf("irregular zero crossings suggest discontinuity: CV=%.3f", stddev/mean))
// Also check frequency is correct // Also check frequency is correct
@ -978,7 +978,7 @@ var _ = Describe("Opus", func() {
// Every sample must be identical — the resampler is deterministic // Every sample must be identical — the resampler is deterministic
var maxDiff float64 var maxDiff float64
for i := 0; i < len(oneShot); i++ { for i := range len(oneShot) {
diff := math.Abs(float64(oneShot[i]) - float64(batched[i])) diff := math.Abs(float64(oneShot[i]) - float64(batched[i]))
if diff > maxDiff { if diff > maxDiff {
maxDiff = diff maxDiff = diff
@ -1037,13 +1037,13 @@ var _ = Describe("Opus", func() {
binary.LittleEndian.PutUint32(hdr[4:8], uint32(36+dataLen)) binary.LittleEndian.PutUint32(hdr[4:8], uint32(36+dataLen))
copy(hdr[8:12], "WAVE") copy(hdr[8:12], "WAVE")
copy(hdr[12:16], "fmt ") copy(hdr[12:16], "fmt ")
binary.LittleEndian.PutUint32(hdr[16:20], 16) // chunk size binary.LittleEndian.PutUint32(hdr[16:20], 16) // chunk size
binary.LittleEndian.PutUint16(hdr[20:22], 1) // PCM binary.LittleEndian.PutUint16(hdr[20:22], 1) // PCM
binary.LittleEndian.PutUint16(hdr[22:24], 1) // mono binary.LittleEndian.PutUint16(hdr[22:24], 1) // mono
binary.LittleEndian.PutUint32(hdr[24:28], uint32(sampleRate)) // sample rate binary.LittleEndian.PutUint32(hdr[24:28], uint32(sampleRate)) // sample rate
binary.LittleEndian.PutUint32(hdr[28:32], uint32(sampleRate*2)) // byte rate binary.LittleEndian.PutUint32(hdr[28:32], uint32(sampleRate*2)) // byte rate
binary.LittleEndian.PutUint16(hdr[32:34], 2) // block align binary.LittleEndian.PutUint16(hdr[32:34], 2) // block align
binary.LittleEndian.PutUint16(hdr[34:36], 16) // bits per sample binary.LittleEndian.PutUint16(hdr[34:36], 16) // bits per sample
copy(hdr[36:40], "data") copy(hdr[36:40], "data")
binary.LittleEndian.PutUint32(hdr[40:44], uint32(dataLen)) binary.LittleEndian.PutUint32(hdr[40:44], uint32(dataLen))
@ -1126,7 +1126,7 @@ var _ = Describe("Opus", func() {
) )
pcm := make([]byte, toneNumSamples*2) pcm := make([]byte, toneNumSamples*2)
for i := 0; i < toneNumSamples; i++ { for i := range toneNumSamples {
sample := int16(toneAmplitude * math.Sin(2*math.Pi*toneFreq*float64(i)/float64(toneSampleRate))) sample := int16(toneAmplitude * math.Sin(2*math.Pi*toneFreq*float64(i)/float64(toneSampleRate)))
binary.LittleEndian.PutUint16(pcm[i*2:], uint16(sample)) binary.LittleEndian.PutUint16(pcm[i*2:], uint16(sample))
} }

View file

@ -138,7 +138,7 @@ func TestAudioTranscription(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer os.RemoveAll(tmpDir) t.Cleanup(func() { os.RemoveAll(tmpDir) })
// Download sample audio — JFK "ask not what your country can do for you" clip // Download sample audio — JFK "ask not what your country can do for you" clip
audioFile := filepath.Join(tmpDir, "sample.wav") audioFile := filepath.Join(tmpDir, "sample.wav")

View file

@ -19,6 +19,10 @@ import tempfile
import backend_pb2 import backend_pb2
import backend_pb2_grpc import backend_pb2_grpc
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
from acestep.inference import ( from acestep.inference import (
GenerationParams, GenerationParams,
GenerationConfig, GenerationConfig,
@ -444,6 +448,8 @@ def serve(address):
("grpc.max_send_message_length", 50 * 1024 * 1024), ("grpc.max_send_message_length", 50 * 1024 * 1024),
("grpc.max_receive_message_length", 50 * 1024 * 1024), ("grpc.max_receive_message_length", 50 * 1024 * 1024),
], ],
interceptors=get_auth_interceptors(),
) )
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)

View file

@ -16,6 +16,10 @@ import torchaudio as ta
from chatterbox.tts import ChatterboxTTS from chatterbox.tts import ChatterboxTTS
from chatterbox.mtl_tts import ChatterboxMultilingualTTS from chatterbox.mtl_tts import ChatterboxMultilingualTTS
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
import tempfile import tempfile
def is_float(s): def is_float(s):
@ -225,7 +229,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View file

@ -0,0 +1,78 @@
"""Shared gRPC bearer token authentication interceptor for LocalAI Python backends.
When the environment variable LOCALAI_GRPC_AUTH_TOKEN is set, requests without
a valid Bearer token in the 'authorization' metadata header are rejected with
UNAUTHENTICATED. When the variable is empty or unset, no authentication is
performed (backward compatible).
"""
import hmac
import os
import grpc
class _AbortHandler(grpc.RpcMethodHandler):
"""A method handler that immediately aborts with UNAUTHENTICATED."""
def __init__(self):
self.request_streaming = False
self.response_streaming = False
self.request_deserializer = None
self.response_serializer = None
self.unary_unary = self._abort
self.unary_stream = None
self.stream_unary = None
self.stream_stream = None
@staticmethod
def _abort(request, context):
context.abort(grpc.StatusCode.UNAUTHENTICATED, "invalid token")
class TokenAuthInterceptor(grpc.ServerInterceptor):
"""Sync gRPC server interceptor that validates a bearer token."""
def __init__(self, token: str):
self._token = token
self._abort_handler = _AbortHandler()
def intercept_service(self, continuation, handler_call_details):
metadata = dict(handler_call_details.invocation_metadata)
auth = metadata.get("authorization", "")
expected = "Bearer " + self._token
if not hmac.compare_digest(auth, expected):
return self._abort_handler
return continuation(handler_call_details)
class AsyncTokenAuthInterceptor(grpc.aio.ServerInterceptor):
"""Async gRPC server interceptor that validates a bearer token."""
def __init__(self, token: str):
self._token = token
async def intercept_service(self, continuation, handler_call_details):
metadata = dict(handler_call_details.invocation_metadata)
auth = metadata.get("authorization", "")
expected = "Bearer " + self._token
if not hmac.compare_digest(auth, expected):
return _AbortHandler()
return await continuation(handler_call_details)
def get_auth_interceptors(*, aio: bool = False):
"""Return a list of gRPC interceptors for bearer token auth.
Args:
aio: If True, return async-compatible interceptors for grpc.aio.server().
If False (default), return sync interceptors for grpc.server().
Returns an empty list when LOCALAI_GRPC_AUTH_TOKEN is not set.
"""
token = os.environ.get("LOCALAI_GRPC_AUTH_TOKEN", "")
if not token:
return []
if aio:
return [AsyncTokenAuthInterceptor(token)]
return [TokenAuthInterceptor(token)]

View file

@ -15,6 +15,10 @@ import torch
from TTS.api import TTS from TTS.api import TTS
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 _ONE_DAY_IN_SECONDS = 60 * 60 * 24
@ -93,7 +97,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View file

@ -22,6 +22,10 @@ import backend_pb2
import backend_pb2_grpc import backend_pb2_grpc
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
# Import dynamic loader for pipeline discovery # Import dynamic loader for pipeline discovery
from diffusers_dynamic_loader import ( from diffusers_dynamic_loader import (
@ -1042,7 +1046,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View file

@ -15,6 +15,10 @@ import torch
import soundfile as sf import soundfile as sf
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
def is_float(s): def is_float(s):
@ -165,6 +169,8 @@ def serve(address):
('grpc.max_send_message_length', 50 * 1024 * 1024), ('grpc.max_send_message_length', 50 * 1024 * 1024),
('grpc.max_receive_message_length', 50 * 1024 * 1024), ('grpc.max_receive_message_length', 50 * 1024 * 1024),
] ]
,
interceptors=get_auth_interceptors(),
) )
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)

View file

@ -14,6 +14,10 @@ import torch
from faster_whisper import WhisperModel from faster_whisper import WhisperModel
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 _ONE_DAY_IN_SECONDS = 60 * 60 * 24
@ -70,7 +74,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View file

@ -19,6 +19,10 @@ import numpy as np
import json import json
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
def is_float(s): def is_float(s):
@ -424,6 +428,8 @@ def serve(address):
("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB ("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB
("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB ("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB
], ],
interceptors=get_auth_interceptors(),
) )
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)

View file

@ -16,6 +16,10 @@ from kittentts import KittenTTS
import soundfile as sf import soundfile as sf
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 _ONE_DAY_IN_SECONDS = 60 * 60 * 24
@ -77,7 +81,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View file

@ -16,6 +16,10 @@ from kokoro import KPipeline
import soundfile as sf import soundfile as sf
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 _ONE_DAY_IN_SECONDS = 60 * 60 * 24
@ -84,7 +88,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View file

@ -17,6 +17,10 @@ import time
from concurrent import futures from concurrent import futures
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
import backend_pb2 import backend_pb2
import backend_pb2_grpc import backend_pb2_grpc
@ -398,7 +402,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
def serve(address): def serve(address):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS)) server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View file

@ -15,6 +15,10 @@ import backend_pb2
import backend_pb2_grpc import backend_pb2_grpc
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
from mlx_audio.tts.utils import load_model from mlx_audio.tts.utils import load_model
import soundfile as sf import soundfile as sf
import numpy as np import numpy as np
@ -436,7 +440,9 @@ async def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(aio=True),
)
# Add the servicer to the server # Add the servicer to the server
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
# Bind the server to the address # Bind the server to the address

View file

@ -23,6 +23,10 @@ import tempfile
from typing import List from typing import List
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
import backend_pb2 import backend_pb2
import backend_pb2_grpc import backend_pb2_grpc
@ -468,6 +472,8 @@ async def serve(address):
('grpc.max_send_message_length', 50 * 1024 * 1024), ('grpc.max_send_message_length', 50 * 1024 * 1024),
('grpc.max_receive_message_length', 50 * 1024 * 1024), ('grpc.max_receive_message_length', 50 * 1024 * 1024),
], ],
interceptors=get_auth_interceptors(aio=True),
) )
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)

View file

@ -12,6 +12,10 @@ import backend_pb2
import backend_pb2_grpc import backend_pb2_grpc
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
from mlx_vlm import load, generate, stream_generate from mlx_vlm import load, generate, stream_generate
from mlx_vlm.prompt_utils import apply_chat_template from mlx_vlm.prompt_utils import apply_chat_template
from mlx_vlm.utils import load_config, load_image from mlx_vlm.utils import load_config, load_image
@ -446,7 +450,9 @@ async def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(aio=True),
)
# Add the servicer to the server # Add the servicer to the server
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
# Bind the server to the address # Bind the server to the address

View file

@ -12,6 +12,10 @@ import backend_pb2
import backend_pb2_grpc import backend_pb2_grpc
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
from mlx_lm import load, generate, stream_generate from mlx_lm import load, generate, stream_generate
from mlx_lm.sample_utils import make_sampler from mlx_lm.sample_utils import make_sampler
from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache
@ -421,7 +425,9 @@ async def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(aio=True),
)
# Add the servicer to the server # Add the servicer to the server
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
# Bind the server to the address # Bind the server to the address

View file

@ -17,6 +17,10 @@ from moonshine_voice import (
) )
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 _ONE_DAY_IN_SECONDS = 60 * 60 * 24
@ -128,7 +132,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View file

@ -14,6 +14,10 @@ import torch
import nemo.collections.asr as nemo_asr import nemo.collections.asr as nemo_asr
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
def is_float(s): def is_float(s):
@ -119,7 +123,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), ('grpc.max_message_length', 50 * 1024 * 1024),
('grpc.max_send_message_length', 50 * 1024 * 1024), ('grpc.max_send_message_length', 50 * 1024 * 1024),
('grpc.max_receive_message_length', 50 * 1024 * 1024), ('grpc.max_receive_message_length', 50 * 1024 * 1024),
]) ],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View file

@ -15,6 +15,10 @@ from neuttsair.neutts import NeuTTSAir
import soundfile as sf import soundfile as sf
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
def is_float(s): def is_float(s):
"""Check if a string can be converted to float.""" """Check if a string can be converted to float."""
@ -130,7 +134,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View file

@ -14,6 +14,10 @@ import backend_pb2
import backend_pb2_grpc import backend_pb2_grpc
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
import outetts import outetts
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 _ONE_DAY_IN_SECONDS = 60 * 60 * 24
@ -116,7 +120,9 @@ async def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), ('grpc.max_message_length', 50 * 1024 * 1024),
('grpc.max_send_message_length', 50 * 1024 * 1024), ('grpc.max_send_message_length', 50 * 1024 * 1024),
('grpc.max_receive_message_length', 50 * 1024 * 1024), ('grpc.max_receive_message_length', 50 * 1024 * 1024),
]) ],
interceptors=get_auth_interceptors(aio=True),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)

View file

@ -16,6 +16,10 @@ import torch
from pocket_tts import TTSModel from pocket_tts import TTSModel
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
def is_float(s): def is_float(s):
"""Check if a string can be converted to float.""" """Check if a string can be converted to float."""
@ -225,7 +229,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View file

@ -14,6 +14,10 @@ import torch
from qwen_asr import Qwen3ASRModel from qwen_asr import Qwen3ASRModel
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
def is_float(s): def is_float(s):
@ -184,7 +188,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), ('grpc.max_message_length', 50 * 1024 * 1024),
('grpc.max_send_message_length', 50 * 1024 * 1024), ('grpc.max_send_message_length', 50 * 1024 * 1024),
('grpc.max_receive_message_length', 50 * 1024 * 1024), ('grpc.max_receive_message_length', 50 * 1024 * 1024),
]) ],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View file

@ -23,6 +23,10 @@ import hashlib
import pickle import pickle
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
def is_float(s): def is_float(s):
@ -900,6 +904,8 @@ def serve(address):
("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB ("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB
("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB ("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB
], ],
interceptors=get_auth_interceptors(),
) )
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)

View file

@ -14,6 +14,10 @@ import backend_pb2
import backend_pb2_grpc import backend_pb2_grpc
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
from rerankers import Reranker from rerankers import Reranker
@ -97,7 +101,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View file

@ -13,6 +13,10 @@ import base64
import backend_pb2 import backend_pb2
import backend_pb2_grpc import backend_pb2_grpc
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
import requests import requests
@ -139,7 +143,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View file

@ -16,6 +16,10 @@ import backend_pb2
import backend_pb2_grpc import backend_pb2_grpc
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
import torch import torch
import torch.cuda import torch.cuda
@ -532,7 +536,9 @@ async def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(aio=True),
)
# Add the servicer to the server # Add the servicer to the server
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
# Bind the server to the address # Bind the server to the address

View file

@ -17,6 +17,10 @@ import uuid
from concurrent import futures from concurrent import futures
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
import backend_pb2 import backend_pb2
import backend_pb2_grpc import backend_pb2_grpc
@ -832,6 +836,8 @@ def serve(address):
('grpc.max_send_message_length', 50 * 1024 * 1024), ('grpc.max_send_message_length', 50 * 1024 * 1024),
('grpc.max_receive_message_length', 50 * 1024 * 1024), ('grpc.max_receive_message_length', 50 * 1024 * 1024),
], ],
interceptors=get_auth_interceptors(),
) )
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)

View file

@ -20,6 +20,10 @@ from vibevoice.modular.modeling_vibevoice_asr import VibeVoiceASRForConditionalG
from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
def is_float(s): def is_float(s):
"""Check if a string can be converted to float.""" """Check if a string can be converted to float."""
@ -724,7 +728,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View file

@ -27,6 +27,10 @@ import backend_pb2
import backend_pb2_grpc import backend_pb2_grpc
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
from vllm_omni.entrypoints.omni import Omni from vllm_omni.entrypoints.omni import Omni
from vllm_omni.outputs import OmniRequestOutput from vllm_omni.outputs import OmniRequestOutput
@ -650,7 +654,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), ('grpc.max_send_message_length', 50 * 1024 * 1024),
('grpc.max_receive_message_length', 50 * 1024 * 1024), ('grpc.max_receive_message_length', 50 * 1024 * 1024),
]) ],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View file

@ -12,6 +12,10 @@ import backend_pb2
import backend_pb2_grpc import backend_pb2_grpc
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
@ -338,7 +342,9 @@ async def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(aio=True),
)
# Add the servicer to the server # Add the servicer to the server
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
# Bind the server to the address # Bind the server to the address

View file

@ -18,6 +18,10 @@ import backend_pb2_grpc
import torch import torch
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
def is_float(s): def is_float(s):
"""Check if a string can be converted to float.""" """Check if a string can be converted to float."""
@ -297,7 +301,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View file

@ -13,6 +13,10 @@ import backend_pb2
import backend_pb2_grpc import backend_pb2_grpc
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 _ONE_DAY_IN_SECONDS = 60 * 60 * 24
@ -137,7 +141,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View file

@ -3,7 +3,7 @@ package application
import ( import (
"time" "time"
"github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/core/services/agentpool"
"github.com/mudler/xlog" "github.com/mudler/xlog"
) )
@ -22,13 +22,23 @@ func (a *Application) RestartAgentJobService() error {
} }
// Create new service instance // Create new service instance
agentJobService := services.NewAgentJobService( agentJobService := agentpool.NewAgentJobService(
a.ApplicationConfig(), a.ApplicationConfig(),
a.ModelLoader(), a.ModelLoader(),
a.ModelConfigLoader(), a.ModelConfigLoader(),
a.TemplatesEvaluator(), a.TemplatesEvaluator(),
) )
// Re-apply distributed wiring if available (matches startup.go logic)
if d := a.Distributed(); d != nil {
if d.Dispatcher != nil {
agentJobService.SetDistributedBackends(d.Dispatcher)
}
if d.JobStore != nil {
agentJobService.SetDistributedJobStore(d.JobStore)
}
}
// Start the service // Start the service
err := agentJobService.Start(a.ApplicationConfig().Context) err := agentJobService.Start(a.ApplicationConfig().Context)
if err != nil { if err != nil {

View file

@ -2,12 +2,16 @@ package application
import ( import (
"context" "context"
"math/rand/v2"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp" mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
"github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/core/services/agentpool"
"github.com/mudler/LocalAI/core/services/galleryop"
"github.com/mudler/LocalAI/core/services/nodes"
"github.com/mudler/LocalAI/core/templates" "github.com/mudler/LocalAI/core/templates"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/xlog" "github.com/mudler/xlog"
@ -20,9 +24,9 @@ type Application struct {
applicationConfig *config.ApplicationConfig applicationConfig *config.ApplicationConfig
startupConfig *config.ApplicationConfig // Stores original config from env vars (before file loading) startupConfig *config.ApplicationConfig // Stores original config from env vars (before file loading)
templatesEvaluator *templates.Evaluator templatesEvaluator *templates.Evaluator
galleryService *services.GalleryService galleryService *galleryop.GalleryService
agentJobService *services.AgentJobService agentJobService *agentpool.AgentJobService
agentPoolService atomic.Pointer[services.AgentPoolService] agentPoolService atomic.Pointer[agentpool.AgentPoolService]
authDB *gorm.DB authDB *gorm.DB
watchdogMutex sync.Mutex watchdogMutex sync.Mutex
watchdogStop chan bool watchdogStop chan bool
@ -30,6 +34,9 @@ type Application struct {
p2pCtx context.Context p2pCtx context.Context
p2pCancel context.CancelFunc p2pCancel context.CancelFunc
agentJobMutex sync.Mutex agentJobMutex sync.Mutex
// Distributed mode services (nil when not in distributed mode)
distributed *DistributedServices
} }
func newApplication(appConfig *config.ApplicationConfig) *Application { func newApplication(appConfig *config.ApplicationConfig) *Application {
@ -64,15 +71,15 @@ func (a *Application) TemplatesEvaluator() *templates.Evaluator {
return a.templatesEvaluator return a.templatesEvaluator
} }
func (a *Application) GalleryService() *services.GalleryService { func (a *Application) GalleryService() *galleryop.GalleryService {
return a.galleryService return a.galleryService
} }
func (a *Application) AgentJobService() *services.AgentJobService { func (a *Application) AgentJobService() *agentpool.AgentJobService {
return a.agentJobService return a.agentJobService
} }
func (a *Application) AgentPoolService() *services.AgentPoolService { func (a *Application) AgentPoolService() *agentpool.AgentPoolService {
return a.agentPoolService.Load() return a.agentPoolService.Load()
} }
@ -86,8 +93,53 @@ func (a *Application) StartupConfig() *config.ApplicationConfig {
return a.startupConfig return a.startupConfig
} }
// Distributed returns the distributed services, or nil if not in distributed mode.
func (a *Application) Distributed() *DistributedServices {
return a.distributed
}
// IsDistributed returns true if the application is running in distributed mode.
func (a *Application) IsDistributed() bool {
return a.distributed != nil
}
// waitForHealthyWorker blocks until at least one healthy backend worker is registered.
// This prevents the agent pool from failing during startup when workers haven't connected yet.
func (a *Application) waitForHealthyWorker() {
maxWait := a.applicationConfig.Distributed.WorkerWaitTimeoutOrDefault()
const basePoll = 2 * time.Second
xlog.Info("Waiting for at least one healthy backend worker before starting agent pool")
deadline := time.Now().Add(maxWait)
for time.Now().Before(deadline) {
registered, err := a.distributed.Registry.List(context.Background())
if err == nil {
for _, n := range registered {
if n.NodeType == nodes.NodeTypeBackend && n.Status == nodes.StatusHealthy {
xlog.Info("Healthy backend worker found", "node", n.Name)
return
}
}
}
// Add 0-1s jitter to prevent thundering-herd on the node registry
jitter := time.Duration(rand.Int64N(int64(time.Second)))
select {
case <-a.applicationConfig.Context.Done():
return
case <-time.After(basePoll + jitter):
}
}
xlog.Warn("No healthy backend worker found after waiting, proceeding anyway")
}
// InstanceID returns the unique identifier for this frontend instance.
func (a *Application) InstanceID() string {
return a.applicationConfig.Distributed.InstanceID
}
func (a *Application) start() error { func (a *Application) start() error {
galleryService := services.NewGalleryService(a.ApplicationConfig(), a.ModelLoader()) galleryService := galleryop.NewGalleryService(a.ApplicationConfig(), a.ModelLoader())
err := galleryService.Start(a.ApplicationConfig().Context, a.ModelConfigLoader(), a.ApplicationConfig().SystemState) err := galleryService.Start(a.ApplicationConfig().Context, a.ModelConfigLoader(), a.ApplicationConfig().SystemState)
if err != nil { if err != nil {
return err return err
@ -95,19 +147,14 @@ func (a *Application) start() error {
a.galleryService = galleryService a.galleryService = galleryService
// Initialize agent job service // Initialize agent job service (Start() is deferred to after distributed wiring)
agentJobService := services.NewAgentJobService( agentJobService := agentpool.NewAgentJobService(
a.ApplicationConfig(), a.ApplicationConfig(),
a.ModelLoader(), a.ModelLoader(),
a.ModelConfigLoader(), a.ModelConfigLoader(),
a.TemplatesEvaluator(), a.TemplatesEvaluator(),
) )
err = agentJobService.Start(a.ApplicationConfig().Context)
if err != nil {
return err
}
a.agentJobService = agentJobService a.agentJobService = agentJobService
return nil return nil
@ -120,27 +167,56 @@ func (a *Application) StartAgentPool() {
if !a.applicationConfig.AgentPool.Enabled { if !a.applicationConfig.AgentPool.Enabled {
return return
} }
aps, err := services.NewAgentPoolService(a.applicationConfig) // Build options struct from available dependencies
opts := agentpool.AgentPoolOptions{
AuthDB: a.authDB,
}
if d := a.Distributed(); d != nil {
if d.DistStores != nil && d.DistStores.Skills != nil {
opts.SkillStore = d.DistStores.Skills
}
opts.NATSClient = d.Nats
opts.EventBridge = d.AgentBridge
opts.AgentStore = d.AgentStore
}
aps, err := agentpool.NewAgentPoolService(a.applicationConfig, opts)
if err != nil { if err != nil {
xlog.Error("Failed to create agent pool service", "error", err) xlog.Error("Failed to create agent pool service", "error", err)
return return
} }
if a.authDB != nil {
aps.SetAuthDB(a.authDB) // Wire distributed mode components
if d := a.Distributed(); d != nil {
// Wait for at least one healthy backend worker before starting the agent pool.
// Collections initialization calls embeddings which require a worker.
if d.Registry != nil {
a.waitForHealthyWorker()
}
} }
if err := aps.Start(a.applicationConfig.Context); err != nil { if err := aps.Start(a.applicationConfig.Context); err != nil {
xlog.Error("Failed to start agent pool", "error", err) xlog.Error("Failed to start agent pool", "error", err)
return return
} }
// Wire per-user scoped services so collections, skills, and jobs are isolated per user // Wire per-user scoped services so collections, skills, and jobs are isolated per user
usm := services.NewUserServicesManager( usm := agentpool.NewUserServicesManager(
aps.UserStorage(), aps.UserStorage(),
a.applicationConfig, a.applicationConfig,
a.modelLoader, a.modelLoader,
a.backendLoader, a.backendLoader,
a.templatesEvaluator, a.templatesEvaluator,
) )
// Wire distributed backends to per-user job services
if a.agentJobService != nil {
if d := a.agentJobService.Dispatcher(); d != nil {
usm.SetJobDispatcher(d)
}
if s := a.agentJobService.DBStore(); s != nil {
usm.SetJobDBStore(s)
}
}
aps.SetUserServicesManager(usm) aps.SetUserServicesManager(usm)
a.agentPoolService.Store(aps) a.agentPoolService.Store(aps)

View file

@ -0,0 +1,267 @@
package application
import (
"context"
"encoding/json"
"fmt"
"io"
"strings"
"sync"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/services/agents"
"github.com/mudler/LocalAI/core/services/distributed"
"github.com/mudler/LocalAI/core/services/jobs"
"github.com/mudler/LocalAI/core/services/messaging"
"github.com/mudler/LocalAI/core/services/nodes"
"github.com/mudler/LocalAI/core/services/storage"
"github.com/mudler/LocalAI/pkg/sanitize"
"github.com/mudler/xlog"
"gorm.io/gorm"
)
// DistributedServices holds all services initialized for distributed mode.
type DistributedServices struct {
Nats *messaging.Client
Store storage.ObjectStore
Registry *nodes.NodeRegistry
Router *nodes.SmartRouter
Health *nodes.HealthMonitor
JobStore *jobs.JobStore
Dispatcher *jobs.Dispatcher
AgentStore *agents.AgentStore
AgentBridge *agents.EventBridge
DistStores *distributed.Stores
FileMgr *storage.FileManager
FileStager nodes.FileStager
ModelAdapter *nodes.ModelRouterAdapter
Unloader *nodes.RemoteUnloaderAdapter
shutdownOnce sync.Once
}
// Shutdown stops all distributed services in reverse initialization order.
// It is safe to call on a nil receiver and is idempotent (uses sync.Once).
func (ds *DistributedServices) Shutdown() {
if ds == nil {
return
}
ds.shutdownOnce.Do(func() {
if ds.Health != nil {
ds.Health.Stop()
}
if ds.Dispatcher != nil {
ds.Dispatcher.Stop()
}
if closer, ok := ds.Store.(io.Closer); ok {
closer.Close()
}
// AgentBridge has no Close method — its NATS subscriptions are cleaned up
// when the NATS client is closed below.
if ds.Nats != nil {
ds.Nats.Close()
}
xlog.Info("Distributed services shut down")
})
}
// initDistributed validates distributed mode prerequisites and initializes
// NATS, object storage, node registry, and instance identity.
// Returns nil if distributed mode is not enabled.
func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB) (*DistributedServices, error) {
if !cfg.Distributed.Enabled {
return nil, nil
}
xlog.Info("Distributed mode enabled — validating prerequisites")
// Validate distributed config (NATS URL, S3 credential pairing, durations, etc.)
if err := cfg.Distributed.Validate(); err != nil {
return nil, err
}
// Validate PostgreSQL is configured (auth DB must be PostgreSQL for distributed mode)
if !cfg.Auth.Enabled {
return nil, fmt.Errorf("distributed mode requires authentication to be enabled (--auth / LOCALAI_AUTH=true)")
}
if !isPostgresURL(cfg.Auth.DatabaseURL) {
return nil, fmt.Errorf("distributed mode requires PostgreSQL for auth database (got %q)", sanitize.URL(cfg.Auth.DatabaseURL))
}
// Generate instance ID if not set
if cfg.Distributed.InstanceID == "" {
cfg.Distributed.InstanceID = uuid.New().String()
}
xlog.Info("Distributed instance", "id", cfg.Distributed.InstanceID)
// Connect to NATS
natsClient, err := messaging.New(cfg.Distributed.NatsURL)
if err != nil {
return nil, fmt.Errorf("connecting to NATS: %w", err)
}
xlog.Info("Connected to NATS", "url", sanitize.URL(cfg.Distributed.NatsURL))
// Ensure NATS is closed if any subsequent initialization step fails.
success := false
defer func() {
if !success {
natsClient.Close()
}
}()
// Initialize object storage
var store storage.ObjectStore
if cfg.Distributed.StorageURL != "" {
if cfg.Distributed.StorageBucket == "" {
return nil, fmt.Errorf("distributed storage bucket must be set when storage URL is configured")
}
s3Store, err := storage.NewS3Store(context.Background(), storage.S3Config{
Endpoint: cfg.Distributed.StorageURL,
Region: cfg.Distributed.StorageRegion,
Bucket: cfg.Distributed.StorageBucket,
AccessKeyID: cfg.Distributed.StorageAccessKey,
SecretAccessKey: cfg.Distributed.StorageSecretKey,
ForcePathStyle: true, // required for MinIO
})
if err != nil {
return nil, fmt.Errorf("initializing S3 storage: %w", err)
}
xlog.Info("Object storage initialized (S3)", "endpoint", cfg.Distributed.StorageURL, "bucket", cfg.Distributed.StorageBucket)
store = s3Store
} else {
// Fallback to filesystem storage in distributed mode (useful for single-node testing)
fsStore, err := storage.NewFilesystemStore(cfg.DataPath + "/objectstore")
if err != nil {
return nil, fmt.Errorf("initializing filesystem storage: %w", err)
}
xlog.Info("Object storage initialized (filesystem fallback)", "path", cfg.DataPath+"/objectstore")
store = fsStore
}
// Initialize node registry (requires the auth DB which is PostgreSQL)
if authDB == nil {
return nil, fmt.Errorf("distributed mode requires auth database to be initialized first")
}
registry, err := nodes.NewNodeRegistry(authDB)
if err != nil {
return nil, fmt.Errorf("initializing node registry: %w", err)
}
xlog.Info("Node registry initialized")
// Collect SmartRouter option values; the router itself is created after all
// dependencies (including FileStager and Unloader) are ready.
var routerAuthToken string
if cfg.Distributed.RegistrationToken != "" {
routerAuthToken = cfg.Distributed.RegistrationToken
}
var routerGalleriesJSON string
if galleriesJSON, err := json.Marshal(cfg.BackendGalleries); err == nil {
routerGalleriesJSON = string(galleriesJSON)
}
healthMon := nodes.NewHealthMonitor(registry, authDB,
cfg.Distributed.HealthCheckIntervalOrDefault(),
cfg.Distributed.StaleNodeThresholdOrDefault(),
routerAuthToken,
cfg.Distributed.PerModelHealthCheck,
)
// Initialize job store
jobStore, err := jobs.NewJobStore(authDB)
if err != nil {
return nil, fmt.Errorf("initializing job store: %w", err)
}
xlog.Info("Distributed job store initialized")
// Initialize job dispatcher
dispatcher := jobs.NewDispatcher(jobStore, natsClient, authDB, cfg.Distributed.InstanceID, cfg.Distributed.JobWorkerConcurrency)
// Initialize agent store
agentStore, err := agents.NewAgentStore(authDB)
if err != nil {
return nil, fmt.Errorf("initializing agent store: %w", err)
}
xlog.Info("Distributed agent store initialized")
// Initialize agent event bridge
agentBridge := agents.NewEventBridge(natsClient, agentStore, cfg.Distributed.InstanceID)
// Start observable persister — captures observable_update events from workers
// (which have no DB access) and persists them to PostgreSQL.
if err := agentBridge.StartObservablePersister(); err != nil {
xlog.Warn("Failed to start observable persister", "error", err)
} else {
xlog.Info("Observable persister started")
}
// Initialize Phase 4 stores (MCP, Gallery, FineTune, Skills)
distStores, err := distributed.InitStores(authDB)
if err != nil {
return nil, fmt.Errorf("initializing distributed stores: %w", err)
}
// Initialize file manager with local cache
cacheDir := cfg.DataPath + "/cache"
fileMgr, err := storage.NewFileManager(store, cacheDir)
if err != nil {
return nil, fmt.Errorf("initializing file manager: %w", err)
}
xlog.Info("File manager initialized", "cacheDir", cacheDir)
// Create FileStager for distributed file transfer
var fileStager nodes.FileStager
if cfg.Distributed.StorageURL != "" {
fileStager = nodes.NewS3NATSFileStager(fileMgr, natsClient)
xlog.Info("File stager initialized (S3+NATS)")
} else {
fileStager = nodes.NewHTTPFileStager(func(nodeID string) (string, error) {
node, err := registry.Get(context.Background(), nodeID)
if err != nil {
return "", err
}
if node.HTTPAddress == "" {
return "", fmt.Errorf("node %s has no HTTP address for file transfer", nodeID)
}
return node.HTTPAddress, nil
}, cfg.Distributed.RegistrationToken)
xlog.Info("File stager initialized (HTTP direct transfer)")
}
// Create RemoteUnloaderAdapter — needed by SmartRouter and startup.go
remoteUnloader := nodes.NewRemoteUnloaderAdapter(registry, natsClient)
// All dependencies ready — build SmartRouter with all options at once
router := nodes.NewSmartRouter(registry, nodes.SmartRouterOptions{
Unloader: remoteUnloader,
FileStager: fileStager,
GalleriesJSON: routerGalleriesJSON,
AuthToken: routerAuthToken,
DB: authDB,
})
// Create ModelRouterAdapter to wire into ModelLoader
modelAdapter := nodes.NewModelRouterAdapter(router)
success = true
return &DistributedServices{
Nats: natsClient,
Store: store,
Registry: registry,
Router: router,
Health: healthMon,
JobStore: jobStore,
Dispatcher: dispatcher,
AgentStore: agentStore,
AgentBridge: agentBridge,
DistStores: distStores,
FileMgr: fileMgr,
FileStager: fileStager,
ModelAdapter: modelAdapter,
Unloader: remoteUnloader,
}, nil
}
func isPostgresURL(url string) bool {
return strings.HasPrefix(url, "postgres://") || strings.HasPrefix(url, "postgresql://")
}

View file

@ -11,7 +11,7 @@ import (
"github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/p2p" "github.com/mudler/LocalAI/core/p2p"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/core/services/galleryop"
"github.com/mudler/edgevpn/pkg/node" "github.com/mudler/edgevpn/pkg/node"
"github.com/mudler/xlog" "github.com/mudler/xlog"
@ -146,22 +146,14 @@ func (a *Application) RestartP2P() error {
return fmt.Errorf("P2P token is not set") return fmt.Errorf("P2P token is not set")
} }
// Create new context for P2P
ctx, cancel := context.WithCancel(appConfig.Context)
a.p2pCtx = ctx
a.p2pCancel = cancel
// Get API address from config
address := appConfig.APIAddress
if address == "" {
address = "127.0.0.1:8080" // default
}
// Start P2P stack in a goroutine // Start P2P stack in a goroutine
// Note: StartP2P creates its own context and assigns a.p2pCtx/a.p2pCancel
go func() { go func() {
if err := a.StartP2P(); err != nil { if err := a.StartP2P(); err != nil {
xlog.Error("Failed to start P2P stack", "error", err) xlog.Error("Failed to start P2P stack", "error", err)
cancel() // Cancel context on error if a.p2pCancel != nil {
a.p2pCancel()
}
} }
}() }()
xlog.Info("P2P stack restarted with new settings") xlog.Info("P2P stack restarted with new settings")
@ -228,7 +220,7 @@ func syncState(ctx context.Context, n *node.Node, app *Application) error {
continue continue
} }
app.GalleryService().ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{ app.GalleryService().ModelGalleryChannel <- galleryop.ManagementOp[gallery.GalleryModel, gallery.ModelConfig]{
ID: uuid.String(), ID: uuid.String(),
GalleryElementName: model, GalleryElementName: model,
Galleries: app.ApplicationConfig().Galleries, Galleries: app.ApplicationConfig().Galleries,

View file

@ -13,11 +13,15 @@ import (
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/http/auth" "github.com/mudler/LocalAI/core/http/auth"
"github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/core/services/galleryop"
"github.com/mudler/LocalAI/core/services/jobs"
"github.com/mudler/LocalAI/core/services/nodes"
"github.com/mudler/LocalAI/core/services/storage"
coreStartup "github.com/mudler/LocalAI/core/startup" coreStartup "github.com/mudler/LocalAI/core/startup"
"github.com/mudler/LocalAI/internal" "github.com/mudler/LocalAI/internal"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/sanitize"
"github.com/mudler/LocalAI/pkg/xsysinfo" "github.com/mudler/LocalAI/pkg/xsysinfo"
"github.com/mudler/xlog" "github.com/mudler/xlog"
) )
@ -101,7 +105,7 @@ func New(opts ...config.AppOption) (*Application, error) {
return nil, fmt.Errorf("failed to initialize auth database: %w", err) return nil, fmt.Errorf("failed to initialize auth database: %w", err)
} }
application.authDB = authDB application.authDB = authDB
xlog.Info("Auth enabled", "database", options.Auth.DatabaseURL) xlog.Info("Auth enabled", "database", sanitize.URL(options.Auth.DatabaseURL))
// Start session and expired API key cleanup goroutine // Start session and expired API key cleanup goroutine
go func() { go func() {
@ -123,12 +127,92 @@ func New(opts ...config.AppOption) (*Application, error) {
}() }()
} }
// Wire JobStore for DB-backed task/job persistence whenever auth DB is available.
// This ensures tasks and jobs survive restarts in both single-node and distributed modes.
if application.authDB != nil && application.agentJobService != nil {
dbJobStore, err := jobs.NewJobStore(application.authDB)
if err != nil {
xlog.Error("Failed to create job store for auth DB", "error", err)
} else {
application.agentJobService.SetDistributedJobStore(dbJobStore)
}
}
// Initialize distributed mode services (NATS, object storage, node registry)
distSvc, err := initDistributed(options, application.authDB)
if err != nil {
return nil, fmt.Errorf("distributed mode initialization failed: %w", err)
}
if distSvc != nil {
application.distributed = distSvc
// Wire remote model unloader so ShutdownModel works for remote nodes
// Uses NATS to tell serve-backend nodes to Free + kill their backend process
application.modelLoader.SetRemoteUnloader(distSvc.Unloader)
// Wire ModelRouter so grpcModel() delegates to SmartRouter in distributed mode
application.modelLoader.SetModelRouter(distSvc.ModelAdapter.AsModelRouter())
// Wire DistributedModelStore so shutdown/list/watchdog can find remote models
distStore := nodes.NewDistributedModelStore(
model.NewInMemoryModelStore(),
distSvc.Registry,
)
application.modelLoader.SetModelStore(distStore)
// Start health monitor
distSvc.Health.Start(options.Context)
// In distributed mode, MCP CI jobs are executed by agent workers (not the frontend)
// because the frontend can't create MCP sessions (e.g., stdio servers using docker).
// The dispatcher still subscribes to jobs.new for persistence (result/progress subs)
// but does NOT set a workerFn — agent workers consume jobs from the same NATS queue.
// Wire model config loader so job events include model config for agent workers
distSvc.Dispatcher.SetModelConfigLoader(application.backendLoader)
// Start job dispatcher — abort startup if it fails, as jobs would be accepted but never dispatched
if err := distSvc.Dispatcher.Start(options.Context); err != nil {
return nil, fmt.Errorf("starting job dispatcher: %w", err)
}
// Start ephemeral file cleanup
storage.StartEphemeralCleanup(options.Context, distSvc.FileMgr, 0, 0)
// Wire distributed backends into AgentJobService (before Start)
if application.agentJobService != nil {
application.agentJobService.SetDistributedBackends(distSvc.Dispatcher)
application.agentJobService.SetDistributedJobStore(distSvc.JobStore)
}
// Wire skill store into AgentPoolService (wired at pool start time via closure)
// The actual wiring happens in StartAgentPool since the pool doesn't exist yet.
// Wire NATS and gallery store into GalleryService for cross-instance progress/cancel
if application.galleryService != nil {
application.galleryService.SetNATSClient(distSvc.Nats)
if distSvc.DistStores != nil && distSvc.DistStores.Gallery != nil {
// Clean up stale in-progress operations from previous crashed instances
if err := distSvc.DistStores.Gallery.CleanStale(30 * time.Minute); err != nil {
xlog.Warn("Failed to clean stale gallery operations", "error", err)
}
application.galleryService.SetGalleryStore(distSvc.DistStores.Gallery)
}
// Wire distributed model/backend managers so delete propagates to workers
application.galleryService.SetModelManager(
nodes.NewDistributedModelManager(options, application.modelLoader, distSvc.Unloader),
)
application.galleryService.SetBackendManager(
nodes.NewDistributedBackendManager(options, application.modelLoader, distSvc.Unloader, distSvc.Registry),
)
}
}
// Start AgentJobService (after distributed wiring so it knows whether to use local or NATS)
if application.agentJobService != nil {
if err := application.agentJobService.Start(options.Context); err != nil {
return nil, fmt.Errorf("starting agent job service: %w", err)
}
}
if err := coreStartup.InstallModels(options.Context, application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil { if err := coreStartup.InstallModels(options.Context, application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
xlog.Error("error installing models", "error", err) xlog.Error("error installing models", "error", err)
} }
for _, backend := range options.ExternalBackends { for _, backend := range options.ExternalBackends {
if err := services.InstallExternalBackend(options.Context, options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil { if err := galleryop.InstallExternalBackend(options.Context, options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
xlog.Error("error installing external backend", "error", err) xlog.Error("error installing external backend", "error", err)
} }
} }
@ -154,13 +238,13 @@ func New(opts ...config.AppOption) (*Application, error) {
} }
if options.PreloadJSONModels != "" { if options.PreloadJSONModels != "" {
if err := services.ApplyGalleryFromString(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadJSONModels); err != nil { if err := galleryop.ApplyGalleryFromString(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadJSONModels); err != nil {
return nil, err return nil, err
} }
} }
if options.PreloadModelsFromPath != "" { if options.PreloadModelsFromPath != "" {
if err := services.ApplyGalleryFromFile(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadModelsFromPath); err != nil { if err := galleryop.ApplyGalleryFromFile(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadModelsFromPath); err != nil {
return nil, err return nil, err
} }
} }
@ -184,6 +268,7 @@ func New(opts ...config.AppOption) (*Application, error) {
go func() { go func() {
<-options.Context.Done() <-options.Context.Done()
xlog.Debug("Context canceled, shutting down") xlog.Debug("Context canceled, shutting down")
application.distributed.Shutdown()
err := application.ModelLoader().StopAllGRPC() err := application.ModelLoader().StopAllGRPC()
if err != nil { if err != nil {
xlog.Error("error while stopping all grpc backends", "error", err) xlog.Error("error while stopping all grpc backends", "error", err)
@ -207,7 +292,7 @@ func New(opts ...config.AppOption) (*Application, error) {
var backendErr error var backendErr error
_, backendErr = application.ModelLoader().Load(o...) _, backendErr = application.ModelLoader().Load(o...)
if backendErr != nil { if backendErr != nil {
return nil, err return nil, backendErr
} }
} }
} }

View file

@ -13,9 +13,9 @@ import (
"github.com/mudler/xlog" "github.com/mudler/xlog"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/trace"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/core/services/galleryop"
"github.com/mudler/LocalAI/core/trace"
"github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/pkg/grpc/proto" "github.com/mudler/LocalAI/pkg/grpc/proto"
@ -27,7 +27,7 @@ type LLMResponse struct {
Response string // should this be []byte? Response string // should this be []byte?
Usage TokenUsage Usage TokenUsage
AudioOutput string AudioOutput string
Logprobs *schema.Logprobs // Logprobs from the backend response Logprobs *schema.Logprobs // Logprobs from the backend response
ChatDeltas []*proto.ChatDelta // Pre-parsed tool calls/content from C++ autoparser ChatDeltas []*proto.ChatDelta // Pre-parsed tool calls/content from C++ autoparser
} }
@ -47,14 +47,18 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
// Check if the modelFile exists, if it doesn't try to load it from the gallery // Check if the modelFile exists, if it doesn't try to load it from the gallery
if o.AutoloadGalleries { // experimental if o.AutoloadGalleries { // experimental
modelNames, err := services.ListModels(cl, loader, nil, services.SKIP_ALWAYS) modelNames, err := galleryop.ListModels(cl, loader, nil, galleryop.SKIP_ALWAYS)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !slices.Contains(modelNames, c.Name) { modelName := c.Name
if modelName == "" {
modelName = c.Model
}
if !slices.Contains(modelNames, modelName) {
utils.ResetDownloadTimers() utils.ResetDownloadTimers()
// if we failed to load the model, we try to download it // if we failed to load the model, we try to download it
err := gallery.InstallModelFromGallery(ctx, o.Galleries, o.BackendGalleries, o.SystemState, loader, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries) err := gallery.InstallModelFromGallery(ctx, o.Galleries, o.BackendGalleries, o.SystemState, loader, modelName, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
if err != nil { if err != nil {
xlog.Error("failed to install model from gallery", "error", err, "model", modelFile) xlog.Error("failed to install model from gallery", "error", err, "model", modelFile)
//return nil, err //return nil, err
@ -252,12 +256,12 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
trace.InitBackendTracingIfEnabled(o.TracingMaxItems) trace.InitBackendTracingIfEnabled(o.TracingMaxItems)
traceData := map[string]any{ traceData := map[string]any{
"chat_template": c.TemplateConfig.Chat, "chat_template": c.TemplateConfig.Chat,
"function_template": c.TemplateConfig.Functions, "function_template": c.TemplateConfig.Functions,
"streaming": tokenCallback != nil, "streaming": tokenCallback != nil,
"images_count": len(images), "images_count": len(images),
"videos_count": len(videos), "videos_count": len(videos),
"audios_count": len(audios), "audios_count": len(audios),
} }
if len(messages) > 0 { if len(messages) > 0 {

View file

@ -1,7 +1,7 @@
package backend package backend
import ( import (
"math/rand" "math/rand/v2"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
@ -86,7 +86,7 @@ func getSeed(c config.ModelConfig) int32 {
} }
if seed == config.RAND_SEED { if seed == config.RAND_SEED {
seed = rand.Int31() seed = rand.Int32()
} }
return seed return seed

View file

@ -4,8 +4,8 @@ import (
"time" "time"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/trace"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/trace"
"github.com/mudler/LocalAI/pkg/grpc" "github.com/mudler/LocalAI/pkg/grpc"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
) )

View file

@ -8,11 +8,11 @@ import (
"os/signal" "os/signal"
"syscall" "syscall"
cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAGI/core/state" "github.com/mudler/LocalAGI/core/state"
coreTypes "github.com/mudler/LocalAGI/core/types" coreTypes "github.com/mudler/LocalAGI/core/types"
cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/services/agentpool"
"github.com/mudler/xlog" "github.com/mudler/xlog"
) )
@ -59,7 +59,7 @@ func (r *AgentRunCMD) Run(ctx *cliContext.Context) error {
appConfig := r.buildAppConfig() appConfig := r.buildAppConfig()
poolService, err := services.NewAgentPoolService(appConfig) poolService, err := agentpool.NewAgentPoolService(appConfig)
if err != nil { if err != nil {
return fmt.Errorf("failed to create agent pool service: %w", err) return fmt.Errorf("failed to create agent pool service: %w", err)
} }

463
core/cli/agent_worker.go Normal file
View file

@ -0,0 +1,463 @@
package cli
import (
"cmp"
"context"
"encoding/json"
"fmt"
"os"
"os/signal"
"strings"
"syscall"
"time"
cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/cli/workerregistry"
"github.com/mudler/LocalAI/core/config"
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
"github.com/mudler/LocalAI/core/services/agents"
"github.com/mudler/LocalAI/core/services/jobs"
mcpRemote "github.com/mudler/LocalAI/core/services/mcp"
"github.com/mudler/LocalAI/core/services/messaging"
"github.com/mudler/LocalAI/pkg/sanitize"
"github.com/mudler/cogito"
"github.com/mudler/cogito/clients"
"github.com/mudler/xlog"
)
// AgentWorkerCMD starts a dedicated agent worker process for distributed mode.
// It registers with the frontend, subscribes to the NATS agent execution queue,
// and executes agent chats using cogito. The worker is a pure executor — it
// receives the full agent config and skills in the NATS job payload, so it
// does not need direct database access.
//
// Usage:
//
// localai agent-worker --nats-url nats://... --register-to http://localai:8080
type AgentWorkerCMD struct {
// NATS (required)
NatsURL string `env:"LOCALAI_NATS_URL" required:"" help:"NATS server URL" group:"distributed"`
// Registration (required)
RegisterTo string `env:"LOCALAI_REGISTER_TO" required:"" help:"Frontend URL for registration" group:"registration"`
NodeName string `env:"LOCALAI_NODE_NAME" help:"Node name for registration (defaults to hostname)" group:"registration"`
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token for authenticating with the frontend" group:"registration"`
HeartbeatInterval string `env:"LOCALAI_HEARTBEAT_INTERVAL" default:"10s" help:"Interval between heartbeats" group:"registration"`
// API access
APIURL string `env:"LOCALAI_API_URL" help:"LocalAI API URL for inference (auto-derived from RegisterTo if not set)" group:"api"`
APIToken string `env:"LOCALAI_API_TOKEN" help:"API token for LocalAI inference (auto-provisioned during registration if not set)" group:"api"`
// NATS subjects
Subject string `env:"LOCALAI_AGENT_SUBJECT" default:"agent.execute" help:"NATS subject for agent execution" group:"distributed"`
Queue string `env:"LOCALAI_AGENT_QUEUE" default:"agent-workers" help:"NATS queue group name" group:"distributed"`
// Timeouts
MCPCIJobTimeout string `env:"LOCALAI_MCP_CI_JOB_TIMEOUT" default:"10m" help:"Timeout for MCP CI job execution" group:"distributed"`
}
func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
xlog.Info("Starting agent worker", "nats", sanitize.URL(cmd.NatsURL), "register_to", cmd.RegisterTo)
// Resolve API URL
apiURL := cmp.Or(cmd.APIURL, strings.TrimRight(cmd.RegisterTo, "/"))
// Register with frontend
regClient := &workerregistry.RegistrationClient{
FrontendURL: cmd.RegisterTo,
RegistrationToken: cmd.RegistrationToken,
}
nodeName := cmd.NodeName
if nodeName == "" {
hostname, _ := os.Hostname()
nodeName = "agent-" + hostname
}
registrationBody := map[string]any{
"name": nodeName,
"node_type": "agent",
}
if cmd.RegistrationToken != "" {
registrationBody["token"] = cmd.RegistrationToken
}
nodeID, apiToken, err := regClient.RegisterWithRetry(context.Background(), registrationBody, 10)
if err != nil {
return fmt.Errorf("registration failed: %w", err)
}
xlog.Info("Registered with frontend", "nodeID", nodeID, "frontend", cmd.RegisterTo)
// Use provisioned API token if none was set
if cmd.APIToken == "" {
cmd.APIToken = apiToken
}
// Start heartbeat
heartbeatInterval, err := time.ParseDuration(cmd.HeartbeatInterval)
if err != nil && cmd.HeartbeatInterval != "" {
xlog.Warn("invalid heartbeat interval, using default 10s", "input", cmd.HeartbeatInterval, "error", err)
}
heartbeatInterval = cmp.Or(heartbeatInterval, 10*time.Second)
// Context cancelled on shutdown — used by heartbeat and other background goroutines
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
defer shutdownCancel()
go regClient.HeartbeatLoop(shutdownCtx, nodeID, heartbeatInterval, func() map[string]any { return map[string]any{} })
// Connect to NATS
natsClient, err := messaging.New(cmd.NatsURL)
if err != nil {
return fmt.Errorf("connecting to NATS: %w", err)
}
defer natsClient.Close()
// Create event bridge for publishing results back via NATS
eventBridge := agents.NewEventBridge(natsClient, nil, "agent-worker-"+nodeID)
// Start cancel listener
cancelSub, err := eventBridge.StartCancelListener()
if err != nil {
xlog.Warn("Failed to start cancel listener", "error", err)
} else {
defer cancelSub.Unsubscribe()
}
// Create and start the NATS dispatcher.
// No ConfigProvider or SkillStore needed — config and skills arrive in the job payload.
dispatcher := agents.NewNATSDispatcher(
natsClient,
eventBridge,
nil, // no ConfigProvider: config comes in the enriched NATS payload
apiURL, cmd.APIToken,
cmd.Subject, cmd.Queue,
0, // no concurrency limit (CLI worker)
)
if err := dispatcher.Start(shutdownCtx); err != nil {
return fmt.Errorf("starting dispatcher: %w", err)
}
// Subscribe to MCP tool execution requests (load-balanced across workers).
// The frontend routes model-level MCP tool calls here via NATS request-reply.
if _, err := natsClient.QueueSubscribeReply(messaging.SubjectMCPToolExecute, messaging.QueueAgentWorkers, func(data []byte, reply func([]byte)) {
handleMCPToolRequest(data, reply)
}); err != nil {
return fmt.Errorf("subscribing to %s: %w", messaging.SubjectMCPToolExecute, err)
}
// Subscribe to MCP discovery requests (load-balanced across workers).
if _, err := natsClient.QueueSubscribeReply(messaging.SubjectMCPDiscovery, messaging.QueueAgentWorkers, func(data []byte, reply func([]byte)) {
handleMCPDiscoveryRequest(data, reply)
}); err != nil {
return fmt.Errorf("subscribing to %s: %w", messaging.SubjectMCPDiscovery, err)
}
// Subscribe to MCP CI job execution (load-balanced across agent workers).
// In distributed mode, MCP CI jobs are routed here because the frontend
// cannot create MCP sessions (e.g., stdio servers using docker).
mcpCIJobTimeout, err := time.ParseDuration(cmd.MCPCIJobTimeout)
if err != nil && cmd.MCPCIJobTimeout != "" {
xlog.Warn("invalid MCP CI job timeout, using default 10m", "input", cmd.MCPCIJobTimeout, "error", err)
}
mcpCIJobTimeout = cmp.Or(mcpCIJobTimeout, config.DefaultMCPCIJobTimeout)
if _, err := natsClient.QueueSubscribe(messaging.SubjectMCPCIJobsNew, messaging.QueueWorkers, func(data []byte) {
handleMCPCIJob(shutdownCtx, data, apiURL, cmd.APIToken, natsClient, mcpCIJobTimeout)
}); err != nil {
return fmt.Errorf("subscribing to %s: %w", messaging.SubjectMCPCIJobsNew, err)
}
// Subscribe to backend stop events to clean up cached MCP sessions.
// In the main application this is done via ml.OnModelUnload, but the agent
// worker has no model loader — we listen for the NATS stop event instead.
if _, err := natsClient.Subscribe(messaging.SubjectNodeBackendStop(nodeID), func(data []byte) {
var req struct {
Backend string `json:"backend"`
}
if json.Unmarshal(data, &req) == nil && req.Backend != "" {
mcpTools.CloseMCPSessions(req.Backend)
}
}); err != nil {
return fmt.Errorf("subscribing to %s: %w", messaging.SubjectNodeBackendStop(nodeID), err)
}
xlog.Info("Agent worker ready, waiting for jobs", "subject", cmd.Subject, "queue", cmd.Queue)
// Wait for shutdown
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh
xlog.Info("Shutting down agent worker")
shutdownCancel() // stop heartbeat loop immediately
dispatcher.Stop()
mcpTools.CloseAllMCPSessions()
regClient.GracefulDeregister(nodeID)
return nil
}
// handleMCPToolRequest handles a NATS request-reply for MCP tool execution.
// The worker creates/caches MCP sessions from the serialized config and executes the tool.
func handleMCPToolRequest(data []byte, reply func([]byte)) {
var req mcpRemote.MCPToolRequest
if err := json.Unmarshal(data, &req); err != nil {
sendMCPToolReply(reply, "", fmt.Sprintf("unmarshal error: %v", err))
return
}
ctx, cancel := context.WithTimeout(context.Background(), config.DefaultMCPToolTimeout)
defer cancel()
// Create/cache named MCP sessions from the provided config
namedSessions, err := mcpTools.NamedSessionsFromMCPConfig(req.ModelName, req.RemoteServers, req.StdioServers, nil)
if err != nil {
sendMCPToolReply(reply, "", fmt.Sprintf("session error: %v", err))
return
}
// Discover tools to find the right session
tools, err := mcpTools.DiscoverMCPTools(ctx, namedSessions)
if err != nil {
sendMCPToolReply(reply, "", fmt.Sprintf("discovery error: %v", err))
return
}
// Execute the tool
argsJSON, _ := json.Marshal(req.Arguments)
result, err := mcpTools.ExecuteMCPToolCall(ctx, tools, req.ToolName, string(argsJSON))
if err != nil {
sendMCPToolReply(reply, "", err.Error())
return
}
sendMCPToolReply(reply, result, "")
}
func sendMCPToolReply(reply func([]byte), result, errMsg string) {
resp := mcpRemote.MCPToolResponse{Result: result, Error: errMsg}
data, _ := json.Marshal(resp)
reply(data)
}
// handleMCPDiscoveryRequest handles a NATS request-reply for MCP tool/prompt/resource discovery.
func handleMCPDiscoveryRequest(data []byte, reply func([]byte)) {
var req mcpRemote.MCPDiscoveryRequest
if err := json.Unmarshal(data, &req); err != nil {
sendMCPDiscoveryReply(reply, nil, nil, fmt.Sprintf("unmarshal error: %v", err))
return
}
ctx, cancel := context.WithTimeout(context.Background(), config.DefaultMCPDiscoveryTimeout)
defer cancel()
// Create/cache named MCP sessions
namedSessions, err := mcpTools.NamedSessionsFromMCPConfig(req.ModelName, req.RemoteServers, req.StdioServers, nil)
if err != nil {
sendMCPDiscoveryReply(reply, nil, nil, fmt.Sprintf("session error: %v", err))
return
}
// List servers with their tools/prompts/resources
serverInfos, err := mcpTools.ListMCPServers(ctx, namedSessions)
if err != nil {
sendMCPDiscoveryReply(reply, nil, nil, fmt.Sprintf("list error: %v", err))
return
}
// Also get tool function schemas for the frontend
tools, _ := mcpTools.DiscoverMCPTools(ctx, namedSessions)
var toolDefs []mcpRemote.MCPToolDef
for _, t := range tools {
toolDefs = append(toolDefs, mcpRemote.MCPToolDef{
ServerName: t.ServerName,
ToolName: t.ToolName,
Function: t.Function,
})
}
// Convert server infos
var servers []mcpRemote.MCPServerInfo
for _, s := range serverInfos {
servers = append(servers, mcpRemote.MCPServerInfo{
Name: s.Name,
Type: s.Type,
Tools: s.Tools,
Prompts: s.Prompts,
Resources: s.Resources,
})
}
sendMCPDiscoveryReply(reply, servers, toolDefs, "")
}
func sendMCPDiscoveryReply(reply func([]byte), servers []mcpRemote.MCPServerInfo, tools []mcpRemote.MCPToolDef, errMsg string) {
resp := mcpRemote.MCPDiscoveryResponse{Servers: servers, Tools: tools, Error: errMsg}
data, _ := json.Marshal(resp)
reply(data)
}
// handleMCPCIJob processes an MCP CI job on the agent worker.
// The agent worker can create MCP sessions (has docker) and call the LocalAI API for inference.
func handleMCPCIJob(shutdownCtx context.Context, data []byte, apiURL, apiToken string, natsClient messaging.MessagingClient, jobTimeout time.Duration) {
var evt jobs.JobEvent
if err := json.Unmarshal(data, &evt); err != nil {
xlog.Error("Failed to unmarshal job event", "error", err)
return
}
job := evt.Job
task := evt.Task
if job == nil || task == nil {
xlog.Error("MCP CI job missing enriched data", "jobID", evt.JobID)
publishJobResult(natsClient, evt.JobID, "failed", "", "job or task data missing from NATS event")
return
}
modelCfg := evt.ModelConfig
if modelCfg == nil {
publishJobResult(natsClient, evt.JobID, "failed", "", "model config missing from job event")
return
}
xlog.Info("Processing MCP CI job", "jobID", evt.JobID, "taskID", evt.TaskID, "model", task.Model)
// Publish running status
natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{
JobID: evt.JobID, Status: "running", Message: "Job started on agent worker",
})
// Parse MCP config
if modelCfg.MCP.Servers == "" && modelCfg.MCP.Stdio == "" {
publishJobResult(natsClient, evt.JobID, "failed", "", "no MCP servers configured for model")
return
}
remote, stdio, err := modelCfg.MCP.MCPConfigFromYAML()
if err != nil {
publishJobResult(natsClient, evt.JobID, "failed", "", fmt.Sprintf("failed to parse MCP config: %v", err))
return
}
// Create MCP sessions locally (agent worker has docker)
sessions, err := mcpTools.SessionsFromMCPConfig(modelCfg.Name, remote, stdio)
if err != nil || len(sessions) == 0 {
errMsg := "no working MCP servers found"
if err != nil {
errMsg = fmt.Sprintf("failed to create MCP sessions: %v", err)
}
publishJobResult(natsClient, evt.JobID, "failed", "", errMsg)
return
}
// Build prompt from template
prompt := task.Prompt
if task.CronParametersJSON != "" {
var params map[string]string
if err := json.Unmarshal([]byte(task.CronParametersJSON), &params); err != nil {
xlog.Warn("Failed to unmarshal parameters", "error", err)
}
for k, v := range params {
prompt = strings.ReplaceAll(prompt, "{{."+k+"}}", v)
}
}
if job.ParametersJSON != "" {
var params map[string]string
if err := json.Unmarshal([]byte(job.ParametersJSON), &params); err != nil {
xlog.Warn("Failed to unmarshal parameters", "error", err)
}
for k, v := range params {
prompt = strings.ReplaceAll(prompt, "{{."+k+"}}", v)
}
}
// Create LLM client pointing back to the frontend API
llm := clients.NewLocalAILLM(task.Model, apiToken, apiURL)
// Build cogito options
ctx, cancel := context.WithTimeout(shutdownCtx, jobTimeout)
defer cancel()
// Update job status to running in DB
publishJobStatus(natsClient, evt.JobID, "running", "")
// Buffer stream tokens and flush as complete blocks
var reasoningBuf, contentBuf strings.Builder
var lastStreamType cogito.StreamEventType
flushStreamBuf := func() {
if reasoningBuf.Len() > 0 {
natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{
JobID: evt.JobID, TraceType: "reasoning", TraceContent: reasoningBuf.String(),
})
reasoningBuf.Reset()
}
if contentBuf.Len() > 0 {
natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{
JobID: evt.JobID, TraceType: "content", TraceContent: contentBuf.String(),
})
contentBuf.Reset()
}
}
cogitoOpts := modelCfg.BuildCogitoOptions()
cogitoOpts = append(cogitoOpts,
cogito.WithContext(ctx),
cogito.WithMCPs(sessions...),
cogito.WithStatusCallback(func(status string) {
flushStreamBuf()
natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{
JobID: evt.JobID, TraceType: "status", TraceContent: status,
})
}),
cogito.WithToolCallResultCallback(func(t cogito.ToolStatus) {
flushStreamBuf()
natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{
JobID: evt.JobID, TraceType: "tool_result", TraceContent: fmt.Sprintf("%s: %s", t.Name, t.Result),
})
}),
cogito.WithStreamCallback(func(ev cogito.StreamEvent) {
// Flush if stream type changed (e.g., reasoning → content)
if ev.Type != lastStreamType {
flushStreamBuf()
lastStreamType = ev.Type
}
switch ev.Type {
case cogito.StreamEventReasoning:
reasoningBuf.WriteString(ev.Content)
case cogito.StreamEventContent:
contentBuf.WriteString(ev.Content)
case cogito.StreamEventToolCall:
natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{
JobID: evt.JobID, TraceType: "tool_call", TraceContent: fmt.Sprintf("%s(%s)", ev.ToolName, ev.ToolArgs),
})
}
}),
)
// Execute via cogito
fragment := cogito.NewEmptyFragment()
fragment = fragment.AddMessage("user", prompt)
f, err := cogito.ExecuteTools(llm, fragment, cogitoOpts...)
flushStreamBuf() // flush any remaining buffered tokens
if err != nil {
publishJobResult(natsClient, evt.JobID, "failed", "", fmt.Sprintf("cogito execution failed: %v", err))
return
}
result := ""
if msg := f.LastMessage(); msg != nil {
result = msg.Content
}
publishJobResult(natsClient, evt.JobID, "completed", result, "")
xlog.Info("MCP CI job completed", "jobID", evt.JobID, "resultLen", len(result))
}
func publishJobStatus(nc messaging.MessagingClient, jobID, status, message string) {
jobs.PublishJobProgress(nc, jobID, status, message)
}
func publishJobResult(nc messaging.MessagingClient, jobID, status, result, errMsg string) {
jobs.PublishJobResult(nc, jobID, status, result, errMsg)
}

View file

@ -8,7 +8,7 @@ import (
cliContext "github.com/mudler/LocalAI/core/cli/context" cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/core/services/galleryop"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/system" "github.com/mudler/LocalAI/pkg/system"
@ -103,7 +103,7 @@ func (bi *BackendsInstall) Run(ctx *cliContext.Context) error {
} }
modelLoader := model.NewModelLoader(systemState) modelLoader := model.NewModelLoader(systemState)
err = services.InstallExternalBackend(context.Background(), galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias) err = galleryop.InstallExternalBackend(context.Background(), galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
if err != nil { if err != nil {
return err return err
} }

View file

@ -15,7 +15,9 @@ var CLI struct {
TTS TTSCMD `cmd:"" help:"Convert text to speech"` TTS TTSCMD `cmd:"" help:"Convert text to speech"`
SoundGeneration SoundGenerationCMD `cmd:"" help:"Generates audio files from text or audio"` SoundGeneration SoundGenerationCMD `cmd:"" help:"Generates audio files from text or audio"`
Transcript TranscriptCMD `cmd:"" help:"Convert audio to text"` Transcript TranscriptCMD `cmd:"" help:"Convert audio to text"`
Worker worker.Worker `cmd:"" help:"Run workers to distribute workload (llama.cpp-only)"` P2PWorker worker.Worker `cmd:"" name:"p2p-worker" help:"Run workers to distribute workload via p2p (llama.cpp-only)"`
Worker WorkerCMD `cmd:"" help:"Start a worker for distributed mode (generic, backend-agnostic)"`
AgentWorker AgentWorkerCMD `cmd:"" name:"agent-worker" help:"Start an agent worker for distributed mode (executes agent chats via NATS)"`
Util UtilCMD `cmd:"" help:"Utility commands"` Util UtilCMD `cmd:"" help:"Utility commands"`
Agent AgentCMD `cmd:"" help:"Run agents standalone without the full LocalAI server"` Agent AgentCMD `cmd:"" help:"Run agents standalone without the full LocalAI server"`
Explorer ExplorerCMD `cmd:"" help:"Run p2p explorer"` Explorer ExplorerCMD `cmd:"" help:"Run p2p explorer"`

View file

@ -186,9 +186,9 @@ _local_ai_completions()
} }
subcmds := []string{} subcmds := []string{}
for _, sub := range cmds { for _, sub := range cmds {
parts := strings.SplitN(sub.fullName, " ", 2) parent, child, found := strings.Cut(sub.fullName, " ")
if len(parts) == 2 && parts[0] == cmd.name && !strings.Contains(parts[1], " ") { if found && parent == cmd.name && !strings.Contains(child, " ") {
subcmds = append(subcmds, parts[1]) subcmds = append(subcmds, child)
} }
} }
if len(subcmds) > 0 { if len(subcmds) > 0 {
@ -279,8 +279,8 @@ _local_ai() {
// Check for subcommands // Check for subcommands
subcmds := []commandInfo{} subcmds := []commandInfo{}
for _, sub := range cmds { for _, sub := range cmds {
parts := strings.SplitN(sub.fullName, " ", 2) parent, child, found := strings.Cut(sub.fullName, " ")
if len(parts) == 2 && parts[0] == cmd.name && !strings.Contains(parts[1], " ") { if found && parent == cmd.name && !strings.Contains(child, " ") {
subcmds = append(subcmds, sub) subcmds = append(subcmds, sub)
} }
} }
@ -289,11 +289,11 @@ _local_ai() {
sb.WriteString(" local -a subcmds\n") sb.WriteString(" local -a subcmds\n")
sb.WriteString(" subcmds=(\n") sb.WriteString(" subcmds=(\n")
for _, sub := range subcmds { for _, sub := range subcmds {
parts := strings.SplitN(sub.fullName, " ", 2) _, child, _ := strings.Cut(sub.fullName, " ")
help := strings.ReplaceAll(sub.help, "'", "'\\''") help := strings.ReplaceAll(sub.help, "'", "'\\''")
help = strings.ReplaceAll(help, "[", "\\[") help = strings.ReplaceAll(help, "[", "\\[")
help = strings.ReplaceAll(help, "]", "\\]") help = strings.ReplaceAll(help, "]", "\\]")
sb.WriteString(fmt.Sprintf(" '%s:%s'\n", parts[1], help)) sb.WriteString(fmt.Sprintf(" '%s:%s'\n", child, help))
} }
sb.WriteString(" )\n") sb.WriteString(" )\n")
sb.WriteString(" _describe -t commands 'subcommands' subcmds\n") sb.WriteString(" _describe -t commands 'subcommands' subcmds\n")
@ -372,10 +372,10 @@ func generateFishCompletion(app *kong.Application) string {
// Subcommands // Subcommands
for _, sub := range cmds { for _, sub := range cmds {
parts := strings.SplitN(sub.fullName, " ", 2) parent, child, found := strings.Cut(sub.fullName, " ")
if len(parts) == 2 && parts[0] == cmd.name && !strings.Contains(parts[1], " ") { if found && parent == cmd.name && !strings.Contains(child, " ") {
help := strings.ReplaceAll(sub.help, "'", "\\'") help := strings.ReplaceAll(sub.help, "'", "\\'")
sb.WriteString(fmt.Sprintf("complete -c local-ai -n '__fish_seen_subcommand_from %s' -a %s -d '%s'\n", cmd.name, parts[1], help)) sb.WriteString(fmt.Sprintf("complete -c local-ai -n '__fish_seen_subcommand_from %s' -a %s -d '%s'\n", cmd.name, child, help))
} }
} }

View file

@ -9,8 +9,8 @@ import (
func getTestApp() *kong.Application { func getTestApp() *kong.Application {
var testCLI struct { var testCLI struct {
Run struct{} `cmd:"" help:"Run the server"` Run struct{} `cmd:"" help:"Run the server"`
Models struct { Models struct {
List struct{} `cmd:"" help:"List models"` List struct{} `cmd:"" help:"List models"`
Install struct{} `cmd:"" help:"Install a model"` Install struct{} `cmd:"" help:"Install a model"`
} `cmd:"" help:"Manage models"` } `cmd:"" help:"Manage models"`

View file

@ -8,7 +8,7 @@ import (
cliContext "github.com/mudler/LocalAI/core/cli/context" cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/core/services/galleryop"
"github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/startup" "github.com/mudler/LocalAI/core/startup"
@ -80,7 +80,7 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
return err return err
} }
galleryService := services.NewGalleryService(&config.ApplicationConfig{ galleryService := galleryop.NewGalleryService(&config.ApplicationConfig{
SystemState: systemState, SystemState: systemState,
}, model.NewModelLoader(systemState)) }, model.NewModelLoader(systemState))
err = galleryService.Start(context.Background(), config.NewModelConfigLoader(mi.ModelsPath), systemState) err = galleryService.Start(context.Background(), config.NewModelConfigLoader(mi.ModelsPath), systemState)

View file

@ -44,9 +44,9 @@ type RunCMD struct {
Galleries string `env:"LOCALAI_GALLERIES,GALLERIES" help:"JSON list of galleries" group:"models" default:"${galleries}"` Galleries string `env:"LOCALAI_GALLERIES,GALLERIES" help:"JSON list of galleries" group:"models" default:"${galleries}"`
AutoloadGalleries bool `env:"LOCALAI_AUTOLOAD_GALLERIES,AUTOLOAD_GALLERIES" group:"models" default:"true"` AutoloadGalleries bool `env:"LOCALAI_AUTOLOAD_GALLERIES,AUTOLOAD_GALLERIES" group:"models" default:"true"`
AutoloadBackendGalleries bool `env:"LOCALAI_AUTOLOAD_BACKEND_GALLERIES,AUTOLOAD_BACKEND_GALLERIES" group:"backends" default:"true"` AutoloadBackendGalleries bool `env:"LOCALAI_AUTOLOAD_BACKEND_GALLERIES,AUTOLOAD_BACKEND_GALLERIES" group:"backends" default:"true"`
BackendImagesReleaseTag string `env:"LOCALAI_BACKEND_IMAGES_RELEASE_TAG,BACKEND_IMAGES_RELEASE_TAG" help:"Fallback release tag for backend images" group:"backends" default:"latest"` BackendImagesReleaseTag string `env:"LOCALAI_BACKEND_IMAGES_RELEASE_TAG,BACKEND_IMAGES_RELEASE_TAG" help:"Fallback release tag for backend images" group:"backends" default:"latest"`
BackendImagesBranchTag string `env:"LOCALAI_BACKEND_IMAGES_BRANCH_TAG,BACKEND_IMAGES_BRANCH_TAG" help:"Fallback branch tag for backend images" group:"backends" default:"master"` BackendImagesBranchTag string `env:"LOCALAI_BACKEND_IMAGES_BRANCH_TAG,BACKEND_IMAGES_BRANCH_TAG" help:"Fallback branch tag for backend images" group:"backends" default:"master"`
BackendDevSuffix string `env:"LOCALAI_BACKEND_DEV_SUFFIX,BACKEND_DEV_SUFFIX" help:"Development suffix for backend images" group:"backends" default:"development"` BackendDevSuffix string `env:"LOCALAI_BACKEND_DEV_SUFFIX,BACKEND_DEV_SUFFIX" help:"Development suffix for backend images" group:"backends" default:"development"`
PreloadModels string `env:"LOCALAI_PRELOAD_MODELS,PRELOAD_MODELS" help:"A List of models to apply in JSON at start" group:"models"` PreloadModels string `env:"LOCALAI_PRELOAD_MODELS,PRELOAD_MODELS" help:"A List of models to apply in JSON at start" group:"models"`
Models []string `env:"LOCALAI_MODELS,MODELS" help:"A List of model configuration URLs to load" group:"models"` Models []string `env:"LOCALAI_MODELS,MODELS" help:"A List of model configuration URLs to load" group:"models"`
PreloadModelsConfig string `env:"LOCALAI_PRELOAD_MODELS_CONFIG,PRELOAD_MODELS_CONFIG" help:"A List of models to apply at startup. Path to a YAML config file" group:"models"` PreloadModelsConfig string `env:"LOCALAI_PRELOAD_MODELS_CONFIG,PRELOAD_MODELS_CONFIG" help:"A List of models to apply at startup. Path to a YAML config file" group:"models"`
@ -100,7 +100,7 @@ type RunCMD struct {
OpenResponsesStoreTTL string `env:"LOCALAI_OPEN_RESPONSES_STORE_TTL,OPEN_RESPONSES_STORE_TTL" default:"0" help:"TTL for Open Responses store (e.g., 1h, 30m, 0 = no expiration)" group:"api"` OpenResponsesStoreTTL string `env:"LOCALAI_OPEN_RESPONSES_STORE_TTL,OPEN_RESPONSES_STORE_TTL" default:"0" help:"TTL for Open Responses store (e.g., 1h, 30m, 0 = no expiration)" group:"api"`
// Agent Pool (LocalAGI) // Agent Pool (LocalAGI)
DisableAgents bool `env:"LOCALAI_DISABLE_AGENTS" default:"false" help:"Disable the agent pool feature" group:"agents"` DisableAgents bool `env:"LOCALAI_DISABLE_AGENTS" default:"false" help:"Disable the agent pool feature" group:"agents"`
AgentPoolAPIURL string `env:"LOCALAI_AGENT_POOL_API_URL" help:"Default API URL for agents (defaults to self-referencing LocalAI)" group:"agents"` AgentPoolAPIURL string `env:"LOCALAI_AGENT_POOL_API_URL" help:"Default API URL for agents (defaults to self-referencing LocalAI)" group:"agents"`
AgentPoolAPIKey string `env:"LOCALAI_AGENT_POOL_API_KEY" help:"Default API key for agents (defaults to first LocalAI API key)" group:"agents"` AgentPoolAPIKey string `env:"LOCALAI_AGENT_POOL_API_KEY" help:"Default API key for agents (defaults to first LocalAI API key)" group:"agents"`
AgentPoolDefaultModel string `env:"LOCALAI_AGENT_POOL_DEFAULT_MODEL" help:"Default model for agents" group:"agents"` AgentPoolDefaultModel string `env:"LOCALAI_AGENT_POOL_DEFAULT_MODEL" help:"Default model for agents" group:"agents"`
@ -109,17 +109,17 @@ type RunCMD struct {
AgentPoolTranscriptionLanguage string `env:"LOCALAI_AGENT_POOL_TRANSCRIPTION_LANGUAGE" help:"Default transcription language for agents" group:"agents"` AgentPoolTranscriptionLanguage string `env:"LOCALAI_AGENT_POOL_TRANSCRIPTION_LANGUAGE" help:"Default transcription language for agents" group:"agents"`
AgentPoolTTSModel string `env:"LOCALAI_AGENT_POOL_TTS_MODEL" help:"Default TTS model for agents" group:"agents"` AgentPoolTTSModel string `env:"LOCALAI_AGENT_POOL_TTS_MODEL" help:"Default TTS model for agents" group:"agents"`
AgentPoolStateDir string `env:"LOCALAI_AGENT_POOL_STATE_DIR" help:"State directory for agent pool" group:"agents"` AgentPoolStateDir string `env:"LOCALAI_AGENT_POOL_STATE_DIR" help:"State directory for agent pool" group:"agents"`
AgentPoolTimeout string `env:"LOCALAI_AGENT_POOL_TIMEOUT" default:"5m" help:"Default agent timeout" group:"agents"` AgentPoolTimeout string `env:"LOCALAI_AGENT_POOL_TIMEOUT" default:"5m" help:"Default agent timeout" group:"agents"`
AgentPoolEnableSkills bool `env:"LOCALAI_AGENT_POOL_ENABLE_SKILLS" default:"false" help:"Enable skills service for agents" group:"agents"` AgentPoolEnableSkills bool `env:"LOCALAI_AGENT_POOL_ENABLE_SKILLS" default:"false" help:"Enable skills service for agents" group:"agents"`
AgentPoolVectorEngine string `env:"LOCALAI_AGENT_POOL_VECTOR_ENGINE" default:"chromem" help:"Vector engine type for agent knowledge base" group:"agents"` AgentPoolVectorEngine string `env:"LOCALAI_AGENT_POOL_VECTOR_ENGINE" default:"chromem" help:"Vector engine type for agent knowledge base" group:"agents"`
AgentPoolEmbeddingModel string `env:"LOCALAI_AGENT_POOL_EMBEDDING_MODEL" default:"granite-embedding-107m-multilingual" help:"Embedding model for agent knowledge base" group:"agents"` AgentPoolEmbeddingModel string `env:"LOCALAI_AGENT_POOL_EMBEDDING_MODEL" default:"granite-embedding-107m-multilingual" help:"Embedding model for agent knowledge base" group:"agents"`
AgentPoolCustomActionsDir string `env:"LOCALAI_AGENT_POOL_CUSTOM_ACTIONS_DIR" help:"Custom actions directory for agents" group:"agents"` AgentPoolCustomActionsDir string `env:"LOCALAI_AGENT_POOL_CUSTOM_ACTIONS_DIR" help:"Custom actions directory for agents" group:"agents"`
AgentPoolDatabaseURL string `env:"LOCALAI_AGENT_POOL_DATABASE_URL" help:"Database URL for agent collections" group:"agents"` AgentPoolDatabaseURL string `env:"LOCALAI_AGENT_POOL_DATABASE_URL" help:"Database URL for agent collections" group:"agents"`
AgentPoolMaxChunkingSize int `env:"LOCALAI_AGENT_POOL_MAX_CHUNKING_SIZE" default:"400" help:"Maximum chunking size for knowledge base documents" group:"agents"` AgentPoolMaxChunkingSize int `env:"LOCALAI_AGENT_POOL_MAX_CHUNKING_SIZE" default:"400" help:"Maximum chunking size for knowledge base documents" group:"agents"`
AgentPoolChunkOverlap int `env:"LOCALAI_AGENT_POOL_CHUNK_OVERLAP" default:"0" help:"Chunk overlap size for knowledge base documents" group:"agents"` AgentPoolChunkOverlap int `env:"LOCALAI_AGENT_POOL_CHUNK_OVERLAP" default:"0" help:"Chunk overlap size for knowledge base documents" group:"agents"`
AgentPoolEnableLogs bool `env:"LOCALAI_AGENT_POOL_ENABLE_LOGS" default:"false" help:"Enable agent logging" group:"agents"` AgentPoolEnableLogs bool `env:"LOCALAI_AGENT_POOL_ENABLE_LOGS" default:"false" help:"Enable agent logging" group:"agents"`
AgentPoolCollectionDBPath string `env:"LOCALAI_AGENT_POOL_COLLECTION_DB_PATH" help:"Database path for agent collections" group:"agents"` AgentPoolCollectionDBPath string `env:"LOCALAI_AGENT_POOL_COLLECTION_DB_PATH" help:"Database path for agent collections" group:"agents"`
AgentHubURL string `env:"LOCALAI_AGENT_HUB_URL" default:"https://agenthub.localai.io" help:"URL for the agent hub where users can browse and download agent configurations" group:"agents"` AgentHubURL string `env:"LOCALAI_AGENT_HUB_URL" default:"https://agenthub.localai.io" help:"URL for the agent hub where users can browse and download agent configurations" group:"agents"`
// Authentication // Authentication
AuthEnabled bool `env:"LOCALAI_AUTH" default:"false" help:"Enable user authentication and authorization" group:"auth"` AuthEnabled bool `env:"LOCALAI_AUTH" default:"false" help:"Enable user authentication and authorization" group:"auth"`
@ -136,6 +136,18 @@ type RunCMD struct {
AuthAPIKeyHMACSecret string `env:"LOCALAI_AUTH_HMAC_SECRET" help:"HMAC secret for API key hashing (auto-generated if empty)" group:"auth"` AuthAPIKeyHMACSecret string `env:"LOCALAI_AUTH_HMAC_SECRET" help:"HMAC secret for API key hashing (auto-generated if empty)" group:"auth"`
DefaultAPIKeyExpiry string `env:"LOCALAI_DEFAULT_API_KEY_EXPIRY" help:"Default expiry for API keys (e.g. 90d, 1y; empty = no expiry)" group:"auth"` DefaultAPIKeyExpiry string `env:"LOCALAI_DEFAULT_API_KEY_EXPIRY" help:"Default expiry for API keys (e.g. 90d, 1y; empty = no expiry)" group:"auth"`
// Distributed / Horizontal Scaling
Distributed bool `env:"LOCALAI_DISTRIBUTED" default:"false" help:"Enable distributed mode (requires PostgreSQL + NATS)" group:"distributed"`
InstanceID string `env:"LOCALAI_INSTANCE_ID" help:"Unique instance ID for distributed mode (auto-generated UUID if empty)" group:"distributed"`
NatsURL string `env:"LOCALAI_NATS_URL" help:"NATS server URL (e.g., nats://localhost:4222)" group:"distributed"`
StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3-compatible storage endpoint URL (e.g., http://minio:9000)" group:"distributed"`
StorageBucket string `env:"LOCALAI_STORAGE_BUCKET" default:"localai" help:"S3 bucket name for object storage" group:"distributed"`
StorageRegion string `env:"LOCALAI_STORAGE_REGION" default:"us-east-1" help:"S3 region" group:"distributed"`
StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key ID" group:"distributed"`
StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret access key" group:"distributed"`
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token that backend nodes must provide to register (empty = no auth required)" group:"distributed"`
AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"`
Version bool Version bool
} }
@ -210,6 +222,38 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
}), }),
} }
// Distributed mode
if r.Distributed {
opts = append(opts, config.EnableDistributed)
}
if r.InstanceID != "" {
opts = append(opts, config.WithDistributedInstanceID(r.InstanceID))
}
if r.NatsURL != "" {
opts = append(opts, config.WithNatsURL(r.NatsURL))
}
if r.StorageURL != "" {
opts = append(opts, config.WithStorageURL(r.StorageURL))
}
if r.StorageBucket != "" {
opts = append(opts, config.WithStorageBucket(r.StorageBucket))
}
if r.StorageRegion != "" {
opts = append(opts, config.WithStorageRegion(r.StorageRegion))
}
if r.StorageAccessKey != "" {
opts = append(opts, config.WithStorageAccessKey(r.StorageAccessKey))
}
if r.StorageSecretKey != "" {
opts = append(opts, config.WithStorageSecretKey(r.StorageSecretKey))
}
if r.RegistrationToken != "" {
opts = append(opts, config.WithRegistrationToken(r.RegistrationToken))
}
if r.AutoApproveNodes {
opts = append(opts, config.EnableAutoApproveNodes)
}
if r.DisableMetricsEndpoint { if r.DisableMetricsEndpoint {
opts = append(opts, config.DisableMetricsEndpoint) opts = append(opts, config.DisableMetricsEndpoint)
} }
@ -218,10 +262,6 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
opts = append(opts, config.DisableRuntimeSettings) opts = append(opts, config.DisableRuntimeSettings)
} }
if r.EnableTracing {
opts = append(opts, config.EnableTracing)
}
if r.EnableTracing { if r.EnableTracing {
opts = append(opts, config.EnableTracing) opts = append(opts, config.EnableTracing)
} }
@ -479,6 +519,10 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
if err := app.ModelLoader().StopAllGRPC(); err != nil { if err := app.ModelLoader().StopAllGRPC(); err != nil {
xlog.Error("error while stopping all grpc backends", "error", err) xlog.Error("error while stopping all grpc backends", "error", err)
} }
// Clean up distributed services (idempotent — safe if already called)
if d := app.Distributed(); d != nil {
d.Shutdown()
}
}) })
// Start the agent pool after the HTTP server is listening, because // Start the agent pool after the HTTP server is listening, because

View file

@ -12,7 +12,6 @@ import (
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/format"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/system" "github.com/mudler/LocalAI/pkg/system"
"github.com/mudler/xlog" "github.com/mudler/xlog"
@ -80,7 +79,7 @@ func (t *TranscriptCMD) Run(ctx *cliContext.Context) error {
switch t.ResponseFormat { switch t.ResponseFormat {
case schema.TranscriptionResponseFormatLrc, schema.TranscriptionResponseFormatSrt, schema.TranscriptionResponseFormatVtt, schema.TranscriptionResponseFormatText: case schema.TranscriptionResponseFormatLrc, schema.TranscriptionResponseFormatSrt, schema.TranscriptionResponseFormatVtt, schema.TranscriptionResponseFormatText:
fmt.Println(format.TranscriptionResponse(tr, t.ResponseFormat)) fmt.Println(schema.TranscriptionResponse(tr, t.ResponseFormat))
case schema.TranscriptionResponseFormatJson: case schema.TranscriptionResponseFormatJson:
tr.Segments = nil tr.Segments = nil
fallthrough fallthrough

897
core/cli/worker.go Normal file
View file

@ -0,0 +1,897 @@
package cli
import (
"cmp"
"context"
"encoding/json"
"fmt"
"maps"
"net"
"os"
"os/signal"
"path/filepath"
"slices"
"strconv"
"strings"
"sync"
"syscall"
"time"
cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/cli/workerregistry"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/services/messaging"
"github.com/mudler/LocalAI/core/services/nodes"
"github.com/mudler/LocalAI/core/services/storage"
grpc "github.com/mudler/LocalAI/pkg/grpc"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/sanitize"
"github.com/mudler/LocalAI/pkg/system"
"github.com/mudler/LocalAI/pkg/xsysinfo"
process "github.com/mudler/go-processmanager"
"github.com/mudler/xlog"
)
// isPathAllowed checks if path is within one of the allowed directories.
func isPathAllowed(path string, allowedDirs []string) bool {
absPath, err := filepath.Abs(path)
if err != nil {
return false
}
resolved, err := filepath.EvalSymlinks(absPath)
if err != nil {
// Path may not exist yet; use the absolute path
resolved = absPath
}
for _, dir := range allowedDirs {
absDir, err := filepath.Abs(dir)
if err != nil {
continue
}
if strings.HasPrefix(resolved, absDir+string(filepath.Separator)) || resolved == absDir {
return true
}
}
return false
}
// WorkerCMD starts a generic worker process for distributed mode.
// Workers are backend-agnostic — they wait for backend.install NATS events
// from the SmartRouter to install and start the required backend.
//
// NATS is required. The worker acts as a process supervisor:
// - Receives backend.install → installs backend from gallery, starts gRPC process, replies success
// - Receives backend.stop → stops the gRPC process
// - Receives stop → full shutdown (deregister + exit)
//
// Model loading (LoadModel) is always via direct gRPC — no NATS needed for that.
type WorkerCMD struct {
Addr string `env:"LOCALAI_SERVE_ADDR" default:"0.0.0.0:50051" help:"Address to bind the gRPC server to" group:"server"`
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends" group:"server"`
BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH" type:"path" default:"/var/lib/local-ai/backends" help:"Path containing system backends" group:"server"`
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"server" default:"${backends}"`
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models" group:"server"`
// HTTP file transfer
HTTPAddr string `env:"LOCALAI_HTTP_ADDR" default:"" help:"HTTP file transfer server address (default: gRPC port + 1)" group:"server"`
AdvertiseHTTPAddr string `env:"LOCALAI_ADVERTISE_HTTP_ADDR" help:"HTTP address the frontend uses to reach this node for file transfer" group:"server"`
// Registration (required)
AdvertiseAddr string `env:"LOCALAI_ADVERTISE_ADDR" help:"Address the frontend uses to reach this node (defaults to hostname:port from Addr)" group:"registration"`
RegisterTo string `env:"LOCALAI_REGISTER_TO" required:"" help:"Frontend URL for registration" group:"registration"`
NodeName string `env:"LOCALAI_NODE_NAME" help:"Node name for registration (defaults to hostname)" group:"registration"`
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token for authenticating with the frontend" group:"registration"`
HeartbeatInterval string `env:"LOCALAI_HEARTBEAT_INTERVAL" default:"10s" help:"Interval between heartbeats" group:"registration"`
// NATS (required)
NatsURL string `env:"LOCALAI_NATS_URL" required:"" help:"NATS server URL" group:"distributed"`
// S3 storage for distributed file transfer
StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3 endpoint URL" group:"distributed"`
StorageBucket string `env:"LOCALAI_STORAGE_BUCKET" help:"S3 bucket name" group:"distributed"`
StorageRegion string `env:"LOCALAI_STORAGE_REGION" help:"S3 region" group:"distributed"`
StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key" group:"distributed"`
StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret key" group:"distributed"`
}
func (cmd *WorkerCMD) Run(ctx *cliContext.Context) error {
xlog.Info("Starting worker", "addr", cmd.Addr)
systemState, err := system.GetSystemState(
system.WithModelPath(cmd.ModelsPath),
system.WithBackendPath(cmd.BackendsPath),
system.WithBackendSystemPath(cmd.BackendsSystemPath),
)
if err != nil {
return fmt.Errorf("getting system state: %w", err)
}
ml := model.NewModelLoader(systemState)
ml.SetBackendLoggingEnabled(true)
// Register already-installed backends
gallery.RegisterBackends(systemState, ml)
// Parse galleries config
var galleries []config.Gallery
if err := json.Unmarshal([]byte(cmd.BackendGalleries), &galleries); err != nil {
xlog.Warn("Failed to parse backend galleries", "error", err)
}
// Self-registration with frontend (with retry)
regClient := &workerregistry.RegistrationClient{
FrontendURL: cmd.RegisterTo,
RegistrationToken: cmd.RegistrationToken,
}
registrationBody := cmd.registrationBody()
nodeID, _, err := regClient.RegisterWithRetry(context.Background(), registrationBody, 10)
if err != nil {
return fmt.Errorf("failed to register with frontend: %w", err)
}
xlog.Info("Registered with frontend", "nodeID", nodeID, "frontend", cmd.RegisterTo)
heartbeatInterval, err := time.ParseDuration(cmd.HeartbeatInterval)
if err != nil && cmd.HeartbeatInterval != "" {
xlog.Warn("invalid heartbeat interval, using default 10s", "input", cmd.HeartbeatInterval, "error", err)
}
heartbeatInterval = cmp.Or(heartbeatInterval, 10*time.Second)
// Context cancelled on shutdown — used by heartbeat and other background goroutines
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
defer shutdownCancel()
// Start HTTP file transfer server
httpAddr := cmd.resolveHTTPAddr()
stagingDir := filepath.Join(cmd.ModelsPath, "..", "staging")
dataDir := filepath.Join(cmd.ModelsPath, "..", "data")
httpServer, err := nodes.StartFileTransferServer(httpAddr, stagingDir, cmd.ModelsPath, dataDir, cmd.RegistrationToken, config.DefaultMaxUploadSize, ml.BackendLogs())
if err != nil {
return fmt.Errorf("starting HTTP file transfer server: %w", err)
}
// Connect to NATS
xlog.Info("Connecting to NATS", "url", sanitize.URL(cmd.NatsURL))
natsClient, err := messaging.New(cmd.NatsURL)
if err != nil {
nodes.ShutdownFileTransferServer(httpServer)
return fmt.Errorf("connecting to NATS: %w", err)
}
defer natsClient.Close()
// Start heartbeat goroutine (after NATS is connected so IsConnected check works)
go func() {
ticker := time.NewTicker(heartbeatInterval)
defer ticker.Stop()
for {
select {
case <-shutdownCtx.Done():
return
case <-ticker.C:
if !natsClient.IsConnected() {
xlog.Warn("Skipping heartbeat: NATS disconnected")
continue
}
body := cmd.heartbeatBody()
if err := regClient.Heartbeat(shutdownCtx, nodeID, body); err != nil {
xlog.Warn("Heartbeat failed", "error", err)
}
}
}
}()
// Process supervisor — manages multiple backend gRPC processes on different ports
basePort := 50051
if cmd.Addr != "" {
// Extract port from addr (e.g., "0.0.0.0:50051" → 50051)
if _, portStr, err := net.SplitHostPort(cmd.Addr); err == nil {
if p, err := strconv.Atoi(portStr); err == nil {
basePort = p
}
}
}
// Buffered so NATS stop handler can send without blocking
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
// Set the registration token once before any backends are started
if cmd.RegistrationToken != "" {
os.Setenv(grpc.AuthTokenEnvVar, cmd.RegistrationToken)
}
supervisor := &backendSupervisor{
cmd: cmd,
ml: ml,
systemState: systemState,
galleries: galleries,
nodeID: nodeID,
nats: natsClient,
sigCh: sigCh,
processes: make(map[string]*backendProcess),
nextPort: basePort,
}
supervisor.subscribeLifecycleEvents()
// Subscribe to file staging NATS subjects if S3 is configured
if cmd.StorageURL != "" {
if err := cmd.subscribeFileStaging(natsClient, nodeID); err != nil {
xlog.Error("Failed to subscribe to file staging subjects", "error", err)
}
}
xlog.Info("Worker ready, waiting for backend.install events")
<-sigCh
xlog.Info("Shutting down worker")
shutdownCancel() // stop heartbeat loop immediately
regClient.GracefulDeregister(nodeID)
supervisor.stopAllBackends()
nodes.ShutdownFileTransferServer(httpServer)
return nil
}
// subscribeFileStaging subscribes to NATS file staging subjects for this node.
func (cmd *WorkerCMD) subscribeFileStaging(natsClient messaging.MessagingClient, nodeID string) error {
// Create FileManager with same S3 config as the frontend
// TODO: propagate a caller-provided context once WorkerCMD carries one
s3Store, err := storage.NewS3Store(context.Background(), storage.S3Config{
Endpoint: cmd.StorageURL,
Region: cmd.StorageRegion,
Bucket: cmd.StorageBucket,
AccessKeyID: cmd.StorageAccessKey,
SecretAccessKey: cmd.StorageSecretKey,
ForcePathStyle: true,
})
if err != nil {
return fmt.Errorf("initializing S3 store: %w", err)
}
cacheDir := filepath.Join(cmd.ModelsPath, "..", "cache")
fm, err := storage.NewFileManager(s3Store, cacheDir)
if err != nil {
return fmt.Errorf("initializing file manager: %w", err)
}
// Subscribe: files.ensure — download S3 key to local, reply with local path
natsClient.SubscribeReply(messaging.SubjectNodeFilesEnsure(nodeID), func(data []byte, reply func([]byte)) {
var req struct {
Key string `json:"key"`
}
if err := json.Unmarshal(data, &req); err != nil {
replyJSON(reply, map[string]string{"error": "invalid request"})
return
}
localPath, err := fm.Download(context.Background(), req.Key)
if err != nil {
xlog.Error("File ensure failed", "key", req.Key, "error", err)
replyJSON(reply, map[string]string{"error": err.Error()})
return
}
xlog.Debug("File ensured locally", "key", req.Key, "path", localPath)
replyJSON(reply, map[string]string{"local_path": localPath})
})
// Subscribe: files.stage — upload local path to S3, reply with key
natsClient.SubscribeReply(messaging.SubjectNodeFilesStage(nodeID), func(data []byte, reply func([]byte)) {
var req struct {
LocalPath string `json:"local_path"`
Key string `json:"key"`
}
if err := json.Unmarshal(data, &req); err != nil {
replyJSON(reply, map[string]string{"error": "invalid request"})
return
}
allowedDirs := []string{cacheDir}
if cmd.ModelsPath != "" {
allowedDirs = append(allowedDirs, cmd.ModelsPath)
}
if !isPathAllowed(req.LocalPath, allowedDirs) {
replyJSON(reply, map[string]string{"error": "path outside allowed directories"})
return
}
if err := fm.Upload(context.Background(), req.Key, req.LocalPath); err != nil {
xlog.Error("File stage failed", "path", req.LocalPath, "key", req.Key, "error", err)
replyJSON(reply, map[string]string{"error": err.Error()})
return
}
xlog.Debug("File staged to S3", "path", req.LocalPath, "key", req.Key)
replyJSON(reply, map[string]string{"key": req.Key})
})
// Subscribe: files.temp — allocate temp file, reply with local path
natsClient.SubscribeReply(messaging.SubjectNodeFilesTemp(nodeID), func(data []byte, reply func([]byte)) {
tmpDir := filepath.Join(cacheDir, "staging-tmp")
if err := os.MkdirAll(tmpDir, 0750); err != nil {
replyJSON(reply, map[string]string{"error": fmt.Sprintf("creating temp dir: %v", err)})
return
}
f, err := os.CreateTemp(tmpDir, "localai-staging-*.tmp")
if err != nil {
replyJSON(reply, map[string]string{"error": fmt.Sprintf("creating temp file: %v", err)})
return
}
localPath := f.Name()
f.Close()
xlog.Debug("Allocated temp file", "path", localPath)
replyJSON(reply, map[string]string{"local_path": localPath})
})
// Subscribe: files.listdir — list files in a local directory, reply with relative paths
natsClient.SubscribeReply(messaging.SubjectNodeFilesListDir(nodeID), func(data []byte, reply func([]byte)) {
var req struct {
KeyPrefix string `json:"key_prefix"`
}
if err := json.Unmarshal(data, &req); err != nil {
replyJSON(reply, map[string]any{"error": "invalid request"})
return
}
// Resolve key prefix to local directory
dirPath := filepath.Join(cacheDir, req.KeyPrefix)
if rel, ok := strings.CutPrefix(req.KeyPrefix, storage.ModelKeyPrefix); ok && cmd.ModelsPath != "" {
dirPath = filepath.Join(cmd.ModelsPath, rel)
} else if rel, ok := strings.CutPrefix(req.KeyPrefix, storage.DataKeyPrefix); ok {
dirPath = filepath.Join(cacheDir, "..", "data", rel)
}
// Sanitize to prevent directory traversal via crafted key_prefix
dirPath = filepath.Clean(dirPath)
cleanCache := filepath.Clean(cacheDir)
cleanModels := filepath.Clean(cmd.ModelsPath)
cleanData := filepath.Clean(filepath.Join(cacheDir, "..", "data"))
if !(strings.HasPrefix(dirPath, cleanCache+string(filepath.Separator)) ||
dirPath == cleanCache ||
(cleanModels != "." && strings.HasPrefix(dirPath, cleanModels+string(filepath.Separator))) ||
dirPath == cleanModels ||
strings.HasPrefix(dirPath, cleanData+string(filepath.Separator)) ||
dirPath == cleanData) {
replyJSON(reply, map[string]any{"error": "invalid key prefix"})
return
}
var files []string
filepath.WalkDir(dirPath, func(path string, d os.DirEntry, err error) error {
if err != nil {
return nil
}
if !d.IsDir() {
rel, err := filepath.Rel(dirPath, path)
if err == nil {
files = append(files, rel)
}
}
return nil
})
xlog.Debug("Listed remote dir", "keyPrefix", req.KeyPrefix, "dirPath", dirPath, "fileCount", len(files))
replyJSON(reply, map[string]any{"files": files})
})
xlog.Info("Subscribed to file staging NATS subjects", "nodeID", nodeID)
return nil
}
// replyJSON marshals v to JSON and calls the reply function.
func replyJSON(reply func([]byte), v any) {
data, err := json.Marshal(v)
if err != nil {
xlog.Error("Failed to marshal NATS reply", "error", err)
data = []byte(`{"error":"internal marshal error"}`)
}
reply(data)
}
// backendProcess represents a single gRPC backend process.
type backendProcess struct {
proc *process.Process
backend string
addr string // gRPC address (host:port)
}
// backendSupervisor manages multiple backend gRPC processes on different ports.
// Each backend type (e.g., llama-cpp, bert-embeddings) gets its own process and port.
type backendSupervisor struct {
cmd *WorkerCMD
ml *model.ModelLoader
systemState *system.SystemState
galleries []config.Gallery
nodeID string
nats messaging.MessagingClient
sigCh chan<- os.Signal // send shutdown signal instead of os.Exit
mu sync.Mutex
processes map[string]*backendProcess // key: backend name
nextPort int // next available port for new backends
freePorts []int // ports freed by stopBackend, reused before nextPort
}
// startBackend starts a gRPC backend process on a dynamically allocated port.
// Returns the gRPC address.
func (s *backendSupervisor) startBackend(backend, backendPath string) (string, error) {
s.mu.Lock()
// Already running?
if bp, ok := s.processes[backend]; ok {
if bp.proc != nil && bp.proc.IsAlive() {
s.mu.Unlock()
return bp.addr, nil
}
// Process died — clean up and restart
xlog.Warn("Backend process died unexpectedly, restarting", "backend", backend)
delete(s.processes, backend)
}
// Allocate port — recycle freed ports first, then grow upward from basePort
var port int
if len(s.freePorts) > 0 {
port = s.freePorts[len(s.freePorts)-1]
s.freePorts = s.freePorts[:len(s.freePorts)-1]
} else {
port = s.nextPort
s.nextPort++
}
bindAddr := fmt.Sprintf("0.0.0.0:%d", port)
clientAddr := fmt.Sprintf("127.0.0.1:%d", port)
proc, err := s.ml.StartProcess(backendPath, backend, bindAddr)
if err != nil {
s.mu.Unlock()
return "", fmt.Errorf("starting backend process: %w", err)
}
s.processes[backend] = &backendProcess{
proc: proc,
backend: backend,
addr: clientAddr,
}
xlog.Info("Backend process started", "backend", backend, "addr", clientAddr)
// Capture reference before unlocking for race-safe health check.
// Another goroutine could stopBackend and recycle the port while we poll.
bp := s.processes[backend]
s.mu.Unlock()
// Wait for the gRPC server to be ready
client := grpc.NewClientWithToken(clientAddr, false, nil, false, s.cmd.RegistrationToken)
for range 20 {
time.Sleep(200 * time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
if ok, _ := client.HealthCheck(ctx); ok {
cancel()
// Verify the process wasn't stopped/replaced while health-checking
s.mu.Lock()
currentBP, exists := s.processes[backend]
s.mu.Unlock()
if !exists || currentBP != bp {
return "", fmt.Errorf("backend %s was stopped during startup", backend)
}
xlog.Debug("Backend gRPC server is ready", "backend", backend, "addr", clientAddr)
return clientAddr, nil
}
cancel()
}
xlog.Warn("Backend gRPC server not ready after waiting, proceeding anyway", "backend", backend, "addr", clientAddr)
return clientAddr, nil
}
// stopBackend stops a specific backend's gRPC process.
func (s *backendSupervisor) stopBackend(backend string) {
s.mu.Lock()
bp, ok := s.processes[backend]
if !ok || bp.proc == nil {
s.mu.Unlock()
return
}
// Clean up map and recycle port while holding lock
delete(s.processes, backend)
if _, portStr, err := net.SplitHostPort(bp.addr); err == nil {
if p, err := strconv.Atoi(portStr); err == nil {
s.freePorts = append(s.freePorts, p)
}
}
s.mu.Unlock()
// Network I/O outside the lock
client := grpc.NewClientWithToken(bp.addr, false, nil, false, s.cmd.RegistrationToken)
if freeFunc, ok := client.(interface{ Free(context.Context) error }); ok {
xlog.Debug("Calling Free() before stopping backend", "backend", backend)
if err := freeFunc.Free(context.Background()); err != nil {
xlog.Warn("Free() failed (best-effort)", "backend", backend, "error", err)
}
}
xlog.Info("Stopping backend process", "backend", backend, "addr", bp.addr)
if err := bp.proc.Stop(); err != nil {
xlog.Error("Error stopping backend process", "backend", backend, "error", err)
}
}
// stopAllBackends stops all running backend processes.
func (s *backendSupervisor) stopAllBackends() {
s.mu.Lock()
backends := slices.Collect(maps.Keys(s.processes))
s.mu.Unlock()
for _, b := range backends {
s.stopBackend(b)
}
}
// isRunning returns whether a specific backend process is currently running.
func (s *backendSupervisor) isRunning(backend string) bool {
s.mu.Lock()
defer s.mu.Unlock()
bp, ok := s.processes[backend]
return ok && bp.proc != nil && bp.proc.IsAlive()
}
// getAddr returns the gRPC address for a running backend, or empty string.
func (s *backendSupervisor) getAddr(backend string) string {
s.mu.Lock()
defer s.mu.Unlock()
if bp, ok := s.processes[backend]; ok {
return bp.addr
}
return ""
}
// installBackend handles the backend.install flow:
// 1. If already running for this model, return existing address
// 2. Install backend from gallery (if not already installed)
// 3. Find backend binary
// 4. Start gRPC process on a new port
// Returns the gRPC address of the backend process.
func (s *backendSupervisor) installBackend(req messaging.BackendInstallRequest) (string, error) {
// Process key: use ModelID if provided (per-model process), else backend name
processKey := req.ModelID
if processKey == "" {
processKey = req.Backend
}
// If already running for this model, return its address
if addr := s.getAddr(processKey); addr != "" {
xlog.Info("Backend already running for model", "backend", req.Backend, "model", req.ModelID, "addr", addr)
return addr, nil
}
// Parse galleries from request (override local config if provided)
galleries := s.galleries
if req.BackendGalleries != "" {
var reqGalleries []config.Gallery
if err := json.Unmarshal([]byte(req.BackendGalleries), &reqGalleries); err == nil {
galleries = reqGalleries
}
}
// Try to find the backend binary
backendPath := s.findBackend(req.Backend)
if backendPath == "" {
// Backend not found locally — try auto-installing from gallery
xlog.Info("Backend not found locally, attempting gallery install", "backend", req.Backend)
if err := gallery.InstallBackendFromGallery(
context.Background(), galleries, s.systemState, s.ml, req.Backend, nil, false,
); err != nil {
return "", fmt.Errorf("installing backend from gallery: %w", err)
}
// Re-register after install and retry
gallery.RegisterBackends(s.systemState, s.ml)
backendPath = s.findBackend(req.Backend)
}
if backendPath == "" {
return "", fmt.Errorf("backend %q not found after install attempt", req.Backend)
}
xlog.Info("Found backend binary", "path", backendPath, "processKey", processKey)
// Start the gRPC process on a new port (keyed by model, not just backend)
return s.startBackend(processKey, backendPath)
}
// findBackend looks for the backend binary in the backends path and system path.
func (s *backendSupervisor) findBackend(backend string) string {
candidates := []string{
filepath.Join(s.cmd.BackendsPath, backend),
filepath.Join(s.cmd.BackendsPath, backend, backend),
filepath.Join(s.cmd.BackendsSystemPath, backend),
filepath.Join(s.cmd.BackendsSystemPath, backend, backend),
}
if uri := s.ml.GetExternalBackend(backend); uri != "" {
if fi, err := os.Stat(uri); err == nil && !fi.IsDir() {
return uri
}
}
for _, path := range candidates {
fi, err := os.Stat(path)
if err == nil && !fi.IsDir() {
return path
}
}
return ""
}
// subscribeLifecycleEvents subscribes to NATS backend lifecycle events.
func (s *backendSupervisor) subscribeLifecycleEvents() {
// backend.install — install backend + start gRPC process (request-reply)
s.nats.SubscribeReply(messaging.SubjectNodeBackendInstall(s.nodeID), func(data []byte, reply func([]byte)) {
xlog.Info("Received NATS backend.install event")
var req messaging.BackendInstallRequest
if err := json.Unmarshal(data, &req); err != nil {
resp := messaging.BackendInstallReply{Success: false, Error: fmt.Sprintf("invalid request: %v", err)}
replyJSON(reply, resp)
return
}
addr, err := s.installBackend(req)
if err != nil {
xlog.Error("Failed to install backend via NATS", "error", err)
resp := messaging.BackendInstallReply{Success: false, Error: err.Error()}
replyJSON(reply, resp)
return
}
// Return the gRPC address so the router knows which port to use
advertiseAddr := addr
if s.cmd.AdvertiseAddr != "" {
// Replace 0.0.0.0 with the advertised host but keep the dynamic port
_, port, _ := net.SplitHostPort(addr)
advertiseHost, _, _ := net.SplitHostPort(s.cmd.AdvertiseAddr)
advertiseAddr = net.JoinHostPort(advertiseHost, port)
}
resp := messaging.BackendInstallReply{Success: true, Address: advertiseAddr}
replyJSON(reply, resp)
})
// backend.stop — stop a specific backend process
s.nats.Subscribe(messaging.SubjectNodeBackendStop(s.nodeID), func(data []byte) {
// Try to parse backend name from payload; if empty, stop all
var req struct {
Backend string `json:"backend"`
}
if json.Unmarshal(data, &req) == nil && req.Backend != "" {
xlog.Info("Received NATS backend.stop event", "backend", req.Backend)
s.stopBackend(req.Backend)
} else {
xlog.Info("Received NATS backend.stop event (all)")
s.stopAllBackends()
}
})
// backend.delete — stop backend + delete files (request-reply)
s.nats.SubscribeReply(messaging.SubjectNodeBackendDelete(s.nodeID), func(data []byte, reply func([]byte)) {
xlog.Info("Received NATS backend.delete event")
var req messaging.BackendDeleteRequest
if err := json.Unmarshal(data, &req); err != nil {
resp := messaging.BackendDeleteReply{Success: false, Error: fmt.Sprintf("invalid request: %v", err)}
replyJSON(reply, resp)
return
}
// Stop if running this backend
if s.isRunning(req.Backend) {
s.stopBackend(req.Backend)
}
// Delete the backend files
if err := gallery.DeleteBackendFromSystem(s.systemState, req.Backend); err != nil {
xlog.Warn("Failed to delete backend files", "backend", req.Backend, "error", err)
resp := messaging.BackendDeleteReply{Success: false, Error: err.Error()}
replyJSON(reply, resp)
return
}
// Re-register backends after deletion
gallery.RegisterBackends(s.systemState, s.ml)
resp := messaging.BackendDeleteReply{Success: true}
replyJSON(reply, resp)
})
// backend.list — list installed backends (request-reply)
s.nats.SubscribeReply(messaging.SubjectNodeBackendList(s.nodeID), func(data []byte, reply func([]byte)) {
xlog.Info("Received NATS backend.list event")
backends, err := gallery.ListSystemBackends(s.systemState)
if err != nil {
resp := messaging.BackendListReply{Error: err.Error()}
replyJSON(reply, resp)
return
}
var infos []messaging.NodeBackendInfo
for name, b := range backends {
info := messaging.NodeBackendInfo{
Name: name,
IsSystem: b.IsSystem,
IsMeta: b.IsMeta,
}
if b.Metadata != nil {
info.InstalledAt = b.Metadata.InstalledAt
info.GalleryURL = b.Metadata.GalleryURL
}
infos = append(infos, info)
}
resp := messaging.BackendListReply{Backends: infos}
replyJSON(reply, resp)
})
// model.unload — call gRPC Free() to release GPU memory (request-reply)
s.nats.SubscribeReply(messaging.SubjectNodeModelUnload(s.nodeID), func(data []byte, reply func([]byte)) {
xlog.Info("Received NATS model.unload event")
var req messaging.ModelUnloadRequest
if err := json.Unmarshal(data, &req); err != nil {
resp := messaging.ModelUnloadReply{Success: false, Error: fmt.Sprintf("invalid request: %v", err)}
replyJSON(reply, resp)
return
}
// Find the backend address for this model's backend type
// The request includes an Address field if the router knows which process to target
targetAddr := req.Address
if targetAddr == "" {
// Fallback: try all running backends
s.mu.Lock()
for _, bp := range s.processes {
targetAddr = bp.addr
break
}
s.mu.Unlock()
}
if targetAddr != "" {
// Best-effort gRPC Free()
client := grpc.NewClientWithToken(targetAddr, false, nil, false, s.cmd.RegistrationToken)
if freeFunc, ok := client.(interface{ Free(context.Context) error }); ok {
if err := freeFunc.Free(context.Background()); err != nil {
xlog.Warn("Free() failed during model.unload", "error", err, "addr", targetAddr)
}
}
}
resp := messaging.ModelUnloadReply{Success: true}
replyJSON(reply, resp)
})
// model.delete — remove model files from disk (request-reply)
s.nats.SubscribeReply(messaging.SubjectNodeModelDelete(s.nodeID), func(data []byte, reply func([]byte)) {
xlog.Info("Received NATS model.delete event")
var req messaging.ModelDeleteRequest
if err := json.Unmarshal(data, &req); err != nil {
replyJSON(reply, messaging.ModelDeleteReply{Success: false, Error: "invalid request"})
return
}
if err := gallery.DeleteStagedModelFiles(s.cmd.ModelsPath, req.ModelName); err != nil {
xlog.Warn("Failed to delete model files", "model", req.ModelName, "error", err)
replyJSON(reply, messaging.ModelDeleteReply{Success: false, Error: err.Error()})
return
}
replyJSON(reply, messaging.ModelDeleteReply{Success: true})
})
// stop — trigger the normal shutdown path via sigCh so deferred cleanup runs
s.nats.Subscribe(messaging.SubjectNodeStop(s.nodeID), func(data []byte) {
xlog.Info("Received NATS stop event — signaling shutdown")
select {
case s.sigCh <- syscall.SIGTERM:
default:
xlog.Debug("Shutdown already signaled, ignoring duplicate stop")
}
})
}
// advertiseAddr returns the address the frontend should use to reach this node.
func (cmd *WorkerCMD) advertiseAddr() string {
if cmd.AdvertiseAddr != "" {
return cmd.AdvertiseAddr
}
host, port, ok := strings.Cut(cmd.Addr, ":")
if ok && (host == "0.0.0.0" || host == "") {
if hostname, err := os.Hostname(); err == nil {
return hostname + ":" + port
}
}
return cmd.Addr
}
// resolveHTTPAddr returns the address to bind the HTTP file transfer server to.
// Uses basePort-1 so it doesn't conflict with dynamically allocated gRPC ports
// which grow upward from basePort.
func (cmd *WorkerCMD) resolveHTTPAddr() string {
if cmd.HTTPAddr != "" {
return cmd.HTTPAddr
}
host, port, ok := strings.Cut(cmd.Addr, ":")
if !ok {
return "0.0.0.0:50050"
}
portNum, _ := strconv.Atoi(port)
return fmt.Sprintf("%s:%d", host, portNum-1)
}
// advertiseHTTPAddr returns the HTTP address the frontend should use to reach
// this node for file transfer.
func (cmd *WorkerCMD) advertiseHTTPAddr() string {
if cmd.AdvertiseHTTPAddr != "" {
return cmd.AdvertiseHTTPAddr
}
httpAddr := cmd.resolveHTTPAddr()
host, port, ok := strings.Cut(httpAddr, ":")
if ok && (host == "0.0.0.0" || host == "") {
if hostname, err := os.Hostname(); err == nil {
return hostname + ":" + port
}
}
return httpAddr
}
// registrationBody builds the JSON body for node registration.
func (cmd *WorkerCMD) registrationBody() map[string]any {
nodeName := cmd.NodeName
if nodeName == "" {
hostname, err := os.Hostname()
if err != nil {
nodeName = fmt.Sprintf("node-%d", os.Getpid())
} else {
nodeName = hostname
}
}
// Detect GPU info for VRAM-aware scheduling
totalVRAM, _ := xsysinfo.TotalAvailableVRAM()
gpuVendor, _ := xsysinfo.DetectGPUVendor()
body := map[string]any{
"name": nodeName,
"address": cmd.advertiseAddr(),
"http_address": cmd.advertiseHTTPAddr(),
"total_vram": totalVRAM,
"available_vram": totalVRAM, // initially all VRAM is available
"gpu_vendor": gpuVendor,
}
// If no GPU detected, report system RAM so the scheduler/UI has capacity info
if totalVRAM == 0 {
if ramInfo, err := xsysinfo.GetSystemRAMInfo(); err == nil {
body["total_ram"] = ramInfo.Total
body["available_ram"] = ramInfo.Available
}
}
if cmd.RegistrationToken != "" {
body["token"] = cmd.RegistrationToken
}
return body
}
// heartbeatBody returns the current VRAM/RAM stats for heartbeat payloads.
func (cmd *WorkerCMD) heartbeatBody() map[string]any {
var availVRAM uint64
aggregate := xsysinfo.GetGPUAggregateInfo()
if aggregate.TotalVRAM > 0 {
availVRAM = aggregate.FreeVRAM
} else {
// Fallback: report total as available (no usage tracking possible)
availVRAM, _ = xsysinfo.TotalAvailableVRAM()
}
body := map[string]any{
"available_vram": availVRAM,
}
// If no GPU, report system RAM usage instead
if aggregate.TotalVRAM == 0 {
if ramInfo, err := xsysinfo.GetSystemRAMInfo(); err == nil {
body["available_ram"] = ramInfo.Available
}
}
return body
}

View file

@ -0,0 +1,272 @@
// Package workerregistry provides a shared HTTP client for worker node
// registration, heartbeating, draining, and deregistration against a
// LocalAI frontend. Both the backend worker (WorkerCMD) and the agent
// worker (AgentWorkerCMD) use this instead of duplicating the logic.
package workerregistry
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"sync"
"time"
"github.com/mudler/xlog"
)
// RegistrationClient talks to the frontend's /api/node/* endpoints.
type RegistrationClient struct {
FrontendURL string
RegistrationToken string
HTTPTimeout time.Duration // used for registration calls; defaults to 10s
client *http.Client
clientOnce sync.Once
}
// httpTimeout returns the configured timeout or a sensible default.
func (c *RegistrationClient) httpTimeout() time.Duration {
if c.HTTPTimeout > 0 {
return c.HTTPTimeout
}
return 10 * time.Second
}
// httpClient returns the shared HTTP client, initializing it on first use.
func (c *RegistrationClient) httpClient() *http.Client {
c.clientOnce.Do(func() {
c.client = &http.Client{Timeout: c.httpTimeout()}
})
return c.client
}
// baseURL returns FrontendURL with any trailing slash stripped.
func (c *RegistrationClient) baseURL() string {
return strings.TrimRight(c.FrontendURL, "/")
}
// setAuth adds an Authorization header when a token is configured.
func (c *RegistrationClient) setAuth(req *http.Request) {
if c.RegistrationToken != "" {
req.Header.Set("Authorization", "Bearer "+c.RegistrationToken)
}
}
// RegisterResponse is the JSON body returned by /api/node/register.
type RegisterResponse struct {
ID string `json:"id"`
APIToken string `json:"api_token,omitempty"`
}
// Register sends a single registration request and returns the node ID and
// (optionally) an auto-provisioned API token.
func (c *RegistrationClient) Register(ctx context.Context, body map[string]any) (string, string, error) {
jsonBody, _ := json.Marshal(body)
url := c.baseURL() + "/api/node/register"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody))
if err != nil {
return "", "", fmt.Errorf("creating request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
c.setAuth(req)
resp, err := c.httpClient().Do(req)
if err != nil {
return "", "", fmt.Errorf("posting to %s: %w", url, err)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return "", "", fmt.Errorf("registration failed with status %d", resp.StatusCode)
}
var result RegisterResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", "", fmt.Errorf("decoding response: %w", err)
}
return result.ID, result.APIToken, nil
}
// RegisterWithRetry retries registration with exponential backoff.
func (c *RegistrationClient) RegisterWithRetry(ctx context.Context, body map[string]any, maxRetries int) (string, string, error) {
backoff := 2 * time.Second
maxBackoff := 30 * time.Second
var nodeID, apiToken string
var err error
for attempt := 1; attempt <= maxRetries; attempt++ {
nodeID, apiToken, err = c.Register(ctx, body)
if err == nil {
return nodeID, apiToken, nil
}
if attempt == maxRetries {
return "", "", fmt.Errorf("failed after %d attempts: %w", maxRetries, err)
}
xlog.Warn("Registration failed, retrying", "attempt", attempt, "next_retry", backoff, "error", err)
select {
case <-ctx.Done():
return "", "", ctx.Err()
case <-time.After(backoff):
}
backoff = min(backoff*2, maxBackoff)
}
return nodeID, apiToken, err
}
// Heartbeat sends a single heartbeat POST with the given body.
func (c *RegistrationClient) Heartbeat(ctx context.Context, nodeID string, body map[string]any) error {
jsonBody, _ := json.Marshal(body)
url := c.baseURL() + "/api/node/" + nodeID + "/heartbeat"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody))
if err != nil {
return fmt.Errorf("creating heartbeat request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
c.setAuth(req)
resp, err := c.httpClient().Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
return nil
}
// HeartbeatLoop runs heartbeats at the given interval until ctx is cancelled.
// bodyFn is called each tick to build the heartbeat payload (e.g. VRAM stats).
func (c *RegistrationClient) HeartbeatLoop(ctx context.Context, nodeID string, interval time.Duration, bodyFn func() map[string]any) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
body := bodyFn()
if err := c.Heartbeat(ctx, nodeID, body); err != nil {
xlog.Warn("Heartbeat failed", "error", err)
}
}
}
}
// Drain sets the node to draining status via POST /api/node/:id/drain.
func (c *RegistrationClient) Drain(ctx context.Context, nodeID string) error {
url := c.baseURL() + "/api/node/" + nodeID + "/drain"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
if err != nil {
return fmt.Errorf("creating drain request: %w", err)
}
c.setAuth(req)
resp, err := c.httpClient().Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("drain failed with status %d", resp.StatusCode)
}
return nil
}
// WaitForDrain polls GET /api/node/:id/models until all models report 0
// in-flight requests, or until timeout elapses.
func (c *RegistrationClient) WaitForDrain(ctx context.Context, nodeID string, timeout time.Duration) {
url := c.baseURL() + "/api/node/" + nodeID + "/models"
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
xlog.Warn("Failed to create drain poll request", "error", err)
return
}
c.setAuth(req)
resp, err := c.httpClient().Do(req)
if err != nil {
xlog.Warn("Drain poll failed, will retry", "error", err)
select {
case <-ctx.Done():
xlog.Warn("Drain wait cancelled")
return
case <-time.After(1 * time.Second):
}
continue
}
var models []struct {
InFlight int `json:"in_flight"`
}
json.NewDecoder(resp.Body).Decode(&models)
resp.Body.Close()
total := 0
for _, m := range models {
total += m.InFlight
}
if total == 0 {
xlog.Info("All in-flight requests drained")
return
}
xlog.Info("Waiting for in-flight requests", "count", total)
select {
case <-ctx.Done():
xlog.Warn("Drain wait cancelled")
return
case <-time.After(1 * time.Second):
}
}
xlog.Warn("Drain timeout reached, proceeding with shutdown")
}
// Deregister marks the node as offline via POST /api/node/:id/deregister.
// The node row is preserved in the database so re-registration restores
// approval status.
func (c *RegistrationClient) Deregister(ctx context.Context, nodeID string) error {
url := c.baseURL() + "/api/node/" + nodeID + "/deregister"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
if err != nil {
return fmt.Errorf("creating deregister request: %w", err)
}
c.setAuth(req)
resp, err := c.httpClient().Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("deregistration failed with status %d", resp.StatusCode)
}
return nil
}
// GracefulDeregister performs drain -> wait -> deregister in sequence.
// This is the standard shutdown sequence for backend workers.
func (c *RegistrationClient) GracefulDeregister(nodeID string) {
if c.FrontendURL == "" || nodeID == "" {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
if err := c.Drain(ctx, nodeID); err != nil {
xlog.Warn("Failed to set drain status", "error", err)
} else {
c.WaitForDrain(ctx, nodeID, 30*time.Second)
}
if err := c.Deregister(ctx, nodeID); err != nil {
xlog.Error("Failed to deregister", "error", err)
} else {
xlog.Info("Deregistered from frontend")
}
}

View file

@ -94,7 +94,7 @@ func (c *StoreClient) Find(req FindRequest) (*FindResponse, error) {
} }
// Helper function to perform a request without expecting a response body // Helper function to perform a request without expecting a response body
func (c *StoreClient) doRequest(path string, data interface{}) error { func (c *StoreClient) doRequest(path string, data any) error {
jsonData, err := json.Marshal(data) jsonData, err := json.Marshal(data)
if err != nil { if err != nil {
return err return err
@ -120,7 +120,7 @@ func (c *StoreClient) doRequest(path string, data interface{}) error {
} }
// Helper function to perform a request and parse the response body // Helper function to perform a request and parse the response body
func (c *StoreClient) doRequestWithResponse(path string, data interface{}) ([]byte, error) { func (c *StoreClient) doRequestWithResponse(path string, data any) ([]byte, error) {
jsonData, err := json.Marshal(data) jsonData, err := json.Marshal(data)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -83,8 +83,8 @@ type ApplicationConfig struct {
APIAddress string APIAddress string
LlamaCPPTunnelCallback func(tunnels []string) LlamaCPPTunnelCallback func(tunnels []string)
MLXTunnelCallback func(tunnels []string) MLXTunnelCallback func(tunnels []string)
DisableRuntimeSettings bool DisableRuntimeSettings bool
@ -99,47 +99,50 @@ type ApplicationConfig struct {
// Authentication & Authorization // Authentication & Authorization
Auth AuthConfig Auth AuthConfig
// Distributed / Horizontal Scaling
Distributed DistributedConfig
} }
// AuthConfig holds configuration for user authentication and authorization. // AuthConfig holds configuration for user authentication and authorization.
type AuthConfig struct { type AuthConfig struct {
Enabled bool Enabled bool
DatabaseURL string // "postgres://..." or file path for SQLite DatabaseURL string // "postgres://..." or file path for SQLite
GitHubClientID string GitHubClientID string
GitHubClientSecret string GitHubClientSecret string
OIDCIssuer string // OIDC issuer URL for auto-discovery (e.g. https://accounts.google.com) OIDCIssuer string // OIDC issuer URL for auto-discovery (e.g. https://accounts.google.com)
OIDCClientID string OIDCClientID string
OIDCClientSecret string OIDCClientSecret string
BaseURL string // for OAuth callback URLs (e.g. "http://localhost:8080") BaseURL string // for OAuth callback URLs (e.g. "http://localhost:8080")
AdminEmail string // auto-promote to admin on login AdminEmail string // auto-promote to admin on login
RegistrationMode string // "open", "approval" (default when empty), "invite" RegistrationMode string // "open", "approval" (default when empty), "invite"
DisableLocalAuth bool // disable local email/password registration and login DisableLocalAuth bool // disable local email/password registration and login
APIKeyHMACSecret string // HMAC secret for API key hashing; auto-generated if empty APIKeyHMACSecret string // HMAC secret for API key hashing; auto-generated if empty
DefaultAPIKeyExpiry string // default expiry duration for API keys (e.g. "90d"); empty = no expiry DefaultAPIKeyExpiry string // default expiry duration for API keys (e.g. "90d"); empty = no expiry
} }
// AgentPoolConfig holds configuration for the LocalAGI agent pool integration. // AgentPoolConfig holds configuration for the LocalAGI agent pool integration.
type AgentPoolConfig struct { type AgentPoolConfig struct {
Enabled bool // default: true (disabled by LOCALAI_DISABLE_AGENTS=true) Enabled bool // default: true (disabled by LOCALAI_DISABLE_AGENTS=true)
StateDir string // default: DynamicConfigsDir (LocalAI configuration folder) StateDir string // default: DynamicConfigsDir (LocalAI configuration folder)
APIURL string // default: self-referencing LocalAI (http://127.0.0.1:<port>) APIURL string // default: self-referencing LocalAI (http://127.0.0.1:<port>)
APIKey string // default: first API key from LocalAI config APIKey string // default: first API key from LocalAI config
DefaultModel string DefaultModel string
MultimodalModel string MultimodalModel string
TranscriptionModel string TranscriptionModel string
TranscriptionLanguage string TranscriptionLanguage string
TTSModel string TTSModel string
Timeout string // default: "5m" Timeout string // default: "5m"
EnableSkills bool EnableSkills bool
EnableLogs bool EnableLogs bool
CustomActionsDir string CustomActionsDir string
CollectionDBPath string CollectionDBPath string
VectorEngine string // default: "chromem" VectorEngine string // default: "chromem"
EmbeddingModel string // default: "granite-embedding-107m-multilingual" EmbeddingModel string // default: "granite-embedding-107m-multilingual"
MaxChunkingSize int // default: 400 MaxChunkingSize int // default: 400
ChunkOverlap int // default: 0 ChunkOverlap int // default: 0
DatabaseURL string DatabaseURL string
AgentHubURL string // default: "https://agenthub.localai.io" AgentHubURL string // default: "https://agenthub.localai.io"
} }
type AppOption func(*ApplicationConfig) type AppOption func(*ApplicationConfig)
@ -155,12 +158,12 @@ func NewApplicationConfig(o ...AppOption) *ApplicationConfig {
WatchDogInterval: 500 * time.Millisecond, // Default: 500ms WatchDogInterval: 500 * time.Millisecond, // Default: 500ms
TracingMaxItems: 1024, TracingMaxItems: 1024,
AgentPool: AgentPoolConfig{ AgentPool: AgentPoolConfig{
Enabled: true, Enabled: true,
Timeout: "5m", Timeout: "5m",
VectorEngine: "chromem", VectorEngine: "chromem",
EmbeddingModel: "granite-embedding-107m-multilingual", EmbeddingModel: "granite-embedding-107m-multilingual",
MaxChunkingSize: 400, MaxChunkingSize: 400,
AgentHubURL: "https://agenthub.localai.io", AgentHubURL: "https://agenthub.localai.io",
}, },
PathWithoutAuth: []string{ PathWithoutAuth: []string{
"/static/", "/static/",
@ -904,40 +907,40 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
agentPoolCollectionDBPath := o.AgentPool.CollectionDBPath agentPoolCollectionDBPath := o.AgentPool.CollectionDBPath
return RuntimeSettings{ return RuntimeSettings{
WatchdogEnabled: &watchdogEnabled, WatchdogEnabled: &watchdogEnabled,
WatchdogIdleEnabled: &watchdogIdle, WatchdogIdleEnabled: &watchdogIdle,
WatchdogBusyEnabled: &watchdogBusy, WatchdogBusyEnabled: &watchdogBusy,
WatchdogIdleTimeout: &idleTimeout, WatchdogIdleTimeout: &idleTimeout,
WatchdogBusyTimeout: &busyTimeout, WatchdogBusyTimeout: &busyTimeout,
WatchdogInterval: &watchdogInterval, WatchdogInterval: &watchdogInterval,
SingleBackend: &singleBackend, SingleBackend: &singleBackend,
MaxActiveBackends: &maxActiveBackends, MaxActiveBackends: &maxActiveBackends,
ParallelBackendRequests: &parallelBackendRequests, ParallelBackendRequests: &parallelBackendRequests,
MemoryReclaimerEnabled: &memoryReclaimerEnabled, MemoryReclaimerEnabled: &memoryReclaimerEnabled,
MemoryReclaimerThreshold: &memoryReclaimerThreshold, MemoryReclaimerThreshold: &memoryReclaimerThreshold,
ForceEvictionWhenBusy: &forceEvictionWhenBusy, ForceEvictionWhenBusy: &forceEvictionWhenBusy,
LRUEvictionMaxRetries: &lruEvictionMaxRetries, LRUEvictionMaxRetries: &lruEvictionMaxRetries,
LRUEvictionRetryInterval: &lruEvictionRetryInterval, LRUEvictionRetryInterval: &lruEvictionRetryInterval,
Threads: &threads, Threads: &threads,
ContextSize: &contextSize, ContextSize: &contextSize,
F16: &f16, F16: &f16,
Debug: &debug, Debug: &debug,
TracingMaxItems: &tracingMaxItems, TracingMaxItems: &tracingMaxItems,
EnableTracing: &enableTracing, EnableTracing: &enableTracing,
EnableBackendLogging: &enableBackendLogging, EnableBackendLogging: &enableBackendLogging,
CORS: &cors, CORS: &cors,
CSRF: &csrf, CSRF: &csrf,
CORSAllowOrigins: &corsAllowOrigins, CORSAllowOrigins: &corsAllowOrigins,
P2PToken: &p2pToken, P2PToken: &p2pToken,
P2PNetworkID: &p2pNetworkID, P2PNetworkID: &p2pNetworkID,
Federated: &federated, Federated: &federated,
Galleries: &galleries, Galleries: &galleries,
BackendGalleries: &backendGalleries, BackendGalleries: &backendGalleries,
AutoloadGalleries: &autoloadGalleries, AutoloadGalleries: &autoloadGalleries,
AutoloadBackendGalleries: &autoloadBackendGalleries, AutoloadBackendGalleries: &autoloadBackendGalleries,
ApiKeys: &apiKeys, ApiKeys: &apiKeys,
AgentJobRetentionDays: &agentJobRetentionDays, AgentJobRetentionDays: &agentJobRetentionDays,
OpenResponsesStoreTTL: &openResponsesStoreTTL, OpenResponsesStoreTTL: &openResponsesStoreTTL,
AgentPoolEnabled: &agentPoolEnabled, AgentPoolEnabled: &agentPoolEnabled,
AgentPoolDefaultModel: &agentPoolDefaultModel, AgentPoolDefaultModel: &agentPoolDefaultModel,
AgentPoolEmbeddingModel: &agentPoolEmbeddingModel, AgentPoolEmbeddingModel: &agentPoolEmbeddingModel,

View file

@ -26,7 +26,7 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
F16: true, F16: true,
Debug: true, Debug: true,
CORS: true, CORS: true,
DisableCSRF: true, DisableCSRF: true,
CORSAllowOrigins: "https://example.com", CORSAllowOrigins: "https://example.com",
P2PToken: "test-token", P2PToken: "test-token",
P2PNetworkID: "test-network", P2PNetworkID: "test-network",
@ -463,7 +463,7 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
F16: true, F16: true,
Debug: false, Debug: false,
CORS: true, CORS: true,
DisableCSRF: false, DisableCSRF: false,
CORSAllowOrigins: "https://test.com", CORSAllowOrigins: "https://test.com",
P2PToken: "round-trip-token", P2PToken: "round-trip-token",
P2PNetworkID: "round-trip-network", P2PNetworkID: "round-trip-network",

View file

@ -0,0 +1,188 @@
package config
import (
"cmp"
"fmt"
"time"
"github.com/mudler/xlog"
)
// DistributedConfig holds configuration for horizontal scaling mode.
// When Enabled is true, PostgreSQL and NATS are required.
type DistributedConfig struct {
Enabled bool // --distributed / LOCALAI_DISTRIBUTED
InstanceID string // --instance-id / LOCALAI_INSTANCE_ID (auto-generated UUID if empty)
NatsURL string // --nats-url / LOCALAI_NATS_URL
StorageURL string // --storage-url / LOCALAI_STORAGE_URL (S3 endpoint)
RegistrationToken string // --registration-token / LOCALAI_REGISTRATION_TOKEN (required token for node registration)
AutoApproveNodes bool // --auto-approve-nodes / LOCALAI_AUTO_APPROVE_NODES (skip admin approval for new workers)
// S3 configuration (used when StorageURL is set)
StorageBucket string // --storage-bucket / LOCALAI_STORAGE_BUCKET
StorageRegion string // --storage-region / LOCALAI_STORAGE_REGION
StorageAccessKey string // --storage-access-key / LOCALAI_STORAGE_ACCESS_KEY
StorageSecretKey string // --storage-secret-key / LOCALAI_STORAGE_SECRET_KEY
// Timeout configuration (all have sensible defaults — zero means use default)
MCPToolTimeout time.Duration // MCP tool execution timeout (default 360s)
MCPDiscoveryTimeout time.Duration // MCP discovery timeout (default 60s)
WorkerWaitTimeout time.Duration // Max wait for healthy worker at startup (default 5m)
DrainTimeout time.Duration // Time to wait for in-flight requests during drain (default 30s)
HealthCheckInterval time.Duration // Health monitor check interval (default 15s)
StaleNodeThreshold time.Duration // Time before a node is considered stale (default 60s)
PerModelHealthCheck bool // Enable per-model backend health checking (default false)
MCPCIJobTimeout time.Duration // MCP CI job execution timeout (default 10m)
MaxUploadSize int64 // Maximum upload body size in bytes (default 50 GB)
AgentWorkerConcurrency int `yaml:"agent_worker_concurrency" json:"agent_worker_concurrency" env:"LOCALAI_AGENT_WORKER_CONCURRENCY"`
JobWorkerConcurrency int `yaml:"job_worker_concurrency" json:"job_worker_concurrency" env:"LOCALAI_JOB_WORKER_CONCURRENCY"`
}
// Validate checks that the distributed configuration is internally consistent.
// It returns nil if distributed mode is disabled.
func (c DistributedConfig) Validate() error {
if !c.Enabled {
return nil
}
if c.NatsURL == "" {
return fmt.Errorf("distributed mode requires --nats-url / LOCALAI_NATS_URL")
}
// S3 credentials must be paired
if (c.StorageAccessKey != "" && c.StorageSecretKey == "") ||
(c.StorageAccessKey == "" && c.StorageSecretKey != "") {
return fmt.Errorf("storage-access-key and storage-secret-key must both be set or both empty")
}
// Warn about missing registration token (not an error)
if c.RegistrationToken == "" {
xlog.Warn("distributed mode running without registration token — node endpoints are unprotected")
}
// Check for negative durations
for name, d := range map[string]time.Duration{
"mcp-tool-timeout": c.MCPToolTimeout,
"mcp-discovery-timeout": c.MCPDiscoveryTimeout,
"worker-wait-timeout": c.WorkerWaitTimeout,
"drain-timeout": c.DrainTimeout,
"health-check-interval": c.HealthCheckInterval,
"stale-node-threshold": c.StaleNodeThreshold,
"mcp-ci-job-timeout": c.MCPCIJobTimeout,
} {
if d < 0 {
return fmt.Errorf("%s must not be negative", name)
}
}
return nil
}
// Distributed config options
var EnableDistributed = func(o *ApplicationConfig) {
o.Distributed.Enabled = true
}
func WithDistributedInstanceID(id string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.InstanceID = id
}
}
func WithNatsURL(url string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.NatsURL = url
}
}
func WithRegistrationToken(token string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.RegistrationToken = token
}
}
func WithStorageURL(url string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.StorageURL = url
}
}
func WithStorageBucket(bucket string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.StorageBucket = bucket
}
}
func WithStorageRegion(region string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.StorageRegion = region
}
}
func WithStorageAccessKey(key string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.StorageAccessKey = key
}
}
func WithStorageSecretKey(key string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.StorageSecretKey = key
}
}
var EnableAutoApproveNodes = func(o *ApplicationConfig) {
o.Distributed.AutoApproveNodes = true
}
// Defaults for distributed timeouts.
const (
DefaultMCPToolTimeout = 360 * time.Second
DefaultMCPDiscoveryTimeout = 60 * time.Second
DefaultWorkerWaitTimeout = 5 * time.Minute
DefaultDrainTimeout = 30 * time.Second
DefaultHealthCheckInterval = 15 * time.Second
DefaultStaleNodeThreshold = 60 * time.Second
DefaultMCPCIJobTimeout = 10 * time.Minute
)
// DefaultMaxUploadSize is the default maximum upload body size (50 GB).
const DefaultMaxUploadSize int64 = 50 << 30
// MCPToolTimeoutOrDefault returns the configured timeout or the default.
func (c DistributedConfig) MCPToolTimeoutOrDefault() time.Duration {
return cmp.Or(c.MCPToolTimeout, DefaultMCPToolTimeout)
}
// MCPDiscoveryTimeoutOrDefault returns the configured timeout or the default.
func (c DistributedConfig) MCPDiscoveryTimeoutOrDefault() time.Duration {
return cmp.Or(c.MCPDiscoveryTimeout, DefaultMCPDiscoveryTimeout)
}
// WorkerWaitTimeoutOrDefault returns the configured timeout or the default.
func (c DistributedConfig) WorkerWaitTimeoutOrDefault() time.Duration {
return cmp.Or(c.WorkerWaitTimeout, DefaultWorkerWaitTimeout)
}
// DrainTimeoutOrDefault returns the configured timeout or the default.
func (c DistributedConfig) DrainTimeoutOrDefault() time.Duration {
return cmp.Or(c.DrainTimeout, DefaultDrainTimeout)
}
// HealthCheckIntervalOrDefault returns the configured interval or the default.
func (c DistributedConfig) HealthCheckIntervalOrDefault() time.Duration {
return cmp.Or(c.HealthCheckInterval, DefaultHealthCheckInterval)
}
// StaleNodeThresholdOrDefault returns the configured threshold or the default.
func (c DistributedConfig) StaleNodeThresholdOrDefault() time.Duration {
return cmp.Or(c.StaleNodeThreshold, DefaultStaleNodeThreshold)
}
// MCPCIJobTimeoutOrDefault returns the configured MCP CI job timeout or the default.
func (c DistributedConfig) MCPCIJobTimeoutOrDefault() time.Duration {
return cmp.Or(c.MCPCIJobTimeout, DefaultMCPCIJobTimeout)
}
// MaxUploadSizeOrDefault returns the configured max upload size or the default.
func (c DistributedConfig) MaxUploadSizeOrDefault() int64 {
return cmp.Or(c.MaxUploadSize, DefaultMaxUploadSize)
}

View file

@ -46,11 +46,11 @@ type ModelConfig struct {
KnownUsecases *ModelConfigUsecase `yaml:"-" json:"-"` KnownUsecases *ModelConfigUsecase `yaml:"-" json:"-"`
Pipeline Pipeline `yaml:"pipeline,omitempty" json:"pipeline,omitempty"` Pipeline Pipeline `yaml:"pipeline,omitempty" json:"pipeline,omitempty"`
PromptStrings, InputStrings []string `yaml:"-" json:"-"` PromptStrings, InputStrings []string `yaml:"-" json:"-"`
InputToken [][]int `yaml:"-" json:"-"` InputToken [][]int `yaml:"-" json:"-"`
functionCallString, functionCallNameString string `yaml:"-" json:"-"` functionCallString, functionCallNameString string `yaml:"-" json:"-"`
ResponseFormat string `yaml:"-" json:"-"` ResponseFormat string `yaml:"-" json:"-"`
ResponseFormatMap map[string]interface{} `yaml:"-" json:"-"` ResponseFormatMap map[string]any `yaml:"-" json:"-"`
FunctionsConfig functions.FunctionsConfig `yaml:"function,omitempty" json:"function,omitempty"` FunctionsConfig functions.FunctionsConfig `yaml:"function,omitempty" json:"function,omitempty"`
ReasoningConfig reasoning.Config `yaml:"reasoning,omitempty" json:"reasoning,omitempty"` ReasoningConfig reasoning.Config `yaml:"reasoning,omitempty" json:"reasoning,omitempty"`
@ -105,6 +105,11 @@ type AgentConfig struct {
ForceReasoningTool bool `yaml:"force_reasoning_tool,omitempty" json:"force_reasoning_tool,omitempty"` ForceReasoningTool bool `yaml:"force_reasoning_tool,omitempty" json:"force_reasoning_tool,omitempty"`
} }
// HasMCPServers returns true if any MCP servers (remote or stdio) are configured.
func (c MCPConfig) HasMCPServers() bool {
return c.Servers != "" || c.Stdio != ""
}
func (c *MCPConfig) MCPConfigFromYAML() (MCPGenericConfig[MCPRemoteServers], MCPGenericConfig[MCPSTDIOServers], error) { func (c *MCPConfig) MCPConfigFromYAML() (MCPGenericConfig[MCPRemoteServers], MCPGenericConfig[MCPSTDIOServers], error) {
var remote MCPGenericConfig[MCPRemoteServers] var remote MCPGenericConfig[MCPRemoteServers]
var stdio MCPGenericConfig[MCPSTDIOServers] var stdio MCPGenericConfig[MCPSTDIOServers]
@ -619,15 +624,32 @@ func (c *ModelConfig) HasUsecases(u ModelConfigUsecase) bool {
// In its current state, this function should ideally check for properties of the config like templates, rather than the direct backend name checks for the lower half. // In its current state, this function should ideally check for properties of the config like templates, rather than the direct backend name checks for the lower half.
// This avoids the maintenance burden of updating this list for each new backend - but unfortunately, that's the best option for some services currently. // This avoids the maintenance burden of updating this list for each new backend - but unfortunately, that's the best option for some services currently.
func (c *ModelConfig) GuessUsecases(u ModelConfigUsecase) bool { func (c *ModelConfig) GuessUsecases(u ModelConfigUsecase) bool {
// Backends that are clearly not text-generation
nonTextGenBackends := []string{
"whisper", "piper", "kokoro",
"diffusers", "stablediffusion", "stablediffusion-ggml",
"rerankers", "silero-vad", "rfdetr",
"transformers-musicgen", "ace-step", "acestep-cpp",
}
if (u & FLAG_CHAT) == FLAG_CHAT { if (u & FLAG_CHAT) == FLAG_CHAT {
if c.TemplateConfig.Chat == "" && c.TemplateConfig.ChatMessage == "" && !c.TemplateConfig.UseTokenizerTemplate { if c.TemplateConfig.Chat == "" && c.TemplateConfig.ChatMessage == "" && !c.TemplateConfig.UseTokenizerTemplate {
return false return false
} }
if slices.Contains(nonTextGenBackends, c.Backend) {
return false
}
if c.Embeddings != nil && *c.Embeddings {
return false
}
} }
if (u & FLAG_COMPLETION) == FLAG_COMPLETION { if (u & FLAG_COMPLETION) == FLAG_COMPLETION {
if c.TemplateConfig.Completion == "" { if c.TemplateConfig.Completion == "" {
return false return false
} }
if slices.Contains(nonTextGenBackends, c.Backend) {
return false
}
} }
if (u & FLAG_EDIT) == FLAG_EDIT { if (u & FLAG_EDIT) == FLAG_EDIT {
if c.TemplateConfig.Edit == "" { if c.TemplateConfig.Edit == "" {

View file

@ -1,12 +1,13 @@
package config package config
import ( import (
"cmp"
"errors" "errors"
"fmt" "fmt"
"io/fs" "io/fs"
"os" "os"
"path/filepath" "path/filepath"
"sort" "slices"
"strings" "strings"
"sync" "sync"
@ -215,8 +216,8 @@ func (bcl *ModelConfigLoader) GetAllModelsConfigs() []ModelConfig {
res = append(res, v) res = append(res, v)
} }
sort.SliceStable(res, func(i, j int) bool { slices.SortStableFunc(res, func(a, b ModelConfig) int {
return res[i].Name < res[j].Name return cmp.Compare(a.Name, b.Name)
}) })
return res return res

View file

@ -27,15 +27,15 @@ type RuntimeSettings struct {
MemoryReclaimerThreshold *float64 `json:"memory_reclaimer_threshold,omitempty"` // Threshold 0.0-1.0 (e.g., 0.95 = 95%) MemoryReclaimerThreshold *float64 `json:"memory_reclaimer_threshold,omitempty"` // Threshold 0.0-1.0 (e.g., 0.95 = 95%)
// Eviction settings // Eviction settings
ForceEvictionWhenBusy *bool `json:"force_eviction_when_busy,omitempty"` // Force eviction even when models have active API calls (default: false for safety) ForceEvictionWhenBusy *bool `json:"force_eviction_when_busy,omitempty"` // Force eviction even when models have active API calls (default: false for safety)
LRUEvictionMaxRetries *int `json:"lru_eviction_max_retries,omitempty"` // Maximum number of retries when waiting for busy models to become idle (default: 30) LRUEvictionMaxRetries *int `json:"lru_eviction_max_retries,omitempty"` // Maximum number of retries when waiting for busy models to become idle (default: 30)
LRUEvictionRetryInterval *string `json:"lru_eviction_retry_interval,omitempty"` // Interval between retries when waiting for busy models (e.g., 1s, 2s) (default: 1s) LRUEvictionRetryInterval *string `json:"lru_eviction_retry_interval,omitempty"` // Interval between retries when waiting for busy models (e.g., 1s, 2s) (default: 1s)
// Performance settings // Performance settings
Threads *int `json:"threads,omitempty"` Threads *int `json:"threads,omitempty"`
ContextSize *int `json:"context_size,omitempty"` ContextSize *int `json:"context_size,omitempty"`
F16 *bool `json:"f16,omitempty"` F16 *bool `json:"f16,omitempty"`
Debug *bool `json:"debug,omitempty"` Debug *bool `json:"debug,omitempty"`
EnableTracing *bool `json:"enable_tracing,omitempty"` EnableTracing *bool `json:"enable_tracing,omitempty"`
TracingMaxItems *int `json:"tracing_max_items,omitempty"` TracingMaxItems *int `json:"tracing_max_items,omitempty"`
EnableBackendLogging *bool `json:"enable_backend_logging,omitempty"` EnableBackendLogging *bool `json:"enable_backend_logging,omitempty"`
@ -66,11 +66,11 @@ type RuntimeSettings struct {
OpenResponsesStoreTTL *string `json:"open_responses_store_ttl,omitempty"` // TTL for stored responses (e.g., "1h", "30m", "0" = no expiration) OpenResponsesStoreTTL *string `json:"open_responses_store_ttl,omitempty"` // TTL for stored responses (e.g., "1h", "30m", "0" = no expiration)
// Agent Pool settings // Agent Pool settings
AgentPoolEnabled *bool `json:"agent_pool_enabled,omitempty"` AgentPoolEnabled *bool `json:"agent_pool_enabled,omitempty"`
AgentPoolDefaultModel *string `json:"agent_pool_default_model,omitempty"` AgentPoolDefaultModel *string `json:"agent_pool_default_model,omitempty"`
AgentPoolEmbeddingModel *string `json:"agent_pool_embedding_model,omitempty"` AgentPoolEmbeddingModel *string `json:"agent_pool_embedding_model,omitempty"`
AgentPoolMaxChunkingSize *int `json:"agent_pool_max_chunking_size,omitempty"` AgentPoolMaxChunkingSize *int `json:"agent_pool_max_chunking_size,omitempty"`
AgentPoolChunkOverlap *int `json:"agent_pool_chunk_overlap,omitempty"` AgentPoolChunkOverlap *int `json:"agent_pool_chunk_overlap,omitempty"`
AgentPoolEnableLogs *bool `json:"agent_pool_enable_logs,omitempty"` AgentPoolEnableLogs *bool `json:"agent_pool_enable_logs,omitempty"`
AgentPoolCollectionDBPath *string `json:"agent_pool_collection_db_path,omitempty"` AgentPoolCollectionDBPath *string `json:"agent_pool_collection_db_path,omitempty"`
} }

View file

@ -3,9 +3,10 @@ package explorer
// A simple JSON database for storing and retrieving p2p network tokens and a name and description. // A simple JSON database for storing and retrieving p2p network tokens and a name and description.
import ( import (
"cmp"
"encoding/json" "encoding/json"
"os" "os"
"sort" "slices"
"sync" "sync"
"github.com/gofrs/flock" "github.com/gofrs/flock"
@ -89,9 +90,8 @@ func (db *Database) TokenList() []string {
tokens = append(tokens, k) tokens = append(tokens, k)
} }
sort.Slice(tokens, func(i, j int) bool { slices.SortFunc(tokens, func(a, b string) int {
// sort by token return cmp.Compare(a, b)
return tokens[i] < tokens[j]
}) })
return tokens return tokens

View file

@ -15,7 +15,7 @@ import (
// modelConfigCacheEntry holds a cached parsed config_file map from a URL-referenced model config. // modelConfigCacheEntry holds a cached parsed config_file map from a URL-referenced model config.
type modelConfigCacheEntry struct { type modelConfigCacheEntry struct {
configMap map[string]interface{} configMap map[string]any
lastUpdated time.Time lastUpdated time.Time
} }
@ -57,7 +57,7 @@ func resolveBackend(m *GalleryModel, basePath string) string {
// fetchModelConfigMap fetches a model config URL, parses the config_file YAML string // fetchModelConfigMap fetches a model config URL, parses the config_file YAML string
// inside it, and returns the result as a map. Results are cached for 1 hour. // inside it, and returns the result as a map. Results are cached for 1 hour.
// Local file:// URLs skip the cache so edits are picked up immediately. // Local file:// URLs skip the cache so edits are picked up immediately.
func fetchModelConfigMap(modelURL, basePath string) map[string]interface{} { func fetchModelConfigMap(modelURL, basePath string) map[string]any {
// Check cache (skip for file:// URLs so local edits are picked up immediately) // Check cache (skip for file:// URLs so local edits are picked up immediately)
isLocal := strings.HasPrefix(modelURL, downloader.LocalPrefix) isLocal := strings.HasPrefix(modelURL, downloader.LocalPrefix)
if !isLocal && modelConfigCache.Exists(modelURL) { if !isLocal && modelConfigCache.Exists(modelURL) {
@ -75,15 +75,15 @@ func fetchModelConfigMap(modelURL, basePath string) map[string]interface{} {
// Cache the failure for remote URLs to avoid repeated fetch attempts // Cache the failure for remote URLs to avoid repeated fetch attempts
if !isLocal { if !isLocal {
modelConfigCache.Set(modelURL, modelConfigCacheEntry{ modelConfigCache.Set(modelURL, modelConfigCacheEntry{
configMap: map[string]interface{}{}, configMap: map[string]any{},
lastUpdated: time.Now(), lastUpdated: time.Now(),
}) })
} }
return map[string]interface{}{} return map[string]any{}
} }
// Parse the config_file YAML string into a map // Parse the config_file YAML string into a map
configMap := make(map[string]interface{}) configMap := make(map[string]any)
if modelConfig.ConfigFile != "" { if modelConfig.ConfigFile != "" {
if err := yaml.Unmarshal([]byte(modelConfig.ConfigFile), &configMap); err != nil { if err := yaml.Unmarshal([]byte(modelConfig.ConfigFile), &configMap); err != nil {
xlog.Debug("Failed to parse config_file for backend resolution", "url", modelURL, "error", err) xlog.Debug("Failed to parse config_file for backend resolution", "url", modelURL, "error", err)
@ -108,13 +108,11 @@ func prefetchModelConfigs(urls []string, basePath string) {
sem := make(chan struct{}, maxConcurrency) sem := make(chan struct{}, maxConcurrency)
var wg sync.WaitGroup var wg sync.WaitGroup
for _, url := range urls { for _, url := range urls {
wg.Add(1) wg.Go(func() {
go func(u string) {
defer wg.Done()
sem <- struct{}{} sem <- struct{}{}
defer func() { <-sem }() defer func() { <-sem }()
fetchModelConfigMap(u, basePath) fetchModelConfigMap(url, basePath)
}(url) })
} }
wg.Wait() wg.Wait()
} }

View file

@ -4,10 +4,10 @@ package gallery
import ( import (
"context" "context"
"os"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"os"
"path/filepath" "path/filepath"
"strings" "strings"
"time" "time"
@ -20,6 +20,9 @@ import (
cp "github.com/otiai10/copy" cp "github.com/otiai10/copy"
) )
// ErrBackendNotFound is returned when a backend is not found in the system.
var ErrBackendNotFound = errors.New("backend not found")
const ( const (
metadataFile = "metadata.json" metadataFile = "metadata.json"
runFile = "run.sh" runFile = "run.sh"
@ -198,9 +201,16 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL
} else { } else {
xlog.Debug("Downloading backend", "uri", config.URI, "backendPath", backendPath) xlog.Debug("Downloading backend", "uri", config.URI, "backendPath", backendPath)
if err := uri.DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err != nil { if err := uri.DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err != nil {
// Don't remove backendPath here — fallback OCI extractions need the directory to exist
xlog.Debug("Backend download failed, trying fallback", "backendPath", backendPath, "error", err) xlog.Debug("Backend download failed, trying fallback", "backendPath", backendPath, "error", err)
// resetBackendPath cleans up partial state from a failed OCI extraction
// so the next download attempt starts fresh. The directory is re-created
// because OCI image extractors need it to exist for writing files into.
resetBackendPath := func() {
os.RemoveAll(backendPath)
os.MkdirAll(backendPath, 0750)
}
success := false success := false
// Try to download from mirrors // Try to download from mirrors
for _, mirror := range config.Mirrors { for _, mirror := range config.Mirrors {
@ -210,6 +220,7 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL
return ctx.Err() return ctx.Err()
default: default:
} }
resetBackendPath()
if err := downloader.URI(mirror).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil { if err := downloader.URI(mirror).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil {
success = true success = true
xlog.Debug("Downloaded backend from mirror", "uri", config.URI, "backendPath", backendPath) xlog.Debug("Downloaded backend from mirror", "uri", config.URI, "backendPath", backendPath)
@ -221,28 +232,22 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL
// Try fallback: replace latestTag + "-" with masterTag + "-" in the URI // Try fallback: replace latestTag + "-" with masterTag + "-" in the URI
fallbackURI := strings.Replace(string(config.URI), latestTag+"-", masterTag+"-", 1) fallbackURI := strings.Replace(string(config.URI), latestTag+"-", masterTag+"-", 1)
if fallbackURI != string(config.URI) { if fallbackURI != string(config.URI) {
xlog.Debug("Trying fallback URI", "original", config.URI, "fallback", fallbackURI) resetBackendPath()
xlog.Info("Trying fallback URI", "original", config.URI, "fallback", fallbackURI)
if err := downloader.URI(fallbackURI).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil { if err := downloader.URI(fallbackURI).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil {
xlog.Info("Downloaded backend using fallback URI", "uri", fallbackURI, "backendPath", backendPath) xlog.Info("Downloaded backend using fallback URI", "uri", fallbackURI, "backendPath", backendPath)
success = true success = true
} else { } else {
// Try another fallback: add "-" + devSuffix suffix to the backend name xlog.Info("Fallback URI failed", "fallback", fallbackURI, "error", err)
// For example: master-gpu-nvidia-cuda-13-ace-step -> master-gpu-nvidia-cuda-13-ace-step-development
if !strings.Contains(fallbackURI, "-"+devSuffix) { if !strings.Contains(fallbackURI, "-"+devSuffix) {
// Extract backend name from URI and add -development resetBackendPath()
parts := strings.Split(fallbackURI, "-") devFallbackURI := fallbackURI + "-" + devSuffix
if len(parts) >= 2 { xlog.Info("Trying development fallback URI", "fallback", devFallbackURI)
// Find where the backend name ends (usually the last part before the tag) if err := downloader.URI(devFallbackURI).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil {
// Pattern: quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-ace-step xlog.Info("Downloaded backend using development fallback URI", "uri", devFallbackURI, "backendPath", backendPath)
lastDash := strings.LastIndex(fallbackURI, "-") success = true
if lastDash > 0 { } else {
devFallbackURI := fallbackURI[:lastDash] + "-" + devSuffix xlog.Info("Development fallback URI failed", "fallback", devFallbackURI, "error", err)
xlog.Debug("Trying development fallback URI", "fallback", devFallbackURI)
if err := downloader.URI(devFallbackURI).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil {
xlog.Info("Downloaded backend using development fallback URI", "uri", devFallbackURI, "backendPath", backendPath)
success = true
}
}
} }
} }
} }
@ -295,7 +300,7 @@ func DeleteBackendFromSystem(systemState *system.SystemState, name string) error
backend, ok := backends.Get(name) backend, ok := backends.Get(name)
if !ok { if !ok {
return fmt.Errorf("backend %q not found", name) return fmt.Errorf("backend %q: %w", name, ErrBackendNotFound)
} }
if backend.IsSystem { if backend.IsSystem {

View file

@ -5,7 +5,7 @@ import (
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"sort" "slices"
"strings" "strings"
"time" "time"
@ -106,64 +106,64 @@ func (gm GalleryElements[T]) FilterByTag(tag string) GalleryElements[T] {
} }
func (gm GalleryElements[T]) SortByName(sortOrder string) GalleryElements[T] { func (gm GalleryElements[T]) SortByName(sortOrder string) GalleryElements[T] {
sort.Slice(gm, func(i, j int) bool { slices.SortFunc(gm, func(a, b T) int {
if sortOrder == "asc" { r := strings.Compare(strings.ToLower(a.GetName()), strings.ToLower(b.GetName()))
return strings.ToLower(gm[i].GetName()) < strings.ToLower(gm[j].GetName()) if sortOrder == "desc" {
} else { return -r
return strings.ToLower(gm[i].GetName()) > strings.ToLower(gm[j].GetName())
} }
return r
}) })
return gm return gm
} }
func (gm GalleryElements[T]) SortByRepository(sortOrder string) GalleryElements[T] { func (gm GalleryElements[T]) SortByRepository(sortOrder string) GalleryElements[T] {
sort.Slice(gm, func(i, j int) bool { slices.SortFunc(gm, func(a, b T) int {
if sortOrder == "asc" { r := strings.Compare(strings.ToLower(a.GetGallery().Name), strings.ToLower(b.GetGallery().Name))
return strings.ToLower(gm[i].GetGallery().Name) < strings.ToLower(gm[j].GetGallery().Name) if sortOrder == "desc" {
} else { return -r
return strings.ToLower(gm[i].GetGallery().Name) > strings.ToLower(gm[j].GetGallery().Name)
} }
return r
}) })
return gm return gm
} }
func (gm GalleryElements[T]) SortByLicense(sortOrder string) GalleryElements[T] { func (gm GalleryElements[T]) SortByLicense(sortOrder string) GalleryElements[T] {
sort.Slice(gm, func(i, j int) bool { slices.SortFunc(gm, func(a, b T) int {
licenseI := gm[i].GetLicense() licenseA := a.GetLicense()
licenseJ := gm[j].GetLicense() licenseB := b.GetLicense()
var result bool var r int
if licenseI == "" && licenseJ != "" { if licenseA == "" && licenseB != "" {
return sortOrder == "desc" r = 1
} else if licenseI != "" && licenseJ == "" { } else if licenseA != "" && licenseB == "" {
return sortOrder == "asc" r = -1
} else if licenseI == "" && licenseJ == "" {
return false
} else { } else {
result = strings.ToLower(licenseI) < strings.ToLower(licenseJ) r = strings.Compare(strings.ToLower(licenseA), strings.ToLower(licenseB))
} }
if sortOrder == "desc" { if sortOrder == "desc" {
return !result return -r
} else {
return result
} }
return r
}) })
return gm return gm
} }
func (gm GalleryElements[T]) SortByInstalled(sortOrder string) GalleryElements[T] { func (gm GalleryElements[T]) SortByInstalled(sortOrder string) GalleryElements[T] {
sort.Slice(gm, func(i, j int) bool { slices.SortFunc(gm, func(a, b T) int {
var result bool var r int
// Sort by installed status: installed items first (true > false) // Sort by installed status: installed items first (true > false)
if gm[i].GetInstalled() != gm[j].GetInstalled() { if a.GetInstalled() != b.GetInstalled() {
result = gm[i].GetInstalled() if a.GetInstalled() {
r = -1
} else {
r = 1
}
} else { } else {
result = strings.ToLower(gm[i].GetName()) < strings.ToLower(gm[j].GetName()) r = strings.Compare(strings.ToLower(a.GetName()), strings.ToLower(b.GetName()))
} }
if sortOrder == "desc" { if sortOrder == "desc" {
return !result return -r
} else {
return result
} }
return r
}) })
return gm return gm
} }

View file

@ -27,7 +27,7 @@ var _ = Describe("Gallery", func() {
Describe("ReadConfigFile", func() { Describe("ReadConfigFile", func() {
It("should read and unmarshal a valid YAML file", func() { It("should read and unmarshal a valid YAML file", func() {
testConfig := map[string]interface{}{ testConfig := map[string]any{
"name": "test-model", "name": "test-model",
"description": "A test model", "description": "A test model",
"license": "MIT", "license": "MIT",
@ -39,8 +39,8 @@ var _ = Describe("Gallery", func() {
err = os.WriteFile(filePath, yamlData, 0644) err = os.WriteFile(filePath, yamlData, 0644)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
var result map[string]interface{} var result map[string]any
config, err := ReadConfigFile[map[string]interface{}](filePath) config, err := ReadConfigFile[map[string]any](filePath)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(config).NotTo(BeNil()) Expect(config).NotTo(BeNil())
result = *config result = *config
@ -50,7 +50,7 @@ var _ = Describe("Gallery", func() {
}) })
It("should return error when file does not exist", func() { It("should return error when file does not exist", func() {
_, err := ReadConfigFile[map[string]interface{}]("nonexistent.yaml") _, err := ReadConfigFile[map[string]any]("nonexistent.yaml")
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
}) })
@ -59,7 +59,7 @@ var _ = Describe("Gallery", func() {
err := os.WriteFile(filePath, []byte("invalid: yaml: content: [unclosed"), 0644) err := os.WriteFile(filePath, []byte("invalid: yaml: content: [unclosed"), 0644)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
_, err = ReadConfigFile[map[string]interface{}](filePath) _, err = ReadConfigFile[map[string]any](filePath)
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
}) })
}) })
@ -552,23 +552,23 @@ var _ = Describe("Gallery", func() {
// Verify first model // Verify first model
Expect(models[0].Name).To(Equal("nanbeige4.1-3b-q8")) Expect(models[0].Name).To(Equal("nanbeige4.1-3b-q8"))
Expect(models[0].Overrides).NotTo(BeNil()) Expect(models[0].Overrides).NotTo(BeNil())
Expect(models[0].Overrides["parameters"]).To(BeAssignableToTypeOf(map[string]interface{}{})) Expect(models[0].Overrides["parameters"]).To(BeAssignableToTypeOf(map[string]any{}))
params := models[0].Overrides["parameters"].(map[string]interface{}) params := models[0].Overrides["parameters"].(map[string]any)
Expect(params["model"]).To(Equal("nanbeige4.1-3b-q8_0.gguf")) Expect(params["model"]).To(Equal("nanbeige4.1-3b-q8_0.gguf"))
// Verify second model (merged) // Verify second model (merged)
Expect(models[1].Name).To(Equal("nanbeige4.1-3b-q4")) Expect(models[1].Name).To(Equal("nanbeige4.1-3b-q4"))
Expect(models[1].Overrides).NotTo(BeNil()) Expect(models[1].Overrides).NotTo(BeNil())
Expect(models[1].Overrides["parameters"]).To(BeAssignableToTypeOf(map[string]interface{}{})) Expect(models[1].Overrides["parameters"]).To(BeAssignableToTypeOf(map[string]any{}))
params = models[1].Overrides["parameters"].(map[string]interface{}) params = models[1].Overrides["parameters"].(map[string]any)
Expect(params["model"]).To(Equal("nanbeige4.1-3b-q4_k_m.gguf")) Expect(params["model"]).To(Equal("nanbeige4.1-3b-q4_k_m.gguf"))
// Simulate the mergo.Merge call that was failing in models.go:251 // Simulate the mergo.Merge call that was failing in models.go:251
// This should not panic with yaml.v3 // This should not panic with yaml.v3
configMap := make(map[string]interface{}) configMap := make(map[string]any)
configMap["name"] = "test" configMap["name"] = "test"
configMap["backend"] = "llama-cpp" configMap["backend"] = "llama-cpp"
configMap["parameters"] = map[string]interface{}{ configMap["parameters"] = map[string]any{
"model": "original.gguf", "model": "original.gguf",
} }
@ -577,7 +577,7 @@ var _ = Describe("Gallery", func() {
Expect(configMap["parameters"]).NotTo(BeNil()) Expect(configMap["parameters"]).NotTo(BeNil())
// Verify the merge worked correctly // Verify the merge worked correctly
mergedParams := configMap["parameters"].(map[string]interface{}) mergedParams := configMap["parameters"].(map[string]any)
Expect(mergedParams["model"]).To(Equal("nanbeige4.1-3b-q4_k_m.gguf")) Expect(mergedParams["model"]).To(Equal("nanbeige4.1-3b-q4_k_m.gguf"))
}) })
}) })

View file

@ -59,7 +59,7 @@ var _ = Describe("ImportLocalPath", func() {
adapterConfig := map[string]any{ adapterConfig := map[string]any{
"base_model_name_or_path": "meta-llama/Llama-2-7b-hf", "base_model_name_or_path": "meta-llama/Llama-2-7b-hf",
"peft_type": "LORA", "peft_type": "LORA",
} }
data, _ := json.Marshal(adapterConfig) data, _ := json.Marshal(adapterConfig)
Expect(os.WriteFile(filepath.Join(modelDir, "adapter_config.json"), data, 0644)).To(Succeed()) Expect(os.WriteFile(filepath.Join(modelDir, "adapter_config.json"), data, 0644)).To(Succeed())

View file

@ -158,7 +158,7 @@ func InstallModelFromGallery(
return applyModel(model) return applyModel(model)
} }
func InstallModel(ctx context.Context, systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) { func InstallModel(ctx context.Context, systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]any, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) {
basePath := systemState.Model.ModelsPath basePath := systemState.Model.ModelsPath
// Create base path if it doesn't exist // Create base path if it doesn't exist
err := os.MkdirAll(basePath, 0750) err := os.MkdirAll(basePath, 0750)
@ -239,7 +239,7 @@ func InstallModel(ctx context.Context, systemState *system.SystemState, nameOver
configFilePath := filepath.Join(basePath, name+".yaml") configFilePath := filepath.Join(basePath, name+".yaml")
// Read and update config file as map[string]interface{} // Read and update config file as map[string]interface{}
configMap := make(map[string]interface{}) configMap := make(map[string]any)
err = yaml.Unmarshal([]byte(config.ConfigFile), &configMap) err = yaml.Unmarshal([]byte(config.ConfigFile), &configMap)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to unmarshal config YAML: %v", err) return nil, fmt.Errorf("failed to unmarshal config YAML: %v", err)

View file

@ -35,7 +35,7 @@ var _ = Describe("Model test", func() {
system.WithModelPath(tempdir), system.WithModelPath(tempdir),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(context.TODO(), systemState, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true) _, err = InstallModel(context.TODO(), systemState, "", c, map[string]any{}, func(string, string, string, float64) {}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} { for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} {
@ -43,7 +43,7 @@ var _ = Describe("Model test", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
} }
content := map[string]interface{}{} content := map[string]any{}
dat, err := os.ReadFile(filepath.Join(tempdir, "cerebras.yaml")) dat, err := os.ReadFile(filepath.Join(tempdir, "cerebras.yaml"))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -95,7 +95,7 @@ var _ = Describe("Model test", func() {
dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml")) dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml"))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
content := map[string]interface{}{} content := map[string]any{}
err = yaml.Unmarshal(dat, &content) err = yaml.Unmarshal(dat, &content)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this")) Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this"))
@ -130,7 +130,7 @@ var _ = Describe("Model test", func() {
system.WithModelPath(tempdir), system.WithModelPath(tempdir),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true) _, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]any{}, func(string, string, string, float64) {}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} { for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
@ -150,7 +150,7 @@ var _ = Describe("Model test", func() {
system.WithModelPath(tempdir), system.WithModelPath(tempdir),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true) _, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]any{"backend": "foo"}, func(string, string, string, float64) {}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} { for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
@ -158,7 +158,7 @@ var _ = Describe("Model test", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
} }
content := map[string]interface{}{} content := map[string]any{}
dat, err := os.ReadFile(filepath.Join(tempdir, "foo.yaml")) dat, err := os.ReadFile(filepath.Join(tempdir, "foo.yaml"))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -180,7 +180,7 @@ var _ = Describe("Model test", func() {
system.WithModelPath(tempdir), system.WithModelPath(tempdir),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(context.TODO(), systemState, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true) _, err = InstallModel(context.TODO(), systemState, "../../../foo", c, map[string]any{}, func(string, string, string, float64) {}, true)
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
}) })

View file

@ -12,9 +12,9 @@ import (
type GalleryModel struct { type GalleryModel struct {
Metadata `json:",inline" yaml:",inline"` Metadata `json:",inline" yaml:",inline"`
// config_file is read in the situation where URL is blank - and therefore this is a base config. // config_file is read in the situation where URL is blank - and therefore this is a base config.
ConfigFile map[string]interface{} `json:"config_file,omitempty" yaml:"config_file,omitempty"` ConfigFile map[string]any `json:"config_file,omitempty" yaml:"config_file,omitempty"`
// Overrides are used to override the configuration of the model located at URL // Overrides are used to override the configuration of the model located at URL
Overrides map[string]interface{} `json:"overrides,omitempty" yaml:"overrides,omitempty"` Overrides map[string]any `json:"overrides,omitempty" yaml:"overrides,omitempty"`
} }
func (m *GalleryModel) GetInstalled() bool { func (m *GalleryModel) GetInstalled() bool {

66
core/gallery/worker.go Normal file
View file

@ -0,0 +1,66 @@
package gallery
import (
"fmt"
"os"
"path/filepath"
"strings"
"github.com/mudler/xlog"
)
// DeleteStagedModelFiles removes all staged files for a model from a worker's
// models directory. Files are expected to be in a subdirectory named after the
// model's tracking key (created by stageModelFiles in the router).
//
// Workers receive model files via S3/HTTP file staging, not gallery install,
// so they lack the YAML configs that DeleteModelFromSystem requires.
//
// Falls back to glob-based cleanup for single-file models or legacy layouts.
func DeleteStagedModelFiles(modelsPath, modelName string) error {
if modelName == "" {
return fmt.Errorf("empty model name")
}
// Clean and validate: resolved path must stay within modelsPath
modelPath := filepath.Clean(filepath.Join(modelsPath, modelName))
absModels := filepath.Clean(modelsPath)
if !strings.HasPrefix(modelPath, absModels+string(filepath.Separator)) {
return fmt.Errorf("model name %q escapes models directory", modelName)
}
// Primary: remove the model's subdirectory (contains all staged files)
if info, err := os.Stat(modelPath); err == nil && info.IsDir() {
return os.RemoveAll(modelPath)
}
// Fallback for single-file models or legacy layouts:
// remove exact file match + glob siblings
removed := false
if _, err := os.Stat(modelPath); err == nil {
if err := os.Remove(modelPath); err != nil {
xlog.Warn("Failed to remove model file", "path", modelPath, "error", err)
} else {
removed = true
}
}
// Remove sibling files (e.g., model.gguf.mmproj alongside model.gguf)
matches, _ := filepath.Glob(modelPath + ".*")
for _, m := range matches {
clean := filepath.Clean(m)
if !strings.HasPrefix(clean, absModels+string(filepath.Separator)) {
continue // skip any glob result that escapes
}
if err := os.Remove(clean); err != nil {
xlog.Warn("Failed to remove model-related file", "path", clean, "error", err)
} else {
removed = true
}
}
if !removed {
xlog.Debug("No files found to delete for model", "model", modelName, "path", modelPath)
}
return nil
}

View file

@ -0,0 +1,99 @@
package gallery_test
import (
"os"
"path/filepath"
"testing"
"github.com/mudler/LocalAI/core/gallery"
)
func TestDeleteStagedModelFiles(t *testing.T) {
t.Run("rejects empty model name", func(t *testing.T) {
dir := t.TempDir()
err := gallery.DeleteStagedModelFiles(dir, "")
if err == nil {
t.Fatal("expected error for empty model name")
}
})
t.Run("rejects path traversal via ..", func(t *testing.T) {
dir := t.TempDir()
err := gallery.DeleteStagedModelFiles(dir, "../../etc/passwd")
if err == nil {
t.Fatal("expected error for path traversal attempt")
}
})
t.Run("rejects path traversal via ../foo", func(t *testing.T) {
dir := t.TempDir()
err := gallery.DeleteStagedModelFiles(dir, "../foo")
if err == nil {
t.Fatal("expected error for path traversal attempt")
}
})
t.Run("removes model subdirectory with all files", func(t *testing.T) {
dir := t.TempDir()
modelDir := filepath.Join(dir, "my-model", "sd-cpp", "models")
if err := os.MkdirAll(modelDir, 0o755); err != nil {
t.Fatal(err)
}
// Create model files in subdirectory
os.WriteFile(filepath.Join(modelDir, "flux.gguf"), []byte("model"), 0o644)
os.WriteFile(filepath.Join(modelDir, "flux.gguf.mmproj"), []byte("mmproj"), 0o644)
err := gallery.DeleteStagedModelFiles(dir, "my-model")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Entire my-model directory should be gone
if _, err := os.Stat(filepath.Join(dir, "my-model")); !os.IsNotExist(err) {
t.Fatal("expected model directory to be removed")
}
})
t.Run("removes single file model", func(t *testing.T) {
dir := t.TempDir()
modelFile := filepath.Join(dir, "model.gguf")
os.WriteFile(modelFile, []byte("model"), 0o644)
err := gallery.DeleteStagedModelFiles(dir, "model.gguf")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if _, err := os.Stat(modelFile); !os.IsNotExist(err) {
t.Fatal("expected model file to be removed")
}
})
t.Run("removes sibling files via glob", func(t *testing.T) {
dir := t.TempDir()
modelFile := filepath.Join(dir, "model.gguf")
siblingFile := filepath.Join(dir, "model.gguf.mmproj")
os.WriteFile(modelFile, []byte("model"), 0o644)
os.WriteFile(siblingFile, []byte("mmproj"), 0o644)
err := gallery.DeleteStagedModelFiles(dir, "model.gguf")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if _, err := os.Stat(modelFile); !os.IsNotExist(err) {
t.Fatal("expected model file to be removed")
}
if _, err := os.Stat(siblingFile); !os.IsNotExist(err) {
t.Fatal("expected sibling file to be removed")
}
})
t.Run("no error when model does not exist", func(t *testing.T) {
dir := t.TempDir()
err := gallery.DeleteStagedModelFiles(dir, "nonexistent")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
})
}

View file

@ -16,12 +16,17 @@ import (
"github.com/mudler/LocalAI/core/http/auth" "github.com/mudler/LocalAI/core/http/auth"
"github.com/mudler/LocalAI/core/http/endpoints/localai" "github.com/mudler/LocalAI/core/http/endpoints/localai"
httpMiddleware "github.com/mudler/LocalAI/core/http/middleware" httpMiddleware "github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/http/routes" "github.com/mudler/LocalAI/core/http/routes"
"github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/core/services/finetune"
"github.com/mudler/LocalAI/core/services/galleryop"
"github.com/mudler/LocalAI/core/services/monitoring"
"github.com/mudler/LocalAI/core/services/nodes"
"github.com/mudler/LocalAI/core/services/quantization"
"github.com/mudler/xlog" "github.com/mudler/xlog"
) )
@ -155,7 +160,7 @@ func API(application *application.Application) (*echo.Echo, error) {
// Metrics middleware // Metrics middleware
if !application.ApplicationConfig().DisableMetrics { if !application.ApplicationConfig().DisableMetrics {
metricsService, err := services.NewLocalAIMetricsService() metricsService, err := monitoring.NewLocalAIMetricsService()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -295,9 +300,9 @@ func API(application *application.Application) (*echo.Echo, error) {
routes.RegisterElevenLabsRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()) routes.RegisterElevenLabsRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
// Create opcache for tracking UI operations (used by both UI and LocalAI routes) // Create opcache for tracking UI operations (used by both UI and LocalAI routes)
var opcache *services.OpCache var opcache *galleryop.OpCache
if !application.ApplicationConfig().DisableWebUI { if !application.ApplicationConfig().DisableWebUI {
opcache = services.NewOpCache(application.GalleryService()) opcache = galleryop.NewOpCache(application.GalleryService())
} }
mcpMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMCP) mcpMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMCP)
@ -305,22 +310,51 @@ func API(application *application.Application) (*echo.Echo, error) {
routes.RegisterAgentPoolRoutes(e, application, agentsMw, skillsMw, collectionsMw) routes.RegisterAgentPoolRoutes(e, application, agentsMw, skillsMw, collectionsMw)
// Fine-tuning routes // Fine-tuning routes
fineTuningMw := auth.RequireFeature(application.AuthDB(), auth.FeatureFineTuning) fineTuningMw := auth.RequireFeature(application.AuthDB(), auth.FeatureFineTuning)
ftService := services.NewFineTuneService( ftService := finetune.NewFineTuneService(
application.ApplicationConfig(), application.ApplicationConfig(),
application.ModelLoader(), application.ModelLoader(),
application.ModelConfigLoader(), application.ModelConfigLoader(),
) )
if d := application.Distributed(); d != nil {
ftService.SetNATSClient(d.Nats)
if d.DistStores != nil && d.DistStores.FineTune != nil {
ftService.SetFineTuneStore(d.DistStores.FineTune)
}
}
routes.RegisterFineTuningRoutes(e, ftService, application.ApplicationConfig(), fineTuningMw) routes.RegisterFineTuningRoutes(e, ftService, application.ApplicationConfig(), fineTuningMw)
// Quantization routes // Quantization routes
quantizationMw := auth.RequireFeature(application.AuthDB(), auth.FeatureQuantization) quantizationMw := auth.RequireFeature(application.AuthDB(), auth.FeatureQuantization)
qService := services.NewQuantizationService( qService := quantization.NewQuantizationService(
application.ApplicationConfig(), application.ApplicationConfig(),
application.ModelLoader(), application.ModelLoader(),
application.ModelConfigLoader(), application.ModelConfigLoader(),
) )
routes.RegisterQuantizationRoutes(e, qService, application.ApplicationConfig(), quantizationMw) routes.RegisterQuantizationRoutes(e, qService, application.ApplicationConfig(), quantizationMw)
// Node management routes (distributed mode)
distCfg := application.ApplicationConfig().Distributed
var registry *nodes.NodeRegistry
var remoteUnloader nodes.NodeCommandSender
if d := application.Distributed(); d != nil {
registry = d.Registry
if d.Router != nil {
remoteUnloader = d.Router.Unloader()
}
}
routes.RegisterNodeSelfServiceRoutes(e, registry, distCfg.RegistrationToken, distCfg.AutoApproveNodes, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret)
routes.RegisterNodeAdminRoutes(e, registry, remoteUnloader, adminMiddleware, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, application.ApplicationConfig().Distributed.RegistrationToken)
// Distributed SSE routes (job progress + agent events via NATS)
if d := application.Distributed(); d != nil {
if d.Dispatcher != nil {
e.GET("/api/agent/jobs/:id/progress", d.Dispatcher.SSEHandler(), mcpJobsMw)
}
if d.AgentBridge != nil {
e.GET("/api/agents/:name/sse/distributed", d.AgentBridge.SSEHandler(), agentsMw)
}
}
routes.RegisterOpenAIRoutes(e, requestExtractor, application) routes.RegisterOpenAIRoutes(e, requestExtractor, application)
routes.RegisterAnthropicRoutes(e, requestExtractor, application) routes.RegisterAnthropicRoutes(e, requestExtractor, application)
routes.RegisterOpenResponsesRoutes(e, requestExtractor, application) routes.RegisterOpenResponsesRoutes(e, requestExtractor, application)

View file

@ -44,14 +44,14 @@ Say hello.
### Response:` ### Response:`
type modelApplyRequest struct { type modelApplyRequest struct {
ID string `json:"id"` ID string `json:"id"`
URL string `json:"url"` URL string `json:"url"`
ConfigURL string `json:"config_url"` ConfigURL string `json:"config_url"`
Name string `json:"name"` Name string `json:"name"`
Overrides map[string]interface{} `json:"overrides"` Overrides map[string]any `json:"overrides"`
} }
func getModelStatus(url string) (response map[string]interface{}) { func getModelStatus(url string) (response map[string]any) {
// Create the HTTP request // Create the HTTP request
req, err := http.NewRequest("GET", url, nil) req, err := http.NewRequest("GET", url, nil)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@ -94,7 +94,7 @@ func getModels(url string) ([]gallery.GalleryModel, error) {
return response, err return response, err
} }
func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]interface{}) { func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]any) {
//url := "http://localhost:AI/models/apply" //url := "http://localhost:AI/models/apply"
@ -336,7 +336,7 @@ var _ = Describe("API test", func() {
Name: "bert", Name: "bert",
URL: bertEmbeddingsURL, URL: bertEmbeddingsURL,
}, },
Overrides: map[string]interface{}{"backend": "llama-cpp"}, Overrides: map[string]any{"backend": "llama-cpp"},
}, },
{ {
Metadata: gallery.Metadata{ Metadata: gallery.Metadata{
@ -344,7 +344,7 @@ var _ = Describe("API test", func() {
URL: bertEmbeddingsURL, URL: bertEmbeddingsURL,
AdditionalFiles: []gallery.File{{Filename: "foo.yaml", URI: bertEmbeddingsURL}}, AdditionalFiles: []gallery.File{{Filename: "foo.yaml", URI: bertEmbeddingsURL}},
}, },
Overrides: map[string]interface{}{"foo": "bar"}, Overrides: map[string]any{"foo": "bar"},
}, },
} }
out, err := yaml.Marshal(g) out, err := yaml.Marshal(g)
@ -464,7 +464,7 @@ var _ = Describe("API test", func() {
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
uuid := response["uuid"].(string) uuid := response["uuid"].(string)
resp := map[string]interface{}{} resp := map[string]any{}
Eventually(func() bool { Eventually(func() bool {
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
fmt.Println(response) fmt.Println(response)
@ -479,7 +479,7 @@ var _ = Describe("API test", func() {
_, err = os.ReadFile(filepath.Join(modelDir, "foo.yaml")) _, err = os.ReadFile(filepath.Join(modelDir, "foo.yaml"))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
content := map[string]interface{}{} content := map[string]any{}
err = yaml.Unmarshal(dat, &content) err = yaml.Unmarshal(dat, &content)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this")) Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this"))
@ -503,7 +503,7 @@ var _ = Describe("API test", func() {
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
URL: bertEmbeddingsURL, URL: bertEmbeddingsURL,
Name: "bert", Name: "bert",
Overrides: map[string]interface{}{ Overrides: map[string]any{
"backend": "llama", "backend": "llama",
}, },
}) })
@ -520,7 +520,7 @@ var _ = Describe("API test", func() {
dat, err := os.ReadFile(filepath.Join(modelDir, "bert.yaml")) dat, err := os.ReadFile(filepath.Join(modelDir, "bert.yaml"))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
content := map[string]interface{}{} content := map[string]any{}
err = yaml.Unmarshal(dat, &content) err = yaml.Unmarshal(dat, &content)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(content["backend"]).To(Equal("llama")) Expect(content["backend"]).To(Equal("llama"))
@ -529,7 +529,7 @@ var _ = Describe("API test", func() {
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
URL: bertEmbeddingsURL, URL: bertEmbeddingsURL,
Name: "bert", Name: "bert",
Overrides: map[string]interface{}{}, Overrides: map[string]any{},
}) })
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
@ -544,7 +544,7 @@ var _ = Describe("API test", func() {
dat, err := os.ReadFile(filepath.Join(modelDir, "bert.yaml")) dat, err := os.ReadFile(filepath.Join(modelDir, "bert.yaml"))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
content := map[string]interface{}{} content := map[string]any{}
err = yaml.Unmarshal(dat, &content) err = yaml.Unmarshal(dat, &content)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this")) Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this"))
@ -586,7 +586,7 @@ parameters:
Expect(response.ID).ToNot(BeEmpty()) Expect(response.ID).ToNot(BeEmpty())
uuid := response.ID uuid := response.ID
resp := map[string]interface{}{} resp := map[string]any{}
Eventually(func() bool { Eventually(func() bool {
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
resp = response resp = response
@ -601,7 +601,7 @@ parameters:
dat, err := os.ReadFile(filepath.Join(modelDir, "test-import-model.yaml")) dat, err := os.ReadFile(filepath.Join(modelDir, "test-import-model.yaml"))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
content := map[string]interface{}{} content := map[string]any{}
err = yaml.Unmarshal(dat, &content) err = yaml.Unmarshal(dat, &content)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(content["name"]).To(Equal("test-import-model")) Expect(content["name"]).To(Equal("test-import-model"))
@ -657,7 +657,7 @@ parameters:
Expect(response.ID).ToNot(BeEmpty()) Expect(response.ID).ToNot(BeEmpty())
uuid := response.ID uuid := response.ID
resp := map[string]interface{}{} resp := map[string]any{}
Eventually(func() bool { Eventually(func() bool {
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
resp = response resp = response
@ -1248,7 +1248,7 @@ parameters:
Context("Agent Jobs", Label("agent-jobs"), func() { Context("Agent Jobs", Label("agent-jobs"), func() {
It("creates and manages tasks", func() { It("creates and manages tasks", func() {
// Create a task // Create a task
taskBody := map[string]interface{}{ taskBody := map[string]any{
"name": "Test Task", "name": "Test Task",
"description": "Test Description", "description": "Test Description",
"model": "testmodel.ggml", "model": "testmodel.ggml",
@ -1256,7 +1256,7 @@ parameters:
"enabled": true, "enabled": true,
} }
var createResp map[string]interface{} var createResp map[string]any
err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp) err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(createResp["id"]).ToNot(BeEmpty()) Expect(createResp["id"]).ToNot(BeEmpty())
@ -1302,20 +1302,20 @@ parameters:
It("executes and monitors jobs", func() { It("executes and monitors jobs", func() {
// Create a task first // Create a task first
taskBody := map[string]interface{}{ taskBody := map[string]any{
"name": "Job Test Task", "name": "Job Test Task",
"model": "testmodel.ggml", "model": "testmodel.ggml",
"prompt": "Say hello", "prompt": "Say hello",
"enabled": true, "enabled": true,
} }
var createResp map[string]interface{} var createResp map[string]any
err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp) err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
taskID := createResp["id"].(string) taskID := createResp["id"].(string)
// Execute a job // Execute a job
jobBody := map[string]interface{}{ jobBody := map[string]any{
"task_id": taskID, "task_id": taskID,
"parameters": map[string]string{}, "parameters": map[string]string{},
} }
@ -1357,14 +1357,14 @@ parameters:
It("executes task by name", func() { It("executes task by name", func() {
// Create a task with a specific name // Create a task with a specific name
taskBody := map[string]interface{}{ taskBody := map[string]any{
"name": "Named Task", "name": "Named Task",
"model": "testmodel.ggml", "model": "testmodel.ggml",
"prompt": "Hello", "prompt": "Hello",
"enabled": true, "enabled": true,
} }
var createResp map[string]interface{} var createResp map[string]any
err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp) err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())

View file

@ -516,6 +516,17 @@ func isExemptPath(path string, appConfig *config.ApplicationConfig) bool {
return true return true
} }
// Node self-service endpoints — authenticated via registration token, not global auth.
// Only exempt the specific known endpoints, not the entire prefix.
if strings.HasPrefix(path, "/api/node/") {
if path == "/api/node/register" ||
strings.HasSuffix(path, "/heartbeat") ||
strings.HasSuffix(path, "/drain") ||
strings.HasSuffix(path, "/deregister") {
return true
}
}
// Check configured exempt paths // Check configured exempt paths
for _, p := range appConfig.PathWithoutAuth { for _, p := range appConfig.PathWithoutAuth {
if strings.HasPrefix(path, p) { if strings.HasPrefix(path, p) {
@ -540,6 +551,14 @@ func isAPIPath(path string) bool {
strings.HasPrefix(path, "/system") || strings.HasPrefix(path, "/system") ||
strings.HasPrefix(path, "/ws/") || strings.HasPrefix(path, "/ws/") ||
strings.HasPrefix(path, "/generated-") || strings.HasPrefix(path, "/generated-") ||
strings.HasPrefix(path, "/chat/") ||
strings.HasPrefix(path, "/completions") ||
strings.HasPrefix(path, "/edits") ||
strings.HasPrefix(path, "/embeddings") ||
strings.HasPrefix(path, "/audio/") ||
strings.HasPrefix(path, "/images/") ||
strings.HasPrefix(path, "/messages") ||
strings.HasPrefix(path, "/responses") ||
path == "/metrics" path == "/metrics"
} }

View file

@ -9,24 +9,25 @@ import (
// Auth provider constants. // Auth provider constants.
const ( const (
ProviderLocal = "local" ProviderLocal = "local"
ProviderGitHub = "github" ProviderGitHub = "github"
ProviderOIDC = "oidc" ProviderOIDC = "oidc"
ProviderAgentWorker = "agent-worker"
) )
// User represents an authenticated user. // User represents an authenticated user.
type User struct { type User struct {
ID string `gorm:"primaryKey;size:36"` ID string `gorm:"primaryKey;size:36"`
Email string `gorm:"size:255;index"` Email string `gorm:"size:255;index"`
Name string `gorm:"size:255"` Name string `gorm:"size:255"`
AvatarURL string `gorm:"size:512"` AvatarURL string `gorm:"size:512"`
Provider string `gorm:"size:50"` // ProviderLocal, ProviderGitHub, ProviderOIDC Provider string `gorm:"size:50"` // ProviderLocal, ProviderGitHub, ProviderOIDC
Subject string `gorm:"size:255"` // provider-specific user ID Subject string `gorm:"size:255"` // provider-specific user ID
PasswordHash string `json:"-"` // bcrypt hash, empty for OAuth-only users PasswordHash string `json:"-"` // bcrypt hash, empty for OAuth-only users
Role string `gorm:"size:20;default:user"` Role string `gorm:"size:20;default:user"`
Status string `gorm:"size:20;default:active"` // "active", "pending" Status string `gorm:"size:20;default:active"` // "active", "pending"
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
} }
// Session represents a user login session. // Session represents a user login session.
@ -90,16 +91,16 @@ func (p *PermissionMap) Scan(value any) error {
// InviteCode represents an admin-generated invitation for user registration. // InviteCode represents an admin-generated invitation for user registration.
type InviteCode struct { type InviteCode struct {
ID string `gorm:"primaryKey;size:36"` ID string `gorm:"primaryKey;size:36"`
Code string `gorm:"uniqueIndex;not null;size:64"` // HMAC-SHA256 hash of invite code Code string `gorm:"uniqueIndex;not null;size:64"` // HMAC-SHA256 hash of invite code
CodePrefix string `gorm:"size:12"` // first 8 chars for admin display CodePrefix string `gorm:"size:12"` // first 8 chars for admin display
CreatedBy string `gorm:"size:36;not null"` CreatedBy string `gorm:"size:36;not null"`
UsedBy *string `gorm:"size:36"` UsedBy *string `gorm:"size:36"`
UsedAt *time.Time UsedAt *time.Time
ExpiresAt time.Time `gorm:"not null;index"` ExpiresAt time.Time `gorm:"not null;index"`
CreatedAt time.Time CreatedAt time.Time
Creator User `gorm:"foreignKey:CreatedBy"` Creator User `gorm:"foreignKey:CreatedBy"`
Consumer *User `gorm:"foreignKey:UsedBy"` Consumer *User `gorm:"foreignKey:UsedBy"`
} }
// ModelAllowlist controls which models a user can access. // ModelAllowlist controls which models a user can access.

View file

@ -33,24 +33,24 @@ const (
FeatureMCPJobs = "mcp_jobs" FeatureMCPJobs = "mcp_jobs"
// General features (default OFF for new users) // General features (default OFF for new users)
FeatureFineTuning = "fine_tuning" FeatureFineTuning = "fine_tuning"
FeatureQuantization = "quantization" FeatureQuantization = "quantization"
// API features (default ON for new users) // API features (default ON for new users)
FeatureChat = "chat" FeatureChat = "chat"
FeatureImages = "images" FeatureImages = "images"
FeatureAudioSpeech = "audio_speech" FeatureAudioSpeech = "audio_speech"
FeatureAudioTranscription = "audio_transcription" FeatureAudioTranscription = "audio_transcription"
FeatureVAD = "vad" FeatureVAD = "vad"
FeatureDetection = "detection" FeatureDetection = "detection"
FeatureVideo = "video" FeatureVideo = "video"
FeatureEmbeddings = "embeddings" FeatureEmbeddings = "embeddings"
FeatureSound = "sound" FeatureSound = "sound"
FeatureRealtime = "realtime" FeatureRealtime = "realtime"
FeatureRerank = "rerank" FeatureRerank = "rerank"
FeatureTokenize = "tokenize" FeatureTokenize = "tokenize"
FeatureMCP = "mcp" FeatureMCP = "mcp"
FeatureStores = "stores" FeatureStores = "stores"
) )
// AgentFeatures lists agent-related features (default OFF). // AgentFeatures lists agent-related features (default OFF).

View file

@ -24,14 +24,14 @@ type QuotaRule struct {
// QuotaStatus is returned to clients with current usage included. // QuotaStatus is returned to clients with current usage included.
type QuotaStatus struct { type QuotaStatus struct {
ID string `json:"id"` ID string `json:"id"`
Model string `json:"model"` Model string `json:"model"`
MaxRequests *int64 `json:"max_requests"` MaxRequests *int64 `json:"max_requests"`
MaxTotalTokens *int64 `json:"max_total_tokens"` MaxTotalTokens *int64 `json:"max_total_tokens"`
Window string `json:"window"` Window string `json:"window"`
CurrentRequests int64 `json:"current_requests"` CurrentRequests int64 `json:"current_requests"`
CurrentTokens int64 `json:"current_total_tokens"` CurrentTokens int64 `json:"current_total_tokens"`
ResetsAt string `json:"resets_at,omitempty"` ResetsAt string `json:"resets_at,omitempty"`
} }
// ── CRUD ── // ── CRUD ──
@ -209,9 +209,9 @@ func QuotaExceeded(db *gorm.DB, userID, model string) (bool, int64, string) {
var quotaCache = newQuotaCacheStore() var quotaCache = newQuotaCacheStore()
type quotaCacheStore struct { type quotaCacheStore struct {
mu sync.RWMutex mu sync.RWMutex
rules map[string]cachedRules // userID -> rules rules map[string]cachedRules // userID -> rules
usage map[string]cachedUsage // "userID|model|windowStart" -> counts usage map[string]cachedUsage // "userID|model|windowStart" -> counts
} }
type cachedRules struct { type cachedRules struct {

View file

@ -13,7 +13,7 @@ import (
const ( const (
sessionDuration = 30 * 24 * time.Hour // 30 days sessionDuration = 30 * 24 * time.Hour // 30 days
sessionIDBytes = 32 // 32 bytes = 64 hex chars sessionIDBytes = 32 // 32 bytes = 64 hex chars
sessionCookie = "session" sessionCookie = "session"
sessionRotationInterval = 1 * time.Hour sessionRotationInterval = 1 * time.Hour
) )

View file

@ -10,15 +10,15 @@ import (
// UsageRecord represents a single API request's token usage. // UsageRecord represents a single API request's token usage.
type UsageRecord struct { type UsageRecord struct {
ID uint `gorm:"primaryKey;autoIncrement"` ID uint `gorm:"primaryKey;autoIncrement"`
UserID string `gorm:"size:36;index:idx_usage_user_time"` UserID string `gorm:"size:36;index:idx_usage_user_time"`
UserName string `gorm:"size:255"` UserName string `gorm:"size:255"`
Model string `gorm:"size:255;index"` Model string `gorm:"size:255;index"`
Endpoint string `gorm:"size:255"` Endpoint string `gorm:"size:255"`
PromptTokens int64 PromptTokens int64
CompletionTokens int64 CompletionTokens int64
TotalTokens int64 TotalTokens int64
Duration int64 // milliseconds Duration int64 // milliseconds
CreatedAt time.Time `gorm:"index:idx_usage_user_time"` CreatedAt time.Time `gorm:"index:idx_usage_user_time"`
} }
@ -127,10 +127,10 @@ func GetAllUsage(db *gorm.DB, period, userID string) ([]UsageBucket, error) {
bucketExpr := fmt.Sprintf("%s as bucket", dateFmt) bucketExpr := fmt.Sprintf("%s as bucket", dateFmt)
query := db.Model(&UsageRecord{}). query := db.Model(&UsageRecord{}).
Select(bucketExpr+", model, user_id, user_name, "+ Select(bucketExpr + ", model, user_id, user_name, " +
"SUM(prompt_tokens) as prompt_tokens, "+ "SUM(prompt_tokens) as prompt_tokens, " +
"SUM(completion_tokens) as completion_tokens, "+ "SUM(completion_tokens) as completion_tokens, " +
"SUM(total_tokens) as total_tokens, "+ "SUM(total_tokens) as total_tokens, " +
"COUNT(*) as request_count"). "COUNT(*) as request_count").
Group("bucket, model, user_id, user_name"). Group("bucket, model, user_id, user_name").
Order("bucket ASC") Order("bucket ASC")

View file

@ -36,7 +36,7 @@ var _ = Describe("Usage", func() {
db := testDB() db := testDB()
// Insert records for two users // Insert records for two users
for i := 0; i < 3; i++ { for range 3 {
err := auth.RecordUsage(db, &auth.UsageRecord{ err := auth.RecordUsage(db, &auth.UsageRecord{
UserID: "user-a", UserID: "user-a",
UserName: "Alice", UserName: "Alice",

View file

@ -3,7 +3,6 @@ package anthropic
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"strings"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
@ -25,7 +24,7 @@ import (
// @Param request body schema.AnthropicRequest true "query params" // @Param request body schema.AnthropicRequest true "query params"
// @Success 200 {object} schema.AnthropicResponse "Response" // @Success 200 {object} schema.AnthropicResponse "Response"
// @Router /v1/messages [post] // @Router /v1/messages [post]
func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc { func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig, natsClient mcpTools.MCPNATSClient) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
id := uuid.New().String() id := uuid.New().String()
@ -52,7 +51,7 @@ func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evalu
funcs, shouldUseFn := convertAnthropicTools(input, cfg) funcs, shouldUseFn := convertAnthropicTools(input, cfg)
// MCP injection: prompts, resources, and tools // MCP injection: prompts, resources, and tools
var mcpToolInfos []mcpTools.MCPToolInfo var mcpExecutor mcpTools.ToolExecutor
mcpServers := mcpTools.MCPServersFromMetadata(input.Metadata) mcpServers := mcpTools.MCPServersFromMetadata(input.Metadata)
mcpPromptName, mcpPromptArgs := mcpTools.MCPPromptFromMetadata(input.Metadata) mcpPromptName, mcpPromptArgs := mcpTools.MCPPromptFromMetadata(input.Metadata)
mcpResourceURIs := mcpTools.MCPResourcesFromMetadata(input.Metadata) mcpResourceURIs := mcpTools.MCPResourcesFromMetadata(input.Metadata)
@ -60,76 +59,29 @@ func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evalu
if (len(mcpServers) > 0 || mcpPromptName != "" || len(mcpResourceURIs) > 0) && (cfg.MCP.Servers != "" || cfg.MCP.Stdio != "") { if (len(mcpServers) > 0 || mcpPromptName != "" || len(mcpResourceURIs) > 0) && (cfg.MCP.Servers != "" || cfg.MCP.Stdio != "") {
remote, stdio, mcpErr := cfg.MCP.MCPConfigFromYAML() remote, stdio, mcpErr := cfg.MCP.MCPConfigFromYAML()
if mcpErr == nil { if mcpErr == nil {
mcpExecutor = mcpTools.NewToolExecutor(c.Request().Context(), natsClient, cfg.Name, remote, stdio, mcpServers)
// Prompt and resource injection (pre-processing step — resolves locally regardless of distributed mode)
namedSessions, sessErr := mcpTools.NamedSessionsFromMCPConfig(cfg.Name, remote, stdio, mcpServers) namedSessions, sessErr := mcpTools.NamedSessionsFromMCPConfig(cfg.Name, remote, stdio, mcpServers)
if sessErr == nil && len(namedSessions) > 0 { if sessErr == nil && len(namedSessions) > 0 {
// Prompt injection mcpCtx, _ := mcpTools.InjectMCPContext(c.Request().Context(), namedSessions, mcpPromptName, mcpPromptArgs, mcpResourceURIs)
if mcpPromptName != "" { if mcpCtx != nil {
prompts, discErr := mcpTools.DiscoverMCPPrompts(c.Request().Context(), namedSessions) openAIMessages = append(mcpCtx.PromptMessages, openAIMessages...)
if discErr == nil { mcpTools.AppendResourceSuffix(openAIMessages, mcpCtx.ResourceSuffix)
promptMsgs, getErr := mcpTools.GetMCPPrompt(c.Request().Context(), prompts, mcpPromptName, mcpPromptArgs)
if getErr == nil {
var injected []schema.Message
for _, pm := range promptMsgs {
injected = append(injected, schema.Message{
Role: string(pm.Role),
Content: mcpTools.PromptMessageToText(pm),
})
}
openAIMessages = append(injected, openAIMessages...)
xlog.Debug("Anthropic MCP prompt injected", "prompt", mcpPromptName, "messages", len(injected))
} else {
xlog.Error("Failed to get MCP prompt", "error", getErr)
}
}
} }
}
// Resource injection // Tool injection via executor
if len(mcpResourceURIs) > 0 { if mcpExecutor.HasTools() {
resources, discErr := mcpTools.DiscoverMCPResources(c.Request().Context(), namedSessions) mcpFuncs, discErr := mcpExecutor.DiscoverTools(c.Request().Context())
if discErr == nil { if discErr == nil {
var resourceTexts []string for _, fn := range mcpFuncs {
for _, uri := range mcpResourceURIs { funcs = append(funcs, fn)
content, readErr := mcpTools.ReadMCPResource(c.Request().Context(), resources, uri)
if readErr != nil {
xlog.Error("Failed to read MCP resource", "error", readErr, "uri", uri)
continue
}
name := uri
for _, r := range resources {
if r.URI == uri {
name = r.Name
break
}
}
resourceTexts = append(resourceTexts, fmt.Sprintf("--- MCP Resource: %s ---\n%s", name, content))
}
if len(resourceTexts) > 0 && len(openAIMessages) > 0 {
lastIdx := len(openAIMessages) - 1
suffix := "\n\n" + strings.Join(resourceTexts, "\n\n")
switch ct := openAIMessages[lastIdx].Content.(type) {
case string:
openAIMessages[lastIdx].Content = ct + suffix
default:
openAIMessages[lastIdx].Content = fmt.Sprintf("%v%s", ct, suffix)
}
xlog.Debug("Anthropic MCP resources injected", "count", len(resourceTexts))
}
}
}
// Tool injection
if len(mcpServers) > 0 {
discovered, discErr := mcpTools.DiscoverMCPTools(c.Request().Context(), namedSessions)
if discErr == nil {
mcpToolInfos = discovered
for _, ti := range mcpToolInfos {
funcs = append(funcs, ti.Function)
}
shouldUseFn = len(funcs) > 0 && cfg.ShouldUseFunctions()
xlog.Debug("Anthropic MCP tools injected", "count", len(mcpToolInfos), "total_funcs", len(funcs))
} else {
xlog.Error("Failed to discover MCP tools", "error", discErr)
} }
shouldUseFn = len(funcs) > 0 && cfg.ShouldUseFunctions()
xlog.Debug("Anthropic MCP tools injected", "count", len(mcpFuncs), "total_funcs", len(funcs))
} else {
xlog.Error("Failed to discover MCP tools", "error", discErr)
} }
} }
} else { } else {
@ -177,19 +129,19 @@ func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evalu
xlog.Debug("Anthropic Messages - Prompt (after templating)", "prompt", predInput) xlog.Debug("Anthropic Messages - Prompt (after templating)", "prompt", predInput)
if input.Stream { if input.Stream {
return handleAnthropicStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpToolInfos, evaluator) return handleAnthropicStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpExecutor, evaluator)
} }
return handleAnthropicNonStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpToolInfos, evaluator) return handleAnthropicNonStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpExecutor, evaluator)
} }
} }
func handleAnthropicNonStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpToolInfos []mcpTools.MCPToolInfo, evaluator *templates.Evaluator) error { func handleAnthropicNonStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpExecutor mcpTools.ToolExecutor, evaluator *templates.Evaluator) error {
mcpMaxIterations := 10 mcpMaxIterations := 10
if cfg.Agent.MaxIterations > 0 { if cfg.Agent.MaxIterations > 0 {
mcpMaxIterations = cfg.Agent.MaxIterations mcpMaxIterations = cfg.Agent.MaxIterations
} }
hasMCPTools := len(mcpToolInfos) > 0 hasMCPTools := mcpExecutor != nil && mcpExecutor.HasTools()
for mcpIteration := 0; mcpIteration <= mcpMaxIterations; mcpIteration++ { for mcpIteration := 0; mcpIteration <= mcpMaxIterations; mcpIteration++ {
// Re-template on each MCP iteration since messages may have changed // Re-template on each MCP iteration since messages may have changed
@ -227,7 +179,7 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic
if hasMCPTools && shouldUseFn && len(toolCalls) > 0 { if hasMCPTools && shouldUseFn && len(toolCalls) > 0 {
var hasMCPCalls bool var hasMCPCalls bool
for _, tc := range toolCalls { for _, tc := range toolCalls {
if mcpTools.IsMCPTool(mcpToolInfos, tc.Name) { if mcpExecutor != nil && mcpExecutor.IsTool(tc.Name) {
hasMCPCalls = true hasMCPCalls = true
break break
} }
@ -257,13 +209,12 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic
// Execute each MCP tool call and append results // Execute each MCP tool call and append results
for _, tc := range assistantMsg.ToolCalls { for _, tc := range assistantMsg.ToolCalls {
if !mcpTools.IsMCPTool(mcpToolInfos, tc.FunctionCall.Name) { if mcpExecutor == nil || !mcpExecutor.IsTool(tc.FunctionCall.Name) {
continue continue
} }
xlog.Debug("Executing MCP tool (Anthropic)", "tool", tc.FunctionCall.Name, "iteration", mcpIteration) xlog.Debug("Executing MCP tool (Anthropic)", "tool", tc.FunctionCall.Name, "iteration", mcpIteration)
toolResult, toolErr := mcpTools.ExecuteMCPToolCall( toolResult, toolErr := mcpExecutor.ExecuteTool(
c.Request().Context(), mcpToolInfos, c.Request().Context(), tc.FunctionCall.Name, tc.FunctionCall.Arguments,
tc.FunctionCall.Name, tc.FunctionCall.Arguments,
) )
if toolErr != nil { if toolErr != nil {
xlog.Error("MCP tool execution failed", "tool", tc.FunctionCall.Name, "error", toolErr) xlog.Error("MCP tool execution failed", "tool", tc.FunctionCall.Name, "error", toolErr)
@ -290,10 +241,10 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic
if shouldUseFn && len(toolCalls) > 0 { if shouldUseFn && len(toolCalls) > 0 {
stopReason = "tool_use" stopReason = "tool_use"
for _, tc := range toolCalls { for _, tc := range toolCalls {
var inputArgs map[string]interface{} var inputArgs map[string]any
if err := json.Unmarshal([]byte(tc.Arguments), &inputArgs); err != nil { if err := json.Unmarshal([]byte(tc.Arguments), &inputArgs); err != nil {
xlog.Warn("Failed to parse tool call arguments as JSON", "error", err, "args", tc.Arguments) xlog.Warn("Failed to parse tool call arguments as JSON", "error", err, "args", tc.Arguments)
inputArgs = map[string]interface{}{"raw": tc.Arguments} inputArgs = map[string]any{"raw": tc.Arguments}
} }
contentBlocks = append(contentBlocks, schema.AnthropicContentBlock{ contentBlocks = append(contentBlocks, schema.AnthropicContentBlock{
Type: "tool_use", Type: "tool_use",
@ -316,9 +267,9 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic
contentBlocks = append(contentBlocks, schema.AnthropicContentBlock{Type: "text", Text: stripped}) contentBlocks = append(contentBlocks, schema.AnthropicContentBlock{Type: "text", Text: stripped})
} }
for i, fc := range parsed { for i, fc := range parsed {
var inputArgs map[string]interface{} var inputArgs map[string]any
if err := json.Unmarshal([]byte(fc.Arguments), &inputArgs); err != nil { if err := json.Unmarshal([]byte(fc.Arguments), &inputArgs); err != nil {
inputArgs = map[string]interface{}{"raw": fc.Arguments} inputArgs = map[string]any{"raw": fc.Arguments}
} }
toolCallID := fc.ID toolCallID := fc.ID
if toolCallID == "" { if toolCallID == "" {
@ -365,7 +316,7 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic
return sendAnthropicError(c, 500, "api_error", "MCP iteration limit reached") return sendAnthropicError(c, 500, "api_error", "MCP iteration limit reached")
} }
func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpToolInfos []mcpTools.MCPToolInfo, evaluator *templates.Evaluator) error { func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpExecutor mcpTools.ToolExecutor, evaluator *templates.Evaluator) error {
c.Response().Header().Set("Content-Type", "text/event-stream") c.Response().Header().Set("Content-Type", "text/event-stream")
c.Response().Header().Set("Cache-Control", "no-cache") c.Response().Header().Set("Cache-Control", "no-cache")
c.Response().Header().Set("Connection", "keep-alive") c.Response().Header().Set("Connection", "keep-alive")
@ -388,7 +339,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
if cfg.Agent.MaxIterations > 0 { if cfg.Agent.MaxIterations > 0 {
mcpMaxIterations = cfg.Agent.MaxIterations mcpMaxIterations = cfg.Agent.MaxIterations
} }
hasMCPTools := len(mcpToolInfos) > 0 hasMCPTools := mcpExecutor != nil && mcpExecutor.HasTools()
for mcpIteration := 0; mcpIteration <= mcpMaxIterations; mcpIteration++ { for mcpIteration := 0; mcpIteration <= mcpMaxIterations; mcpIteration++ {
// Re-template on MCP iterations // Re-template on MCP iterations
@ -483,7 +434,14 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
_, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, func(s string, c *[]schema.Choice) {}, tokenCallback) _, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, func(s string, c *[]schema.Choice) {}, tokenCallback)
if err != nil { if err != nil {
xlog.Error("Anthropic stream model inference failed", "error", err) xlog.Error("Anthropic stream model inference failed", "error", err)
return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("model inference failed: %v", err)) sendAnthropicSSE(c, schema.AnthropicStreamEvent{
Type: "error",
Error: &schema.AnthropicError{
Type: "api_error",
Message: fmt.Sprintf("model inference failed: %v", err),
},
})
return nil
} }
// Also check chat deltas for tool calls // Also check chat deltas for tool calls
@ -495,7 +453,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
if hasMCPTools && len(collectedToolCalls) > 0 { if hasMCPTools && len(collectedToolCalls) > 0 {
var hasMCPCalls bool var hasMCPCalls bool
for _, tc := range collectedToolCalls { for _, tc := range collectedToolCalls {
if mcpTools.IsMCPTool(mcpToolInfos, tc.Name) { if mcpExecutor != nil && mcpExecutor.IsTool(tc.Name) {
hasMCPCalls = true hasMCPCalls = true
break break
} }
@ -525,13 +483,12 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
// Execute MCP tool calls // Execute MCP tool calls
for _, tc := range assistantMsg.ToolCalls { for _, tc := range assistantMsg.ToolCalls {
if !mcpTools.IsMCPTool(mcpToolInfos, tc.FunctionCall.Name) { if mcpExecutor == nil || !mcpExecutor.IsTool(tc.FunctionCall.Name) {
continue continue
} }
xlog.Debug("Executing MCP tool (Anthropic stream)", "tool", tc.FunctionCall.Name, "iteration", mcpIteration) xlog.Debug("Executing MCP tool (Anthropic stream)", "tool", tc.FunctionCall.Name, "iteration", mcpIteration)
toolResult, toolErr := mcpTools.ExecuteMCPToolCall( toolResult, toolErr := mcpExecutor.ExecuteTool(
c.Request().Context(), mcpToolInfos, c.Request().Context(), tc.FunctionCall.Name, tc.FunctionCall.Arguments,
tc.FunctionCall.Name, tc.FunctionCall.Arguments,
) )
if toolErr != nil { if toolErr != nil {
xlog.Error("MCP tool execution failed", "tool", tc.FunctionCall.Name, "error", toolErr) xlog.Error("MCP tool execution failed", "tool", tc.FunctionCall.Name, "error", toolErr)
@ -686,7 +643,7 @@ func convertAnthropicToOpenAIMessages(input *schema.AnthropicRequest) []schema.M
case string: case string:
openAIMsg.StringContent = content openAIMsg.StringContent = content
openAIMsg.Content = content openAIMsg.Content = content
case []interface{}: case []any:
// Handle array of content blocks // Handle array of content blocks
var textContent string var textContent string
var stringImages []string var stringImages []string
@ -694,7 +651,7 @@ func convertAnthropicToOpenAIMessages(input *schema.AnthropicRequest) []schema.M
toolCallIndex := 0 toolCallIndex := 0
for _, block := range content { for _, block := range content {
if blockMap, ok := block.(map[string]interface{}); ok { if blockMap, ok := block.(map[string]any); ok {
blockType, _ := blockMap["type"].(string) blockType, _ := blockMap["type"].(string)
switch blockType { switch blockType {
case "text": case "text":
@ -703,7 +660,7 @@ func convertAnthropicToOpenAIMessages(input *schema.AnthropicRequest) []schema.M
} }
case "image": case "image":
// Handle image content // Handle image content
if source, ok := blockMap["source"].(map[string]interface{}); ok { if source, ok := blockMap["source"].(map[string]any); ok {
if sourceType, ok := source["type"].(string); ok && sourceType == "base64" { if sourceType, ok := source["type"].(string); ok && sourceType == "base64" {
if data, ok := source["data"].(string); ok { if data, ok := source["data"].(string); ok {
mediaType, _ := source["media_type"].(string) mediaType, _ := source["media_type"].(string)
@ -751,10 +708,10 @@ func convertAnthropicToOpenAIMessages(input *schema.AnthropicRequest) []schema.M
switch rc := resultContent.(type) { switch rc := resultContent.(type) {
case string: case string:
resultText = rc resultText = rc
case []interface{}: case []any:
// Array of content blocks // Array of content blocks
for _, cb := range rc { for _, cb := range rc {
if cbMap, ok := cb.(map[string]interface{}); ok { if cbMap, ok := cb.(map[string]any); ok {
if cbMap["type"] == "text" { if cbMap["type"] == "text" {
if text, ok := cbMap["text"].(string); ok { if text, ok := cbMap["text"].(string); ok {
resultText += text resultText += text
@ -823,7 +780,7 @@ func convertAnthropicTools(input *schema.AnthropicRequest, cfg *config.ModelConf
return nil, false return nil, false
} }
// "auto" is the default - let model decide // "auto" is the default - let model decide
case map[string]interface{}: case map[string]any:
// Specific tool selection: {"type": "tool", "name": "tool_name"} // Specific tool selection: {"type": "tool", "name": "tool_name"}
if tcType, ok := tc["type"].(string); ok && tcType == "tool" { if tcType, ok := tc["type"].(string); ok && tcType == "tool" {
if name, ok := tc["name"].(string); ok { if name, ok := tc["name"].(string); ok {

View file

@ -1,9 +1,10 @@
package explorer package explorer
import ( import (
"cmp"
"encoding/base64" "encoding/base64"
"net/http" "net/http"
"sort" "slices"
"strings" "strings"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
@ -14,7 +15,7 @@ import (
func Dashboard() echo.HandlerFunc { func Dashboard() echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
summary := map[string]interface{}{ summary := map[string]any{
"Title": "LocalAI API - " + internal.PrintableVersion(), "Title": "LocalAI API - " + internal.PrintableVersion(),
"Version": internal.PrintableVersion(), "Version": internal.PrintableVersion(),
"BaseURL": middleware.BaseURL(c), "BaseURL": middleware.BaseURL(c),
@ -61,8 +62,8 @@ func ShowNetworks(db *explorer.Database) echo.HandlerFunc {
} }
// order by number of clusters // order by number of clusters
sort.Slice(results, func(i, j int) bool { slices.SortFunc(results, func(a, b Network) int {
return len(results[i].Clusters) > len(results[j].Clusters) return cmp.Compare(len(b.Clusters), len(a.Clusters))
}) })
return c.JSON(http.StatusOK, results) return c.JSON(http.StatusOK, results)
@ -73,36 +74,36 @@ func AddNetwork(db *explorer.Database) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
request := new(AddNetworkRequest) request := new(AddNetworkRequest)
if err := c.Bind(request); err != nil { if err := c.Bind(request); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Cannot parse JSON"}) return c.JSON(http.StatusBadRequest, map[string]any{"error": "Cannot parse JSON"})
} }
if request.Token == "" { if request.Token == "" {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Token is required"}) return c.JSON(http.StatusBadRequest, map[string]any{"error": "Token is required"})
} }
if request.Name == "" { if request.Name == "" {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Name is required"}) return c.JSON(http.StatusBadRequest, map[string]any{"error": "Name is required"})
} }
if request.Description == "" { if request.Description == "" {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Description is required"}) return c.JSON(http.StatusBadRequest, map[string]any{"error": "Description is required"})
} }
// TODO: check if token is valid, otherwise reject // TODO: check if token is valid, otherwise reject
// try to decode the token from base64 // try to decode the token from base64
_, err := base64.StdEncoding.DecodeString(request.Token) _, err := base64.StdEncoding.DecodeString(request.Token)
if err != nil { if err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid token"}) return c.JSON(http.StatusBadRequest, map[string]any{"error": "Invalid token"})
} }
if _, exists := db.Get(request.Token); exists { if _, exists := db.Get(request.Token); exists {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Token already exists"}) return c.JSON(http.StatusBadRequest, map[string]any{"error": "Token already exists"})
} }
err = db.Set(request.Token, explorer.TokenData{Name: request.Name, Description: request.Description}) err = db.Set(request.Token, explorer.TokenData{Name: request.Name, Description: request.Description})
if err != nil { if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Cannot add token"}) return c.JSON(http.StatusInternalServerError, map[string]any{"error": "Cannot add token"})
} }
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Token added"}) return c.JSON(http.StatusOK, map[string]any{"message": "Token added"})
} }
} }

View file

@ -1,6 +1,7 @@
package localai package localai
import ( import (
"errors"
"fmt" "fmt"
"net/http" "net/http"
"strconv" "strconv"
@ -8,12 +9,12 @@ import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/core/services/agentpool"
) )
// getJobService returns the job service for the current user. // getJobService returns the job service for the current user.
// Falls back to the global service when no user is authenticated. // Falls back to the global service when no user is authenticated.
func getJobService(app *application.Application, c echo.Context) *services.AgentJobService { func getJobService(app *application.Application, c echo.Context) *agentpool.AgentJobService {
userID := getUserID(c) userID := getUserID(c)
if userID == "" { if userID == "" {
return app.AgentJobService() return app.AgentJobService()
@ -54,7 +55,7 @@ func UpdateTaskEndpoint(app *application.Application) echo.HandlerFunc {
} }
if err := getJobService(app, c).UpdateTask(id, task); err != nil { if err := getJobService(app, c).UpdateTask(id, task); err != nil {
if err.Error() == "task not found: "+id { if errors.Is(err, agentpool.ErrTaskNotFound) {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
} }
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
@ -68,7 +69,7 @@ func DeleteTaskEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
id := c.Param("id") id := c.Param("id")
if err := getJobService(app, c).DeleteTask(id); err != nil { if err := getJobService(app, c).DeleteTask(id); err != nil {
if err.Error() == "task not found: "+id { if errors.Is(err, agentpool.ErrTaskNotFound) {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
} }
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
@ -244,7 +245,7 @@ func CancelJobEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
id := c.Param("id") id := c.Param("id")
if err := getJobService(app, c).CancelJob(id); err != nil { if err := getJobService(app, c).CancelJob(id); err != nil {
if err.Error() == "job not found: "+id { if errors.Is(err, agentpool.ErrJobNotFound) {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
} }
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
@ -258,7 +259,7 @@ func DeleteJobEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
id := c.Param("id") id := c.Param("id")
if err := getJobService(app, c).DeleteJob(id); err != nil { if err := getJobService(app, c).DeleteJob(id); err != nil {
if err.Error() == "job not found: "+id { if errors.Is(err, agentpool.ErrJobNotFound) {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
} }
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
@ -275,7 +276,7 @@ func ExecuteTaskByNameEndpoint(app *application.Application) echo.HandlerFunc {
if c.Request().ContentLength > 0 { if c.Request().ContentLength > 0 {
if err := c.Bind(&params); err != nil { if err := c.Bind(&params); err != nil {
body := make(map[string]interface{}) body := make(map[string]any)
if err := c.Bind(&body); err == nil { if err := c.Bind(&body); err == nil {
params = make(map[string]string) params = make(map[string]string)
for k, v := range body { for k, v := range body {

View file

@ -2,6 +2,7 @@ package localai
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -10,8 +11,9 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/application"
coreTypes "github.com/mudler/LocalAGI/core/types" coreTypes "github.com/mudler/LocalAGI/core/types"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/services/agents"
"github.com/mudler/xlog" "github.com/mudler/xlog"
"github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
) )
@ -50,55 +52,105 @@ func AgentResponsesInterceptor(app *application.Application) echo.MiddlewareFunc
return next(c) return next(c)
} }
// Check if this model name is an agent // Check if this model name is an agent — try in-process agent first,
ag := svc.GetAgent(req.Model) // fall back to config lookup (covers distributed mode where agents
if ag == nil { // don't run in-process).
return next(c)
}
// This is an agent — handle the request directly
messages := parseInputToMessages(req.Input) messages := parseInputToMessages(req.Input)
if len(messages) == 0 { userID := effectiveUserID(c)
return c.JSON(http.StatusBadRequest, map[string]any{ ag := svc.GetAgent(req.Model)
"error": map[string]string{ if ag == nil && svc.GetAgentConfigForUser(userID, req.Model) == nil {
"type": "invalid_request_error", return next(c) // not an agent
"message": "no input messages provided",
},
})
} }
jobOptions := []coreTypes.JobOption{ // Extract the last user message for the executor
coreTypes.WithConversationHistory(messages), var userMessage string
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Role == "user" {
userMessage = messages[i].Content
break
}
} }
res := ag.Ask(jobOptions...) var responseText string
if res == nil { if ag != nil {
return c.JSON(http.StatusInternalServerError, map[string]any{ // Local mode: use LocalAGI agent directly
"error": map[string]string{ jobOptions := []coreTypes.JobOption{
"type": "server_error", coreTypes.WithConversationHistory(messages),
"message": "agent request failed or was cancelled", }
},
}) res := ag.Ask(jobOptions...)
} if res == nil {
if res.Error != nil { return c.JSON(http.StatusInternalServerError, map[string]any{
xlog.Error("Error asking agent via responses API", "agent", req.Model, "error", res.Error) "error": map[string]string{
return c.JSON(http.StatusInternalServerError, map[string]any{ "type": "server_error",
"error": map[string]string{ "message": "agent request failed or was cancelled",
"type": "server_error", },
"message": res.Error.Error(), })
}, }
if res.Error != nil {
xlog.Error("Error asking agent via responses API", "agent", req.Model, "error", res.Error)
return c.JSON(http.StatusInternalServerError, map[string]any{
"error": map[string]string{
"type": "server_error",
"message": res.Error.Error(),
},
})
}
responseText = res.Response
} else {
// Distributed mode: dispatch via NATS + wait for response synchronously
var bridge *agents.EventBridge
if d := app.Distributed(); d != nil {
bridge = d.AgentBridge
}
if bridge == nil {
return next(c)
}
// Subscribe BEFORE dispatching so we never miss a fast response
ctx, cancel := context.WithTimeout(c.Request().Context(), 5*time.Minute)
defer cancel()
responseCh := make(chan string, 1)
sub, err := bridge.SubscribeEvents(req.Model, userID, func(evt agents.AgentEvent) {
if evt.EventType == "json_message" && evt.Sender == "agent" {
responseCh <- evt.Content
}
}) })
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]any{
"error": map[string]string{"type": "server_error", "message": "failed to subscribe to agent events"},
})
}
defer sub.Unsubscribe()
// Now dispatch via ChatForUser (publishes to NATS)
_, err = svc.ChatForUser(userID, req.Model, userMessage)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]any{
"error": map[string]string{"type": "server_error", "message": err.Error()},
})
}
select {
case responseText = <-responseCh:
// Got the response
case <-ctx.Done():
return c.JSON(http.StatusGatewayTimeout, map[string]any{
"error": map[string]string{"type": "server_error", "message": "agent response timeout"},
})
}
} }
id := fmt.Sprintf("resp_%s", uuid.New().String()) id := fmt.Sprintf("resp_%s", uuid.New().String())
return c.JSON(http.StatusOK, map[string]any{ return c.JSON(http.StatusOK, map[string]any{
"id": id, "id": id,
"object": "response", "object": "response",
"created_at": time.Now().Unix(), "created_at": time.Now().Unix(),
"status": "completed", "status": "completed",
"model": req.Model, "model": req.Model,
"previous_response_id": nil, "previous_response_id": nil,
"output": []any{ "output": []any{
map[string]any{ map[string]any{
@ -109,7 +161,7 @@ func AgentResponsesInterceptor(app *application.Application) echo.MiddlewareFunc
"content": []map[string]any{ "content": []map[string]any{
{ {
"type": "output_text", "type": "output_text",
"text": res.Response, "text": responseText,
"annotations": []any{}, "annotations": []any{},
}, },
}, },

View file

@ -7,6 +7,7 @@ import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/application"
skillsManager "github.com/mudler/LocalAI/core/services/skills"
skilldomain "github.com/mudler/skillserver/pkg/domain" skilldomain "github.com/mudler/skillserver/pkg/domain"
) )
@ -41,27 +42,48 @@ func skillsToResponses(skills []skilldomain.Skill) []skillResponse {
return out return out
} }
// getSkillManager returns a SkillManager for the request's user.
func getSkillManager(c echo.Context, app *application.Application) (skillsManager.Manager, error) {
svc := app.AgentPoolService()
userID := getUserID(c)
return svc.SkillManagerForUser(userID)
}
func getSkillManagerEffective(c echo.Context, app *application.Application) (skillsManager.Manager, error) {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
return svc.SkillManagerForUser(userID)
}
func ListSkillsEndpoint(app *application.Application) echo.HandlerFunc { func ListSkillsEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
svc := app.AgentPoolService() mgr, err := getSkillManager(c, app)
userID := getUserID(c) if err != nil {
skills, err := svc.ListSkillsForUser(userID) return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
skills, err := mgr.List()
if err != nil { if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
} }
// Admin cross-user aggregation // Admin cross-user aggregation
if wantsAllUsers(c) { if wantsAllUsers(c) {
svc := app.AgentPoolService()
usm := svc.UserServicesManager() usm := svc.UserServicesManager()
if usm != nil { if usm != nil {
userIDs, _ := usm.ListAllUserIDs() userIDs, _ := usm.ListAllUserIDs()
userGroups := map[string]any{} userGroups := map[string]any{}
userID := getUserID(c)
for _, uid := range userIDs { for _, uid := range userIDs {
if uid == userID { if uid == userID {
continue continue
} }
userSkills, err := svc.ListSkillsForUser(uid) uidMgr, mgrErr := svc.SkillManagerForUser(uid)
if err != nil || len(userSkills) == 0 { if mgrErr != nil {
continue
}
userSkills, listErr := uidMgr.List()
if listErr != nil || len(userSkills) == 0 {
continue continue
} }
userGroups[uid] = map[string]any{"skills": skillsToResponses(userSkills)} userGroups[uid] = map[string]any{"skills": skillsToResponses(userSkills)}
@ -76,25 +98,28 @@ func ListSkillsEndpoint(app *application.Application) echo.HandlerFunc {
} }
} }
return c.JSON(http.StatusOK, skillsToResponses(skills)) return c.JSON(http.StatusOK, map[string]any{"skills": skillsToResponses(skills)})
} }
} }
func GetSkillsConfigEndpoint(app *application.Application) echo.HandlerFunc { func GetSkillsConfigEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
svc := app.AgentPoolService() mgr, err := getSkillManager(c, app)
userID := getUserID(c) if err != nil {
cfg := svc.GetSkillsConfigForUser(userID) return c.JSON(http.StatusOK, map[string]string{})
return c.JSON(http.StatusOK, cfg) }
return c.JSON(http.StatusOK, mgr.GetConfig())
} }
} }
func SearchSkillsEndpoint(app *application.Application) echo.HandlerFunc { func SearchSkillsEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
svc := app.AgentPoolService() mgr, err := getSkillManager(c, app)
userID := getUserID(c) if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
query := c.QueryParam("q") query := c.QueryParam("q")
skills, err := svc.SearchSkillsForUser(userID, query) skills, err := mgr.Search(query)
if err != nil { if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
} }
@ -104,8 +129,10 @@ func SearchSkillsEndpoint(app *application.Application) echo.HandlerFunc {
func CreateSkillEndpoint(app *application.Application) echo.HandlerFunc { func CreateSkillEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
svc := app.AgentPoolService() mgr, err := getSkillManager(c, app)
userID := getUserID(c) if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
var payload struct { var payload struct {
Name string `json:"name"` Name string `json:"name"`
Description string `json:"description"` Description string `json:"description"`
@ -118,7 +145,7 @@ func CreateSkillEndpoint(app *application.Application) echo.HandlerFunc {
if err := c.Bind(&payload); err != nil { if err := c.Bind(&payload); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
} }
skill, err := svc.CreateSkillForUser(userID, payload.Name, payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata) skill, err := mgr.Create(payload.Name, payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata)
if err != nil { if err != nil {
if strings.Contains(err.Error(), "already exists") { if strings.Contains(err.Error(), "already exists") {
return c.JSON(http.StatusConflict, map[string]string{"error": err.Error()}) return c.JSON(http.StatusConflict, map[string]string{"error": err.Error()})
@ -131,9 +158,11 @@ func CreateSkillEndpoint(app *application.Application) echo.HandlerFunc {
func GetSkillEndpoint(app *application.Application) echo.HandlerFunc { func GetSkillEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
svc := app.AgentPoolService() mgr, err := getSkillManagerEffective(c, app)
userID := effectiveUserID(c) if err != nil {
skill, err := svc.GetSkillForUser(userID, c.Param("name")) return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
skill, err := mgr.Get(c.Param("name"))
if err != nil { if err != nil {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
} }
@ -143,8 +172,10 @@ func GetSkillEndpoint(app *application.Application) echo.HandlerFunc {
func UpdateSkillEndpoint(app *application.Application) echo.HandlerFunc { func UpdateSkillEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
svc := app.AgentPoolService() mgr, err := getSkillManagerEffective(c, app)
userID := effectiveUserID(c) if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
var payload struct { var payload struct {
Description string `json:"description"` Description string `json:"description"`
Content string `json:"content"` Content string `json:"content"`
@ -156,7 +187,7 @@ func UpdateSkillEndpoint(app *application.Application) echo.HandlerFunc {
if err := c.Bind(&payload); err != nil { if err := c.Bind(&payload); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
} }
skill, err := svc.UpdateSkillForUser(userID, c.Param("name"), payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata) skill, err := mgr.Update(c.Param("name"), payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata)
if err != nil { if err != nil {
if strings.Contains(err.Error(), "not found") { if strings.Contains(err.Error(), "not found") {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
@ -169,9 +200,11 @@ func UpdateSkillEndpoint(app *application.Application) echo.HandlerFunc {
func DeleteSkillEndpoint(app *application.Application) echo.HandlerFunc { func DeleteSkillEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
svc := app.AgentPoolService() mgr, err := getSkillManagerEffective(c, app)
userID := effectiveUserID(c) if err != nil {
if err := svc.DeleteSkillForUser(userID, c.Param("name")); err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
if err := mgr.Delete(c.Param("name")); err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
} }
return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
@ -180,10 +213,12 @@ func DeleteSkillEndpoint(app *application.Application) echo.HandlerFunc {
func ExportSkillEndpoint(app *application.Application) echo.HandlerFunc { func ExportSkillEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
svc := app.AgentPoolService() mgr, err := getSkillManagerEffective(c, app)
userID := effectiveUserID(c) if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
name := c.Param("*") name := c.Param("*")
data, err := svc.ExportSkillForUser(userID, name) data, err := mgr.Export(name)
if err != nil { if err != nil {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
} }
@ -195,8 +230,10 @@ func ExportSkillEndpoint(app *application.Application) echo.HandlerFunc {
func ImportSkillEndpoint(app *application.Application) echo.HandlerFunc { func ImportSkillEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
svc := app.AgentPoolService() mgr, err := getSkillManager(c, app)
userID := getUserID(c) if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
file, err := c.FormFile("file") file, err := c.FormFile("file")
if err != nil { if err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "file required"}) return c.JSON(http.StatusBadRequest, map[string]string{"error": "file required"})
@ -210,7 +247,7 @@ func ImportSkillEndpoint(app *application.Application) echo.HandlerFunc {
if err != nil { if err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
} }
skill, err := svc.ImportSkillForUser(userID, data) skill, err := mgr.Import(data)
if err != nil { if err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
} }
@ -222,9 +259,11 @@ func ImportSkillEndpoint(app *application.Application) echo.HandlerFunc {
func ListSkillResourcesEndpoint(app *application.Application) echo.HandlerFunc { func ListSkillResourcesEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
svc := app.AgentPoolService() mgr, err := getSkillManagerEffective(c, app)
userID := effectiveUserID(c) if err != nil {
resources, skill, err := svc.ListSkillResourcesForUser(userID, c.Param("name")) return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
resources, skill, err := mgr.ListResources(c.Param("name"))
if err != nil { if err != nil {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
} }
@ -260,9 +299,11 @@ func ListSkillResourcesEndpoint(app *application.Application) echo.HandlerFunc {
func GetSkillResourceEndpoint(app *application.Application) echo.HandlerFunc { func GetSkillResourceEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
svc := app.AgentPoolService() mgr, err := getSkillManagerEffective(c, app)
userID := effectiveUserID(c) if err != nil {
content, info, err := svc.GetSkillResourceForUser(userID, c.Param("name"), c.Param("*")) return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
content, info, err := mgr.GetResource(c.Param("name"), c.Param("*"))
if err != nil { if err != nil {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
} }
@ -281,10 +322,12 @@ func GetSkillResourceEndpoint(app *application.Application) echo.HandlerFunc {
func CreateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc { func CreateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
svc := app.AgentPoolService() mgr, err := getSkillManager(c, app)
userID := getUserID(c)
file, err := c.FormFile("file")
if err != nil { if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
file, fileErr := c.FormFile("file")
if fileErr != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "file is required"}) return c.JSON(http.StatusBadRequest, map[string]string{"error": "file is required"})
} }
path := c.FormValue("path") path := c.FormValue("path")
@ -300,7 +343,7 @@ func CreateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc
if err != nil { if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
} }
if err := svc.CreateSkillResourceForUser(userID, c.Param("name"), path, data); err != nil { if err := mgr.CreateResource(c.Param("name"), path, data); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
} }
return c.JSON(http.StatusCreated, map[string]string{"path": path}) return c.JSON(http.StatusCreated, map[string]string{"path": path})
@ -309,15 +352,17 @@ func CreateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc
func UpdateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc { func UpdateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
svc := app.AgentPoolService() mgr, err := getSkillManager(c, app)
userID := getUserID(c) if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
var payload struct { var payload struct {
Content string `json:"content"` Content string `json:"content"`
} }
if err := c.Bind(&payload); err != nil { if err := c.Bind(&payload); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
} }
if err := svc.UpdateSkillResourceForUser(userID, c.Param("name"), c.Param("*"), payload.Content); err != nil { if err := mgr.UpdateResource(c.Param("name"), c.Param("*"), payload.Content); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
} }
return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
@ -326,9 +371,11 @@ func UpdateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc
func DeleteSkillResourceEndpoint(app *application.Application) echo.HandlerFunc { func DeleteSkillResourceEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
svc := app.AgentPoolService() mgr, err := getSkillManager(c, app)
userID := getUserID(c) if err != nil {
if err := svc.DeleteSkillResourceForUser(userID, c.Param("name"), c.Param("*")); err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
if err := mgr.DeleteResource(c.Param("name"), c.Param("*")); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
} }
return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
@ -339,9 +386,11 @@ func DeleteSkillResourceEndpoint(app *application.Application) echo.HandlerFunc
func ListGitReposEndpoint(app *application.Application) echo.HandlerFunc { func ListGitReposEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
svc := app.AgentPoolService() mgr, err := getSkillManager(c, app)
userID := getUserID(c) if err != nil {
repos, err := svc.ListGitReposForUser(userID) return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
repos, err := mgr.ListGitRepos()
if err != nil { if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
} }
@ -351,15 +400,17 @@ func ListGitReposEndpoint(app *application.Application) echo.HandlerFunc {
func AddGitRepoEndpoint(app *application.Application) echo.HandlerFunc { func AddGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
svc := app.AgentPoolService() mgr, err := getSkillManager(c, app)
userID := getUserID(c) if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
var payload struct { var payload struct {
URL string `json:"url"` URL string `json:"url"`
} }
if err := c.Bind(&payload); err != nil { if err := c.Bind(&payload); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
} }
repo, err := svc.AddGitRepoForUser(userID, payload.URL) repo, err := mgr.AddGitRepo(payload.URL)
if err != nil { if err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
} }
@ -369,8 +420,10 @@ func AddGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
func UpdateGitRepoEndpoint(app *application.Application) echo.HandlerFunc { func UpdateGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
svc := app.AgentPoolService() mgr, err := getSkillManager(c, app)
userID := getUserID(c) if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
var payload struct { var payload struct {
URL string `json:"url"` URL string `json:"url"`
Enabled *bool `json:"enabled"` Enabled *bool `json:"enabled"`
@ -378,7 +431,7 @@ func UpdateGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
if err := c.Bind(&payload); err != nil { if err := c.Bind(&payload); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
} }
repo, err := svc.UpdateGitRepoForUser(userID, c.Param("id"), payload.URL, payload.Enabled) repo, err := mgr.UpdateGitRepo(c.Param("id"), payload.URL, payload.Enabled)
if err != nil { if err != nil {
if strings.Contains(err.Error(), "not found") { if strings.Contains(err.Error(), "not found") {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
@ -391,9 +444,11 @@ func UpdateGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
func DeleteGitRepoEndpoint(app *application.Application) echo.HandlerFunc { func DeleteGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
svc := app.AgentPoolService() mgr, err := getSkillManager(c, app)
userID := getUserID(c) if err != nil {
if err := svc.DeleteGitRepoForUser(userID, c.Param("id")); err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
if err := mgr.DeleteGitRepo(c.Param("id")); err != nil {
if strings.Contains(err.Error(), "not found") { if strings.Contains(err.Error(), "not found") {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
} }
@ -405,9 +460,11 @@ func DeleteGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
func SyncGitRepoEndpoint(app *application.Application) echo.HandlerFunc { func SyncGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
svc := app.AgentPoolService() mgr, err := getSkillManager(c, app)
userID := getUserID(c) if err != nil {
if err := svc.SyncGitRepoForUser(userID, c.Param("id")); err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
if err := mgr.SyncGitRepo(c.Param("id")); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
} }
return c.JSON(http.StatusAccepted, map[string]string{"status": "syncing"}) return c.JSON(http.StatusAccepted, map[string]string{"status": "syncing"})
@ -416,9 +473,11 @@ func SyncGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
func ToggleGitRepoEndpoint(app *application.Application) echo.HandlerFunc { func ToggleGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
svc := app.AgentPoolService() mgr, err := getSkillManager(c, app)
userID := getUserID(c) if err != nil {
repo, err := svc.ToggleGitRepoForUser(userID, c.Param("id")) return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
repo, err := mgr.ToggleGitRepo(c.Param("id"))
if err != nil { if err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
} }

View file

@ -4,20 +4,23 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"maps"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
"sort" "slices"
"strings" "strings"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/http/auth"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/mudler/LocalAGI/core/state" "github.com/mudler/LocalAGI/core/state"
coreTypes "github.com/mudler/LocalAGI/core/types" coreTypes "github.com/mudler/LocalAGI/core/types"
agiServices "github.com/mudler/LocalAGI/services" agiServices "github.com/mudler/LocalAGI/services"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/http/auth"
"github.com/mudler/LocalAI/core/services/agentpool"
"github.com/mudler/LocalAI/core/services/agents"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/mudler/xlog"
) )
// getUserID extracts the scoped user ID from the request context. // getUserID extracts the scoped user ID from the request context.
@ -42,25 +45,39 @@ func wantsAllUsers(c echo.Context) bool {
} }
// effectiveUserID returns the user ID to scope operations to. // effectiveUserID returns the user ID to scope operations to.
// SECURITY: Only admins may supply ?user_id=<id> to operate on another user's // SECURITY: Only admins and agent-worker service accounts may supply
// resources. Non-admin callers always get their own ID regardless of query params. // ?user_id=<id> to operate on another user's resources. Agent-worker users are
// created exclusively server-side during node registration and need to access
// collections on behalf of the user whose agent they are executing.
// Regular callers always get their own ID regardless of query params.
func effectiveUserID(c echo.Context) string { func effectiveUserID(c echo.Context) string {
if targetUID := c.QueryParam("user_id"); targetUID != "" && isAdminUser(c) { if targetUID := c.QueryParam("user_id"); targetUID != "" && canImpersonateUser(c) {
if callerID := getUserID(c); callerID != targetUID {
xlog.Info("User impersonation", "caller", callerID, "target", targetUID, "path", c.Path())
}
return targetUID return targetUID
} }
return getUserID(c) return getUserID(c)
} }
// canImpersonateUser returns true if the caller is allowed to use ?user_id= to
// scope operations to another user. Allowed for admins and agent-worker service
// accounts (ProviderAgentWorker is set server-side during node registration and
// cannot be self-assigned).
func canImpersonateUser(c echo.Context) bool {
user := auth.GetUser(c)
if user == nil {
return false
}
return user.Role == auth.RoleAdmin || user.Provider == auth.ProviderAgentWorker
}
func ListAgentsEndpoint(app *application.Application) echo.HandlerFunc { func ListAgentsEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
svc := app.AgentPoolService() svc := app.AgentPoolService()
userID := getUserID(c) userID := getUserID(c)
statuses := svc.ListAgentsForUser(userID) statuses := svc.ListAgentsForUser(userID)
agents := make([]string, 0, len(statuses)) agents := slices.Sorted(maps.Keys(statuses))
for name := range statuses {
agents = append(agents, name)
}
sort.Strings(agents)
resp := map[string]any{ resp := map[string]any{
"agents": agents, "agents": agents,
"agentCount": len(agents), "agentCount": len(agents),
@ -111,13 +128,13 @@ func GetAgentEndpoint(app *application.Application) echo.HandlerFunc {
svc := app.AgentPoolService() svc := app.AgentPoolService()
userID := effectiveUserID(c) userID := effectiveUserID(c)
name := c.Param("name") name := c.Param("name")
ag := svc.GetAgentForUser(userID, name)
if ag == nil { statuses := svc.ListAgentsForUser(userID)
active, exists := statuses[name]
if !exists {
return c.JSON(http.StatusNotFound, map[string]string{"error": "Agent not found"}) return c.JSON(http.StatusNotFound, map[string]string{"error": "Agent not found"})
} }
return c.JSON(http.StatusOK, map[string]any{ return c.JSON(http.StatusOK, map[string]any{"active": active})
"active": !ag.Paused(),
})
} }
} }
@ -192,9 +209,13 @@ func GetAgentStatusEndpoint(app *application.Application) echo.HandlerFunc {
svc := app.AgentPoolService() svc := app.AgentPoolService()
userID := effectiveUserID(c) userID := effectiveUserID(c)
name := c.Param("name") name := c.Param("name")
history := svc.GetAgentStatusForUser(userID, name) history := svc.GetAgentStatusForUser(userID, name)
if history == nil { if history == nil {
history = &state.Status{ActionResults: []coreTypes.ActionState{}} return c.JSON(http.StatusOK, map[string]any{
"Name": name,
"History": []string{},
})
} }
entries := []string{} entries := []string{}
for i := len(history.Results()) - 1; i >= 0; i-- { for i := len(history.Results()) - 1; i >= 0; i-- {
@ -221,10 +242,14 @@ func GetAgentObservablesEndpoint(app *application.Application) echo.HandlerFunc
svc := app.AgentPoolService() svc := app.AgentPoolService()
userID := effectiveUserID(c) userID := effectiveUserID(c)
name := c.Param("name") name := c.Param("name")
history, err := svc.GetAgentObservablesForUser(userID, name) history, err := svc.GetAgentObservablesForUser(userID, name)
if err != nil { if err != nil {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
} }
if history == nil {
history = []json.RawMessage{}
}
return c.JSON(http.StatusOK, map[string]any{ return c.JSON(http.StatusOK, map[string]any{
"Name": name, "Name": name,
"History": history, "History": history,
@ -278,26 +303,30 @@ func AgentSSEEndpoint(app *application.Application) echo.HandlerFunc {
svc := app.AgentPoolService() svc := app.AgentPoolService()
userID := effectiveUserID(c) userID := effectiveUserID(c)
name := c.Param("name") name := c.Param("name")
manager := svc.GetSSEManagerForUser(userID, name)
if manager == nil {
return c.JSON(http.StatusNotFound, map[string]string{"error": "Agent not found"})
}
return services.HandleSSE(c, manager)
}
}
type agentConfigMetaResponse struct { // Try local SSE manager first
state.AgentConfigMeta manager := svc.GetSSEManagerForUser(userID, name)
OutputsDir string `json:"OutputsDir"` if manager != nil {
return agentpool.HandleSSE(c, manager)
}
// Fall back to distributed EventBridge SSE
var bridge *agents.EventBridge
if d := app.Distributed(); d != nil {
bridge = d.AgentBridge
}
if bridge != nil {
return bridge.HandleSSE(c, name, userID)
}
return c.JSON(http.StatusNotFound, map[string]string{"error": "Agent not found"})
}
} }
func GetAgentConfigMetaEndpoint(app *application.Application) echo.HandlerFunc { func GetAgentConfigMetaEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
svc := app.AgentPoolService() svc := app.AgentPoolService()
return c.JSON(http.StatusOK, agentConfigMetaResponse{ return c.JSON(http.StatusOK, svc.GetConfigMetaResult())
AgentConfigMeta: svc.GetConfigMeta(),
OutputsDir: svc.OutputsDir(),
})
} }
} }

View file

@ -10,7 +10,7 @@ import (
"github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/core/services/galleryop"
"github.com/mudler/LocalAI/pkg/system" "github.com/mudler/LocalAI/pkg/system"
"github.com/mudler/xlog" "github.com/mudler/xlog"
) )
@ -19,14 +19,14 @@ type BackendEndpointService struct {
galleries []config.Gallery galleries []config.Gallery
backendPath string backendPath string
backendSystemPath string backendSystemPath string
backendApplier *services.GalleryService backendApplier *galleryop.GalleryService
} }
type GalleryBackend struct { type GalleryBackend struct {
ID string `json:"id"` ID string `json:"id"`
} }
func CreateBackendEndpointService(galleries []config.Gallery, systemState *system.SystemState, backendApplier *services.GalleryService) BackendEndpointService { func CreateBackendEndpointService(galleries []config.Gallery, systemState *system.SystemState, backendApplier *galleryop.GalleryService) BackendEndpointService {
return BackendEndpointService{ return BackendEndpointService{
galleries: galleries, galleries: galleries,
backendPath: systemState.Backend.BackendsPath, backendPath: systemState.Backend.BackendsPath,
@ -37,7 +37,7 @@ func CreateBackendEndpointService(galleries []config.Gallery, systemState *syste
// GetOpStatusEndpoint returns the job status // GetOpStatusEndpoint returns the job status
// @Summary Returns the job status // @Summary Returns the job status
// @Success 200 {object} services.GalleryOpStatus "Response" // @Success 200 {object} galleryop.OpStatus "Response"
// @Router /backends/jobs/{uuid} [get] // @Router /backends/jobs/{uuid} [get]
func (mgs *BackendEndpointService) GetOpStatusEndpoint() echo.HandlerFunc { func (mgs *BackendEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
@ -51,7 +51,7 @@ func (mgs *BackendEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
// GetAllStatusEndpoint returns all the jobs status progress // GetAllStatusEndpoint returns all the jobs status progress
// @Summary Returns all the jobs status progress // @Summary Returns all the jobs status progress
// @Success 200 {object} map[string]services.GalleryOpStatus "Response" // @Success 200 {object} map[string]galleryop.OpStatus "Response"
// @Router /backends/jobs [get] // @Router /backends/jobs [get]
func (mgs *BackendEndpointService) GetAllStatusEndpoint() echo.HandlerFunc { func (mgs *BackendEndpointService) GetAllStatusEndpoint() echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
@ -76,7 +76,7 @@ func (mgs *BackendEndpointService) ApplyBackendEndpoint() echo.HandlerFunc {
if err != nil { if err != nil {
return err return err
} }
mgs.backendApplier.BackendGalleryChannel <- services.GalleryOp[gallery.GalleryBackend, any]{ mgs.backendApplier.BackendGalleryChannel <- galleryop.ManagementOp[gallery.GalleryBackend, any]{
ID: uuid.String(), ID: uuid.String(),
GalleryElementName: input.ID, GalleryElementName: input.ID,
Galleries: mgs.galleries, Galleries: mgs.galleries,
@ -95,7 +95,7 @@ func (mgs *BackendEndpointService) DeleteBackendEndpoint() echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
backendName := c.Param("name") backendName := c.Param("name")
mgs.backendApplier.BackendGalleryChannel <- services.GalleryOp[gallery.GalleryBackend, any]{ mgs.backendApplier.BackendGalleryChannel <- galleryop.ManagementOp[gallery.GalleryBackend, any]{
Delete: true, Delete: true,
GalleryElementName: backendName, GalleryElementName: backendName,
Galleries: mgs.galleries, Galleries: mgs.galleries,
@ -114,9 +114,9 @@ func (mgs *BackendEndpointService) DeleteBackendEndpoint() echo.HandlerFunc {
// @Summary List all Backends // @Summary List all Backends
// @Success 200 {object} []gallery.GalleryBackend "Response" // @Success 200 {object} []gallery.GalleryBackend "Response"
// @Router /backends [get] // @Router /backends [get]
func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.SystemState) echo.HandlerFunc { func (mgs *BackendEndpointService) ListBackendsEndpoint() echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
backends, err := gallery.ListSystemBackends(systemState) backends, err := mgs.backendApplier.ListBackends()
if err != nil { if err != nil {
return err return err
} }

View file

@ -3,7 +3,7 @@ package localai
import ( import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/core/services/monitoring"
) )
// BackendMonitorEndpoint returns the status of the specified backend // BackendMonitorEndpoint returns the status of the specified backend
@ -11,7 +11,7 @@ import (
// @Param request body schema.BackendMonitorRequest true "Backend statistics request" // @Param request body schema.BackendMonitorRequest true "Backend statistics request"
// @Success 200 {object} proto.StatusResponse "Response" // @Success 200 {object} proto.StatusResponse "Response"
// @Router /backend/monitor [get] // @Router /backend/monitor [get]
func BackendMonitorEndpoint(bm *services.BackendMonitorService) echo.HandlerFunc { func BackendMonitorEndpoint(bm *monitoring.BackendMonitorService) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
input := new(schema.BackendMonitorRequest) input := new(schema.BackendMonitorRequest)
@ -32,7 +32,7 @@ func BackendMonitorEndpoint(bm *services.BackendMonitorService) echo.HandlerFunc
// @Summary Backend monitor endpoint // @Summary Backend monitor endpoint
// @Param request body schema.BackendMonitorRequest true "Backend statistics request" // @Param request body schema.BackendMonitorRequest true "Backend statistics request"
// @Router /backend/shutdown [post] // @Router /backend/shutdown [post]
func BackendShutdownEndpoint(bm *services.BackendMonitorService) echo.HandlerFunc { func BackendShutdownEndpoint(bm *monitoring.BackendMonitorService) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
input := new(schema.BackendMonitorRequest) input := new(schema.BackendMonitorRequest)
// Get input data from the request body // Get input data from the request body

Some files were not shown because too many files have changed in this diff Show more