From 6704a81255ab637a0e9c9faebb0bdefec913362f Mon Sep 17 00:00:00 2001 From: Brian Murray Date: Fri, 14 Apr 2023 09:52:10 -0600 Subject: [PATCH] go : exposed various parts to the Go Interface (#697) --- bindings/go/params.go | 4 ++++ bindings/go/pkg/whisper/context.go | 15 ++++++++++++--- bindings/go/pkg/whisper/interface.go | 8 +++++--- bindings/go/whisper.go | 10 +++++++++- 4 files changed, 30 insertions(+), 7 deletions(-) diff --git a/bindings/go/params.go b/bindings/go/params.go index c413895..1ddcbea 100644 --- a/bindings/go/params.go +++ b/bindings/go/params.go @@ -105,6 +105,10 @@ func (p *Params) SetMaxSegmentLength(n int) { p.max_len = C.int(n) } +func (p *Params) SetTokenTimestamps(b bool) { + p.token_timestamps = toBool(b) +} + // Set max tokens per segment (0 = no limit) func (p *Params) SetMaxTokensPerSegment(n int) { p.max_tokens = C.int(n) diff --git a/bindings/go/pkg/whisper/context.go b/bindings/go/pkg/whisper/context.go index 0a6e9cb..593b32b 100644 --- a/bindings/go/pkg/whisper/context.go +++ b/bindings/go/pkg/whisper/context.go @@ -111,6 +111,11 @@ func (context *context) SetMaxSegmentLength(n uint) { context.params.SetMaxSegmentLength(int(n)) } +// Set token timestamps flag +func (context *context) SetTokenTimestamps(b bool) { + context.params.SetTokenTimestamps(b) +} + // Set max tokens per segment (0 = no limit) func (context *context) SetMaxTokensPerSegment(n uint) { context.params.SetMaxTokensPerSegment(int(n)) @@ -280,10 +285,14 @@ func toSegment(ctx *whisper.Context, n int) Segment { func toTokens(ctx *whisper.Context, n int) []Token { result := make([]Token, ctx.Whisper_full_n_tokens(n)) for i := 0; i < len(result); i++ { + data := ctx.Whisper_full_get_token_data(n, i) + result[i] = Token{ - Id: int(ctx.Whisper_full_get_token_id(n, i)), - Text: strings.TrimSpace(ctx.Whisper_full_get_token_text(n, i)), - P: ctx.Whisper_full_get_token_p(n, i), + Id: int(ctx.Whisper_full_get_token_id(n, i)), + Text: ctx.Whisper_full_get_token_text(n, i), + P: ctx.Whisper_full_get_token_p(n, i), + Start: time.Duration(data.T0()) * time.Millisecond * 10, + End: time.Duration(data.T1()) * time.Millisecond * 10, } } return result diff --git a/bindings/go/pkg/whisper/interface.go b/bindings/go/pkg/whisper/interface.go index a1d3f68..e65fed1 100644 --- a/bindings/go/pkg/whisper/interface.go +++ b/bindings/go/pkg/whisper/interface.go @@ -41,6 +41,7 @@ type Context interface { SetTokenThreshold(float32) // Set timestamp token probability threshold SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold SetMaxSegmentLength(uint) // Set max segment length in characters + SetTokenTimestamps(bool) // Set token timestamps flag SetMaxTokensPerSegment(uint) // Set max tokens per segment (0 = no limit) // Process mono audio data and return any errors. @@ -85,7 +86,8 @@ type Segment struct { // Token is a text or special token type Token struct { - Id int - Text string - P float32 + Id int + Text string + P float32 + Start, End time.Duration } diff --git a/bindings/go/whisper.go b/bindings/go/whisper.go index d47f7f7..babadf0 100644 --- a/bindings/go/whisper.go +++ b/bindings/go/whisper.go @@ -356,7 +356,7 @@ func (ctx *Context) Whisper_full_get_token_id(segment int, token int) Token { // Get token data for the specified token in the specified segment. // This contains probabilities, timestamps, etc. -func (ctx *Context) whisper_full_get_token_data(segment int, token int) TokenData { +func (ctx *Context) Whisper_full_get_token_data(segment int, token int) TokenData { return TokenData(C.whisper_full_get_token_data((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token))) } @@ -407,3 +407,11 @@ func callEncoderBegin(user_data unsafe.Pointer) C.bool { } return true } + +func (t TokenData) T0() int64 { + return int64(t.t0) +} + +func (t TokenData) T1() int64 { + return int64(t.t1) +}