1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "build/build_config.h"
6 
7 #include <stddef.h>
8 #include <stdint.h>
9 #include <memory>
10 
11 #include "base/message_loop/message_pump_type.h"
12 #include "base/pickle.h"
13 #include "base/run_loop.h"
14 #include "base/threading/thread.h"
15 #include "ipc/ipc_message.h"
16 #include "ipc/ipc_test_base.h"
17 #include "ipc/message_filter.h"
18 
19 // Get basic type definitions.
20 #define IPC_MESSAGE_IMPL
21 #include "ipc/ipc_channel_proxy_unittest_messages.h"
22 
23 // Generate constructors.
24 #include "ipc/struct_constructor_macros.h"
25 #include "ipc/ipc_channel_proxy_unittest_messages.h"
26 
27 // Generate param traits write methods.
28 #include "ipc/param_traits_write_macros.h"
29 namespace IPC {
30 #include "ipc/ipc_channel_proxy_unittest_messages.h"
31 }  // namespace IPC
32 
33 // Generate param traits read methods.
34 #include "ipc/param_traits_read_macros.h"
35 namespace IPC {
36 #include "ipc/ipc_channel_proxy_unittest_messages.h"
37 }  // namespace IPC
38 
39 // Generate param traits log methods.
40 #include "ipc/param_traits_log_macros.h"
41 namespace IPC {
42 #include "ipc/ipc_channel_proxy_unittest_messages.h"
43 }  // namespace IPC
44 
45 
46 namespace {
47 
CreateRunLoopAndRun(base::RunLoop ** run_loop_ptr)48 void CreateRunLoopAndRun(base::RunLoop** run_loop_ptr) {
49   base::RunLoop run_loop;
50   *run_loop_ptr = &run_loop;
51   run_loop.Run();
52   *run_loop_ptr = nullptr;
53 }
54 
55 class QuitListener : public IPC::Listener {
56  public:
57   QuitListener() = default;
58 
OnMessageReceived(const IPC::Message & message)59   bool OnMessageReceived(const IPC::Message& message) override {
60     IPC_BEGIN_MESSAGE_MAP(QuitListener, message)
61       IPC_MESSAGE_HANDLER(WorkerMsg_Quit, OnQuit)
62       IPC_MESSAGE_HANDLER(TestMsg_BadMessage, OnBadMessage)
63     IPC_END_MESSAGE_MAP()
64     return true;
65   }
66 
OnBadMessageReceived(const IPC::Message & message)67   void OnBadMessageReceived(const IPC::Message& message) override {
68     bad_message_received_ = true;
69   }
70 
OnChannelError()71   void OnChannelError() override { CHECK(quit_message_received_); }
72 
OnQuit()73   void OnQuit() {
74     quit_message_received_ = true;
75     run_loop_->QuitWhenIdle();
76   }
77 
OnBadMessage(const BadType & bad_type)78   void OnBadMessage(const BadType& bad_type) {
79     // Should never be called since IPC wouldn't be deserialized correctly.
80     CHECK(false);
81   }
82 
83   bool bad_message_received_ = false;
84   bool quit_message_received_ = false;
85   base::RunLoop* run_loop_ = nullptr;
86 };
87 
88 class ChannelReflectorListener : public IPC::Listener {
89  public:
90   ChannelReflectorListener() = default;
91 
Init(IPC::Channel * channel)92   void Init(IPC::Channel* channel) {
93     DCHECK(!channel_);
94     channel_ = channel;
95   }
96 
OnMessageReceived(const IPC::Message & message)97   bool OnMessageReceived(const IPC::Message& message) override {
98     IPC_BEGIN_MESSAGE_MAP(ChannelReflectorListener, message)
99       IPC_MESSAGE_HANDLER(TestMsg_Bounce, OnTestBounce)
100       IPC_MESSAGE_HANDLER(TestMsg_SendBadMessage, OnSendBadMessage)
101       IPC_MESSAGE_HANDLER(AutomationMsg_Bounce, OnAutomationBounce)
102       IPC_MESSAGE_HANDLER(WorkerMsg_Bounce, OnBounce)
103       IPC_MESSAGE_HANDLER(WorkerMsg_Quit, OnQuit)
104     IPC_END_MESSAGE_MAP()
105     return true;
106   }
107 
OnTestBounce()108   void OnTestBounce() {
109     channel_->Send(new TestMsg_Bounce());
110   }
111 
OnSendBadMessage()112   void OnSendBadMessage() {
113     channel_->Send(new TestMsg_BadMessage(BadType()));
114   }
115 
OnAutomationBounce()116   void OnAutomationBounce() { channel_->Send(new AutomationMsg_Bounce()); }
117 
OnBounce()118   void OnBounce() {
119     channel_->Send(new WorkerMsg_Bounce());
120   }
121 
OnQuit()122   void OnQuit() {
123     channel_->Send(new WorkerMsg_Quit());
124     run_loop_->QuitWhenIdle();
125   }
126 
127   base::RunLoop* run_loop_ = nullptr;
128 
129  private:
130   IPC::Channel* channel_ = nullptr;
131 };
132 
133 class MessageCountFilter : public IPC::MessageFilter {
134  public:
135   enum FilterEvent {
136     NONE,
137     FILTER_ADDED,
138     CHANNEL_CONNECTED,
139     CHANNEL_ERROR,
140     CHANNEL_CLOSING,
141     FILTER_REMOVED
142   };
143 
144   MessageCountFilter() = default;
MessageCountFilter(uint32_t supported_message_class)145   MessageCountFilter(uint32_t supported_message_class)
146       : supported_message_class_(supported_message_class),
147         is_global_filter_(false) {}
148 
OnFilterAdded(IPC::Channel * channel)149   void OnFilterAdded(IPC::Channel* channel) override {
150     EXPECT_TRUE(channel);
151     EXPECT_EQ(NONE, last_filter_event_);
152     last_filter_event_ = FILTER_ADDED;
153   }
154 
OnChannelConnected(int32_t peer_pid)155   void OnChannelConnected(int32_t peer_pid) override {
156     EXPECT_EQ(FILTER_ADDED, last_filter_event_);
157     EXPECT_NE(static_cast<int32_t>(base::kNullProcessId), peer_pid);
158     last_filter_event_ = CHANNEL_CONNECTED;
159   }
160 
OnChannelError()161   void OnChannelError() override {
162     EXPECT_EQ(CHANNEL_CONNECTED, last_filter_event_);
163     last_filter_event_ = CHANNEL_ERROR;
164   }
165 
OnChannelClosing()166   void OnChannelClosing() override {
167     // We may or may not have gotten OnChannelError; if not, the last event has
168     // to be OnChannelConnected.
169     EXPECT_NE(FILTER_REMOVED, last_filter_event_);
170     if (last_filter_event_ != CHANNEL_ERROR)
171       EXPECT_EQ(CHANNEL_CONNECTED, last_filter_event_);
172     last_filter_event_ = CHANNEL_CLOSING;
173   }
174 
OnFilterRemoved()175   void OnFilterRemoved() override {
176     // A filter may be removed at any time, even before the channel is connected
177     // (and thus before OnFilterAdded is ever able to dispatch.) The only time
178     // we won't see OnFilterRemoved is immediately after OnFilterAdded, because
179     // OnChannelConnected is always the next event to fire after that.
180     EXPECT_NE(FILTER_ADDED, last_filter_event_);
181     last_filter_event_ = FILTER_REMOVED;
182   }
183 
OnMessageReceived(const IPC::Message & message)184   bool OnMessageReceived(const IPC::Message& message) override {
185     // We should always get the OnFilterAdded and OnChannelConnected events
186     // prior to any messages.
187     EXPECT_EQ(CHANNEL_CONNECTED, last_filter_event_);
188 
189     if (!is_global_filter_) {
190       EXPECT_EQ(supported_message_class_, IPC_MESSAGE_CLASS(message));
191     }
192     ++messages_received_;
193 
194     if (!message_filtering_enabled_)
195       return false;
196 
197     bool handled = true;
198     IPC_BEGIN_MESSAGE_MAP(MessageCountFilter, message)
199       IPC_MESSAGE_HANDLER(TestMsg_BadMessage, OnBadMessage)
200       IPC_MESSAGE_UNHANDLED(handled = false)
201     IPC_END_MESSAGE_MAP()
202     return handled;
203   }
204 
OnBadMessage(const BadType & bad_type)205   void OnBadMessage(const BadType& bad_type) {
206     // Should never be called since IPC wouldn't be deserialized correctly.
207     CHECK(false);
208   }
209 
GetSupportedMessageClasses(std::vector<uint32_t> * supported_message_classes) const210   bool GetSupportedMessageClasses(
211       std::vector<uint32_t>* supported_message_classes) const override {
212     if (is_global_filter_)
213       return false;
214     supported_message_classes->push_back(supported_message_class_);
215     return true;
216   }
217 
set_message_filtering_enabled(bool enabled)218   void set_message_filtering_enabled(bool enabled) {
219     message_filtering_enabled_ = enabled;
220   }
221 
messages_received() const222   size_t messages_received() const { return messages_received_; }
last_filter_event() const223   FilterEvent last_filter_event() const { return last_filter_event_; }
224 
225  private:
226   ~MessageCountFilter() override = default;
227 
228   size_t messages_received_ = 0;
229   uint32_t supported_message_class_ = 0;
230   bool is_global_filter_ = true;
231 
232   FilterEvent last_filter_event_ = NONE;
233   bool message_filtering_enabled_ = false;
234 };
235 
236 class IPCChannelProxyTest : public IPCChannelMojoTestBase {
237  public:
238   IPCChannelProxyTest() = default;
239   ~IPCChannelProxyTest() override = default;
240 
SetUp()241   void SetUp() override {
242     IPCChannelMojoTestBase::SetUp();
243 
244     Init("ChannelProxyClient");
245 
246     thread_.reset(new base::Thread("ChannelProxyTestServerThread"));
247     base::Thread::Options options;
248     options.message_pump_type = base::MessagePumpType::IO;
249     thread_->StartWithOptions(options);
250 
251     listener_.reset(new QuitListener());
252     channel_proxy_ = IPC::ChannelProxy::Create(
253         TakeHandle().release(), IPC::Channel::MODE_SERVER, listener_.get(),
254         thread_->task_runner(), base::ThreadTaskRunnerHandle::Get());
255   }
256 
TearDown()257   void TearDown() override {
258     channel_proxy_.reset();
259     thread_.reset();
260     listener_.reset();
261     IPCChannelMojoTestBase::TearDown();
262   }
263 
SendQuitMessageAndWaitForIdle()264   void SendQuitMessageAndWaitForIdle() {
265     sender()->Send(new WorkerMsg_Quit);
266     CreateRunLoopAndRun(&listener_->run_loop_);
267     EXPECT_TRUE(WaitForClientShutdown());
268   }
269 
DidListenerGetBadMessage()270   bool DidListenerGetBadMessage() {
271     return listener_->bad_message_received_;
272   }
273 
channel_proxy()274   IPC::ChannelProxy* channel_proxy() { return channel_proxy_.get(); }
sender()275   IPC::Sender* sender() { return channel_proxy_.get(); }
276 
277  private:
278   std::unique_ptr<base::Thread> thread_;
279   std::unique_ptr<QuitListener> listener_;
280   std::unique_ptr<IPC::ChannelProxy> channel_proxy_;
281 };
282 
TEST_F(IPCChannelProxyTest,MessageClassFilters)283 TEST_F(IPCChannelProxyTest, MessageClassFilters) {
284   // Construct a filter per message class.
285   std::vector<scoped_refptr<MessageCountFilter>> class_filters;
286   class_filters.push_back(
287       base::MakeRefCounted<MessageCountFilter>(TestMsgStart));
288   class_filters.push_back(
289       base::MakeRefCounted<MessageCountFilter>(AutomationMsgStart));
290   for (size_t i = 0; i < class_filters.size(); ++i)
291     channel_proxy()->AddFilter(class_filters[i].get());
292 
293   // Send a message for each class; each filter should receive just one message.
294   sender()->Send(new TestMsg_Bounce);
295   sender()->Send(new AutomationMsg_Bounce);
296 
297   // Send some messages not assigned to a specific or valid message class.
298   sender()->Send(new WorkerMsg_Bounce);
299 
300   // Each filter should have received just the one sent message of the
301   // corresponding class.
302   SendQuitMessageAndWaitForIdle();
303   for (size_t i = 0; i < class_filters.size(); ++i)
304     EXPECT_EQ(1U, class_filters[i]->messages_received());
305 }
306 
TEST_F(IPCChannelProxyTest,GlobalAndMessageClassFilters)307 TEST_F(IPCChannelProxyTest, GlobalAndMessageClassFilters) {
308   // Add a class and global filter.
309   scoped_refptr<MessageCountFilter> class_filter(
310       new MessageCountFilter(TestMsgStart));
311   class_filter->set_message_filtering_enabled(false);
312   channel_proxy()->AddFilter(class_filter.get());
313 
314   scoped_refptr<MessageCountFilter> global_filter(new MessageCountFilter());
315   global_filter->set_message_filtering_enabled(false);
316   channel_proxy()->AddFilter(global_filter.get());
317 
318   // A message  of class Test should be seen by both the global filter and
319   // Test-specific filter.
320   sender()->Send(new TestMsg_Bounce);
321 
322   // A message of a different class should be seen only by the global filter.
323   sender()->Send(new AutomationMsg_Bounce);
324 
325   // Flush all messages.
326   SendQuitMessageAndWaitForIdle();
327 
328   // The class filter should have received only the class-specific message.
329   EXPECT_EQ(1U, class_filter->messages_received());
330 
331   // The global filter should have received both messages, as well as the final
332   // QUIT message.
333   EXPECT_EQ(3U, global_filter->messages_received());
334 }
335 
TEST_F(IPCChannelProxyTest,FilterRemoval)336 TEST_F(IPCChannelProxyTest, FilterRemoval) {
337   // Add a class and global filter.
338   scoped_refptr<MessageCountFilter> class_filter(
339       new MessageCountFilter(TestMsgStart));
340   scoped_refptr<MessageCountFilter> global_filter(new MessageCountFilter());
341 
342   // Add and remove both types of filters.
343   channel_proxy()->AddFilter(class_filter.get());
344   channel_proxy()->AddFilter(global_filter.get());
345   channel_proxy()->RemoveFilter(global_filter.get());
346   channel_proxy()->RemoveFilter(class_filter.get());
347 
348   // Send some messages; they should not be seen by either filter.
349   sender()->Send(new TestMsg_Bounce);
350   sender()->Send(new AutomationMsg_Bounce);
351 
352   // Ensure that the filters were removed and did not receive any messages.
353   SendQuitMessageAndWaitForIdle();
354   EXPECT_EQ(MessageCountFilter::FILTER_REMOVED,
355             global_filter->last_filter_event());
356   EXPECT_EQ(MessageCountFilter::FILTER_REMOVED,
357             class_filter->last_filter_event());
358   EXPECT_EQ(0U, class_filter->messages_received());
359   EXPECT_EQ(0U, global_filter->messages_received());
360 }
361 
TEST_F(IPCChannelProxyTest,BadMessageOnListenerThread)362 TEST_F(IPCChannelProxyTest, BadMessageOnListenerThread) {
363   scoped_refptr<MessageCountFilter> class_filter(
364       new MessageCountFilter(TestMsgStart));
365   class_filter->set_message_filtering_enabled(false);
366   channel_proxy()->AddFilter(class_filter.get());
367 
368   sender()->Send(new TestMsg_SendBadMessage());
369 
370   SendQuitMessageAndWaitForIdle();
371   EXPECT_TRUE(DidListenerGetBadMessage());
372 }
373 
TEST_F(IPCChannelProxyTest,BadMessageOnIPCThread)374 TEST_F(IPCChannelProxyTest, BadMessageOnIPCThread) {
375   scoped_refptr<MessageCountFilter> class_filter(
376       new MessageCountFilter(TestMsgStart));
377   class_filter->set_message_filtering_enabled(true);
378   channel_proxy()->AddFilter(class_filter.get());
379 
380   sender()->Send(new TestMsg_SendBadMessage());
381 
382   SendQuitMessageAndWaitForIdle();
383   EXPECT_TRUE(DidListenerGetBadMessage());
384 }
385 
386 class IPCChannelBadMessageTest : public IPCChannelMojoTestBase {
387  public:
SetUp()388   void SetUp() override {
389     IPCChannelMojoTestBase::SetUp();
390 
391     Init("ChannelProxyClient");
392 
393     listener_.reset(new QuitListener());
394     CreateChannel(listener_.get());
395     ASSERT_TRUE(ConnectChannel());
396   }
397 
TearDown()398   void TearDown() override {
399     IPCChannelMojoTestBase::TearDown();
400     listener_.reset();
401   }
402 
SendQuitMessageAndWaitForIdle()403   void SendQuitMessageAndWaitForIdle() {
404     sender()->Send(new WorkerMsg_Quit);
405     CreateRunLoopAndRun(&listener_->run_loop_);
406     EXPECT_TRUE(WaitForClientShutdown());
407   }
408 
DidListenerGetBadMessage()409   bool DidListenerGetBadMessage() {
410     return listener_->bad_message_received_;
411   }
412 
413  private:
414   std::unique_ptr<QuitListener> listener_;
415 };
416 
TEST_F(IPCChannelBadMessageTest,BadMessage)417 TEST_F(IPCChannelBadMessageTest, BadMessage) {
418   sender()->Send(new TestMsg_SendBadMessage());
419   SendQuitMessageAndWaitForIdle();
420   EXPECT_TRUE(DidListenerGetBadMessage());
421 }
422 
DEFINE_IPC_CHANNEL_MOJO_TEST_CLIENT(ChannelProxyClient)423 DEFINE_IPC_CHANNEL_MOJO_TEST_CLIENT(ChannelProxyClient) {
424   ChannelReflectorListener listener;
425   Connect(&listener);
426   listener.Init(channel());
427 
428   CreateRunLoopAndRun(&listener.run_loop_);
429 
430   Close();
431 }
432 
433 }  // namespace
434