Skip to content

Commit

Permalink
feat(whisper): add translate option (#2649)
Browse files Browse the repository at this point in the history
Signed-off-by: Ettore Di Giacinto <[email protected]>
  • Loading branch information
mudler authored Jun 24, 2024
1 parent 9e6dec0 commit 03b1cf5
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 8 deletions.
1 change: 1 addition & 0 deletions backend/backend.proto
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ message TranscriptRequest {
string dst = 2;
string language = 3;
uint32 threads = 4;
bool translate = 5;
}

message TranscriptResult {
Expand Down
6 changes: 5 additions & 1 deletion backend/go/transcribe/transcript.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func audioToWav(src, dst string) error {
return nil
}

func Transcript(model whisper.Model, audiopath, language string, threads uint) (schema.TranscriptionResult, error) {
func Transcript(model whisper.Model, audiopath, language string, translate bool, threads uint) (schema.TranscriptionResult, error) {
res := schema.TranscriptionResult{}

dir, err := os.MkdirTemp("", "whisper")
Expand Down Expand Up @@ -75,6 +75,10 @@ func Transcript(model whisper.Model, audiopath, language string, threads uint) (
context.SetLanguage("auto")
}

if translate {
context.SetTranslate(true)
}

if err := context.Process(data, nil, nil); err != nil {
return res, err
}
Expand Down
2 changes: 1 addition & 1 deletion backend/go/transcribe/whisper.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@ func (sd *Whisper) Load(opts *pb.ModelOptions) error {
}

func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (schema.TranscriptionResult, error) {
return Transcript(sd.whisper, opts.Dst, opts.Language, uint(opts.Threads))
return Transcript(sd.whisper, opts.Dst, opts.Language, opts.Translate, uint(opts.Threads))
}
9 changes: 5 additions & 4 deletions core/backend/transcript.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
model "github.com/mudler/LocalAI/pkg/model"
)

func ModelTranscription(audio, language string, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
func ModelTranscription(audio, language string, translate bool, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {

opts := modelOpts(backendConfig, appConfig, []model.Option{
model.WithBackendString(model.WhisperBackend),
Expand All @@ -31,8 +31,9 @@ func ModelTranscription(audio, language string, ml *model.ModelLoader, backendCo
}

return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
Dst: audio,
Language: language,
Threads: uint32(*backendConfig.Threads),
Dst: audio,
Language: language,
Translate: translate,
Threads: uint32(*backendConfig.Threads),
})
}
3 changes: 2 additions & 1 deletion core/cli/transcript.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type TranscriptCMD struct {
Backend string `short:"b" default:"whisper" help:"Backend to run the transcription model"`
Model string `short:"m" required:"" help:"Model name to run the TTS"`
Language string `short:"l" help:"Language of the audio file"`
Translate bool `short:"t" help:"Translate the transcription to english"`
Threads int `short:"t" default:"1" help:"Number of threads used for parallel computation"`
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
BackendAssetsPath string `env:"LOCALAI_BACKEND_ASSETS_PATH,BACKEND_ASSETS_PATH" type:"path" default:"/tmp/localai/backend_data" help:"Path used to extract libraries that are required by some of the backends in runtime" group:"storage"`
Expand Down Expand Up @@ -50,7 +51,7 @@ func (t *TranscriptCMD) Run(ctx *cliContext.Context) error {
}
}()

tr, err := backend.ModelTranscription(t.Filename, t.Language, ml, c, opts)
tr, err := backend.ModelTranscription(t.Filename, t.Language, t.Translate, ml, c, opts)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion core/http/endpoints/openai/transcription.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a

log.Debug().Msgf("Audio file copied to: %+v", dst)

tr, err := backend.ModelTranscription(dst, input.Language, ml, *config, appConfig)
tr, err := backend.ModelTranscription(dst, input.Language, input.Translate, ml, *config, appConfig)
if err != nil {
return err
}
Expand Down
3 changes: 3 additions & 0 deletions core/schema/prediction.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ type PredictionOptions struct {
// Also part of the OpenAI official spec
Language string `json:"language"`

// Only for audio transcription
Translate bool `json:"translate"`

// Also part of the OpenAI official spec. use it for returning multiple results
N int `json:"n"`

Expand Down

0 comments on commit 03b1cf5

Please sign in to comment.