1 // Copyright 2016 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "chrome/common/conflicts/module_watcher_win.h"
6 
7 #include <memory>
8 
9 #include "base/bind.h"
10 #include "base/test/task_environment.h"
11 #include "testing/gtest/include/gtest/gtest.h"
12 
13 #include <windows.h>
14 
15 class ModuleWatcherTest : public testing::Test {
16  protected:
ModuleWatcherTest()17   ModuleWatcherTest()
18       : module_(nullptr),
19         module_event_count_(0),
20         module_already_loaded_event_count_(0),
21         module_loaded_event_count_(0) {}
22 
OnModuleEvent(const ModuleWatcher::ModuleEvent & event)23   void OnModuleEvent(const ModuleWatcher::ModuleEvent& event) {
24     ++module_event_count_;
25     switch (event.event_type) {
26       case ModuleWatcher::ModuleEventType::kModuleAlreadyLoaded:
27         ++module_already_loaded_event_count_;
28         break;
29       case ModuleWatcher::ModuleEventType::kModuleLoaded:
30         ++module_loaded_event_count_;
31         break;
32     }
33   }
34 
TearDown()35   void TearDown() override { UnloadModule(); }
36 
LoadModule()37   void LoadModule() {
38     if (module_)
39       return;
40     // This module should not be a static dependency of the unit-test
41     // executable, but should be a build-system dependency or a module that is
42     // present on any Windows machine.
43     static constexpr wchar_t kModuleName[] = L"conflicts_dll.dll";
44     // The module should not already be loaded.
45     ASSERT_FALSE(::GetModuleHandle(kModuleName));
46     // It should load successfully.
47     module_ = ::LoadLibrary(kModuleName);
48     ASSERT_TRUE(module_);
49   }
50 
UnloadModule()51   void UnloadModule() {
52     if (!module_)
53       return;
54     ::FreeLibrary(module_);
55     module_ = nullptr;
56   }
57 
RunUntilIdle()58   void RunUntilIdle() { task_environment_.RunUntilIdle(); }
59 
Create()60   std::unique_ptr<ModuleWatcher> Create() {
61     return ModuleWatcher::Create(base::BindRepeating(
62         &ModuleWatcherTest::OnModuleEvent, base::Unretained(this)));
63   }
64 
65   base::test::TaskEnvironment task_environment_;
66 
67   // Holds a handle to a loaded module.
68   HMODULE module_;
69   // Total number of module events seen.
70   int module_event_count_;
71   // Total number of MODULE_ALREADY_LOADED events seen.
72   int module_already_loaded_event_count_;
73   // Total number of MODULE_LOADED events seen.
74   int module_loaded_event_count_;
75 
76  private:
77   DISALLOW_COPY_AND_ASSIGN(ModuleWatcherTest);
78 };
79 
TEST_F(ModuleWatcherTest,SingleModuleWatcherOnly)80 TEST_F(ModuleWatcherTest, SingleModuleWatcherOnly) {
81   std::unique_ptr<ModuleWatcher> mw1(Create());
82   EXPECT_TRUE(mw1.get());
83 
84   std::unique_ptr<ModuleWatcher> mw2(Create());
85   EXPECT_FALSE(mw2.get());
86 }
87 
TEST_F(ModuleWatcherTest,ModuleEvents)88 TEST_F(ModuleWatcherTest, ModuleEvents) {
89   // Create the module watcher. This should immediately enumerate all already
90   // loaded modules on a background task.
91   std::unique_ptr<ModuleWatcher> mw(Create());
92   RunUntilIdle();
93 
94   EXPECT_LT(0, module_event_count_);
95   EXPECT_LT(0, module_already_loaded_event_count_);
96   EXPECT_EQ(0, module_loaded_event_count_);
97 
98   // Dynamically load a module and ensure a notification is received for it.
99   int previous_module_loaded_event_count = module_loaded_event_count_;
100   LoadModule();
101   EXPECT_LT(previous_module_loaded_event_count, module_loaded_event_count_);
102 
103   UnloadModule();
104 
105   // Dynamically load a module and ensure a notification is received for it.
106   previous_module_loaded_event_count = module_loaded_event_count_;
107   LoadModule();
108   EXPECT_LT(previous_module_loaded_event_count, module_loaded_event_count_);
109 
110   UnloadModule();
111 
112   // Destroy the module watcher.
113   mw.reset();
114 
115   // Load the module and ensure no notification is received this time.
116   previous_module_loaded_event_count = module_loaded_event_count_;
117   LoadModule();
118   EXPECT_EQ(previous_module_loaded_event_count, module_loaded_event_count_);
119 }
120