fix(nodes): better detection if nodes goes down or model is not available (#9274)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2026-04-08 12:11:02 +02:00 committed by GitHub
parent 154fa000d3
commit 510d6759fe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 216 additions and 4 deletions

View file

@ -0,0 +1,35 @@
package model
import (
"errors"
"strings"
"syscall"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// isConnectionError returns true if the error indicates the remote endpoint is
// unreachable (connection refused, reset, gRPC Unavailable). Returns false for
// timeouts and deadline exceeded — those may indicate a busy server, not a dead one.
func isConnectionError(err error) bool {
if err == nil {
return false
}
// gRPC Unavailable = server not reachable (covers connection refused, DNS, TLS errors)
if s, ok := status.FromError(err); ok && s.Code() == codes.Unavailable {
return true
}
// Syscall-level connection errors
if errors.Is(err, syscall.ECONNREFUSED) || errors.Is(err, syscall.ECONNRESET) {
return true
}
// Fallback string matching for wrapped errors that lose the typed error
msg := err.Error()
return strings.Contains(msg, "connection refused") ||
strings.Contains(msg, "connection reset") ||
strings.Contains(msg, "no such host")
}

View file

@ -0,0 +1,109 @@
package model
import (
"context"
"sync"
grpc "github.com/mudler/LocalAI/pkg/grpc"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/xlog"
ggrpc "google.golang.org/grpc"
)
// ConnectionEvictingClient wraps a grpc.Backend. When any inference method
// fails with a connection error (server unreachable), it calls the evict
// callback to remove the model from the ModelLoader's cache. The error is
// still returned to the caller — the NEXT request will trigger rescheduling
// via SmartRouter.
type ConnectionEvictingClient struct {
grpc.Backend
modelID string
evict func()
once sync.Once
}
func newConnectionEvictingClient(inner grpc.Backend, modelID string, evict func()) grpc.Backend {
return &ConnectionEvictingClient{
Backend: inner,
modelID: modelID,
evict: evict,
}
}
func (c *ConnectionEvictingClient) checkErr(err error) {
if err != nil && isConnectionError(err) {
c.once.Do(func() {
xlog.Warn("Connection error during inference, evicting model from cache",
"model", c.modelID, "error", err)
c.evict()
})
}
}
// --- Intercepted inference methods ---
func (c *ConnectionEvictingClient) Predict(ctx context.Context, in *pb.PredictOptions, opts ...ggrpc.CallOption) (*pb.Reply, error) {
reply, err := c.Backend.Predict(ctx, in, opts...)
c.checkErr(err)
return reply, err
}
func (c *ConnectionEvictingClient) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...ggrpc.CallOption) error {
err := c.Backend.PredictStream(ctx, in, f, opts...)
c.checkErr(err)
return err
}
func (c *ConnectionEvictingClient) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...ggrpc.CallOption) (*pb.EmbeddingResult, error) {
result, err := c.Backend.Embeddings(ctx, in, opts...)
c.checkErr(err)
return result, err
}
func (c *ConnectionEvictingClient) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
result, err := c.Backend.GenerateImage(ctx, in, opts...)
c.checkErr(err)
return result, err
}
func (c *ConnectionEvictingClient) GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
result, err := c.Backend.GenerateVideo(ctx, in, opts...)
c.checkErr(err)
return result, err
}
func (c *ConnectionEvictingClient) TTS(ctx context.Context, in *pb.TTSRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
result, err := c.Backend.TTS(ctx, in, opts...)
c.checkErr(err)
return result, err
}
func (c *ConnectionEvictingClient) TTSStream(ctx context.Context, in *pb.TTSRequest, f func(reply *pb.Reply), opts ...ggrpc.CallOption) error {
err := c.Backend.TTSStream(ctx, in, f, opts...)
c.checkErr(err)
return err
}
func (c *ConnectionEvictingClient) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
result, err := c.Backend.SoundGeneration(ctx, in, opts...)
c.checkErr(err)
return result, err
}
func (c *ConnectionEvictingClient) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...ggrpc.CallOption) (*pb.TranscriptResult, error) {
result, err := c.Backend.AudioTranscription(ctx, in, opts...)
c.checkErr(err)
return result, err
}
func (c *ConnectionEvictingClient) Detect(ctx context.Context, in *pb.DetectOptions, opts ...ggrpc.CallOption) (*pb.DetectResponse, error) {
result, err := c.Backend.Detect(ctx, in, opts...)
c.checkErr(err)
return result, err
}
func (c *ConnectionEvictingClient) Rerank(ctx context.Context, in *pb.RerankRequest, opts ...ggrpc.CallOption) (*pb.RerankResult, error) {
result, err := c.Backend.Rerank(ctx, in, opts...)
c.checkErr(err)
return result, err
}

View file

@ -253,7 +253,14 @@ func (ml *ModelLoader) Load(opts ...Option) (grpc.Backend, error) {
xlog.Debug("Model already loaded", "model", o.modelID)
// Update last used time for LRU tracking
ml.updateModelLastUsed(m)
return m.GRPC(o.parallelRequests, ml.wd), nil
client := m.GRPC(o.parallelRequests, ml.wd)
// Wrap remote models so connection errors during inference trigger eviction
if m.Process() == nil {
client = newConnectionEvictingClient(client, o.modelID, func() {
ml.ShutdownModel(o.modelID)
})
}
return client, nil
}
// Enforce LRU limit before loading a new model
@ -265,6 +272,12 @@ func (ml *ModelLoader) Load(opts ...Option) (grpc.Backend, error) {
if err != nil {
return nil, err
}
// Wrap remote models so connection errors during inference trigger eviction
if m := ml.CheckIsLoaded(o.modelID); m != nil && m.Process() == nil {
client = newConnectionEvictingClient(client, o.modelID, func() {
ml.ShutdownModel(o.modelID)
})
}
return client, nil
}
@ -297,6 +310,12 @@ func (ml *ModelLoader) Load(opts ...Option) (grpc.Backend, error) {
model, modelerr := ml.backendLoader(options...)
if modelerr == nil && model != nil {
xlog.Info("Loads OK", "backend", key)
// Wrap remote models so connection errors during inference trigger eviction
if m := ml.CheckIsLoaded(o.modelID); m != nil && m.Process() == nil {
model = newConnectionEvictingClient(model, o.modelID, func() {
ml.ShutdownModel(o.modelID)
})
}
return model, nil
} else if modelerr != nil {
err = errors.Join(err, fmt.Errorf("[%s]: %w", key, modelerr))

View file

@ -347,7 +347,17 @@ func (ml *ModelLoader) checkIsLoaded(s string) *Model {
xlog.Warn("Deleting the process in order to recreate it")
process := m.Process()
if process == nil {
xlog.Error("Process not found and the model is not responding anymore", "model", s)
// Remote/distributed model — no local process to check.
// Only evict on definitive connection errors (node is down).
// Timeouts may mean the node is busy, so keep the model cached.
if isConnectionError(err) {
xlog.Warn("Remote model unreachable (connection error), removing from cache", "model", s, "error", err)
if delErr := ml.deleteProcess(s); delErr != nil {
xlog.Error("error cleaning up remote model", "error", delErr, "model", s)
}
return nil
}
xlog.Warn("Remote model health check failed (possible timeout), keeping cached", "model", s, "error", err)
return m
}
if !process.IsAlive() {

View file

@ -75,6 +75,7 @@ var _ = Describe("ModelLoader", func() {
Context("LoadModel", func() {
It("should load a model and keep it in memory", func() {
mockModel = model.NewModel("foo", "test.model", nil)
mockModel.MarkHealthy() // skip gRPC health check (no real server)
mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) {
return mockModel, nil
@ -97,6 +98,40 @@ var _ = Describe("ModelLoader", func() {
})
})
Context("Remote model eviction", func() {
It("should evict unreachable remote models from cache on health check", func() {
// Create a remote model (process=nil) with an unreachable address
remoteModel := model.NewModel("remote-test", "127.0.0.1:1", nil)
mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) {
return remoteModel, nil
}
_, err := modelLoader.LoadModel("remote-test", "test.model", mockLoader)
Expect(err).To(BeNil())
// CheckIsLoaded should detect the connection error and evict
result := modelLoader.CheckIsLoaded("remote-test")
Expect(result).To(BeNil(), "unreachable remote model should be evicted from cache")
})
It("should keep recently-healthy remote models in cache", func() {
remoteModel := model.NewModel("healthy-remote", "127.0.0.1:1", nil)
remoteModel.MarkHealthy() // simulate a recent successful health check
mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) {
return remoteModel, nil
}
loaded, err := modelLoader.LoadModel("healthy-remote", "test.model", mockLoader)
Expect(err).To(BeNil())
// Within TTL, should return the model without health check
result := modelLoader.CheckIsLoaded("healthy-remote")
Expect(result).To(Equal(loaded), "recently-healthy model should be returned from cache")
})
})
Context("ShutdownModel", func() {
It("should shutdown a loaded model", func() {
mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) {
@ -118,7 +153,9 @@ var _ = Describe("ModelLoader", func() {
mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) {
atomic.AddInt32(&loadCount, 1)
time.Sleep(100 * time.Millisecond) // Simulate loading time
return model.NewModel(modelID, modelName, nil), nil
m := model.NewModel(modelID, modelName, nil)
m.MarkHealthy() // skip gRPC health check (no real server)
return m, nil
}
var wg sync.WaitGroup
@ -154,7 +191,9 @@ var _ = Describe("ModelLoader", func() {
mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) {
atomic.AddInt32(&loadCount, 1)
time.Sleep(50 * time.Millisecond) // Simulate loading time
return model.NewModel(modelID, modelName, nil), nil
m := model.NewModel(modelID, modelName, nil)
m.MarkHealthy() // skip gRPC health check (no real server)
return m, nil
}
var wg sync.WaitGroup