1//go:build integration
2// +build integration
3
4package transcribestreamingservice
5
6import (
7	"bytes"
8	"context"
9	"encoding/base64"
10	"flag"
11	"io"
12	"os"
13	"strings"
14	"sync"
15	"testing"
16
17	"github.com/aws/aws-sdk-go/aws"
18	"github.com/aws/aws-sdk-go/awstesting/integration"
19)
20
21var (
22	audioFilename   string
23	audioFormat     string
24	audioLang       string
25	audioSampleRate int
26	audioFrameSize  int
27	withDebug       bool
28)
29
30func init() {
31	flag.BoolVar(&withDebug, "debug", false, "Include debug logging with test.")
32	flag.StringVar(&audioFilename, "audio-file", "", "Audio file filename to perform test with.")
33	flag.StringVar(&audioLang, "audio-lang", LanguageCodeEnUs, "Language of audio speech.")
34	flag.StringVar(&audioFormat, "audio-format", MediaEncodingPcm, "Format of audio.")
35	flag.IntVar(&audioSampleRate, "audio-sample", 16000, "Sample rate of the audio.")
36	flag.IntVar(&audioFrameSize, "audio-frame", 15*1024, "Size of frames of audio uploaded.")
37}
38
39func TestInteg_StartStreamTranscription(t *testing.T) {
40	var audio io.Reader
41	if len(audioFilename) != 0 {
42		audioFile, err := os.Open(audioFilename)
43		if err != nil {
44			t.Fatalf("expect to open file, %v", err)
45		}
46		defer audioFile.Close()
47		audio = audioFile
48	} else {
49		b, err := base64.StdEncoding.DecodeString(
50			`UklGRjzxPQBXQVZFZm10IBAAAAABAAEAgD4AAAB9AAACABAAZGF0YVTwPQAAAAAAAAAAAAAAAAD//wIA/f8EAA==`,
51		)
52		if err != nil {
53			t.Fatalf("expect decode audio bytes, %v", err)
54		}
55		audio = bytes.NewReader(b)
56	}
57
58	sess := integration.SessionWithDefaultRegion("us-west-2")
59	var cfgs []*aws.Config
60	if withDebug {
61		cfgs = append(cfgs, &aws.Config{
62			Logger:   t,
63			LogLevel: aws.LogLevel(aws.LogDebugWithEventStreamBody),
64		})
65	}
66
67	client := New(sess, cfgs...)
68	resp, err := client.StartStreamTranscription(&StartStreamTranscriptionInput{
69		LanguageCode:         aws.String(audioLang),
70		MediaEncoding:        aws.String(audioFormat),
71		MediaSampleRateHertz: aws.Int64(int64(audioSampleRate)),
72	})
73	if err != nil {
74		t.Fatalf("failed to start streaming, %v", err)
75	}
76	stream := resp.GetStream()
77	defer stream.Close()
78
79	go StreamAudioFromReader(context.Background(), stream.Writer, audioFrameSize, audio)
80
81	for event := range stream.Events() {
82		switch e := event.(type) {
83		case *TranscriptEvent:
84			t.Logf("got event, %v results", len(e.Transcript.Results))
85			for _, res := range e.Transcript.Results {
86				for _, alt := range res.Alternatives {
87					t.Logf("* %s", aws.StringValue(alt.Transcript))
88				}
89			}
90		default:
91			t.Fatalf("unexpected event, %T", event)
92		}
93	}
94
95	if err := stream.Err(); err != nil {
96		t.Fatalf("expect no error from stream, got %v", err)
97	}
98}
99
100func TestInteg_StartStreamTranscription_contextClose(t *testing.T) {
101	b, err := base64.StdEncoding.DecodeString(
102		`UklGRjzxPQBXQVZFZm10IBAAAAABAAEAgD4AAAB9AAACABAAZGF0YVTwPQAAAAAAAAAAAAAAAAD//wIA/f8EAA==`,
103	)
104	if err != nil {
105		t.Fatalf("expect decode audio bytes, %v", err)
106	}
107	audio := bytes.NewReader(b)
108
109	sess := integration.SessionWithDefaultRegion("us-west-2")
110	var cfgs []*aws.Config
111
112	client := New(sess, cfgs...)
113	resp, err := client.StartStreamTranscription(&StartStreamTranscriptionInput{
114		LanguageCode:         aws.String(LanguageCodeEnUs),
115		MediaEncoding:        aws.String(MediaEncodingPcm),
116		MediaSampleRateHertz: aws.Int64(16000),
117	})
118	if err != nil {
119		t.Fatalf("failed to start streaming, %v", err)
120	}
121	stream := resp.GetStream()
122	defer stream.Close()
123
124	ctx, cancelFn := context.WithCancel(context.Background())
125	var wg sync.WaitGroup
126	wg.Add(1)
127	go func() {
128		err := StreamAudioFromReader(ctx, stream.Writer, audioFrameSize, audio)
129		if err == nil {
130			t.Errorf("expect error")
131		}
132		if e, a := "context canceled", err.Error(); !strings.Contains(a, e) {
133			t.Errorf("expect %q error in %q", e, a)
134		}
135		wg.Done()
136	}()
137
138	cancelFn()
139
140Loop:
141	for {
142		select {
143		case <-ctx.Done():
144			break Loop
145		case event, ok := <-stream.Events():
146			if !ok {
147				break Loop
148			}
149			switch e := event.(type) {
150			case *TranscriptEvent:
151				t.Logf("got event, %v results", len(e.Transcript.Results))
152				for _, res := range e.Transcript.Results {
153					for _, alt := range res.Alternatives {
154						t.Logf("* %s", aws.StringValue(alt.Transcript))
155					}
156				}
157			default:
158				t.Fatalf("unexpected event, %T", event)
159			}
160		}
161	}
162
163	wg.Wait()
164
165	if err := stream.Err(); err != nil {
166		t.Fatalf("expect no error from stream, got %v", err)
167	}
168}
169