mirror of
https://github.com/mudler/LocalAI
synced 2026-04-21 13:27:21 +00:00
feat: add distributed mode (#9124)
* feat: add distributed mode (experimental) Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix data races, mutexes, transactions Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactorings Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fixups Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix events and tool stream in agent chat Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * use ginkgo Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(cron): compute correctly time boundaries avoiding re-triggering Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * enhancements, refactorings Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * do not flood of healthy checks Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * do not list obvious backends as text backends Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * tests fixups Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Drop redundant healthcheck Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * enhancements, refactorings Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
4c870288d9
commit
59108fbe32
389 changed files with 276305 additions and 246521 deletions
2
.github/gallery-agent/agent.go
vendored
2
.github/gallery-agent/agent.go
vendored
|
|
@ -406,7 +406,7 @@ func getHuggingFaceAvatarURL(author string) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse the response to get avatar URL
|
// Parse the response to get avatar URL
|
||||||
var userInfo map[string]interface{}
|
var userInfo map[string]any
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ""
|
return ""
|
||||||
|
|
|
||||||
40
.github/gallery-agent/testing.go
vendored
40
.github/gallery-agent/testing.go
vendored
|
|
@ -3,7 +3,7 @@ package main
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand/v2"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
@ -13,11 +13,11 @@ func runSyntheticMode() error {
|
||||||
generator := NewSyntheticDataGenerator()
|
generator := NewSyntheticDataGenerator()
|
||||||
|
|
||||||
// Generate a random number of synthetic models (1-3)
|
// Generate a random number of synthetic models (1-3)
|
||||||
numModels := generator.rand.Intn(3) + 1
|
numModels := generator.rand.IntN(3) + 1
|
||||||
fmt.Printf("Generating %d synthetic models for testing...\n", numModels)
|
fmt.Printf("Generating %d synthetic models for testing...\n", numModels)
|
||||||
|
|
||||||
var models []ProcessedModel
|
var models []ProcessedModel
|
||||||
for i := 0; i < numModels; i++ {
|
for i := range numModels {
|
||||||
model := generator.GenerateProcessedModel()
|
model := generator.GenerateProcessedModel()
|
||||||
models = append(models, model)
|
models = append(models, model)
|
||||||
fmt.Printf("Generated synthetic model: %s\n", model.ModelID)
|
fmt.Printf("Generated synthetic model: %s\n", model.ModelID)
|
||||||
|
|
@ -42,14 +42,14 @@ type SyntheticDataGenerator struct {
|
||||||
// NewSyntheticDataGenerator creates a new synthetic data generator
|
// NewSyntheticDataGenerator creates a new synthetic data generator
|
||||||
func NewSyntheticDataGenerator() *SyntheticDataGenerator {
|
func NewSyntheticDataGenerator() *SyntheticDataGenerator {
|
||||||
return &SyntheticDataGenerator{
|
return &SyntheticDataGenerator{
|
||||||
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
|
rand: rand.New(rand.NewPCG(uint64(time.Now().UnixNano()), 0)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateProcessedModelFile creates a synthetic ProcessedModelFile
|
// GenerateProcessedModelFile creates a synthetic ProcessedModelFile
|
||||||
func (g *SyntheticDataGenerator) GenerateProcessedModelFile() ProcessedModelFile {
|
func (g *SyntheticDataGenerator) GenerateProcessedModelFile() ProcessedModelFile {
|
||||||
fileTypes := []string{"model", "readme", "other"}
|
fileTypes := []string{"model", "readme", "other"}
|
||||||
fileType := fileTypes[g.rand.Intn(len(fileTypes))]
|
fileType := fileTypes[g.rand.IntN(len(fileTypes))]
|
||||||
|
|
||||||
var path string
|
var path string
|
||||||
var isReadme bool
|
var isReadme bool
|
||||||
|
|
@ -68,7 +68,7 @@ func (g *SyntheticDataGenerator) GenerateProcessedModelFile() ProcessedModelFile
|
||||||
|
|
||||||
return ProcessedModelFile{
|
return ProcessedModelFile{
|
||||||
Path: path,
|
Path: path,
|
||||||
Size: int64(g.rand.Intn(1000000000) + 1000000), // 1MB to 1GB
|
Size: int64(g.rand.IntN(1000000000) + 1000000), // 1MB to 1GB
|
||||||
SHA256: g.randomSHA256(),
|
SHA256: g.randomSHA256(),
|
||||||
IsReadme: isReadme,
|
IsReadme: isReadme,
|
||||||
FileType: fileType,
|
FileType: fileType,
|
||||||
|
|
@ -80,19 +80,19 @@ func (g *SyntheticDataGenerator) GenerateProcessedModel() ProcessedModel {
|
||||||
authors := []string{"microsoft", "meta", "google", "openai", "anthropic", "mistralai", "huggingface"}
|
authors := []string{"microsoft", "meta", "google", "openai", "anthropic", "mistralai", "huggingface"}
|
||||||
modelNames := []string{"llama", "gpt", "claude", "mistral", "gemma", "phi", "qwen", "codellama"}
|
modelNames := []string{"llama", "gpt", "claude", "mistral", "gemma", "phi", "qwen", "codellama"}
|
||||||
|
|
||||||
author := authors[g.rand.Intn(len(authors))]
|
author := authors[g.rand.IntN(len(authors))]
|
||||||
modelName := modelNames[g.rand.Intn(len(modelNames))]
|
modelName := modelNames[g.rand.IntN(len(modelNames))]
|
||||||
modelID := fmt.Sprintf("%s/%s-%s", author, modelName, g.randomString(6))
|
modelID := fmt.Sprintf("%s/%s-%s", author, modelName, g.randomString(6))
|
||||||
|
|
||||||
// Generate files
|
// Generate files
|
||||||
numFiles := g.rand.Intn(5) + 2 // 2-6 files
|
numFiles := g.rand.IntN(5) + 2 // 2-6 files
|
||||||
files := make([]ProcessedModelFile, numFiles)
|
files := make([]ProcessedModelFile, numFiles)
|
||||||
|
|
||||||
// Ensure at least one model file and one readme
|
// Ensure at least one model file and one readme
|
||||||
hasModelFile := false
|
hasModelFile := false
|
||||||
hasReadme := false
|
hasReadme := false
|
||||||
|
|
||||||
for i := 0; i < numFiles; i++ {
|
for i := range numFiles {
|
||||||
files[i] = g.GenerateProcessedModelFile()
|
files[i] = g.GenerateProcessedModelFile()
|
||||||
if files[i].FileType == "model" {
|
if files[i].FileType == "model" {
|
||||||
hasModelFile = true
|
hasModelFile = true
|
||||||
|
|
@ -140,27 +140,27 @@ func (g *SyntheticDataGenerator) GenerateProcessedModel() ProcessedModel {
|
||||||
|
|
||||||
// Generate sample metadata
|
// Generate sample metadata
|
||||||
licenses := []string{"apache-2.0", "mit", "llama2", "gpl-3.0", "bsd", ""}
|
licenses := []string{"apache-2.0", "mit", "llama2", "gpl-3.0", "bsd", ""}
|
||||||
license := licenses[g.rand.Intn(len(licenses))]
|
license := licenses[g.rand.IntN(len(licenses))]
|
||||||
|
|
||||||
sampleTags := []string{"llm", "gguf", "gpu", "cpu", "text-to-text", "chat", "instruction-tuned"}
|
sampleTags := []string{"llm", "gguf", "gpu", "cpu", "text-to-text", "chat", "instruction-tuned"}
|
||||||
numTags := g.rand.Intn(4) + 3 // 3-6 tags
|
numTags := g.rand.IntN(4) + 3 // 3-6 tags
|
||||||
tags := make([]string, numTags)
|
tags := make([]string, numTags)
|
||||||
for i := 0; i < numTags; i++ {
|
for i := range numTags {
|
||||||
tags[i] = sampleTags[g.rand.Intn(len(sampleTags))]
|
tags[i] = sampleTags[g.rand.IntN(len(sampleTags))]
|
||||||
}
|
}
|
||||||
// Remove duplicates
|
// Remove duplicates
|
||||||
tags = g.removeDuplicates(tags)
|
tags = g.removeDuplicates(tags)
|
||||||
|
|
||||||
// Optionally include icon (50% chance)
|
// Optionally include icon (50% chance)
|
||||||
icon := ""
|
icon := ""
|
||||||
if g.rand.Intn(2) == 0 {
|
if g.rand.IntN(2) == 0 {
|
||||||
icon = fmt.Sprintf("https://cdn-avatars.huggingface.co/v1/production/uploads/%s.png", g.randomString(24))
|
icon = fmt.Sprintf("https://cdn-avatars.huggingface.co/v1/production/uploads/%s.png", g.randomString(24))
|
||||||
}
|
}
|
||||||
|
|
||||||
return ProcessedModel{
|
return ProcessedModel{
|
||||||
ModelID: modelID,
|
ModelID: modelID,
|
||||||
Author: author,
|
Author: author,
|
||||||
Downloads: g.rand.Intn(1000000) + 1000,
|
Downloads: g.rand.IntN(1000000) + 1000,
|
||||||
LastModified: g.randomDate(),
|
LastModified: g.randomDate(),
|
||||||
Files: files,
|
Files: files,
|
||||||
PreferredModelFile: preferredModelFile,
|
PreferredModelFile: preferredModelFile,
|
||||||
|
|
@ -180,7 +180,7 @@ func (g *SyntheticDataGenerator) randomString(length int) string {
|
||||||
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||||
b := make([]byte, length)
|
b := make([]byte, length)
|
||||||
for i := range b {
|
for i := range b {
|
||||||
b[i] = charset[g.rand.Intn(len(charset))]
|
b[i] = charset[g.rand.IntN(len(charset))]
|
||||||
}
|
}
|
||||||
return string(b)
|
return string(b)
|
||||||
}
|
}
|
||||||
|
|
@ -189,14 +189,14 @@ func (g *SyntheticDataGenerator) randomSHA256() string {
|
||||||
const charset = "0123456789abcdef"
|
const charset = "0123456789abcdef"
|
||||||
b := make([]byte, 64)
|
b := make([]byte, 64)
|
||||||
for i := range b {
|
for i := range b {
|
||||||
b[i] = charset[g.rand.Intn(len(charset))]
|
b[i] = charset[g.rand.IntN(len(charset))]
|
||||||
}
|
}
|
||||||
return string(b)
|
return string(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *SyntheticDataGenerator) randomDate() string {
|
func (g *SyntheticDataGenerator) randomDate() string {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
daysAgo := g.rand.Intn(365) // Random date within last year
|
daysAgo := g.rand.IntN(365) // Random date within last year
|
||||||
pastDate := now.AddDate(0, 0, -daysAgo)
|
pastDate := now.AddDate(0, 0, -daysAgo)
|
||||||
return pastDate.Format("2006-01-02T15:04:05.000Z")
|
return pastDate.Format("2006-01-02T15:04:05.000Z")
|
||||||
}
|
}
|
||||||
|
|
@ -220,5 +220,5 @@ func (g *SyntheticDataGenerator) generateReadmeContent(modelName, author string)
|
||||||
fmt.Sprintf("# %s Language Model\n\nDeveloped by %s, this model represents state-of-the-art performance in natural language understanding and generation.\n\n## Key Features\n\n- Multilingual support\n- Context-aware responses\n- Efficient memory usage\n- Fast inference speed\n\n## Applications\n\n- Chatbots and virtual assistants\n- Content generation\n- Code completion\n- Educational tools", strings.Title(modelName), author),
|
fmt.Sprintf("# %s Language Model\n\nDeveloped by %s, this model represents state-of-the-art performance in natural language understanding and generation.\n\n## Key Features\n\n- Multilingual support\n- Context-aware responses\n- Efficient memory usage\n- Fast inference speed\n\n## Applications\n\n- Chatbots and virtual assistants\n- Content generation\n- Code completion\n- Educational tools", strings.Title(modelName), author),
|
||||||
}
|
}
|
||||||
|
|
||||||
return templates[g.rand.Intn(len(templates))]
|
return templates[g.rand.IntN(len(templates))]
|
||||||
}
|
}
|
||||||
|
|
|
||||||
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
|
|
@ -21,7 +21,7 @@ jobs:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
go-version: ['1.25.x']
|
go-version: ['1.26.x']
|
||||||
steps:
|
steps:
|
||||||
- name: Free Disk Space (Ubuntu)
|
- name: Free Disk Space (Ubuntu)
|
||||||
uses: jlumbroso/free-disk-space@main
|
uses: jlumbroso/free-disk-space@main
|
||||||
|
|
@ -179,7 +179,7 @@ jobs:
|
||||||
runs-on: macos-latest
|
runs-on: macos-latest
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
go-version: ['1.25.x']
|
go-version: ['1.26.x']
|
||||||
steps:
|
steps:
|
||||||
- name: Clone
|
- name: Clone
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v6
|
||||||
|
|
|
||||||
|
|
@ -176,7 +176,7 @@ ENV PATH=/opt/rocm/bin:${PATH}
|
||||||
# The requirements-core target is common to all images. It should not be placed in requirements-core unless every single build will use it.
|
# The requirements-core target is common to all images. It should not be placed in requirements-core unless every single build will use it.
|
||||||
FROM requirements-drivers AS build-requirements
|
FROM requirements-drivers AS build-requirements
|
||||||
|
|
||||||
ARG GO_VERSION=1.25.4
|
ARG GO_VERSION=1.26.0
|
||||||
ARG CMAKE_VERSION=3.31.10
|
ARG CMAKE_VERSION=3.31.10
|
||||||
ARG CMAKE_FROM_SOURCE=false
|
ARG CMAKE_FROM_SOURCE=false
|
||||||
ARG TARGETARCH
|
ARG TARGETARCH
|
||||||
|
|
@ -319,7 +319,6 @@ COPY ./.git ./.git
|
||||||
# Some of the Go backends use libs from the main src, we could further optimize the caching by building the CPP backends before here
|
# Some of the Go backends use libs from the main src, we could further optimize the caching by building the CPP backends before here
|
||||||
COPY ./pkg/grpc ./pkg/grpc
|
COPY ./pkg/grpc ./pkg/grpc
|
||||||
COPY ./pkg/utils ./pkg/utils
|
COPY ./pkg/utils ./pkg/utils
|
||||||
COPY ./pkg/langchain ./pkg/langchain
|
|
||||||
|
|
||||||
RUN ls -l ./
|
RUN ls -l ./
|
||||||
RUN make protogen-go
|
RUN make protogen-go
|
||||||
|
|
|
||||||
|
|
@ -154,6 +154,7 @@ For older news and full release notes, see [GitHub Releases](https://github.com/
|
||||||
- [Object Detection](https://localai.io/features/object-detection/)
|
- [Object Detection](https://localai.io/features/object-detection/)
|
||||||
- [Reranker API](https://localai.io/features/reranker/)
|
- [Reranker API](https://localai.io/features/reranker/)
|
||||||
- [P2P Inferencing](https://localai.io/features/distribute/)
|
- [P2P Inferencing](https://localai.io/features/distribute/)
|
||||||
|
- [Distributed Mode](https://localai.io/features/distributed-mode/) — Horizontal scaling with PostgreSQL + NATS
|
||||||
- [Model Context Protocol (MCP)](https://localai.io/docs/features/mcp/)
|
- [Model Context Protocol (MCP)](https://localai.io/docs/features/mcp/)
|
||||||
- [Built-in Agents](https://localai.io/features/agents/) — Autonomous AI agents with tool use, RAG, skills, SSE streaming, and [Agent Hub](https://agenthub.localai.io)
|
- [Built-in Agents](https://localai.io/features/agents/) — Autonomous AI agents with tool use, RAG, skills, SSE streaming, and [Agent Hub](https://agenthub.localai.io)
|
||||||
- [Backend Gallery](https://localai.io/backends/) — Install/remove backends on the fly via OCI images
|
- [Backend Gallery](https://localai.io/backends/) — Install/remove backends on the fly via OCI images
|
||||||
|
|
|
||||||
|
|
@ -51,6 +51,7 @@ service Backend {
|
||||||
rpc StartQuantization(QuantizationRequest) returns (QuantizationJobResult) {}
|
rpc StartQuantization(QuantizationRequest) returns (QuantizationJobResult) {}
|
||||||
rpc QuantizationProgress(QuantizationProgressRequest) returns (stream QuantizationProgressUpdate) {}
|
rpc QuantizationProgress(QuantizationProgressRequest) returns (stream QuantizationProgressUpdate) {}
|
||||||
rpc StopQuantization(QuantizationStopRequest) returns (Result) {}
|
rpc StopQuantization(QuantizationStopRequest) returns (Result) {}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Define the empty request
|
// Define the empty request
|
||||||
|
|
@ -676,3 +677,4 @@ message QuantizationProgressUpdate {
|
||||||
message QuantizationStopRequest {
|
message QuantizationStopRequest {
|
||||||
string job_id = 1;
|
string job_id = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,8 +22,10 @@
|
||||||
#include <grpcpp/ext/proto_server_reflection_plugin.h>
|
#include <grpcpp/ext/proto_server_reflection_plugin.h>
|
||||||
#include <grpcpp/grpcpp.h>
|
#include <grpcpp/grpcpp.h>
|
||||||
#include <grpcpp/health_check_service_interface.h>
|
#include <grpcpp/health_check_service_interface.h>
|
||||||
|
#include <grpcpp/security/server_credentials.h>
|
||||||
#include <regex>
|
#include <regex>
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
|
#include <cstdlib>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <signal.h>
|
#include <signal.h>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
|
|
@ -37,6 +39,47 @@ using grpc::Server;
|
||||||
using grpc::ServerBuilder;
|
using grpc::ServerBuilder;
|
||||||
using grpc::ServerContext;
|
using grpc::ServerContext;
|
||||||
using grpc::Status;
|
using grpc::Status;
|
||||||
|
|
||||||
|
// gRPC bearer token auth via AuthMetadataProcessor for distributed mode.
|
||||||
|
// Reads LOCALAI_GRPC_AUTH_TOKEN from the environment. When set, rejects
|
||||||
|
// requests without a matching "authorization: Bearer <token>" metadata header.
|
||||||
|
class TokenAuthMetadataProcessor : public grpc::AuthMetadataProcessor {
|
||||||
|
public:
|
||||||
|
explicit TokenAuthMetadataProcessor(const std::string& token) : token_(token) {}
|
||||||
|
|
||||||
|
bool IsBlocking() const override { return false; }
|
||||||
|
|
||||||
|
grpc::Status Process(const InputMetadata& auth_metadata,
|
||||||
|
grpc::AuthContext* /*context*/,
|
||||||
|
OutputMetadata* /*consumed_auth_metadata*/,
|
||||||
|
OutputMetadata* /*response_metadata*/) override {
|
||||||
|
auto it = auth_metadata.find("authorization");
|
||||||
|
if (it != auth_metadata.end()) {
|
||||||
|
std::string expected = "Bearer " + token_;
|
||||||
|
std::string got(it->second.data(), it->second.size());
|
||||||
|
// Constant-time comparison
|
||||||
|
if (expected.size() == got.size() && ct_memcmp(expected.data(), got.data(), expected.size()) == 0) {
|
||||||
|
return grpc::Status::OK;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "invalid token");
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string token_;
|
||||||
|
|
||||||
|
// Minimal constant-time comparison (avoids OpenSSL dependency)
|
||||||
|
static int ct_memcmp(const void* a, const void* b, size_t n) {
|
||||||
|
const unsigned char* pa = static_cast<const unsigned char*>(a);
|
||||||
|
const unsigned char* pb = static_cast<const unsigned char*>(b);
|
||||||
|
unsigned char result = 0;
|
||||||
|
for (size_t i = 0; i < n; i++) {
|
||||||
|
result |= pa[i] ^ pb[i];
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// END LocalAI
|
// END LocalAI
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -2760,11 +2803,24 @@ int main(int argc, char** argv) {
|
||||||
BackendServiceImpl service(ctx_server);
|
BackendServiceImpl service(ctx_server);
|
||||||
|
|
||||||
ServerBuilder builder;
|
ServerBuilder builder;
|
||||||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
// Add bearer token auth via AuthMetadataProcessor if LOCALAI_GRPC_AUTH_TOKEN is set
|
||||||
|
const char* auth_token = std::getenv("LOCALAI_GRPC_AUTH_TOKEN");
|
||||||
|
std::shared_ptr<grpc::ServerCredentials> creds;
|
||||||
|
if (auth_token != nullptr && auth_token[0] != '\0') {
|
||||||
|
creds = grpc::InsecureServerCredentials();
|
||||||
|
creds->SetAuthMetadataProcessor(
|
||||||
|
std::make_shared<TokenAuthMetadataProcessor>(auth_token));
|
||||||
|
std::cout << "gRPC auth enabled via LOCALAI_GRPC_AUTH_TOKEN" << std::endl;
|
||||||
|
} else {
|
||||||
|
creds = grpc::InsecureServerCredentials();
|
||||||
|
}
|
||||||
|
|
||||||
|
builder.AddListeningPort(server_address, creds);
|
||||||
builder.RegisterService(&service);
|
builder.RegisterService(&service);
|
||||||
builder.SetMaxMessageSize(50 * 1024 * 1024); // 50MB
|
builder.SetMaxMessageSize(50 * 1024 * 1024); // 50MB
|
||||||
builder.SetMaxSendMessageSize(50 * 1024 * 1024); // 50MB
|
builder.SetMaxSendMessageSize(50 * 1024 * 1024); // 50MB
|
||||||
builder.SetMaxReceiveMessageSize(50 * 1024 * 1024); // 50MB
|
builder.SetMaxReceiveMessageSize(50 * 1024 * 1024); // 50MB
|
||||||
|
|
||||||
std::unique_ptr<Server> server(builder.BuildAndStart());
|
std::unique_ptr<Server> server(builder.BuildAndStart());
|
||||||
// run the HTTP server in a thread - see comment below
|
// run the HTTP server in a thread - see comment below
|
||||||
std::thread t([&]()
|
std::thread t([&]()
|
||||||
|
|
|
||||||
|
|
@ -134,7 +134,7 @@ func TestSoundGeneration(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer os.RemoveAll(tmpDir)
|
t.Cleanup(func() { os.RemoveAll(tmpDir) })
|
||||||
|
|
||||||
outputFile := filepath.Join(tmpDir, "output.wav")
|
outputFile := filepath.Join(tmpDir, "output.wav")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
CppLoadModel func(lmModelPath, textEncoderPath, ditModelPath, vaeModelPath string) int
|
CppLoadModel func(lmModelPath, textEncoderPath, ditModelPath, vaeModelPath string) int
|
||||||
CppGenerateMusic func(caption, lyrics string, bpm int, keyscale, timesignature string, duration, temperature float32, instrumental bool, seed int, dst string, threads int) int
|
CppGenerateMusic func(caption, lyrics string, bpm int, keyscale, timesignature string, duration, temperature float32, instrumental bool, seed int, dst string, threads int) int
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -29,18 +29,18 @@ func (a *AceStepCpp) Load(opts *pb.ModelOptions) error {
|
||||||
var textEncoderModel, ditModel, vaeModel string
|
var textEncoderModel, ditModel, vaeModel string
|
||||||
|
|
||||||
for _, oo := range opts.Options {
|
for _, oo := range opts.Options {
|
||||||
parts := strings.SplitN(oo, ":", 2)
|
key, value, found := strings.Cut(oo, ":")
|
||||||
if len(parts) != 2 {
|
if !found {
|
||||||
fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo)
|
fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
switch parts[0] {
|
switch key {
|
||||||
case "text_encoder_model":
|
case "text_encoder_model":
|
||||||
textEncoderModel = parts[1]
|
textEncoderModel = value
|
||||||
case "dit_model":
|
case "dit_model":
|
||||||
ditModel = parts[1]
|
ditModel = value
|
||||||
case "vae_model":
|
case "vae_model":
|
||||||
vaeModel = parts[1]
|
vaeModel = value
|
||||||
default:
|
default:
|
||||||
fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo)
|
fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,6 @@ type LLM struct {
|
||||||
draftModel *llama.LLama
|
draftModel *llama.LLama
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Free releases GPU resources and frees the llama model
|
// Free releases GPU resources and frees the llama model
|
||||||
// This should be called when the model is being unloaded to properly release VRAM
|
// This should be called when the model is being unloaded to properly release VRAM
|
||||||
func (llm *LLM) Free() error {
|
func (llm *LLM) Free() error {
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
//go:build debug
|
//go:build debug
|
||||||
// +build debug
|
|
||||||
|
|
||||||
package main
|
package main
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
//go:build !debug
|
//go:build !debug
|
||||||
// +build !debug
|
|
||||||
|
|
||||||
package main
|
package main
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -332,7 +332,7 @@ func normalizedCosineSimilarity(k1, k2 []float32) float32 {
|
||||||
assert(len(k1) == len(k2), fmt.Sprintf("normalizedCosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
|
assert(len(k1) == len(k2), fmt.Sprintf("normalizedCosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
|
||||||
|
|
||||||
var dot float32
|
var dot float32
|
||||||
for i := 0; i < len(k1); i++ {
|
for i := range len(k1) {
|
||||||
dot += k1[i] * k2[i]
|
dot += k1[i] * k2[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -419,7 +419,7 @@ func cosineSimilarity(k1, k2 []float32, mag1 float64) float32 {
|
||||||
assert(len(k1) == len(k2), fmt.Sprintf("cosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
|
assert(len(k1) == len(k2), fmt.Sprintf("cosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
|
||||||
|
|
||||||
var dot, mag2 float64
|
var dot, mag2 float64
|
||||||
for i := 0; i < len(k1); i++ {
|
for i := range len(k1) {
|
||||||
dot += float64(k1[i] * k2[i])
|
dot += float64(k1[i] * k2[i])
|
||||||
mag2 += float64(k2[i] * k2[i])
|
mag2 += float64(k2[i] * k2[i])
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -701,7 +701,7 @@ var _ = Describe("Opus", func() {
|
||||||
// to one-shot (only difference is resampler batch boundaries).
|
// to one-shot (only difference is resampler batch boundaries).
|
||||||
var maxDiff float64
|
var maxDiff float64
|
||||||
var sumDiffSq float64
|
var sumDiffSq float64
|
||||||
for i := 0; i < minLen; i++ {
|
for i := range minLen {
|
||||||
diff := math.Abs(float64(oneShotTail[i]) - float64(batchedTail[i]))
|
diff := math.Abs(float64(oneShotTail[i]) - float64(batchedTail[i]))
|
||||||
if diff > maxDiff {
|
if diff > maxDiff {
|
||||||
maxDiff = diff
|
maxDiff = diff
|
||||||
|
|
@ -774,7 +774,7 @@ var _ = Describe("Opus", func() {
|
||||||
minLen := min(len(refTail), min(len(persistentTail), len(freshTail)))
|
minLen := min(len(refTail), min(len(persistentTail), len(freshTail)))
|
||||||
|
|
||||||
var persistentMaxDiff, freshMaxDiff float64
|
var persistentMaxDiff, freshMaxDiff float64
|
||||||
for i := 0; i < minLen; i++ {
|
for i := range minLen {
|
||||||
pd := math.Abs(float64(refTail[i]) - float64(persistentTail[i]))
|
pd := math.Abs(float64(refTail[i]) - float64(persistentTail[i]))
|
||||||
fd := math.Abs(float64(refTail[i]) - float64(freshTail[i]))
|
fd := math.Abs(float64(refTail[i]) - float64(freshTail[i]))
|
||||||
if pd > persistentMaxDiff {
|
if pd > persistentMaxDiff {
|
||||||
|
|
@ -932,7 +932,7 @@ var _ = Describe("Opus", func() {
|
||||||
GinkgoWriter.Printf("Zero-crossing intervals: mean=%.2f stddev=%.2f CV=%.3f (expected period ~%.1f)\n",
|
GinkgoWriter.Printf("Zero-crossing intervals: mean=%.2f stddev=%.2f CV=%.3f (expected period ~%.1f)\n",
|
||||||
mean, stddev, stddev/mean, 16000.0/440.0/2.0)
|
mean, stddev, stddev/mean, 16000.0/440.0/2.0)
|
||||||
|
|
||||||
Expect(stddev / mean).To(BeNumerically("<", 0.15),
|
Expect(stddev/mean).To(BeNumerically("<", 0.15),
|
||||||
fmt.Sprintf("irregular zero crossings suggest discontinuity: CV=%.3f", stddev/mean))
|
fmt.Sprintf("irregular zero crossings suggest discontinuity: CV=%.3f", stddev/mean))
|
||||||
|
|
||||||
// Also check frequency is correct
|
// Also check frequency is correct
|
||||||
|
|
@ -978,7 +978,7 @@ var _ = Describe("Opus", func() {
|
||||||
|
|
||||||
// Every sample must be identical — the resampler is deterministic
|
// Every sample must be identical — the resampler is deterministic
|
||||||
var maxDiff float64
|
var maxDiff float64
|
||||||
for i := 0; i < len(oneShot); i++ {
|
for i := range len(oneShot) {
|
||||||
diff := math.Abs(float64(oneShot[i]) - float64(batched[i]))
|
diff := math.Abs(float64(oneShot[i]) - float64(batched[i]))
|
||||||
if diff > maxDiff {
|
if diff > maxDiff {
|
||||||
maxDiff = diff
|
maxDiff = diff
|
||||||
|
|
@ -1037,13 +1037,13 @@ var _ = Describe("Opus", func() {
|
||||||
binary.LittleEndian.PutUint32(hdr[4:8], uint32(36+dataLen))
|
binary.LittleEndian.PutUint32(hdr[4:8], uint32(36+dataLen))
|
||||||
copy(hdr[8:12], "WAVE")
|
copy(hdr[8:12], "WAVE")
|
||||||
copy(hdr[12:16], "fmt ")
|
copy(hdr[12:16], "fmt ")
|
||||||
binary.LittleEndian.PutUint32(hdr[16:20], 16) // chunk size
|
binary.LittleEndian.PutUint32(hdr[16:20], 16) // chunk size
|
||||||
binary.LittleEndian.PutUint16(hdr[20:22], 1) // PCM
|
binary.LittleEndian.PutUint16(hdr[20:22], 1) // PCM
|
||||||
binary.LittleEndian.PutUint16(hdr[22:24], 1) // mono
|
binary.LittleEndian.PutUint16(hdr[22:24], 1) // mono
|
||||||
binary.LittleEndian.PutUint32(hdr[24:28], uint32(sampleRate)) // sample rate
|
binary.LittleEndian.PutUint32(hdr[24:28], uint32(sampleRate)) // sample rate
|
||||||
binary.LittleEndian.PutUint32(hdr[28:32], uint32(sampleRate*2)) // byte rate
|
binary.LittleEndian.PutUint32(hdr[28:32], uint32(sampleRate*2)) // byte rate
|
||||||
binary.LittleEndian.PutUint16(hdr[32:34], 2) // block align
|
binary.LittleEndian.PutUint16(hdr[32:34], 2) // block align
|
||||||
binary.LittleEndian.PutUint16(hdr[34:36], 16) // bits per sample
|
binary.LittleEndian.PutUint16(hdr[34:36], 16) // bits per sample
|
||||||
copy(hdr[36:40], "data")
|
copy(hdr[36:40], "data")
|
||||||
binary.LittleEndian.PutUint32(hdr[40:44], uint32(dataLen))
|
binary.LittleEndian.PutUint32(hdr[40:44], uint32(dataLen))
|
||||||
|
|
||||||
|
|
@ -1126,7 +1126,7 @@ var _ = Describe("Opus", func() {
|
||||||
)
|
)
|
||||||
|
|
||||||
pcm := make([]byte, toneNumSamples*2)
|
pcm := make([]byte, toneNumSamples*2)
|
||||||
for i := 0; i < toneNumSamples; i++ {
|
for i := range toneNumSamples {
|
||||||
sample := int16(toneAmplitude * math.Sin(2*math.Pi*toneFreq*float64(i)/float64(toneSampleRate)))
|
sample := int16(toneAmplitude * math.Sin(2*math.Pi*toneFreq*float64(i)/float64(toneSampleRate)))
|
||||||
binary.LittleEndian.PutUint16(pcm[i*2:], uint16(sample))
|
binary.LittleEndian.PutUint16(pcm[i*2:], uint16(sample))
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -138,7 +138,7 @@ func TestAudioTranscription(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer os.RemoveAll(tmpDir)
|
t.Cleanup(func() { os.RemoveAll(tmpDir) })
|
||||||
|
|
||||||
// Download sample audio — JFK "ask not what your country can do for you" clip
|
// Download sample audio — JFK "ask not what your country can do for you" clip
|
||||||
audioFile := filepath.Join(tmpDir, "sample.wav")
|
audioFile := filepath.Join(tmpDir, "sample.wav")
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,10 @@ import tempfile
|
||||||
import backend_pb2
|
import backend_pb2
|
||||||
import backend_pb2_grpc
|
import backend_pb2_grpc
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
from acestep.inference import (
|
from acestep.inference import (
|
||||||
GenerationParams,
|
GenerationParams,
|
||||||
GenerationConfig,
|
GenerationConfig,
|
||||||
|
|
@ -444,6 +448,8 @@ def serve(address):
|
||||||
("grpc.max_send_message_length", 50 * 1024 * 1024),
|
("grpc.max_send_message_length", 50 * 1024 * 1024),
|
||||||
("grpc.max_receive_message_length", 50 * 1024 * 1024),
|
("grpc.max_receive_message_length", 50 * 1024 * 1024),
|
||||||
],
|
],
|
||||||
|
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
)
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,10 @@ import torchaudio as ta
|
||||||
from chatterbox.tts import ChatterboxTTS
|
from chatterbox.tts import ChatterboxTTS
|
||||||
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
|
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
def is_float(s):
|
def is_float(s):
|
||||||
|
|
@ -225,7 +229,9 @@ def serve(address):
|
||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
|
|
|
||||||
78
backend/python/common/grpc_auth.py
Normal file
78
backend/python/common/grpc_auth.py
Normal file
|
|
@ -0,0 +1,78 @@
|
||||||
|
"""Shared gRPC bearer token authentication interceptor for LocalAI Python backends.
|
||||||
|
|
||||||
|
When the environment variable LOCALAI_GRPC_AUTH_TOKEN is set, requests without
|
||||||
|
a valid Bearer token in the 'authorization' metadata header are rejected with
|
||||||
|
UNAUTHENTICATED. When the variable is empty or unset, no authentication is
|
||||||
|
performed (backward compatible).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hmac
|
||||||
|
import os
|
||||||
|
|
||||||
|
import grpc
|
||||||
|
|
||||||
|
|
||||||
|
class _AbortHandler(grpc.RpcMethodHandler):
|
||||||
|
"""A method handler that immediately aborts with UNAUTHENTICATED."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.request_streaming = False
|
||||||
|
self.response_streaming = False
|
||||||
|
self.request_deserializer = None
|
||||||
|
self.response_serializer = None
|
||||||
|
self.unary_unary = self._abort
|
||||||
|
self.unary_stream = None
|
||||||
|
self.stream_unary = None
|
||||||
|
self.stream_stream = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _abort(request, context):
|
||||||
|
context.abort(grpc.StatusCode.UNAUTHENTICATED, "invalid token")
|
||||||
|
|
||||||
|
|
||||||
|
class TokenAuthInterceptor(grpc.ServerInterceptor):
|
||||||
|
"""Sync gRPC server interceptor that validates a bearer token."""
|
||||||
|
|
||||||
|
def __init__(self, token: str):
|
||||||
|
self._token = token
|
||||||
|
self._abort_handler = _AbortHandler()
|
||||||
|
|
||||||
|
def intercept_service(self, continuation, handler_call_details):
|
||||||
|
metadata = dict(handler_call_details.invocation_metadata)
|
||||||
|
auth = metadata.get("authorization", "")
|
||||||
|
expected = "Bearer " + self._token
|
||||||
|
if not hmac.compare_digest(auth, expected):
|
||||||
|
return self._abort_handler
|
||||||
|
return continuation(handler_call_details)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncTokenAuthInterceptor(grpc.aio.ServerInterceptor):
|
||||||
|
"""Async gRPC server interceptor that validates a bearer token."""
|
||||||
|
|
||||||
|
def __init__(self, token: str):
|
||||||
|
self._token = token
|
||||||
|
|
||||||
|
async def intercept_service(self, continuation, handler_call_details):
|
||||||
|
metadata = dict(handler_call_details.invocation_metadata)
|
||||||
|
auth = metadata.get("authorization", "")
|
||||||
|
expected = "Bearer " + self._token
|
||||||
|
if not hmac.compare_digest(auth, expected):
|
||||||
|
return _AbortHandler()
|
||||||
|
return await continuation(handler_call_details)
|
||||||
|
|
||||||
|
|
||||||
|
def get_auth_interceptors(*, aio: bool = False):
|
||||||
|
"""Return a list of gRPC interceptors for bearer token auth.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
aio: If True, return async-compatible interceptors for grpc.aio.server().
|
||||||
|
If False (default), return sync interceptors for grpc.server().
|
||||||
|
|
||||||
|
Returns an empty list when LOCALAI_GRPC_AUTH_TOKEN is not set.
|
||||||
|
"""
|
||||||
|
token = os.environ.get("LOCALAI_GRPC_AUTH_TOKEN", "")
|
||||||
|
if not token:
|
||||||
|
return []
|
||||||
|
if aio:
|
||||||
|
return [AsyncTokenAuthInterceptor(token)]
|
||||||
|
return [TokenAuthInterceptor(token)]
|
||||||
|
|
@ -15,6 +15,10 @@ import torch
|
||||||
from TTS.api import TTS
|
from TTS.api import TTS
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
|
|
@ -93,7 +97,9 @@ def serve(address):
|
||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,10 @@ import backend_pb2
|
||||||
import backend_pb2_grpc
|
import backend_pb2_grpc
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
# Import dynamic loader for pipeline discovery
|
# Import dynamic loader for pipeline discovery
|
||||||
from diffusers_dynamic_loader import (
|
from diffusers_dynamic_loader import (
|
||||||
|
|
@ -1042,7 +1046,9 @@ def serve(address):
|
||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,10 @@ import torch
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def is_float(s):
|
def is_float(s):
|
||||||
|
|
@ -165,6 +169,8 @@ def serve(address):
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||||
]
|
]
|
||||||
|
,
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
)
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,10 @@ import torch
|
||||||
from faster_whisper import WhisperModel
|
from faster_whisper import WhisperModel
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
|
|
@ -70,7 +74,9 @@ def serve(address):
|
||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,10 @@ import numpy as np
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def is_float(s):
|
def is_float(s):
|
||||||
|
|
@ -424,6 +428,8 @@ def serve(address):
|
||||||
("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB
|
("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB
|
||||||
("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB
|
("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB
|
||||||
],
|
],
|
||||||
|
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
)
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,10 @@ from kittentts import KittenTTS
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
|
|
@ -77,7 +81,9 @@ def serve(address):
|
||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,10 @@ from kokoro import KPipeline
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
|
|
@ -84,7 +88,9 @@ def serve(address):
|
||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,10 @@ import time
|
||||||
from concurrent import futures
|
from concurrent import futures
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
import backend_pb2
|
import backend_pb2
|
||||||
import backend_pb2_grpc
|
import backend_pb2_grpc
|
||||||
|
|
||||||
|
|
@ -398,7 +402,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
|
|
||||||
|
|
||||||
def serve(address):
|
def serve(address):
|
||||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,10 @@ import backend_pb2
|
||||||
import backend_pb2_grpc
|
import backend_pb2_grpc
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
from mlx_audio.tts.utils import load_model
|
from mlx_audio.tts.utils import load_model
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -436,7 +440,9 @@ async def serve(address):
|
||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(aio=True),
|
||||||
|
)
|
||||||
# Add the servicer to the server
|
# Add the servicer to the server
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
# Bind the server to the address
|
# Bind the server to the address
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,10 @@ import tempfile
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
import backend_pb2
|
import backend_pb2
|
||||||
import backend_pb2_grpc
|
import backend_pb2_grpc
|
||||||
|
|
@ -468,6 +472,8 @@ async def serve(address):
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||||
],
|
],
|
||||||
|
|
||||||
|
interceptors=get_auth_interceptors(aio=True),
|
||||||
)
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,10 @@ import backend_pb2
|
||||||
import backend_pb2_grpc
|
import backend_pb2_grpc
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
from mlx_vlm import load, generate, stream_generate
|
from mlx_vlm import load, generate, stream_generate
|
||||||
from mlx_vlm.prompt_utils import apply_chat_template
|
from mlx_vlm.prompt_utils import apply_chat_template
|
||||||
from mlx_vlm.utils import load_config, load_image
|
from mlx_vlm.utils import load_config, load_image
|
||||||
|
|
@ -446,7 +450,9 @@ async def serve(address):
|
||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(aio=True),
|
||||||
|
)
|
||||||
# Add the servicer to the server
|
# Add the servicer to the server
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
# Bind the server to the address
|
# Bind the server to the address
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,10 @@ import backend_pb2
|
||||||
import backend_pb2_grpc
|
import backend_pb2_grpc
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
from mlx_lm import load, generate, stream_generate
|
from mlx_lm import load, generate, stream_generate
|
||||||
from mlx_lm.sample_utils import make_sampler
|
from mlx_lm.sample_utils import make_sampler
|
||||||
from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache
|
from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache
|
||||||
|
|
@ -421,7 +425,9 @@ async def serve(address):
|
||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(aio=True),
|
||||||
|
)
|
||||||
# Add the servicer to the server
|
# Add the servicer to the server
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
# Bind the server to the address
|
# Bind the server to the address
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,10 @@ from moonshine_voice import (
|
||||||
)
|
)
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
|
|
@ -128,7 +132,9 @@ def serve(address):
|
||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,10 @@ import torch
|
||||||
import nemo.collections.asr as nemo_asr
|
import nemo.collections.asr as nemo_asr
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def is_float(s):
|
def is_float(s):
|
||||||
|
|
@ -119,7 +123,9 @@ def serve(address):
|
||||||
('grpc.max_message_length', 50 * 1024 * 1024),
|
('grpc.max_message_length', 50 * 1024 * 1024),
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,10 @@ from neuttsair.neutts import NeuTTSAir
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
def is_float(s):
|
def is_float(s):
|
||||||
"""Check if a string can be converted to float."""
|
"""Check if a string can be converted to float."""
|
||||||
|
|
@ -130,7 +134,9 @@ def serve(address):
|
||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,10 @@ import backend_pb2
|
||||||
import backend_pb2_grpc
|
import backend_pb2_grpc
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
import outetts
|
import outetts
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
|
|
@ -116,7 +120,9 @@ async def serve(address):
|
||||||
('grpc.max_message_length', 50 * 1024 * 1024),
|
('grpc.max_message_length', 50 * 1024 * 1024),
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(aio=True),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,10 @@ import torch
|
||||||
from pocket_tts import TTSModel
|
from pocket_tts import TTSModel
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
def is_float(s):
|
def is_float(s):
|
||||||
"""Check if a string can be converted to float."""
|
"""Check if a string can be converted to float."""
|
||||||
|
|
@ -225,7 +229,9 @@ def serve(address):
|
||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,10 @@ import torch
|
||||||
from qwen_asr import Qwen3ASRModel
|
from qwen_asr import Qwen3ASRModel
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def is_float(s):
|
def is_float(s):
|
||||||
|
|
@ -184,7 +188,9 @@ def serve(address):
|
||||||
('grpc.max_message_length', 50 * 1024 * 1024),
|
('grpc.max_message_length', 50 * 1024 * 1024),
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,10 @@ import hashlib
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def is_float(s):
|
def is_float(s):
|
||||||
|
|
@ -900,6 +904,8 @@ def serve(address):
|
||||||
("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB
|
("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB
|
||||||
("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB
|
("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB
|
||||||
],
|
],
|
||||||
|
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
)
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,10 @@ import backend_pb2
|
||||||
import backend_pb2_grpc
|
import backend_pb2_grpc
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
from rerankers import Reranker
|
from rerankers import Reranker
|
||||||
|
|
||||||
|
|
@ -97,7 +101,9 @@ def serve(address):
|
||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,10 @@ import base64
|
||||||
import backend_pb2
|
import backend_pb2
|
||||||
import backend_pb2_grpc
|
import backend_pb2_grpc
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
|
@ -139,7 +143,9 @@ def serve(address):
|
||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,10 @@ import backend_pb2
|
||||||
import backend_pb2_grpc
|
import backend_pb2_grpc
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
|
|
||||||
|
|
@ -532,7 +536,9 @@ async def serve(address):
|
||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(aio=True),
|
||||||
|
)
|
||||||
# Add the servicer to the server
|
# Add the servicer to the server
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
# Bind the server to the address
|
# Bind the server to the address
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,10 @@ import uuid
|
||||||
from concurrent import futures
|
from concurrent import futures
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
import backend_pb2
|
import backend_pb2
|
||||||
import backend_pb2_grpc
|
import backend_pb2_grpc
|
||||||
|
|
||||||
|
|
@ -832,6 +836,8 @@ def serve(address):
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||||
],
|
],
|
||||||
|
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
)
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,10 @@ from vibevoice.modular.modeling_vibevoice_asr import VibeVoiceASRForConditionalG
|
||||||
from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor
|
from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
def is_float(s):
|
def is_float(s):
|
||||||
"""Check if a string can be converted to float."""
|
"""Check if a string can be converted to float."""
|
||||||
|
|
@ -724,7 +728,9 @@ def serve(address):
|
||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,10 @@ import backend_pb2
|
||||||
import backend_pb2_grpc
|
import backend_pb2_grpc
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
from vllm_omni.entrypoints.omni import Omni
|
from vllm_omni.entrypoints.omni import Omni
|
||||||
from vllm_omni.outputs import OmniRequestOutput
|
from vllm_omni.outputs import OmniRequestOutput
|
||||||
|
|
@ -650,7 +654,9 @@ def serve(address):
|
||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,10 @@ import backend_pb2
|
||||||
import backend_pb2_grpc
|
import backend_pb2_grpc
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
@ -338,7 +342,9 @@ async def serve(address):
|
||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(aio=True),
|
||||||
|
)
|
||||||
# Add the servicer to the server
|
# Add the servicer to the server
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
# Bind the server to the address
|
# Bind the server to the address
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,10 @@ import backend_pb2_grpc
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
def is_float(s):
|
def is_float(s):
|
||||||
"""Check if a string can be converted to float."""
|
"""Check if a string can be converted to float."""
|
||||||
|
|
@ -297,7 +301,9 @@ def serve(address):
|
||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,10 @@ import backend_pb2
|
||||||
import backend_pb2_grpc
|
import backend_pb2_grpc
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
|
|
@ -137,7 +141,9 @@ def serve(address):
|
||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ package application
|
||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/services"
|
"github.com/mudler/LocalAI/core/services/agentpool"
|
||||||
"github.com/mudler/xlog"
|
"github.com/mudler/xlog"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -22,13 +22,23 @@ func (a *Application) RestartAgentJobService() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create new service instance
|
// Create new service instance
|
||||||
agentJobService := services.NewAgentJobService(
|
agentJobService := agentpool.NewAgentJobService(
|
||||||
a.ApplicationConfig(),
|
a.ApplicationConfig(),
|
||||||
a.ModelLoader(),
|
a.ModelLoader(),
|
||||||
a.ModelConfigLoader(),
|
a.ModelConfigLoader(),
|
||||||
a.TemplatesEvaluator(),
|
a.TemplatesEvaluator(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Re-apply distributed wiring if available (matches startup.go logic)
|
||||||
|
if d := a.Distributed(); d != nil {
|
||||||
|
if d.Dispatcher != nil {
|
||||||
|
agentJobService.SetDistributedBackends(d.Dispatcher)
|
||||||
|
}
|
||||||
|
if d.JobStore != nil {
|
||||||
|
agentJobService.SetDistributedJobStore(d.JobStore)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Start the service
|
// Start the service
|
||||||
err := agentJobService.Start(a.ApplicationConfig().Context)
|
err := agentJobService.Start(a.ApplicationConfig().Context)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -2,12 +2,16 @@ package application
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"math/rand/v2"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
||||||
"github.com/mudler/LocalAI/core/services"
|
"github.com/mudler/LocalAI/core/services/agentpool"
|
||||||
|
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||||
|
"github.com/mudler/LocalAI/core/services/nodes"
|
||||||
"github.com/mudler/LocalAI/core/templates"
|
"github.com/mudler/LocalAI/core/templates"
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
"github.com/mudler/xlog"
|
"github.com/mudler/xlog"
|
||||||
|
|
@ -20,9 +24,9 @@ type Application struct {
|
||||||
applicationConfig *config.ApplicationConfig
|
applicationConfig *config.ApplicationConfig
|
||||||
startupConfig *config.ApplicationConfig // Stores original config from env vars (before file loading)
|
startupConfig *config.ApplicationConfig // Stores original config from env vars (before file loading)
|
||||||
templatesEvaluator *templates.Evaluator
|
templatesEvaluator *templates.Evaluator
|
||||||
galleryService *services.GalleryService
|
galleryService *galleryop.GalleryService
|
||||||
agentJobService *services.AgentJobService
|
agentJobService *agentpool.AgentJobService
|
||||||
agentPoolService atomic.Pointer[services.AgentPoolService]
|
agentPoolService atomic.Pointer[agentpool.AgentPoolService]
|
||||||
authDB *gorm.DB
|
authDB *gorm.DB
|
||||||
watchdogMutex sync.Mutex
|
watchdogMutex sync.Mutex
|
||||||
watchdogStop chan bool
|
watchdogStop chan bool
|
||||||
|
|
@ -30,6 +34,9 @@ type Application struct {
|
||||||
p2pCtx context.Context
|
p2pCtx context.Context
|
||||||
p2pCancel context.CancelFunc
|
p2pCancel context.CancelFunc
|
||||||
agentJobMutex sync.Mutex
|
agentJobMutex sync.Mutex
|
||||||
|
|
||||||
|
// Distributed mode services (nil when not in distributed mode)
|
||||||
|
distributed *DistributedServices
|
||||||
}
|
}
|
||||||
|
|
||||||
func newApplication(appConfig *config.ApplicationConfig) *Application {
|
func newApplication(appConfig *config.ApplicationConfig) *Application {
|
||||||
|
|
@ -64,15 +71,15 @@ func (a *Application) TemplatesEvaluator() *templates.Evaluator {
|
||||||
return a.templatesEvaluator
|
return a.templatesEvaluator
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Application) GalleryService() *services.GalleryService {
|
func (a *Application) GalleryService() *galleryop.GalleryService {
|
||||||
return a.galleryService
|
return a.galleryService
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Application) AgentJobService() *services.AgentJobService {
|
func (a *Application) AgentJobService() *agentpool.AgentJobService {
|
||||||
return a.agentJobService
|
return a.agentJobService
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Application) AgentPoolService() *services.AgentPoolService {
|
func (a *Application) AgentPoolService() *agentpool.AgentPoolService {
|
||||||
return a.agentPoolService.Load()
|
return a.agentPoolService.Load()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -86,8 +93,53 @@ func (a *Application) StartupConfig() *config.ApplicationConfig {
|
||||||
return a.startupConfig
|
return a.startupConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Distributed returns the distributed services, or nil if not in distributed mode.
|
||||||
|
func (a *Application) Distributed() *DistributedServices {
|
||||||
|
return a.distributed
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsDistributed returns true if the application is running in distributed mode.
|
||||||
|
func (a *Application) IsDistributed() bool {
|
||||||
|
return a.distributed != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitForHealthyWorker blocks until at least one healthy backend worker is registered.
|
||||||
|
// This prevents the agent pool from failing during startup when workers haven't connected yet.
|
||||||
|
func (a *Application) waitForHealthyWorker() {
|
||||||
|
maxWait := a.applicationConfig.Distributed.WorkerWaitTimeoutOrDefault()
|
||||||
|
const basePoll = 2 * time.Second
|
||||||
|
|
||||||
|
xlog.Info("Waiting for at least one healthy backend worker before starting agent pool")
|
||||||
|
deadline := time.Now().Add(maxWait)
|
||||||
|
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
registered, err := a.distributed.Registry.List(context.Background())
|
||||||
|
if err == nil {
|
||||||
|
for _, n := range registered {
|
||||||
|
if n.NodeType == nodes.NodeTypeBackend && n.Status == nodes.StatusHealthy {
|
||||||
|
xlog.Info("Healthy backend worker found", "node", n.Name)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Add 0-1s jitter to prevent thundering-herd on the node registry
|
||||||
|
jitter := time.Duration(rand.Int64N(int64(time.Second)))
|
||||||
|
select {
|
||||||
|
case <-a.applicationConfig.Context.Done():
|
||||||
|
return
|
||||||
|
case <-time.After(basePoll + jitter):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
xlog.Warn("No healthy backend worker found after waiting, proceeding anyway")
|
||||||
|
}
|
||||||
|
|
||||||
|
// InstanceID returns the unique identifier for this frontend instance.
|
||||||
|
func (a *Application) InstanceID() string {
|
||||||
|
return a.applicationConfig.Distributed.InstanceID
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Application) start() error {
|
func (a *Application) start() error {
|
||||||
galleryService := services.NewGalleryService(a.ApplicationConfig(), a.ModelLoader())
|
galleryService := galleryop.NewGalleryService(a.ApplicationConfig(), a.ModelLoader())
|
||||||
err := galleryService.Start(a.ApplicationConfig().Context, a.ModelConfigLoader(), a.ApplicationConfig().SystemState)
|
err := galleryService.Start(a.ApplicationConfig().Context, a.ModelConfigLoader(), a.ApplicationConfig().SystemState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -95,19 +147,14 @@ func (a *Application) start() error {
|
||||||
|
|
||||||
a.galleryService = galleryService
|
a.galleryService = galleryService
|
||||||
|
|
||||||
// Initialize agent job service
|
// Initialize agent job service (Start() is deferred to after distributed wiring)
|
||||||
agentJobService := services.NewAgentJobService(
|
agentJobService := agentpool.NewAgentJobService(
|
||||||
a.ApplicationConfig(),
|
a.ApplicationConfig(),
|
||||||
a.ModelLoader(),
|
a.ModelLoader(),
|
||||||
a.ModelConfigLoader(),
|
a.ModelConfigLoader(),
|
||||||
a.TemplatesEvaluator(),
|
a.TemplatesEvaluator(),
|
||||||
)
|
)
|
||||||
|
|
||||||
err = agentJobService.Start(a.ApplicationConfig().Context)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
a.agentJobService = agentJobService
|
a.agentJobService = agentJobService
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
@ -120,27 +167,56 @@ func (a *Application) StartAgentPool() {
|
||||||
if !a.applicationConfig.AgentPool.Enabled {
|
if !a.applicationConfig.AgentPool.Enabled {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
aps, err := services.NewAgentPoolService(a.applicationConfig)
|
// Build options struct from available dependencies
|
||||||
|
opts := agentpool.AgentPoolOptions{
|
||||||
|
AuthDB: a.authDB,
|
||||||
|
}
|
||||||
|
if d := a.Distributed(); d != nil {
|
||||||
|
if d.DistStores != nil && d.DistStores.Skills != nil {
|
||||||
|
opts.SkillStore = d.DistStores.Skills
|
||||||
|
}
|
||||||
|
opts.NATSClient = d.Nats
|
||||||
|
opts.EventBridge = d.AgentBridge
|
||||||
|
opts.AgentStore = d.AgentStore
|
||||||
|
}
|
||||||
|
|
||||||
|
aps, err := agentpool.NewAgentPoolService(a.applicationConfig, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xlog.Error("Failed to create agent pool service", "error", err)
|
xlog.Error("Failed to create agent pool service", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if a.authDB != nil {
|
|
||||||
aps.SetAuthDB(a.authDB)
|
// Wire distributed mode components
|
||||||
|
if d := a.Distributed(); d != nil {
|
||||||
|
// Wait for at least one healthy backend worker before starting the agent pool.
|
||||||
|
// Collections initialization calls embeddings which require a worker.
|
||||||
|
if d.Registry != nil {
|
||||||
|
a.waitForHealthyWorker()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := aps.Start(a.applicationConfig.Context); err != nil {
|
if err := aps.Start(a.applicationConfig.Context); err != nil {
|
||||||
xlog.Error("Failed to start agent pool", "error", err)
|
xlog.Error("Failed to start agent pool", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wire per-user scoped services so collections, skills, and jobs are isolated per user
|
// Wire per-user scoped services so collections, skills, and jobs are isolated per user
|
||||||
usm := services.NewUserServicesManager(
|
usm := agentpool.NewUserServicesManager(
|
||||||
aps.UserStorage(),
|
aps.UserStorage(),
|
||||||
a.applicationConfig,
|
a.applicationConfig,
|
||||||
a.modelLoader,
|
a.modelLoader,
|
||||||
a.backendLoader,
|
a.backendLoader,
|
||||||
a.templatesEvaluator,
|
a.templatesEvaluator,
|
||||||
)
|
)
|
||||||
|
// Wire distributed backends to per-user job services
|
||||||
|
if a.agentJobService != nil {
|
||||||
|
if d := a.agentJobService.Dispatcher(); d != nil {
|
||||||
|
usm.SetJobDispatcher(d)
|
||||||
|
}
|
||||||
|
if s := a.agentJobService.DBStore(); s != nil {
|
||||||
|
usm.SetJobDBStore(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
aps.SetUserServicesManager(usm)
|
aps.SetUserServicesManager(usm)
|
||||||
|
|
||||||
a.agentPoolService.Store(aps)
|
a.agentPoolService.Store(aps)
|
||||||
|
|
|
||||||
267
core/application/distributed.go
Normal file
267
core/application/distributed.go
Normal file
|
|
@ -0,0 +1,267 @@
|
||||||
|
package application
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/core/services/agents"
|
||||||
|
"github.com/mudler/LocalAI/core/services/distributed"
|
||||||
|
"github.com/mudler/LocalAI/core/services/jobs"
|
||||||
|
"github.com/mudler/LocalAI/core/services/messaging"
|
||||||
|
"github.com/mudler/LocalAI/core/services/nodes"
|
||||||
|
"github.com/mudler/LocalAI/core/services/storage"
|
||||||
|
"github.com/mudler/LocalAI/pkg/sanitize"
|
||||||
|
"github.com/mudler/xlog"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DistributedServices holds all services initialized for distributed mode.
|
||||||
|
type DistributedServices struct {
|
||||||
|
Nats *messaging.Client
|
||||||
|
Store storage.ObjectStore
|
||||||
|
Registry *nodes.NodeRegistry
|
||||||
|
Router *nodes.SmartRouter
|
||||||
|
Health *nodes.HealthMonitor
|
||||||
|
JobStore *jobs.JobStore
|
||||||
|
Dispatcher *jobs.Dispatcher
|
||||||
|
AgentStore *agents.AgentStore
|
||||||
|
AgentBridge *agents.EventBridge
|
||||||
|
DistStores *distributed.Stores
|
||||||
|
FileMgr *storage.FileManager
|
||||||
|
FileStager nodes.FileStager
|
||||||
|
ModelAdapter *nodes.ModelRouterAdapter
|
||||||
|
Unloader *nodes.RemoteUnloaderAdapter
|
||||||
|
|
||||||
|
shutdownOnce sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown stops all distributed services in reverse initialization order.
|
||||||
|
// It is safe to call on a nil receiver and is idempotent (uses sync.Once).
|
||||||
|
func (ds *DistributedServices) Shutdown() {
|
||||||
|
if ds == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ds.shutdownOnce.Do(func() {
|
||||||
|
if ds.Health != nil {
|
||||||
|
ds.Health.Stop()
|
||||||
|
}
|
||||||
|
if ds.Dispatcher != nil {
|
||||||
|
ds.Dispatcher.Stop()
|
||||||
|
}
|
||||||
|
if closer, ok := ds.Store.(io.Closer); ok {
|
||||||
|
closer.Close()
|
||||||
|
}
|
||||||
|
// AgentBridge has no Close method — its NATS subscriptions are cleaned up
|
||||||
|
// when the NATS client is closed below.
|
||||||
|
if ds.Nats != nil {
|
||||||
|
ds.Nats.Close()
|
||||||
|
}
|
||||||
|
xlog.Info("Distributed services shut down")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// initDistributed validates distributed mode prerequisites and initializes
|
||||||
|
// NATS, object storage, node registry, and instance identity.
|
||||||
|
// Returns nil if distributed mode is not enabled.
|
||||||
|
func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB) (*DistributedServices, error) {
|
||||||
|
if !cfg.Distributed.Enabled {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
xlog.Info("Distributed mode enabled — validating prerequisites")
|
||||||
|
|
||||||
|
// Validate distributed config (NATS URL, S3 credential pairing, durations, etc.)
|
||||||
|
if err := cfg.Distributed.Validate(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate PostgreSQL is configured (auth DB must be PostgreSQL for distributed mode)
|
||||||
|
if !cfg.Auth.Enabled {
|
||||||
|
return nil, fmt.Errorf("distributed mode requires authentication to be enabled (--auth / LOCALAI_AUTH=true)")
|
||||||
|
}
|
||||||
|
if !isPostgresURL(cfg.Auth.DatabaseURL) {
|
||||||
|
return nil, fmt.Errorf("distributed mode requires PostgreSQL for auth database (got %q)", sanitize.URL(cfg.Auth.DatabaseURL))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate instance ID if not set
|
||||||
|
if cfg.Distributed.InstanceID == "" {
|
||||||
|
cfg.Distributed.InstanceID = uuid.New().String()
|
||||||
|
}
|
||||||
|
xlog.Info("Distributed instance", "id", cfg.Distributed.InstanceID)
|
||||||
|
|
||||||
|
// Connect to NATS
|
||||||
|
natsClient, err := messaging.New(cfg.Distributed.NatsURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("connecting to NATS: %w", err)
|
||||||
|
}
|
||||||
|
xlog.Info("Connected to NATS", "url", sanitize.URL(cfg.Distributed.NatsURL))
|
||||||
|
|
||||||
|
// Ensure NATS is closed if any subsequent initialization step fails.
|
||||||
|
success := false
|
||||||
|
defer func() {
|
||||||
|
if !success {
|
||||||
|
natsClient.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Initialize object storage
|
||||||
|
var store storage.ObjectStore
|
||||||
|
if cfg.Distributed.StorageURL != "" {
|
||||||
|
if cfg.Distributed.StorageBucket == "" {
|
||||||
|
return nil, fmt.Errorf("distributed storage bucket must be set when storage URL is configured")
|
||||||
|
}
|
||||||
|
s3Store, err := storage.NewS3Store(context.Background(), storage.S3Config{
|
||||||
|
Endpoint: cfg.Distributed.StorageURL,
|
||||||
|
Region: cfg.Distributed.StorageRegion,
|
||||||
|
Bucket: cfg.Distributed.StorageBucket,
|
||||||
|
AccessKeyID: cfg.Distributed.StorageAccessKey,
|
||||||
|
SecretAccessKey: cfg.Distributed.StorageSecretKey,
|
||||||
|
ForcePathStyle: true, // required for MinIO
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("initializing S3 storage: %w", err)
|
||||||
|
}
|
||||||
|
xlog.Info("Object storage initialized (S3)", "endpoint", cfg.Distributed.StorageURL, "bucket", cfg.Distributed.StorageBucket)
|
||||||
|
store = s3Store
|
||||||
|
} else {
|
||||||
|
// Fallback to filesystem storage in distributed mode (useful for single-node testing)
|
||||||
|
fsStore, err := storage.NewFilesystemStore(cfg.DataPath + "/objectstore")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("initializing filesystem storage: %w", err)
|
||||||
|
}
|
||||||
|
xlog.Info("Object storage initialized (filesystem fallback)", "path", cfg.DataPath+"/objectstore")
|
||||||
|
store = fsStore
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize node registry (requires the auth DB which is PostgreSQL)
|
||||||
|
if authDB == nil {
|
||||||
|
return nil, fmt.Errorf("distributed mode requires auth database to be initialized first")
|
||||||
|
}
|
||||||
|
|
||||||
|
registry, err := nodes.NewNodeRegistry(authDB)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("initializing node registry: %w", err)
|
||||||
|
}
|
||||||
|
xlog.Info("Node registry initialized")
|
||||||
|
|
||||||
|
// Collect SmartRouter option values; the router itself is created after all
|
||||||
|
// dependencies (including FileStager and Unloader) are ready.
|
||||||
|
var routerAuthToken string
|
||||||
|
if cfg.Distributed.RegistrationToken != "" {
|
||||||
|
routerAuthToken = cfg.Distributed.RegistrationToken
|
||||||
|
}
|
||||||
|
var routerGalleriesJSON string
|
||||||
|
if galleriesJSON, err := json.Marshal(cfg.BackendGalleries); err == nil {
|
||||||
|
routerGalleriesJSON = string(galleriesJSON)
|
||||||
|
}
|
||||||
|
|
||||||
|
healthMon := nodes.NewHealthMonitor(registry, authDB,
|
||||||
|
cfg.Distributed.HealthCheckIntervalOrDefault(),
|
||||||
|
cfg.Distributed.StaleNodeThresholdOrDefault(),
|
||||||
|
routerAuthToken,
|
||||||
|
cfg.Distributed.PerModelHealthCheck,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Initialize job store
|
||||||
|
jobStore, err := jobs.NewJobStore(authDB)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("initializing job store: %w", err)
|
||||||
|
}
|
||||||
|
xlog.Info("Distributed job store initialized")
|
||||||
|
|
||||||
|
// Initialize job dispatcher
|
||||||
|
dispatcher := jobs.NewDispatcher(jobStore, natsClient, authDB, cfg.Distributed.InstanceID, cfg.Distributed.JobWorkerConcurrency)
|
||||||
|
|
||||||
|
// Initialize agent store
|
||||||
|
agentStore, err := agents.NewAgentStore(authDB)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("initializing agent store: %w", err)
|
||||||
|
}
|
||||||
|
xlog.Info("Distributed agent store initialized")
|
||||||
|
|
||||||
|
// Initialize agent event bridge
|
||||||
|
agentBridge := agents.NewEventBridge(natsClient, agentStore, cfg.Distributed.InstanceID)
|
||||||
|
|
||||||
|
// Start observable persister — captures observable_update events from workers
|
||||||
|
// (which have no DB access) and persists them to PostgreSQL.
|
||||||
|
if err := agentBridge.StartObservablePersister(); err != nil {
|
||||||
|
xlog.Warn("Failed to start observable persister", "error", err)
|
||||||
|
} else {
|
||||||
|
xlog.Info("Observable persister started")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize Phase 4 stores (MCP, Gallery, FineTune, Skills)
|
||||||
|
distStores, err := distributed.InitStores(authDB)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("initializing distributed stores: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize file manager with local cache
|
||||||
|
cacheDir := cfg.DataPath + "/cache"
|
||||||
|
fileMgr, err := storage.NewFileManager(store, cacheDir)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("initializing file manager: %w", err)
|
||||||
|
}
|
||||||
|
xlog.Info("File manager initialized", "cacheDir", cacheDir)
|
||||||
|
|
||||||
|
// Create FileStager for distributed file transfer
|
||||||
|
var fileStager nodes.FileStager
|
||||||
|
if cfg.Distributed.StorageURL != "" {
|
||||||
|
fileStager = nodes.NewS3NATSFileStager(fileMgr, natsClient)
|
||||||
|
xlog.Info("File stager initialized (S3+NATS)")
|
||||||
|
} else {
|
||||||
|
fileStager = nodes.NewHTTPFileStager(func(nodeID string) (string, error) {
|
||||||
|
node, err := registry.Get(context.Background(), nodeID)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if node.HTTPAddress == "" {
|
||||||
|
return "", fmt.Errorf("node %s has no HTTP address for file transfer", nodeID)
|
||||||
|
}
|
||||||
|
return node.HTTPAddress, nil
|
||||||
|
}, cfg.Distributed.RegistrationToken)
|
||||||
|
xlog.Info("File stager initialized (HTTP direct transfer)")
|
||||||
|
}
|
||||||
|
// Create RemoteUnloaderAdapter — needed by SmartRouter and startup.go
|
||||||
|
remoteUnloader := nodes.NewRemoteUnloaderAdapter(registry, natsClient)
|
||||||
|
|
||||||
|
// All dependencies ready — build SmartRouter with all options at once
|
||||||
|
router := nodes.NewSmartRouter(registry, nodes.SmartRouterOptions{
|
||||||
|
Unloader: remoteUnloader,
|
||||||
|
FileStager: fileStager,
|
||||||
|
GalleriesJSON: routerGalleriesJSON,
|
||||||
|
AuthToken: routerAuthToken,
|
||||||
|
DB: authDB,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create ModelRouterAdapter to wire into ModelLoader
|
||||||
|
modelAdapter := nodes.NewModelRouterAdapter(router)
|
||||||
|
|
||||||
|
success = true
|
||||||
|
return &DistributedServices{
|
||||||
|
Nats: natsClient,
|
||||||
|
Store: store,
|
||||||
|
Registry: registry,
|
||||||
|
Router: router,
|
||||||
|
Health: healthMon,
|
||||||
|
JobStore: jobStore,
|
||||||
|
Dispatcher: dispatcher,
|
||||||
|
AgentStore: agentStore,
|
||||||
|
AgentBridge: agentBridge,
|
||||||
|
DistStores: distStores,
|
||||||
|
FileMgr: fileMgr,
|
||||||
|
FileStager: fileStager,
|
||||||
|
ModelAdapter: modelAdapter,
|
||||||
|
Unloader: remoteUnloader,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isPostgresURL(url string) bool {
|
||||||
|
return strings.HasPrefix(url, "postgres://") || strings.HasPrefix(url, "postgresql://")
|
||||||
|
}
|
||||||
|
|
@ -11,7 +11,7 @@ import (
|
||||||
"github.com/mudler/LocalAI/core/gallery"
|
"github.com/mudler/LocalAI/core/gallery"
|
||||||
"github.com/mudler/LocalAI/core/p2p"
|
"github.com/mudler/LocalAI/core/p2p"
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
"github.com/mudler/LocalAI/core/services"
|
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||||
|
|
||||||
"github.com/mudler/edgevpn/pkg/node"
|
"github.com/mudler/edgevpn/pkg/node"
|
||||||
"github.com/mudler/xlog"
|
"github.com/mudler/xlog"
|
||||||
|
|
@ -146,22 +146,14 @@ func (a *Application) RestartP2P() error {
|
||||||
return fmt.Errorf("P2P token is not set")
|
return fmt.Errorf("P2P token is not set")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create new context for P2P
|
|
||||||
ctx, cancel := context.WithCancel(appConfig.Context)
|
|
||||||
a.p2pCtx = ctx
|
|
||||||
a.p2pCancel = cancel
|
|
||||||
|
|
||||||
// Get API address from config
|
|
||||||
address := appConfig.APIAddress
|
|
||||||
if address == "" {
|
|
||||||
address = "127.0.0.1:8080" // default
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start P2P stack in a goroutine
|
// Start P2P stack in a goroutine
|
||||||
|
// Note: StartP2P creates its own context and assigns a.p2pCtx/a.p2pCancel
|
||||||
go func() {
|
go func() {
|
||||||
if err := a.StartP2P(); err != nil {
|
if err := a.StartP2P(); err != nil {
|
||||||
xlog.Error("Failed to start P2P stack", "error", err)
|
xlog.Error("Failed to start P2P stack", "error", err)
|
||||||
cancel() // Cancel context on error
|
if a.p2pCancel != nil {
|
||||||
|
a.p2pCancel()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
xlog.Info("P2P stack restarted with new settings")
|
xlog.Info("P2P stack restarted with new settings")
|
||||||
|
|
@ -228,7 +220,7 @@ func syncState(ctx context.Context, n *node.Node, app *Application) error {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
app.GalleryService().ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
|
app.GalleryService().ModelGalleryChannel <- galleryop.ManagementOp[gallery.GalleryModel, gallery.ModelConfig]{
|
||||||
ID: uuid.String(),
|
ID: uuid.String(),
|
||||||
GalleryElementName: model,
|
GalleryElementName: model,
|
||||||
Galleries: app.ApplicationConfig().Galleries,
|
Galleries: app.ApplicationConfig().Galleries,
|
||||||
|
|
|
||||||
|
|
@ -13,11 +13,15 @@ import (
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
"github.com/mudler/LocalAI/core/gallery"
|
"github.com/mudler/LocalAI/core/gallery"
|
||||||
"github.com/mudler/LocalAI/core/http/auth"
|
"github.com/mudler/LocalAI/core/http/auth"
|
||||||
"github.com/mudler/LocalAI/core/services"
|
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||||
|
"github.com/mudler/LocalAI/core/services/jobs"
|
||||||
|
"github.com/mudler/LocalAI/core/services/nodes"
|
||||||
|
"github.com/mudler/LocalAI/core/services/storage"
|
||||||
coreStartup "github.com/mudler/LocalAI/core/startup"
|
coreStartup "github.com/mudler/LocalAI/core/startup"
|
||||||
"github.com/mudler/LocalAI/internal"
|
"github.com/mudler/LocalAI/internal"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/mudler/LocalAI/pkg/sanitize"
|
||||||
"github.com/mudler/LocalAI/pkg/xsysinfo"
|
"github.com/mudler/LocalAI/pkg/xsysinfo"
|
||||||
"github.com/mudler/xlog"
|
"github.com/mudler/xlog"
|
||||||
)
|
)
|
||||||
|
|
@ -101,7 +105,7 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||||
return nil, fmt.Errorf("failed to initialize auth database: %w", err)
|
return nil, fmt.Errorf("failed to initialize auth database: %w", err)
|
||||||
}
|
}
|
||||||
application.authDB = authDB
|
application.authDB = authDB
|
||||||
xlog.Info("Auth enabled", "database", options.Auth.DatabaseURL)
|
xlog.Info("Auth enabled", "database", sanitize.URL(options.Auth.DatabaseURL))
|
||||||
|
|
||||||
// Start session and expired API key cleanup goroutine
|
// Start session and expired API key cleanup goroutine
|
||||||
go func() {
|
go func() {
|
||||||
|
|
@ -123,12 +127,92 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Wire JobStore for DB-backed task/job persistence whenever auth DB is available.
|
||||||
|
// This ensures tasks and jobs survive restarts in both single-node and distributed modes.
|
||||||
|
if application.authDB != nil && application.agentJobService != nil {
|
||||||
|
dbJobStore, err := jobs.NewJobStore(application.authDB)
|
||||||
|
if err != nil {
|
||||||
|
xlog.Error("Failed to create job store for auth DB", "error", err)
|
||||||
|
} else {
|
||||||
|
application.agentJobService.SetDistributedJobStore(dbJobStore)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize distributed mode services (NATS, object storage, node registry)
|
||||||
|
distSvc, err := initDistributed(options, application.authDB)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("distributed mode initialization failed: %w", err)
|
||||||
|
}
|
||||||
|
if distSvc != nil {
|
||||||
|
application.distributed = distSvc
|
||||||
|
// Wire remote model unloader so ShutdownModel works for remote nodes
|
||||||
|
// Uses NATS to tell serve-backend nodes to Free + kill their backend process
|
||||||
|
application.modelLoader.SetRemoteUnloader(distSvc.Unloader)
|
||||||
|
// Wire ModelRouter so grpcModel() delegates to SmartRouter in distributed mode
|
||||||
|
application.modelLoader.SetModelRouter(distSvc.ModelAdapter.AsModelRouter())
|
||||||
|
// Wire DistributedModelStore so shutdown/list/watchdog can find remote models
|
||||||
|
distStore := nodes.NewDistributedModelStore(
|
||||||
|
model.NewInMemoryModelStore(),
|
||||||
|
distSvc.Registry,
|
||||||
|
)
|
||||||
|
application.modelLoader.SetModelStore(distStore)
|
||||||
|
// Start health monitor
|
||||||
|
distSvc.Health.Start(options.Context)
|
||||||
|
// In distributed mode, MCP CI jobs are executed by agent workers (not the frontend)
|
||||||
|
// because the frontend can't create MCP sessions (e.g., stdio servers using docker).
|
||||||
|
// The dispatcher still subscribes to jobs.new for persistence (result/progress subs)
|
||||||
|
// but does NOT set a workerFn — agent workers consume jobs from the same NATS queue.
|
||||||
|
|
||||||
|
// Wire model config loader so job events include model config for agent workers
|
||||||
|
distSvc.Dispatcher.SetModelConfigLoader(application.backendLoader)
|
||||||
|
|
||||||
|
// Start job dispatcher — abort startup if it fails, as jobs would be accepted but never dispatched
|
||||||
|
if err := distSvc.Dispatcher.Start(options.Context); err != nil {
|
||||||
|
return nil, fmt.Errorf("starting job dispatcher: %w", err)
|
||||||
|
}
|
||||||
|
// Start ephemeral file cleanup
|
||||||
|
storage.StartEphemeralCleanup(options.Context, distSvc.FileMgr, 0, 0)
|
||||||
|
// Wire distributed backends into AgentJobService (before Start)
|
||||||
|
if application.agentJobService != nil {
|
||||||
|
application.agentJobService.SetDistributedBackends(distSvc.Dispatcher)
|
||||||
|
application.agentJobService.SetDistributedJobStore(distSvc.JobStore)
|
||||||
|
}
|
||||||
|
// Wire skill store into AgentPoolService (wired at pool start time via closure)
|
||||||
|
// The actual wiring happens in StartAgentPool since the pool doesn't exist yet.
|
||||||
|
|
||||||
|
// Wire NATS and gallery store into GalleryService for cross-instance progress/cancel
|
||||||
|
if application.galleryService != nil {
|
||||||
|
application.galleryService.SetNATSClient(distSvc.Nats)
|
||||||
|
if distSvc.DistStores != nil && distSvc.DistStores.Gallery != nil {
|
||||||
|
// Clean up stale in-progress operations from previous crashed instances
|
||||||
|
if err := distSvc.DistStores.Gallery.CleanStale(30 * time.Minute); err != nil {
|
||||||
|
xlog.Warn("Failed to clean stale gallery operations", "error", err)
|
||||||
|
}
|
||||||
|
application.galleryService.SetGalleryStore(distSvc.DistStores.Gallery)
|
||||||
|
}
|
||||||
|
// Wire distributed model/backend managers so delete propagates to workers
|
||||||
|
application.galleryService.SetModelManager(
|
||||||
|
nodes.NewDistributedModelManager(options, application.modelLoader, distSvc.Unloader),
|
||||||
|
)
|
||||||
|
application.galleryService.SetBackendManager(
|
||||||
|
nodes.NewDistributedBackendManager(options, application.modelLoader, distSvc.Unloader, distSvc.Registry),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start AgentJobService (after distributed wiring so it knows whether to use local or NATS)
|
||||||
|
if application.agentJobService != nil {
|
||||||
|
if err := application.agentJobService.Start(options.Context); err != nil {
|
||||||
|
return nil, fmt.Errorf("starting agent job service: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := coreStartup.InstallModels(options.Context, application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
|
if err := coreStartup.InstallModels(options.Context, application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
|
||||||
xlog.Error("error installing models", "error", err)
|
xlog.Error("error installing models", "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, backend := range options.ExternalBackends {
|
for _, backend := range options.ExternalBackends {
|
||||||
if err := services.InstallExternalBackend(options.Context, options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
|
if err := galleryop.InstallExternalBackend(options.Context, options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
|
||||||
xlog.Error("error installing external backend", "error", err)
|
xlog.Error("error installing external backend", "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -154,13 +238,13 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if options.PreloadJSONModels != "" {
|
if options.PreloadJSONModels != "" {
|
||||||
if err := services.ApplyGalleryFromString(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadJSONModels); err != nil {
|
if err := galleryop.ApplyGalleryFromString(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadJSONModels); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if options.PreloadModelsFromPath != "" {
|
if options.PreloadModelsFromPath != "" {
|
||||||
if err := services.ApplyGalleryFromFile(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadModelsFromPath); err != nil {
|
if err := galleryop.ApplyGalleryFromFile(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadModelsFromPath); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -184,6 +268,7 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||||
go func() {
|
go func() {
|
||||||
<-options.Context.Done()
|
<-options.Context.Done()
|
||||||
xlog.Debug("Context canceled, shutting down")
|
xlog.Debug("Context canceled, shutting down")
|
||||||
|
application.distributed.Shutdown()
|
||||||
err := application.ModelLoader().StopAllGRPC()
|
err := application.ModelLoader().StopAllGRPC()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xlog.Error("error while stopping all grpc backends", "error", err)
|
xlog.Error("error while stopping all grpc backends", "error", err)
|
||||||
|
|
@ -207,7 +292,7 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||||
var backendErr error
|
var backendErr error
|
||||||
_, backendErr = application.ModelLoader().Load(o...)
|
_, backendErr = application.ModelLoader().Load(o...)
|
||||||
if backendErr != nil {
|
if backendErr != nil {
|
||||||
return nil, err
|
return nil, backendErr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -13,9 +13,9 @@ import (
|
||||||
"github.com/mudler/xlog"
|
"github.com/mudler/xlog"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
"github.com/mudler/LocalAI/core/trace"
|
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
"github.com/mudler/LocalAI/core/services"
|
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||||
|
"github.com/mudler/LocalAI/core/trace"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/gallery"
|
"github.com/mudler/LocalAI/core/gallery"
|
||||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
|
|
@ -27,7 +27,7 @@ type LLMResponse struct {
|
||||||
Response string // should this be []byte?
|
Response string // should this be []byte?
|
||||||
Usage TokenUsage
|
Usage TokenUsage
|
||||||
AudioOutput string
|
AudioOutput string
|
||||||
Logprobs *schema.Logprobs // Logprobs from the backend response
|
Logprobs *schema.Logprobs // Logprobs from the backend response
|
||||||
ChatDeltas []*proto.ChatDelta // Pre-parsed tool calls/content from C++ autoparser
|
ChatDeltas []*proto.ChatDelta // Pre-parsed tool calls/content from C++ autoparser
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -47,14 +47,18 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||||
|
|
||||||
// Check if the modelFile exists, if it doesn't try to load it from the gallery
|
// Check if the modelFile exists, if it doesn't try to load it from the gallery
|
||||||
if o.AutoloadGalleries { // experimental
|
if o.AutoloadGalleries { // experimental
|
||||||
modelNames, err := services.ListModels(cl, loader, nil, services.SKIP_ALWAYS)
|
modelNames, err := galleryop.ListModels(cl, loader, nil, galleryop.SKIP_ALWAYS)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if !slices.Contains(modelNames, c.Name) {
|
modelName := c.Name
|
||||||
|
if modelName == "" {
|
||||||
|
modelName = c.Model
|
||||||
|
}
|
||||||
|
if !slices.Contains(modelNames, modelName) {
|
||||||
utils.ResetDownloadTimers()
|
utils.ResetDownloadTimers()
|
||||||
// if we failed to load the model, we try to download it
|
// if we failed to load the model, we try to download it
|
||||||
err := gallery.InstallModelFromGallery(ctx, o.Galleries, o.BackendGalleries, o.SystemState, loader, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
|
err := gallery.InstallModelFromGallery(ctx, o.Galleries, o.BackendGalleries, o.SystemState, loader, modelName, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xlog.Error("failed to install model from gallery", "error", err, "model", modelFile)
|
xlog.Error("failed to install model from gallery", "error", err, "model", modelFile)
|
||||||
//return nil, err
|
//return nil, err
|
||||||
|
|
@ -252,12 +256,12 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||||
trace.InitBackendTracingIfEnabled(o.TracingMaxItems)
|
trace.InitBackendTracingIfEnabled(o.TracingMaxItems)
|
||||||
|
|
||||||
traceData := map[string]any{
|
traceData := map[string]any{
|
||||||
"chat_template": c.TemplateConfig.Chat,
|
"chat_template": c.TemplateConfig.Chat,
|
||||||
"function_template": c.TemplateConfig.Functions,
|
"function_template": c.TemplateConfig.Functions,
|
||||||
"streaming": tokenCallback != nil,
|
"streaming": tokenCallback != nil,
|
||||||
"images_count": len(images),
|
"images_count": len(images),
|
||||||
"videos_count": len(videos),
|
"videos_count": len(videos),
|
||||||
"audios_count": len(audios),
|
"audios_count": len(audios),
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(messages) > 0 {
|
if len(messages) > 0 {
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
package backend
|
package backend
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"math/rand"
|
"math/rand/v2"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
@ -86,7 +86,7 @@ func getSeed(c config.ModelConfig) int32 {
|
||||||
}
|
}
|
||||||
|
|
||||||
if seed == config.RAND_SEED {
|
if seed == config.RAND_SEED {
|
||||||
seed = rand.Int31()
|
seed = rand.Int32()
|
||||||
}
|
}
|
||||||
|
|
||||||
return seed
|
return seed
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,8 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
"github.com/mudler/LocalAI/core/trace"
|
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
|
"github.com/mudler/LocalAI/core/trace"
|
||||||
"github.com/mudler/LocalAI/pkg/grpc"
|
"github.com/mudler/LocalAI/pkg/grpc"
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -8,11 +8,11 @@ import (
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
|
||||||
"github.com/mudler/LocalAI/core/config"
|
|
||||||
"github.com/mudler/LocalAI/core/services"
|
|
||||||
"github.com/mudler/LocalAGI/core/state"
|
"github.com/mudler/LocalAGI/core/state"
|
||||||
coreTypes "github.com/mudler/LocalAGI/core/types"
|
coreTypes "github.com/mudler/LocalAGI/core/types"
|
||||||
|
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/core/services/agentpool"
|
||||||
"github.com/mudler/xlog"
|
"github.com/mudler/xlog"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -59,7 +59,7 @@ func (r *AgentRunCMD) Run(ctx *cliContext.Context) error {
|
||||||
|
|
||||||
appConfig := r.buildAppConfig()
|
appConfig := r.buildAppConfig()
|
||||||
|
|
||||||
poolService, err := services.NewAgentPoolService(appConfig)
|
poolService, err := agentpool.NewAgentPoolService(appConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create agent pool service: %w", err)
|
return fmt.Errorf("failed to create agent pool service: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
463
core/cli/agent_worker.go
Normal file
463
core/cli/agent_worker.go
Normal file
|
|
@ -0,0 +1,463 @@
|
||||||
|
package cli
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||||
|
"github.com/mudler/LocalAI/core/cli/workerregistry"
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
||||||
|
"github.com/mudler/LocalAI/core/services/agents"
|
||||||
|
"github.com/mudler/LocalAI/core/services/jobs"
|
||||||
|
mcpRemote "github.com/mudler/LocalAI/core/services/mcp"
|
||||||
|
"github.com/mudler/LocalAI/core/services/messaging"
|
||||||
|
"github.com/mudler/LocalAI/pkg/sanitize"
|
||||||
|
"github.com/mudler/cogito"
|
||||||
|
"github.com/mudler/cogito/clients"
|
||||||
|
"github.com/mudler/xlog"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AgentWorkerCMD starts a dedicated agent worker process for distributed mode.
|
||||||
|
// It registers with the frontend, subscribes to the NATS agent execution queue,
|
||||||
|
// and executes agent chats using cogito. The worker is a pure executor — it
|
||||||
|
// receives the full agent config and skills in the NATS job payload, so it
|
||||||
|
// does not need direct database access.
|
||||||
|
//
|
||||||
|
// Usage:
|
||||||
|
//
|
||||||
|
// localai agent-worker --nats-url nats://... --register-to http://localai:8080
|
||||||
|
type AgentWorkerCMD struct {
|
||||||
|
// NATS (required)
|
||||||
|
NatsURL string `env:"LOCALAI_NATS_URL" required:"" help:"NATS server URL" group:"distributed"`
|
||||||
|
|
||||||
|
// Registration (required)
|
||||||
|
RegisterTo string `env:"LOCALAI_REGISTER_TO" required:"" help:"Frontend URL for registration" group:"registration"`
|
||||||
|
NodeName string `env:"LOCALAI_NODE_NAME" help:"Node name for registration (defaults to hostname)" group:"registration"`
|
||||||
|
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token for authenticating with the frontend" group:"registration"`
|
||||||
|
HeartbeatInterval string `env:"LOCALAI_HEARTBEAT_INTERVAL" default:"10s" help:"Interval between heartbeats" group:"registration"`
|
||||||
|
|
||||||
|
// API access
|
||||||
|
APIURL string `env:"LOCALAI_API_URL" help:"LocalAI API URL for inference (auto-derived from RegisterTo if not set)" group:"api"`
|
||||||
|
APIToken string `env:"LOCALAI_API_TOKEN" help:"API token for LocalAI inference (auto-provisioned during registration if not set)" group:"api"`
|
||||||
|
|
||||||
|
// NATS subjects
|
||||||
|
Subject string `env:"LOCALAI_AGENT_SUBJECT" default:"agent.execute" help:"NATS subject for agent execution" group:"distributed"`
|
||||||
|
Queue string `env:"LOCALAI_AGENT_QUEUE" default:"agent-workers" help:"NATS queue group name" group:"distributed"`
|
||||||
|
|
||||||
|
// Timeouts
|
||||||
|
MCPCIJobTimeout string `env:"LOCALAI_MCP_CI_JOB_TIMEOUT" default:"10m" help:"Timeout for MCP CI job execution" group:"distributed"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
|
||||||
|
xlog.Info("Starting agent worker", "nats", sanitize.URL(cmd.NatsURL), "register_to", cmd.RegisterTo)
|
||||||
|
|
||||||
|
// Resolve API URL
|
||||||
|
apiURL := cmp.Or(cmd.APIURL, strings.TrimRight(cmd.RegisterTo, "/"))
|
||||||
|
|
||||||
|
// Register with frontend
|
||||||
|
regClient := &workerregistry.RegistrationClient{
|
||||||
|
FrontendURL: cmd.RegisterTo,
|
||||||
|
RegistrationToken: cmd.RegistrationToken,
|
||||||
|
}
|
||||||
|
|
||||||
|
nodeName := cmd.NodeName
|
||||||
|
if nodeName == "" {
|
||||||
|
hostname, _ := os.Hostname()
|
||||||
|
nodeName = "agent-" + hostname
|
||||||
|
}
|
||||||
|
registrationBody := map[string]any{
|
||||||
|
"name": nodeName,
|
||||||
|
"node_type": "agent",
|
||||||
|
}
|
||||||
|
if cmd.RegistrationToken != "" {
|
||||||
|
registrationBody["token"] = cmd.RegistrationToken
|
||||||
|
}
|
||||||
|
|
||||||
|
nodeID, apiToken, err := regClient.RegisterWithRetry(context.Background(), registrationBody, 10)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("registration failed: %w", err)
|
||||||
|
}
|
||||||
|
xlog.Info("Registered with frontend", "nodeID", nodeID, "frontend", cmd.RegisterTo)
|
||||||
|
|
||||||
|
// Use provisioned API token if none was set
|
||||||
|
if cmd.APIToken == "" {
|
||||||
|
cmd.APIToken = apiToken
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start heartbeat
|
||||||
|
heartbeatInterval, err := time.ParseDuration(cmd.HeartbeatInterval)
|
||||||
|
if err != nil && cmd.HeartbeatInterval != "" {
|
||||||
|
xlog.Warn("invalid heartbeat interval, using default 10s", "input", cmd.HeartbeatInterval, "error", err)
|
||||||
|
}
|
||||||
|
heartbeatInterval = cmp.Or(heartbeatInterval, 10*time.Second)
|
||||||
|
// Context cancelled on shutdown — used by heartbeat and other background goroutines
|
||||||
|
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||||
|
defer shutdownCancel()
|
||||||
|
|
||||||
|
go regClient.HeartbeatLoop(shutdownCtx, nodeID, heartbeatInterval, func() map[string]any { return map[string]any{} })
|
||||||
|
|
||||||
|
// Connect to NATS
|
||||||
|
natsClient, err := messaging.New(cmd.NatsURL)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("connecting to NATS: %w", err)
|
||||||
|
}
|
||||||
|
defer natsClient.Close()
|
||||||
|
|
||||||
|
// Create event bridge for publishing results back via NATS
|
||||||
|
eventBridge := agents.NewEventBridge(natsClient, nil, "agent-worker-"+nodeID)
|
||||||
|
|
||||||
|
// Start cancel listener
|
||||||
|
cancelSub, err := eventBridge.StartCancelListener()
|
||||||
|
if err != nil {
|
||||||
|
xlog.Warn("Failed to start cancel listener", "error", err)
|
||||||
|
} else {
|
||||||
|
defer cancelSub.Unsubscribe()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create and start the NATS dispatcher.
|
||||||
|
// No ConfigProvider or SkillStore needed — config and skills arrive in the job payload.
|
||||||
|
dispatcher := agents.NewNATSDispatcher(
|
||||||
|
natsClient,
|
||||||
|
eventBridge,
|
||||||
|
nil, // no ConfigProvider: config comes in the enriched NATS payload
|
||||||
|
apiURL, cmd.APIToken,
|
||||||
|
cmd.Subject, cmd.Queue,
|
||||||
|
0, // no concurrency limit (CLI worker)
|
||||||
|
)
|
||||||
|
|
||||||
|
if err := dispatcher.Start(shutdownCtx); err != nil {
|
||||||
|
return fmt.Errorf("starting dispatcher: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe to MCP tool execution requests (load-balanced across workers).
|
||||||
|
// The frontend routes model-level MCP tool calls here via NATS request-reply.
|
||||||
|
if _, err := natsClient.QueueSubscribeReply(messaging.SubjectMCPToolExecute, messaging.QueueAgentWorkers, func(data []byte, reply func([]byte)) {
|
||||||
|
handleMCPToolRequest(data, reply)
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("subscribing to %s: %w", messaging.SubjectMCPToolExecute, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe to MCP discovery requests (load-balanced across workers).
|
||||||
|
if _, err := natsClient.QueueSubscribeReply(messaging.SubjectMCPDiscovery, messaging.QueueAgentWorkers, func(data []byte, reply func([]byte)) {
|
||||||
|
handleMCPDiscoveryRequest(data, reply)
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("subscribing to %s: %w", messaging.SubjectMCPDiscovery, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe to MCP CI job execution (load-balanced across agent workers).
|
||||||
|
// In distributed mode, MCP CI jobs are routed here because the frontend
|
||||||
|
// cannot create MCP sessions (e.g., stdio servers using docker).
|
||||||
|
mcpCIJobTimeout, err := time.ParseDuration(cmd.MCPCIJobTimeout)
|
||||||
|
if err != nil && cmd.MCPCIJobTimeout != "" {
|
||||||
|
xlog.Warn("invalid MCP CI job timeout, using default 10m", "input", cmd.MCPCIJobTimeout, "error", err)
|
||||||
|
}
|
||||||
|
mcpCIJobTimeout = cmp.Or(mcpCIJobTimeout, config.DefaultMCPCIJobTimeout)
|
||||||
|
|
||||||
|
if _, err := natsClient.QueueSubscribe(messaging.SubjectMCPCIJobsNew, messaging.QueueWorkers, func(data []byte) {
|
||||||
|
handleMCPCIJob(shutdownCtx, data, apiURL, cmd.APIToken, natsClient, mcpCIJobTimeout)
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("subscribing to %s: %w", messaging.SubjectMCPCIJobsNew, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe to backend stop events to clean up cached MCP sessions.
|
||||||
|
// In the main application this is done via ml.OnModelUnload, but the agent
|
||||||
|
// worker has no model loader — we listen for the NATS stop event instead.
|
||||||
|
if _, err := natsClient.Subscribe(messaging.SubjectNodeBackendStop(nodeID), func(data []byte) {
|
||||||
|
var req struct {
|
||||||
|
Backend string `json:"backend"`
|
||||||
|
}
|
||||||
|
if json.Unmarshal(data, &req) == nil && req.Backend != "" {
|
||||||
|
mcpTools.CloseMCPSessions(req.Backend)
|
||||||
|
}
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("subscribing to %s: %w", messaging.SubjectNodeBackendStop(nodeID), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
xlog.Info("Agent worker ready, waiting for jobs", "subject", cmd.Subject, "queue", cmd.Queue)
|
||||||
|
|
||||||
|
// Wait for shutdown
|
||||||
|
sigCh := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
<-sigCh
|
||||||
|
|
||||||
|
xlog.Info("Shutting down agent worker")
|
||||||
|
shutdownCancel() // stop heartbeat loop immediately
|
||||||
|
dispatcher.Stop()
|
||||||
|
mcpTools.CloseAllMCPSessions()
|
||||||
|
regClient.GracefulDeregister(nodeID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleMCPToolRequest handles a NATS request-reply for MCP tool execution.
|
||||||
|
// The worker creates/caches MCP sessions from the serialized config and executes the tool.
|
||||||
|
func handleMCPToolRequest(data []byte, reply func([]byte)) {
|
||||||
|
var req mcpRemote.MCPToolRequest
|
||||||
|
if err := json.Unmarshal(data, &req); err != nil {
|
||||||
|
sendMCPToolReply(reply, "", fmt.Sprintf("unmarshal error: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), config.DefaultMCPToolTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Create/cache named MCP sessions from the provided config
|
||||||
|
namedSessions, err := mcpTools.NamedSessionsFromMCPConfig(req.ModelName, req.RemoteServers, req.StdioServers, nil)
|
||||||
|
if err != nil {
|
||||||
|
sendMCPToolReply(reply, "", fmt.Sprintf("session error: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Discover tools to find the right session
|
||||||
|
tools, err := mcpTools.DiscoverMCPTools(ctx, namedSessions)
|
||||||
|
if err != nil {
|
||||||
|
sendMCPToolReply(reply, "", fmt.Sprintf("discovery error: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute the tool
|
||||||
|
argsJSON, _ := json.Marshal(req.Arguments)
|
||||||
|
result, err := mcpTools.ExecuteMCPToolCall(ctx, tools, req.ToolName, string(argsJSON))
|
||||||
|
if err != nil {
|
||||||
|
sendMCPToolReply(reply, "", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sendMCPToolReply(reply, result, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendMCPToolReply(reply func([]byte), result, errMsg string) {
|
||||||
|
resp := mcpRemote.MCPToolResponse{Result: result, Error: errMsg}
|
||||||
|
data, _ := json.Marshal(resp)
|
||||||
|
reply(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleMCPDiscoveryRequest handles a NATS request-reply for MCP tool/prompt/resource discovery.
|
||||||
|
func handleMCPDiscoveryRequest(data []byte, reply func([]byte)) {
|
||||||
|
var req mcpRemote.MCPDiscoveryRequest
|
||||||
|
if err := json.Unmarshal(data, &req); err != nil {
|
||||||
|
sendMCPDiscoveryReply(reply, nil, nil, fmt.Sprintf("unmarshal error: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), config.DefaultMCPDiscoveryTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Create/cache named MCP sessions
|
||||||
|
namedSessions, err := mcpTools.NamedSessionsFromMCPConfig(req.ModelName, req.RemoteServers, req.StdioServers, nil)
|
||||||
|
if err != nil {
|
||||||
|
sendMCPDiscoveryReply(reply, nil, nil, fmt.Sprintf("session error: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// List servers with their tools/prompts/resources
|
||||||
|
serverInfos, err := mcpTools.ListMCPServers(ctx, namedSessions)
|
||||||
|
if err != nil {
|
||||||
|
sendMCPDiscoveryReply(reply, nil, nil, fmt.Sprintf("list error: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Also get tool function schemas for the frontend
|
||||||
|
tools, _ := mcpTools.DiscoverMCPTools(ctx, namedSessions)
|
||||||
|
var toolDefs []mcpRemote.MCPToolDef
|
||||||
|
for _, t := range tools {
|
||||||
|
toolDefs = append(toolDefs, mcpRemote.MCPToolDef{
|
||||||
|
ServerName: t.ServerName,
|
||||||
|
ToolName: t.ToolName,
|
||||||
|
Function: t.Function,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert server infos
|
||||||
|
var servers []mcpRemote.MCPServerInfo
|
||||||
|
for _, s := range serverInfos {
|
||||||
|
servers = append(servers, mcpRemote.MCPServerInfo{
|
||||||
|
Name: s.Name,
|
||||||
|
Type: s.Type,
|
||||||
|
Tools: s.Tools,
|
||||||
|
Prompts: s.Prompts,
|
||||||
|
Resources: s.Resources,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
sendMCPDiscoveryReply(reply, servers, toolDefs, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendMCPDiscoveryReply(reply func([]byte), servers []mcpRemote.MCPServerInfo, tools []mcpRemote.MCPToolDef, errMsg string) {
|
||||||
|
resp := mcpRemote.MCPDiscoveryResponse{Servers: servers, Tools: tools, Error: errMsg}
|
||||||
|
data, _ := json.Marshal(resp)
|
||||||
|
reply(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleMCPCIJob processes an MCP CI job on the agent worker.
|
||||||
|
// The agent worker can create MCP sessions (has docker) and call the LocalAI API for inference.
|
||||||
|
func handleMCPCIJob(shutdownCtx context.Context, data []byte, apiURL, apiToken string, natsClient messaging.MessagingClient, jobTimeout time.Duration) {
|
||||||
|
var evt jobs.JobEvent
|
||||||
|
if err := json.Unmarshal(data, &evt); err != nil {
|
||||||
|
xlog.Error("Failed to unmarshal job event", "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
job := evt.Job
|
||||||
|
task := evt.Task
|
||||||
|
if job == nil || task == nil {
|
||||||
|
xlog.Error("MCP CI job missing enriched data", "jobID", evt.JobID)
|
||||||
|
publishJobResult(natsClient, evt.JobID, "failed", "", "job or task data missing from NATS event")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
modelCfg := evt.ModelConfig
|
||||||
|
if modelCfg == nil {
|
||||||
|
publishJobResult(natsClient, evt.JobID, "failed", "", "model config missing from job event")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
xlog.Info("Processing MCP CI job", "jobID", evt.JobID, "taskID", evt.TaskID, "model", task.Model)
|
||||||
|
|
||||||
|
// Publish running status
|
||||||
|
natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{
|
||||||
|
JobID: evt.JobID, Status: "running", Message: "Job started on agent worker",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Parse MCP config
|
||||||
|
if modelCfg.MCP.Servers == "" && modelCfg.MCP.Stdio == "" {
|
||||||
|
publishJobResult(natsClient, evt.JobID, "failed", "", "no MCP servers configured for model")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
remote, stdio, err := modelCfg.MCP.MCPConfigFromYAML()
|
||||||
|
if err != nil {
|
||||||
|
publishJobResult(natsClient, evt.JobID, "failed", "", fmt.Sprintf("failed to parse MCP config: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create MCP sessions locally (agent worker has docker)
|
||||||
|
sessions, err := mcpTools.SessionsFromMCPConfig(modelCfg.Name, remote, stdio)
|
||||||
|
if err != nil || len(sessions) == 0 {
|
||||||
|
errMsg := "no working MCP servers found"
|
||||||
|
if err != nil {
|
||||||
|
errMsg = fmt.Sprintf("failed to create MCP sessions: %v", err)
|
||||||
|
}
|
||||||
|
publishJobResult(natsClient, evt.JobID, "failed", "", errMsg)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build prompt from template
|
||||||
|
prompt := task.Prompt
|
||||||
|
if task.CronParametersJSON != "" {
|
||||||
|
var params map[string]string
|
||||||
|
if err := json.Unmarshal([]byte(task.CronParametersJSON), ¶ms); err != nil {
|
||||||
|
xlog.Warn("Failed to unmarshal parameters", "error", err)
|
||||||
|
}
|
||||||
|
for k, v := range params {
|
||||||
|
prompt = strings.ReplaceAll(prompt, "{{."+k+"}}", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if job.ParametersJSON != "" {
|
||||||
|
var params map[string]string
|
||||||
|
if err := json.Unmarshal([]byte(job.ParametersJSON), ¶ms); err != nil {
|
||||||
|
xlog.Warn("Failed to unmarshal parameters", "error", err)
|
||||||
|
}
|
||||||
|
for k, v := range params {
|
||||||
|
prompt = strings.ReplaceAll(prompt, "{{."+k+"}}", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create LLM client pointing back to the frontend API
|
||||||
|
llm := clients.NewLocalAILLM(task.Model, apiToken, apiURL)
|
||||||
|
|
||||||
|
// Build cogito options
|
||||||
|
ctx, cancel := context.WithTimeout(shutdownCtx, jobTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Update job status to running in DB
|
||||||
|
publishJobStatus(natsClient, evt.JobID, "running", "")
|
||||||
|
|
||||||
|
// Buffer stream tokens and flush as complete blocks
|
||||||
|
var reasoningBuf, contentBuf strings.Builder
|
||||||
|
var lastStreamType cogito.StreamEventType
|
||||||
|
|
||||||
|
flushStreamBuf := func() {
|
||||||
|
if reasoningBuf.Len() > 0 {
|
||||||
|
natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{
|
||||||
|
JobID: evt.JobID, TraceType: "reasoning", TraceContent: reasoningBuf.String(),
|
||||||
|
})
|
||||||
|
reasoningBuf.Reset()
|
||||||
|
}
|
||||||
|
if contentBuf.Len() > 0 {
|
||||||
|
natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{
|
||||||
|
JobID: evt.JobID, TraceType: "content", TraceContent: contentBuf.String(),
|
||||||
|
})
|
||||||
|
contentBuf.Reset()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cogitoOpts := modelCfg.BuildCogitoOptions()
|
||||||
|
cogitoOpts = append(cogitoOpts,
|
||||||
|
cogito.WithContext(ctx),
|
||||||
|
cogito.WithMCPs(sessions...),
|
||||||
|
cogito.WithStatusCallback(func(status string) {
|
||||||
|
flushStreamBuf()
|
||||||
|
natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{
|
||||||
|
JobID: evt.JobID, TraceType: "status", TraceContent: status,
|
||||||
|
})
|
||||||
|
}),
|
||||||
|
cogito.WithToolCallResultCallback(func(t cogito.ToolStatus) {
|
||||||
|
flushStreamBuf()
|
||||||
|
natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{
|
||||||
|
JobID: evt.JobID, TraceType: "tool_result", TraceContent: fmt.Sprintf("%s: %s", t.Name, t.Result),
|
||||||
|
})
|
||||||
|
}),
|
||||||
|
cogito.WithStreamCallback(func(ev cogito.StreamEvent) {
|
||||||
|
// Flush if stream type changed (e.g., reasoning → content)
|
||||||
|
if ev.Type != lastStreamType {
|
||||||
|
flushStreamBuf()
|
||||||
|
lastStreamType = ev.Type
|
||||||
|
}
|
||||||
|
switch ev.Type {
|
||||||
|
case cogito.StreamEventReasoning:
|
||||||
|
reasoningBuf.WriteString(ev.Content)
|
||||||
|
case cogito.StreamEventContent:
|
||||||
|
contentBuf.WriteString(ev.Content)
|
||||||
|
case cogito.StreamEventToolCall:
|
||||||
|
natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{
|
||||||
|
JobID: evt.JobID, TraceType: "tool_call", TraceContent: fmt.Sprintf("%s(%s)", ev.ToolName, ev.ToolArgs),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Execute via cogito
|
||||||
|
fragment := cogito.NewEmptyFragment()
|
||||||
|
fragment = fragment.AddMessage("user", prompt)
|
||||||
|
|
||||||
|
f, err := cogito.ExecuteTools(llm, fragment, cogitoOpts...)
|
||||||
|
flushStreamBuf() // flush any remaining buffered tokens
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
publishJobResult(natsClient, evt.JobID, "failed", "", fmt.Sprintf("cogito execution failed: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
result := ""
|
||||||
|
if msg := f.LastMessage(); msg != nil {
|
||||||
|
result = msg.Content
|
||||||
|
}
|
||||||
|
publishJobResult(natsClient, evt.JobID, "completed", result, "")
|
||||||
|
xlog.Info("MCP CI job completed", "jobID", evt.JobID, "resultLen", len(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
func publishJobStatus(nc messaging.MessagingClient, jobID, status, message string) {
|
||||||
|
jobs.PublishJobProgress(nc, jobID, status, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func publishJobResult(nc messaging.MessagingClient, jobID, status, result, errMsg string) {
|
||||||
|
jobs.PublishJobResult(nc, jobID, status, result, errMsg)
|
||||||
|
}
|
||||||
|
|
@ -8,7 +8,7 @@ import (
|
||||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
"github.com/mudler/LocalAI/core/gallery"
|
"github.com/mudler/LocalAI/core/gallery"
|
||||||
"github.com/mudler/LocalAI/core/services"
|
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
"github.com/mudler/LocalAI/pkg/system"
|
"github.com/mudler/LocalAI/pkg/system"
|
||||||
|
|
||||||
|
|
@ -103,7 +103,7 @@ func (bi *BackendsInstall) Run(ctx *cliContext.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
modelLoader := model.NewModelLoader(systemState)
|
modelLoader := model.NewModelLoader(systemState)
|
||||||
err = services.InstallExternalBackend(context.Background(), galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
|
err = galleryop.InstallExternalBackend(context.Background(), galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,9 @@ var CLI struct {
|
||||||
TTS TTSCMD `cmd:"" help:"Convert text to speech"`
|
TTS TTSCMD `cmd:"" help:"Convert text to speech"`
|
||||||
SoundGeneration SoundGenerationCMD `cmd:"" help:"Generates audio files from text or audio"`
|
SoundGeneration SoundGenerationCMD `cmd:"" help:"Generates audio files from text or audio"`
|
||||||
Transcript TranscriptCMD `cmd:"" help:"Convert audio to text"`
|
Transcript TranscriptCMD `cmd:"" help:"Convert audio to text"`
|
||||||
Worker worker.Worker `cmd:"" help:"Run workers to distribute workload (llama.cpp-only)"`
|
P2PWorker worker.Worker `cmd:"" name:"p2p-worker" help:"Run workers to distribute workload via p2p (llama.cpp-only)"`
|
||||||
|
Worker WorkerCMD `cmd:"" help:"Start a worker for distributed mode (generic, backend-agnostic)"`
|
||||||
|
AgentWorker AgentWorkerCMD `cmd:"" name:"agent-worker" help:"Start an agent worker for distributed mode (executes agent chats via NATS)"`
|
||||||
Util UtilCMD `cmd:"" help:"Utility commands"`
|
Util UtilCMD `cmd:"" help:"Utility commands"`
|
||||||
Agent AgentCMD `cmd:"" help:"Run agents standalone without the full LocalAI server"`
|
Agent AgentCMD `cmd:"" help:"Run agents standalone without the full LocalAI server"`
|
||||||
Explorer ExplorerCMD `cmd:"" help:"Run p2p explorer"`
|
Explorer ExplorerCMD `cmd:"" help:"Run p2p explorer"`
|
||||||
|
|
|
||||||
|
|
@ -186,9 +186,9 @@ _local_ai_completions()
|
||||||
}
|
}
|
||||||
subcmds := []string{}
|
subcmds := []string{}
|
||||||
for _, sub := range cmds {
|
for _, sub := range cmds {
|
||||||
parts := strings.SplitN(sub.fullName, " ", 2)
|
parent, child, found := strings.Cut(sub.fullName, " ")
|
||||||
if len(parts) == 2 && parts[0] == cmd.name && !strings.Contains(parts[1], " ") {
|
if found && parent == cmd.name && !strings.Contains(child, " ") {
|
||||||
subcmds = append(subcmds, parts[1])
|
subcmds = append(subcmds, child)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(subcmds) > 0 {
|
if len(subcmds) > 0 {
|
||||||
|
|
@ -279,8 +279,8 @@ _local_ai() {
|
||||||
// Check for subcommands
|
// Check for subcommands
|
||||||
subcmds := []commandInfo{}
|
subcmds := []commandInfo{}
|
||||||
for _, sub := range cmds {
|
for _, sub := range cmds {
|
||||||
parts := strings.SplitN(sub.fullName, " ", 2)
|
parent, child, found := strings.Cut(sub.fullName, " ")
|
||||||
if len(parts) == 2 && parts[0] == cmd.name && !strings.Contains(parts[1], " ") {
|
if found && parent == cmd.name && !strings.Contains(child, " ") {
|
||||||
subcmds = append(subcmds, sub)
|
subcmds = append(subcmds, sub)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -289,11 +289,11 @@ _local_ai() {
|
||||||
sb.WriteString(" local -a subcmds\n")
|
sb.WriteString(" local -a subcmds\n")
|
||||||
sb.WriteString(" subcmds=(\n")
|
sb.WriteString(" subcmds=(\n")
|
||||||
for _, sub := range subcmds {
|
for _, sub := range subcmds {
|
||||||
parts := strings.SplitN(sub.fullName, " ", 2)
|
_, child, _ := strings.Cut(sub.fullName, " ")
|
||||||
help := strings.ReplaceAll(sub.help, "'", "'\\''")
|
help := strings.ReplaceAll(sub.help, "'", "'\\''")
|
||||||
help = strings.ReplaceAll(help, "[", "\\[")
|
help = strings.ReplaceAll(help, "[", "\\[")
|
||||||
help = strings.ReplaceAll(help, "]", "\\]")
|
help = strings.ReplaceAll(help, "]", "\\]")
|
||||||
sb.WriteString(fmt.Sprintf(" '%s:%s'\n", parts[1], help))
|
sb.WriteString(fmt.Sprintf(" '%s:%s'\n", child, help))
|
||||||
}
|
}
|
||||||
sb.WriteString(" )\n")
|
sb.WriteString(" )\n")
|
||||||
sb.WriteString(" _describe -t commands 'subcommands' subcmds\n")
|
sb.WriteString(" _describe -t commands 'subcommands' subcmds\n")
|
||||||
|
|
@ -372,10 +372,10 @@ func generateFishCompletion(app *kong.Application) string {
|
||||||
|
|
||||||
// Subcommands
|
// Subcommands
|
||||||
for _, sub := range cmds {
|
for _, sub := range cmds {
|
||||||
parts := strings.SplitN(sub.fullName, " ", 2)
|
parent, child, found := strings.Cut(sub.fullName, " ")
|
||||||
if len(parts) == 2 && parts[0] == cmd.name && !strings.Contains(parts[1], " ") {
|
if found && parent == cmd.name && !strings.Contains(child, " ") {
|
||||||
help := strings.ReplaceAll(sub.help, "'", "\\'")
|
help := strings.ReplaceAll(sub.help, "'", "\\'")
|
||||||
sb.WriteString(fmt.Sprintf("complete -c local-ai -n '__fish_seen_subcommand_from %s' -a %s -d '%s'\n", cmd.name, parts[1], help))
|
sb.WriteString(fmt.Sprintf("complete -c local-ai -n '__fish_seen_subcommand_from %s' -a %s -d '%s'\n", cmd.name, child, help))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,8 +9,8 @@ import (
|
||||||
|
|
||||||
func getTestApp() *kong.Application {
|
func getTestApp() *kong.Application {
|
||||||
var testCLI struct {
|
var testCLI struct {
|
||||||
Run struct{} `cmd:"" help:"Run the server"`
|
Run struct{} `cmd:"" help:"Run the server"`
|
||||||
Models struct {
|
Models struct {
|
||||||
List struct{} `cmd:"" help:"List models"`
|
List struct{} `cmd:"" help:"List models"`
|
||||||
Install struct{} `cmd:"" help:"Install a model"`
|
Install struct{} `cmd:"" help:"Install a model"`
|
||||||
} `cmd:"" help:"Manage models"`
|
} `cmd:"" help:"Manage models"`
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ import (
|
||||||
|
|
||||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
"github.com/mudler/LocalAI/core/services"
|
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/gallery"
|
"github.com/mudler/LocalAI/core/gallery"
|
||||||
"github.com/mudler/LocalAI/core/startup"
|
"github.com/mudler/LocalAI/core/startup"
|
||||||
|
|
@ -80,7 +80,7 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
galleryService := services.NewGalleryService(&config.ApplicationConfig{
|
galleryService := galleryop.NewGalleryService(&config.ApplicationConfig{
|
||||||
SystemState: systemState,
|
SystemState: systemState,
|
||||||
}, model.NewModelLoader(systemState))
|
}, model.NewModelLoader(systemState))
|
||||||
err = galleryService.Start(context.Background(), config.NewModelConfigLoader(mi.ModelsPath), systemState)
|
err = galleryService.Start(context.Background(), config.NewModelConfigLoader(mi.ModelsPath), systemState)
|
||||||
|
|
|
||||||
|
|
@ -44,9 +44,9 @@ type RunCMD struct {
|
||||||
Galleries string `env:"LOCALAI_GALLERIES,GALLERIES" help:"JSON list of galleries" group:"models" default:"${galleries}"`
|
Galleries string `env:"LOCALAI_GALLERIES,GALLERIES" help:"JSON list of galleries" group:"models" default:"${galleries}"`
|
||||||
AutoloadGalleries bool `env:"LOCALAI_AUTOLOAD_GALLERIES,AUTOLOAD_GALLERIES" group:"models" default:"true"`
|
AutoloadGalleries bool `env:"LOCALAI_AUTOLOAD_GALLERIES,AUTOLOAD_GALLERIES" group:"models" default:"true"`
|
||||||
AutoloadBackendGalleries bool `env:"LOCALAI_AUTOLOAD_BACKEND_GALLERIES,AUTOLOAD_BACKEND_GALLERIES" group:"backends" default:"true"`
|
AutoloadBackendGalleries bool `env:"LOCALAI_AUTOLOAD_BACKEND_GALLERIES,AUTOLOAD_BACKEND_GALLERIES" group:"backends" default:"true"`
|
||||||
BackendImagesReleaseTag string `env:"LOCALAI_BACKEND_IMAGES_RELEASE_TAG,BACKEND_IMAGES_RELEASE_TAG" help:"Fallback release tag for backend images" group:"backends" default:"latest"`
|
BackendImagesReleaseTag string `env:"LOCALAI_BACKEND_IMAGES_RELEASE_TAG,BACKEND_IMAGES_RELEASE_TAG" help:"Fallback release tag for backend images" group:"backends" default:"latest"`
|
||||||
BackendImagesBranchTag string `env:"LOCALAI_BACKEND_IMAGES_BRANCH_TAG,BACKEND_IMAGES_BRANCH_TAG" help:"Fallback branch tag for backend images" group:"backends" default:"master"`
|
BackendImagesBranchTag string `env:"LOCALAI_BACKEND_IMAGES_BRANCH_TAG,BACKEND_IMAGES_BRANCH_TAG" help:"Fallback branch tag for backend images" group:"backends" default:"master"`
|
||||||
BackendDevSuffix string `env:"LOCALAI_BACKEND_DEV_SUFFIX,BACKEND_DEV_SUFFIX" help:"Development suffix for backend images" group:"backends" default:"development"`
|
BackendDevSuffix string `env:"LOCALAI_BACKEND_DEV_SUFFIX,BACKEND_DEV_SUFFIX" help:"Development suffix for backend images" group:"backends" default:"development"`
|
||||||
PreloadModels string `env:"LOCALAI_PRELOAD_MODELS,PRELOAD_MODELS" help:"A List of models to apply in JSON at start" group:"models"`
|
PreloadModels string `env:"LOCALAI_PRELOAD_MODELS,PRELOAD_MODELS" help:"A List of models to apply in JSON at start" group:"models"`
|
||||||
Models []string `env:"LOCALAI_MODELS,MODELS" help:"A List of model configuration URLs to load" group:"models"`
|
Models []string `env:"LOCALAI_MODELS,MODELS" help:"A List of model configuration URLs to load" group:"models"`
|
||||||
PreloadModelsConfig string `env:"LOCALAI_PRELOAD_MODELS_CONFIG,PRELOAD_MODELS_CONFIG" help:"A List of models to apply at startup. Path to a YAML config file" group:"models"`
|
PreloadModelsConfig string `env:"LOCALAI_PRELOAD_MODELS_CONFIG,PRELOAD_MODELS_CONFIG" help:"A List of models to apply at startup. Path to a YAML config file" group:"models"`
|
||||||
|
|
@ -100,7 +100,7 @@ type RunCMD struct {
|
||||||
OpenResponsesStoreTTL string `env:"LOCALAI_OPEN_RESPONSES_STORE_TTL,OPEN_RESPONSES_STORE_TTL" default:"0" help:"TTL for Open Responses store (e.g., 1h, 30m, 0 = no expiration)" group:"api"`
|
OpenResponsesStoreTTL string `env:"LOCALAI_OPEN_RESPONSES_STORE_TTL,OPEN_RESPONSES_STORE_TTL" default:"0" help:"TTL for Open Responses store (e.g., 1h, 30m, 0 = no expiration)" group:"api"`
|
||||||
|
|
||||||
// Agent Pool (LocalAGI)
|
// Agent Pool (LocalAGI)
|
||||||
DisableAgents bool `env:"LOCALAI_DISABLE_AGENTS" default:"false" help:"Disable the agent pool feature" group:"agents"`
|
DisableAgents bool `env:"LOCALAI_DISABLE_AGENTS" default:"false" help:"Disable the agent pool feature" group:"agents"`
|
||||||
AgentPoolAPIURL string `env:"LOCALAI_AGENT_POOL_API_URL" help:"Default API URL for agents (defaults to self-referencing LocalAI)" group:"agents"`
|
AgentPoolAPIURL string `env:"LOCALAI_AGENT_POOL_API_URL" help:"Default API URL for agents (defaults to self-referencing LocalAI)" group:"agents"`
|
||||||
AgentPoolAPIKey string `env:"LOCALAI_AGENT_POOL_API_KEY" help:"Default API key for agents (defaults to first LocalAI API key)" group:"agents"`
|
AgentPoolAPIKey string `env:"LOCALAI_AGENT_POOL_API_KEY" help:"Default API key for agents (defaults to first LocalAI API key)" group:"agents"`
|
||||||
AgentPoolDefaultModel string `env:"LOCALAI_AGENT_POOL_DEFAULT_MODEL" help:"Default model for agents" group:"agents"`
|
AgentPoolDefaultModel string `env:"LOCALAI_AGENT_POOL_DEFAULT_MODEL" help:"Default model for agents" group:"agents"`
|
||||||
|
|
@ -109,17 +109,17 @@ type RunCMD struct {
|
||||||
AgentPoolTranscriptionLanguage string `env:"LOCALAI_AGENT_POOL_TRANSCRIPTION_LANGUAGE" help:"Default transcription language for agents" group:"agents"`
|
AgentPoolTranscriptionLanguage string `env:"LOCALAI_AGENT_POOL_TRANSCRIPTION_LANGUAGE" help:"Default transcription language for agents" group:"agents"`
|
||||||
AgentPoolTTSModel string `env:"LOCALAI_AGENT_POOL_TTS_MODEL" help:"Default TTS model for agents" group:"agents"`
|
AgentPoolTTSModel string `env:"LOCALAI_AGENT_POOL_TTS_MODEL" help:"Default TTS model for agents" group:"agents"`
|
||||||
AgentPoolStateDir string `env:"LOCALAI_AGENT_POOL_STATE_DIR" help:"State directory for agent pool" group:"agents"`
|
AgentPoolStateDir string `env:"LOCALAI_AGENT_POOL_STATE_DIR" help:"State directory for agent pool" group:"agents"`
|
||||||
AgentPoolTimeout string `env:"LOCALAI_AGENT_POOL_TIMEOUT" default:"5m" help:"Default agent timeout" group:"agents"`
|
AgentPoolTimeout string `env:"LOCALAI_AGENT_POOL_TIMEOUT" default:"5m" help:"Default agent timeout" group:"agents"`
|
||||||
AgentPoolEnableSkills bool `env:"LOCALAI_AGENT_POOL_ENABLE_SKILLS" default:"false" help:"Enable skills service for agents" group:"agents"`
|
AgentPoolEnableSkills bool `env:"LOCALAI_AGENT_POOL_ENABLE_SKILLS" default:"false" help:"Enable skills service for agents" group:"agents"`
|
||||||
AgentPoolVectorEngine string `env:"LOCALAI_AGENT_POOL_VECTOR_ENGINE" default:"chromem" help:"Vector engine type for agent knowledge base" group:"agents"`
|
AgentPoolVectorEngine string `env:"LOCALAI_AGENT_POOL_VECTOR_ENGINE" default:"chromem" help:"Vector engine type for agent knowledge base" group:"agents"`
|
||||||
AgentPoolEmbeddingModel string `env:"LOCALAI_AGENT_POOL_EMBEDDING_MODEL" default:"granite-embedding-107m-multilingual" help:"Embedding model for agent knowledge base" group:"agents"`
|
AgentPoolEmbeddingModel string `env:"LOCALAI_AGENT_POOL_EMBEDDING_MODEL" default:"granite-embedding-107m-multilingual" help:"Embedding model for agent knowledge base" group:"agents"`
|
||||||
AgentPoolCustomActionsDir string `env:"LOCALAI_AGENT_POOL_CUSTOM_ACTIONS_DIR" help:"Custom actions directory for agents" group:"agents"`
|
AgentPoolCustomActionsDir string `env:"LOCALAI_AGENT_POOL_CUSTOM_ACTIONS_DIR" help:"Custom actions directory for agents" group:"agents"`
|
||||||
AgentPoolDatabaseURL string `env:"LOCALAI_AGENT_POOL_DATABASE_URL" help:"Database URL for agent collections" group:"agents"`
|
AgentPoolDatabaseURL string `env:"LOCALAI_AGENT_POOL_DATABASE_URL" help:"Database URL for agent collections" group:"agents"`
|
||||||
AgentPoolMaxChunkingSize int `env:"LOCALAI_AGENT_POOL_MAX_CHUNKING_SIZE" default:"400" help:"Maximum chunking size for knowledge base documents" group:"agents"`
|
AgentPoolMaxChunkingSize int `env:"LOCALAI_AGENT_POOL_MAX_CHUNKING_SIZE" default:"400" help:"Maximum chunking size for knowledge base documents" group:"agents"`
|
||||||
AgentPoolChunkOverlap int `env:"LOCALAI_AGENT_POOL_CHUNK_OVERLAP" default:"0" help:"Chunk overlap size for knowledge base documents" group:"agents"`
|
AgentPoolChunkOverlap int `env:"LOCALAI_AGENT_POOL_CHUNK_OVERLAP" default:"0" help:"Chunk overlap size for knowledge base documents" group:"agents"`
|
||||||
AgentPoolEnableLogs bool `env:"LOCALAI_AGENT_POOL_ENABLE_LOGS" default:"false" help:"Enable agent logging" group:"agents"`
|
AgentPoolEnableLogs bool `env:"LOCALAI_AGENT_POOL_ENABLE_LOGS" default:"false" help:"Enable agent logging" group:"agents"`
|
||||||
AgentPoolCollectionDBPath string `env:"LOCALAI_AGENT_POOL_COLLECTION_DB_PATH" help:"Database path for agent collections" group:"agents"`
|
AgentPoolCollectionDBPath string `env:"LOCALAI_AGENT_POOL_COLLECTION_DB_PATH" help:"Database path for agent collections" group:"agents"`
|
||||||
AgentHubURL string `env:"LOCALAI_AGENT_HUB_URL" default:"https://agenthub.localai.io" help:"URL for the agent hub where users can browse and download agent configurations" group:"agents"`
|
AgentHubURL string `env:"LOCALAI_AGENT_HUB_URL" default:"https://agenthub.localai.io" help:"URL for the agent hub where users can browse and download agent configurations" group:"agents"`
|
||||||
|
|
||||||
// Authentication
|
// Authentication
|
||||||
AuthEnabled bool `env:"LOCALAI_AUTH" default:"false" help:"Enable user authentication and authorization" group:"auth"`
|
AuthEnabled bool `env:"LOCALAI_AUTH" default:"false" help:"Enable user authentication and authorization" group:"auth"`
|
||||||
|
|
@ -136,6 +136,18 @@ type RunCMD struct {
|
||||||
AuthAPIKeyHMACSecret string `env:"LOCALAI_AUTH_HMAC_SECRET" help:"HMAC secret for API key hashing (auto-generated if empty)" group:"auth"`
|
AuthAPIKeyHMACSecret string `env:"LOCALAI_AUTH_HMAC_SECRET" help:"HMAC secret for API key hashing (auto-generated if empty)" group:"auth"`
|
||||||
DefaultAPIKeyExpiry string `env:"LOCALAI_DEFAULT_API_KEY_EXPIRY" help:"Default expiry for API keys (e.g. 90d, 1y; empty = no expiry)" group:"auth"`
|
DefaultAPIKeyExpiry string `env:"LOCALAI_DEFAULT_API_KEY_EXPIRY" help:"Default expiry for API keys (e.g. 90d, 1y; empty = no expiry)" group:"auth"`
|
||||||
|
|
||||||
|
// Distributed / Horizontal Scaling
|
||||||
|
Distributed bool `env:"LOCALAI_DISTRIBUTED" default:"false" help:"Enable distributed mode (requires PostgreSQL + NATS)" group:"distributed"`
|
||||||
|
InstanceID string `env:"LOCALAI_INSTANCE_ID" help:"Unique instance ID for distributed mode (auto-generated UUID if empty)" group:"distributed"`
|
||||||
|
NatsURL string `env:"LOCALAI_NATS_URL" help:"NATS server URL (e.g., nats://localhost:4222)" group:"distributed"`
|
||||||
|
StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3-compatible storage endpoint URL (e.g., http://minio:9000)" group:"distributed"`
|
||||||
|
StorageBucket string `env:"LOCALAI_STORAGE_BUCKET" default:"localai" help:"S3 bucket name for object storage" group:"distributed"`
|
||||||
|
StorageRegion string `env:"LOCALAI_STORAGE_REGION" default:"us-east-1" help:"S3 region" group:"distributed"`
|
||||||
|
StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key ID" group:"distributed"`
|
||||||
|
StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret access key" group:"distributed"`
|
||||||
|
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token that backend nodes must provide to register (empty = no auth required)" group:"distributed"`
|
||||||
|
AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"`
|
||||||
|
|
||||||
Version bool
|
Version bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -210,6 +222,38 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Distributed mode
|
||||||
|
if r.Distributed {
|
||||||
|
opts = append(opts, config.EnableDistributed)
|
||||||
|
}
|
||||||
|
if r.InstanceID != "" {
|
||||||
|
opts = append(opts, config.WithDistributedInstanceID(r.InstanceID))
|
||||||
|
}
|
||||||
|
if r.NatsURL != "" {
|
||||||
|
opts = append(opts, config.WithNatsURL(r.NatsURL))
|
||||||
|
}
|
||||||
|
if r.StorageURL != "" {
|
||||||
|
opts = append(opts, config.WithStorageURL(r.StorageURL))
|
||||||
|
}
|
||||||
|
if r.StorageBucket != "" {
|
||||||
|
opts = append(opts, config.WithStorageBucket(r.StorageBucket))
|
||||||
|
}
|
||||||
|
if r.StorageRegion != "" {
|
||||||
|
opts = append(opts, config.WithStorageRegion(r.StorageRegion))
|
||||||
|
}
|
||||||
|
if r.StorageAccessKey != "" {
|
||||||
|
opts = append(opts, config.WithStorageAccessKey(r.StorageAccessKey))
|
||||||
|
}
|
||||||
|
if r.StorageSecretKey != "" {
|
||||||
|
opts = append(opts, config.WithStorageSecretKey(r.StorageSecretKey))
|
||||||
|
}
|
||||||
|
if r.RegistrationToken != "" {
|
||||||
|
opts = append(opts, config.WithRegistrationToken(r.RegistrationToken))
|
||||||
|
}
|
||||||
|
if r.AutoApproveNodes {
|
||||||
|
opts = append(opts, config.EnableAutoApproveNodes)
|
||||||
|
}
|
||||||
|
|
||||||
if r.DisableMetricsEndpoint {
|
if r.DisableMetricsEndpoint {
|
||||||
opts = append(opts, config.DisableMetricsEndpoint)
|
opts = append(opts, config.DisableMetricsEndpoint)
|
||||||
}
|
}
|
||||||
|
|
@ -218,10 +262,6 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||||
opts = append(opts, config.DisableRuntimeSettings)
|
opts = append(opts, config.DisableRuntimeSettings)
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.EnableTracing {
|
|
||||||
opts = append(opts, config.EnableTracing)
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.EnableTracing {
|
if r.EnableTracing {
|
||||||
opts = append(opts, config.EnableTracing)
|
opts = append(opts, config.EnableTracing)
|
||||||
}
|
}
|
||||||
|
|
@ -479,6 +519,10 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||||
if err := app.ModelLoader().StopAllGRPC(); err != nil {
|
if err := app.ModelLoader().StopAllGRPC(); err != nil {
|
||||||
xlog.Error("error while stopping all grpc backends", "error", err)
|
xlog.Error("error while stopping all grpc backends", "error", err)
|
||||||
}
|
}
|
||||||
|
// Clean up distributed services (idempotent — safe if already called)
|
||||||
|
if d := app.Distributed(); d != nil {
|
||||||
|
d.Shutdown()
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// Start the agent pool after the HTTP server is listening, because
|
// Start the agent pool after the HTTP server is listening, because
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,6 @@ import (
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
"github.com/mudler/LocalAI/core/gallery"
|
"github.com/mudler/LocalAI/core/gallery"
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
"github.com/mudler/LocalAI/pkg/format"
|
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
"github.com/mudler/LocalAI/pkg/system"
|
"github.com/mudler/LocalAI/pkg/system"
|
||||||
"github.com/mudler/xlog"
|
"github.com/mudler/xlog"
|
||||||
|
|
@ -80,7 +79,7 @@ func (t *TranscriptCMD) Run(ctx *cliContext.Context) error {
|
||||||
|
|
||||||
switch t.ResponseFormat {
|
switch t.ResponseFormat {
|
||||||
case schema.TranscriptionResponseFormatLrc, schema.TranscriptionResponseFormatSrt, schema.TranscriptionResponseFormatVtt, schema.TranscriptionResponseFormatText:
|
case schema.TranscriptionResponseFormatLrc, schema.TranscriptionResponseFormatSrt, schema.TranscriptionResponseFormatVtt, schema.TranscriptionResponseFormatText:
|
||||||
fmt.Println(format.TranscriptionResponse(tr, t.ResponseFormat))
|
fmt.Println(schema.TranscriptionResponse(tr, t.ResponseFormat))
|
||||||
case schema.TranscriptionResponseFormatJson:
|
case schema.TranscriptionResponseFormatJson:
|
||||||
tr.Segments = nil
|
tr.Segments = nil
|
||||||
fallthrough
|
fallthrough
|
||||||
|
|
|
||||||
897
core/cli/worker.go
Normal file
897
core/cli/worker.go
Normal file
|
|
@ -0,0 +1,897 @@
|
||||||
|
package cli
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"maps"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||||
|
"github.com/mudler/LocalAI/core/cli/workerregistry"
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/core/gallery"
|
||||||
|
"github.com/mudler/LocalAI/core/services/messaging"
|
||||||
|
"github.com/mudler/LocalAI/core/services/nodes"
|
||||||
|
"github.com/mudler/LocalAI/core/services/storage"
|
||||||
|
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||||
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/mudler/LocalAI/pkg/sanitize"
|
||||||
|
"github.com/mudler/LocalAI/pkg/system"
|
||||||
|
"github.com/mudler/LocalAI/pkg/xsysinfo"
|
||||||
|
process "github.com/mudler/go-processmanager"
|
||||||
|
"github.com/mudler/xlog"
|
||||||
|
)
|
||||||
|
|
||||||
|
// isPathAllowed checks if path is within one of the allowed directories.
|
||||||
|
func isPathAllowed(path string, allowedDirs []string) bool {
|
||||||
|
absPath, err := filepath.Abs(path)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
resolved, err := filepath.EvalSymlinks(absPath)
|
||||||
|
if err != nil {
|
||||||
|
// Path may not exist yet; use the absolute path
|
||||||
|
resolved = absPath
|
||||||
|
}
|
||||||
|
for _, dir := range allowedDirs {
|
||||||
|
absDir, err := filepath.Abs(dir)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(resolved, absDir+string(filepath.Separator)) || resolved == absDir {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// WorkerCMD starts a generic worker process for distributed mode.
|
||||||
|
// Workers are backend-agnostic — they wait for backend.install NATS events
|
||||||
|
// from the SmartRouter to install and start the required backend.
|
||||||
|
//
|
||||||
|
// NATS is required. The worker acts as a process supervisor:
|
||||||
|
// - Receives backend.install → installs backend from gallery, starts gRPC process, replies success
|
||||||
|
// - Receives backend.stop → stops the gRPC process
|
||||||
|
// - Receives stop → full shutdown (deregister + exit)
|
||||||
|
//
|
||||||
|
// Model loading (LoadModel) is always via direct gRPC — no NATS needed for that.
|
||||||
|
type WorkerCMD struct {
|
||||||
|
Addr string `env:"LOCALAI_SERVE_ADDR" default:"0.0.0.0:50051" help:"Address to bind the gRPC server to" group:"server"`
|
||||||
|
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends" group:"server"`
|
||||||
|
BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH" type:"path" default:"/var/lib/local-ai/backends" help:"Path containing system backends" group:"server"`
|
||||||
|
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"server" default:"${backends}"`
|
||||||
|
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models" group:"server"`
|
||||||
|
|
||||||
|
// HTTP file transfer
|
||||||
|
HTTPAddr string `env:"LOCALAI_HTTP_ADDR" default:"" help:"HTTP file transfer server address (default: gRPC port + 1)" group:"server"`
|
||||||
|
AdvertiseHTTPAddr string `env:"LOCALAI_ADVERTISE_HTTP_ADDR" help:"HTTP address the frontend uses to reach this node for file transfer" group:"server"`
|
||||||
|
|
||||||
|
// Registration (required)
|
||||||
|
AdvertiseAddr string `env:"LOCALAI_ADVERTISE_ADDR" help:"Address the frontend uses to reach this node (defaults to hostname:port from Addr)" group:"registration"`
|
||||||
|
RegisterTo string `env:"LOCALAI_REGISTER_TO" required:"" help:"Frontend URL for registration" group:"registration"`
|
||||||
|
NodeName string `env:"LOCALAI_NODE_NAME" help:"Node name for registration (defaults to hostname)" group:"registration"`
|
||||||
|
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token for authenticating with the frontend" group:"registration"`
|
||||||
|
HeartbeatInterval string `env:"LOCALAI_HEARTBEAT_INTERVAL" default:"10s" help:"Interval between heartbeats" group:"registration"`
|
||||||
|
|
||||||
|
// NATS (required)
|
||||||
|
NatsURL string `env:"LOCALAI_NATS_URL" required:"" help:"NATS server URL" group:"distributed"`
|
||||||
|
|
||||||
|
// S3 storage for distributed file transfer
|
||||||
|
StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3 endpoint URL" group:"distributed"`
|
||||||
|
StorageBucket string `env:"LOCALAI_STORAGE_BUCKET" help:"S3 bucket name" group:"distributed"`
|
||||||
|
StorageRegion string `env:"LOCALAI_STORAGE_REGION" help:"S3 region" group:"distributed"`
|
||||||
|
StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key" group:"distributed"`
|
||||||
|
StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret key" group:"distributed"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cmd *WorkerCMD) Run(ctx *cliContext.Context) error {
|
||||||
|
xlog.Info("Starting worker", "addr", cmd.Addr)
|
||||||
|
|
||||||
|
systemState, err := system.GetSystemState(
|
||||||
|
system.WithModelPath(cmd.ModelsPath),
|
||||||
|
system.WithBackendPath(cmd.BackendsPath),
|
||||||
|
system.WithBackendSystemPath(cmd.BackendsSystemPath),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("getting system state: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ml := model.NewModelLoader(systemState)
|
||||||
|
ml.SetBackendLoggingEnabled(true)
|
||||||
|
|
||||||
|
// Register already-installed backends
|
||||||
|
gallery.RegisterBackends(systemState, ml)
|
||||||
|
|
||||||
|
// Parse galleries config
|
||||||
|
var galleries []config.Gallery
|
||||||
|
if err := json.Unmarshal([]byte(cmd.BackendGalleries), &galleries); err != nil {
|
||||||
|
xlog.Warn("Failed to parse backend galleries", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Self-registration with frontend (with retry)
|
||||||
|
regClient := &workerregistry.RegistrationClient{
|
||||||
|
FrontendURL: cmd.RegisterTo,
|
||||||
|
RegistrationToken: cmd.RegistrationToken,
|
||||||
|
}
|
||||||
|
|
||||||
|
registrationBody := cmd.registrationBody()
|
||||||
|
nodeID, _, err := regClient.RegisterWithRetry(context.Background(), registrationBody, 10)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to register with frontend: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
xlog.Info("Registered with frontend", "nodeID", nodeID, "frontend", cmd.RegisterTo)
|
||||||
|
heartbeatInterval, err := time.ParseDuration(cmd.HeartbeatInterval)
|
||||||
|
if err != nil && cmd.HeartbeatInterval != "" {
|
||||||
|
xlog.Warn("invalid heartbeat interval, using default 10s", "input", cmd.HeartbeatInterval, "error", err)
|
||||||
|
}
|
||||||
|
heartbeatInterval = cmp.Or(heartbeatInterval, 10*time.Second)
|
||||||
|
// Context cancelled on shutdown — used by heartbeat and other background goroutines
|
||||||
|
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||||
|
defer shutdownCancel()
|
||||||
|
|
||||||
|
// Start HTTP file transfer server
|
||||||
|
httpAddr := cmd.resolveHTTPAddr()
|
||||||
|
stagingDir := filepath.Join(cmd.ModelsPath, "..", "staging")
|
||||||
|
dataDir := filepath.Join(cmd.ModelsPath, "..", "data")
|
||||||
|
httpServer, err := nodes.StartFileTransferServer(httpAddr, stagingDir, cmd.ModelsPath, dataDir, cmd.RegistrationToken, config.DefaultMaxUploadSize, ml.BackendLogs())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("starting HTTP file transfer server: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect to NATS
|
||||||
|
xlog.Info("Connecting to NATS", "url", sanitize.URL(cmd.NatsURL))
|
||||||
|
natsClient, err := messaging.New(cmd.NatsURL)
|
||||||
|
if err != nil {
|
||||||
|
nodes.ShutdownFileTransferServer(httpServer)
|
||||||
|
return fmt.Errorf("connecting to NATS: %w", err)
|
||||||
|
}
|
||||||
|
defer natsClient.Close()
|
||||||
|
|
||||||
|
// Start heartbeat goroutine (after NATS is connected so IsConnected check works)
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(heartbeatInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-shutdownCtx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
if !natsClient.IsConnected() {
|
||||||
|
xlog.Warn("Skipping heartbeat: NATS disconnected")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
body := cmd.heartbeatBody()
|
||||||
|
if err := regClient.Heartbeat(shutdownCtx, nodeID, body); err != nil {
|
||||||
|
xlog.Warn("Heartbeat failed", "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Process supervisor — manages multiple backend gRPC processes on different ports
|
||||||
|
basePort := 50051
|
||||||
|
if cmd.Addr != "" {
|
||||||
|
// Extract port from addr (e.g., "0.0.0.0:50051" → 50051)
|
||||||
|
if _, portStr, err := net.SplitHostPort(cmd.Addr); err == nil {
|
||||||
|
if p, err := strconv.Atoi(portStr); err == nil {
|
||||||
|
basePort = p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Buffered so NATS stop handler can send without blocking
|
||||||
|
sigCh := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
|
// Set the registration token once before any backends are started
|
||||||
|
if cmd.RegistrationToken != "" {
|
||||||
|
os.Setenv(grpc.AuthTokenEnvVar, cmd.RegistrationToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
supervisor := &backendSupervisor{
|
||||||
|
cmd: cmd,
|
||||||
|
ml: ml,
|
||||||
|
systemState: systemState,
|
||||||
|
galleries: galleries,
|
||||||
|
nodeID: nodeID,
|
||||||
|
nats: natsClient,
|
||||||
|
sigCh: sigCh,
|
||||||
|
processes: make(map[string]*backendProcess),
|
||||||
|
nextPort: basePort,
|
||||||
|
}
|
||||||
|
supervisor.subscribeLifecycleEvents()
|
||||||
|
|
||||||
|
// Subscribe to file staging NATS subjects if S3 is configured
|
||||||
|
if cmd.StorageURL != "" {
|
||||||
|
if err := cmd.subscribeFileStaging(natsClient, nodeID); err != nil {
|
||||||
|
xlog.Error("Failed to subscribe to file staging subjects", "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
xlog.Info("Worker ready, waiting for backend.install events")
|
||||||
|
<-sigCh
|
||||||
|
|
||||||
|
xlog.Info("Shutting down worker")
|
||||||
|
shutdownCancel() // stop heartbeat loop immediately
|
||||||
|
regClient.GracefulDeregister(nodeID)
|
||||||
|
supervisor.stopAllBackends()
|
||||||
|
nodes.ShutdownFileTransferServer(httpServer)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// subscribeFileStaging subscribes to NATS file staging subjects for this node.
|
||||||
|
func (cmd *WorkerCMD) subscribeFileStaging(natsClient messaging.MessagingClient, nodeID string) error {
|
||||||
|
// Create FileManager with same S3 config as the frontend
|
||||||
|
// TODO: propagate a caller-provided context once WorkerCMD carries one
|
||||||
|
s3Store, err := storage.NewS3Store(context.Background(), storage.S3Config{
|
||||||
|
Endpoint: cmd.StorageURL,
|
||||||
|
Region: cmd.StorageRegion,
|
||||||
|
Bucket: cmd.StorageBucket,
|
||||||
|
AccessKeyID: cmd.StorageAccessKey,
|
||||||
|
SecretAccessKey: cmd.StorageSecretKey,
|
||||||
|
ForcePathStyle: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("initializing S3 store: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cacheDir := filepath.Join(cmd.ModelsPath, "..", "cache")
|
||||||
|
fm, err := storage.NewFileManager(s3Store, cacheDir)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("initializing file manager: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe: files.ensure — download S3 key to local, reply with local path
|
||||||
|
natsClient.SubscribeReply(messaging.SubjectNodeFilesEnsure(nodeID), func(data []byte, reply func([]byte)) {
|
||||||
|
var req struct {
|
||||||
|
Key string `json:"key"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &req); err != nil {
|
||||||
|
replyJSON(reply, map[string]string{"error": "invalid request"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
localPath, err := fm.Download(context.Background(), req.Key)
|
||||||
|
if err != nil {
|
||||||
|
xlog.Error("File ensure failed", "key", req.Key, "error", err)
|
||||||
|
replyJSON(reply, map[string]string{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
xlog.Debug("File ensured locally", "key", req.Key, "path", localPath)
|
||||||
|
replyJSON(reply, map[string]string{"local_path": localPath})
|
||||||
|
})
|
||||||
|
|
||||||
|
// Subscribe: files.stage — upload local path to S3, reply with key
|
||||||
|
natsClient.SubscribeReply(messaging.SubjectNodeFilesStage(nodeID), func(data []byte, reply func([]byte)) {
|
||||||
|
var req struct {
|
||||||
|
LocalPath string `json:"local_path"`
|
||||||
|
Key string `json:"key"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &req); err != nil {
|
||||||
|
replyJSON(reply, map[string]string{"error": "invalid request"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
allowedDirs := []string{cacheDir}
|
||||||
|
if cmd.ModelsPath != "" {
|
||||||
|
allowedDirs = append(allowedDirs, cmd.ModelsPath)
|
||||||
|
}
|
||||||
|
if !isPathAllowed(req.LocalPath, allowedDirs) {
|
||||||
|
replyJSON(reply, map[string]string{"error": "path outside allowed directories"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := fm.Upload(context.Background(), req.Key, req.LocalPath); err != nil {
|
||||||
|
xlog.Error("File stage failed", "path", req.LocalPath, "key", req.Key, "error", err)
|
||||||
|
replyJSON(reply, map[string]string{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
xlog.Debug("File staged to S3", "path", req.LocalPath, "key", req.Key)
|
||||||
|
replyJSON(reply, map[string]string{"key": req.Key})
|
||||||
|
})
|
||||||
|
|
||||||
|
// Subscribe: files.temp — allocate temp file, reply with local path
|
||||||
|
natsClient.SubscribeReply(messaging.SubjectNodeFilesTemp(nodeID), func(data []byte, reply func([]byte)) {
|
||||||
|
tmpDir := filepath.Join(cacheDir, "staging-tmp")
|
||||||
|
if err := os.MkdirAll(tmpDir, 0750); err != nil {
|
||||||
|
replyJSON(reply, map[string]string{"error": fmt.Sprintf("creating temp dir: %v", err)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := os.CreateTemp(tmpDir, "localai-staging-*.tmp")
|
||||||
|
if err != nil {
|
||||||
|
replyJSON(reply, map[string]string{"error": fmt.Sprintf("creating temp file: %v", err)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
localPath := f.Name()
|
||||||
|
f.Close()
|
||||||
|
|
||||||
|
xlog.Debug("Allocated temp file", "path", localPath)
|
||||||
|
replyJSON(reply, map[string]string{"local_path": localPath})
|
||||||
|
})
|
||||||
|
|
||||||
|
// Subscribe: files.listdir — list files in a local directory, reply with relative paths
|
||||||
|
natsClient.SubscribeReply(messaging.SubjectNodeFilesListDir(nodeID), func(data []byte, reply func([]byte)) {
|
||||||
|
var req struct {
|
||||||
|
KeyPrefix string `json:"key_prefix"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &req); err != nil {
|
||||||
|
replyJSON(reply, map[string]any{"error": "invalid request"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve key prefix to local directory
|
||||||
|
dirPath := filepath.Join(cacheDir, req.KeyPrefix)
|
||||||
|
if rel, ok := strings.CutPrefix(req.KeyPrefix, storage.ModelKeyPrefix); ok && cmd.ModelsPath != "" {
|
||||||
|
dirPath = filepath.Join(cmd.ModelsPath, rel)
|
||||||
|
} else if rel, ok := strings.CutPrefix(req.KeyPrefix, storage.DataKeyPrefix); ok {
|
||||||
|
dirPath = filepath.Join(cacheDir, "..", "data", rel)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sanitize to prevent directory traversal via crafted key_prefix
|
||||||
|
dirPath = filepath.Clean(dirPath)
|
||||||
|
cleanCache := filepath.Clean(cacheDir)
|
||||||
|
cleanModels := filepath.Clean(cmd.ModelsPath)
|
||||||
|
cleanData := filepath.Clean(filepath.Join(cacheDir, "..", "data"))
|
||||||
|
if !(strings.HasPrefix(dirPath, cleanCache+string(filepath.Separator)) ||
|
||||||
|
dirPath == cleanCache ||
|
||||||
|
(cleanModels != "." && strings.HasPrefix(dirPath, cleanModels+string(filepath.Separator))) ||
|
||||||
|
dirPath == cleanModels ||
|
||||||
|
strings.HasPrefix(dirPath, cleanData+string(filepath.Separator)) ||
|
||||||
|
dirPath == cleanData) {
|
||||||
|
replyJSON(reply, map[string]any{"error": "invalid key prefix"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var files []string
|
||||||
|
filepath.WalkDir(dirPath, func(path string, d os.DirEntry, err error) error {
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !d.IsDir() {
|
||||||
|
rel, err := filepath.Rel(dirPath, path)
|
||||||
|
if err == nil {
|
||||||
|
files = append(files, rel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
xlog.Debug("Listed remote dir", "keyPrefix", req.KeyPrefix, "dirPath", dirPath, "fileCount", len(files))
|
||||||
|
replyJSON(reply, map[string]any{"files": files})
|
||||||
|
})
|
||||||
|
|
||||||
|
xlog.Info("Subscribed to file staging NATS subjects", "nodeID", nodeID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// replyJSON marshals v to JSON and calls the reply function.
|
||||||
|
func replyJSON(reply func([]byte), v any) {
|
||||||
|
data, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
xlog.Error("Failed to marshal NATS reply", "error", err)
|
||||||
|
data = []byte(`{"error":"internal marshal error"}`)
|
||||||
|
}
|
||||||
|
reply(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// backendProcess represents a single gRPC backend process.
|
||||||
|
type backendProcess struct {
|
||||||
|
proc *process.Process
|
||||||
|
backend string
|
||||||
|
addr string // gRPC address (host:port)
|
||||||
|
}
|
||||||
|
|
||||||
|
// backendSupervisor manages multiple backend gRPC processes on different ports.
|
||||||
|
// Each backend type (e.g., llama-cpp, bert-embeddings) gets its own process and port.
|
||||||
|
type backendSupervisor struct {
|
||||||
|
cmd *WorkerCMD
|
||||||
|
ml *model.ModelLoader
|
||||||
|
systemState *system.SystemState
|
||||||
|
galleries []config.Gallery
|
||||||
|
nodeID string
|
||||||
|
nats messaging.MessagingClient
|
||||||
|
sigCh chan<- os.Signal // send shutdown signal instead of os.Exit
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
processes map[string]*backendProcess // key: backend name
|
||||||
|
nextPort int // next available port for new backends
|
||||||
|
freePorts []int // ports freed by stopBackend, reused before nextPort
|
||||||
|
}
|
||||||
|
|
||||||
|
// startBackend starts a gRPC backend process on a dynamically allocated port.
|
||||||
|
// Returns the gRPC address.
|
||||||
|
func (s *backendSupervisor) startBackend(backend, backendPath string) (string, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
|
||||||
|
// Already running?
|
||||||
|
if bp, ok := s.processes[backend]; ok {
|
||||||
|
if bp.proc != nil && bp.proc.IsAlive() {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return bp.addr, nil
|
||||||
|
}
|
||||||
|
// Process died — clean up and restart
|
||||||
|
xlog.Warn("Backend process died unexpectedly, restarting", "backend", backend)
|
||||||
|
delete(s.processes, backend)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate port — recycle freed ports first, then grow upward from basePort
|
||||||
|
var port int
|
||||||
|
if len(s.freePorts) > 0 {
|
||||||
|
port = s.freePorts[len(s.freePorts)-1]
|
||||||
|
s.freePorts = s.freePorts[:len(s.freePorts)-1]
|
||||||
|
} else {
|
||||||
|
port = s.nextPort
|
||||||
|
s.nextPort++
|
||||||
|
}
|
||||||
|
bindAddr := fmt.Sprintf("0.0.0.0:%d", port)
|
||||||
|
clientAddr := fmt.Sprintf("127.0.0.1:%d", port)
|
||||||
|
|
||||||
|
proc, err := s.ml.StartProcess(backendPath, backend, bindAddr)
|
||||||
|
if err != nil {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return "", fmt.Errorf("starting backend process: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.processes[backend] = &backendProcess{
|
||||||
|
proc: proc,
|
||||||
|
backend: backend,
|
||||||
|
addr: clientAddr,
|
||||||
|
}
|
||||||
|
xlog.Info("Backend process started", "backend", backend, "addr", clientAddr)
|
||||||
|
|
||||||
|
// Capture reference before unlocking for race-safe health check.
|
||||||
|
// Another goroutine could stopBackend and recycle the port while we poll.
|
||||||
|
bp := s.processes[backend]
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
// Wait for the gRPC server to be ready
|
||||||
|
client := grpc.NewClientWithToken(clientAddr, false, nil, false, s.cmd.RegistrationToken)
|
||||||
|
for range 20 {
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
if ok, _ := client.HealthCheck(ctx); ok {
|
||||||
|
cancel()
|
||||||
|
// Verify the process wasn't stopped/replaced while health-checking
|
||||||
|
s.mu.Lock()
|
||||||
|
currentBP, exists := s.processes[backend]
|
||||||
|
s.mu.Unlock()
|
||||||
|
if !exists || currentBP != bp {
|
||||||
|
return "", fmt.Errorf("backend %s was stopped during startup", backend)
|
||||||
|
}
|
||||||
|
xlog.Debug("Backend gRPC server is ready", "backend", backend, "addr", clientAddr)
|
||||||
|
return clientAddr, nil
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
xlog.Warn("Backend gRPC server not ready after waiting, proceeding anyway", "backend", backend, "addr", clientAddr)
|
||||||
|
return clientAddr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// stopBackend stops a specific backend's gRPC process.
|
||||||
|
func (s *backendSupervisor) stopBackend(backend string) {
|
||||||
|
s.mu.Lock()
|
||||||
|
bp, ok := s.processes[backend]
|
||||||
|
if !ok || bp.proc == nil {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Clean up map and recycle port while holding lock
|
||||||
|
delete(s.processes, backend)
|
||||||
|
if _, portStr, err := net.SplitHostPort(bp.addr); err == nil {
|
||||||
|
if p, err := strconv.Atoi(portStr); err == nil {
|
||||||
|
s.freePorts = append(s.freePorts, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
// Network I/O outside the lock
|
||||||
|
client := grpc.NewClientWithToken(bp.addr, false, nil, false, s.cmd.RegistrationToken)
|
||||||
|
if freeFunc, ok := client.(interface{ Free(context.Context) error }); ok {
|
||||||
|
xlog.Debug("Calling Free() before stopping backend", "backend", backend)
|
||||||
|
if err := freeFunc.Free(context.Background()); err != nil {
|
||||||
|
xlog.Warn("Free() failed (best-effort)", "backend", backend, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
xlog.Info("Stopping backend process", "backend", backend, "addr", bp.addr)
|
||||||
|
if err := bp.proc.Stop(); err != nil {
|
||||||
|
xlog.Error("Error stopping backend process", "backend", backend, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// stopAllBackends stops all running backend processes.
|
||||||
|
func (s *backendSupervisor) stopAllBackends() {
|
||||||
|
s.mu.Lock()
|
||||||
|
backends := slices.Collect(maps.Keys(s.processes))
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
for _, b := range backends {
|
||||||
|
s.stopBackend(b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isRunning returns whether a specific backend process is currently running.
|
||||||
|
func (s *backendSupervisor) isRunning(backend string) bool {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
bp, ok := s.processes[backend]
|
||||||
|
return ok && bp.proc != nil && bp.proc.IsAlive()
|
||||||
|
}
|
||||||
|
|
||||||
|
// getAddr returns the gRPC address for a running backend, or empty string.
|
||||||
|
func (s *backendSupervisor) getAddr(backend string) string {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
if bp, ok := s.processes[backend]; ok {
|
||||||
|
return bp.addr
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// installBackend handles the backend.install flow:
|
||||||
|
// 1. If already running for this model, return existing address
|
||||||
|
// 2. Install backend from gallery (if not already installed)
|
||||||
|
// 3. Find backend binary
|
||||||
|
// 4. Start gRPC process on a new port
|
||||||
|
// Returns the gRPC address of the backend process.
|
||||||
|
func (s *backendSupervisor) installBackend(req messaging.BackendInstallRequest) (string, error) {
|
||||||
|
// Process key: use ModelID if provided (per-model process), else backend name
|
||||||
|
processKey := req.ModelID
|
||||||
|
if processKey == "" {
|
||||||
|
processKey = req.Backend
|
||||||
|
}
|
||||||
|
|
||||||
|
// If already running for this model, return its address
|
||||||
|
if addr := s.getAddr(processKey); addr != "" {
|
||||||
|
xlog.Info("Backend already running for model", "backend", req.Backend, "model", req.ModelID, "addr", addr)
|
||||||
|
return addr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse galleries from request (override local config if provided)
|
||||||
|
galleries := s.galleries
|
||||||
|
if req.BackendGalleries != "" {
|
||||||
|
var reqGalleries []config.Gallery
|
||||||
|
if err := json.Unmarshal([]byte(req.BackendGalleries), &reqGalleries); err == nil {
|
||||||
|
galleries = reqGalleries
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to find the backend binary
|
||||||
|
backendPath := s.findBackend(req.Backend)
|
||||||
|
if backendPath == "" {
|
||||||
|
// Backend not found locally — try auto-installing from gallery
|
||||||
|
xlog.Info("Backend not found locally, attempting gallery install", "backend", req.Backend)
|
||||||
|
if err := gallery.InstallBackendFromGallery(
|
||||||
|
context.Background(), galleries, s.systemState, s.ml, req.Backend, nil, false,
|
||||||
|
); err != nil {
|
||||||
|
return "", fmt.Errorf("installing backend from gallery: %w", err)
|
||||||
|
}
|
||||||
|
// Re-register after install and retry
|
||||||
|
gallery.RegisterBackends(s.systemState, s.ml)
|
||||||
|
backendPath = s.findBackend(req.Backend)
|
||||||
|
}
|
||||||
|
|
||||||
|
if backendPath == "" {
|
||||||
|
return "", fmt.Errorf("backend %q not found after install attempt", req.Backend)
|
||||||
|
}
|
||||||
|
|
||||||
|
xlog.Info("Found backend binary", "path", backendPath, "processKey", processKey)
|
||||||
|
|
||||||
|
// Start the gRPC process on a new port (keyed by model, not just backend)
|
||||||
|
return s.startBackend(processKey, backendPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// findBackend looks for the backend binary in the backends path and system path.
|
||||||
|
func (s *backendSupervisor) findBackend(backend string) string {
|
||||||
|
candidates := []string{
|
||||||
|
filepath.Join(s.cmd.BackendsPath, backend),
|
||||||
|
filepath.Join(s.cmd.BackendsPath, backend, backend),
|
||||||
|
filepath.Join(s.cmd.BackendsSystemPath, backend),
|
||||||
|
filepath.Join(s.cmd.BackendsSystemPath, backend, backend),
|
||||||
|
}
|
||||||
|
if uri := s.ml.GetExternalBackend(backend); uri != "" {
|
||||||
|
if fi, err := os.Stat(uri); err == nil && !fi.IsDir() {
|
||||||
|
return uri
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, path := range candidates {
|
||||||
|
fi, err := os.Stat(path)
|
||||||
|
if err == nil && !fi.IsDir() {
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// subscribeLifecycleEvents subscribes to NATS backend lifecycle events.
|
||||||
|
func (s *backendSupervisor) subscribeLifecycleEvents() {
|
||||||
|
// backend.install — install backend + start gRPC process (request-reply)
|
||||||
|
s.nats.SubscribeReply(messaging.SubjectNodeBackendInstall(s.nodeID), func(data []byte, reply func([]byte)) {
|
||||||
|
xlog.Info("Received NATS backend.install event")
|
||||||
|
var req messaging.BackendInstallRequest
|
||||||
|
if err := json.Unmarshal(data, &req); err != nil {
|
||||||
|
resp := messaging.BackendInstallReply{Success: false, Error: fmt.Sprintf("invalid request: %v", err)}
|
||||||
|
replyJSON(reply, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
addr, err := s.installBackend(req)
|
||||||
|
if err != nil {
|
||||||
|
xlog.Error("Failed to install backend via NATS", "error", err)
|
||||||
|
resp := messaging.BackendInstallReply{Success: false, Error: err.Error()}
|
||||||
|
replyJSON(reply, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the gRPC address so the router knows which port to use
|
||||||
|
advertiseAddr := addr
|
||||||
|
if s.cmd.AdvertiseAddr != "" {
|
||||||
|
// Replace 0.0.0.0 with the advertised host but keep the dynamic port
|
||||||
|
_, port, _ := net.SplitHostPort(addr)
|
||||||
|
advertiseHost, _, _ := net.SplitHostPort(s.cmd.AdvertiseAddr)
|
||||||
|
advertiseAddr = net.JoinHostPort(advertiseHost, port)
|
||||||
|
}
|
||||||
|
resp := messaging.BackendInstallReply{Success: true, Address: advertiseAddr}
|
||||||
|
replyJSON(reply, resp)
|
||||||
|
})
|
||||||
|
|
||||||
|
// backend.stop — stop a specific backend process
|
||||||
|
s.nats.Subscribe(messaging.SubjectNodeBackendStop(s.nodeID), func(data []byte) {
|
||||||
|
// Try to parse backend name from payload; if empty, stop all
|
||||||
|
var req struct {
|
||||||
|
Backend string `json:"backend"`
|
||||||
|
}
|
||||||
|
if json.Unmarshal(data, &req) == nil && req.Backend != "" {
|
||||||
|
xlog.Info("Received NATS backend.stop event", "backend", req.Backend)
|
||||||
|
s.stopBackend(req.Backend)
|
||||||
|
} else {
|
||||||
|
xlog.Info("Received NATS backend.stop event (all)")
|
||||||
|
s.stopAllBackends()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// backend.delete — stop backend + delete files (request-reply)
|
||||||
|
s.nats.SubscribeReply(messaging.SubjectNodeBackendDelete(s.nodeID), func(data []byte, reply func([]byte)) {
|
||||||
|
xlog.Info("Received NATS backend.delete event")
|
||||||
|
var req messaging.BackendDeleteRequest
|
||||||
|
if err := json.Unmarshal(data, &req); err != nil {
|
||||||
|
resp := messaging.BackendDeleteReply{Success: false, Error: fmt.Sprintf("invalid request: %v", err)}
|
||||||
|
replyJSON(reply, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop if running this backend
|
||||||
|
if s.isRunning(req.Backend) {
|
||||||
|
s.stopBackend(req.Backend)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete the backend files
|
||||||
|
if err := gallery.DeleteBackendFromSystem(s.systemState, req.Backend); err != nil {
|
||||||
|
xlog.Warn("Failed to delete backend files", "backend", req.Backend, "error", err)
|
||||||
|
resp := messaging.BackendDeleteReply{Success: false, Error: err.Error()}
|
||||||
|
replyJSON(reply, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-register backends after deletion
|
||||||
|
gallery.RegisterBackends(s.systemState, s.ml)
|
||||||
|
|
||||||
|
resp := messaging.BackendDeleteReply{Success: true}
|
||||||
|
replyJSON(reply, resp)
|
||||||
|
})
|
||||||
|
|
||||||
|
// backend.list — list installed backends (request-reply)
|
||||||
|
s.nats.SubscribeReply(messaging.SubjectNodeBackendList(s.nodeID), func(data []byte, reply func([]byte)) {
|
||||||
|
xlog.Info("Received NATS backend.list event")
|
||||||
|
backends, err := gallery.ListSystemBackends(s.systemState)
|
||||||
|
if err != nil {
|
||||||
|
resp := messaging.BackendListReply{Error: err.Error()}
|
||||||
|
replyJSON(reply, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var infos []messaging.NodeBackendInfo
|
||||||
|
for name, b := range backends {
|
||||||
|
info := messaging.NodeBackendInfo{
|
||||||
|
Name: name,
|
||||||
|
IsSystem: b.IsSystem,
|
||||||
|
IsMeta: b.IsMeta,
|
||||||
|
}
|
||||||
|
if b.Metadata != nil {
|
||||||
|
info.InstalledAt = b.Metadata.InstalledAt
|
||||||
|
info.GalleryURL = b.Metadata.GalleryURL
|
||||||
|
}
|
||||||
|
infos = append(infos, info)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := messaging.BackendListReply{Backends: infos}
|
||||||
|
replyJSON(reply, resp)
|
||||||
|
})
|
||||||
|
|
||||||
|
// model.unload — call gRPC Free() to release GPU memory (request-reply)
|
||||||
|
s.nats.SubscribeReply(messaging.SubjectNodeModelUnload(s.nodeID), func(data []byte, reply func([]byte)) {
|
||||||
|
xlog.Info("Received NATS model.unload event")
|
||||||
|
var req messaging.ModelUnloadRequest
|
||||||
|
if err := json.Unmarshal(data, &req); err != nil {
|
||||||
|
resp := messaging.ModelUnloadReply{Success: false, Error: fmt.Sprintf("invalid request: %v", err)}
|
||||||
|
replyJSON(reply, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the backend address for this model's backend type
|
||||||
|
// The request includes an Address field if the router knows which process to target
|
||||||
|
targetAddr := req.Address
|
||||||
|
if targetAddr == "" {
|
||||||
|
// Fallback: try all running backends
|
||||||
|
s.mu.Lock()
|
||||||
|
for _, bp := range s.processes {
|
||||||
|
targetAddr = bp.addr
|
||||||
|
break
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
if targetAddr != "" {
|
||||||
|
// Best-effort gRPC Free()
|
||||||
|
client := grpc.NewClientWithToken(targetAddr, false, nil, false, s.cmd.RegistrationToken)
|
||||||
|
if freeFunc, ok := client.(interface{ Free(context.Context) error }); ok {
|
||||||
|
if err := freeFunc.Free(context.Background()); err != nil {
|
||||||
|
xlog.Warn("Free() failed during model.unload", "error", err, "addr", targetAddr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := messaging.ModelUnloadReply{Success: true}
|
||||||
|
replyJSON(reply, resp)
|
||||||
|
})
|
||||||
|
|
||||||
|
// model.delete — remove model files from disk (request-reply)
|
||||||
|
s.nats.SubscribeReply(messaging.SubjectNodeModelDelete(s.nodeID), func(data []byte, reply func([]byte)) {
|
||||||
|
xlog.Info("Received NATS model.delete event")
|
||||||
|
var req messaging.ModelDeleteRequest
|
||||||
|
if err := json.Unmarshal(data, &req); err != nil {
|
||||||
|
replyJSON(reply, messaging.ModelDeleteReply{Success: false, Error: "invalid request"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := gallery.DeleteStagedModelFiles(s.cmd.ModelsPath, req.ModelName); err != nil {
|
||||||
|
xlog.Warn("Failed to delete model files", "model", req.ModelName, "error", err)
|
||||||
|
replyJSON(reply, messaging.ModelDeleteReply{Success: false, Error: err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
replyJSON(reply, messaging.ModelDeleteReply{Success: true})
|
||||||
|
})
|
||||||
|
|
||||||
|
// stop — trigger the normal shutdown path via sigCh so deferred cleanup runs
|
||||||
|
s.nats.Subscribe(messaging.SubjectNodeStop(s.nodeID), func(data []byte) {
|
||||||
|
xlog.Info("Received NATS stop event — signaling shutdown")
|
||||||
|
select {
|
||||||
|
case s.sigCh <- syscall.SIGTERM:
|
||||||
|
default:
|
||||||
|
xlog.Debug("Shutdown already signaled, ignoring duplicate stop")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// advertiseAddr returns the address the frontend should use to reach this node.
|
||||||
|
func (cmd *WorkerCMD) advertiseAddr() string {
|
||||||
|
if cmd.AdvertiseAddr != "" {
|
||||||
|
return cmd.AdvertiseAddr
|
||||||
|
}
|
||||||
|
host, port, ok := strings.Cut(cmd.Addr, ":")
|
||||||
|
if ok && (host == "0.0.0.0" || host == "") {
|
||||||
|
if hostname, err := os.Hostname(); err == nil {
|
||||||
|
return hostname + ":" + port
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cmd.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveHTTPAddr returns the address to bind the HTTP file transfer server to.
|
||||||
|
// Uses basePort-1 so it doesn't conflict with dynamically allocated gRPC ports
|
||||||
|
// which grow upward from basePort.
|
||||||
|
func (cmd *WorkerCMD) resolveHTTPAddr() string {
|
||||||
|
if cmd.HTTPAddr != "" {
|
||||||
|
return cmd.HTTPAddr
|
||||||
|
}
|
||||||
|
host, port, ok := strings.Cut(cmd.Addr, ":")
|
||||||
|
if !ok {
|
||||||
|
return "0.0.0.0:50050"
|
||||||
|
}
|
||||||
|
portNum, _ := strconv.Atoi(port)
|
||||||
|
return fmt.Sprintf("%s:%d", host, portNum-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// advertiseHTTPAddr returns the HTTP address the frontend should use to reach
|
||||||
|
// this node for file transfer.
|
||||||
|
func (cmd *WorkerCMD) advertiseHTTPAddr() string {
|
||||||
|
if cmd.AdvertiseHTTPAddr != "" {
|
||||||
|
return cmd.AdvertiseHTTPAddr
|
||||||
|
}
|
||||||
|
httpAddr := cmd.resolveHTTPAddr()
|
||||||
|
host, port, ok := strings.Cut(httpAddr, ":")
|
||||||
|
if ok && (host == "0.0.0.0" || host == "") {
|
||||||
|
if hostname, err := os.Hostname(); err == nil {
|
||||||
|
return hostname + ":" + port
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return httpAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
// registrationBody builds the JSON body for node registration.
|
||||||
|
func (cmd *WorkerCMD) registrationBody() map[string]any {
|
||||||
|
nodeName := cmd.NodeName
|
||||||
|
if nodeName == "" {
|
||||||
|
hostname, err := os.Hostname()
|
||||||
|
if err != nil {
|
||||||
|
nodeName = fmt.Sprintf("node-%d", os.Getpid())
|
||||||
|
} else {
|
||||||
|
nodeName = hostname
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detect GPU info for VRAM-aware scheduling
|
||||||
|
totalVRAM, _ := xsysinfo.TotalAvailableVRAM()
|
||||||
|
gpuVendor, _ := xsysinfo.DetectGPUVendor()
|
||||||
|
|
||||||
|
body := map[string]any{
|
||||||
|
"name": nodeName,
|
||||||
|
"address": cmd.advertiseAddr(),
|
||||||
|
"http_address": cmd.advertiseHTTPAddr(),
|
||||||
|
"total_vram": totalVRAM,
|
||||||
|
"available_vram": totalVRAM, // initially all VRAM is available
|
||||||
|
"gpu_vendor": gpuVendor,
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no GPU detected, report system RAM so the scheduler/UI has capacity info
|
||||||
|
if totalVRAM == 0 {
|
||||||
|
if ramInfo, err := xsysinfo.GetSystemRAMInfo(); err == nil {
|
||||||
|
body["total_ram"] = ramInfo.Total
|
||||||
|
body["available_ram"] = ramInfo.Available
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cmd.RegistrationToken != "" {
|
||||||
|
body["token"] = cmd.RegistrationToken
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// heartbeatBody returns the current VRAM/RAM stats for heartbeat payloads.
|
||||||
|
func (cmd *WorkerCMD) heartbeatBody() map[string]any {
|
||||||
|
var availVRAM uint64
|
||||||
|
aggregate := xsysinfo.GetGPUAggregateInfo()
|
||||||
|
if aggregate.TotalVRAM > 0 {
|
||||||
|
availVRAM = aggregate.FreeVRAM
|
||||||
|
} else {
|
||||||
|
// Fallback: report total as available (no usage tracking possible)
|
||||||
|
availVRAM, _ = xsysinfo.TotalAvailableVRAM()
|
||||||
|
}
|
||||||
|
|
||||||
|
body := map[string]any{
|
||||||
|
"available_vram": availVRAM,
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no GPU, report system RAM usage instead
|
||||||
|
if aggregate.TotalVRAM == 0 {
|
||||||
|
if ramInfo, err := xsysinfo.GetSystemRAMInfo(); err == nil {
|
||||||
|
body["available_ram"] = ramInfo.Available
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
}
|
||||||
272
core/cli/workerregistry/client.go
Normal file
272
core/cli/workerregistry/client.go
Normal file
|
|
@ -0,0 +1,272 @@
|
||||||
|
// Package workerregistry provides a shared HTTP client for worker node
|
||||||
|
// registration, heartbeating, draining, and deregistration against a
|
||||||
|
// LocalAI frontend. Both the backend worker (WorkerCMD) and the agent
|
||||||
|
// worker (AgentWorkerCMD) use this instead of duplicating the logic.
|
||||||
|
package workerregistry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mudler/xlog"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RegistrationClient talks to the frontend's /api/node/* endpoints.
|
||||||
|
type RegistrationClient struct {
|
||||||
|
FrontendURL string
|
||||||
|
RegistrationToken string
|
||||||
|
HTTPTimeout time.Duration // used for registration calls; defaults to 10s
|
||||||
|
client *http.Client
|
||||||
|
clientOnce sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
// httpTimeout returns the configured timeout or a sensible default.
|
||||||
|
func (c *RegistrationClient) httpTimeout() time.Duration {
|
||||||
|
if c.HTTPTimeout > 0 {
|
||||||
|
return c.HTTPTimeout
|
||||||
|
}
|
||||||
|
return 10 * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
// httpClient returns the shared HTTP client, initializing it on first use.
|
||||||
|
func (c *RegistrationClient) httpClient() *http.Client {
|
||||||
|
c.clientOnce.Do(func() {
|
||||||
|
c.client = &http.Client{Timeout: c.httpTimeout()}
|
||||||
|
})
|
||||||
|
return c.client
|
||||||
|
}
|
||||||
|
|
||||||
|
// baseURL returns FrontendURL with any trailing slash stripped.
|
||||||
|
func (c *RegistrationClient) baseURL() string {
|
||||||
|
return strings.TrimRight(c.FrontendURL, "/")
|
||||||
|
}
|
||||||
|
|
||||||
|
// setAuth adds an Authorization header when a token is configured.
|
||||||
|
func (c *RegistrationClient) setAuth(req *http.Request) {
|
||||||
|
if c.RegistrationToken != "" {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+c.RegistrationToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterResponse is the JSON body returned by /api/node/register.
|
||||||
|
type RegisterResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
APIToken string `json:"api_token,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register sends a single registration request and returns the node ID and
|
||||||
|
// (optionally) an auto-provisioned API token.
|
||||||
|
func (c *RegistrationClient) Register(ctx context.Context, body map[string]any) (string, string, error) {
|
||||||
|
jsonBody, _ := json.Marshal(body)
|
||||||
|
url := c.baseURL() + "/api/node/register"
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody))
|
||||||
|
if err != nil {
|
||||||
|
return "", "", fmt.Errorf("creating request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
c.setAuth(req)
|
||||||
|
|
||||||
|
resp, err := c.httpClient().Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", fmt.Errorf("posting to %s: %w", url, err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return "", "", fmt.Errorf("registration failed with status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result RegisterResponse
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||||
|
return "", "", fmt.Errorf("decoding response: %w", err)
|
||||||
|
}
|
||||||
|
return result.ID, result.APIToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterWithRetry retries registration with exponential backoff.
|
||||||
|
func (c *RegistrationClient) RegisterWithRetry(ctx context.Context, body map[string]any, maxRetries int) (string, string, error) {
|
||||||
|
backoff := 2 * time.Second
|
||||||
|
maxBackoff := 30 * time.Second
|
||||||
|
|
||||||
|
var nodeID, apiToken string
|
||||||
|
var err error
|
||||||
|
|
||||||
|
for attempt := 1; attempt <= maxRetries; attempt++ {
|
||||||
|
nodeID, apiToken, err = c.Register(ctx, body)
|
||||||
|
if err == nil {
|
||||||
|
return nodeID, apiToken, nil
|
||||||
|
}
|
||||||
|
if attempt == maxRetries {
|
||||||
|
return "", "", fmt.Errorf("failed after %d attempts: %w", maxRetries, err)
|
||||||
|
}
|
||||||
|
xlog.Warn("Registration failed, retrying", "attempt", attempt, "next_retry", backoff, "error", err)
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return "", "", ctx.Err()
|
||||||
|
case <-time.After(backoff):
|
||||||
|
}
|
||||||
|
backoff = min(backoff*2, maxBackoff)
|
||||||
|
}
|
||||||
|
return nodeID, apiToken, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Heartbeat sends a single heartbeat POST with the given body.
|
||||||
|
func (c *RegistrationClient) Heartbeat(ctx context.Context, nodeID string, body map[string]any) error {
|
||||||
|
jsonBody, _ := json.Marshal(body)
|
||||||
|
url := c.baseURL() + "/api/node/" + nodeID + "/heartbeat"
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("creating heartbeat request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
c.setAuth(req)
|
||||||
|
|
||||||
|
resp, err := c.httpClient().Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HeartbeatLoop runs heartbeats at the given interval until ctx is cancelled.
|
||||||
|
// bodyFn is called each tick to build the heartbeat payload (e.g. VRAM stats).
|
||||||
|
func (c *RegistrationClient) HeartbeatLoop(ctx context.Context, nodeID string, interval time.Duration, bodyFn func() map[string]any) {
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
body := bodyFn()
|
||||||
|
if err := c.Heartbeat(ctx, nodeID, body); err != nil {
|
||||||
|
xlog.Warn("Heartbeat failed", "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drain sets the node to draining status via POST /api/node/:id/drain.
|
||||||
|
func (c *RegistrationClient) Drain(ctx context.Context, nodeID string) error {
|
||||||
|
url := c.baseURL() + "/api/node/" + nodeID + "/drain"
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("creating drain request: %w", err)
|
||||||
|
}
|
||||||
|
c.setAuth(req)
|
||||||
|
|
||||||
|
resp, err := c.httpClient().Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return fmt.Errorf("drain failed with status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitForDrain polls GET /api/node/:id/models until all models report 0
|
||||||
|
// in-flight requests, or until timeout elapses.
|
||||||
|
func (c *RegistrationClient) WaitForDrain(ctx context.Context, nodeID string, timeout time.Duration) {
|
||||||
|
url := c.baseURL() + "/api/node/" + nodeID + "/models"
|
||||||
|
|
||||||
|
deadline := time.Now().Add(timeout)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
xlog.Warn("Failed to create drain poll request", "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.setAuth(req)
|
||||||
|
|
||||||
|
resp, err := c.httpClient().Do(req)
|
||||||
|
if err != nil {
|
||||||
|
xlog.Warn("Drain poll failed, will retry", "error", err)
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
xlog.Warn("Drain wait cancelled")
|
||||||
|
return
|
||||||
|
case <-time.After(1 * time.Second):
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var models []struct {
|
||||||
|
InFlight int `json:"in_flight"`
|
||||||
|
}
|
||||||
|
json.NewDecoder(resp.Body).Decode(&models)
|
||||||
|
resp.Body.Close()
|
||||||
|
|
||||||
|
total := 0
|
||||||
|
for _, m := range models {
|
||||||
|
total += m.InFlight
|
||||||
|
}
|
||||||
|
if total == 0 {
|
||||||
|
xlog.Info("All in-flight requests drained")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
xlog.Info("Waiting for in-flight requests", "count", total)
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
xlog.Warn("Drain wait cancelled")
|
||||||
|
return
|
||||||
|
case <-time.After(1 * time.Second):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
xlog.Warn("Drain timeout reached, proceeding with shutdown")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deregister marks the node as offline via POST /api/node/:id/deregister.
|
||||||
|
// The node row is preserved in the database so re-registration restores
|
||||||
|
// approval status.
|
||||||
|
func (c *RegistrationClient) Deregister(ctx context.Context, nodeID string) error {
|
||||||
|
url := c.baseURL() + "/api/node/" + nodeID + "/deregister"
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("creating deregister request: %w", err)
|
||||||
|
}
|
||||||
|
c.setAuth(req)
|
||||||
|
|
||||||
|
resp, err := c.httpClient().Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return fmt.Errorf("deregistration failed with status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GracefulDeregister performs drain -> wait -> deregister in sequence.
|
||||||
|
// This is the standard shutdown sequence for backend workers.
|
||||||
|
func (c *RegistrationClient) GracefulDeregister(nodeID string) {
|
||||||
|
if c.FrontendURL == "" || nodeID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := c.Drain(ctx, nodeID); err != nil {
|
||||||
|
xlog.Warn("Failed to set drain status", "error", err)
|
||||||
|
} else {
|
||||||
|
c.WaitForDrain(ctx, nodeID, 30*time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.Deregister(ctx, nodeID); err != nil {
|
||||||
|
xlog.Error("Failed to deregister", "error", err)
|
||||||
|
} else {
|
||||||
|
xlog.Info("Deregistered from frontend")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -94,7 +94,7 @@ func (c *StoreClient) Find(req FindRequest) (*FindResponse, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to perform a request without expecting a response body
|
// Helper function to perform a request without expecting a response body
|
||||||
func (c *StoreClient) doRequest(path string, data interface{}) error {
|
func (c *StoreClient) doRequest(path string, data any) error {
|
||||||
jsonData, err := json.Marshal(data)
|
jsonData, err := json.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -120,7 +120,7 @@ func (c *StoreClient) doRequest(path string, data interface{}) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to perform a request and parse the response body
|
// Helper function to perform a request and parse the response body
|
||||||
func (c *StoreClient) doRequestWithResponse(path string, data interface{}) ([]byte, error) {
|
func (c *StoreClient) doRequestWithResponse(path string, data any) ([]byte, error) {
|
||||||
jsonData, err := json.Marshal(data)
|
jsonData, err := json.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
||||||
|
|
@ -83,8 +83,8 @@ type ApplicationConfig struct {
|
||||||
|
|
||||||
APIAddress string
|
APIAddress string
|
||||||
|
|
||||||
LlamaCPPTunnelCallback func(tunnels []string)
|
LlamaCPPTunnelCallback func(tunnels []string)
|
||||||
MLXTunnelCallback func(tunnels []string)
|
MLXTunnelCallback func(tunnels []string)
|
||||||
|
|
||||||
DisableRuntimeSettings bool
|
DisableRuntimeSettings bool
|
||||||
|
|
||||||
|
|
@ -99,47 +99,50 @@ type ApplicationConfig struct {
|
||||||
|
|
||||||
// Authentication & Authorization
|
// Authentication & Authorization
|
||||||
Auth AuthConfig
|
Auth AuthConfig
|
||||||
|
|
||||||
|
// Distributed / Horizontal Scaling
|
||||||
|
Distributed DistributedConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthConfig holds configuration for user authentication and authorization.
|
// AuthConfig holds configuration for user authentication and authorization.
|
||||||
type AuthConfig struct {
|
type AuthConfig struct {
|
||||||
Enabled bool
|
Enabled bool
|
||||||
DatabaseURL string // "postgres://..." or file path for SQLite
|
DatabaseURL string // "postgres://..." or file path for SQLite
|
||||||
GitHubClientID string
|
GitHubClientID string
|
||||||
GitHubClientSecret string
|
GitHubClientSecret string
|
||||||
OIDCIssuer string // OIDC issuer URL for auto-discovery (e.g. https://accounts.google.com)
|
OIDCIssuer string // OIDC issuer URL for auto-discovery (e.g. https://accounts.google.com)
|
||||||
OIDCClientID string
|
OIDCClientID string
|
||||||
OIDCClientSecret string
|
OIDCClientSecret string
|
||||||
BaseURL string // for OAuth callback URLs (e.g. "http://localhost:8080")
|
BaseURL string // for OAuth callback URLs (e.g. "http://localhost:8080")
|
||||||
AdminEmail string // auto-promote to admin on login
|
AdminEmail string // auto-promote to admin on login
|
||||||
RegistrationMode string // "open", "approval" (default when empty), "invite"
|
RegistrationMode string // "open", "approval" (default when empty), "invite"
|
||||||
DisableLocalAuth bool // disable local email/password registration and login
|
DisableLocalAuth bool // disable local email/password registration and login
|
||||||
APIKeyHMACSecret string // HMAC secret for API key hashing; auto-generated if empty
|
APIKeyHMACSecret string // HMAC secret for API key hashing; auto-generated if empty
|
||||||
DefaultAPIKeyExpiry string // default expiry duration for API keys (e.g. "90d"); empty = no expiry
|
DefaultAPIKeyExpiry string // default expiry duration for API keys (e.g. "90d"); empty = no expiry
|
||||||
}
|
}
|
||||||
|
|
||||||
// AgentPoolConfig holds configuration for the LocalAGI agent pool integration.
|
// AgentPoolConfig holds configuration for the LocalAGI agent pool integration.
|
||||||
type AgentPoolConfig struct {
|
type AgentPoolConfig struct {
|
||||||
Enabled bool // default: true (disabled by LOCALAI_DISABLE_AGENTS=true)
|
Enabled bool // default: true (disabled by LOCALAI_DISABLE_AGENTS=true)
|
||||||
StateDir string // default: DynamicConfigsDir (LocalAI configuration folder)
|
StateDir string // default: DynamicConfigsDir (LocalAI configuration folder)
|
||||||
APIURL string // default: self-referencing LocalAI (http://127.0.0.1:<port>)
|
APIURL string // default: self-referencing LocalAI (http://127.0.0.1:<port>)
|
||||||
APIKey string // default: first API key from LocalAI config
|
APIKey string // default: first API key from LocalAI config
|
||||||
DefaultModel string
|
DefaultModel string
|
||||||
MultimodalModel string
|
MultimodalModel string
|
||||||
TranscriptionModel string
|
TranscriptionModel string
|
||||||
TranscriptionLanguage string
|
TranscriptionLanguage string
|
||||||
TTSModel string
|
TTSModel string
|
||||||
Timeout string // default: "5m"
|
Timeout string // default: "5m"
|
||||||
EnableSkills bool
|
EnableSkills bool
|
||||||
EnableLogs bool
|
EnableLogs bool
|
||||||
CustomActionsDir string
|
CustomActionsDir string
|
||||||
CollectionDBPath string
|
CollectionDBPath string
|
||||||
VectorEngine string // default: "chromem"
|
VectorEngine string // default: "chromem"
|
||||||
EmbeddingModel string // default: "granite-embedding-107m-multilingual"
|
EmbeddingModel string // default: "granite-embedding-107m-multilingual"
|
||||||
MaxChunkingSize int // default: 400
|
MaxChunkingSize int // default: 400
|
||||||
ChunkOverlap int // default: 0
|
ChunkOverlap int // default: 0
|
||||||
DatabaseURL string
|
DatabaseURL string
|
||||||
AgentHubURL string // default: "https://agenthub.localai.io"
|
AgentHubURL string // default: "https://agenthub.localai.io"
|
||||||
}
|
}
|
||||||
|
|
||||||
type AppOption func(*ApplicationConfig)
|
type AppOption func(*ApplicationConfig)
|
||||||
|
|
@ -155,12 +158,12 @@ func NewApplicationConfig(o ...AppOption) *ApplicationConfig {
|
||||||
WatchDogInterval: 500 * time.Millisecond, // Default: 500ms
|
WatchDogInterval: 500 * time.Millisecond, // Default: 500ms
|
||||||
TracingMaxItems: 1024,
|
TracingMaxItems: 1024,
|
||||||
AgentPool: AgentPoolConfig{
|
AgentPool: AgentPoolConfig{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Timeout: "5m",
|
Timeout: "5m",
|
||||||
VectorEngine: "chromem",
|
VectorEngine: "chromem",
|
||||||
EmbeddingModel: "granite-embedding-107m-multilingual",
|
EmbeddingModel: "granite-embedding-107m-multilingual",
|
||||||
MaxChunkingSize: 400,
|
MaxChunkingSize: 400,
|
||||||
AgentHubURL: "https://agenthub.localai.io",
|
AgentHubURL: "https://agenthub.localai.io",
|
||||||
},
|
},
|
||||||
PathWithoutAuth: []string{
|
PathWithoutAuth: []string{
|
||||||
"/static/",
|
"/static/",
|
||||||
|
|
@ -904,40 +907,40 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
||||||
agentPoolCollectionDBPath := o.AgentPool.CollectionDBPath
|
agentPoolCollectionDBPath := o.AgentPool.CollectionDBPath
|
||||||
|
|
||||||
return RuntimeSettings{
|
return RuntimeSettings{
|
||||||
WatchdogEnabled: &watchdogEnabled,
|
WatchdogEnabled: &watchdogEnabled,
|
||||||
WatchdogIdleEnabled: &watchdogIdle,
|
WatchdogIdleEnabled: &watchdogIdle,
|
||||||
WatchdogBusyEnabled: &watchdogBusy,
|
WatchdogBusyEnabled: &watchdogBusy,
|
||||||
WatchdogIdleTimeout: &idleTimeout,
|
WatchdogIdleTimeout: &idleTimeout,
|
||||||
WatchdogBusyTimeout: &busyTimeout,
|
WatchdogBusyTimeout: &busyTimeout,
|
||||||
WatchdogInterval: &watchdogInterval,
|
WatchdogInterval: &watchdogInterval,
|
||||||
SingleBackend: &singleBackend,
|
SingleBackend: &singleBackend,
|
||||||
MaxActiveBackends: &maxActiveBackends,
|
MaxActiveBackends: &maxActiveBackends,
|
||||||
ParallelBackendRequests: ¶llelBackendRequests,
|
ParallelBackendRequests: ¶llelBackendRequests,
|
||||||
MemoryReclaimerEnabled: &memoryReclaimerEnabled,
|
MemoryReclaimerEnabled: &memoryReclaimerEnabled,
|
||||||
MemoryReclaimerThreshold: &memoryReclaimerThreshold,
|
MemoryReclaimerThreshold: &memoryReclaimerThreshold,
|
||||||
ForceEvictionWhenBusy: &forceEvictionWhenBusy,
|
ForceEvictionWhenBusy: &forceEvictionWhenBusy,
|
||||||
LRUEvictionMaxRetries: &lruEvictionMaxRetries,
|
LRUEvictionMaxRetries: &lruEvictionMaxRetries,
|
||||||
LRUEvictionRetryInterval: &lruEvictionRetryInterval,
|
LRUEvictionRetryInterval: &lruEvictionRetryInterval,
|
||||||
Threads: &threads,
|
Threads: &threads,
|
||||||
ContextSize: &contextSize,
|
ContextSize: &contextSize,
|
||||||
F16: &f16,
|
F16: &f16,
|
||||||
Debug: &debug,
|
Debug: &debug,
|
||||||
TracingMaxItems: &tracingMaxItems,
|
TracingMaxItems: &tracingMaxItems,
|
||||||
EnableTracing: &enableTracing,
|
EnableTracing: &enableTracing,
|
||||||
EnableBackendLogging: &enableBackendLogging,
|
EnableBackendLogging: &enableBackendLogging,
|
||||||
CORS: &cors,
|
CORS: &cors,
|
||||||
CSRF: &csrf,
|
CSRF: &csrf,
|
||||||
CORSAllowOrigins: &corsAllowOrigins,
|
CORSAllowOrigins: &corsAllowOrigins,
|
||||||
P2PToken: &p2pToken,
|
P2PToken: &p2pToken,
|
||||||
P2PNetworkID: &p2pNetworkID,
|
P2PNetworkID: &p2pNetworkID,
|
||||||
Federated: &federated,
|
Federated: &federated,
|
||||||
Galleries: &galleries,
|
Galleries: &galleries,
|
||||||
BackendGalleries: &backendGalleries,
|
BackendGalleries: &backendGalleries,
|
||||||
AutoloadGalleries: &autoloadGalleries,
|
AutoloadGalleries: &autoloadGalleries,
|
||||||
AutoloadBackendGalleries: &autoloadBackendGalleries,
|
AutoloadBackendGalleries: &autoloadBackendGalleries,
|
||||||
ApiKeys: &apiKeys,
|
ApiKeys: &apiKeys,
|
||||||
AgentJobRetentionDays: &agentJobRetentionDays,
|
AgentJobRetentionDays: &agentJobRetentionDays,
|
||||||
OpenResponsesStoreTTL: &openResponsesStoreTTL,
|
OpenResponsesStoreTTL: &openResponsesStoreTTL,
|
||||||
AgentPoolEnabled: &agentPoolEnabled,
|
AgentPoolEnabled: &agentPoolEnabled,
|
||||||
AgentPoolDefaultModel: &agentPoolDefaultModel,
|
AgentPoolDefaultModel: &agentPoolDefaultModel,
|
||||||
AgentPoolEmbeddingModel: &agentPoolEmbeddingModel,
|
AgentPoolEmbeddingModel: &agentPoolEmbeddingModel,
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
|
||||||
F16: true,
|
F16: true,
|
||||||
Debug: true,
|
Debug: true,
|
||||||
CORS: true,
|
CORS: true,
|
||||||
DisableCSRF: true,
|
DisableCSRF: true,
|
||||||
CORSAllowOrigins: "https://example.com",
|
CORSAllowOrigins: "https://example.com",
|
||||||
P2PToken: "test-token",
|
P2PToken: "test-token",
|
||||||
P2PNetworkID: "test-network",
|
P2PNetworkID: "test-network",
|
||||||
|
|
@ -463,7 +463,7 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
|
||||||
F16: true,
|
F16: true,
|
||||||
Debug: false,
|
Debug: false,
|
||||||
CORS: true,
|
CORS: true,
|
||||||
DisableCSRF: false,
|
DisableCSRF: false,
|
||||||
CORSAllowOrigins: "https://test.com",
|
CORSAllowOrigins: "https://test.com",
|
||||||
P2PToken: "round-trip-token",
|
P2PToken: "round-trip-token",
|
||||||
P2PNetworkID: "round-trip-network",
|
P2PNetworkID: "round-trip-network",
|
||||||
|
|
|
||||||
188
core/config/distributed_config.go
Normal file
188
core/config/distributed_config.go
Normal file
|
|
@ -0,0 +1,188 @@
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mudler/xlog"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DistributedConfig holds configuration for horizontal scaling mode.
|
||||||
|
// When Enabled is true, PostgreSQL and NATS are required.
|
||||||
|
type DistributedConfig struct {
|
||||||
|
Enabled bool // --distributed / LOCALAI_DISTRIBUTED
|
||||||
|
InstanceID string // --instance-id / LOCALAI_INSTANCE_ID (auto-generated UUID if empty)
|
||||||
|
NatsURL string // --nats-url / LOCALAI_NATS_URL
|
||||||
|
StorageURL string // --storage-url / LOCALAI_STORAGE_URL (S3 endpoint)
|
||||||
|
RegistrationToken string // --registration-token / LOCALAI_REGISTRATION_TOKEN (required token for node registration)
|
||||||
|
AutoApproveNodes bool // --auto-approve-nodes / LOCALAI_AUTO_APPROVE_NODES (skip admin approval for new workers)
|
||||||
|
|
||||||
|
// S3 configuration (used when StorageURL is set)
|
||||||
|
StorageBucket string // --storage-bucket / LOCALAI_STORAGE_BUCKET
|
||||||
|
StorageRegion string // --storage-region / LOCALAI_STORAGE_REGION
|
||||||
|
StorageAccessKey string // --storage-access-key / LOCALAI_STORAGE_ACCESS_KEY
|
||||||
|
StorageSecretKey string // --storage-secret-key / LOCALAI_STORAGE_SECRET_KEY
|
||||||
|
|
||||||
|
// Timeout configuration (all have sensible defaults — zero means use default)
|
||||||
|
MCPToolTimeout time.Duration // MCP tool execution timeout (default 360s)
|
||||||
|
MCPDiscoveryTimeout time.Duration // MCP discovery timeout (default 60s)
|
||||||
|
WorkerWaitTimeout time.Duration // Max wait for healthy worker at startup (default 5m)
|
||||||
|
DrainTimeout time.Duration // Time to wait for in-flight requests during drain (default 30s)
|
||||||
|
HealthCheckInterval time.Duration // Health monitor check interval (default 15s)
|
||||||
|
StaleNodeThreshold time.Duration // Time before a node is considered stale (default 60s)
|
||||||
|
PerModelHealthCheck bool // Enable per-model backend health checking (default false)
|
||||||
|
MCPCIJobTimeout time.Duration // MCP CI job execution timeout (default 10m)
|
||||||
|
|
||||||
|
MaxUploadSize int64 // Maximum upload body size in bytes (default 50 GB)
|
||||||
|
|
||||||
|
AgentWorkerConcurrency int `yaml:"agent_worker_concurrency" json:"agent_worker_concurrency" env:"LOCALAI_AGENT_WORKER_CONCURRENCY"`
|
||||||
|
JobWorkerConcurrency int `yaml:"job_worker_concurrency" json:"job_worker_concurrency" env:"LOCALAI_JOB_WORKER_CONCURRENCY"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate checks that the distributed configuration is internally consistent.
|
||||||
|
// It returns nil if distributed mode is disabled.
|
||||||
|
func (c DistributedConfig) Validate() error {
|
||||||
|
if !c.Enabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if c.NatsURL == "" {
|
||||||
|
return fmt.Errorf("distributed mode requires --nats-url / LOCALAI_NATS_URL")
|
||||||
|
}
|
||||||
|
// S3 credentials must be paired
|
||||||
|
if (c.StorageAccessKey != "" && c.StorageSecretKey == "") ||
|
||||||
|
(c.StorageAccessKey == "" && c.StorageSecretKey != "") {
|
||||||
|
return fmt.Errorf("storage-access-key and storage-secret-key must both be set or both empty")
|
||||||
|
}
|
||||||
|
// Warn about missing registration token (not an error)
|
||||||
|
if c.RegistrationToken == "" {
|
||||||
|
xlog.Warn("distributed mode running without registration token — node endpoints are unprotected")
|
||||||
|
}
|
||||||
|
// Check for negative durations
|
||||||
|
for name, d := range map[string]time.Duration{
|
||||||
|
"mcp-tool-timeout": c.MCPToolTimeout,
|
||||||
|
"mcp-discovery-timeout": c.MCPDiscoveryTimeout,
|
||||||
|
"worker-wait-timeout": c.WorkerWaitTimeout,
|
||||||
|
"drain-timeout": c.DrainTimeout,
|
||||||
|
"health-check-interval": c.HealthCheckInterval,
|
||||||
|
"stale-node-threshold": c.StaleNodeThreshold,
|
||||||
|
"mcp-ci-job-timeout": c.MCPCIJobTimeout,
|
||||||
|
} {
|
||||||
|
if d < 0 {
|
||||||
|
return fmt.Errorf("%s must not be negative", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Distributed config options
|
||||||
|
|
||||||
|
var EnableDistributed = func(o *ApplicationConfig) {
|
||||||
|
o.Distributed.Enabled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithDistributedInstanceID(id string) AppOption {
|
||||||
|
return func(o *ApplicationConfig) {
|
||||||
|
o.Distributed.InstanceID = id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithNatsURL(url string) AppOption {
|
||||||
|
return func(o *ApplicationConfig) {
|
||||||
|
o.Distributed.NatsURL = url
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRegistrationToken(token string) AppOption {
|
||||||
|
return func(o *ApplicationConfig) {
|
||||||
|
o.Distributed.RegistrationToken = token
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithStorageURL(url string) AppOption {
|
||||||
|
return func(o *ApplicationConfig) {
|
||||||
|
o.Distributed.StorageURL = url
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithStorageBucket(bucket string) AppOption {
|
||||||
|
return func(o *ApplicationConfig) {
|
||||||
|
o.Distributed.StorageBucket = bucket
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithStorageRegion(region string) AppOption {
|
||||||
|
return func(o *ApplicationConfig) {
|
||||||
|
o.Distributed.StorageRegion = region
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithStorageAccessKey(key string) AppOption {
|
||||||
|
return func(o *ApplicationConfig) {
|
||||||
|
o.Distributed.StorageAccessKey = key
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithStorageSecretKey(key string) AppOption {
|
||||||
|
return func(o *ApplicationConfig) {
|
||||||
|
o.Distributed.StorageSecretKey = key
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var EnableAutoApproveNodes = func(o *ApplicationConfig) {
|
||||||
|
o.Distributed.AutoApproveNodes = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Defaults for distributed timeouts.
|
||||||
|
const (
|
||||||
|
DefaultMCPToolTimeout = 360 * time.Second
|
||||||
|
DefaultMCPDiscoveryTimeout = 60 * time.Second
|
||||||
|
DefaultWorkerWaitTimeout = 5 * time.Minute
|
||||||
|
DefaultDrainTimeout = 30 * time.Second
|
||||||
|
DefaultHealthCheckInterval = 15 * time.Second
|
||||||
|
DefaultStaleNodeThreshold = 60 * time.Second
|
||||||
|
DefaultMCPCIJobTimeout = 10 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultMaxUploadSize is the default maximum upload body size (50 GB).
|
||||||
|
const DefaultMaxUploadSize int64 = 50 << 30
|
||||||
|
|
||||||
|
// MCPToolTimeoutOrDefault returns the configured timeout or the default.
|
||||||
|
func (c DistributedConfig) MCPToolTimeoutOrDefault() time.Duration {
|
||||||
|
return cmp.Or(c.MCPToolTimeout, DefaultMCPToolTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MCPDiscoveryTimeoutOrDefault returns the configured timeout or the default.
|
||||||
|
func (c DistributedConfig) MCPDiscoveryTimeoutOrDefault() time.Duration {
|
||||||
|
return cmp.Or(c.MCPDiscoveryTimeout, DefaultMCPDiscoveryTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WorkerWaitTimeoutOrDefault returns the configured timeout or the default.
|
||||||
|
func (c DistributedConfig) WorkerWaitTimeoutOrDefault() time.Duration {
|
||||||
|
return cmp.Or(c.WorkerWaitTimeout, DefaultWorkerWaitTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DrainTimeoutOrDefault returns the configured timeout or the default.
|
||||||
|
func (c DistributedConfig) DrainTimeoutOrDefault() time.Duration {
|
||||||
|
return cmp.Or(c.DrainTimeout, DefaultDrainTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HealthCheckIntervalOrDefault returns the configured interval or the default.
|
||||||
|
func (c DistributedConfig) HealthCheckIntervalOrDefault() time.Duration {
|
||||||
|
return cmp.Or(c.HealthCheckInterval, DefaultHealthCheckInterval)
|
||||||
|
}
|
||||||
|
|
||||||
|
// StaleNodeThresholdOrDefault returns the configured threshold or the default.
|
||||||
|
func (c DistributedConfig) StaleNodeThresholdOrDefault() time.Duration {
|
||||||
|
return cmp.Or(c.StaleNodeThreshold, DefaultStaleNodeThreshold)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MCPCIJobTimeoutOrDefault returns the configured MCP CI job timeout or the default.
|
||||||
|
func (c DistributedConfig) MCPCIJobTimeoutOrDefault() time.Duration {
|
||||||
|
return cmp.Or(c.MCPCIJobTimeout, DefaultMCPCIJobTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaxUploadSizeOrDefault returns the configured max upload size or the default.
|
||||||
|
func (c DistributedConfig) MaxUploadSizeOrDefault() int64 {
|
||||||
|
return cmp.Or(c.MaxUploadSize, DefaultMaxUploadSize)
|
||||||
|
}
|
||||||
|
|
@ -46,11 +46,11 @@ type ModelConfig struct {
|
||||||
KnownUsecases *ModelConfigUsecase `yaml:"-" json:"-"`
|
KnownUsecases *ModelConfigUsecase `yaml:"-" json:"-"`
|
||||||
Pipeline Pipeline `yaml:"pipeline,omitempty" json:"pipeline,omitempty"`
|
Pipeline Pipeline `yaml:"pipeline,omitempty" json:"pipeline,omitempty"`
|
||||||
|
|
||||||
PromptStrings, InputStrings []string `yaml:"-" json:"-"`
|
PromptStrings, InputStrings []string `yaml:"-" json:"-"`
|
||||||
InputToken [][]int `yaml:"-" json:"-"`
|
InputToken [][]int `yaml:"-" json:"-"`
|
||||||
functionCallString, functionCallNameString string `yaml:"-" json:"-"`
|
functionCallString, functionCallNameString string `yaml:"-" json:"-"`
|
||||||
ResponseFormat string `yaml:"-" json:"-"`
|
ResponseFormat string `yaml:"-" json:"-"`
|
||||||
ResponseFormatMap map[string]interface{} `yaml:"-" json:"-"`
|
ResponseFormatMap map[string]any `yaml:"-" json:"-"`
|
||||||
|
|
||||||
FunctionsConfig functions.FunctionsConfig `yaml:"function,omitempty" json:"function,omitempty"`
|
FunctionsConfig functions.FunctionsConfig `yaml:"function,omitempty" json:"function,omitempty"`
|
||||||
ReasoningConfig reasoning.Config `yaml:"reasoning,omitempty" json:"reasoning,omitempty"`
|
ReasoningConfig reasoning.Config `yaml:"reasoning,omitempty" json:"reasoning,omitempty"`
|
||||||
|
|
@ -105,6 +105,11 @@ type AgentConfig struct {
|
||||||
ForceReasoningTool bool `yaml:"force_reasoning_tool,omitempty" json:"force_reasoning_tool,omitempty"`
|
ForceReasoningTool bool `yaml:"force_reasoning_tool,omitempty" json:"force_reasoning_tool,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HasMCPServers returns true if any MCP servers (remote or stdio) are configured.
|
||||||
|
func (c MCPConfig) HasMCPServers() bool {
|
||||||
|
return c.Servers != "" || c.Stdio != ""
|
||||||
|
}
|
||||||
|
|
||||||
func (c *MCPConfig) MCPConfigFromYAML() (MCPGenericConfig[MCPRemoteServers], MCPGenericConfig[MCPSTDIOServers], error) {
|
func (c *MCPConfig) MCPConfigFromYAML() (MCPGenericConfig[MCPRemoteServers], MCPGenericConfig[MCPSTDIOServers], error) {
|
||||||
var remote MCPGenericConfig[MCPRemoteServers]
|
var remote MCPGenericConfig[MCPRemoteServers]
|
||||||
var stdio MCPGenericConfig[MCPSTDIOServers]
|
var stdio MCPGenericConfig[MCPSTDIOServers]
|
||||||
|
|
@ -619,15 +624,32 @@ func (c *ModelConfig) HasUsecases(u ModelConfigUsecase) bool {
|
||||||
// In its current state, this function should ideally check for properties of the config like templates, rather than the direct backend name checks for the lower half.
|
// In its current state, this function should ideally check for properties of the config like templates, rather than the direct backend name checks for the lower half.
|
||||||
// This avoids the maintenance burden of updating this list for each new backend - but unfortunately, that's the best option for some services currently.
|
// This avoids the maintenance burden of updating this list for each new backend - but unfortunately, that's the best option for some services currently.
|
||||||
func (c *ModelConfig) GuessUsecases(u ModelConfigUsecase) bool {
|
func (c *ModelConfig) GuessUsecases(u ModelConfigUsecase) bool {
|
||||||
|
// Backends that are clearly not text-generation
|
||||||
|
nonTextGenBackends := []string{
|
||||||
|
"whisper", "piper", "kokoro",
|
||||||
|
"diffusers", "stablediffusion", "stablediffusion-ggml",
|
||||||
|
"rerankers", "silero-vad", "rfdetr",
|
||||||
|
"transformers-musicgen", "ace-step", "acestep-cpp",
|
||||||
|
}
|
||||||
|
|
||||||
if (u & FLAG_CHAT) == FLAG_CHAT {
|
if (u & FLAG_CHAT) == FLAG_CHAT {
|
||||||
if c.TemplateConfig.Chat == "" && c.TemplateConfig.ChatMessage == "" && !c.TemplateConfig.UseTokenizerTemplate {
|
if c.TemplateConfig.Chat == "" && c.TemplateConfig.ChatMessage == "" && !c.TemplateConfig.UseTokenizerTemplate {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
if slices.Contains(nonTextGenBackends, c.Backend) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if c.Embeddings != nil && *c.Embeddings {
|
||||||
|
return false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (u & FLAG_COMPLETION) == FLAG_COMPLETION {
|
if (u & FLAG_COMPLETION) == FLAG_COMPLETION {
|
||||||
if c.TemplateConfig.Completion == "" {
|
if c.TemplateConfig.Completion == "" {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
if slices.Contains(nonTextGenBackends, c.Backend) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (u & FLAG_EDIT) == FLAG_EDIT {
|
if (u & FLAG_EDIT) == FLAG_EDIT {
|
||||||
if c.TemplateConfig.Edit == "" {
|
if c.TemplateConfig.Edit == "" {
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,13 @@
|
||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"cmp"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"sort"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
|
@ -215,8 +216,8 @@ func (bcl *ModelConfigLoader) GetAllModelsConfigs() []ModelConfig {
|
||||||
res = append(res, v)
|
res = append(res, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
sort.SliceStable(res, func(i, j int) bool {
|
slices.SortStableFunc(res, func(a, b ModelConfig) int {
|
||||||
return res[i].Name < res[j].Name
|
return cmp.Compare(a.Name, b.Name)
|
||||||
})
|
})
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
|
||||||
|
|
@ -27,15 +27,15 @@ type RuntimeSettings struct {
|
||||||
MemoryReclaimerThreshold *float64 `json:"memory_reclaimer_threshold,omitempty"` // Threshold 0.0-1.0 (e.g., 0.95 = 95%)
|
MemoryReclaimerThreshold *float64 `json:"memory_reclaimer_threshold,omitempty"` // Threshold 0.0-1.0 (e.g., 0.95 = 95%)
|
||||||
|
|
||||||
// Eviction settings
|
// Eviction settings
|
||||||
ForceEvictionWhenBusy *bool `json:"force_eviction_when_busy,omitempty"` // Force eviction even when models have active API calls (default: false for safety)
|
ForceEvictionWhenBusy *bool `json:"force_eviction_when_busy,omitempty"` // Force eviction even when models have active API calls (default: false for safety)
|
||||||
LRUEvictionMaxRetries *int `json:"lru_eviction_max_retries,omitempty"` // Maximum number of retries when waiting for busy models to become idle (default: 30)
|
LRUEvictionMaxRetries *int `json:"lru_eviction_max_retries,omitempty"` // Maximum number of retries when waiting for busy models to become idle (default: 30)
|
||||||
LRUEvictionRetryInterval *string `json:"lru_eviction_retry_interval,omitempty"` // Interval between retries when waiting for busy models (e.g., 1s, 2s) (default: 1s)
|
LRUEvictionRetryInterval *string `json:"lru_eviction_retry_interval,omitempty"` // Interval between retries when waiting for busy models (e.g., 1s, 2s) (default: 1s)
|
||||||
|
|
||||||
// Performance settings
|
// Performance settings
|
||||||
Threads *int `json:"threads,omitempty"`
|
Threads *int `json:"threads,omitempty"`
|
||||||
ContextSize *int `json:"context_size,omitempty"`
|
ContextSize *int `json:"context_size,omitempty"`
|
||||||
F16 *bool `json:"f16,omitempty"`
|
F16 *bool `json:"f16,omitempty"`
|
||||||
Debug *bool `json:"debug,omitempty"`
|
Debug *bool `json:"debug,omitempty"`
|
||||||
EnableTracing *bool `json:"enable_tracing,omitempty"`
|
EnableTracing *bool `json:"enable_tracing,omitempty"`
|
||||||
TracingMaxItems *int `json:"tracing_max_items,omitempty"`
|
TracingMaxItems *int `json:"tracing_max_items,omitempty"`
|
||||||
EnableBackendLogging *bool `json:"enable_backend_logging,omitempty"`
|
EnableBackendLogging *bool `json:"enable_backend_logging,omitempty"`
|
||||||
|
|
@ -66,11 +66,11 @@ type RuntimeSettings struct {
|
||||||
OpenResponsesStoreTTL *string `json:"open_responses_store_ttl,omitempty"` // TTL for stored responses (e.g., "1h", "30m", "0" = no expiration)
|
OpenResponsesStoreTTL *string `json:"open_responses_store_ttl,omitempty"` // TTL for stored responses (e.g., "1h", "30m", "0" = no expiration)
|
||||||
|
|
||||||
// Agent Pool settings
|
// Agent Pool settings
|
||||||
AgentPoolEnabled *bool `json:"agent_pool_enabled,omitempty"`
|
AgentPoolEnabled *bool `json:"agent_pool_enabled,omitempty"`
|
||||||
AgentPoolDefaultModel *string `json:"agent_pool_default_model,omitempty"`
|
AgentPoolDefaultModel *string `json:"agent_pool_default_model,omitempty"`
|
||||||
AgentPoolEmbeddingModel *string `json:"agent_pool_embedding_model,omitempty"`
|
AgentPoolEmbeddingModel *string `json:"agent_pool_embedding_model,omitempty"`
|
||||||
AgentPoolMaxChunkingSize *int `json:"agent_pool_max_chunking_size,omitempty"`
|
AgentPoolMaxChunkingSize *int `json:"agent_pool_max_chunking_size,omitempty"`
|
||||||
AgentPoolChunkOverlap *int `json:"agent_pool_chunk_overlap,omitempty"`
|
AgentPoolChunkOverlap *int `json:"agent_pool_chunk_overlap,omitempty"`
|
||||||
AgentPoolEnableLogs *bool `json:"agent_pool_enable_logs,omitempty"`
|
AgentPoolEnableLogs *bool `json:"agent_pool_enable_logs,omitempty"`
|
||||||
AgentPoolCollectionDBPath *string `json:"agent_pool_collection_db_path,omitempty"`
|
AgentPoolCollectionDBPath *string `json:"agent_pool_collection_db_path,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,10 @@ package explorer
|
||||||
// A simple JSON database for storing and retrieving p2p network tokens and a name and description.
|
// A simple JSON database for storing and retrieving p2p network tokens and a name and description.
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"cmp"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"os"
|
"os"
|
||||||
"sort"
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/gofrs/flock"
|
"github.com/gofrs/flock"
|
||||||
|
|
@ -89,9 +90,8 @@ func (db *Database) TokenList() []string {
|
||||||
tokens = append(tokens, k)
|
tokens = append(tokens, k)
|
||||||
}
|
}
|
||||||
|
|
||||||
sort.Slice(tokens, func(i, j int) bool {
|
slices.SortFunc(tokens, func(a, b string) int {
|
||||||
// sort by token
|
return cmp.Compare(a, b)
|
||||||
return tokens[i] < tokens[j]
|
|
||||||
})
|
})
|
||||||
|
|
||||||
return tokens
|
return tokens
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ import (
|
||||||
|
|
||||||
// modelConfigCacheEntry holds a cached parsed config_file map from a URL-referenced model config.
|
// modelConfigCacheEntry holds a cached parsed config_file map from a URL-referenced model config.
|
||||||
type modelConfigCacheEntry struct {
|
type modelConfigCacheEntry struct {
|
||||||
configMap map[string]interface{}
|
configMap map[string]any
|
||||||
lastUpdated time.Time
|
lastUpdated time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -57,7 +57,7 @@ func resolveBackend(m *GalleryModel, basePath string) string {
|
||||||
// fetchModelConfigMap fetches a model config URL, parses the config_file YAML string
|
// fetchModelConfigMap fetches a model config URL, parses the config_file YAML string
|
||||||
// inside it, and returns the result as a map. Results are cached for 1 hour.
|
// inside it, and returns the result as a map. Results are cached for 1 hour.
|
||||||
// Local file:// URLs skip the cache so edits are picked up immediately.
|
// Local file:// URLs skip the cache so edits are picked up immediately.
|
||||||
func fetchModelConfigMap(modelURL, basePath string) map[string]interface{} {
|
func fetchModelConfigMap(modelURL, basePath string) map[string]any {
|
||||||
// Check cache (skip for file:// URLs so local edits are picked up immediately)
|
// Check cache (skip for file:// URLs so local edits are picked up immediately)
|
||||||
isLocal := strings.HasPrefix(modelURL, downloader.LocalPrefix)
|
isLocal := strings.HasPrefix(modelURL, downloader.LocalPrefix)
|
||||||
if !isLocal && modelConfigCache.Exists(modelURL) {
|
if !isLocal && modelConfigCache.Exists(modelURL) {
|
||||||
|
|
@ -75,15 +75,15 @@ func fetchModelConfigMap(modelURL, basePath string) map[string]interface{} {
|
||||||
// Cache the failure for remote URLs to avoid repeated fetch attempts
|
// Cache the failure for remote URLs to avoid repeated fetch attempts
|
||||||
if !isLocal {
|
if !isLocal {
|
||||||
modelConfigCache.Set(modelURL, modelConfigCacheEntry{
|
modelConfigCache.Set(modelURL, modelConfigCacheEntry{
|
||||||
configMap: map[string]interface{}{},
|
configMap: map[string]any{},
|
||||||
lastUpdated: time.Now(),
|
lastUpdated: time.Now(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return map[string]interface{}{}
|
return map[string]any{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse the config_file YAML string into a map
|
// Parse the config_file YAML string into a map
|
||||||
configMap := make(map[string]interface{})
|
configMap := make(map[string]any)
|
||||||
if modelConfig.ConfigFile != "" {
|
if modelConfig.ConfigFile != "" {
|
||||||
if err := yaml.Unmarshal([]byte(modelConfig.ConfigFile), &configMap); err != nil {
|
if err := yaml.Unmarshal([]byte(modelConfig.ConfigFile), &configMap); err != nil {
|
||||||
xlog.Debug("Failed to parse config_file for backend resolution", "url", modelURL, "error", err)
|
xlog.Debug("Failed to parse config_file for backend resolution", "url", modelURL, "error", err)
|
||||||
|
|
@ -108,13 +108,11 @@ func prefetchModelConfigs(urls []string, basePath string) {
|
||||||
sem := make(chan struct{}, maxConcurrency)
|
sem := make(chan struct{}, maxConcurrency)
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for _, url := range urls {
|
for _, url := range urls {
|
||||||
wg.Add(1)
|
wg.Go(func() {
|
||||||
go func(u string) {
|
|
||||||
defer wg.Done()
|
|
||||||
sem <- struct{}{}
|
sem <- struct{}{}
|
||||||
defer func() { <-sem }()
|
defer func() { <-sem }()
|
||||||
fetchModelConfigMap(u, basePath)
|
fetchModelConfigMap(url, basePath)
|
||||||
}(url)
|
})
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,10 +4,10 @@ package gallery
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"os"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
@ -20,6 +20,9 @@ import (
|
||||||
cp "github.com/otiai10/copy"
|
cp "github.com/otiai10/copy"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrBackendNotFound is returned when a backend is not found in the system.
|
||||||
|
var ErrBackendNotFound = errors.New("backend not found")
|
||||||
|
|
||||||
const (
|
const (
|
||||||
metadataFile = "metadata.json"
|
metadataFile = "metadata.json"
|
||||||
runFile = "run.sh"
|
runFile = "run.sh"
|
||||||
|
|
@ -198,9 +201,16 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL
|
||||||
} else {
|
} else {
|
||||||
xlog.Debug("Downloading backend", "uri", config.URI, "backendPath", backendPath)
|
xlog.Debug("Downloading backend", "uri", config.URI, "backendPath", backendPath)
|
||||||
if err := uri.DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err != nil {
|
if err := uri.DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err != nil {
|
||||||
// Don't remove backendPath here — fallback OCI extractions need the directory to exist
|
|
||||||
xlog.Debug("Backend download failed, trying fallback", "backendPath", backendPath, "error", err)
|
xlog.Debug("Backend download failed, trying fallback", "backendPath", backendPath, "error", err)
|
||||||
|
|
||||||
|
// resetBackendPath cleans up partial state from a failed OCI extraction
|
||||||
|
// so the next download attempt starts fresh. The directory is re-created
|
||||||
|
// because OCI image extractors need it to exist for writing files into.
|
||||||
|
resetBackendPath := func() {
|
||||||
|
os.RemoveAll(backendPath)
|
||||||
|
os.MkdirAll(backendPath, 0750)
|
||||||
|
}
|
||||||
|
|
||||||
success := false
|
success := false
|
||||||
// Try to download from mirrors
|
// Try to download from mirrors
|
||||||
for _, mirror := range config.Mirrors {
|
for _, mirror := range config.Mirrors {
|
||||||
|
|
@ -210,6 +220,7 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
resetBackendPath()
|
||||||
if err := downloader.URI(mirror).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil {
|
if err := downloader.URI(mirror).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil {
|
||||||
success = true
|
success = true
|
||||||
xlog.Debug("Downloaded backend from mirror", "uri", config.URI, "backendPath", backendPath)
|
xlog.Debug("Downloaded backend from mirror", "uri", config.URI, "backendPath", backendPath)
|
||||||
|
|
@ -221,28 +232,22 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL
|
||||||
// Try fallback: replace latestTag + "-" with masterTag + "-" in the URI
|
// Try fallback: replace latestTag + "-" with masterTag + "-" in the URI
|
||||||
fallbackURI := strings.Replace(string(config.URI), latestTag+"-", masterTag+"-", 1)
|
fallbackURI := strings.Replace(string(config.URI), latestTag+"-", masterTag+"-", 1)
|
||||||
if fallbackURI != string(config.URI) {
|
if fallbackURI != string(config.URI) {
|
||||||
xlog.Debug("Trying fallback URI", "original", config.URI, "fallback", fallbackURI)
|
resetBackendPath()
|
||||||
|
xlog.Info("Trying fallback URI", "original", config.URI, "fallback", fallbackURI)
|
||||||
if err := downloader.URI(fallbackURI).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil {
|
if err := downloader.URI(fallbackURI).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil {
|
||||||
xlog.Info("Downloaded backend using fallback URI", "uri", fallbackURI, "backendPath", backendPath)
|
xlog.Info("Downloaded backend using fallback URI", "uri", fallbackURI, "backendPath", backendPath)
|
||||||
success = true
|
success = true
|
||||||
} else {
|
} else {
|
||||||
// Try another fallback: add "-" + devSuffix suffix to the backend name
|
xlog.Info("Fallback URI failed", "fallback", fallbackURI, "error", err)
|
||||||
// For example: master-gpu-nvidia-cuda-13-ace-step -> master-gpu-nvidia-cuda-13-ace-step-development
|
|
||||||
if !strings.Contains(fallbackURI, "-"+devSuffix) {
|
if !strings.Contains(fallbackURI, "-"+devSuffix) {
|
||||||
// Extract backend name from URI and add -development
|
resetBackendPath()
|
||||||
parts := strings.Split(fallbackURI, "-")
|
devFallbackURI := fallbackURI + "-" + devSuffix
|
||||||
if len(parts) >= 2 {
|
xlog.Info("Trying development fallback URI", "fallback", devFallbackURI)
|
||||||
// Find where the backend name ends (usually the last part before the tag)
|
if err := downloader.URI(devFallbackURI).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil {
|
||||||
// Pattern: quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-ace-step
|
xlog.Info("Downloaded backend using development fallback URI", "uri", devFallbackURI, "backendPath", backendPath)
|
||||||
lastDash := strings.LastIndex(fallbackURI, "-")
|
success = true
|
||||||
if lastDash > 0 {
|
} else {
|
||||||
devFallbackURI := fallbackURI[:lastDash] + "-" + devSuffix
|
xlog.Info("Development fallback URI failed", "fallback", devFallbackURI, "error", err)
|
||||||
xlog.Debug("Trying development fallback URI", "fallback", devFallbackURI)
|
|
||||||
if err := downloader.URI(devFallbackURI).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil {
|
|
||||||
xlog.Info("Downloaded backend using development fallback URI", "uri", devFallbackURI, "backendPath", backendPath)
|
|
||||||
success = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -295,7 +300,7 @@ func DeleteBackendFromSystem(systemState *system.SystemState, name string) error
|
||||||
|
|
||||||
backend, ok := backends.Get(name)
|
backend, ok := backends.Get(name)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("backend %q not found", name)
|
return fmt.Errorf("backend %q: %w", name, ErrBackendNotFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
if backend.IsSystem {
|
if backend.IsSystem {
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"sort"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
|
@ -106,64 +106,64 @@ func (gm GalleryElements[T]) FilterByTag(tag string) GalleryElements[T] {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (gm GalleryElements[T]) SortByName(sortOrder string) GalleryElements[T] {
|
func (gm GalleryElements[T]) SortByName(sortOrder string) GalleryElements[T] {
|
||||||
sort.Slice(gm, func(i, j int) bool {
|
slices.SortFunc(gm, func(a, b T) int {
|
||||||
if sortOrder == "asc" {
|
r := strings.Compare(strings.ToLower(a.GetName()), strings.ToLower(b.GetName()))
|
||||||
return strings.ToLower(gm[i].GetName()) < strings.ToLower(gm[j].GetName())
|
if sortOrder == "desc" {
|
||||||
} else {
|
return -r
|
||||||
return strings.ToLower(gm[i].GetName()) > strings.ToLower(gm[j].GetName())
|
|
||||||
}
|
}
|
||||||
|
return r
|
||||||
})
|
})
|
||||||
return gm
|
return gm
|
||||||
}
|
}
|
||||||
|
|
||||||
func (gm GalleryElements[T]) SortByRepository(sortOrder string) GalleryElements[T] {
|
func (gm GalleryElements[T]) SortByRepository(sortOrder string) GalleryElements[T] {
|
||||||
sort.Slice(gm, func(i, j int) bool {
|
slices.SortFunc(gm, func(a, b T) int {
|
||||||
if sortOrder == "asc" {
|
r := strings.Compare(strings.ToLower(a.GetGallery().Name), strings.ToLower(b.GetGallery().Name))
|
||||||
return strings.ToLower(gm[i].GetGallery().Name) < strings.ToLower(gm[j].GetGallery().Name)
|
if sortOrder == "desc" {
|
||||||
} else {
|
return -r
|
||||||
return strings.ToLower(gm[i].GetGallery().Name) > strings.ToLower(gm[j].GetGallery().Name)
|
|
||||||
}
|
}
|
||||||
|
return r
|
||||||
})
|
})
|
||||||
return gm
|
return gm
|
||||||
}
|
}
|
||||||
|
|
||||||
func (gm GalleryElements[T]) SortByLicense(sortOrder string) GalleryElements[T] {
|
func (gm GalleryElements[T]) SortByLicense(sortOrder string) GalleryElements[T] {
|
||||||
sort.Slice(gm, func(i, j int) bool {
|
slices.SortFunc(gm, func(a, b T) int {
|
||||||
licenseI := gm[i].GetLicense()
|
licenseA := a.GetLicense()
|
||||||
licenseJ := gm[j].GetLicense()
|
licenseB := b.GetLicense()
|
||||||
var result bool
|
var r int
|
||||||
if licenseI == "" && licenseJ != "" {
|
if licenseA == "" && licenseB != "" {
|
||||||
return sortOrder == "desc"
|
r = 1
|
||||||
} else if licenseI != "" && licenseJ == "" {
|
} else if licenseA != "" && licenseB == "" {
|
||||||
return sortOrder == "asc"
|
r = -1
|
||||||
} else if licenseI == "" && licenseJ == "" {
|
|
||||||
return false
|
|
||||||
} else {
|
} else {
|
||||||
result = strings.ToLower(licenseI) < strings.ToLower(licenseJ)
|
r = strings.Compare(strings.ToLower(licenseA), strings.ToLower(licenseB))
|
||||||
}
|
}
|
||||||
if sortOrder == "desc" {
|
if sortOrder == "desc" {
|
||||||
return !result
|
return -r
|
||||||
} else {
|
|
||||||
return result
|
|
||||||
}
|
}
|
||||||
|
return r
|
||||||
})
|
})
|
||||||
return gm
|
return gm
|
||||||
}
|
}
|
||||||
|
|
||||||
func (gm GalleryElements[T]) SortByInstalled(sortOrder string) GalleryElements[T] {
|
func (gm GalleryElements[T]) SortByInstalled(sortOrder string) GalleryElements[T] {
|
||||||
sort.Slice(gm, func(i, j int) bool {
|
slices.SortFunc(gm, func(a, b T) int {
|
||||||
var result bool
|
var r int
|
||||||
// Sort by installed status: installed items first (true > false)
|
// Sort by installed status: installed items first (true > false)
|
||||||
if gm[i].GetInstalled() != gm[j].GetInstalled() {
|
if a.GetInstalled() != b.GetInstalled() {
|
||||||
result = gm[i].GetInstalled()
|
if a.GetInstalled() {
|
||||||
|
r = -1
|
||||||
|
} else {
|
||||||
|
r = 1
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
result = strings.ToLower(gm[i].GetName()) < strings.ToLower(gm[j].GetName())
|
r = strings.Compare(strings.ToLower(a.GetName()), strings.ToLower(b.GetName()))
|
||||||
}
|
}
|
||||||
if sortOrder == "desc" {
|
if sortOrder == "desc" {
|
||||||
return !result
|
return -r
|
||||||
} else {
|
|
||||||
return result
|
|
||||||
}
|
}
|
||||||
|
return r
|
||||||
})
|
})
|
||||||
return gm
|
return gm
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ var _ = Describe("Gallery", func() {
|
||||||
|
|
||||||
Describe("ReadConfigFile", func() {
|
Describe("ReadConfigFile", func() {
|
||||||
It("should read and unmarshal a valid YAML file", func() {
|
It("should read and unmarshal a valid YAML file", func() {
|
||||||
testConfig := map[string]interface{}{
|
testConfig := map[string]any{
|
||||||
"name": "test-model",
|
"name": "test-model",
|
||||||
"description": "A test model",
|
"description": "A test model",
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
|
|
@ -39,8 +39,8 @@ var _ = Describe("Gallery", func() {
|
||||||
err = os.WriteFile(filePath, yamlData, 0644)
|
err = os.WriteFile(filePath, yamlData, 0644)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
var result map[string]interface{}
|
var result map[string]any
|
||||||
config, err := ReadConfigFile[map[string]interface{}](filePath)
|
config, err := ReadConfigFile[map[string]any](filePath)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
Expect(config).NotTo(BeNil())
|
Expect(config).NotTo(BeNil())
|
||||||
result = *config
|
result = *config
|
||||||
|
|
@ -50,7 +50,7 @@ var _ = Describe("Gallery", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("should return error when file does not exist", func() {
|
It("should return error when file does not exist", func() {
|
||||||
_, err := ReadConfigFile[map[string]interface{}]("nonexistent.yaml")
|
_, err := ReadConfigFile[map[string]any]("nonexistent.yaml")
|
||||||
Expect(err).To(HaveOccurred())
|
Expect(err).To(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
@ -59,7 +59,7 @@ var _ = Describe("Gallery", func() {
|
||||||
err := os.WriteFile(filePath, []byte("invalid: yaml: content: [unclosed"), 0644)
|
err := os.WriteFile(filePath, []byte("invalid: yaml: content: [unclosed"), 0644)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
_, err = ReadConfigFile[map[string]interface{}](filePath)
|
_, err = ReadConfigFile[map[string]any](filePath)
|
||||||
Expect(err).To(HaveOccurred())
|
Expect(err).To(HaveOccurred())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
@ -552,23 +552,23 @@ var _ = Describe("Gallery", func() {
|
||||||
// Verify first model
|
// Verify first model
|
||||||
Expect(models[0].Name).To(Equal("nanbeige4.1-3b-q8"))
|
Expect(models[0].Name).To(Equal("nanbeige4.1-3b-q8"))
|
||||||
Expect(models[0].Overrides).NotTo(BeNil())
|
Expect(models[0].Overrides).NotTo(BeNil())
|
||||||
Expect(models[0].Overrides["parameters"]).To(BeAssignableToTypeOf(map[string]interface{}{}))
|
Expect(models[0].Overrides["parameters"]).To(BeAssignableToTypeOf(map[string]any{}))
|
||||||
params := models[0].Overrides["parameters"].(map[string]interface{})
|
params := models[0].Overrides["parameters"].(map[string]any)
|
||||||
Expect(params["model"]).To(Equal("nanbeige4.1-3b-q8_0.gguf"))
|
Expect(params["model"]).To(Equal("nanbeige4.1-3b-q8_0.gguf"))
|
||||||
|
|
||||||
// Verify second model (merged)
|
// Verify second model (merged)
|
||||||
Expect(models[1].Name).To(Equal("nanbeige4.1-3b-q4"))
|
Expect(models[1].Name).To(Equal("nanbeige4.1-3b-q4"))
|
||||||
Expect(models[1].Overrides).NotTo(BeNil())
|
Expect(models[1].Overrides).NotTo(BeNil())
|
||||||
Expect(models[1].Overrides["parameters"]).To(BeAssignableToTypeOf(map[string]interface{}{}))
|
Expect(models[1].Overrides["parameters"]).To(BeAssignableToTypeOf(map[string]any{}))
|
||||||
params = models[1].Overrides["parameters"].(map[string]interface{})
|
params = models[1].Overrides["parameters"].(map[string]any)
|
||||||
Expect(params["model"]).To(Equal("nanbeige4.1-3b-q4_k_m.gguf"))
|
Expect(params["model"]).To(Equal("nanbeige4.1-3b-q4_k_m.gguf"))
|
||||||
|
|
||||||
// Simulate the mergo.Merge call that was failing in models.go:251
|
// Simulate the mergo.Merge call that was failing in models.go:251
|
||||||
// This should not panic with yaml.v3
|
// This should not panic with yaml.v3
|
||||||
configMap := make(map[string]interface{})
|
configMap := make(map[string]any)
|
||||||
configMap["name"] = "test"
|
configMap["name"] = "test"
|
||||||
configMap["backend"] = "llama-cpp"
|
configMap["backend"] = "llama-cpp"
|
||||||
configMap["parameters"] = map[string]interface{}{
|
configMap["parameters"] = map[string]any{
|
||||||
"model": "original.gguf",
|
"model": "original.gguf",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -577,7 +577,7 @@ var _ = Describe("Gallery", func() {
|
||||||
Expect(configMap["parameters"]).NotTo(BeNil())
|
Expect(configMap["parameters"]).NotTo(BeNil())
|
||||||
|
|
||||||
// Verify the merge worked correctly
|
// Verify the merge worked correctly
|
||||||
mergedParams := configMap["parameters"].(map[string]interface{})
|
mergedParams := configMap["parameters"].(map[string]any)
|
||||||
Expect(mergedParams["model"]).To(Equal("nanbeige4.1-3b-q4_k_m.gguf"))
|
Expect(mergedParams["model"]).To(Equal("nanbeige4.1-3b-q4_k_m.gguf"))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -59,7 +59,7 @@ var _ = Describe("ImportLocalPath", func() {
|
||||||
|
|
||||||
adapterConfig := map[string]any{
|
adapterConfig := map[string]any{
|
||||||
"base_model_name_or_path": "meta-llama/Llama-2-7b-hf",
|
"base_model_name_or_path": "meta-llama/Llama-2-7b-hf",
|
||||||
"peft_type": "LORA",
|
"peft_type": "LORA",
|
||||||
}
|
}
|
||||||
data, _ := json.Marshal(adapterConfig)
|
data, _ := json.Marshal(adapterConfig)
|
||||||
Expect(os.WriteFile(filepath.Join(modelDir, "adapter_config.json"), data, 0644)).To(Succeed())
|
Expect(os.WriteFile(filepath.Join(modelDir, "adapter_config.json"), data, 0644)).To(Succeed())
|
||||||
|
|
|
||||||
|
|
@ -158,7 +158,7 @@ func InstallModelFromGallery(
|
||||||
return applyModel(model)
|
return applyModel(model)
|
||||||
}
|
}
|
||||||
|
|
||||||
func InstallModel(ctx context.Context, systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) {
|
func InstallModel(ctx context.Context, systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]any, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) {
|
||||||
basePath := systemState.Model.ModelsPath
|
basePath := systemState.Model.ModelsPath
|
||||||
// Create base path if it doesn't exist
|
// Create base path if it doesn't exist
|
||||||
err := os.MkdirAll(basePath, 0750)
|
err := os.MkdirAll(basePath, 0750)
|
||||||
|
|
@ -239,7 +239,7 @@ func InstallModel(ctx context.Context, systemState *system.SystemState, nameOver
|
||||||
configFilePath := filepath.Join(basePath, name+".yaml")
|
configFilePath := filepath.Join(basePath, name+".yaml")
|
||||||
|
|
||||||
// Read and update config file as map[string]interface{}
|
// Read and update config file as map[string]interface{}
|
||||||
configMap := make(map[string]interface{})
|
configMap := make(map[string]any)
|
||||||
err = yaml.Unmarshal([]byte(config.ConfigFile), &configMap)
|
err = yaml.Unmarshal([]byte(config.ConfigFile), &configMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to unmarshal config YAML: %v", err)
|
return nil, fmt.Errorf("failed to unmarshal config YAML: %v", err)
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,7 @@ var _ = Describe("Model test", func() {
|
||||||
system.WithModelPath(tempdir),
|
system.WithModelPath(tempdir),
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
_, err = InstallModel(context.TODO(), systemState, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
_, err = InstallModel(context.TODO(), systemState, "", c, map[string]any{}, func(string, string, string, float64) {}, true)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} {
|
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} {
|
||||||
|
|
@ -43,7 +43,7 @@ var _ = Describe("Model test", func() {
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
}
|
}
|
||||||
|
|
||||||
content := map[string]interface{}{}
|
content := map[string]any{}
|
||||||
|
|
||||||
dat, err := os.ReadFile(filepath.Join(tempdir, "cerebras.yaml"))
|
dat, err := os.ReadFile(filepath.Join(tempdir, "cerebras.yaml"))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
@ -95,7 +95,7 @@ var _ = Describe("Model test", func() {
|
||||||
dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml"))
|
dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml"))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
content := map[string]interface{}{}
|
content := map[string]any{}
|
||||||
err = yaml.Unmarshal(dat, &content)
|
err = yaml.Unmarshal(dat, &content)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this"))
|
Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this"))
|
||||||
|
|
@ -130,7 +130,7 @@ var _ = Describe("Model test", func() {
|
||||||
system.WithModelPath(tempdir),
|
system.WithModelPath(tempdir),
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
_, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
_, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]any{}, func(string, string, string, float64) {}, true)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
|
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
|
||||||
|
|
@ -150,7 +150,7 @@ var _ = Describe("Model test", func() {
|
||||||
system.WithModelPath(tempdir),
|
system.WithModelPath(tempdir),
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
_, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true)
|
_, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]any{"backend": "foo"}, func(string, string, string, float64) {}, true)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
|
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
|
||||||
|
|
@ -158,7 +158,7 @@ var _ = Describe("Model test", func() {
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
}
|
}
|
||||||
|
|
||||||
content := map[string]interface{}{}
|
content := map[string]any{}
|
||||||
|
|
||||||
dat, err := os.ReadFile(filepath.Join(tempdir, "foo.yaml"))
|
dat, err := os.ReadFile(filepath.Join(tempdir, "foo.yaml"))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
@ -180,7 +180,7 @@ var _ = Describe("Model test", func() {
|
||||||
system.WithModelPath(tempdir),
|
system.WithModelPath(tempdir),
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
_, err = InstallModel(context.TODO(), systemState, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
_, err = InstallModel(context.TODO(), systemState, "../../../foo", c, map[string]any{}, func(string, string, string, float64) {}, true)
|
||||||
Expect(err).To(HaveOccurred())
|
Expect(err).To(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,9 +12,9 @@ import (
|
||||||
type GalleryModel struct {
|
type GalleryModel struct {
|
||||||
Metadata `json:",inline" yaml:",inline"`
|
Metadata `json:",inline" yaml:",inline"`
|
||||||
// config_file is read in the situation where URL is blank - and therefore this is a base config.
|
// config_file is read in the situation where URL is blank - and therefore this is a base config.
|
||||||
ConfigFile map[string]interface{} `json:"config_file,omitempty" yaml:"config_file,omitempty"`
|
ConfigFile map[string]any `json:"config_file,omitempty" yaml:"config_file,omitempty"`
|
||||||
// Overrides are used to override the configuration of the model located at URL
|
// Overrides are used to override the configuration of the model located at URL
|
||||||
Overrides map[string]interface{} `json:"overrides,omitempty" yaml:"overrides,omitempty"`
|
Overrides map[string]any `json:"overrides,omitempty" yaml:"overrides,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *GalleryModel) GetInstalled() bool {
|
func (m *GalleryModel) GetInstalled() bool {
|
||||||
|
|
|
||||||
66
core/gallery/worker.go
Normal file
66
core/gallery/worker.go
Normal file
|
|
@ -0,0 +1,66 @@
|
||||||
|
package gallery
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/mudler/xlog"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DeleteStagedModelFiles removes all staged files for a model from a worker's
|
||||||
|
// models directory. Files are expected to be in a subdirectory named after the
|
||||||
|
// model's tracking key (created by stageModelFiles in the router).
|
||||||
|
//
|
||||||
|
// Workers receive model files via S3/HTTP file staging, not gallery install,
|
||||||
|
// so they lack the YAML configs that DeleteModelFromSystem requires.
|
||||||
|
//
|
||||||
|
// Falls back to glob-based cleanup for single-file models or legacy layouts.
|
||||||
|
func DeleteStagedModelFiles(modelsPath, modelName string) error {
|
||||||
|
if modelName == "" {
|
||||||
|
return fmt.Errorf("empty model name")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean and validate: resolved path must stay within modelsPath
|
||||||
|
modelPath := filepath.Clean(filepath.Join(modelsPath, modelName))
|
||||||
|
absModels := filepath.Clean(modelsPath)
|
||||||
|
if !strings.HasPrefix(modelPath, absModels+string(filepath.Separator)) {
|
||||||
|
return fmt.Errorf("model name %q escapes models directory", modelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Primary: remove the model's subdirectory (contains all staged files)
|
||||||
|
if info, err := os.Stat(modelPath); err == nil && info.IsDir() {
|
||||||
|
return os.RemoveAll(modelPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback for single-file models or legacy layouts:
|
||||||
|
// remove exact file match + glob siblings
|
||||||
|
removed := false
|
||||||
|
if _, err := os.Stat(modelPath); err == nil {
|
||||||
|
if err := os.Remove(modelPath); err != nil {
|
||||||
|
xlog.Warn("Failed to remove model file", "path", modelPath, "error", err)
|
||||||
|
} else {
|
||||||
|
removed = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove sibling files (e.g., model.gguf.mmproj alongside model.gguf)
|
||||||
|
matches, _ := filepath.Glob(modelPath + ".*")
|
||||||
|
for _, m := range matches {
|
||||||
|
clean := filepath.Clean(m)
|
||||||
|
if !strings.HasPrefix(clean, absModels+string(filepath.Separator)) {
|
||||||
|
continue // skip any glob result that escapes
|
||||||
|
}
|
||||||
|
if err := os.Remove(clean); err != nil {
|
||||||
|
xlog.Warn("Failed to remove model-related file", "path", clean, "error", err)
|
||||||
|
} else {
|
||||||
|
removed = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !removed {
|
||||||
|
xlog.Debug("No files found to delete for model", "model", modelName, "path", modelPath)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
99
core/gallery/worker_test.go
Normal file
99
core/gallery/worker_test.go
Normal file
|
|
@ -0,0 +1,99 @@
|
||||||
|
package gallery_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/core/gallery"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDeleteStagedModelFiles(t *testing.T) {
|
||||||
|
t.Run("rejects empty model name", func(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
err := gallery.DeleteStagedModelFiles(dir, "")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for empty model name")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("rejects path traversal via ..", func(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
err := gallery.DeleteStagedModelFiles(dir, "../../etc/passwd")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for path traversal attempt")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("rejects path traversal via ../foo", func(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
err := gallery.DeleteStagedModelFiles(dir, "../foo")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for path traversal attempt")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("removes model subdirectory with all files", func(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
modelDir := filepath.Join(dir, "my-model", "sd-cpp", "models")
|
||||||
|
if err := os.MkdirAll(modelDir, 0o755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
// Create model files in subdirectory
|
||||||
|
os.WriteFile(filepath.Join(modelDir, "flux.gguf"), []byte("model"), 0o644)
|
||||||
|
os.WriteFile(filepath.Join(modelDir, "flux.gguf.mmproj"), []byte("mmproj"), 0o644)
|
||||||
|
|
||||||
|
err := gallery.DeleteStagedModelFiles(dir, "my-model")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Entire my-model directory should be gone
|
||||||
|
if _, err := os.Stat(filepath.Join(dir, "my-model")); !os.IsNotExist(err) {
|
||||||
|
t.Fatal("expected model directory to be removed")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("removes single file model", func(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
modelFile := filepath.Join(dir, "model.gguf")
|
||||||
|
os.WriteFile(modelFile, []byte("model"), 0o644)
|
||||||
|
|
||||||
|
err := gallery.DeleteStagedModelFiles(dir, "model.gguf")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := os.Stat(modelFile); !os.IsNotExist(err) {
|
||||||
|
t.Fatal("expected model file to be removed")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("removes sibling files via glob", func(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
modelFile := filepath.Join(dir, "model.gguf")
|
||||||
|
siblingFile := filepath.Join(dir, "model.gguf.mmproj")
|
||||||
|
os.WriteFile(modelFile, []byte("model"), 0o644)
|
||||||
|
os.WriteFile(siblingFile, []byte("mmproj"), 0o644)
|
||||||
|
|
||||||
|
err := gallery.DeleteStagedModelFiles(dir, "model.gguf")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := os.Stat(modelFile); !os.IsNotExist(err) {
|
||||||
|
t.Fatal("expected model file to be removed")
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(siblingFile); !os.IsNotExist(err) {
|
||||||
|
t.Fatal("expected sibling file to be removed")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no error when model does not exist", func(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
err := gallery.DeleteStagedModelFiles(dir, "nonexistent")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
@ -16,12 +16,17 @@ import (
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/http/auth"
|
"github.com/mudler/LocalAI/core/http/auth"
|
||||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||||
|
|
||||||
httpMiddleware "github.com/mudler/LocalAI/core/http/middleware"
|
httpMiddleware "github.com/mudler/LocalAI/core/http/middleware"
|
||||||
"github.com/mudler/LocalAI/core/http/routes"
|
"github.com/mudler/LocalAI/core/http/routes"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/application"
|
"github.com/mudler/LocalAI/core/application"
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
"github.com/mudler/LocalAI/core/services"
|
"github.com/mudler/LocalAI/core/services/finetune"
|
||||||
|
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||||
|
"github.com/mudler/LocalAI/core/services/monitoring"
|
||||||
|
"github.com/mudler/LocalAI/core/services/nodes"
|
||||||
|
"github.com/mudler/LocalAI/core/services/quantization"
|
||||||
|
|
||||||
"github.com/mudler/xlog"
|
"github.com/mudler/xlog"
|
||||||
)
|
)
|
||||||
|
|
@ -155,7 +160,7 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||||
|
|
||||||
// Metrics middleware
|
// Metrics middleware
|
||||||
if !application.ApplicationConfig().DisableMetrics {
|
if !application.ApplicationConfig().DisableMetrics {
|
||||||
metricsService, err := services.NewLocalAIMetricsService()
|
metricsService, err := monitoring.NewLocalAIMetricsService()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -295,9 +300,9 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||||
routes.RegisterElevenLabsRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
routes.RegisterElevenLabsRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||||
|
|
||||||
// Create opcache for tracking UI operations (used by both UI and LocalAI routes)
|
// Create opcache for tracking UI operations (used by both UI and LocalAI routes)
|
||||||
var opcache *services.OpCache
|
var opcache *galleryop.OpCache
|
||||||
if !application.ApplicationConfig().DisableWebUI {
|
if !application.ApplicationConfig().DisableWebUI {
|
||||||
opcache = services.NewOpCache(application.GalleryService())
|
opcache = galleryop.NewOpCache(application.GalleryService())
|
||||||
}
|
}
|
||||||
|
|
||||||
mcpMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMCP)
|
mcpMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMCP)
|
||||||
|
|
@ -305,22 +310,51 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||||
routes.RegisterAgentPoolRoutes(e, application, agentsMw, skillsMw, collectionsMw)
|
routes.RegisterAgentPoolRoutes(e, application, agentsMw, skillsMw, collectionsMw)
|
||||||
// Fine-tuning routes
|
// Fine-tuning routes
|
||||||
fineTuningMw := auth.RequireFeature(application.AuthDB(), auth.FeatureFineTuning)
|
fineTuningMw := auth.RequireFeature(application.AuthDB(), auth.FeatureFineTuning)
|
||||||
ftService := services.NewFineTuneService(
|
ftService := finetune.NewFineTuneService(
|
||||||
application.ApplicationConfig(),
|
application.ApplicationConfig(),
|
||||||
application.ModelLoader(),
|
application.ModelLoader(),
|
||||||
application.ModelConfigLoader(),
|
application.ModelConfigLoader(),
|
||||||
)
|
)
|
||||||
|
if d := application.Distributed(); d != nil {
|
||||||
|
ftService.SetNATSClient(d.Nats)
|
||||||
|
if d.DistStores != nil && d.DistStores.FineTune != nil {
|
||||||
|
ftService.SetFineTuneStore(d.DistStores.FineTune)
|
||||||
|
}
|
||||||
|
}
|
||||||
routes.RegisterFineTuningRoutes(e, ftService, application.ApplicationConfig(), fineTuningMw)
|
routes.RegisterFineTuningRoutes(e, ftService, application.ApplicationConfig(), fineTuningMw)
|
||||||
|
|
||||||
// Quantization routes
|
// Quantization routes
|
||||||
quantizationMw := auth.RequireFeature(application.AuthDB(), auth.FeatureQuantization)
|
quantizationMw := auth.RequireFeature(application.AuthDB(), auth.FeatureQuantization)
|
||||||
qService := services.NewQuantizationService(
|
qService := quantization.NewQuantizationService(
|
||||||
application.ApplicationConfig(),
|
application.ApplicationConfig(),
|
||||||
application.ModelLoader(),
|
application.ModelLoader(),
|
||||||
application.ModelConfigLoader(),
|
application.ModelConfigLoader(),
|
||||||
)
|
)
|
||||||
routes.RegisterQuantizationRoutes(e, qService, application.ApplicationConfig(), quantizationMw)
|
routes.RegisterQuantizationRoutes(e, qService, application.ApplicationConfig(), quantizationMw)
|
||||||
|
|
||||||
|
// Node management routes (distributed mode)
|
||||||
|
distCfg := application.ApplicationConfig().Distributed
|
||||||
|
var registry *nodes.NodeRegistry
|
||||||
|
var remoteUnloader nodes.NodeCommandSender
|
||||||
|
if d := application.Distributed(); d != nil {
|
||||||
|
registry = d.Registry
|
||||||
|
if d.Router != nil {
|
||||||
|
remoteUnloader = d.Router.Unloader()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
routes.RegisterNodeSelfServiceRoutes(e, registry, distCfg.RegistrationToken, distCfg.AutoApproveNodes, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret)
|
||||||
|
routes.RegisterNodeAdminRoutes(e, registry, remoteUnloader, adminMiddleware, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, application.ApplicationConfig().Distributed.RegistrationToken)
|
||||||
|
|
||||||
|
// Distributed SSE routes (job progress + agent events via NATS)
|
||||||
|
if d := application.Distributed(); d != nil {
|
||||||
|
if d.Dispatcher != nil {
|
||||||
|
e.GET("/api/agent/jobs/:id/progress", d.Dispatcher.SSEHandler(), mcpJobsMw)
|
||||||
|
}
|
||||||
|
if d.AgentBridge != nil {
|
||||||
|
e.GET("/api/agents/:name/sse/distributed", d.AgentBridge.SSEHandler(), agentsMw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
routes.RegisterOpenAIRoutes(e, requestExtractor, application)
|
routes.RegisterOpenAIRoutes(e, requestExtractor, application)
|
||||||
routes.RegisterAnthropicRoutes(e, requestExtractor, application)
|
routes.RegisterAnthropicRoutes(e, requestExtractor, application)
|
||||||
routes.RegisterOpenResponsesRoutes(e, requestExtractor, application)
|
routes.RegisterOpenResponsesRoutes(e, requestExtractor, application)
|
||||||
|
|
|
||||||
|
|
@ -44,14 +44,14 @@ Say hello.
|
||||||
### Response:`
|
### Response:`
|
||||||
|
|
||||||
type modelApplyRequest struct {
|
type modelApplyRequest struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
URL string `json:"url"`
|
URL string `json:"url"`
|
||||||
ConfigURL string `json:"config_url"`
|
ConfigURL string `json:"config_url"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Overrides map[string]interface{} `json:"overrides"`
|
Overrides map[string]any `json:"overrides"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func getModelStatus(url string) (response map[string]interface{}) {
|
func getModelStatus(url string) (response map[string]any) {
|
||||||
// Create the HTTP request
|
// Create the HTTP request
|
||||||
req, err := http.NewRequest("GET", url, nil)
|
req, err := http.NewRequest("GET", url, nil)
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
@ -94,7 +94,7 @@ func getModels(url string) ([]gallery.GalleryModel, error) {
|
||||||
return response, err
|
return response, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]interface{}) {
|
func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]any) {
|
||||||
|
|
||||||
//url := "http://localhost:AI/models/apply"
|
//url := "http://localhost:AI/models/apply"
|
||||||
|
|
||||||
|
|
@ -336,7 +336,7 @@ var _ = Describe("API test", func() {
|
||||||
Name: "bert",
|
Name: "bert",
|
||||||
URL: bertEmbeddingsURL,
|
URL: bertEmbeddingsURL,
|
||||||
},
|
},
|
||||||
Overrides: map[string]interface{}{"backend": "llama-cpp"},
|
Overrides: map[string]any{"backend": "llama-cpp"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Metadata: gallery.Metadata{
|
Metadata: gallery.Metadata{
|
||||||
|
|
@ -344,7 +344,7 @@ var _ = Describe("API test", func() {
|
||||||
URL: bertEmbeddingsURL,
|
URL: bertEmbeddingsURL,
|
||||||
AdditionalFiles: []gallery.File{{Filename: "foo.yaml", URI: bertEmbeddingsURL}},
|
AdditionalFiles: []gallery.File{{Filename: "foo.yaml", URI: bertEmbeddingsURL}},
|
||||||
},
|
},
|
||||||
Overrides: map[string]interface{}{"foo": "bar"},
|
Overrides: map[string]any{"foo": "bar"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
out, err := yaml.Marshal(g)
|
out, err := yaml.Marshal(g)
|
||||||
|
|
@ -464,7 +464,7 @@ var _ = Describe("API test", func() {
|
||||||
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
|
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
|
||||||
|
|
||||||
uuid := response["uuid"].(string)
|
uuid := response["uuid"].(string)
|
||||||
resp := map[string]interface{}{}
|
resp := map[string]any{}
|
||||||
Eventually(func() bool {
|
Eventually(func() bool {
|
||||||
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
|
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
|
||||||
fmt.Println(response)
|
fmt.Println(response)
|
||||||
|
|
@ -479,7 +479,7 @@ var _ = Describe("API test", func() {
|
||||||
_, err = os.ReadFile(filepath.Join(modelDir, "foo.yaml"))
|
_, err = os.ReadFile(filepath.Join(modelDir, "foo.yaml"))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
content := map[string]interface{}{}
|
content := map[string]any{}
|
||||||
err = yaml.Unmarshal(dat, &content)
|
err = yaml.Unmarshal(dat, &content)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this"))
|
Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this"))
|
||||||
|
|
@ -503,7 +503,7 @@ var _ = Describe("API test", func() {
|
||||||
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
|
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
|
||||||
URL: bertEmbeddingsURL,
|
URL: bertEmbeddingsURL,
|
||||||
Name: "bert",
|
Name: "bert",
|
||||||
Overrides: map[string]interface{}{
|
Overrides: map[string]any{
|
||||||
"backend": "llama",
|
"backend": "llama",
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
@ -520,7 +520,7 @@ var _ = Describe("API test", func() {
|
||||||
dat, err := os.ReadFile(filepath.Join(modelDir, "bert.yaml"))
|
dat, err := os.ReadFile(filepath.Join(modelDir, "bert.yaml"))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
content := map[string]interface{}{}
|
content := map[string]any{}
|
||||||
err = yaml.Unmarshal(dat, &content)
|
err = yaml.Unmarshal(dat, &content)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(content["backend"]).To(Equal("llama"))
|
Expect(content["backend"]).To(Equal("llama"))
|
||||||
|
|
@ -529,7 +529,7 @@ var _ = Describe("API test", func() {
|
||||||
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
|
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
|
||||||
URL: bertEmbeddingsURL,
|
URL: bertEmbeddingsURL,
|
||||||
Name: "bert",
|
Name: "bert",
|
||||||
Overrides: map[string]interface{}{},
|
Overrides: map[string]any{},
|
||||||
})
|
})
|
||||||
|
|
||||||
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
|
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
|
||||||
|
|
@ -544,7 +544,7 @@ var _ = Describe("API test", func() {
|
||||||
dat, err := os.ReadFile(filepath.Join(modelDir, "bert.yaml"))
|
dat, err := os.ReadFile(filepath.Join(modelDir, "bert.yaml"))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
content := map[string]interface{}{}
|
content := map[string]any{}
|
||||||
err = yaml.Unmarshal(dat, &content)
|
err = yaml.Unmarshal(dat, &content)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this"))
|
Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this"))
|
||||||
|
|
@ -586,7 +586,7 @@ parameters:
|
||||||
Expect(response.ID).ToNot(BeEmpty())
|
Expect(response.ID).ToNot(BeEmpty())
|
||||||
|
|
||||||
uuid := response.ID
|
uuid := response.ID
|
||||||
resp := map[string]interface{}{}
|
resp := map[string]any{}
|
||||||
Eventually(func() bool {
|
Eventually(func() bool {
|
||||||
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
|
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
|
||||||
resp = response
|
resp = response
|
||||||
|
|
@ -601,7 +601,7 @@ parameters:
|
||||||
dat, err := os.ReadFile(filepath.Join(modelDir, "test-import-model.yaml"))
|
dat, err := os.ReadFile(filepath.Join(modelDir, "test-import-model.yaml"))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
content := map[string]interface{}{}
|
content := map[string]any{}
|
||||||
err = yaml.Unmarshal(dat, &content)
|
err = yaml.Unmarshal(dat, &content)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(content["name"]).To(Equal("test-import-model"))
|
Expect(content["name"]).To(Equal("test-import-model"))
|
||||||
|
|
@ -657,7 +657,7 @@ parameters:
|
||||||
Expect(response.ID).ToNot(BeEmpty())
|
Expect(response.ID).ToNot(BeEmpty())
|
||||||
|
|
||||||
uuid := response.ID
|
uuid := response.ID
|
||||||
resp := map[string]interface{}{}
|
resp := map[string]any{}
|
||||||
Eventually(func() bool {
|
Eventually(func() bool {
|
||||||
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
|
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
|
||||||
resp = response
|
resp = response
|
||||||
|
|
@ -1248,7 +1248,7 @@ parameters:
|
||||||
Context("Agent Jobs", Label("agent-jobs"), func() {
|
Context("Agent Jobs", Label("agent-jobs"), func() {
|
||||||
It("creates and manages tasks", func() {
|
It("creates and manages tasks", func() {
|
||||||
// Create a task
|
// Create a task
|
||||||
taskBody := map[string]interface{}{
|
taskBody := map[string]any{
|
||||||
"name": "Test Task",
|
"name": "Test Task",
|
||||||
"description": "Test Description",
|
"description": "Test Description",
|
||||||
"model": "testmodel.ggml",
|
"model": "testmodel.ggml",
|
||||||
|
|
@ -1256,7 +1256,7 @@ parameters:
|
||||||
"enabled": true,
|
"enabled": true,
|
||||||
}
|
}
|
||||||
|
|
||||||
var createResp map[string]interface{}
|
var createResp map[string]any
|
||||||
err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp)
|
err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(createResp["id"]).ToNot(BeEmpty())
|
Expect(createResp["id"]).ToNot(BeEmpty())
|
||||||
|
|
@ -1302,20 +1302,20 @@ parameters:
|
||||||
|
|
||||||
It("executes and monitors jobs", func() {
|
It("executes and monitors jobs", func() {
|
||||||
// Create a task first
|
// Create a task first
|
||||||
taskBody := map[string]interface{}{
|
taskBody := map[string]any{
|
||||||
"name": "Job Test Task",
|
"name": "Job Test Task",
|
||||||
"model": "testmodel.ggml",
|
"model": "testmodel.ggml",
|
||||||
"prompt": "Say hello",
|
"prompt": "Say hello",
|
||||||
"enabled": true,
|
"enabled": true,
|
||||||
}
|
}
|
||||||
|
|
||||||
var createResp map[string]interface{}
|
var createResp map[string]any
|
||||||
err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp)
|
err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
taskID := createResp["id"].(string)
|
taskID := createResp["id"].(string)
|
||||||
|
|
||||||
// Execute a job
|
// Execute a job
|
||||||
jobBody := map[string]interface{}{
|
jobBody := map[string]any{
|
||||||
"task_id": taskID,
|
"task_id": taskID,
|
||||||
"parameters": map[string]string{},
|
"parameters": map[string]string{},
|
||||||
}
|
}
|
||||||
|
|
@ -1357,14 +1357,14 @@ parameters:
|
||||||
|
|
||||||
It("executes task by name", func() {
|
It("executes task by name", func() {
|
||||||
// Create a task with a specific name
|
// Create a task with a specific name
|
||||||
taskBody := map[string]interface{}{
|
taskBody := map[string]any{
|
||||||
"name": "Named Task",
|
"name": "Named Task",
|
||||||
"model": "testmodel.ggml",
|
"model": "testmodel.ggml",
|
||||||
"prompt": "Hello",
|
"prompt": "Hello",
|
||||||
"enabled": true,
|
"enabled": true,
|
||||||
}
|
}
|
||||||
|
|
||||||
var createResp map[string]interface{}
|
var createResp map[string]any
|
||||||
err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp)
|
err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -516,6 +516,17 @@ func isExemptPath(path string, appConfig *config.ApplicationConfig) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Node self-service endpoints — authenticated via registration token, not global auth.
|
||||||
|
// Only exempt the specific known endpoints, not the entire prefix.
|
||||||
|
if strings.HasPrefix(path, "/api/node/") {
|
||||||
|
if path == "/api/node/register" ||
|
||||||
|
strings.HasSuffix(path, "/heartbeat") ||
|
||||||
|
strings.HasSuffix(path, "/drain") ||
|
||||||
|
strings.HasSuffix(path, "/deregister") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Check configured exempt paths
|
// Check configured exempt paths
|
||||||
for _, p := range appConfig.PathWithoutAuth {
|
for _, p := range appConfig.PathWithoutAuth {
|
||||||
if strings.HasPrefix(path, p) {
|
if strings.HasPrefix(path, p) {
|
||||||
|
|
@ -540,6 +551,14 @@ func isAPIPath(path string) bool {
|
||||||
strings.HasPrefix(path, "/system") ||
|
strings.HasPrefix(path, "/system") ||
|
||||||
strings.HasPrefix(path, "/ws/") ||
|
strings.HasPrefix(path, "/ws/") ||
|
||||||
strings.HasPrefix(path, "/generated-") ||
|
strings.HasPrefix(path, "/generated-") ||
|
||||||
|
strings.HasPrefix(path, "/chat/") ||
|
||||||
|
strings.HasPrefix(path, "/completions") ||
|
||||||
|
strings.HasPrefix(path, "/edits") ||
|
||||||
|
strings.HasPrefix(path, "/embeddings") ||
|
||||||
|
strings.HasPrefix(path, "/audio/") ||
|
||||||
|
strings.HasPrefix(path, "/images/") ||
|
||||||
|
strings.HasPrefix(path, "/messages") ||
|
||||||
|
strings.HasPrefix(path, "/responses") ||
|
||||||
path == "/metrics"
|
path == "/metrics"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,24 +9,25 @@ import (
|
||||||
|
|
||||||
// Auth provider constants.
|
// Auth provider constants.
|
||||||
const (
|
const (
|
||||||
ProviderLocal = "local"
|
ProviderLocal = "local"
|
||||||
ProviderGitHub = "github"
|
ProviderGitHub = "github"
|
||||||
ProviderOIDC = "oidc"
|
ProviderOIDC = "oidc"
|
||||||
|
ProviderAgentWorker = "agent-worker"
|
||||||
)
|
)
|
||||||
|
|
||||||
// User represents an authenticated user.
|
// User represents an authenticated user.
|
||||||
type User struct {
|
type User struct {
|
||||||
ID string `gorm:"primaryKey;size:36"`
|
ID string `gorm:"primaryKey;size:36"`
|
||||||
Email string `gorm:"size:255;index"`
|
Email string `gorm:"size:255;index"`
|
||||||
Name string `gorm:"size:255"`
|
Name string `gorm:"size:255"`
|
||||||
AvatarURL string `gorm:"size:512"`
|
AvatarURL string `gorm:"size:512"`
|
||||||
Provider string `gorm:"size:50"` // ProviderLocal, ProviderGitHub, ProviderOIDC
|
Provider string `gorm:"size:50"` // ProviderLocal, ProviderGitHub, ProviderOIDC
|
||||||
Subject string `gorm:"size:255"` // provider-specific user ID
|
Subject string `gorm:"size:255"` // provider-specific user ID
|
||||||
PasswordHash string `json:"-"` // bcrypt hash, empty for OAuth-only users
|
PasswordHash string `json:"-"` // bcrypt hash, empty for OAuth-only users
|
||||||
Role string `gorm:"size:20;default:user"`
|
Role string `gorm:"size:20;default:user"`
|
||||||
Status string `gorm:"size:20;default:active"` // "active", "pending"
|
Status string `gorm:"size:20;default:active"` // "active", "pending"
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// Session represents a user login session.
|
// Session represents a user login session.
|
||||||
|
|
@ -90,16 +91,16 @@ func (p *PermissionMap) Scan(value any) error {
|
||||||
|
|
||||||
// InviteCode represents an admin-generated invitation for user registration.
|
// InviteCode represents an admin-generated invitation for user registration.
|
||||||
type InviteCode struct {
|
type InviteCode struct {
|
||||||
ID string `gorm:"primaryKey;size:36"`
|
ID string `gorm:"primaryKey;size:36"`
|
||||||
Code string `gorm:"uniqueIndex;not null;size:64"` // HMAC-SHA256 hash of invite code
|
Code string `gorm:"uniqueIndex;not null;size:64"` // HMAC-SHA256 hash of invite code
|
||||||
CodePrefix string `gorm:"size:12"` // first 8 chars for admin display
|
CodePrefix string `gorm:"size:12"` // first 8 chars for admin display
|
||||||
CreatedBy string `gorm:"size:36;not null"`
|
CreatedBy string `gorm:"size:36;not null"`
|
||||||
UsedBy *string `gorm:"size:36"`
|
UsedBy *string `gorm:"size:36"`
|
||||||
UsedAt *time.Time
|
UsedAt *time.Time
|
||||||
ExpiresAt time.Time `gorm:"not null;index"`
|
ExpiresAt time.Time `gorm:"not null;index"`
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
Creator User `gorm:"foreignKey:CreatedBy"`
|
Creator User `gorm:"foreignKey:CreatedBy"`
|
||||||
Consumer *User `gorm:"foreignKey:UsedBy"`
|
Consumer *User `gorm:"foreignKey:UsedBy"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ModelAllowlist controls which models a user can access.
|
// ModelAllowlist controls which models a user can access.
|
||||||
|
|
|
||||||
|
|
@ -33,24 +33,24 @@ const (
|
||||||
FeatureMCPJobs = "mcp_jobs"
|
FeatureMCPJobs = "mcp_jobs"
|
||||||
|
|
||||||
// General features (default OFF for new users)
|
// General features (default OFF for new users)
|
||||||
FeatureFineTuning = "fine_tuning"
|
FeatureFineTuning = "fine_tuning"
|
||||||
FeatureQuantization = "quantization"
|
FeatureQuantization = "quantization"
|
||||||
|
|
||||||
// API features (default ON for new users)
|
// API features (default ON for new users)
|
||||||
FeatureChat = "chat"
|
FeatureChat = "chat"
|
||||||
FeatureImages = "images"
|
FeatureImages = "images"
|
||||||
FeatureAudioSpeech = "audio_speech"
|
FeatureAudioSpeech = "audio_speech"
|
||||||
FeatureAudioTranscription = "audio_transcription"
|
FeatureAudioTranscription = "audio_transcription"
|
||||||
FeatureVAD = "vad"
|
FeatureVAD = "vad"
|
||||||
FeatureDetection = "detection"
|
FeatureDetection = "detection"
|
||||||
FeatureVideo = "video"
|
FeatureVideo = "video"
|
||||||
FeatureEmbeddings = "embeddings"
|
FeatureEmbeddings = "embeddings"
|
||||||
FeatureSound = "sound"
|
FeatureSound = "sound"
|
||||||
FeatureRealtime = "realtime"
|
FeatureRealtime = "realtime"
|
||||||
FeatureRerank = "rerank"
|
FeatureRerank = "rerank"
|
||||||
FeatureTokenize = "tokenize"
|
FeatureTokenize = "tokenize"
|
||||||
FeatureMCP = "mcp"
|
FeatureMCP = "mcp"
|
||||||
FeatureStores = "stores"
|
FeatureStores = "stores"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AgentFeatures lists agent-related features (default OFF).
|
// AgentFeatures lists agent-related features (default OFF).
|
||||||
|
|
|
||||||
|
|
@ -24,14 +24,14 @@ type QuotaRule struct {
|
||||||
|
|
||||||
// QuotaStatus is returned to clients with current usage included.
|
// QuotaStatus is returned to clients with current usage included.
|
||||||
type QuotaStatus struct {
|
type QuotaStatus struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
MaxRequests *int64 `json:"max_requests"`
|
MaxRequests *int64 `json:"max_requests"`
|
||||||
MaxTotalTokens *int64 `json:"max_total_tokens"`
|
MaxTotalTokens *int64 `json:"max_total_tokens"`
|
||||||
Window string `json:"window"`
|
Window string `json:"window"`
|
||||||
CurrentRequests int64 `json:"current_requests"`
|
CurrentRequests int64 `json:"current_requests"`
|
||||||
CurrentTokens int64 `json:"current_total_tokens"`
|
CurrentTokens int64 `json:"current_total_tokens"`
|
||||||
ResetsAt string `json:"resets_at,omitempty"`
|
ResetsAt string `json:"resets_at,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── CRUD ──
|
// ── CRUD ──
|
||||||
|
|
@ -209,9 +209,9 @@ func QuotaExceeded(db *gorm.DB, userID, model string) (bool, int64, string) {
|
||||||
var quotaCache = newQuotaCacheStore()
|
var quotaCache = newQuotaCacheStore()
|
||||||
|
|
||||||
type quotaCacheStore struct {
|
type quotaCacheStore struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
rules map[string]cachedRules // userID -> rules
|
rules map[string]cachedRules // userID -> rules
|
||||||
usage map[string]cachedUsage // "userID|model|windowStart" -> counts
|
usage map[string]cachedUsage // "userID|model|windowStart" -> counts
|
||||||
}
|
}
|
||||||
|
|
||||||
type cachedRules struct {
|
type cachedRules struct {
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ import (
|
||||||
|
|
||||||
const (
|
const (
|
||||||
sessionDuration = 30 * 24 * time.Hour // 30 days
|
sessionDuration = 30 * 24 * time.Hour // 30 days
|
||||||
sessionIDBytes = 32 // 32 bytes = 64 hex chars
|
sessionIDBytes = 32 // 32 bytes = 64 hex chars
|
||||||
sessionCookie = "session"
|
sessionCookie = "session"
|
||||||
sessionRotationInterval = 1 * time.Hour
|
sessionRotationInterval = 1 * time.Hour
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -10,15 +10,15 @@ import (
|
||||||
|
|
||||||
// UsageRecord represents a single API request's token usage.
|
// UsageRecord represents a single API request's token usage.
|
||||||
type UsageRecord struct {
|
type UsageRecord struct {
|
||||||
ID uint `gorm:"primaryKey;autoIncrement"`
|
ID uint `gorm:"primaryKey;autoIncrement"`
|
||||||
UserID string `gorm:"size:36;index:idx_usage_user_time"`
|
UserID string `gorm:"size:36;index:idx_usage_user_time"`
|
||||||
UserName string `gorm:"size:255"`
|
UserName string `gorm:"size:255"`
|
||||||
Model string `gorm:"size:255;index"`
|
Model string `gorm:"size:255;index"`
|
||||||
Endpoint string `gorm:"size:255"`
|
Endpoint string `gorm:"size:255"`
|
||||||
PromptTokens int64
|
PromptTokens int64
|
||||||
CompletionTokens int64
|
CompletionTokens int64
|
||||||
TotalTokens int64
|
TotalTokens int64
|
||||||
Duration int64 // milliseconds
|
Duration int64 // milliseconds
|
||||||
CreatedAt time.Time `gorm:"index:idx_usage_user_time"`
|
CreatedAt time.Time `gorm:"index:idx_usage_user_time"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -127,10 +127,10 @@ func GetAllUsage(db *gorm.DB, period, userID string) ([]UsageBucket, error) {
|
||||||
bucketExpr := fmt.Sprintf("%s as bucket", dateFmt)
|
bucketExpr := fmt.Sprintf("%s as bucket", dateFmt)
|
||||||
|
|
||||||
query := db.Model(&UsageRecord{}).
|
query := db.Model(&UsageRecord{}).
|
||||||
Select(bucketExpr+", model, user_id, user_name, "+
|
Select(bucketExpr + ", model, user_id, user_name, " +
|
||||||
"SUM(prompt_tokens) as prompt_tokens, "+
|
"SUM(prompt_tokens) as prompt_tokens, " +
|
||||||
"SUM(completion_tokens) as completion_tokens, "+
|
"SUM(completion_tokens) as completion_tokens, " +
|
||||||
"SUM(total_tokens) as total_tokens, "+
|
"SUM(total_tokens) as total_tokens, " +
|
||||||
"COUNT(*) as request_count").
|
"COUNT(*) as request_count").
|
||||||
Group("bucket, model, user_id, user_name").
|
Group("bucket, model, user_id, user_name").
|
||||||
Order("bucket ASC")
|
Order("bucket ASC")
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ var _ = Describe("Usage", func() {
|
||||||
db := testDB()
|
db := testDB()
|
||||||
|
|
||||||
// Insert records for two users
|
// Insert records for two users
|
||||||
for i := 0; i < 3; i++ {
|
for range 3 {
|
||||||
err := auth.RecordUsage(db, &auth.UsageRecord{
|
err := auth.RecordUsage(db, &auth.UsageRecord{
|
||||||
UserID: "user-a",
|
UserID: "user-a",
|
||||||
UserName: "Alice",
|
UserName: "Alice",
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@ package anthropic
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
|
|
@ -25,7 +24,7 @@ import (
|
||||||
// @Param request body schema.AnthropicRequest true "query params"
|
// @Param request body schema.AnthropicRequest true "query params"
|
||||||
// @Success 200 {object} schema.AnthropicResponse "Response"
|
// @Success 200 {object} schema.AnthropicResponse "Response"
|
||||||
// @Router /v1/messages [post]
|
// @Router /v1/messages [post]
|
||||||
func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig, natsClient mcpTools.MCPNATSClient) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
id := uuid.New().String()
|
id := uuid.New().String()
|
||||||
|
|
||||||
|
|
@ -52,7 +51,7 @@ func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evalu
|
||||||
funcs, shouldUseFn := convertAnthropicTools(input, cfg)
|
funcs, shouldUseFn := convertAnthropicTools(input, cfg)
|
||||||
|
|
||||||
// MCP injection: prompts, resources, and tools
|
// MCP injection: prompts, resources, and tools
|
||||||
var mcpToolInfos []mcpTools.MCPToolInfo
|
var mcpExecutor mcpTools.ToolExecutor
|
||||||
mcpServers := mcpTools.MCPServersFromMetadata(input.Metadata)
|
mcpServers := mcpTools.MCPServersFromMetadata(input.Metadata)
|
||||||
mcpPromptName, mcpPromptArgs := mcpTools.MCPPromptFromMetadata(input.Metadata)
|
mcpPromptName, mcpPromptArgs := mcpTools.MCPPromptFromMetadata(input.Metadata)
|
||||||
mcpResourceURIs := mcpTools.MCPResourcesFromMetadata(input.Metadata)
|
mcpResourceURIs := mcpTools.MCPResourcesFromMetadata(input.Metadata)
|
||||||
|
|
@ -60,76 +59,29 @@ func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evalu
|
||||||
if (len(mcpServers) > 0 || mcpPromptName != "" || len(mcpResourceURIs) > 0) && (cfg.MCP.Servers != "" || cfg.MCP.Stdio != "") {
|
if (len(mcpServers) > 0 || mcpPromptName != "" || len(mcpResourceURIs) > 0) && (cfg.MCP.Servers != "" || cfg.MCP.Stdio != "") {
|
||||||
remote, stdio, mcpErr := cfg.MCP.MCPConfigFromYAML()
|
remote, stdio, mcpErr := cfg.MCP.MCPConfigFromYAML()
|
||||||
if mcpErr == nil {
|
if mcpErr == nil {
|
||||||
|
mcpExecutor = mcpTools.NewToolExecutor(c.Request().Context(), natsClient, cfg.Name, remote, stdio, mcpServers)
|
||||||
|
|
||||||
|
// Prompt and resource injection (pre-processing step — resolves locally regardless of distributed mode)
|
||||||
namedSessions, sessErr := mcpTools.NamedSessionsFromMCPConfig(cfg.Name, remote, stdio, mcpServers)
|
namedSessions, sessErr := mcpTools.NamedSessionsFromMCPConfig(cfg.Name, remote, stdio, mcpServers)
|
||||||
if sessErr == nil && len(namedSessions) > 0 {
|
if sessErr == nil && len(namedSessions) > 0 {
|
||||||
// Prompt injection
|
mcpCtx, _ := mcpTools.InjectMCPContext(c.Request().Context(), namedSessions, mcpPromptName, mcpPromptArgs, mcpResourceURIs)
|
||||||
if mcpPromptName != "" {
|
if mcpCtx != nil {
|
||||||
prompts, discErr := mcpTools.DiscoverMCPPrompts(c.Request().Context(), namedSessions)
|
openAIMessages = append(mcpCtx.PromptMessages, openAIMessages...)
|
||||||
if discErr == nil {
|
mcpTools.AppendResourceSuffix(openAIMessages, mcpCtx.ResourceSuffix)
|
||||||
promptMsgs, getErr := mcpTools.GetMCPPrompt(c.Request().Context(), prompts, mcpPromptName, mcpPromptArgs)
|
|
||||||
if getErr == nil {
|
|
||||||
var injected []schema.Message
|
|
||||||
for _, pm := range promptMsgs {
|
|
||||||
injected = append(injected, schema.Message{
|
|
||||||
Role: string(pm.Role),
|
|
||||||
Content: mcpTools.PromptMessageToText(pm),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
openAIMessages = append(injected, openAIMessages...)
|
|
||||||
xlog.Debug("Anthropic MCP prompt injected", "prompt", mcpPromptName, "messages", len(injected))
|
|
||||||
} else {
|
|
||||||
xlog.Error("Failed to get MCP prompt", "error", getErr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Resource injection
|
// Tool injection via executor
|
||||||
if len(mcpResourceURIs) > 0 {
|
if mcpExecutor.HasTools() {
|
||||||
resources, discErr := mcpTools.DiscoverMCPResources(c.Request().Context(), namedSessions)
|
mcpFuncs, discErr := mcpExecutor.DiscoverTools(c.Request().Context())
|
||||||
if discErr == nil {
|
if discErr == nil {
|
||||||
var resourceTexts []string
|
for _, fn := range mcpFuncs {
|
||||||
for _, uri := range mcpResourceURIs {
|
funcs = append(funcs, fn)
|
||||||
content, readErr := mcpTools.ReadMCPResource(c.Request().Context(), resources, uri)
|
|
||||||
if readErr != nil {
|
|
||||||
xlog.Error("Failed to read MCP resource", "error", readErr, "uri", uri)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
name := uri
|
|
||||||
for _, r := range resources {
|
|
||||||
if r.URI == uri {
|
|
||||||
name = r.Name
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
resourceTexts = append(resourceTexts, fmt.Sprintf("--- MCP Resource: %s ---\n%s", name, content))
|
|
||||||
}
|
|
||||||
if len(resourceTexts) > 0 && len(openAIMessages) > 0 {
|
|
||||||
lastIdx := len(openAIMessages) - 1
|
|
||||||
suffix := "\n\n" + strings.Join(resourceTexts, "\n\n")
|
|
||||||
switch ct := openAIMessages[lastIdx].Content.(type) {
|
|
||||||
case string:
|
|
||||||
openAIMessages[lastIdx].Content = ct + suffix
|
|
||||||
default:
|
|
||||||
openAIMessages[lastIdx].Content = fmt.Sprintf("%v%s", ct, suffix)
|
|
||||||
}
|
|
||||||
xlog.Debug("Anthropic MCP resources injected", "count", len(resourceTexts))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tool injection
|
|
||||||
if len(mcpServers) > 0 {
|
|
||||||
discovered, discErr := mcpTools.DiscoverMCPTools(c.Request().Context(), namedSessions)
|
|
||||||
if discErr == nil {
|
|
||||||
mcpToolInfos = discovered
|
|
||||||
for _, ti := range mcpToolInfos {
|
|
||||||
funcs = append(funcs, ti.Function)
|
|
||||||
}
|
|
||||||
shouldUseFn = len(funcs) > 0 && cfg.ShouldUseFunctions()
|
|
||||||
xlog.Debug("Anthropic MCP tools injected", "count", len(mcpToolInfos), "total_funcs", len(funcs))
|
|
||||||
} else {
|
|
||||||
xlog.Error("Failed to discover MCP tools", "error", discErr)
|
|
||||||
}
|
}
|
||||||
|
shouldUseFn = len(funcs) > 0 && cfg.ShouldUseFunctions()
|
||||||
|
xlog.Debug("Anthropic MCP tools injected", "count", len(mcpFuncs), "total_funcs", len(funcs))
|
||||||
|
} else {
|
||||||
|
xlog.Error("Failed to discover MCP tools", "error", discErr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -177,19 +129,19 @@ func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evalu
|
||||||
xlog.Debug("Anthropic Messages - Prompt (after templating)", "prompt", predInput)
|
xlog.Debug("Anthropic Messages - Prompt (after templating)", "prompt", predInput)
|
||||||
|
|
||||||
if input.Stream {
|
if input.Stream {
|
||||||
return handleAnthropicStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpToolInfos, evaluator)
|
return handleAnthropicStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpExecutor, evaluator)
|
||||||
}
|
}
|
||||||
|
|
||||||
return handleAnthropicNonStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpToolInfos, evaluator)
|
return handleAnthropicNonStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpExecutor, evaluator)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleAnthropicNonStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpToolInfos []mcpTools.MCPToolInfo, evaluator *templates.Evaluator) error {
|
func handleAnthropicNonStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpExecutor mcpTools.ToolExecutor, evaluator *templates.Evaluator) error {
|
||||||
mcpMaxIterations := 10
|
mcpMaxIterations := 10
|
||||||
if cfg.Agent.MaxIterations > 0 {
|
if cfg.Agent.MaxIterations > 0 {
|
||||||
mcpMaxIterations = cfg.Agent.MaxIterations
|
mcpMaxIterations = cfg.Agent.MaxIterations
|
||||||
}
|
}
|
||||||
hasMCPTools := len(mcpToolInfos) > 0
|
hasMCPTools := mcpExecutor != nil && mcpExecutor.HasTools()
|
||||||
|
|
||||||
for mcpIteration := 0; mcpIteration <= mcpMaxIterations; mcpIteration++ {
|
for mcpIteration := 0; mcpIteration <= mcpMaxIterations; mcpIteration++ {
|
||||||
// Re-template on each MCP iteration since messages may have changed
|
// Re-template on each MCP iteration since messages may have changed
|
||||||
|
|
@ -227,7 +179,7 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic
|
||||||
if hasMCPTools && shouldUseFn && len(toolCalls) > 0 {
|
if hasMCPTools && shouldUseFn && len(toolCalls) > 0 {
|
||||||
var hasMCPCalls bool
|
var hasMCPCalls bool
|
||||||
for _, tc := range toolCalls {
|
for _, tc := range toolCalls {
|
||||||
if mcpTools.IsMCPTool(mcpToolInfos, tc.Name) {
|
if mcpExecutor != nil && mcpExecutor.IsTool(tc.Name) {
|
||||||
hasMCPCalls = true
|
hasMCPCalls = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
@ -257,13 +209,12 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic
|
||||||
|
|
||||||
// Execute each MCP tool call and append results
|
// Execute each MCP tool call and append results
|
||||||
for _, tc := range assistantMsg.ToolCalls {
|
for _, tc := range assistantMsg.ToolCalls {
|
||||||
if !mcpTools.IsMCPTool(mcpToolInfos, tc.FunctionCall.Name) {
|
if mcpExecutor == nil || !mcpExecutor.IsTool(tc.FunctionCall.Name) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
xlog.Debug("Executing MCP tool (Anthropic)", "tool", tc.FunctionCall.Name, "iteration", mcpIteration)
|
xlog.Debug("Executing MCP tool (Anthropic)", "tool", tc.FunctionCall.Name, "iteration", mcpIteration)
|
||||||
toolResult, toolErr := mcpTools.ExecuteMCPToolCall(
|
toolResult, toolErr := mcpExecutor.ExecuteTool(
|
||||||
c.Request().Context(), mcpToolInfos,
|
c.Request().Context(), tc.FunctionCall.Name, tc.FunctionCall.Arguments,
|
||||||
tc.FunctionCall.Name, tc.FunctionCall.Arguments,
|
|
||||||
)
|
)
|
||||||
if toolErr != nil {
|
if toolErr != nil {
|
||||||
xlog.Error("MCP tool execution failed", "tool", tc.FunctionCall.Name, "error", toolErr)
|
xlog.Error("MCP tool execution failed", "tool", tc.FunctionCall.Name, "error", toolErr)
|
||||||
|
|
@ -290,10 +241,10 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic
|
||||||
if shouldUseFn && len(toolCalls) > 0 {
|
if shouldUseFn && len(toolCalls) > 0 {
|
||||||
stopReason = "tool_use"
|
stopReason = "tool_use"
|
||||||
for _, tc := range toolCalls {
|
for _, tc := range toolCalls {
|
||||||
var inputArgs map[string]interface{}
|
var inputArgs map[string]any
|
||||||
if err := json.Unmarshal([]byte(tc.Arguments), &inputArgs); err != nil {
|
if err := json.Unmarshal([]byte(tc.Arguments), &inputArgs); err != nil {
|
||||||
xlog.Warn("Failed to parse tool call arguments as JSON", "error", err, "args", tc.Arguments)
|
xlog.Warn("Failed to parse tool call arguments as JSON", "error", err, "args", tc.Arguments)
|
||||||
inputArgs = map[string]interface{}{"raw": tc.Arguments}
|
inputArgs = map[string]any{"raw": tc.Arguments}
|
||||||
}
|
}
|
||||||
contentBlocks = append(contentBlocks, schema.AnthropicContentBlock{
|
contentBlocks = append(contentBlocks, schema.AnthropicContentBlock{
|
||||||
Type: "tool_use",
|
Type: "tool_use",
|
||||||
|
|
@ -316,9 +267,9 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic
|
||||||
contentBlocks = append(contentBlocks, schema.AnthropicContentBlock{Type: "text", Text: stripped})
|
contentBlocks = append(contentBlocks, schema.AnthropicContentBlock{Type: "text", Text: stripped})
|
||||||
}
|
}
|
||||||
for i, fc := range parsed {
|
for i, fc := range parsed {
|
||||||
var inputArgs map[string]interface{}
|
var inputArgs map[string]any
|
||||||
if err := json.Unmarshal([]byte(fc.Arguments), &inputArgs); err != nil {
|
if err := json.Unmarshal([]byte(fc.Arguments), &inputArgs); err != nil {
|
||||||
inputArgs = map[string]interface{}{"raw": fc.Arguments}
|
inputArgs = map[string]any{"raw": fc.Arguments}
|
||||||
}
|
}
|
||||||
toolCallID := fc.ID
|
toolCallID := fc.ID
|
||||||
if toolCallID == "" {
|
if toolCallID == "" {
|
||||||
|
|
@ -365,7 +316,7 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic
|
||||||
return sendAnthropicError(c, 500, "api_error", "MCP iteration limit reached")
|
return sendAnthropicError(c, 500, "api_error", "MCP iteration limit reached")
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpToolInfos []mcpTools.MCPToolInfo, evaluator *templates.Evaluator) error {
|
func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpExecutor mcpTools.ToolExecutor, evaluator *templates.Evaluator) error {
|
||||||
c.Response().Header().Set("Content-Type", "text/event-stream")
|
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||||
c.Response().Header().Set("Connection", "keep-alive")
|
c.Response().Header().Set("Connection", "keep-alive")
|
||||||
|
|
@ -388,7 +339,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||||
if cfg.Agent.MaxIterations > 0 {
|
if cfg.Agent.MaxIterations > 0 {
|
||||||
mcpMaxIterations = cfg.Agent.MaxIterations
|
mcpMaxIterations = cfg.Agent.MaxIterations
|
||||||
}
|
}
|
||||||
hasMCPTools := len(mcpToolInfos) > 0
|
hasMCPTools := mcpExecutor != nil && mcpExecutor.HasTools()
|
||||||
|
|
||||||
for mcpIteration := 0; mcpIteration <= mcpMaxIterations; mcpIteration++ {
|
for mcpIteration := 0; mcpIteration <= mcpMaxIterations; mcpIteration++ {
|
||||||
// Re-template on MCP iterations
|
// Re-template on MCP iterations
|
||||||
|
|
@ -483,7 +434,14 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||||
_, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, func(s string, c *[]schema.Choice) {}, tokenCallback)
|
_, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, func(s string, c *[]schema.Choice) {}, tokenCallback)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xlog.Error("Anthropic stream model inference failed", "error", err)
|
xlog.Error("Anthropic stream model inference failed", "error", err)
|
||||||
return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("model inference failed: %v", err))
|
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||||
|
Type: "error",
|
||||||
|
Error: &schema.AnthropicError{
|
||||||
|
Type: "api_error",
|
||||||
|
Message: fmt.Sprintf("model inference failed: %v", err),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Also check chat deltas for tool calls
|
// Also check chat deltas for tool calls
|
||||||
|
|
@ -495,7 +453,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||||
if hasMCPTools && len(collectedToolCalls) > 0 {
|
if hasMCPTools && len(collectedToolCalls) > 0 {
|
||||||
var hasMCPCalls bool
|
var hasMCPCalls bool
|
||||||
for _, tc := range collectedToolCalls {
|
for _, tc := range collectedToolCalls {
|
||||||
if mcpTools.IsMCPTool(mcpToolInfos, tc.Name) {
|
if mcpExecutor != nil && mcpExecutor.IsTool(tc.Name) {
|
||||||
hasMCPCalls = true
|
hasMCPCalls = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
@ -525,13 +483,12 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||||
|
|
||||||
// Execute MCP tool calls
|
// Execute MCP tool calls
|
||||||
for _, tc := range assistantMsg.ToolCalls {
|
for _, tc := range assistantMsg.ToolCalls {
|
||||||
if !mcpTools.IsMCPTool(mcpToolInfos, tc.FunctionCall.Name) {
|
if mcpExecutor == nil || !mcpExecutor.IsTool(tc.FunctionCall.Name) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
xlog.Debug("Executing MCP tool (Anthropic stream)", "tool", tc.FunctionCall.Name, "iteration", mcpIteration)
|
xlog.Debug("Executing MCP tool (Anthropic stream)", "tool", tc.FunctionCall.Name, "iteration", mcpIteration)
|
||||||
toolResult, toolErr := mcpTools.ExecuteMCPToolCall(
|
toolResult, toolErr := mcpExecutor.ExecuteTool(
|
||||||
c.Request().Context(), mcpToolInfos,
|
c.Request().Context(), tc.FunctionCall.Name, tc.FunctionCall.Arguments,
|
||||||
tc.FunctionCall.Name, tc.FunctionCall.Arguments,
|
|
||||||
)
|
)
|
||||||
if toolErr != nil {
|
if toolErr != nil {
|
||||||
xlog.Error("MCP tool execution failed", "tool", tc.FunctionCall.Name, "error", toolErr)
|
xlog.Error("MCP tool execution failed", "tool", tc.FunctionCall.Name, "error", toolErr)
|
||||||
|
|
@ -686,7 +643,7 @@ func convertAnthropicToOpenAIMessages(input *schema.AnthropicRequest) []schema.M
|
||||||
case string:
|
case string:
|
||||||
openAIMsg.StringContent = content
|
openAIMsg.StringContent = content
|
||||||
openAIMsg.Content = content
|
openAIMsg.Content = content
|
||||||
case []interface{}:
|
case []any:
|
||||||
// Handle array of content blocks
|
// Handle array of content blocks
|
||||||
var textContent string
|
var textContent string
|
||||||
var stringImages []string
|
var stringImages []string
|
||||||
|
|
@ -694,7 +651,7 @@ func convertAnthropicToOpenAIMessages(input *schema.AnthropicRequest) []schema.M
|
||||||
toolCallIndex := 0
|
toolCallIndex := 0
|
||||||
|
|
||||||
for _, block := range content {
|
for _, block := range content {
|
||||||
if blockMap, ok := block.(map[string]interface{}); ok {
|
if blockMap, ok := block.(map[string]any); ok {
|
||||||
blockType, _ := blockMap["type"].(string)
|
blockType, _ := blockMap["type"].(string)
|
||||||
switch blockType {
|
switch blockType {
|
||||||
case "text":
|
case "text":
|
||||||
|
|
@ -703,7 +660,7 @@ func convertAnthropicToOpenAIMessages(input *schema.AnthropicRequest) []schema.M
|
||||||
}
|
}
|
||||||
case "image":
|
case "image":
|
||||||
// Handle image content
|
// Handle image content
|
||||||
if source, ok := blockMap["source"].(map[string]interface{}); ok {
|
if source, ok := blockMap["source"].(map[string]any); ok {
|
||||||
if sourceType, ok := source["type"].(string); ok && sourceType == "base64" {
|
if sourceType, ok := source["type"].(string); ok && sourceType == "base64" {
|
||||||
if data, ok := source["data"].(string); ok {
|
if data, ok := source["data"].(string); ok {
|
||||||
mediaType, _ := source["media_type"].(string)
|
mediaType, _ := source["media_type"].(string)
|
||||||
|
|
@ -751,10 +708,10 @@ func convertAnthropicToOpenAIMessages(input *schema.AnthropicRequest) []schema.M
|
||||||
switch rc := resultContent.(type) {
|
switch rc := resultContent.(type) {
|
||||||
case string:
|
case string:
|
||||||
resultText = rc
|
resultText = rc
|
||||||
case []interface{}:
|
case []any:
|
||||||
// Array of content blocks
|
// Array of content blocks
|
||||||
for _, cb := range rc {
|
for _, cb := range rc {
|
||||||
if cbMap, ok := cb.(map[string]interface{}); ok {
|
if cbMap, ok := cb.(map[string]any); ok {
|
||||||
if cbMap["type"] == "text" {
|
if cbMap["type"] == "text" {
|
||||||
if text, ok := cbMap["text"].(string); ok {
|
if text, ok := cbMap["text"].(string); ok {
|
||||||
resultText += text
|
resultText += text
|
||||||
|
|
@ -823,7 +780,7 @@ func convertAnthropicTools(input *schema.AnthropicRequest, cfg *config.ModelConf
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
// "auto" is the default - let model decide
|
// "auto" is the default - let model decide
|
||||||
case map[string]interface{}:
|
case map[string]any:
|
||||||
// Specific tool selection: {"type": "tool", "name": "tool_name"}
|
// Specific tool selection: {"type": "tool", "name": "tool_name"}
|
||||||
if tcType, ok := tc["type"].(string); ok && tcType == "tool" {
|
if tcType, ok := tc["type"].(string); ok && tcType == "tool" {
|
||||||
if name, ok := tc["name"].(string); ok {
|
if name, ok := tc["name"].(string); ok {
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
package explorer
|
package explorer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"cmp"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sort"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
|
|
@ -14,7 +15,7 @@ import (
|
||||||
|
|
||||||
func Dashboard() echo.HandlerFunc {
|
func Dashboard() echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
summary := map[string]interface{}{
|
summary := map[string]any{
|
||||||
"Title": "LocalAI API - " + internal.PrintableVersion(),
|
"Title": "LocalAI API - " + internal.PrintableVersion(),
|
||||||
"Version": internal.PrintableVersion(),
|
"Version": internal.PrintableVersion(),
|
||||||
"BaseURL": middleware.BaseURL(c),
|
"BaseURL": middleware.BaseURL(c),
|
||||||
|
|
@ -61,8 +62,8 @@ func ShowNetworks(db *explorer.Database) echo.HandlerFunc {
|
||||||
}
|
}
|
||||||
|
|
||||||
// order by number of clusters
|
// order by number of clusters
|
||||||
sort.Slice(results, func(i, j int) bool {
|
slices.SortFunc(results, func(a, b Network) int {
|
||||||
return len(results[i].Clusters) > len(results[j].Clusters)
|
return cmp.Compare(len(b.Clusters), len(a.Clusters))
|
||||||
})
|
})
|
||||||
|
|
||||||
return c.JSON(http.StatusOK, results)
|
return c.JSON(http.StatusOK, results)
|
||||||
|
|
@ -73,36 +74,36 @@ func AddNetwork(db *explorer.Database) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
request := new(AddNetworkRequest)
|
request := new(AddNetworkRequest)
|
||||||
if err := c.Bind(request); err != nil {
|
if err := c.Bind(request); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Cannot parse JSON"})
|
return c.JSON(http.StatusBadRequest, map[string]any{"error": "Cannot parse JSON"})
|
||||||
}
|
}
|
||||||
|
|
||||||
if request.Token == "" {
|
if request.Token == "" {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Token is required"})
|
return c.JSON(http.StatusBadRequest, map[string]any{"error": "Token is required"})
|
||||||
}
|
}
|
||||||
|
|
||||||
if request.Name == "" {
|
if request.Name == "" {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Name is required"})
|
return c.JSON(http.StatusBadRequest, map[string]any{"error": "Name is required"})
|
||||||
}
|
}
|
||||||
|
|
||||||
if request.Description == "" {
|
if request.Description == "" {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Description is required"})
|
return c.JSON(http.StatusBadRequest, map[string]any{"error": "Description is required"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: check if token is valid, otherwise reject
|
// TODO: check if token is valid, otherwise reject
|
||||||
// try to decode the token from base64
|
// try to decode the token from base64
|
||||||
_, err := base64.StdEncoding.DecodeString(request.Token)
|
_, err := base64.StdEncoding.DecodeString(request.Token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid token"})
|
return c.JSON(http.StatusBadRequest, map[string]any{"error": "Invalid token"})
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, exists := db.Get(request.Token); exists {
|
if _, exists := db.Get(request.Token); exists {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Token already exists"})
|
return c.JSON(http.StatusBadRequest, map[string]any{"error": "Token already exists"})
|
||||||
}
|
}
|
||||||
err = db.Set(request.Token, explorer.TokenData{Name: request.Name, Description: request.Description})
|
err = db.Set(request.Token, explorer.TokenData{Name: request.Name, Description: request.Description})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Cannot add token"})
|
return c.JSON(http.StatusInternalServerError, map[string]any{"error": "Cannot add token"})
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Token added"})
|
return c.JSON(http.StatusOK, map[string]any{"message": "Token added"})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
package localai
|
package localai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
@ -8,12 +9,12 @@ import (
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
"github.com/mudler/LocalAI/core/application"
|
"github.com/mudler/LocalAI/core/application"
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
"github.com/mudler/LocalAI/core/services"
|
"github.com/mudler/LocalAI/core/services/agentpool"
|
||||||
)
|
)
|
||||||
|
|
||||||
// getJobService returns the job service for the current user.
|
// getJobService returns the job service for the current user.
|
||||||
// Falls back to the global service when no user is authenticated.
|
// Falls back to the global service when no user is authenticated.
|
||||||
func getJobService(app *application.Application, c echo.Context) *services.AgentJobService {
|
func getJobService(app *application.Application, c echo.Context) *agentpool.AgentJobService {
|
||||||
userID := getUserID(c)
|
userID := getUserID(c)
|
||||||
if userID == "" {
|
if userID == "" {
|
||||||
return app.AgentJobService()
|
return app.AgentJobService()
|
||||||
|
|
@ -54,7 +55,7 @@ func UpdateTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := getJobService(app, c).UpdateTask(id, task); err != nil {
|
if err := getJobService(app, c).UpdateTask(id, task); err != nil {
|
||||||
if err.Error() == "task not found: "+id {
|
if errors.Is(err, agentpool.ErrTaskNotFound) {
|
||||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||||
}
|
}
|
||||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||||
|
|
@ -68,7 +69,7 @@ func DeleteTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
id := c.Param("id")
|
id := c.Param("id")
|
||||||
if err := getJobService(app, c).DeleteTask(id); err != nil {
|
if err := getJobService(app, c).DeleteTask(id); err != nil {
|
||||||
if err.Error() == "task not found: "+id {
|
if errors.Is(err, agentpool.ErrTaskNotFound) {
|
||||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||||
}
|
}
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||||
|
|
@ -244,7 +245,7 @@ func CancelJobEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
id := c.Param("id")
|
id := c.Param("id")
|
||||||
if err := getJobService(app, c).CancelJob(id); err != nil {
|
if err := getJobService(app, c).CancelJob(id); err != nil {
|
||||||
if err.Error() == "job not found: "+id {
|
if errors.Is(err, agentpool.ErrJobNotFound) {
|
||||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||||
}
|
}
|
||||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||||
|
|
@ -258,7 +259,7 @@ func DeleteJobEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
id := c.Param("id")
|
id := c.Param("id")
|
||||||
if err := getJobService(app, c).DeleteJob(id); err != nil {
|
if err := getJobService(app, c).DeleteJob(id); err != nil {
|
||||||
if err.Error() == "job not found: "+id {
|
if errors.Is(err, agentpool.ErrJobNotFound) {
|
||||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||||
}
|
}
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||||
|
|
@ -275,7 +276,7 @@ func ExecuteTaskByNameEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
|
|
||||||
if c.Request().ContentLength > 0 {
|
if c.Request().ContentLength > 0 {
|
||||||
if err := c.Bind(¶ms); err != nil {
|
if err := c.Bind(¶ms); err != nil {
|
||||||
body := make(map[string]interface{})
|
body := make(map[string]any)
|
||||||
if err := c.Bind(&body); err == nil {
|
if err := c.Bind(&body); err == nil {
|
||||||
params = make(map[string]string)
|
params = make(map[string]string)
|
||||||
for k, v := range body {
|
for k, v := range body {
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ package localai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
@ -10,8 +11,9 @@ import (
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
"github.com/mudler/LocalAI/core/application"
|
|
||||||
coreTypes "github.com/mudler/LocalAGI/core/types"
|
coreTypes "github.com/mudler/LocalAGI/core/types"
|
||||||
|
"github.com/mudler/LocalAI/core/application"
|
||||||
|
"github.com/mudler/LocalAI/core/services/agents"
|
||||||
"github.com/mudler/xlog"
|
"github.com/mudler/xlog"
|
||||||
"github.com/sashabaranov/go-openai"
|
"github.com/sashabaranov/go-openai"
|
||||||
)
|
)
|
||||||
|
|
@ -50,55 +52,105 @@ func AgentResponsesInterceptor(app *application.Application) echo.MiddlewareFunc
|
||||||
return next(c)
|
return next(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if this model name is an agent
|
// Check if this model name is an agent — try in-process agent first,
|
||||||
ag := svc.GetAgent(req.Model)
|
// fall back to config lookup (covers distributed mode where agents
|
||||||
if ag == nil {
|
// don't run in-process).
|
||||||
return next(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
// This is an agent — handle the request directly
|
|
||||||
messages := parseInputToMessages(req.Input)
|
messages := parseInputToMessages(req.Input)
|
||||||
if len(messages) == 0 {
|
userID := effectiveUserID(c)
|
||||||
return c.JSON(http.StatusBadRequest, map[string]any{
|
ag := svc.GetAgent(req.Model)
|
||||||
"error": map[string]string{
|
if ag == nil && svc.GetAgentConfigForUser(userID, req.Model) == nil {
|
||||||
"type": "invalid_request_error",
|
return next(c) // not an agent
|
||||||
"message": "no input messages provided",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
jobOptions := []coreTypes.JobOption{
|
// Extract the last user message for the executor
|
||||||
coreTypes.WithConversationHistory(messages),
|
var userMessage string
|
||||||
|
for i := len(messages) - 1; i >= 0; i-- {
|
||||||
|
if messages[i].Role == "user" {
|
||||||
|
userMessage = messages[i].Content
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
res := ag.Ask(jobOptions...)
|
var responseText string
|
||||||
|
|
||||||
if res == nil {
|
if ag != nil {
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]any{
|
// Local mode: use LocalAGI agent directly
|
||||||
"error": map[string]string{
|
jobOptions := []coreTypes.JobOption{
|
||||||
"type": "server_error",
|
coreTypes.WithConversationHistory(messages),
|
||||||
"message": "agent request failed or was cancelled",
|
}
|
||||||
},
|
|
||||||
})
|
res := ag.Ask(jobOptions...)
|
||||||
}
|
if res == nil {
|
||||||
if res.Error != nil {
|
return c.JSON(http.StatusInternalServerError, map[string]any{
|
||||||
xlog.Error("Error asking agent via responses API", "agent", req.Model, "error", res.Error)
|
"error": map[string]string{
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]any{
|
"type": "server_error",
|
||||||
"error": map[string]string{
|
"message": "agent request failed or was cancelled",
|
||||||
"type": "server_error",
|
},
|
||||||
"message": res.Error.Error(),
|
})
|
||||||
},
|
}
|
||||||
|
if res.Error != nil {
|
||||||
|
xlog.Error("Error asking agent via responses API", "agent", req.Model, "error", res.Error)
|
||||||
|
return c.JSON(http.StatusInternalServerError, map[string]any{
|
||||||
|
"error": map[string]string{
|
||||||
|
"type": "server_error",
|
||||||
|
"message": res.Error.Error(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
responseText = res.Response
|
||||||
|
} else {
|
||||||
|
// Distributed mode: dispatch via NATS + wait for response synchronously
|
||||||
|
var bridge *agents.EventBridge
|
||||||
|
if d := app.Distributed(); d != nil {
|
||||||
|
bridge = d.AgentBridge
|
||||||
|
}
|
||||||
|
if bridge == nil {
|
||||||
|
return next(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe BEFORE dispatching so we never miss a fast response
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request().Context(), 5*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
responseCh := make(chan string, 1)
|
||||||
|
sub, err := bridge.SubscribeEvents(req.Model, userID, func(evt agents.AgentEvent) {
|
||||||
|
if evt.EventType == "json_message" && evt.Sender == "agent" {
|
||||||
|
responseCh <- evt.Content
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, map[string]any{
|
||||||
|
"error": map[string]string{"type": "server_error", "message": "failed to subscribe to agent events"},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
defer sub.Unsubscribe()
|
||||||
|
|
||||||
|
// Now dispatch via ChatForUser (publishes to NATS)
|
||||||
|
_, err = svc.ChatForUser(userID, req.Model, userMessage)
|
||||||
|
if err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, map[string]any{
|
||||||
|
"error": map[string]string{"type": "server_error", "message": err.Error()},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case responseText = <-responseCh:
|
||||||
|
// Got the response
|
||||||
|
case <-ctx.Done():
|
||||||
|
return c.JSON(http.StatusGatewayTimeout, map[string]any{
|
||||||
|
"error": map[string]string{"type": "server_error", "message": "agent response timeout"},
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
id := fmt.Sprintf("resp_%s", uuid.New().String())
|
id := fmt.Sprintf("resp_%s", uuid.New().String())
|
||||||
|
|
||||||
return c.JSON(http.StatusOK, map[string]any{
|
return c.JSON(http.StatusOK, map[string]any{
|
||||||
"id": id,
|
"id": id,
|
||||||
"object": "response",
|
"object": "response",
|
||||||
"created_at": time.Now().Unix(),
|
"created_at": time.Now().Unix(),
|
||||||
"status": "completed",
|
"status": "completed",
|
||||||
"model": req.Model,
|
"model": req.Model,
|
||||||
"previous_response_id": nil,
|
"previous_response_id": nil,
|
||||||
"output": []any{
|
"output": []any{
|
||||||
map[string]any{
|
map[string]any{
|
||||||
|
|
@ -109,7 +161,7 @@ func AgentResponsesInterceptor(app *application.Application) echo.MiddlewareFunc
|
||||||
"content": []map[string]any{
|
"content": []map[string]any{
|
||||||
{
|
{
|
||||||
"type": "output_text",
|
"type": "output_text",
|
||||||
"text": res.Response,
|
"text": responseText,
|
||||||
"annotations": []any{},
|
"annotations": []any{},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ import (
|
||||||
|
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
"github.com/mudler/LocalAI/core/application"
|
"github.com/mudler/LocalAI/core/application"
|
||||||
|
skillsManager "github.com/mudler/LocalAI/core/services/skills"
|
||||||
skilldomain "github.com/mudler/skillserver/pkg/domain"
|
skilldomain "github.com/mudler/skillserver/pkg/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -41,27 +42,48 @@ func skillsToResponses(skills []skilldomain.Skill) []skillResponse {
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getSkillManager returns a SkillManager for the request's user.
|
||||||
|
func getSkillManager(c echo.Context, app *application.Application) (skillsManager.Manager, error) {
|
||||||
|
svc := app.AgentPoolService()
|
||||||
|
userID := getUserID(c)
|
||||||
|
return svc.SkillManagerForUser(userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getSkillManagerEffective(c echo.Context, app *application.Application) (skillsManager.Manager, error) {
|
||||||
|
svc := app.AgentPoolService()
|
||||||
|
userID := effectiveUserID(c)
|
||||||
|
return svc.SkillManagerForUser(userID)
|
||||||
|
}
|
||||||
|
|
||||||
func ListSkillsEndpoint(app *application.Application) echo.HandlerFunc {
|
func ListSkillsEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
svc := app.AgentPoolService()
|
mgr, err := getSkillManager(c, app)
|
||||||
userID := getUserID(c)
|
if err != nil {
|
||||||
skills, err := svc.ListSkillsForUser(userID)
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||||
|
}
|
||||||
|
skills, err := mgr.List()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Admin cross-user aggregation
|
// Admin cross-user aggregation
|
||||||
if wantsAllUsers(c) {
|
if wantsAllUsers(c) {
|
||||||
|
svc := app.AgentPoolService()
|
||||||
usm := svc.UserServicesManager()
|
usm := svc.UserServicesManager()
|
||||||
if usm != nil {
|
if usm != nil {
|
||||||
userIDs, _ := usm.ListAllUserIDs()
|
userIDs, _ := usm.ListAllUserIDs()
|
||||||
userGroups := map[string]any{}
|
userGroups := map[string]any{}
|
||||||
|
userID := getUserID(c)
|
||||||
for _, uid := range userIDs {
|
for _, uid := range userIDs {
|
||||||
if uid == userID {
|
if uid == userID {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
userSkills, err := svc.ListSkillsForUser(uid)
|
uidMgr, mgrErr := svc.SkillManagerForUser(uid)
|
||||||
if err != nil || len(userSkills) == 0 {
|
if mgrErr != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
userSkills, listErr := uidMgr.List()
|
||||||
|
if listErr != nil || len(userSkills) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
userGroups[uid] = map[string]any{"skills": skillsToResponses(userSkills)}
|
userGroups[uid] = map[string]any{"skills": skillsToResponses(userSkills)}
|
||||||
|
|
@ -76,25 +98,28 @@ func ListSkillsEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.JSON(http.StatusOK, skillsToResponses(skills))
|
return c.JSON(http.StatusOK, map[string]any{"skills": skillsToResponses(skills)})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetSkillsConfigEndpoint(app *application.Application) echo.HandlerFunc {
|
func GetSkillsConfigEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
svc := app.AgentPoolService()
|
mgr, err := getSkillManager(c, app)
|
||||||
userID := getUserID(c)
|
if err != nil {
|
||||||
cfg := svc.GetSkillsConfigForUser(userID)
|
return c.JSON(http.StatusOK, map[string]string{})
|
||||||
return c.JSON(http.StatusOK, cfg)
|
}
|
||||||
|
return c.JSON(http.StatusOK, mgr.GetConfig())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func SearchSkillsEndpoint(app *application.Application) echo.HandlerFunc {
|
func SearchSkillsEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
svc := app.AgentPoolService()
|
mgr, err := getSkillManager(c, app)
|
||||||
userID := getUserID(c)
|
if err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||||
|
}
|
||||||
query := c.QueryParam("q")
|
query := c.QueryParam("q")
|
||||||
skills, err := svc.SearchSkillsForUser(userID, query)
|
skills, err := mgr.Search(query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||||
}
|
}
|
||||||
|
|
@ -104,8 +129,10 @@ func SearchSkillsEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
|
|
||||||
func CreateSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
func CreateSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
svc := app.AgentPoolService()
|
mgr, err := getSkillManager(c, app)
|
||||||
userID := getUserID(c)
|
if err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||||
|
}
|
||||||
var payload struct {
|
var payload struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
|
|
@ -118,7 +145,7 @@ func CreateSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
if err := c.Bind(&payload); err != nil {
|
if err := c.Bind(&payload); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||||
}
|
}
|
||||||
skill, err := svc.CreateSkillForUser(userID, payload.Name, payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata)
|
skill, err := mgr.Create(payload.Name, payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if strings.Contains(err.Error(), "already exists") {
|
if strings.Contains(err.Error(), "already exists") {
|
||||||
return c.JSON(http.StatusConflict, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusConflict, map[string]string{"error": err.Error()})
|
||||||
|
|
@ -131,9 +158,11 @@ func CreateSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
|
|
||||||
func GetSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
func GetSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
svc := app.AgentPoolService()
|
mgr, err := getSkillManagerEffective(c, app)
|
||||||
userID := effectiveUserID(c)
|
if err != nil {
|
||||||
skill, err := svc.GetSkillForUser(userID, c.Param("name"))
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||||
|
}
|
||||||
|
skill, err := mgr.Get(c.Param("name"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||||
}
|
}
|
||||||
|
|
@ -143,8 +172,10 @@ func GetSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
|
|
||||||
func UpdateSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
func UpdateSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
svc := app.AgentPoolService()
|
mgr, err := getSkillManagerEffective(c, app)
|
||||||
userID := effectiveUserID(c)
|
if err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||||
|
}
|
||||||
var payload struct {
|
var payload struct {
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
|
|
@ -156,7 +187,7 @@ func UpdateSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
if err := c.Bind(&payload); err != nil {
|
if err := c.Bind(&payload); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||||
}
|
}
|
||||||
skill, err := svc.UpdateSkillForUser(userID, c.Param("name"), payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata)
|
skill, err := mgr.Update(c.Param("name"), payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if strings.Contains(err.Error(), "not found") {
|
if strings.Contains(err.Error(), "not found") {
|
||||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||||
|
|
@ -169,9 +200,11 @@ func UpdateSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
|
|
||||||
func DeleteSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
func DeleteSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
svc := app.AgentPoolService()
|
mgr, err := getSkillManagerEffective(c, app)
|
||||||
userID := effectiveUserID(c)
|
if err != nil {
|
||||||
if err := svc.DeleteSkillForUser(userID, c.Param("name")); err != nil {
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||||
|
}
|
||||||
|
if err := mgr.Delete(c.Param("name")); err != nil {
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||||
}
|
}
|
||||||
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
|
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
|
||||||
|
|
@ -180,10 +213,12 @@ func DeleteSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
|
|
||||||
func ExportSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
func ExportSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
svc := app.AgentPoolService()
|
mgr, err := getSkillManagerEffective(c, app)
|
||||||
userID := effectiveUserID(c)
|
if err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||||
|
}
|
||||||
name := c.Param("*")
|
name := c.Param("*")
|
||||||
data, err := svc.ExportSkillForUser(userID, name)
|
data, err := mgr.Export(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||||
}
|
}
|
||||||
|
|
@ -195,8 +230,10 @@ func ExportSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
|
|
||||||
func ImportSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
func ImportSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
svc := app.AgentPoolService()
|
mgr, err := getSkillManager(c, app)
|
||||||
userID := getUserID(c)
|
if err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||||
|
}
|
||||||
file, err := c.FormFile("file")
|
file, err := c.FormFile("file")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "file required"})
|
return c.JSON(http.StatusBadRequest, map[string]string{"error": "file required"})
|
||||||
|
|
@ -210,7 +247,7 @@ func ImportSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||||
}
|
}
|
||||||
skill, err := svc.ImportSkillForUser(userID, data)
|
skill, err := mgr.Import(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||||
}
|
}
|
||||||
|
|
@ -222,9 +259,11 @@ func ImportSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
|
|
||||||
func ListSkillResourcesEndpoint(app *application.Application) echo.HandlerFunc {
|
func ListSkillResourcesEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
svc := app.AgentPoolService()
|
mgr, err := getSkillManagerEffective(c, app)
|
||||||
userID := effectiveUserID(c)
|
if err != nil {
|
||||||
resources, skill, err := svc.ListSkillResourcesForUser(userID, c.Param("name"))
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||||
|
}
|
||||||
|
resources, skill, err := mgr.ListResources(c.Param("name"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||||
}
|
}
|
||||||
|
|
@ -260,9 +299,11 @@ func ListSkillResourcesEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
|
|
||||||
func GetSkillResourceEndpoint(app *application.Application) echo.HandlerFunc {
|
func GetSkillResourceEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
svc := app.AgentPoolService()
|
mgr, err := getSkillManagerEffective(c, app)
|
||||||
userID := effectiveUserID(c)
|
if err != nil {
|
||||||
content, info, err := svc.GetSkillResourceForUser(userID, c.Param("name"), c.Param("*"))
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||||
|
}
|
||||||
|
content, info, err := mgr.GetResource(c.Param("name"), c.Param("*"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||||
}
|
}
|
||||||
|
|
@ -281,10 +322,12 @@ func GetSkillResourceEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
|
|
||||||
func CreateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc {
|
func CreateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
svc := app.AgentPoolService()
|
mgr, err := getSkillManager(c, app)
|
||||||
userID := getUserID(c)
|
|
||||||
file, err := c.FormFile("file")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||||
|
}
|
||||||
|
file, fileErr := c.FormFile("file")
|
||||||
|
if fileErr != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "file is required"})
|
return c.JSON(http.StatusBadRequest, map[string]string{"error": "file is required"})
|
||||||
}
|
}
|
||||||
path := c.FormValue("path")
|
path := c.FormValue("path")
|
||||||
|
|
@ -300,7 +343,7 @@ func CreateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||||
}
|
}
|
||||||
if err := svc.CreateSkillResourceForUser(userID, c.Param("name"), path, data); err != nil {
|
if err := mgr.CreateResource(c.Param("name"), path, data); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||||
}
|
}
|
||||||
return c.JSON(http.StatusCreated, map[string]string{"path": path})
|
return c.JSON(http.StatusCreated, map[string]string{"path": path})
|
||||||
|
|
@ -309,15 +352,17 @@ func CreateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc
|
||||||
|
|
||||||
func UpdateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc {
|
func UpdateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
svc := app.AgentPoolService()
|
mgr, err := getSkillManager(c, app)
|
||||||
userID := getUserID(c)
|
if err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||||
|
}
|
||||||
var payload struct {
|
var payload struct {
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
}
|
}
|
||||||
if err := c.Bind(&payload); err != nil {
|
if err := c.Bind(&payload); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||||
}
|
}
|
||||||
if err := svc.UpdateSkillResourceForUser(userID, c.Param("name"), c.Param("*"), payload.Content); err != nil {
|
if err := mgr.UpdateResource(c.Param("name"), c.Param("*"), payload.Content); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||||
}
|
}
|
||||||
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
|
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
|
||||||
|
|
@ -326,9 +371,11 @@ func UpdateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc
|
||||||
|
|
||||||
func DeleteSkillResourceEndpoint(app *application.Application) echo.HandlerFunc {
|
func DeleteSkillResourceEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
svc := app.AgentPoolService()
|
mgr, err := getSkillManager(c, app)
|
||||||
userID := getUserID(c)
|
if err != nil {
|
||||||
if err := svc.DeleteSkillResourceForUser(userID, c.Param("name"), c.Param("*")); err != nil {
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||||
|
}
|
||||||
|
if err := mgr.DeleteResource(c.Param("name"), c.Param("*")); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||||
}
|
}
|
||||||
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
|
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
|
||||||
|
|
@ -339,9 +386,11 @@ func DeleteSkillResourceEndpoint(app *application.Application) echo.HandlerFunc
|
||||||
|
|
||||||
func ListGitReposEndpoint(app *application.Application) echo.HandlerFunc {
|
func ListGitReposEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
svc := app.AgentPoolService()
|
mgr, err := getSkillManager(c, app)
|
||||||
userID := getUserID(c)
|
if err != nil {
|
||||||
repos, err := svc.ListGitReposForUser(userID)
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||||
|
}
|
||||||
|
repos, err := mgr.ListGitRepos()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||||
}
|
}
|
||||||
|
|
@ -351,15 +400,17 @@ func ListGitReposEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
|
|
||||||
func AddGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
func AddGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
svc := app.AgentPoolService()
|
mgr, err := getSkillManager(c, app)
|
||||||
userID := getUserID(c)
|
if err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||||
|
}
|
||||||
var payload struct {
|
var payload struct {
|
||||||
URL string `json:"url"`
|
URL string `json:"url"`
|
||||||
}
|
}
|
||||||
if err := c.Bind(&payload); err != nil {
|
if err := c.Bind(&payload); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||||
}
|
}
|
||||||
repo, err := svc.AddGitRepoForUser(userID, payload.URL)
|
repo, err := mgr.AddGitRepo(payload.URL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||||
}
|
}
|
||||||
|
|
@ -369,8 +420,10 @@ func AddGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
|
|
||||||
func UpdateGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
func UpdateGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
svc := app.AgentPoolService()
|
mgr, err := getSkillManager(c, app)
|
||||||
userID := getUserID(c)
|
if err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||||
|
}
|
||||||
var payload struct {
|
var payload struct {
|
||||||
URL string `json:"url"`
|
URL string `json:"url"`
|
||||||
Enabled *bool `json:"enabled"`
|
Enabled *bool `json:"enabled"`
|
||||||
|
|
@ -378,7 +431,7 @@ func UpdateGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
if err := c.Bind(&payload); err != nil {
|
if err := c.Bind(&payload); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||||
}
|
}
|
||||||
repo, err := svc.UpdateGitRepoForUser(userID, c.Param("id"), payload.URL, payload.Enabled)
|
repo, err := mgr.UpdateGitRepo(c.Param("id"), payload.URL, payload.Enabled)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if strings.Contains(err.Error(), "not found") {
|
if strings.Contains(err.Error(), "not found") {
|
||||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||||
|
|
@ -391,9 +444,11 @@ func UpdateGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
|
|
||||||
func DeleteGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
func DeleteGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
svc := app.AgentPoolService()
|
mgr, err := getSkillManager(c, app)
|
||||||
userID := getUserID(c)
|
if err != nil {
|
||||||
if err := svc.DeleteGitRepoForUser(userID, c.Param("id")); err != nil {
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||||
|
}
|
||||||
|
if err := mgr.DeleteGitRepo(c.Param("id")); err != nil {
|
||||||
if strings.Contains(err.Error(), "not found") {
|
if strings.Contains(err.Error(), "not found") {
|
||||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||||
}
|
}
|
||||||
|
|
@ -405,9 +460,11 @@ func DeleteGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
|
|
||||||
func SyncGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
func SyncGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
svc := app.AgentPoolService()
|
mgr, err := getSkillManager(c, app)
|
||||||
userID := getUserID(c)
|
if err != nil {
|
||||||
if err := svc.SyncGitRepoForUser(userID, c.Param("id")); err != nil {
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||||
|
}
|
||||||
|
if err := mgr.SyncGitRepo(c.Param("id")); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||||
}
|
}
|
||||||
return c.JSON(http.StatusAccepted, map[string]string{"status": "syncing"})
|
return c.JSON(http.StatusAccepted, map[string]string{"status": "syncing"})
|
||||||
|
|
@ -416,9 +473,11 @@ func SyncGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
|
|
||||||
func ToggleGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
func ToggleGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
svc := app.AgentPoolService()
|
mgr, err := getSkillManager(c, app)
|
||||||
userID := getUserID(c)
|
if err != nil {
|
||||||
repo, err := svc.ToggleGitRepoForUser(userID, c.Param("id"))
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||||
|
}
|
||||||
|
repo, err := mgr.ToggleGitRepo(c.Param("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,20 +4,23 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"maps"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"sort"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
"github.com/mudler/LocalAI/core/application"
|
|
||||||
"github.com/mudler/LocalAI/core/http/auth"
|
|
||||||
"github.com/mudler/LocalAI/core/services"
|
|
||||||
"github.com/mudler/LocalAI/pkg/utils"
|
|
||||||
"github.com/mudler/LocalAGI/core/state"
|
"github.com/mudler/LocalAGI/core/state"
|
||||||
coreTypes "github.com/mudler/LocalAGI/core/types"
|
coreTypes "github.com/mudler/LocalAGI/core/types"
|
||||||
agiServices "github.com/mudler/LocalAGI/services"
|
agiServices "github.com/mudler/LocalAGI/services"
|
||||||
|
"github.com/mudler/LocalAI/core/application"
|
||||||
|
"github.com/mudler/LocalAI/core/http/auth"
|
||||||
|
"github.com/mudler/LocalAI/core/services/agentpool"
|
||||||
|
"github.com/mudler/LocalAI/core/services/agents"
|
||||||
|
"github.com/mudler/LocalAI/pkg/utils"
|
||||||
|
"github.com/mudler/xlog"
|
||||||
)
|
)
|
||||||
|
|
||||||
// getUserID extracts the scoped user ID from the request context.
|
// getUserID extracts the scoped user ID from the request context.
|
||||||
|
|
@ -42,25 +45,39 @@ func wantsAllUsers(c echo.Context) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// effectiveUserID returns the user ID to scope operations to.
|
// effectiveUserID returns the user ID to scope operations to.
|
||||||
// SECURITY: Only admins may supply ?user_id=<id> to operate on another user's
|
// SECURITY: Only admins and agent-worker service accounts may supply
|
||||||
// resources. Non-admin callers always get their own ID regardless of query params.
|
// ?user_id=<id> to operate on another user's resources. Agent-worker users are
|
||||||
|
// created exclusively server-side during node registration and need to access
|
||||||
|
// collections on behalf of the user whose agent they are executing.
|
||||||
|
// Regular callers always get their own ID regardless of query params.
|
||||||
func effectiveUserID(c echo.Context) string {
|
func effectiveUserID(c echo.Context) string {
|
||||||
if targetUID := c.QueryParam("user_id"); targetUID != "" && isAdminUser(c) {
|
if targetUID := c.QueryParam("user_id"); targetUID != "" && canImpersonateUser(c) {
|
||||||
|
if callerID := getUserID(c); callerID != targetUID {
|
||||||
|
xlog.Info("User impersonation", "caller", callerID, "target", targetUID, "path", c.Path())
|
||||||
|
}
|
||||||
return targetUID
|
return targetUID
|
||||||
}
|
}
|
||||||
return getUserID(c)
|
return getUserID(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// canImpersonateUser returns true if the caller is allowed to use ?user_id= to
|
||||||
|
// scope operations to another user. Allowed for admins and agent-worker service
|
||||||
|
// accounts (ProviderAgentWorker is set server-side during node registration and
|
||||||
|
// cannot be self-assigned).
|
||||||
|
func canImpersonateUser(c echo.Context) bool {
|
||||||
|
user := auth.GetUser(c)
|
||||||
|
if user == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return user.Role == auth.RoleAdmin || user.Provider == auth.ProviderAgentWorker
|
||||||
|
}
|
||||||
|
|
||||||
func ListAgentsEndpoint(app *application.Application) echo.HandlerFunc {
|
func ListAgentsEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
svc := app.AgentPoolService()
|
svc := app.AgentPoolService()
|
||||||
userID := getUserID(c)
|
userID := getUserID(c)
|
||||||
statuses := svc.ListAgentsForUser(userID)
|
statuses := svc.ListAgentsForUser(userID)
|
||||||
agents := make([]string, 0, len(statuses))
|
agents := slices.Sorted(maps.Keys(statuses))
|
||||||
for name := range statuses {
|
|
||||||
agents = append(agents, name)
|
|
||||||
}
|
|
||||||
sort.Strings(agents)
|
|
||||||
resp := map[string]any{
|
resp := map[string]any{
|
||||||
"agents": agents,
|
"agents": agents,
|
||||||
"agentCount": len(agents),
|
"agentCount": len(agents),
|
||||||
|
|
@ -111,13 +128,13 @@ func GetAgentEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
svc := app.AgentPoolService()
|
svc := app.AgentPoolService()
|
||||||
userID := effectiveUserID(c)
|
userID := effectiveUserID(c)
|
||||||
name := c.Param("name")
|
name := c.Param("name")
|
||||||
ag := svc.GetAgentForUser(userID, name)
|
|
||||||
if ag == nil {
|
statuses := svc.ListAgentsForUser(userID)
|
||||||
|
active, exists := statuses[name]
|
||||||
|
if !exists {
|
||||||
return c.JSON(http.StatusNotFound, map[string]string{"error": "Agent not found"})
|
return c.JSON(http.StatusNotFound, map[string]string{"error": "Agent not found"})
|
||||||
}
|
}
|
||||||
return c.JSON(http.StatusOK, map[string]any{
|
return c.JSON(http.StatusOK, map[string]any{"active": active})
|
||||||
"active": !ag.Paused(),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -192,9 +209,13 @@ func GetAgentStatusEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
svc := app.AgentPoolService()
|
svc := app.AgentPoolService()
|
||||||
userID := effectiveUserID(c)
|
userID := effectiveUserID(c)
|
||||||
name := c.Param("name")
|
name := c.Param("name")
|
||||||
|
|
||||||
history := svc.GetAgentStatusForUser(userID, name)
|
history := svc.GetAgentStatusForUser(userID, name)
|
||||||
if history == nil {
|
if history == nil {
|
||||||
history = &state.Status{ActionResults: []coreTypes.ActionState{}}
|
return c.JSON(http.StatusOK, map[string]any{
|
||||||
|
"Name": name,
|
||||||
|
"History": []string{},
|
||||||
|
})
|
||||||
}
|
}
|
||||||
entries := []string{}
|
entries := []string{}
|
||||||
for i := len(history.Results()) - 1; i >= 0; i-- {
|
for i := len(history.Results()) - 1; i >= 0; i-- {
|
||||||
|
|
@ -221,10 +242,14 @@ func GetAgentObservablesEndpoint(app *application.Application) echo.HandlerFunc
|
||||||
svc := app.AgentPoolService()
|
svc := app.AgentPoolService()
|
||||||
userID := effectiveUserID(c)
|
userID := effectiveUserID(c)
|
||||||
name := c.Param("name")
|
name := c.Param("name")
|
||||||
|
|
||||||
history, err := svc.GetAgentObservablesForUser(userID, name)
|
history, err := svc.GetAgentObservablesForUser(userID, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||||
}
|
}
|
||||||
|
if history == nil {
|
||||||
|
history = []json.RawMessage{}
|
||||||
|
}
|
||||||
return c.JSON(http.StatusOK, map[string]any{
|
return c.JSON(http.StatusOK, map[string]any{
|
||||||
"Name": name,
|
"Name": name,
|
||||||
"History": history,
|
"History": history,
|
||||||
|
|
@ -278,26 +303,30 @@ func AgentSSEEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
svc := app.AgentPoolService()
|
svc := app.AgentPoolService()
|
||||||
userID := effectiveUserID(c)
|
userID := effectiveUserID(c)
|
||||||
name := c.Param("name")
|
name := c.Param("name")
|
||||||
manager := svc.GetSSEManagerForUser(userID, name)
|
|
||||||
if manager == nil {
|
|
||||||
return c.JSON(http.StatusNotFound, map[string]string{"error": "Agent not found"})
|
|
||||||
}
|
|
||||||
return services.HandleSSE(c, manager)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type agentConfigMetaResponse struct {
|
// Try local SSE manager first
|
||||||
state.AgentConfigMeta
|
manager := svc.GetSSEManagerForUser(userID, name)
|
||||||
OutputsDir string `json:"OutputsDir"`
|
if manager != nil {
|
||||||
|
return agentpool.HandleSSE(c, manager)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to distributed EventBridge SSE
|
||||||
|
var bridge *agents.EventBridge
|
||||||
|
if d := app.Distributed(); d != nil {
|
||||||
|
bridge = d.AgentBridge
|
||||||
|
}
|
||||||
|
if bridge != nil {
|
||||||
|
return bridge.HandleSSE(c, name, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.JSON(http.StatusNotFound, map[string]string{"error": "Agent not found"})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAgentConfigMetaEndpoint(app *application.Application) echo.HandlerFunc {
|
func GetAgentConfigMetaEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
svc := app.AgentPoolService()
|
svc := app.AgentPoolService()
|
||||||
return c.JSON(http.StatusOK, agentConfigMetaResponse{
|
return c.JSON(http.StatusOK, svc.GetConfigMetaResult())
|
||||||
AgentConfigMeta: svc.GetConfigMeta(),
|
|
||||||
OutputsDir: svc.OutputsDir(),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ import (
|
||||||
"github.com/mudler/LocalAI/core/gallery"
|
"github.com/mudler/LocalAI/core/gallery"
|
||||||
"github.com/mudler/LocalAI/core/http/middleware"
|
"github.com/mudler/LocalAI/core/http/middleware"
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
"github.com/mudler/LocalAI/core/services"
|
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||||
"github.com/mudler/LocalAI/pkg/system"
|
"github.com/mudler/LocalAI/pkg/system"
|
||||||
"github.com/mudler/xlog"
|
"github.com/mudler/xlog"
|
||||||
)
|
)
|
||||||
|
|
@ -19,14 +19,14 @@ type BackendEndpointService struct {
|
||||||
galleries []config.Gallery
|
galleries []config.Gallery
|
||||||
backendPath string
|
backendPath string
|
||||||
backendSystemPath string
|
backendSystemPath string
|
||||||
backendApplier *services.GalleryService
|
backendApplier *galleryop.GalleryService
|
||||||
}
|
}
|
||||||
|
|
||||||
type GalleryBackend struct {
|
type GalleryBackend struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateBackendEndpointService(galleries []config.Gallery, systemState *system.SystemState, backendApplier *services.GalleryService) BackendEndpointService {
|
func CreateBackendEndpointService(galleries []config.Gallery, systemState *system.SystemState, backendApplier *galleryop.GalleryService) BackendEndpointService {
|
||||||
return BackendEndpointService{
|
return BackendEndpointService{
|
||||||
galleries: galleries,
|
galleries: galleries,
|
||||||
backendPath: systemState.Backend.BackendsPath,
|
backendPath: systemState.Backend.BackendsPath,
|
||||||
|
|
@ -37,7 +37,7 @@ func CreateBackendEndpointService(galleries []config.Gallery, systemState *syste
|
||||||
|
|
||||||
// GetOpStatusEndpoint returns the job status
|
// GetOpStatusEndpoint returns the job status
|
||||||
// @Summary Returns the job status
|
// @Summary Returns the job status
|
||||||
// @Success 200 {object} services.GalleryOpStatus "Response"
|
// @Success 200 {object} galleryop.OpStatus "Response"
|
||||||
// @Router /backends/jobs/{uuid} [get]
|
// @Router /backends/jobs/{uuid} [get]
|
||||||
func (mgs *BackendEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
|
func (mgs *BackendEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
|
|
@ -51,7 +51,7 @@ func (mgs *BackendEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
|
||||||
|
|
||||||
// GetAllStatusEndpoint returns all the jobs status progress
|
// GetAllStatusEndpoint returns all the jobs status progress
|
||||||
// @Summary Returns all the jobs status progress
|
// @Summary Returns all the jobs status progress
|
||||||
// @Success 200 {object} map[string]services.GalleryOpStatus "Response"
|
// @Success 200 {object} map[string]galleryop.OpStatus "Response"
|
||||||
// @Router /backends/jobs [get]
|
// @Router /backends/jobs [get]
|
||||||
func (mgs *BackendEndpointService) GetAllStatusEndpoint() echo.HandlerFunc {
|
func (mgs *BackendEndpointService) GetAllStatusEndpoint() echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
|
|
@ -76,7 +76,7 @@ func (mgs *BackendEndpointService) ApplyBackendEndpoint() echo.HandlerFunc {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
mgs.backendApplier.BackendGalleryChannel <- services.GalleryOp[gallery.GalleryBackend, any]{
|
mgs.backendApplier.BackendGalleryChannel <- galleryop.ManagementOp[gallery.GalleryBackend, any]{
|
||||||
ID: uuid.String(),
|
ID: uuid.String(),
|
||||||
GalleryElementName: input.ID,
|
GalleryElementName: input.ID,
|
||||||
Galleries: mgs.galleries,
|
Galleries: mgs.galleries,
|
||||||
|
|
@ -95,7 +95,7 @@ func (mgs *BackendEndpointService) DeleteBackendEndpoint() echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
backendName := c.Param("name")
|
backendName := c.Param("name")
|
||||||
|
|
||||||
mgs.backendApplier.BackendGalleryChannel <- services.GalleryOp[gallery.GalleryBackend, any]{
|
mgs.backendApplier.BackendGalleryChannel <- galleryop.ManagementOp[gallery.GalleryBackend, any]{
|
||||||
Delete: true,
|
Delete: true,
|
||||||
GalleryElementName: backendName,
|
GalleryElementName: backendName,
|
||||||
Galleries: mgs.galleries,
|
Galleries: mgs.galleries,
|
||||||
|
|
@ -114,9 +114,9 @@ func (mgs *BackendEndpointService) DeleteBackendEndpoint() echo.HandlerFunc {
|
||||||
// @Summary List all Backends
|
// @Summary List all Backends
|
||||||
// @Success 200 {object} []gallery.GalleryBackend "Response"
|
// @Success 200 {object} []gallery.GalleryBackend "Response"
|
||||||
// @Router /backends [get]
|
// @Router /backends [get]
|
||||||
func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.SystemState) echo.HandlerFunc {
|
func (mgs *BackendEndpointService) ListBackendsEndpoint() echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
backends, err := gallery.ListSystemBackends(systemState)
|
backends, err := mgs.backendApplier.ListBackends()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ package localai
|
||||||
import (
|
import (
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
"github.com/mudler/LocalAI/core/services"
|
"github.com/mudler/LocalAI/core/services/monitoring"
|
||||||
)
|
)
|
||||||
|
|
||||||
// BackendMonitorEndpoint returns the status of the specified backend
|
// BackendMonitorEndpoint returns the status of the specified backend
|
||||||
|
|
@ -11,7 +11,7 @@ import (
|
||||||
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
|
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
|
||||||
// @Success 200 {object} proto.StatusResponse "Response"
|
// @Success 200 {object} proto.StatusResponse "Response"
|
||||||
// @Router /backend/monitor [get]
|
// @Router /backend/monitor [get]
|
||||||
func BackendMonitorEndpoint(bm *services.BackendMonitorService) echo.HandlerFunc {
|
func BackendMonitorEndpoint(bm *monitoring.BackendMonitorService) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
|
|
||||||
input := new(schema.BackendMonitorRequest)
|
input := new(schema.BackendMonitorRequest)
|
||||||
|
|
@ -32,7 +32,7 @@ func BackendMonitorEndpoint(bm *services.BackendMonitorService) echo.HandlerFunc
|
||||||
// @Summary Backend monitor endpoint
|
// @Summary Backend monitor endpoint
|
||||||
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
|
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
|
||||||
// @Router /backend/shutdown [post]
|
// @Router /backend/shutdown [post]
|
||||||
func BackendShutdownEndpoint(bm *services.BackendMonitorService) echo.HandlerFunc {
|
func BackendShutdownEndpoint(bm *monitoring.BackendMonitorService) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
input := new(schema.BackendMonitorRequest)
|
input := new(schema.BackendMonitorRequest)
|
||||||
// Get input data from the request body
|
// Get input data from the request body
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue