1 /*
2  *  Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "video/video_stream_decoder_impl.h"
12 
13 #include <vector>
14 
15 #include "api/video/i420_buffer.h"
16 #include "test/gmock.h"
17 #include "test/gtest.h"
18 #include "test/time_controller/simulated_time_controller.h"
19 
20 namespace webrtc {
21 namespace {
22 using ::testing::_;
23 using ::testing::ByMove;
24 using ::testing::NiceMock;
25 using ::testing::Return;
26 
27 class MockVideoStreamDecoderCallbacks
28     : public VideoStreamDecoderInterface::Callbacks {
29  public:
30   MOCK_METHOD(void, OnNonDecodableState, (), (override));
31   MOCK_METHOD(void,
32               OnContinuousUntil,
33               (const video_coding::VideoLayerFrameId& key),
34               (override));
35   MOCK_METHOD(
36       void,
37       OnDecodedFrame,
38       (VideoFrame frame,
39        const VideoStreamDecoderInterface::Callbacks::FrameInfo& frame_info),
40       (override));
41 };
42 
43 class StubVideoDecoder : public VideoDecoder {
44  public:
45   MOCK_METHOD(int32_t,
46               InitDecode,
47               (const VideoCodec*, int32_t number_of_cores),
48               (override));
49 
Decode(const EncodedImage & input_image,bool missing_frames,int64_t render_time_ms)50   int32_t Decode(const EncodedImage& input_image,
51                  bool missing_frames,
52                  int64_t render_time_ms) override {
53     int32_t ret_code = DecodeCall(input_image, missing_frames, render_time_ms);
54     if (ret_code == WEBRTC_VIDEO_CODEC_OK ||
55         ret_code == WEBRTC_VIDEO_CODEC_OK_REQUEST_KEYFRAME) {
56       VideoFrame frame = VideoFrame::Builder()
57                              .set_video_frame_buffer(I420Buffer::Create(1, 1))
58                              .build();
59       callback_->Decoded(frame);
60     }
61     return ret_code;
62   }
63 
64   MOCK_METHOD(int32_t,
65               DecodeCall,
66               (const EncodedImage& input_image,
67                bool missing_frames,
68                int64_t render_time_ms),
69               ());
70 
Release()71   int32_t Release() override { return 0; }
72 
RegisterDecodeCompleteCallback(DecodedImageCallback * callback)73   int32_t RegisterDecodeCompleteCallback(
74       DecodedImageCallback* callback) override {
75     callback_ = callback;
76     return 0;
77   }
78 
79  private:
80   DecodedImageCallback* callback_;
81 };
82 
83 class WrappedVideoDecoder : public VideoDecoder {
84  public:
WrappedVideoDecoder(StubVideoDecoder * decoder)85   explicit WrappedVideoDecoder(StubVideoDecoder* decoder) : decoder_(decoder) {}
86 
InitDecode(const VideoCodec * codec_settings,int32_t number_of_cores)87   int32_t InitDecode(const VideoCodec* codec_settings,
88                      int32_t number_of_cores) override {
89     return decoder_->InitDecode(codec_settings, number_of_cores);
90   }
Decode(const EncodedImage & input_image,bool missing_frames,int64_t render_time_ms)91   int32_t Decode(const EncodedImage& input_image,
92                  bool missing_frames,
93                  int64_t render_time_ms) override {
94     return decoder_->Decode(input_image, missing_frames, render_time_ms);
95   }
Release()96   int32_t Release() override { return decoder_->Release(); }
97 
RegisterDecodeCompleteCallback(DecodedImageCallback * callback)98   int32_t RegisterDecodeCompleteCallback(
99       DecodedImageCallback* callback) override {
100     return decoder_->RegisterDecodeCompleteCallback(callback);
101   }
102 
103  private:
104   StubVideoDecoder* decoder_;
105 };
106 
107 class FakeVideoDecoderFactory : public VideoDecoderFactory {
108  public:
GetSupportedFormats() const109   std::vector<SdpVideoFormat> GetSupportedFormats() const override {
110     return {};
111   }
CreateVideoDecoder(const SdpVideoFormat & format)112   std::unique_ptr<VideoDecoder> CreateVideoDecoder(
113       const SdpVideoFormat& format) override {
114     if (format.name == "VP8") {
115       return std::make_unique<WrappedVideoDecoder>(&vp8_decoder_);
116     }
117 
118     if (format.name == "AV1") {
119       return std::make_unique<WrappedVideoDecoder>(&av1_decoder_);
120     }
121 
122     return {};
123   }
124 
Vp8Decoder()125   StubVideoDecoder& Vp8Decoder() { return vp8_decoder_; }
Av1Decoder()126   StubVideoDecoder& Av1Decoder() { return av1_decoder_; }
127 
128  private:
129   NiceMock<StubVideoDecoder> vp8_decoder_;
130   NiceMock<StubVideoDecoder> av1_decoder_;
131 };
132 
133 class FakeEncodedFrame : public video_coding::EncodedFrame {
134  public:
ReceivedTime() const135   int64_t ReceivedTime() const override { return 0; }
RenderTime() const136   int64_t RenderTime() const override { return 0; }
137 
138   // Setters for protected variables.
SetPayloadType(int payload_type)139   void SetPayloadType(int payload_type) { _payloadType = payload_type; }
140 };
141 
142 class FrameBuilder {
143  public:
FrameBuilder()144   FrameBuilder() : frame_(std::make_unique<FakeEncodedFrame>()) {}
145 
WithPayloadType(int payload_type)146   FrameBuilder& WithPayloadType(int payload_type) {
147     frame_->SetPayloadType(payload_type);
148     return *this;
149   }
150 
WithPictureId(int picture_id)151   FrameBuilder& WithPictureId(int picture_id) {
152     frame_->id.picture_id = picture_id;
153     return *this;
154   }
155 
Build()156   std::unique_ptr<FakeEncodedFrame> Build() { return std::move(frame_); }
157 
158  private:
159   std::unique_ptr<FakeEncodedFrame> frame_;
160 };
161 
162 class VideoStreamDecoderImplTest : public ::testing::Test {
163  public:
VideoStreamDecoderImplTest()164   VideoStreamDecoderImplTest()
165       : time_controller_(Timestamp::Seconds(0)),
166         video_stream_decoder_(&callbacks_,
167                               &decoder_factory_,
168                               time_controller_.GetTaskQueueFactory(),
169                               {{1, std::make_pair(SdpVideoFormat("VP8"), 1)},
170                                {2, std::make_pair(SdpVideoFormat("AV1"), 1)}}) {
171   }
172 
173   NiceMock<MockVideoStreamDecoderCallbacks> callbacks_;
174   FakeVideoDecoderFactory decoder_factory_;
175   GlobalSimulatedTimeController time_controller_;
176   VideoStreamDecoderImpl video_stream_decoder_;
177 };
178 
TEST_F(VideoStreamDecoderImplTest,InsertAndDecodeFrame)179 TEST_F(VideoStreamDecoderImplTest, InsertAndDecodeFrame) {
180   video_stream_decoder_.OnFrame(FrameBuilder().WithPayloadType(1).Build());
181   EXPECT_CALL(callbacks_, OnDecodedFrame);
182   time_controller_.AdvanceTime(TimeDelta::Millis(1));
183 }
184 
TEST_F(VideoStreamDecoderImplTest,NonDecodableStateWaitingForKeyframe)185 TEST_F(VideoStreamDecoderImplTest, NonDecodableStateWaitingForKeyframe) {
186   EXPECT_CALL(callbacks_, OnNonDecodableState);
187   time_controller_.AdvanceTime(TimeDelta::Millis(200));
188 }
189 
TEST_F(VideoStreamDecoderImplTest,NonDecodableStateWaitingForDeltaFrame)190 TEST_F(VideoStreamDecoderImplTest, NonDecodableStateWaitingForDeltaFrame) {
191   video_stream_decoder_.OnFrame(FrameBuilder().WithPayloadType(1).Build());
192   EXPECT_CALL(callbacks_, OnDecodedFrame);
193   time_controller_.AdvanceTime(TimeDelta::Millis(1));
194   EXPECT_CALL(callbacks_, OnNonDecodableState);
195   time_controller_.AdvanceTime(TimeDelta::Millis(3000));
196 }
197 
TEST_F(VideoStreamDecoderImplTest,InsertAndDecodeFrameWithKeyframeRequest)198 TEST_F(VideoStreamDecoderImplTest, InsertAndDecodeFrameWithKeyframeRequest) {
199   video_stream_decoder_.OnFrame(FrameBuilder().WithPayloadType(1).Build());
200   EXPECT_CALL(decoder_factory_.Vp8Decoder(), DecodeCall)
201       .WillOnce(Return(WEBRTC_VIDEO_CODEC_OK_REQUEST_KEYFRAME));
202   EXPECT_CALL(callbacks_, OnDecodedFrame);
203   EXPECT_CALL(callbacks_, OnNonDecodableState);
204   time_controller_.AdvanceTime(TimeDelta::Millis(1));
205 }
206 
TEST_F(VideoStreamDecoderImplTest,FailToInitDecoder)207 TEST_F(VideoStreamDecoderImplTest, FailToInitDecoder) {
208   video_stream_decoder_.OnFrame(FrameBuilder().WithPayloadType(1).Build());
209   ON_CALL(decoder_factory_.Vp8Decoder(), InitDecode)
210       .WillByDefault(Return(WEBRTC_VIDEO_CODEC_ERROR));
211   EXPECT_CALL(callbacks_, OnNonDecodableState);
212   time_controller_.AdvanceTime(TimeDelta::Millis(1));
213 }
214 
TEST_F(VideoStreamDecoderImplTest,FailToDecodeFrame)215 TEST_F(VideoStreamDecoderImplTest, FailToDecodeFrame) {
216   video_stream_decoder_.OnFrame(FrameBuilder().WithPayloadType(1).Build());
217   ON_CALL(decoder_factory_.Vp8Decoder(), DecodeCall)
218       .WillByDefault(Return(WEBRTC_VIDEO_CODEC_ERROR));
219   EXPECT_CALL(callbacks_, OnNonDecodableState);
220   time_controller_.AdvanceTime(TimeDelta::Millis(1));
221 }
222 
TEST_F(VideoStreamDecoderImplTest,ChangeFramePayloadType)223 TEST_F(VideoStreamDecoderImplTest, ChangeFramePayloadType) {
224   video_stream_decoder_.OnFrame(
225       FrameBuilder().WithPayloadType(1).WithPictureId(0).Build());
226   EXPECT_CALL(decoder_factory_.Vp8Decoder(), DecodeCall);
227   EXPECT_CALL(callbacks_, OnDecodedFrame);
228   time_controller_.AdvanceTime(TimeDelta::Millis(1));
229 
230   video_stream_decoder_.OnFrame(
231       FrameBuilder().WithPayloadType(2).WithPictureId(1).Build());
232   EXPECT_CALL(decoder_factory_.Av1Decoder(), DecodeCall);
233   EXPECT_CALL(callbacks_, OnDecodedFrame);
234   time_controller_.AdvanceTime(TimeDelta::Millis(1));
235 }
236 
237 }  // namespace
238 }  // namespace webrtc
239