1 // Copyright 2016 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "chrome/common/conflicts/module_watcher_win.h"
6
7 #include <memory>
8
9 #include "base/bind.h"
10 #include "base/test/task_environment.h"
11 #include "testing/gtest/include/gtest/gtest.h"
12
13 #include <windows.h>
14
15 class ModuleWatcherTest : public testing::Test {
16 protected:
ModuleWatcherTest()17 ModuleWatcherTest()
18 : module_(nullptr),
19 module_event_count_(0),
20 module_already_loaded_event_count_(0),
21 module_loaded_event_count_(0) {}
22
OnModuleEvent(const ModuleWatcher::ModuleEvent & event)23 void OnModuleEvent(const ModuleWatcher::ModuleEvent& event) {
24 ++module_event_count_;
25 switch (event.event_type) {
26 case ModuleWatcher::ModuleEventType::kModuleAlreadyLoaded:
27 ++module_already_loaded_event_count_;
28 break;
29 case ModuleWatcher::ModuleEventType::kModuleLoaded:
30 ++module_loaded_event_count_;
31 break;
32 }
33 }
34
TearDown()35 void TearDown() override { UnloadModule(); }
36
LoadModule()37 void LoadModule() {
38 if (module_)
39 return;
40 // This module should not be a static dependency of the unit-test
41 // executable, but should be a build-system dependency or a module that is
42 // present on any Windows machine.
43 static constexpr wchar_t kModuleName[] = L"conflicts_dll.dll";
44 // The module should not already be loaded.
45 ASSERT_FALSE(::GetModuleHandle(kModuleName));
46 // It should load successfully.
47 module_ = ::LoadLibrary(kModuleName);
48 ASSERT_TRUE(module_);
49 }
50
UnloadModule()51 void UnloadModule() {
52 if (!module_)
53 return;
54 ::FreeLibrary(module_);
55 module_ = nullptr;
56 }
57
RunUntilIdle()58 void RunUntilIdle() { task_environment_.RunUntilIdle(); }
59
Create()60 std::unique_ptr<ModuleWatcher> Create() {
61 return ModuleWatcher::Create(base::BindRepeating(
62 &ModuleWatcherTest::OnModuleEvent, base::Unretained(this)));
63 }
64
65 base::test::TaskEnvironment task_environment_;
66
67 // Holds a handle to a loaded module.
68 HMODULE module_;
69 // Total number of module events seen.
70 int module_event_count_;
71 // Total number of MODULE_ALREADY_LOADED events seen.
72 int module_already_loaded_event_count_;
73 // Total number of MODULE_LOADED events seen.
74 int module_loaded_event_count_;
75
76 private:
77 DISALLOW_COPY_AND_ASSIGN(ModuleWatcherTest);
78 };
79
TEST_F(ModuleWatcherTest,SingleModuleWatcherOnly)80 TEST_F(ModuleWatcherTest, SingleModuleWatcherOnly) {
81 std::unique_ptr<ModuleWatcher> mw1(Create());
82 EXPECT_TRUE(mw1.get());
83
84 std::unique_ptr<ModuleWatcher> mw2(Create());
85 EXPECT_FALSE(mw2.get());
86 }
87
TEST_F(ModuleWatcherTest,ModuleEvents)88 TEST_F(ModuleWatcherTest, ModuleEvents) {
89 // Create the module watcher. This should immediately enumerate all already
90 // loaded modules on a background task.
91 std::unique_ptr<ModuleWatcher> mw(Create());
92 RunUntilIdle();
93
94 EXPECT_LT(0, module_event_count_);
95 EXPECT_LT(0, module_already_loaded_event_count_);
96 EXPECT_EQ(0, module_loaded_event_count_);
97
98 // Dynamically load a module and ensure a notification is received for it.
99 int previous_module_loaded_event_count = module_loaded_event_count_;
100 LoadModule();
101 EXPECT_LT(previous_module_loaded_event_count, module_loaded_event_count_);
102
103 UnloadModule();
104
105 // Dynamically load a module and ensure a notification is received for it.
106 previous_module_loaded_event_count = module_loaded_event_count_;
107 LoadModule();
108 EXPECT_LT(previous_module_loaded_event_count, module_loaded_event_count_);
109
110 UnloadModule();
111
112 // Destroy the module watcher.
113 mw.reset();
114
115 // Load the module and ensure no notification is received this time.
116 previous_module_loaded_event_count = module_loaded_event_count_;
117 LoadModule();
118 EXPECT_EQ(previous_module_loaded_event_count, module_loaded_event_count_);
119 }
120