mirror of
https://github.com/mudler/LocalAI
synced 2026-04-21 13:27:21 +00:00
feat: wire transcription for llama.cpp, add streaming support (#9353)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
b361d2ddd6
commit
87e6de1989
22 changed files with 722 additions and 71 deletions
17
.github/workflows/test-extra.yml
vendored
17
.github/workflows/test-extra.yml
vendored
|
|
@ -485,6 +485,23 @@ jobs:
|
||||||
- name: Build llama-cpp backend image and run gRPC e2e tests
|
- name: Build llama-cpp backend image and run gRPC e2e tests
|
||||||
run: |
|
run: |
|
||||||
make test-extra-backend-llama-cpp
|
make test-extra-backend-llama-cpp
|
||||||
|
tests-llama-cpp-grpc-transcription:
|
||||||
|
needs: detect-changes
|
||||||
|
if: needs.detect-changes.outputs.llama-cpp == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
timeout-minutes: 90
|
||||||
|
steps:
|
||||||
|
- name: Clone
|
||||||
|
uses: actions/checkout@v6
|
||||||
|
with:
|
||||||
|
submodules: true
|
||||||
|
- name: Setup Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: '1.25.4'
|
||||||
|
- name: Build llama-cpp backend image and run audio transcription gRPC e2e tests
|
||||||
|
run: |
|
||||||
|
make test-extra-backend-llama-cpp-transcription
|
||||||
tests-ik-llama-cpp-grpc:
|
tests-ik-llama-cpp-grpc:
|
||||||
needs: detect-changes
|
needs: detect-changes
|
||||||
if: needs.detect-changes.outputs.ik-llama-cpp == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
if: needs.detect-changes.outputs.ik-llama-cpp == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||||
|
|
|
||||||
17
Makefile
17
Makefile
|
|
@ -493,6 +493,10 @@ test-extra-backend: protogen-go
|
||||||
BACKEND_TEST_MODEL_URL="$${BACKEND_TEST_MODEL_URL:-$(BACKEND_TEST_MODEL_URL)}" \
|
BACKEND_TEST_MODEL_URL="$${BACKEND_TEST_MODEL_URL:-$(BACKEND_TEST_MODEL_URL)}" \
|
||||||
BACKEND_TEST_MODEL_FILE="$$BACKEND_TEST_MODEL_FILE" \
|
BACKEND_TEST_MODEL_FILE="$$BACKEND_TEST_MODEL_FILE" \
|
||||||
BACKEND_TEST_MODEL_NAME="$$BACKEND_TEST_MODEL_NAME" \
|
BACKEND_TEST_MODEL_NAME="$$BACKEND_TEST_MODEL_NAME" \
|
||||||
|
BACKEND_TEST_MMPROJ_URL="$$BACKEND_TEST_MMPROJ_URL" \
|
||||||
|
BACKEND_TEST_MMPROJ_FILE="$$BACKEND_TEST_MMPROJ_FILE" \
|
||||||
|
BACKEND_TEST_AUDIO_URL="$$BACKEND_TEST_AUDIO_URL" \
|
||||||
|
BACKEND_TEST_AUDIO_FILE="$$BACKEND_TEST_AUDIO_FILE" \
|
||||||
BACKEND_TEST_CAPS="$$BACKEND_TEST_CAPS" \
|
BACKEND_TEST_CAPS="$$BACKEND_TEST_CAPS" \
|
||||||
BACKEND_TEST_PROMPT="$$BACKEND_TEST_PROMPT" \
|
BACKEND_TEST_PROMPT="$$BACKEND_TEST_PROMPT" \
|
||||||
BACKEND_TEST_OPTIONS="$$BACKEND_TEST_OPTIONS" \
|
BACKEND_TEST_OPTIONS="$$BACKEND_TEST_OPTIONS" \
|
||||||
|
|
@ -507,6 +511,19 @@ test-extra-backend-llama-cpp: docker-build-llama-cpp
|
||||||
test-extra-backend-ik-llama-cpp: docker-build-ik-llama-cpp
|
test-extra-backend-ik-llama-cpp: docker-build-ik-llama-cpp
|
||||||
BACKEND_IMAGE=local-ai-backend:ik-llama-cpp $(MAKE) test-extra-backend
|
BACKEND_IMAGE=local-ai-backend:ik-llama-cpp $(MAKE) test-extra-backend
|
||||||
|
|
||||||
|
## Audio transcription wrapper for the llama-cpp backend.
|
||||||
|
## Drives the new AudioTranscription / AudioTranscriptionStream RPCs against
|
||||||
|
## ggml-org/Qwen3-ASR-0.6B-GGUF (a small ASR model that requires its mmproj
|
||||||
|
## audio encoder companion). The audio fixture is a short public-domain
|
||||||
|
## "jfk.wav" clip ggml-org bundles with whisper.cpp's CI assets.
|
||||||
|
test-extra-backend-llama-cpp-transcription: docker-build-llama-cpp
|
||||||
|
BACKEND_IMAGE=local-ai-backend:llama-cpp \
|
||||||
|
BACKEND_TEST_MODEL_URL=https://huggingface.co/ggml-org/Qwen3-ASR-0.6B-GGUF/resolve/main/Qwen3-ASR-0.6B-Q8_0.gguf \
|
||||||
|
BACKEND_TEST_MMPROJ_URL=https://huggingface.co/ggml-org/Qwen3-ASR-0.6B-GGUF/resolve/main/mmproj-Qwen3-ASR-0.6B-Q8_0.gguf \
|
||||||
|
BACKEND_TEST_AUDIO_URL=https://github.com/ggml-org/whisper.cpp/raw/master/samples/jfk.wav \
|
||||||
|
BACKEND_TEST_CAPS=health,load,transcription \
|
||||||
|
$(MAKE) test-extra-backend
|
||||||
|
|
||||||
## vllm is resolved from a HuggingFace model id (no file download) and
|
## vllm is resolved from a HuggingFace model id (no file download) and
|
||||||
## exercises Predict + streaming + tool-call extraction via the hermes parser.
|
## exercises Predict + streaming + tool-call extraction via the hermes parser.
|
||||||
## Requires a host CPU with the SIMD instructions the prebuilt vllm CPU
|
## Requires a host CPU with the SIMD instructions the prebuilt vllm CPU
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ service Backend {
|
||||||
rpc GenerateImage(GenerateImageRequest) returns (Result) {}
|
rpc GenerateImage(GenerateImageRequest) returns (Result) {}
|
||||||
rpc GenerateVideo(GenerateVideoRequest) returns (Result) {}
|
rpc GenerateVideo(GenerateVideoRequest) returns (Result) {}
|
||||||
rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {}
|
rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {}
|
||||||
|
rpc AudioTranscriptionStream(TranscriptRequest) returns (stream TranscriptStreamResponse) {}
|
||||||
rpc TTS(TTSRequest) returns (Result) {}
|
rpc TTS(TTSRequest) returns (Result) {}
|
||||||
rpc TTSStream(TTSRequest) returns (stream Reply) {}
|
rpc TTSStream(TTSRequest) returns (stream Reply) {}
|
||||||
rpc SoundGeneration(SoundGenerationRequest) returns (Result) {}
|
rpc SoundGeneration(SoundGenerationRequest) returns (Result) {}
|
||||||
|
|
@ -322,11 +323,21 @@ message TranscriptRequest {
|
||||||
bool translate = 5;
|
bool translate = 5;
|
||||||
bool diarize = 6;
|
bool diarize = 6;
|
||||||
string prompt = 7;
|
string prompt = 7;
|
||||||
|
float temperature = 8;
|
||||||
|
repeated string timestamp_granularities = 9;
|
||||||
|
bool stream = 10;
|
||||||
}
|
}
|
||||||
|
|
||||||
message TranscriptResult {
|
message TranscriptResult {
|
||||||
repeated TranscriptSegment segments = 1;
|
repeated TranscriptSegment segments = 1;
|
||||||
string text = 2;
|
string text = 2;
|
||||||
|
string language = 3;
|
||||||
|
float duration = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TranscriptStreamResponse {
|
||||||
|
string delta = 1;
|
||||||
|
TranscriptResult final_result = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
message TranscriptSegment {
|
message TranscriptSegment {
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
|
|
||||||
LLAMA_VERSION?=e97492369888f5311e4d1f3beb325a36bbed70e9
|
LLAMA_VERSION?=6a6780a232b73fe44799b0c0d5f01c61612f1b79
|
||||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||||
|
|
||||||
CMAKE_ARGS?=
|
CMAKE_ARGS?=
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,8 @@
|
||||||
#include <regex>
|
#include <regex>
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
|
#include <fstream>
|
||||||
|
#include <iterator>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <signal.h>
|
#include <signal.h>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
|
|
@ -76,6 +78,27 @@ static grpc::Status checkAuth(grpc::ServerContext* context) {
|
||||||
return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "invalid token");
|
return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "invalid token");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Minimal base64 encoder. The C++ backend already pulls in base64_decode from
|
||||||
|
// llama.cpp's server-common.cpp, but no encoder is exposed — and we need one to
|
||||||
|
// hand audio bytes to the existing PredictOptions.audios path (which expects
|
||||||
|
// base64-encoded strings, just like images).
|
||||||
|
static std::string base64_encode_bytes(const unsigned char* data, size_t len) {
|
||||||
|
static const char tbl[] =
|
||||||
|
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
||||||
|
std::string out;
|
||||||
|
out.reserve(((len + 2) / 3) * 4);
|
||||||
|
for (size_t i = 0; i < len; i += 3) {
|
||||||
|
uint32_t triple = (uint32_t(data[i]) << 16);
|
||||||
|
if (i + 1 < len) triple |= (uint32_t(data[i + 1]) << 8);
|
||||||
|
if (i + 2 < len) triple |= uint32_t(data[i + 2]);
|
||||||
|
out.push_back(tbl[(triple >> 18) & 0x3F]);
|
||||||
|
out.push_back(tbl[(triple >> 12) & 0x3F]);
|
||||||
|
out.push_back(i + 1 < len ? tbl[(triple >> 6) & 0x3F] : '=');
|
||||||
|
out.push_back(i + 2 < len ? tbl[triple & 0x3F] : '=');
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
// END LocalAI
|
// END LocalAI
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -2931,6 +2954,119 @@ public:
|
||||||
|
|
||||||
return grpc::Status::OK;
|
return grpc::Status::OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// runTranscriptionAsCompletion implements OAI /v1/audio/transcriptions on
|
||||||
|
// top of the existing chat-completion + multimodal-audio pipeline, exactly
|
||||||
|
// the way upstream llama.cpp's server does it (see
|
||||||
|
// tools/server/server-context.cpp post_transcriptions_oai → forwards into
|
||||||
|
// handle_completions_impl with a single user message attaching the audio
|
||||||
|
// file via the mtmd marker).
|
||||||
|
//
|
||||||
|
// We synthesize a backend::PredictOptions with one user message
|
||||||
|
// ("Transcribe audio to text" + optional language hint) and the audio
|
||||||
|
// bytes attached via the existing PredictOptions.audios field, then
|
||||||
|
// delegate to our own Predict() handler. This keeps every multimodal
|
||||||
|
// codepath identical to the chat path and avoids duplicating ~700 lines
|
||||||
|
// of task-construction logic.
|
||||||
|
grpc::Status runTranscriptionAsCompletion(grpc::ServerContext* context,
|
||||||
|
const backend::TranscriptRequest* request,
|
||||||
|
backend::Reply* out_reply) {
|
||||||
|
if (params_base.model.path.empty()) {
|
||||||
|
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||||
|
}
|
||||||
|
if (request->dst().empty()) {
|
||||||
|
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "dst (audio file path) is required");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read audio bytes from the path LocalAI's HTTP layer wrote.
|
||||||
|
std::ifstream f(request->dst(), std::ios::binary);
|
||||||
|
if (!f.is_open()) {
|
||||||
|
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "failed to open audio file: " + request->dst());
|
||||||
|
}
|
||||||
|
std::vector<unsigned char> bytes((std::istreambuf_iterator<char>(f)),
|
||||||
|
std::istreambuf_iterator<char>());
|
||||||
|
f.close();
|
||||||
|
if (bytes.empty()) {
|
||||||
|
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "audio file is empty: " + request->dst());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string b64 = base64_encode_bytes(bytes.data(), bytes.size());
|
||||||
|
|
||||||
|
// Build the same prompt upstream uses in convert_transcriptions_to_chatcmpl.
|
||||||
|
std::string user_prompt = "Transcribe audio to text";
|
||||||
|
if (!request->language().empty()) {
|
||||||
|
user_prompt += " (language: " + request->language() + ")";
|
||||||
|
}
|
||||||
|
if (!request->prompt().empty()) {
|
||||||
|
// Optional context hint from the caller.
|
||||||
|
user_prompt += "\n" + request->prompt();
|
||||||
|
}
|
||||||
|
|
||||||
|
backend::PredictOptions synthetic;
|
||||||
|
synthetic.set_usetokenizertemplate(true);
|
||||||
|
synthetic.set_temperature(request->temperature());
|
||||||
|
// Generation length: leave at 0 so parse_options uses -1 (model default).
|
||||||
|
// The model's stop tokens / EOS handle termination naturally for ASR.
|
||||||
|
backend::Message* msg = synthetic.add_messages();
|
||||||
|
msg->set_role("user");
|
||||||
|
msg->set_content(user_prompt);
|
||||||
|
synthetic.add_audios(b64);
|
||||||
|
|
||||||
|
return Predict(context, &synthetic, out_reply);
|
||||||
|
}
|
||||||
|
|
||||||
|
grpc::Status AudioTranscription(ServerContext* context,
|
||||||
|
const backend::TranscriptRequest* request,
|
||||||
|
backend::TranscriptResult* response) override {
|
||||||
|
auto auth = checkAuth(context);
|
||||||
|
if (!auth.ok()) return auth;
|
||||||
|
|
||||||
|
backend::Reply reply;
|
||||||
|
grpc::Status st = runTranscriptionAsCompletion(context, request, &reply);
|
||||||
|
if (!st.ok()) {
|
||||||
|
return st;
|
||||||
|
}
|
||||||
|
response->set_text(reply.message());
|
||||||
|
if (!request->language().empty()) {
|
||||||
|
response->set_language(request->language());
|
||||||
|
}
|
||||||
|
return grpc::Status::OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
grpc::Status AudioTranscriptionStream(ServerContext* context,
|
||||||
|
const backend::TranscriptRequest* request,
|
||||||
|
grpc::ServerWriter<backend::TranscriptStreamResponse>* writer) override {
|
||||||
|
auto auth = checkAuth(context);
|
||||||
|
if (!auth.ok()) return auth;
|
||||||
|
|
||||||
|
// Buffered streaming: run the transcription as a normal chat
|
||||||
|
// completion, then emit one delta + one final event. Real
|
||||||
|
// token-by-token streaming would require refactoring PredictStream's
|
||||||
|
// 700-line writer-coupled body; the HTTP/SSE contract is identical
|
||||||
|
// either way, and clients that only consume the assembled text don't
|
||||||
|
// notice the difference.
|
||||||
|
backend::Reply reply;
|
||||||
|
grpc::Status st = runTranscriptionAsCompletion(context, request, &reply);
|
||||||
|
if (!st.ok()) {
|
||||||
|
return st;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string& text = reply.message();
|
||||||
|
if (!text.empty()) {
|
||||||
|
backend::TranscriptStreamResponse delta_chunk;
|
||||||
|
delta_chunk.set_delta(text);
|
||||||
|
writer->Write(delta_chunk);
|
||||||
|
}
|
||||||
|
|
||||||
|
backend::TranscriptStreamResponse final_chunk;
|
||||||
|
backend::TranscriptResult* final_result = final_chunk.mutable_final_result();
|
||||||
|
final_result->set_text(text);
|
||||||
|
if (!request->language().empty()) {
|
||||||
|
final_result->set_language(request->language());
|
||||||
|
}
|
||||||
|
writer->Write(final_chunk);
|
||||||
|
return grpc::Status::OK;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -56,5 +56,6 @@ func (v *Voxtral) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptR
|
||||||
return pb.TranscriptResult{
|
return pb.TranscriptResult{
|
||||||
Segments: segments,
|
Segments: segments,
|
||||||
Text: text,
|
Text: text,
|
||||||
|
Language: opts.Language,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -120,6 +120,12 @@ func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptR
|
||||||
}
|
}
|
||||||
|
|
||||||
data := buf.AsFloat32Buffer().Data
|
data := buf.AsFloat32Buffer().Data
|
||||||
|
// whisper.cpp resamples to 16 kHz internally; this matches buf.Format.SampleRate
|
||||||
|
// for the converted file produced by AudioToWav above.
|
||||||
|
var duration float32
|
||||||
|
if buf.Format != nil && buf.Format.SampleRate > 0 {
|
||||||
|
duration = float32(len(data)) / float32(buf.Format.SampleRate)
|
||||||
|
}
|
||||||
segsLen := uintptr(0xdeadbeef)
|
segsLen := uintptr(0xdeadbeef)
|
||||||
segsLenPtr := unsafe.Pointer(&segsLen)
|
segsLenPtr := unsafe.Pointer(&segsLen)
|
||||||
|
|
||||||
|
|
@ -158,5 +164,7 @@ func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptR
|
||||||
return pb.TranscriptResult{
|
return pb.TranscriptResult{
|
||||||
Segments: segments,
|
Segments: segments,
|
||||||
Text: strings.TrimSpace(text),
|
Text: strings.TrimSpace(text),
|
||||||
|
Language: opts.Language,
|
||||||
|
Duration: duration,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -10,26 +10,68 @@ import (
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
"github.com/mudler/LocalAI/core/trace"
|
"github.com/mudler/LocalAI/core/trace"
|
||||||
|
|
||||||
|
grpcPkg "github.com/mudler/LocalAI/pkg/grpc"
|
||||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ModelTranscription(audio, language string, translate, diarize bool, prompt string, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
// TranscriptionRequest groups the parameters accepted by ModelTranscription.
|
||||||
|
// Use this so callers don't have to pass long positional arg lists when they
|
||||||
|
// only care about a subset of fields.
|
||||||
|
type TranscriptionRequest struct {
|
||||||
|
Audio string
|
||||||
|
Language string
|
||||||
|
Translate bool
|
||||||
|
Diarize bool
|
||||||
|
Prompt string
|
||||||
|
Temperature float32
|
||||||
|
TimestampGranularities []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *TranscriptionRequest) toProto(threads uint32) *proto.TranscriptRequest {
|
||||||
|
return &proto.TranscriptRequest{
|
||||||
|
Dst: r.Audio,
|
||||||
|
Language: r.Language,
|
||||||
|
Translate: r.Translate,
|
||||||
|
Diarize: r.Diarize,
|
||||||
|
Threads: threads,
|
||||||
|
Prompt: r.Prompt,
|
||||||
|
Temperature: r.Temperature,
|
||||||
|
TimestampGranularities: r.TimestampGranularities,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadTranscriptionModel(ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (grpcPkg.Backend, error) {
|
||||||
if modelConfig.Backend == "" {
|
if modelConfig.Backend == "" {
|
||||||
modelConfig.Backend = model.WhisperBackend
|
modelConfig.Backend = model.WhisperBackend
|
||||||
}
|
}
|
||||||
|
|
||||||
opts := ModelOptions(modelConfig, appConfig)
|
opts := ModelOptions(modelConfig, appConfig)
|
||||||
|
|
||||||
transcriptionModel, err := ml.Load(opts...)
|
transcriptionModel, err := ml.Load(opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if transcriptionModel == nil {
|
if transcriptionModel == nil {
|
||||||
return nil, fmt.Errorf("could not load transcription model")
|
return nil, fmt.Errorf("could not load transcription model")
|
||||||
}
|
}
|
||||||
|
return transcriptionModel, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ModelTranscription(audio, language string, translate, diarize bool, prompt string, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
||||||
|
return ModelTranscriptionWithOptions(TranscriptionRequest{
|
||||||
|
Audio: audio,
|
||||||
|
Language: language,
|
||||||
|
Translate: translate,
|
||||||
|
Diarize: diarize,
|
||||||
|
Prompt: prompt,
|
||||||
|
}, ml, modelConfig, appConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ModelTranscriptionWithOptions(req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
||||||
|
transcriptionModel, err := loadTranscriptionModel(ml, modelConfig, appConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
var startTime time.Time
|
var startTime time.Time
|
||||||
var audioSnippet map[string]any
|
var audioSnippet map[string]any
|
||||||
|
|
@ -37,25 +79,18 @@ func ModelTranscription(audio, language string, translate, diarize bool, prompt
|
||||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||||
startTime = time.Now()
|
startTime = time.Now()
|
||||||
// Capture audio before the backend call — the backend may delete the file.
|
// Capture audio before the backend call — the backend may delete the file.
|
||||||
audioSnippet = trace.AudioSnippet(audio)
|
audioSnippet = trace.AudioSnippet(req.Audio)
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err := transcriptionModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
|
r, err := transcriptionModel.AudioTranscription(context.Background(), req.toProto(uint32(*modelConfig.Threads)))
|
||||||
Dst: audio,
|
|
||||||
Language: language,
|
|
||||||
Translate: translate,
|
|
||||||
Diarize: diarize,
|
|
||||||
Threads: uint32(*modelConfig.Threads),
|
|
||||||
Prompt: prompt,
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if appConfig.EnableTracing {
|
if appConfig.EnableTracing {
|
||||||
errData := map[string]any{
|
errData := map[string]any{
|
||||||
"audio_file": audio,
|
"audio_file": req.Audio,
|
||||||
"language": language,
|
"language": req.Language,
|
||||||
"translate": translate,
|
"translate": req.Translate,
|
||||||
"diarize": diarize,
|
"diarize": req.Diarize,
|
||||||
"prompt": prompt,
|
"prompt": req.Prompt,
|
||||||
}
|
}
|
||||||
if audioSnippet != nil {
|
if audioSnippet != nil {
|
||||||
maps.Copy(errData, audioSnippet)
|
maps.Copy(errData, audioSnippet)
|
||||||
|
|
@ -66,15 +101,83 @@ func ModelTranscription(audio, language string, translate, diarize bool, prompt
|
||||||
Type: trace.BackendTraceTranscription,
|
Type: trace.BackendTraceTranscription,
|
||||||
ModelName: modelConfig.Name,
|
ModelName: modelConfig.Name,
|
||||||
Backend: modelConfig.Backend,
|
Backend: modelConfig.Backend,
|
||||||
Summary: trace.TruncateString(audio, 200),
|
Summary: trace.TruncateString(req.Audio, 200),
|
||||||
Error: err.Error(),
|
Error: err.Error(),
|
||||||
Data: errData,
|
Data: errData,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
tr := transcriptResultFromProto(r)
|
||||||
|
|
||||||
|
if appConfig.EnableTracing {
|
||||||
|
data := map[string]any{
|
||||||
|
"audio_file": req.Audio,
|
||||||
|
"language": req.Language,
|
||||||
|
"translate": req.Translate,
|
||||||
|
"diarize": req.Diarize,
|
||||||
|
"prompt": req.Prompt,
|
||||||
|
"result_text": tr.Text,
|
||||||
|
"segments_count": len(tr.Segments),
|
||||||
|
}
|
||||||
|
if audioSnippet != nil {
|
||||||
|
maps.Copy(data, audioSnippet)
|
||||||
|
}
|
||||||
|
trace.RecordBackendTrace(trace.BackendTrace{
|
||||||
|
Timestamp: startTime,
|
||||||
|
Duration: time.Since(startTime),
|
||||||
|
Type: trace.BackendTraceTranscription,
|
||||||
|
ModelName: modelConfig.Name,
|
||||||
|
Backend: modelConfig.Backend,
|
||||||
|
Summary: trace.TruncateString(req.Audio+" -> "+tr.Text, 200),
|
||||||
|
Data: data,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return tr, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TranscriptionStreamChunk is a streaming event emitted by
|
||||||
|
// ModelTranscriptionStream. Either Delta carries an incremental text fragment,
|
||||||
|
// or Final carries the completed transcription as the very last event.
|
||||||
|
type TranscriptionStreamChunk struct {
|
||||||
|
Delta string
|
||||||
|
Final *schema.TranscriptionResult
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelTranscriptionStream runs the gRPC streaming transcription RPC and
|
||||||
|
// invokes onChunk for each event the backend produces. Backends that don't
|
||||||
|
// support real streaming should still emit one terminal event with Final set,
|
||||||
|
// which the HTTP layer turns into a single delta + done SSE pair.
|
||||||
|
func ModelTranscriptionStream(req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, onChunk func(TranscriptionStreamChunk)) error {
|
||||||
|
transcriptionModel, err := loadTranscriptionModel(ml, modelConfig, appConfig)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
pbReq := req.toProto(uint32(*modelConfig.Threads))
|
||||||
|
pbReq.Stream = true
|
||||||
|
|
||||||
|
return transcriptionModel.AudioTranscriptionStream(context.Background(), pbReq, func(chunk *proto.TranscriptStreamResponse) {
|
||||||
|
if chunk == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
out := TranscriptionStreamChunk{Delta: chunk.Delta}
|
||||||
|
if chunk.FinalResult != nil {
|
||||||
|
out.Final = transcriptResultFromProto(chunk.FinalResult)
|
||||||
|
}
|
||||||
|
onChunk(out)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func transcriptResultFromProto(r *proto.TranscriptResult) *schema.TranscriptionResult {
|
||||||
|
if r == nil {
|
||||||
|
return &schema.TranscriptionResult{}
|
||||||
|
}
|
||||||
tr := &schema.TranscriptionResult{
|
tr := &schema.TranscriptionResult{
|
||||||
Text: r.Text,
|
Text: r.Text,
|
||||||
|
Language: r.Language,
|
||||||
|
Duration: float64(r.Duration),
|
||||||
}
|
}
|
||||||
for _, s := range r.Segments {
|
for _, s := range r.Segments {
|
||||||
var tks []int
|
var tks []int
|
||||||
|
|
@ -91,30 +194,5 @@ func ModelTranscription(audio, language string, translate, diarize bool, prompt
|
||||||
Speaker: s.Speaker,
|
Speaker: s.Speaker,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
return tr
|
||||||
if appConfig.EnableTracing {
|
|
||||||
data := map[string]any{
|
|
||||||
"audio_file": audio,
|
|
||||||
"language": language,
|
|
||||||
"translate": translate,
|
|
||||||
"diarize": diarize,
|
|
||||||
"prompt": prompt,
|
|
||||||
"result_text": tr.Text,
|
|
||||||
"segments_count": len(tr.Segments),
|
|
||||||
}
|
|
||||||
if audioSnippet != nil {
|
|
||||||
maps.Copy(data, audioSnippet)
|
|
||||||
}
|
|
||||||
trace.RecordBackendTrace(trace.BackendTrace{
|
|
||||||
Timestamp: startTime,
|
|
||||||
Duration: time.Since(startTime),
|
|
||||||
Type: trace.BackendTraceTranscription,
|
|
||||||
ModelName: modelConfig.Name,
|
|
||||||
Backend: modelConfig.Backend,
|
|
||||||
Summary: trace.TruncateString(audio+" -> "+tr.Text, 200),
|
|
||||||
Data: data,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return tr, err
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,16 @@
|
||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
"github.com/mudler/LocalAI/core/backend"
|
"github.com/mudler/LocalAI/core/backend"
|
||||||
|
|
@ -24,6 +28,9 @@ import (
|
||||||
// @accept multipart/form-data
|
// @accept multipart/form-data
|
||||||
// @Param model formData string true "model"
|
// @Param model formData string true "model"
|
||||||
// @Param file formData file true "file"
|
// @Param file formData file true "file"
|
||||||
|
// @Param temperature formData number false "sampling temperature"
|
||||||
|
// @Param timestamp_granularities formData []string false "timestamp granularities (word, segment)"
|
||||||
|
// @Param stream formData boolean false "stream partial results as SSE"
|
||||||
// @Success 200 {object} map[string]string "Response"
|
// @Success 200 {object} map[string]string "Response"
|
||||||
// @Router /v1/audio/transcriptions [post]
|
// @Router /v1/audio/transcriptions [post]
|
||||||
func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||||
|
|
@ -42,6 +49,38 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||||
prompt := c.FormValue("prompt")
|
prompt := c.FormValue("prompt")
|
||||||
responseFormat := schema.TranscriptionResponseFormatType(c.FormValue("response_format"))
|
responseFormat := schema.TranscriptionResponseFormatType(c.FormValue("response_format"))
|
||||||
|
|
||||||
|
// OpenAI accepts `temperature` as a string in multipart form. Tolerate
|
||||||
|
// missing/invalid values rather than failing the whole request.
|
||||||
|
var temperature float32
|
||||||
|
if v := c.FormValue("temperature"); v != "" {
|
||||||
|
if t, err := strconv.ParseFloat(v, 32); err == nil {
|
||||||
|
temperature = float32(t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// timestamp_granularities[] is a multi-value form field per the OpenAI spec.
|
||||||
|
// Echo exposes all values for a key via FormParams.
|
||||||
|
var timestampGranularities []string
|
||||||
|
if form, err := c.FormParams(); err == nil {
|
||||||
|
for _, key := range []string{"timestamp_granularities[]", "timestamp_granularities"} {
|
||||||
|
if vals, ok := form[key]; ok {
|
||||||
|
for _, v := range vals {
|
||||||
|
v = strings.TrimSpace(v)
|
||||||
|
if v != "" {
|
||||||
|
timestampGranularities = append(timestampGranularities, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
stream := false
|
||||||
|
if v := c.FormValue("stream"); v != "" {
|
||||||
|
if b, err := strconv.ParseBool(v); err == nil {
|
||||||
|
stream = b
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// retrieve the file data from the request
|
// retrieve the file data from the request
|
||||||
file, err := c.FormFile("file")
|
file, err := c.FormFile("file")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -73,7 +112,21 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||||
|
|
||||||
xlog.Debug("Audio file copied", "dst", dst)
|
xlog.Debug("Audio file copied", "dst", dst)
|
||||||
|
|
||||||
tr, err := backend.ModelTranscription(dst, input.Language, input.Translate, diarize, prompt, ml, *config, appConfig)
|
req := backend.TranscriptionRequest{
|
||||||
|
Audio: dst,
|
||||||
|
Language: input.Language,
|
||||||
|
Translate: input.Translate,
|
||||||
|
Diarize: diarize,
|
||||||
|
Prompt: prompt,
|
||||||
|
Temperature: temperature,
|
||||||
|
TimestampGranularities: timestampGranularities,
|
||||||
|
}
|
||||||
|
|
||||||
|
if stream {
|
||||||
|
return streamTranscription(c, req, ml, *config, appConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
tr, err := backend.ModelTranscriptionWithOptions(req, ml, *config, appConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -93,3 +146,79 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// streamTranscription emits OpenAI-format SSE events for a transcription
|
||||||
|
// request: one `transcript.text.delta` per backend chunk, a final
|
||||||
|
// `transcript.text.done` with the assembled text, and `[DONE]`. Backends that
|
||||||
|
// can't truly stream still produce a single Final event, which we surface as
|
||||||
|
// one delta + done.
|
||||||
|
func streamTranscription(c echo.Context, req backend.TranscriptionRequest, ml *model.ModelLoader, config config.ModelConfig, appConfig *config.ApplicationConfig) error {
|
||||||
|
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||||
|
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||||
|
c.Response().Header().Set("Connection", "keep-alive")
|
||||||
|
c.Response().WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
|
writeEvent := func(payload any) error {
|
||||||
|
data, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := fmt.Fprintf(c.Response().Writer, "data: %s\n\n", data); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.Response().Flush()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var assembled strings.Builder
|
||||||
|
var finalResult *schema.TranscriptionResult
|
||||||
|
|
||||||
|
err := backend.ModelTranscriptionStream(req, ml, config, appConfig, func(chunk backend.TranscriptionStreamChunk) {
|
||||||
|
if chunk.Delta != "" {
|
||||||
|
assembled.WriteString(chunk.Delta)
|
||||||
|
_ = writeEvent(map[string]any{
|
||||||
|
"type": "transcript.text.delta",
|
||||||
|
"delta": chunk.Delta,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if chunk.Final != nil {
|
||||||
|
finalResult = chunk.Final
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
errPayload := map[string]any{
|
||||||
|
"type": "error",
|
||||||
|
"error": map[string]any{
|
||||||
|
"message": err.Error(),
|
||||||
|
"type": "server_error",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_ = writeEvent(errPayload)
|
||||||
|
_, _ = fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
|
||||||
|
c.Response().Flush()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the final event. Prefer the backend-provided final result; if the
|
||||||
|
// backend only emitted deltas, synthesize the result from what we collected.
|
||||||
|
if finalResult == nil {
|
||||||
|
finalResult = &schema.TranscriptionResult{Text: assembled.String()}
|
||||||
|
} else if finalResult.Text == "" && assembled.Len() > 0 {
|
||||||
|
finalResult.Text = assembled.String()
|
||||||
|
}
|
||||||
|
// If the backend never produced a delta but did return a final text, emit
|
||||||
|
// it as a single delta so clients always see at least one delta event.
|
||||||
|
if assembled.Len() == 0 && finalResult.Text != "" {
|
||||||
|
_ = writeEvent(map[string]any{
|
||||||
|
"type": "transcript.text.delta",
|
||||||
|
"delta": finalResult.Text,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
_ = writeEvent(map[string]any{
|
||||||
|
"type": "transcript.text.done",
|
||||||
|
"text": finalResult.Text,
|
||||||
|
})
|
||||||
|
_, _ = fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
|
||||||
|
c.Response().Flush()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -14,4 +14,6 @@ type TranscriptionSegment struct {
|
||||||
type TranscriptionResult struct {
|
type TranscriptionResult struct {
|
||||||
Segments []TranscriptionSegment `json:"segments,omitempty"`
|
Segments []TranscriptionSegment `json:"segments,omitempty"`
|
||||||
Text string `json:"text"`
|
Text string `json:"text"`
|
||||||
|
Language string `json:"language,omitempty"`
|
||||||
|
Duration float64 `json:"duration,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -294,6 +294,21 @@ func (f *FileStagingClient) AudioTranscription(ctx context.Context, in *pb.Trans
|
||||||
return f.Backend.AudioTranscription(ctx, in, opts...)
|
return f.Backend.AudioTranscription(ctx, in, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *FileStagingClient) AudioTranscriptionStream(ctx context.Context, in *pb.TranscriptRequest, fn func(chunk *pb.TranscriptStreamResponse), opts ...ggrpc.CallOption) error {
|
||||||
|
reqID := requestID()
|
||||||
|
|
||||||
|
// Stage input audio file
|
||||||
|
if in.Dst != "" && isFilePath(in.Dst) {
|
||||||
|
backendPath, _, err := f.stageInputFile(ctx, reqID, in.Dst, "inputs")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("staging audio for transcription stream: %w", err)
|
||||||
|
}
|
||||||
|
in.Dst = backendPath
|
||||||
|
}
|
||||||
|
|
||||||
|
return f.Backend.AudioTranscriptionStream(ctx, in, fn, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
func (f *FileStagingClient) ExportModel(ctx context.Context, in *pb.ExportModelRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
|
func (f *FileStagingClient) ExportModel(ctx context.Context, in *pb.ExportModelRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
|
||||||
frontendOutputPath := in.OutputPath
|
frontendOutputPath := in.OutputPath
|
||||||
if frontendOutputPath != "" {
|
if frontendOutputPath != "" {
|
||||||
|
|
|
||||||
|
|
@ -171,6 +171,9 @@ func (c *fakeBackendClient) Detect(_ context.Context, _ *pb.DetectOptions, _ ...
|
||||||
func (c *fakeBackendClient) AudioTranscription(_ context.Context, _ *pb.TranscriptRequest, _ ...ggrpc.CallOption) (*pb.TranscriptResult, error) {
|
func (c *fakeBackendClient) AudioTranscription(_ context.Context, _ *pb.TranscriptRequest, _ ...ggrpc.CallOption) (*pb.TranscriptResult, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
func (c *fakeBackendClient) AudioTranscriptionStream(_ context.Context, _ *pb.TranscriptRequest, _ func(chunk *pb.TranscriptStreamResponse), _ ...ggrpc.CallOption) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
func (c *fakeBackendClient) TokenizeString(_ context.Context, _ *pb.PredictOptions, _ ...ggrpc.CallOption) (*pb.TokenizationResponse, error) {
|
func (c *fakeBackendClient) TokenizeString(_ context.Context, _ *pb.PredictOptions, _ ...ggrpc.CallOption) (*pb.TokenizationResponse, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -105,6 +105,11 @@ func (c *InFlightTrackingClient) AudioTranscription(ctx context.Context, in *pb.
|
||||||
return c.Backend.AudioTranscription(ctx, in, opts...)
|
return c.Backend.AudioTranscription(ctx, in, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *InFlightTrackingClient) AudioTranscriptionStream(ctx context.Context, in *pb.TranscriptRequest, f func(chunk *pb.TranscriptStreamResponse), opts ...ggrpc.CallOption) error {
|
||||||
|
defer c.track(ctx)()
|
||||||
|
return c.Backend.AudioTranscriptionStream(ctx, in, f, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
func (c *InFlightTrackingClient) Detect(ctx context.Context, in *pb.DetectOptions, opts ...ggrpc.CallOption) (*pb.DetectResponse, error) {
|
func (c *InFlightTrackingClient) Detect(ctx context.Context, in *pb.DetectOptions, opts ...ggrpc.CallOption) (*pb.DetectResponse, error) {
|
||||||
defer c.track(ctx)()
|
defer c.track(ctx)()
|
||||||
return c.Backend.Detect(ctx, in, opts...)
|
return c.Backend.Detect(ctx, in, opts...)
|
||||||
|
|
|
||||||
|
|
@ -95,6 +95,10 @@ func (f *fakeGRPCBackend) AudioTranscription(_ context.Context, _ *pb.Transcript
|
||||||
return &pb.TranscriptResult{}, nil
|
return &pb.TranscriptResult{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *fakeGRPCBackend) AudioTranscriptionStream(_ context.Context, _ *pb.TranscriptRequest, _ func(chunk *pb.TranscriptStreamResponse), _ ...ggrpc.CallOption) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (f *fakeGRPCBackend) TokenizeString(_ context.Context, _ *pb.PredictOptions, _ ...ggrpc.CallOption) (*pb.TokenizationResponse, error) {
|
func (f *fakeGRPCBackend) TokenizeString(_ context.Context, _ *pb.PredictOptions, _ ...ggrpc.CallOption) (*pb.TokenizationResponse, error) {
|
||||||
return &pb.TokenizationResponse{}, nil
|
return &pb.TokenizationResponse{}, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -55,6 +55,7 @@ type Backend interface {
|
||||||
SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||||
Detect(ctx context.Context, in *pb.DetectOptions, opts ...grpc.CallOption) (*pb.DetectResponse, error)
|
Detect(ctx context.Context, in *pb.DetectOptions, opts ...grpc.CallOption) (*pb.DetectResponse, error)
|
||||||
AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*pb.TranscriptResult, error)
|
AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*pb.TranscriptResult, error)
|
||||||
|
AudioTranscriptionStream(ctx context.Context, in *pb.TranscriptRequest, f func(chunk *pb.TranscriptStreamResponse), opts ...grpc.CallOption) error
|
||||||
TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error)
|
TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error)
|
||||||
Status(ctx context.Context) (*pb.StatusResponse, error)
|
Status(ctx context.Context) (*pb.StatusResponse, error)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -61,6 +61,10 @@ func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult,
|
||||||
return pb.TranscriptResult{}, fmt.Errorf("unimplemented")
|
return pb.TranscriptResult{}, fmt.Errorf("unimplemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (llm *Base) AudioTranscriptionStream(*pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error {
|
||||||
|
return fmt.Errorf("unimplemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (llm *Base) TTS(*pb.TTSRequest) error {
|
func (llm *Base) TTS(*pb.TTSRequest) error {
|
||||||
return fmt.Errorf("unimplemented")
|
return fmt.Errorf("unimplemented")
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -352,6 +352,50 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
|
||||||
return client.AudioTranscription(ctx, in, opts...)
|
return client.AudioTranscription(ctx, in, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) AudioTranscriptionStream(ctx context.Context, in *pb.TranscriptRequest, f func(chunk *pb.TranscriptStreamResponse), opts ...grpc.CallOption) error {
|
||||||
|
if !c.parallel {
|
||||||
|
c.opMutex.Lock()
|
||||||
|
defer c.opMutex.Unlock()
|
||||||
|
}
|
||||||
|
c.setBusy(true)
|
||||||
|
defer c.setBusy(false)
|
||||||
|
c.wdMark()
|
||||||
|
defer c.wdUnMark()
|
||||||
|
conn, err := c.dial()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
client := pb.NewBackendClient(conn)
|
||||||
|
|
||||||
|
stream, err := client.AudioTranscriptionStream(ctx, in, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
chunk, err := stream.Recv()
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
f(chunk)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Client) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) {
|
func (c *Client) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) {
|
||||||
if !c.parallel {
|
if !c.parallel {
|
||||||
c.opMutex.Lock()
|
c.opMutex.Lock()
|
||||||
|
|
|
||||||
|
|
@ -75,6 +75,14 @@ func (e *embedBackend) AudioTranscription(ctx context.Context, in *pb.Transcript
|
||||||
return e.s.AudioTranscription(ctx, in)
|
return e.s.AudioTranscription(ctx, in)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *embedBackend) AudioTranscriptionStream(ctx context.Context, in *pb.TranscriptRequest, f func(chunk *pb.TranscriptStreamResponse), opts ...grpc.CallOption) error {
|
||||||
|
bs := &embedBackendAudioTranscriptionStream{
|
||||||
|
ctx: ctx,
|
||||||
|
fn: f,
|
||||||
|
}
|
||||||
|
return e.s.AudioTranscriptionStream(in, bs)
|
||||||
|
}
|
||||||
|
|
||||||
func (e *embedBackend) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) {
|
func (e *embedBackend) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) {
|
||||||
return e.s.TokenizeString(ctx, in)
|
return e.s.TokenizeString(ctx, in)
|
||||||
}
|
}
|
||||||
|
|
@ -168,6 +176,44 @@ func (e *embedBackend) Free(ctx context.Context) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ pb.Backend_AudioTranscriptionStreamServer = new(embedBackendAudioTranscriptionStream)
|
||||||
|
|
||||||
|
type embedBackendAudioTranscriptionStream struct {
|
||||||
|
ctx context.Context
|
||||||
|
fn func(chunk *pb.TranscriptStreamResponse)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *embedBackendAudioTranscriptionStream) Send(chunk *pb.TranscriptStreamResponse) error {
|
||||||
|
e.fn(chunk)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *embedBackendAudioTranscriptionStream) SetHeader(md metadata.MD) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *embedBackendAudioTranscriptionStream) SendHeader(md metadata.MD) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *embedBackendAudioTranscriptionStream) SetTrailer(md metadata.MD) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *embedBackendAudioTranscriptionStream) Context() context.Context {
|
||||||
|
return e.ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *embedBackendAudioTranscriptionStream) SendMsg(m any) error {
|
||||||
|
if x, ok := m.(*pb.TranscriptStreamResponse); ok {
|
||||||
|
return e.Send(x)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *embedBackendAudioTranscriptionStream) RecvMsg(m any) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
var _ pb.Backend_FineTuneProgressServer = new(embedBackendFineTuneProgressStream)
|
var _ pb.Backend_FineTuneProgressServer = new(embedBackendFineTuneProgressStream)
|
||||||
|
|
||||||
type embedBackendFineTuneProgressStream struct {
|
type embedBackendFineTuneProgressStream struct {
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ type AIModel interface {
|
||||||
GenerateVideo(*pb.GenerateVideoRequest) error
|
GenerateVideo(*pb.GenerateVideoRequest) error
|
||||||
Detect(*pb.DetectOptions) (pb.DetectResponse, error)
|
Detect(*pb.DetectOptions) (pb.DetectResponse, error)
|
||||||
AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error)
|
AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error)
|
||||||
|
AudioTranscriptionStream(*pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error
|
||||||
TTS(*pb.TTSRequest) error
|
TTS(*pb.TTSRequest) error
|
||||||
TTSStream(*pb.TTSRequest, chan []byte) error
|
TTSStream(*pb.TTSRequest, chan []byte) error
|
||||||
SoundGeneration(*pb.SoundGenerationRequest) error
|
SoundGeneration(*pb.SoundGenerationRequest) error
|
||||||
|
|
|
||||||
|
|
@ -168,18 +168,42 @@ func (s *server) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
|
||||||
}
|
}
|
||||||
tresult.Segments = append(tresult.Segments,
|
tresult.Segments = append(tresult.Segments,
|
||||||
&pb.TranscriptSegment{
|
&pb.TranscriptSegment{
|
||||||
Text: s.Text,
|
Text: s.Text,
|
||||||
Id: int32(s.Id),
|
Id: int32(s.Id),
|
||||||
Start: int64(s.Start),
|
Start: int64(s.Start),
|
||||||
End: int64(s.End),
|
End: int64(s.End),
|
||||||
Tokens: tks,
|
Tokens: tks,
|
||||||
|
Speaker: s.Speaker,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
tresult.Text = result.Text
|
tresult.Text = result.Text
|
||||||
|
tresult.Language = result.Language
|
||||||
|
tresult.Duration = result.Duration
|
||||||
return tresult, nil
|
return tresult, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *server) AudioTranscriptionStream(in *pb.TranscriptRequest, stream pb.Backend_AudioTranscriptionStreamServer) error {
|
||||||
|
if s.llm.Locking() {
|
||||||
|
s.llm.Lock()
|
||||||
|
defer s.llm.Unlock()
|
||||||
|
}
|
||||||
|
resultChan := make(chan *pb.TranscriptStreamResponse)
|
||||||
|
|
||||||
|
done := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
for chunk := range resultChan {
|
||||||
|
stream.Send(chunk)
|
||||||
|
}
|
||||||
|
done <- true
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := s.llm.AudioTranscriptionStream(in, resultChan)
|
||||||
|
<-done
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (s *server) PredictStream(in *pb.PredictOptions, stream pb.Backend_PredictStreamServer) error {
|
func (s *server) PredictStream(in *pb.PredictOptions, stream pb.Backend_PredictStreamServer) error {
|
||||||
if s.llm.Locking() {
|
if s.llm.Locking() {
|
||||||
s.llm.Lock()
|
s.llm.Lock()
|
||||||
|
|
|
||||||
|
|
@ -96,6 +96,12 @@ func (c *ConnectionEvictingClient) AudioTranscription(ctx context.Context, in *p
|
||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *ConnectionEvictingClient) AudioTranscriptionStream(ctx context.Context, in *pb.TranscriptRequest, f func(chunk *pb.TranscriptStreamResponse), opts ...ggrpc.CallOption) error {
|
||||||
|
err := c.Backend.AudioTranscriptionStream(ctx, in, f, opts...)
|
||||||
|
c.checkErr(err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (c *ConnectionEvictingClient) Detect(ctx context.Context, in *pb.DetectOptions, opts ...ggrpc.CallOption) (*pb.DetectResponse, error) {
|
func (c *ConnectionEvictingClient) Detect(ctx context.Context, in *pb.DetectOptions, opts ...ggrpc.CallOption) (*pb.DetectResponse, error) {
|
||||||
result, err := c.Backend.Detect(ctx, in, opts...)
|
result, err := c.Backend.Detect(ctx, in, opts...)
|
||||||
c.checkErr(err)
|
c.checkErr(err)
|
||||||
|
|
|
||||||
|
|
@ -35,9 +35,16 @@ import (
|
||||||
//
|
//
|
||||||
// Optional:
|
// Optional:
|
||||||
//
|
//
|
||||||
|
// BACKEND_TEST_MMPROJ_URL HTTP(S) URL of an mmproj file (audio/vision encoder)
|
||||||
|
// to download alongside the main model — required for
|
||||||
|
// multimodal models like Qwen3-ASR-0.6B-GGUF.
|
||||||
|
// BACKEND_TEST_MMPROJ_FILE Path to an already-available mmproj file.
|
||||||
|
// BACKEND_TEST_AUDIO_URL HTTP(S) URL of a sample audio file used by the
|
||||||
|
// transcription specs.
|
||||||
|
// BACKEND_TEST_AUDIO_FILE Path to an already-available sample audio file.
|
||||||
// BACKEND_TEST_CAPS Comma-separated list of capabilities to exercise.
|
// BACKEND_TEST_CAPS Comma-separated list of capabilities to exercise.
|
||||||
// Supported values: health, load, predict, stream,
|
// Supported values: health, load, predict, stream,
|
||||||
// embeddings, tools.
|
// embeddings, tools, transcription.
|
||||||
// Defaults to "health,load,predict,stream".
|
// Defaults to "health,load,predict,stream".
|
||||||
// A backend that only does embeddings would set this to
|
// A backend that only does embeddings would set this to
|
||||||
// "health,load,embeddings"; an image/TTS backend that cannot
|
// "health,load,embeddings"; an image/TTS backend that cannot
|
||||||
|
|
@ -58,12 +65,13 @@ import (
|
||||||
// file path to LoadModel, so GGUF, ONNX, safetensors, .bin etc. all work so
|
// file path to LoadModel, so GGUF, ONNX, safetensors, .bin etc. all work so
|
||||||
// long as the backend under test accepts that format.
|
// long as the backend under test accepts that format.
|
||||||
const (
|
const (
|
||||||
capHealth = "health"
|
capHealth = "health"
|
||||||
capLoad = "load"
|
capLoad = "load"
|
||||||
capPredict = "predict"
|
capPredict = "predict"
|
||||||
capStream = "stream"
|
capStream = "stream"
|
||||||
capEmbeddings = "embeddings"
|
capEmbeddings = "embeddings"
|
||||||
capTools = "tools"
|
capTools = "tools"
|
||||||
|
capTranscription = "transcription"
|
||||||
|
|
||||||
defaultPrompt = "The capital of France is"
|
defaultPrompt = "The capital of France is"
|
||||||
streamPrompt = "Once upon a time"
|
streamPrompt = "Once upon a time"
|
||||||
|
|
@ -99,17 +107,19 @@ func parseCaps() map[string]bool {
|
||||||
|
|
||||||
var _ = Describe("Backend container", Ordered, func() {
|
var _ = Describe("Backend container", Ordered, func() {
|
||||||
var (
|
var (
|
||||||
caps map[string]bool
|
caps map[string]bool
|
||||||
workDir string
|
workDir string
|
||||||
binaryDir string
|
binaryDir string
|
||||||
modelFile string // set when a local file is used
|
modelFile string // set when a local file is used
|
||||||
modelName string // set when a HuggingFace model id is used
|
modelName string // set when a HuggingFace model id is used
|
||||||
addr string
|
mmprojFile string // optional multimodal projector
|
||||||
serverCmd *exec.Cmd
|
audioFile string // optional audio fixture for transcription specs
|
||||||
conn *grpc.ClientConn
|
addr string
|
||||||
client pb.BackendClient
|
serverCmd *exec.Cmd
|
||||||
prompt string
|
conn *grpc.ClientConn
|
||||||
options []string
|
client pb.BackendClient
|
||||||
|
prompt string
|
||||||
|
options []string
|
||||||
)
|
)
|
||||||
|
|
||||||
BeforeAll(func() {
|
BeforeAll(func() {
|
||||||
|
|
@ -155,6 +165,25 @@ var _ = Describe("Backend container", Ordered, func() {
|
||||||
downloadFile(modelURL, modelFile)
|
downloadFile(modelURL, modelFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Multimodal projector (mmproj): required by audio/vision-capable
|
||||||
|
// llama.cpp models like Qwen3-ASR-0.6B-GGUF. Either file or URL.
|
||||||
|
mmprojFile = os.Getenv("BACKEND_TEST_MMPROJ_FILE")
|
||||||
|
if mmprojFile == "" {
|
||||||
|
if url := os.Getenv("BACKEND_TEST_MMPROJ_URL"); url != "" {
|
||||||
|
mmprojFile = filepath.Join(workDir, "mmproj.bin")
|
||||||
|
downloadFile(url, mmprojFile)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Audio fixture for the transcription specs.
|
||||||
|
audioFile = os.Getenv("BACKEND_TEST_AUDIO_FILE")
|
||||||
|
if audioFile == "" {
|
||||||
|
if url := os.Getenv("BACKEND_TEST_AUDIO_URL"); url != "" {
|
||||||
|
audioFile = filepath.Join(workDir, "sample.wav")
|
||||||
|
downloadFile(url, audioFile)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Pick a free port and launch the backend.
|
// Pick a free port and launch the backend.
|
||||||
port, err := freeport.GetFreePort()
|
port, err := freeport.GetFreePort()
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
@ -244,6 +273,7 @@ var _ = Describe("Backend container", Ordered, func() {
|
||||||
MMap: true,
|
MMap: true,
|
||||||
NBatch: 128,
|
NBatch: 128,
|
||||||
Options: options,
|
Options: options,
|
||||||
|
MMProj: mmprojFile,
|
||||||
})
|
})
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
Expect(res.GetSuccess()).To(BeTrue(), "LoadModel failed: %s", res.GetMessage())
|
Expect(res.GetSuccess()).To(BeTrue(), "LoadModel failed: %s", res.GetMessage())
|
||||||
|
|
@ -385,6 +415,75 @@ var _ = Describe("Backend container", Ordered, func() {
|
||||||
Expect(matched).To(BeTrue(),
|
Expect(matched).To(BeTrue(),
|
||||||
"Expected a tool call named %q in ChatDelta.tool_calls", toolName)
|
"Expected a tool call named %q in ChatDelta.tool_calls", toolName)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("transcribes audio via AudioTranscription", func() {
|
||||||
|
if !caps[capTranscription] {
|
||||||
|
Skip("transcription capability not enabled")
|
||||||
|
}
|
||||||
|
Expect(audioFile).NotTo(BeEmpty(),
|
||||||
|
"BACKEND_TEST_AUDIO_FILE or BACKEND_TEST_AUDIO_URL must be set when transcription cap is enabled")
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
res, err := client.AudioTranscription(ctx, &pb.TranscriptRequest{
|
||||||
|
Dst: audioFile,
|
||||||
|
Threads: uint32(envInt32("BACKEND_TEST_THREADS", 4)),
|
||||||
|
Temperature: 0.0,
|
||||||
|
})
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(strings.TrimSpace(res.GetText())).NotTo(BeEmpty(),
|
||||||
|
"AudioTranscription returned empty text")
|
||||||
|
GinkgoWriter.Printf("AudioTranscription: text=%q language=%q duration=%v\n",
|
||||||
|
res.GetText(), res.GetLanguage(), res.GetDuration())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("streams audio transcription via AudioTranscriptionStream", func() {
|
||||||
|
if !caps[capTranscription] {
|
||||||
|
Skip("transcription capability not enabled")
|
||||||
|
}
|
||||||
|
Expect(audioFile).NotTo(BeEmpty(),
|
||||||
|
"BACKEND_TEST_AUDIO_FILE or BACKEND_TEST_AUDIO_URL must be set when transcription cap is enabled")
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
stream, err := client.AudioTranscriptionStream(ctx, &pb.TranscriptRequest{
|
||||||
|
Dst: audioFile,
|
||||||
|
Threads: uint32(envInt32("BACKEND_TEST_THREADS", 4)),
|
||||||
|
Temperature: 0.0,
|
||||||
|
Stream: true,
|
||||||
|
})
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
var deltas []string
|
||||||
|
var assembled strings.Builder
|
||||||
|
var finalText string
|
||||||
|
for {
|
||||||
|
chunk, err := stream.Recv()
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
if d := chunk.GetDelta(); d != "" {
|
||||||
|
deltas = append(deltas, d)
|
||||||
|
assembled.WriteString(d)
|
||||||
|
}
|
||||||
|
if final := chunk.GetFinalResult(); final != nil && final.GetText() != "" {
|
||||||
|
finalText = final.GetText()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// At least one of: a delta arrived, or the final event carried text.
|
||||||
|
Expect(deltas).NotTo(BeEmpty(),
|
||||||
|
"AudioTranscriptionStream did not emit any deltas (assembled=%q final=%q)",
|
||||||
|
assembled.String(), finalText)
|
||||||
|
|
||||||
|
// If both arrived, the final event should match the assembled deltas.
|
||||||
|
if finalText != "" && assembled.Len() > 0 {
|
||||||
|
Expect(finalText).To(Equal(assembled.String()),
|
||||||
|
"final transcript should match concatenated deltas")
|
||||||
|
}
|
||||||
|
GinkgoWriter.Printf("AudioTranscriptionStream: deltas=%d assembled=%q final=%q\n",
|
||||||
|
len(deltas), assembled.String(), finalText)
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
// extractImage runs `docker create` + `docker export` to materialise the image
|
// extractImage runs `docker create` + `docker export` to materialise the image
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue