mirror of
https://github.com/mudler/LocalAI
synced 2026-04-21 13:27:21 +00:00
fix(nodes): better detection if nodes goes down or model is not available (#9274)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
154fa000d3
commit
510d6759fe
5 changed files with 216 additions and 4 deletions
35
pkg/model/connection_errors.go
Normal file
35
pkg/model/connection_errors.go
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// isConnectionError returns true if the error indicates the remote endpoint is
|
||||
// unreachable (connection refused, reset, gRPC Unavailable). Returns false for
|
||||
// timeouts and deadline exceeded — those may indicate a busy server, not a dead one.
|
||||
func isConnectionError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// gRPC Unavailable = server not reachable (covers connection refused, DNS, TLS errors)
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.Unavailable {
|
||||
return true
|
||||
}
|
||||
|
||||
// Syscall-level connection errors
|
||||
if errors.Is(err, syscall.ECONNREFUSED) || errors.Is(err, syscall.ECONNRESET) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Fallback string matching for wrapped errors that lose the typed error
|
||||
msg := err.Error()
|
||||
return strings.Contains(msg, "connection refused") ||
|
||||
strings.Contains(msg, "connection reset") ||
|
||||
strings.Contains(msg, "no such host")
|
||||
}
|
||||
109
pkg/model/connection_evicting_client.go
Normal file
109
pkg/model/connection_evicting_client.go
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
package model
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
ggrpc "google.golang.org/grpc"
|
||||
)
|
||||
|
||||
// ConnectionEvictingClient wraps a grpc.Backend. When any inference method
|
||||
// fails with a connection error (server unreachable), it calls the evict
|
||||
// callback to remove the model from the ModelLoader's cache. The error is
|
||||
// still returned to the caller — the NEXT request will trigger rescheduling
|
||||
// via SmartRouter.
|
||||
type ConnectionEvictingClient struct {
|
||||
grpc.Backend
|
||||
modelID string
|
||||
evict func()
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func newConnectionEvictingClient(inner grpc.Backend, modelID string, evict func()) grpc.Backend {
|
||||
return &ConnectionEvictingClient{
|
||||
Backend: inner,
|
||||
modelID: modelID,
|
||||
evict: evict,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ConnectionEvictingClient) checkErr(err error) {
|
||||
if err != nil && isConnectionError(err) {
|
||||
c.once.Do(func() {
|
||||
xlog.Warn("Connection error during inference, evicting model from cache",
|
||||
"model", c.modelID, "error", err)
|
||||
c.evict()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- Intercepted inference methods ---
|
||||
|
||||
func (c *ConnectionEvictingClient) Predict(ctx context.Context, in *pb.PredictOptions, opts ...ggrpc.CallOption) (*pb.Reply, error) {
|
||||
reply, err := c.Backend.Predict(ctx, in, opts...)
|
||||
c.checkErr(err)
|
||||
return reply, err
|
||||
}
|
||||
|
||||
func (c *ConnectionEvictingClient) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...ggrpc.CallOption) error {
|
||||
err := c.Backend.PredictStream(ctx, in, f, opts...)
|
||||
c.checkErr(err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *ConnectionEvictingClient) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...ggrpc.CallOption) (*pb.EmbeddingResult, error) {
|
||||
result, err := c.Backend.Embeddings(ctx, in, opts...)
|
||||
c.checkErr(err)
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (c *ConnectionEvictingClient) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
|
||||
result, err := c.Backend.GenerateImage(ctx, in, opts...)
|
||||
c.checkErr(err)
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (c *ConnectionEvictingClient) GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
|
||||
result, err := c.Backend.GenerateVideo(ctx, in, opts...)
|
||||
c.checkErr(err)
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (c *ConnectionEvictingClient) TTS(ctx context.Context, in *pb.TTSRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
|
||||
result, err := c.Backend.TTS(ctx, in, opts...)
|
||||
c.checkErr(err)
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (c *ConnectionEvictingClient) TTSStream(ctx context.Context, in *pb.TTSRequest, f func(reply *pb.Reply), opts ...ggrpc.CallOption) error {
|
||||
err := c.Backend.TTSStream(ctx, in, f, opts...)
|
||||
c.checkErr(err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *ConnectionEvictingClient) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
|
||||
result, err := c.Backend.SoundGeneration(ctx, in, opts...)
|
||||
c.checkErr(err)
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (c *ConnectionEvictingClient) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...ggrpc.CallOption) (*pb.TranscriptResult, error) {
|
||||
result, err := c.Backend.AudioTranscription(ctx, in, opts...)
|
||||
c.checkErr(err)
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (c *ConnectionEvictingClient) Detect(ctx context.Context, in *pb.DetectOptions, opts ...ggrpc.CallOption) (*pb.DetectResponse, error) {
|
||||
result, err := c.Backend.Detect(ctx, in, opts...)
|
||||
c.checkErr(err)
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (c *ConnectionEvictingClient) Rerank(ctx context.Context, in *pb.RerankRequest, opts ...ggrpc.CallOption) (*pb.RerankResult, error) {
|
||||
result, err := c.Backend.Rerank(ctx, in, opts...)
|
||||
c.checkErr(err)
|
||||
return result, err
|
||||
}
|
||||
|
|
@ -253,7 +253,14 @@ func (ml *ModelLoader) Load(opts ...Option) (grpc.Backend, error) {
|
|||
xlog.Debug("Model already loaded", "model", o.modelID)
|
||||
// Update last used time for LRU tracking
|
||||
ml.updateModelLastUsed(m)
|
||||
return m.GRPC(o.parallelRequests, ml.wd), nil
|
||||
client := m.GRPC(o.parallelRequests, ml.wd)
|
||||
// Wrap remote models so connection errors during inference trigger eviction
|
||||
if m.Process() == nil {
|
||||
client = newConnectionEvictingClient(client, o.modelID, func() {
|
||||
ml.ShutdownModel(o.modelID)
|
||||
})
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// Enforce LRU limit before loading a new model
|
||||
|
|
@ -265,6 +272,12 @@ func (ml *ModelLoader) Load(opts ...Option) (grpc.Backend, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Wrap remote models so connection errors during inference trigger eviction
|
||||
if m := ml.CheckIsLoaded(o.modelID); m != nil && m.Process() == nil {
|
||||
client = newConnectionEvictingClient(client, o.modelID, func() {
|
||||
ml.ShutdownModel(o.modelID)
|
||||
})
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
|
|
@ -297,6 +310,12 @@ func (ml *ModelLoader) Load(opts ...Option) (grpc.Backend, error) {
|
|||
model, modelerr := ml.backendLoader(options...)
|
||||
if modelerr == nil && model != nil {
|
||||
xlog.Info("Loads OK", "backend", key)
|
||||
// Wrap remote models so connection errors during inference trigger eviction
|
||||
if m := ml.CheckIsLoaded(o.modelID); m != nil && m.Process() == nil {
|
||||
model = newConnectionEvictingClient(model, o.modelID, func() {
|
||||
ml.ShutdownModel(o.modelID)
|
||||
})
|
||||
}
|
||||
return model, nil
|
||||
} else if modelerr != nil {
|
||||
err = errors.Join(err, fmt.Errorf("[%s]: %w", key, modelerr))
|
||||
|
|
|
|||
|
|
@ -347,7 +347,17 @@ func (ml *ModelLoader) checkIsLoaded(s string) *Model {
|
|||
xlog.Warn("Deleting the process in order to recreate it")
|
||||
process := m.Process()
|
||||
if process == nil {
|
||||
xlog.Error("Process not found and the model is not responding anymore", "model", s)
|
||||
// Remote/distributed model — no local process to check.
|
||||
// Only evict on definitive connection errors (node is down).
|
||||
// Timeouts may mean the node is busy, so keep the model cached.
|
||||
if isConnectionError(err) {
|
||||
xlog.Warn("Remote model unreachable (connection error), removing from cache", "model", s, "error", err)
|
||||
if delErr := ml.deleteProcess(s); delErr != nil {
|
||||
xlog.Error("error cleaning up remote model", "error", delErr, "model", s)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
xlog.Warn("Remote model health check failed (possible timeout), keeping cached", "model", s, "error", err)
|
||||
return m
|
||||
}
|
||||
if !process.IsAlive() {
|
||||
|
|
|
|||
|
|
@ -75,6 +75,7 @@ var _ = Describe("ModelLoader", func() {
|
|||
Context("LoadModel", func() {
|
||||
It("should load a model and keep it in memory", func() {
|
||||
mockModel = model.NewModel("foo", "test.model", nil)
|
||||
mockModel.MarkHealthy() // skip gRPC health check (no real server)
|
||||
|
||||
mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) {
|
||||
return mockModel, nil
|
||||
|
|
@ -97,6 +98,40 @@ var _ = Describe("ModelLoader", func() {
|
|||
})
|
||||
})
|
||||
|
||||
Context("Remote model eviction", func() {
|
||||
It("should evict unreachable remote models from cache on health check", func() {
|
||||
// Create a remote model (process=nil) with an unreachable address
|
||||
remoteModel := model.NewModel("remote-test", "127.0.0.1:1", nil)
|
||||
|
||||
mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) {
|
||||
return remoteModel, nil
|
||||
}
|
||||
|
||||
_, err := modelLoader.LoadModel("remote-test", "test.model", mockLoader)
|
||||
Expect(err).To(BeNil())
|
||||
|
||||
// CheckIsLoaded should detect the connection error and evict
|
||||
result := modelLoader.CheckIsLoaded("remote-test")
|
||||
Expect(result).To(BeNil(), "unreachable remote model should be evicted from cache")
|
||||
})
|
||||
|
||||
It("should keep recently-healthy remote models in cache", func() {
|
||||
remoteModel := model.NewModel("healthy-remote", "127.0.0.1:1", nil)
|
||||
remoteModel.MarkHealthy() // simulate a recent successful health check
|
||||
|
||||
mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) {
|
||||
return remoteModel, nil
|
||||
}
|
||||
|
||||
loaded, err := modelLoader.LoadModel("healthy-remote", "test.model", mockLoader)
|
||||
Expect(err).To(BeNil())
|
||||
|
||||
// Within TTL, should return the model without health check
|
||||
result := modelLoader.CheckIsLoaded("healthy-remote")
|
||||
Expect(result).To(Equal(loaded), "recently-healthy model should be returned from cache")
|
||||
})
|
||||
})
|
||||
|
||||
Context("ShutdownModel", func() {
|
||||
It("should shutdown a loaded model", func() {
|
||||
mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) {
|
||||
|
|
@ -118,7 +153,9 @@ var _ = Describe("ModelLoader", func() {
|
|||
mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) {
|
||||
atomic.AddInt32(&loadCount, 1)
|
||||
time.Sleep(100 * time.Millisecond) // Simulate loading time
|
||||
return model.NewModel(modelID, modelName, nil), nil
|
||||
m := model.NewModel(modelID, modelName, nil)
|
||||
m.MarkHealthy() // skip gRPC health check (no real server)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
|
@ -154,7 +191,9 @@ var _ = Describe("ModelLoader", func() {
|
|||
mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) {
|
||||
atomic.AddInt32(&loadCount, 1)
|
||||
time.Sleep(50 * time.Millisecond) // Simulate loading time
|
||||
return model.NewModel(modelID, modelName, nil), nil
|
||||
m := model.NewModel(modelID, modelName, nil)
|
||||
m.MarkHealthy() // skip gRPC health check (no real server)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
|
|
|||
Loading…
Reference in a new issue