diff --git a/.gitignore b/.gitignore index 4effe81..29cf029 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,5 @@ models/* third_party/whisper.cpp .env dist + +/.vscode diff --git a/cmd/crtowebsocket/README.md b/cmd/crtowebsocket/README.md new file mode 100644 index 0000000..b00ac90 --- /dev/null +++ b/cmd/crtowebsocket/README.md @@ -0,0 +1,36 @@ +# Chromium to WebSocket + +This is a simple HTTP server that listens for WebSpeech connections from a +Chromium browser, converts audio data to WAV format, streams it to a WebSocket +client, receives text data from the WebSocket client, and sends it back to the +Chromium browser. + +## Usage + +```bash +go run . +``` + +## Building + +```bash +go build -o crtowebsocket . +``` + +## Running + +```bash +./crtowebocket +``` + +## Building for release + +```bash +go build -o crtowebsocket . +``` + +## Running for release + +```bash +./crtowebsocket +``` diff --git a/cmd/crtowebsocket/main.go b/cmd/crtowebsocket/main.go new file mode 100644 index 0000000..acbf1c6 --- /dev/null +++ b/cmd/crtowebsocket/main.go @@ -0,0 +1,91 @@ +package main + +import ( + "net/http" + "os" + "strconv" + "time" + + "github.com/brave-experiments/go-stt/cr_api_websocket_proxy" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "github.com/urfave/cli/v2" +) + +// Configuration for the remote WebSocket STT service +const ( + version = "1" + defaultListenAddress = "127.0.0.1:8090" + defaultWebsocketURL = "ws://127.0.0.1:8080/api-speech-wss/" +) + +func main() { + zerolog.SetGlobalLevel(zerolog.InfoLevel) + log.Logger = log.Output( + zerolog.ConsoleWriter{ + Out: os.Stderr, + NoColor: true, + }, + ) + zerolog.CallerMarshalFunc = func(pc uintptr, file string, line int) string { + short := file + for i := len(file) - 1; i > 0; i-- { + if file[i] == '/' { + short = file[i+1:] + break + } + } + file = short + return file + ":" + strconv.Itoa(line) + } + + zerolog.SetGlobalLevel(zerolog.DebugLevel) + log.Logger = log.With().Caller().Logger() + + app := cli.NewApp() + app.Name = "Chromium WebSpeech API Endpoint to WebSocket proxy" + app.Version = version + app.Flags = []cli.Flag{ + &cli.StringFlag{ + Name: "listen-address", + Value: defaultListenAddress, + }, + &cli.StringFlag{ + Name: "websocket-url", + Value: defaultWebsocketURL, + }, + &cli.DurationFlag{ + Name: "timeout", + Value: 60 * time.Second, + }, + &cli.BoolFlag{ + Name: "try-to-finalize-text", + Value: false, + }, + } + app.Action = run + + if err := app.Run(os.Args); err != nil { + log.Fatal().Err(err) + } +} + +func run(c *cli.Context) error { + // Create a configuration struct + config := &cr_api_websocket_proxy.HandlerConfig{ + WebsocketURL: c.String("websocket-url"), + Timeout: c.Duration("timeout"), + TryToFinalizeText: c.Bool("try-to-finalize-text"), + } + + // Create a handler instance with the config + handler := cr_api_websocket_proxy.NewHandler(config) + + // Register handlers that have access to the config + http.HandleFunc("/up", handler.HandleUpstreamRequest) + http.HandleFunc("/down", handler.HandleDownstreamRequest) + + http.ListenAndServe(c.String("listen-address"), nil) + + return nil +} diff --git a/cmd/crtowebsocket/pprof.go b/cmd/crtowebsocket/pprof.go new file mode 100644 index 0000000..a822e57 --- /dev/null +++ b/cmd/crtowebsocket/pprof.go @@ -0,0 +1,7 @@ +//go:build pprof + +package main + +import ( + _ "net/http/pprof" +) diff --git a/cr_api_websocket_proxy/audio.go b/cr_api_websocket_proxy/audio.go new file mode 100644 index 0000000..959511e --- /dev/null +++ b/cr_api_websocket_proxy/audio.go @@ -0,0 +1,51 @@ +package cr_api_websocket_proxy + +import ( + "fmt" + "net/http" + + "azul3d.org/engine/audio" + "github.com/colega/zeropool" + + // Add flac decoder for decoding incoming audio + _ "azul3d.org/engine/audio/flac" +) + +const expectedSampleRate = 16000 + +const ( + samplesPerChunk = expectedSampleRate / 1000 * 20 // 20ms + bytesPerChunk = samplesPerChunk * 2 +) + +var audioSamplesBufferPool = zeropool.New( + func() audio.Int16 { + return make( + audio.Int16, + samplesPerChunk, + ) + }, +) + +var audioBytesBufferPool = zeropool.New( + func() []byte { + return make( + []byte, + bytesPerChunk, + ) + }, +) + +func NewAudioDecoder(req *http.Request) (audio.Decoder, error) { + dec, _, err := audio.NewDecoder(req.Body) + if err != nil { + return nil, err + } + + // Ensure we're working with the correct sample rate + if dec.Config().SampleRate != expectedSampleRate { + return nil, fmt.Errorf("unexpected sample rate: %d", dec.Config().SampleRate) + } + + return dec, nil +} diff --git a/cr_api_websocket_proxy/audio_test.go b/cr_api_websocket_proxy/audio_test.go new file mode 100644 index 0000000..955b912 --- /dev/null +++ b/cr_api_websocket_proxy/audio_test.go @@ -0,0 +1,112 @@ +package cr_api_websocket_proxy + +import ( + "bytes" + "io" + "net/http" + "os" + "testing" + + "azul3d.org/engine/audio" +) + +func TestAudioBufferPools(t *testing.T) { + t.Run("samples buffer pool", func(t *testing.T) { + samples := audioSamplesBufferPool.Get() + if samples == nil { + t.Error("expected non-nil samples buffer") + } + if got := len(samples); got != samplesPerChunk { + t.Errorf("samples buffer length = %v, want %v", got, samplesPerChunk) + } + + audioSamplesBufferPool.Put(samples) + }) + + t.Run("bytes buffer pool", func(t *testing.T) { + buffer := audioBytesBufferPool.Get() + if buffer == nil { + t.Error("expected non-nil bytes buffer") + } + if got := len(buffer); got != bytesPerChunk { + t.Errorf("bytes buffer length = %v, want %v", got, bytesPerChunk) + } + + audioBytesBufferPool.Put(buffer) + }) +} + +type mockBody struct { + *bytes.Buffer +} + +func (m mockBody) Close() error { + return nil +} + +func TestFlacDecoder_InvalidData(t *testing.T) { + t.Run("invalid audio data", func(t *testing.T) { + req := &http.Request{ + Body: mockBody{bytes.NewBuffer([]byte("invalid audio data"))}, + } + + decoder, err := NewAudioDecoder(req) + if err == nil { + t.Error("expected error for invalid audio data") + } + if decoder != nil { + t.Error("expected nil decoder for invalid audio data") + } + }) +} + +func TestFlacDecoder_ValidData(t *testing.T) { + req := &http.Request{ + Body: mockBody{bytes.NewBuffer(readTestFile(t, "testdata/16khz.flac"))}, + } + + decoder, err := NewAudioDecoder(req) + if err != nil { + t.Fatalf("failed to create decoder: %v", err) + } + + // Verify decoder config + config := decoder.Config() + if config.SampleRate != expectedSampleRate { + t.Errorf("sample rate = %v, want %v", config.SampleRate, expectedSampleRate) + } + + // Try reading some samples + samples := make(audio.Int16, 1024) + n, err := decoder.Read(samples) + if err != nil && err != io.EOF { + t.Errorf("failed to read samples: %v", err) + } + if n == 0 { + t.Error("expected to read some samples") + } +} + +func TestFlacDecoder_InvalidSampleRate(t *testing.T) { + req := &http.Request{ + Body: mockBody{bytes.NewBuffer(readTestFile(t, "testdata/8khz.flac"))}, + } + + decoder, err := NewAudioDecoder(req) + if decoder != nil { + t.Error("expected nil decoder for invalid sample rate") + } + if err == nil { + t.Error("expected error for invalid sample rate") + } +} + +// Helper to read test file contents +func readTestFile(t *testing.T, path string) []byte { + t.Helper() + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("failed to read test file %s: %v", path, err) + } + return data +} diff --git a/cr_api_websocket_proxy/chromium_handlers.go b/cr_api_websocket_proxy/chromium_handlers.go new file mode 100644 index 0000000..7762cc8 --- /dev/null +++ b/cr_api_websocket_proxy/chromium_handlers.go @@ -0,0 +1,191 @@ +package cr_api_websocket_proxy + +import ( + "context" + "encoding/binary" + "net/http" + "time" + + "github.com/brave-experiments/go-stt/google_streaming_api" + + "github.com/rs/zerolog/log" +) + +// HandleUpstreamRequest handles the upstream request from Chromium. +func (h *Handler) HandleUpstreamRequest( + w http.ResponseWriter, + req *http.Request, +) { + pairContext := pairContexts.getOrCreatePairContextForRequest(h, req) + if pairContext == nil { + return + } + + defer pairContext.Close() + defer close(pairContext.audioChan) + + log.Debug().Msgf("[%s] [UPSTREAM] Start", pairContext.pair) + defer log.Debug().Msgf("[%s] [UPSTREAM] Done", pairContext.pair) + + dec, err := NewAudioDecoder(req) + if err != nil { + log.Warn().Msgf("[%s] [UPSTREAM] Failed to create audio decoder: %v", pairContext.pair, err) + return + } + + samples := audioSamplesBufferPool.Get() + defer audioSamplesBufferPool.Put(samples) + + for { + select { + case <-pairContext.ctx.Done(): + return + case <-req.Context().Done(): + log.Debug().Msgf("[%s] [UPSTREAM] Request context done: %v", pairContext.pair, req.Context().Err()) + return + default: + n, err := dec.Read(samples) + if n > 0 { + bytes := audioBytesBufferPool.Get() + // Convert int16 samples to bytes + for i, sample := range samples[:n] { + binary.LittleEndian.PutUint16(bytes[i*2:], uint16(sample)) + } + pairContext.audioChan <- bytes[:n*2] + } + if err != nil { + log.Warn().Msgf("[%s] [UPSTREAM] Failed to read audio: %v", pairContext.pair, err) + return + } + } + } +} + +// HandleDownstreamRequest handles the downstream request to Chromium. +func (h *Handler) HandleDownstreamRequest( + w http.ResponseWriter, + req *http.Request, +) { + pairContext := pairContexts.getOrCreatePairContextForRequest(h, req) + if pairContext == nil { + return + } + + defer pairContext.Close() + + log.Debug().Msgf("[%s] [DOWNSTREAM] Start", pairContext.pair) + defer log.Debug().Msgf("[%s] [DOWNSTREAM] Done", pairContext.pair) + + sentFinalIndices := make( + map[int]bool, + ) + for { + select { + case <-pairContext.ctx.Done(): + return + case <-req.Context().Done(): + log.Debug().Msgf("[%s] [DOWNSTREAM] Request context done: %v", pairContext.pair, req.Context().Err()) + return + case segments, ok := <-pairContext.results: + if !ok { + return + } + hasMessages := false + message := google_streaming_api.NewRecognitionMessage() + for i, segment := range segments { + // Skip if segment is final and we've already sent it + if h.config.TryToFinalizeText && segment.Final && sentFinalIndices[i] { + continue + } + + message.Add(segment.Text, h.config.TryToFinalizeText && segment.Final) + hasMessages = true + // Track final segments we've sent + if h.config.TryToFinalizeText && segment.Final { + sentFinalIndices[i] = true + } + } + + // Only send if we have messages to send + if hasMessages { + bytes, err := message.Serialize() + if err != nil { + log.Warn().Msgf("[%s] [DOWNSTREAM] Failed to serialize message: %v", pairContext.pair, err) + return + } + binary.Write(w, binary.BigEndian, uint32(len(bytes))) + w.Write(bytes) + if fl, ok := w.(http.Flusher); ok { + fl.Flush() + } + } + } + } +} + +func (hs *PairContexts) getOrCreatePairContextForRequest( + h *Handler, + req *http.Request, +) *PairContext { + isUpstreamRequest := req.URL.Path == "/up" + isDownstreamRequest := req.URL.Path == "/down" + + if !isUpstreamRequest && !isDownstreamRequest { + return nil + } + + err := req.ParseForm() + if err != nil { + log.Debug().Msgf("[UPSTREAM] Failed to parse form: %v", err) + return nil + } + + pair := req.FormValue("pair") + if pair == "" { + log.Debug().Msgf("[UPSTREAM] Pair is empty") + return nil + } + + pairContext := hs.GetOrCreate( + pair, + func(pair string) *PairContext { + ctx, cancel := context.WithDeadline( + context.Background(), + time.Now().Add(h.config.Timeout), + ) + pairContext := &PairContext{ + pair: pair, + audioChan: make(chan []byte, bytesPerChunk), + results: make(chan []TextSegment, 10), + ctx: ctx, + cancel: cancel, + } + return pairContext + }, + ) + + if isUpstreamRequest { + lang := req.FormValue("lang") + if len(lang) > 2 { + lang = lang[:2] + } else { + log.Warn().Msgf("[UPSTREAM] Language is empty, using default: %s", lang) + lang = "en" + } + pairContext.lang = lang + go processAudioOverWebSocket(h, pairContext) + + pairContext.SetUpstreamConnected() + } + + if isDownstreamRequest { + pairContext.SetDownstreamConnected() + } + + if pairContext.IsPaired() { + log.Info().Msgf("[%s] PairContext is paired", pair) + hs.Remove(pair) + } + + return pairContext +} diff --git a/cr_api_websocket_proxy/chromium_handlers_test.go b/cr_api_websocket_proxy/chromium_handlers_test.go new file mode 100644 index 0000000..731b5e7 --- /dev/null +++ b/cr_api_websocket_proxy/chromium_handlers_test.go @@ -0,0 +1,392 @@ +package cr_api_websocket_proxy + +import ( + "context" + "encoding/binary" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/brave-experiments/go-stt/google_streaming_api" + "github.com/gorilla/websocket" + "google.golang.org/protobuf/proto" +) + +// Add this type at the top level +type infiniteReader struct { + audio []byte + done chan struct{} + position int +} + +func newInfiniteReader(audio []byte) *infiniteReader { + return &infiniteReader{ + audio: audio, + done: make(chan struct{}), + position: 0, + } +} + +func (r *infiniteReader) Read(p []byte) (n int, err error) { + select { + case <-r.done: + return 0, io.EOF + default: + if r.position >= len(r.audio) { + return 0, nil + } + + // Copy remaining audio data + n = copy(p, r.audio[r.position:]) + r.position += n + // Add small delay between iterations + time.Sleep(10 * time.Millisecond) + return n, nil + } +} + +func (r *infiniteReader) Close() { + close(r.done) +} + +// Add at the top level +type responseRecorder struct { + io.Writer + done chan struct{} + http.Flusher +} + +func (r *responseRecorder) Header() http.Header { + return make(http.Header) +} + +func (r *responseRecorder) Write(b []byte) (int, error) { + return r.Writer.Write(b) +} + +func (r *responseRecorder) WriteHeader(statusCode int) {} + +func (r *responseRecorder) Flush() {} + +func (r *responseRecorder) CloseNotify() <-chan bool { + notify := make(chan bool, 1) + go func() { + <-r.done + notify <- true + }() + return notify +} + +// Test Helpers and Mocks +// ---------------------------------------- + +// Test WebSocket Server Setup +// ---------------------------------------- + +func setupTestWebSocketServer(t *testing.T) *httptest.Server { + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + } + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("Failed to upgrade connection: %v", err) + return + } + defer conn.Close() + + for { + messageType, _, err := conn.ReadMessage() + if err != nil { + return + } + if messageType == websocket.BinaryMessage { + response := "test transcription" + err = conn.WriteMessage(websocket.TextMessage, []byte(response)) + if err != nil { + t.Errorf("Failed to write message: %v", err) + return + } + } + } + })) + return wsServer +} + +// Test Setup Helpers +// ---------------------------------------- + +func setupTestHandler(wsURL string) *Handler { + return NewHandler(&HandlerConfig{ + Timeout: 5 * time.Second, + WebsocketURL: wsURL, + TryToFinalizeText: true, + }) +} + +func createTestRequest(method, path, pair, lang string, body io.Reader) *http.Request { + form := url.Values{} + form.Add("pair", pair) + if lang != "" { + form.Add("lang", lang) + } + return httptest.NewRequest(method, path+"?"+form.Encode(), body) +} + +// Tests +// ---------------------------------------- + +func TestHandleUpstreamRequest(t *testing.T) { + tests := []struct { + name string + path string + pairParam string + langParam string + wantStatus int + }{ + { + name: "valid upstream request", + path: "/up", + pairParam: "test-pair-1", + langParam: "en", + wantStatus: http.StatusOK, + }, + { + name: "missing pair parameter", + path: "/up", + pairParam: "", + langParam: "en", + wantStatus: http.StatusOK, // The handler doesn't set status codes explicitly + }, + { + name: "invalid path", + path: "/invalid", + pairParam: "test-pair-1", + langParam: "en", + wantStatus: http.StatusOK, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + handler := NewHandler(&HandlerConfig{ + Timeout: 5 * time.Second, + WebsocketURL: "ws://localhost:8080", + }) + + form := url.Values{} + form.Add("pair", tt.pairParam) + form.Add("lang", tt.langParam) + + req := httptest.NewRequest(http.MethodPost, tt.path+"?"+form.Encode(), strings.NewReader("")) + w := httptest.NewRecorder() + + handler.HandleUpstreamRequest(w, req) + + if got := w.Code; got != tt.wantStatus { + t.Errorf("HandleUpstreamRequest() status = %v, want %v", got, tt.wantStatus) + } + }) + } +} + +func TestPairContextCreation(t *testing.T) { + handler := NewHandler(&HandlerConfig{ + Timeout: 5 * time.Second, + WebsocketURL: "ws://localhost:8080", + }) + + t.Run("creates new pair context", func(t *testing.T) { + form := url.Values{} + form.Add("pair", "test-pair-2") + form.Add("lang", "en") + + req := httptest.NewRequest(http.MethodPost, "/up?"+form.Encode(), strings.NewReader("")) + + pairContext := pairContexts.getOrCreatePairContextForRequest(handler, req) + if pairContext == nil { + t.Fatal("expected non-nil pair context") + } + if got := pairContext.pair; got != "test-pair-2" { + t.Errorf("pair = %v, want %v", got, "test-pair-2") + } + if got := pairContext.lang; got != "en" { + t.Errorf("lang = %v, want %v", got, "en") + } + }) + + t.Run("reuses existing pair context", func(t *testing.T) { + pair := "test-pair-3" + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Create initial pair context + firstContext := &PairContext{ + pair: pair, + audioChan: make(chan []byte, bytesPerChunk), + results: make(chan []TextSegment, 10), + ctx: ctx, + cancel: cancel, + lang: "en", + } + pairContexts.pairContexts[pair] = firstContext + + // Try to get the same pair context + form := url.Values{} + form.Add("pair", pair) + form.Add("lang", "es") // Different language shouldn't create new context + + req := httptest.NewRequest(http.MethodPost, "/up?"+form.Encode(), strings.NewReader("")) + + secondContext := pairContexts.getOrCreatePairContextForRequest(handler, req) + if secondContext == nil { + t.Fatal("expected non-nil pair context") + } + if secondContext != firstContext { + t.Error("expected same context to be reused") + } + }) +} + +func TestPairContextLifecycle(t *testing.T) { + handler := NewHandler(&HandlerConfig{ + Timeout: 2 * time.Second, + WebsocketURL: "ws://localhost:8080", + }) + + t.Run("context times out", func(t *testing.T) { + form := url.Values{} + form.Add("pair", "test-pair-4") + form.Add("lang", "en") + + req := httptest.NewRequest(http.MethodPost, "/up?"+form.Encode(), strings.NewReader("")) + + pairContext := pairContexts.getOrCreatePairContextForRequest(handler, req) + if pairContext == nil { + t.Fatal("expected non-nil pair context") + } + + // Wait for timeout + time.Sleep(3 * time.Second) + + // Context should be done + select { + case <-pairContext.ctx.Done(): + // Success + default: + t.Error("Context should have timed out") + } + }) +} + +func TestUpDownPairWithWebSocket(t *testing.T) { + wsServer := setupTestWebSocketServer(t) + defer wsServer.Close() + + wsURL := "ws" + strings.TrimPrefix(wsServer.URL, "http") + "/" + handler := setupTestHandler(wsURL) + pairID := "test-pair-websocket" + + // Start upstream request + go handleUpstreamRequest(t, handler, pairID) + + // Start downstream request and wait for results + if err := handleDownstreamRequestAndVerify(t, handler, pairID); err != nil { + t.Error(err) + } +} + +// Helper functions for TestUpDownPairWithWebSocket +func handleUpstreamRequest(t *testing.T, handler *Handler, pairID string) { + reader := newInfiniteReader(readTestFile(t, "testdata/16khz.flac")) + defer reader.Close() + + req := createTestRequest(http.MethodPost, "/up", pairID, "en", reader) + req.Header.Set("Content-Type", "audio/x-flac; rate=16000") + req.Header.Set("Transfer-Encoding", "chunked") + + w := httptest.NewRecorder() + handler.HandleUpstreamRequest(w, req) +} + +func handleDownstreamRequestAndVerify(t *testing.T, handler *Handler, pairID string) error { + receivedText := make(chan string, 1) + + req := createTestRequest(http.MethodGet, "/down", pairID, "", nil) + req.Header.Set("Connection", "keep-alive") + + pr, pw := io.Pipe() + downW := &responseRecorder{ + Writer: pw, + done: make(chan struct{}), + } + + go processDownstreamResponse(t, pr, receivedText) + go func() { + defer pw.Close() + handler.HandleDownstreamRequest(downW, req) + close(downW.done) + }() + + select { + case text := <-receivedText: + t.Logf("Test passed with transcription: %s", text) + return nil + case <-time.After(2 * time.Second): + return fmt.Errorf("timeout waiting for transcription response") + } +} + +func processDownstreamResponse(t *testing.T, pr *io.PipeReader, receivedText chan<- string) { + defer pr.Close() + + for { + message, err := readProtobufMessage(pr) + if err != nil { + if err != io.EOF { + t.Errorf("Failed to read message: %v", err) + } + return + } + + fullText := extractFullText(message) + if strings.Contains(fullText, "test transcription") { + receivedText <- fullText + return + } + } +} + +func readProtobufMessage(r io.Reader) (*google_streaming_api.SpeechRecognitionEvent, error) { + var length uint32 + if err := binary.Read(r, binary.BigEndian, &length); err != nil { + return nil, err + } + + data := make([]byte, length) + if _, err := io.ReadAtLeast(r, data, int(length)); err != nil { + return nil, err + } + + message := &google_streaming_api.SpeechRecognitionEvent{} + if err := proto.Unmarshal(data, message); err != nil { + return nil, err + } + + return message, nil +} + +func extractFullText(message *google_streaming_api.SpeechRecognitionEvent) string { + var fullText string + for _, result := range message.Result { + if len(result.Alternative) > 0 { + fullText += *result.Alternative[0].Transcript + } + } + return fullText +} diff --git a/cr_api_websocket_proxy/handler.go b/cr_api_websocket_proxy/handler.go new file mode 100644 index 0000000..49979de --- /dev/null +++ b/cr_api_websocket_proxy/handler.go @@ -0,0 +1,22 @@ +package cr_api_websocket_proxy + +import "time" + +// HandlerConfig holds configuration for the WebSocket handlers +type HandlerConfig struct { + WebsocketURL string + Timeout time.Duration + TryToFinalizeText bool +} + +// Handler holds the configuration and provides methods that return http.HandlerFunc +type Handler struct { + config *HandlerConfig +} + +// NewHandler creates a new Handler with the given configuration +func NewHandler(config *HandlerConfig) *Handler { + return &Handler{ + config: config, + } +} diff --git a/cr_api_websocket_proxy/pair_context.go b/cr_api_websocket_proxy/pair_context.go new file mode 100644 index 0000000..9505b51 --- /dev/null +++ b/cr_api_websocket_proxy/pair_context.go @@ -0,0 +1,102 @@ +package cr_api_websocket_proxy + +import ( + "context" + "strings" + "sync" + "time" + + "github.com/rs/zerolog/log" +) + +const textPartFinalizeDelay = 5 * time.Second // Time to wait before marking text as final + +type TextSegment struct { + Text string + Timestamp time.Time + Final bool +} + +type PairContext struct { + sync.Mutex + pair string + lang string + audioChan chan []byte + results chan []TextSegment + ctx context.Context + cancel context.CancelFunc + upstreamConnected bool + downstreamConnected bool + textCache []TextSegment +} + +func (h *PairContext) SetUpstreamConnected() { + h.Lock() + h.upstreamConnected = true + h.Unlock() +} + +func (h *PairContext) SetDownstreamConnected() { + h.Lock() + h.downstreamConnected = true + h.Unlock() +} + +func (h *PairContext) IsPaired() bool { + h.Lock() + defer h.Unlock() + return h.upstreamConnected && h.downstreamConnected +} + +func (h *PairContext) Close() { + log.Debug().Msgf("[%s] Closing pairContext", h.pair) + h.cancel() // Signal all goroutines to stop +} + +func (h *PairContext) updateTextCache(newText string) { + h.Lock() + defer h.Unlock() + + // Split text into parts + parts := strings.Split(newText, " ") + + currentTime := time.Now() + newCache := make([]TextSegment, len(parts)) + + // Process each part + for i, part := range parts { + if i < len(parts)-1 { + part += " " + } + + // Check if part exists in cache + found := false + for _, cached := range h.textCache { + if cached.Text == part { + // Keep existing cache entry + newCache[i] = cached + found = true + break + } + } + + if !found { + // Add new part + newCache[i] = TextSegment{ + Text: part, + Timestamp: currentTime, + Final: false, + } + } + } + + // Mark parts as final if they haven't changed for textPartFinalizeDelay + for i := range newCache { + if !newCache[i].Final && currentTime.Sub(newCache[i].Timestamp) > textPartFinalizeDelay { + newCache[i].Final = true + } + } + + h.textCache = newCache + h.results <- newCache +} diff --git a/cr_api_websocket_proxy/pair_contexts.go b/cr_api_websocket_proxy/pair_contexts.go new file mode 100644 index 0000000..4ab4334 --- /dev/null +++ b/cr_api_websocket_proxy/pair_contexts.go @@ -0,0 +1,38 @@ +package cr_api_websocket_proxy + +import ( + "sync" +) + +type PairContexts struct { + sync.Mutex + pairContexts map[string]*PairContext +} + +func (h *PairContexts) GetOrCreate(key string, createFn func(key string) *PairContext) *PairContext { + h.Lock() + defer h.Unlock() + + pairContext, exists := h.pairContexts[key] + if !exists { + pairContext = createFn(key) + h.pairContexts[key] = pairContext + + go func() { + <-pairContext.ctx.Done() + h.Remove(key) + }() + } + return pairContext +} + +func (h *PairContexts) Remove(key string) { + h.Lock() + defer h.Unlock() + + delete(h.pairContexts, key) +} + +var pairContexts = &PairContexts{ + pairContexts: make(map[string]*PairContext), +} diff --git a/cr_api_websocket_proxy/testdata/16khz.flac b/cr_api_websocket_proxy/testdata/16khz.flac new file mode 100644 index 0000000..5de06eb Binary files /dev/null and b/cr_api_websocket_proxy/testdata/16khz.flac differ diff --git a/cr_api_websocket_proxy/testdata/8khz.flac b/cr_api_websocket_proxy/testdata/8khz.flac new file mode 100644 index 0000000..a633c41 Binary files /dev/null and b/cr_api_websocket_proxy/testdata/8khz.flac differ diff --git a/cr_api_websocket_proxy/wav_writer_debug.go b/cr_api_websocket_proxy/wav_writer_debug.go new file mode 100644 index 0000000..d3deaa4 --- /dev/null +++ b/cr_api_websocket_proxy/wav_writer_debug.go @@ -0,0 +1,85 @@ +//go:build wav_write + +package cr_api_websocket_proxy + +import ( + "encoding/binary" + "os" + "time" + + "github.com/rs/zerolog/log" +) + +type wavFileWriter struct { + file *os.File +} + +func initWavWriter(writer **wavFileWriter, pairID string) { + // Create WAV file with timestamp name + timestamp := time.Now().Format("20060102-150405") + filename := timestamp + ".wav" + f, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + log.Error().Err(err).Msgf("[%s] Failed to create WAV file", pairID) + return + } + + // Write WAV header if file is new (size 0) + info, _ := f.Stat() + if info.Size() == 0 { + header := []byte{ + 'R', 'I', 'F', 'F', // ChunkID + 0, 0, 0, 0, // ChunkSize (to be updated on close) + 'W', 'A', 'V', 'E', // Format + 'f', 'm', 't', ' ', // Subchunk1ID + 16, 0, 0, 0, // Subchunk1Size + 1, 0, // AudioFormat (PCM) + 1, 0, // NumChannels (Mono) + 0x80, 0x3E, 0, 0, // SampleRate (16000) + 0x00, 0x7D, 0, 0, // ByteRate + 2, 0, // BlockAlign + 16, 0, // BitsPerSample + 'd', 'a', 't', 'a', // Subchunk2ID + 0, 0, 0, 0, // Subchunk2Size (to be updated on close) + } + f.Write(header) + } + + *writer = &wavFileWriter{file: f} + log.Info().Msgf("[%s] WAV recording started: %s", pairID, filename) +} + +func writeToWavFile(writer *wavFileWriter, audioBytes []byte) { + if writer != nil && writer.file != nil { + writer.file.Write(audioBytes) + } +} + +func closeWavWriter(writer *wavFileWriter) { + if writer == nil || writer.file == nil { + return + } + + log.Info().Msg("Closing WAV file") + f := writer.file + + // Update WAV header with final sizes before closing + f.Sync() + f.Seek(0, 0) + info, _ := f.Stat() + fileSize := info.Size() + log.Info().Msgf("File size: %d", fileSize) + + // Update ChunkSize + chunkSize := uint32(fileSize - 8) + f.Seek(4, 0) + binary.Write(f, binary.LittleEndian, chunkSize) + + // Update Subchunk2Size + dataSize := uint32(fileSize - 44) + f.Seek(40, 0) + binary.Write(f, binary.LittleEndian, dataSize) + + f.Close() + writer.file = nil +} diff --git a/cr_api_websocket_proxy/wav_writer_release.go b/cr_api_websocket_proxy/wav_writer_release.go new file mode 100644 index 0000000..3854e7c --- /dev/null +++ b/cr_api_websocket_proxy/wav_writer_release.go @@ -0,0 +1,19 @@ +//go:build !wav_write + +package cr_api_websocket_proxy + +// Empty implementations for release builds + +type wavFileWriter struct{} + +func initWavWriter(writer **wavFileWriter, pairID string) { + // Do nothing in release builds +} + +func writeToWavFile(writer *wavFileWriter, audioBytes []byte) { + // Do nothing in release builds +} + +func closeWavWriter(writer *wavFileWriter) { + // Do nothing in release builds +} diff --git a/cr_api_websocket_proxy/websocket_processor.go b/cr_api_websocket_proxy/websocket_processor.go new file mode 100644 index 0000000..4e7301b --- /dev/null +++ b/cr_api_websocket_proxy/websocket_processor.go @@ -0,0 +1,164 @@ +package cr_api_websocket_proxy + +import ( + "bytes" + "net/http" + "net/url" + "time" + "unicode/utf8" + + "github.com/gorilla/websocket" + "github.com/rs/zerolog/log" +) + +func processAudioOverWebSocket(h *Handler, pairContext *PairContext) { + defer pairContext.Close() + defer close(pairContext.results) + + // parse wsURL + u, err := url.Parse(h.config.WebsocketURL) + if err != nil { + log.Error(). + Err(err). + Msg("Failed to parse WebSocket URL") + return + } + + // set parameters + q := u.Query() + q.Set("output_native", "True") + q.Set("not_use_prompt", "False") + q.Set("denoise", "False") + q.Set("lang", pairContext.lang) + u.RawQuery = q.Encode() + wsURL := u.String() + + wsDialer := websocket.Dialer{HandshakeTimeout: 10 * time.Second} + headers := http.Header{} + + log.Debug().Msgf("[%s] Dialing %s", pairContext.pair, wsURL) + wsConn, _, err := wsDialer.Dial(wsURL, headers) + if wsConn == nil || err != nil { + log.Error().Err(err).Msg("Failed to connect to WebSocket") + return + } + log.Debug().Msgf("[%s] Connected to %s", pairContext.pair, wsURL) + + defer func() { + if wsConn == nil { + return + } + log.Debug().Msgf("[%s] Closing WebSocket connection", pairContext.pair) + wsConn.WriteControl( + websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), + time.Now().Add(time.Second), + ) + for { + if _, _, err := wsConn.NextReader(); err != nil { + break + } + } + }() + + // Configure WebSocket connection + wsConn.SetReadLimit(32768) // 32KB max message size + wsConn.SetWriteDeadline(time.Now().Add(60 * time.Second)) + wsConn.SetReadDeadline(time.Now().Add(60 * time.Second)) + wsConn.SetPongHandler( + func(string) error { + if wsConn != nil { + wsConn.SetReadDeadline(time.Now().Add(60 * time.Second)) + } + return nil + }, + ) + + // Handle the WebSocket disconnect + wsConn.SetCloseHandler( + func(code int, text string) error { + log.Debug().Msgf("[%s] WebSocket disconnected with code %d: %s", pairContext.pair, code, text) + pairContext.Close() + return nil + }, + ) + + // Handle incoming text messages + go func() { + for { + select { + case <-pairContext.ctx.Done(): + return + default: + messageType, message, err := wsConn.ReadMessage() + if err != nil { + if websocket.IsCloseError( + err, + websocket.CloseNormalClosure, + websocket.CloseGoingAway, + ) { + log.Debug().Msg("WebSocket closed normally") + } else { + log.Error().Err(err).Msg("WebSocket read error") + } + return + } + + log.Debug().Msgf("Read message type: %d", messageType) + + var text string + switch messageType { + case websocket.TextMessage: + text = string(message) + case websocket.BinaryMessage: + text = string(bytes.TrimRight(message, "\x00")) // Try to convert binary to UTF-8 string + if !utf8.ValidString(text) { + log.Warn().Msg("Received invalid UTF-8 in binary message") + continue + } + default: + log.Error().Msgf("Received unsupported message type: %d", messageType) + continue + } + + log.Debug().Msgf("Received text: %s", text) + pairContext.updateTextCache(text) + } + } + }() + + defer pairContext.Close() + + // Initialize WAV file writer for in enabled builds + var wavWriter *wavFileWriter + initWavWriter(&wavWriter, pairContext.pair) + defer closeWavWriter(wavWriter) + + for { + select { + case <-pairContext.ctx.Done(): + log.Debug().Msgf("[%s] PairContext cancelled", pairContext.pair) + return + case audioBytes, ok := <-pairContext.audioChan: + if !ok { + log.Debug().Msgf("[%s] Audio channel closed", pairContext.pair) + return + } + + // Write audio data to WAV file in debug builds + writeToWavFile(wavWriter, audioBytes) + + if wsConn == nil { + log.Error().Msg("WebSocket connection is nil") + return + } + + err := wsConn.WriteMessage(websocket.BinaryMessage, audioBytes) + audioBytesBufferPool.Put(audioBytes) + if err != nil { + log.Error().Err(err).Msg("Failed to send audio") + return + } + } + } +} diff --git a/go.mod b/go.mod index 27a8f39..aa14360 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,9 @@ go 1.20 require ( azul3d.org/engine v0.0.0-20211024043305-793ea6c2839d github.com/brave-experiments/whisper.cpp/bindings/go v0.0.0-20231102103318-0dad03e80436 + github.com/colega/zeropool v0.0.0-20230505084239-6fb4a4f75381 github.com/golang/protobuf v1.5.3 + github.com/gorilla/websocket v1.5.3 github.com/rs/zerolog v1.30.0 github.com/urfave/cli/v2 v2.25.7 google.golang.org/protobuf v1.31.0 diff --git a/go.sum b/go.sum index 0d788be..46288d6 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ azul3d.org/engine v0.0.0-20211024043305-793ea6c2839d h1:oLr8Nu7iEKm8OKfHNznBnBj8 azul3d.org/engine v0.0.0-20211024043305-793ea6c2839d/go.mod h1:3y1cwzJTKvXXop+EAg+AUVfNm3bfHf3djeX+l1UBuUE= github.com/brave-experiments/whisper.cpp/bindings/go v0.0.0-20231102103318-0dad03e80436 h1:8oivPf0lUKpMbN6V7QE6N4htLXMkAJI+aolXtzKL4eI= github.com/brave-experiments/whisper.cpp/bindings/go v0.0.0-20231102103318-0dad03e80436/go.mod h1:KKRZumBM2HhAnyL4papFJ01NHAaKDWluPID9/5JF4uA= +github.com/colega/zeropool v0.0.0-20230505084239-6fb4a4f75381 h1:d5EKgQfRQvO97jnISfR89AiCCCJMwMFoSxUiU0OGCRU= +github.com/colega/zeropool v0.0.0-20230505084239-6fb4a4f75381/go.mod h1:OU76gHeRo8xrzGJU3F3I1CqX1ekM8dfJw0+wPeMwnp0= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= @@ -15,6 +17,8 @@ github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/icza/bitio v1.1.0 h1:ysX4vtldjdi3Ygai5m1cWy4oLkhWTAi+SyO6HC8L9T0= github.com/icza/bitio v1.1.0/go.mod h1:0jGnlLAx8MKMr9VGnn/4YrvZiprkvBelsVIbA9Jjr9A= github.com/icza/mighty v0.0.0-20180919140131-cfd07d671de6 h1:8UsGZ2rr2ksmEru6lToqnXgA8Mz1DP11X4zSJ159C3k=