1 #include "shellext.h"
2 #ifndef __REACTOS__
3 #include "mountmgr.h"
4 #else
5 #include "mountmgr_local.h"
6 #endif
7 #include <mountmgr.h>
8 
9 using namespace std;
10 
mountmgr()11 mountmgr::mountmgr() {
12     UNICODE_STRING us;
13     OBJECT_ATTRIBUTES attr;
14     IO_STATUS_BLOCK iosb;
15     NTSTATUS Status;
16 
17     RtlInitUnicodeString(&us, MOUNTMGR_DEVICE_NAME);
18     InitializeObjectAttributes(&attr, &us, 0, nullptr, nullptr);
19 
20     Status = NtOpenFile(&h, FILE_GENERIC_READ | FILE_GENERIC_WRITE, &attr, &iosb,
21                         FILE_SHARE_READ, FILE_SYNCHRONOUS_IO_ALERT);
22 
23     if (!NT_SUCCESS(Status))
24         throw ntstatus_error(Status);
25 }
26 
~mountmgr()27 mountmgr::~mountmgr() {
28     NtClose(h);
29 }
30 
create_point(const wstring_view & symlink,const wstring_view & device) const31 void mountmgr::create_point(const wstring_view& symlink, const wstring_view& device) const {
32     NTSTATUS Status;
33     IO_STATUS_BLOCK iosb;
34 
35     vector<uint8_t> buf(sizeof(MOUNTMGR_CREATE_POINT_INPUT) + ((symlink.length() + device.length()) * sizeof(WCHAR)));
36     auto mcpi = reinterpret_cast<MOUNTMGR_CREATE_POINT_INPUT*>(buf.data());
37 
38     mcpi->SymbolicLinkNameOffset = sizeof(MOUNTMGR_CREATE_POINT_INPUT);
39     mcpi->SymbolicLinkNameLength = (USHORT)(symlink.length() * sizeof(WCHAR));
40     mcpi->DeviceNameOffset = (USHORT)(mcpi->SymbolicLinkNameOffset + mcpi->SymbolicLinkNameLength);
41     mcpi->DeviceNameLength = (USHORT)(device.length() * sizeof(WCHAR));
42 
43     memcpy((uint8_t*)mcpi + mcpi->SymbolicLinkNameOffset, symlink.data(), symlink.length() * sizeof(WCHAR));
44     memcpy((uint8_t*)mcpi + mcpi->DeviceNameOffset, device.data(), device.length() * sizeof(WCHAR));
45 
46     Status = NtDeviceIoControlFile(h, nullptr, nullptr, nullptr, &iosb, IOCTL_MOUNTMGR_CREATE_POINT,
47                                    buf.data(), (ULONG)buf.size(), nullptr, 0);
48 
49     if (!NT_SUCCESS(Status))
50         throw ntstatus_error(Status);
51 }
52 
delete_points(const wstring_view & symlink,const wstring_view & unique_id,const wstring_view & device_name) const53 void mountmgr::delete_points(const wstring_view& symlink, const wstring_view& unique_id, const wstring_view& device_name) const {
54     NTSTATUS Status;
55     IO_STATUS_BLOCK iosb;
56 
57     vector<uint8_t> buf(sizeof(MOUNTMGR_MOUNT_POINT) + ((symlink.length() + unique_id.length() + device_name.length()) * sizeof(WCHAR)));
58     auto mmp = reinterpret_cast<MOUNTMGR_MOUNT_POINT*>(buf.data());
59 
60     memset(mmp, 0, sizeof(MOUNTMGR_MOUNT_POINT));
61 
62     if (symlink.length() > 0) {
63         mmp->SymbolicLinkNameOffset = sizeof(MOUNTMGR_MOUNT_POINT);
64         mmp->SymbolicLinkNameLength = (USHORT)(symlink.length() * sizeof(WCHAR));
65         memcpy((uint8_t*)mmp + mmp->SymbolicLinkNameOffset, symlink.data(), symlink.length() * sizeof(WCHAR));
66     }
67 
68     if (unique_id.length() > 0) {
69         if (mmp->SymbolicLinkNameLength == 0)
70             mmp->UniqueIdOffset = sizeof(MOUNTMGR_MOUNT_POINT);
71         else
72             mmp->UniqueIdOffset = mmp->SymbolicLinkNameOffset + mmp->SymbolicLinkNameLength;
73 
74         mmp->UniqueIdLength = (USHORT)(unique_id.length() * sizeof(WCHAR));
75         memcpy((uint8_t*)mmp + mmp->UniqueIdOffset, unique_id.data(), unique_id.length() * sizeof(WCHAR));
76     }
77 
78     if (device_name.length() > 0) {
79         if (mmp->SymbolicLinkNameLength == 0 && mmp->UniqueIdOffset == 0)
80             mmp->DeviceNameOffset = sizeof(MOUNTMGR_MOUNT_POINT);
81         else if (mmp->SymbolicLinkNameLength != 0)
82             mmp->DeviceNameOffset = mmp->SymbolicLinkNameOffset + mmp->SymbolicLinkNameLength;
83         else
84             mmp->DeviceNameOffset = mmp->UniqueIdOffset + mmp->UniqueIdLength;
85 
86         mmp->DeviceNameLength = (USHORT)(device_name.length() * sizeof(WCHAR));
87         memcpy((uint8_t*)mmp + mmp->DeviceNameOffset, device_name.data(), device_name.length() * sizeof(WCHAR));
88     }
89 
90     vector<uint8_t> buf2(sizeof(MOUNTMGR_MOUNT_POINTS));
91     auto mmps = reinterpret_cast<MOUNTMGR_MOUNT_POINTS*>(buf2.data());
92 
93     Status = NtDeviceIoControlFile(h, nullptr, nullptr, nullptr, &iosb, IOCTL_MOUNTMGR_DELETE_POINTS,
94                                    buf.data(), (ULONG)buf.size(), buf2.data(), (ULONG)buf2.size());
95 
96     if (Status == STATUS_BUFFER_OVERFLOW) {
97         buf2.resize(mmps->Size);
98 
99         Status = NtDeviceIoControlFile(h, nullptr, nullptr, nullptr, &iosb, IOCTL_MOUNTMGR_DELETE_POINTS,
100                                        buf.data(), (ULONG)buf.size(), buf2.data(), (ULONG)buf2.size());
101     }
102 
103     if (!NT_SUCCESS(Status))
104         throw ntstatus_error(Status);
105 }
106 
query_points(const wstring_view & symlink,const wstring_view & unique_id,const wstring_view & device_name) const107 vector<mountmgr_point> mountmgr::query_points(const wstring_view& symlink, const wstring_view& unique_id, const wstring_view& device_name) const {
108     NTSTATUS Status;
109     IO_STATUS_BLOCK iosb;
110     vector<mountmgr_point> v;
111 
112     vector<uint8_t> buf(sizeof(MOUNTMGR_MOUNT_POINT) + ((symlink.length() + unique_id.length() + device_name.length()) * sizeof(WCHAR)));
113     auto mmp = reinterpret_cast<MOUNTMGR_MOUNT_POINT*>(buf.data());
114 
115     memset(mmp, 0, sizeof(MOUNTMGR_MOUNT_POINT));
116 
117     if (symlink.length() > 0) {
118         mmp->SymbolicLinkNameOffset = sizeof(MOUNTMGR_MOUNT_POINT);
119         mmp->SymbolicLinkNameLength = (USHORT)(symlink.length() * sizeof(WCHAR));
120         memcpy((uint8_t*)mmp + mmp->SymbolicLinkNameOffset, symlink.data(), symlink.length() * sizeof(WCHAR));
121     }
122 
123     if (unique_id.length() > 0) {
124         if (mmp->SymbolicLinkNameLength == 0)
125             mmp->UniqueIdOffset = sizeof(MOUNTMGR_MOUNT_POINT);
126         else
127             mmp->UniqueIdOffset = mmp->SymbolicLinkNameOffset + mmp->SymbolicLinkNameLength;
128 
129         mmp->UniqueIdLength = (USHORT)(unique_id.length() * sizeof(WCHAR));
130         memcpy((uint8_t*)mmp + mmp->UniqueIdOffset, unique_id.data(), unique_id.length() * sizeof(WCHAR));
131     }
132 
133     if (device_name.length() > 0) {
134         if (mmp->SymbolicLinkNameLength == 0 && mmp->UniqueIdOffset == 0)
135             mmp->DeviceNameOffset = sizeof(MOUNTMGR_MOUNT_POINT);
136         else if (mmp->SymbolicLinkNameLength != 0)
137             mmp->DeviceNameOffset = mmp->SymbolicLinkNameOffset + mmp->SymbolicLinkNameLength;
138         else
139             mmp->DeviceNameOffset = mmp->UniqueIdOffset + mmp->UniqueIdLength;
140 
141         mmp->DeviceNameLength = (USHORT)(device_name.length() * sizeof(WCHAR));
142         memcpy((uint8_t*)mmp + mmp->DeviceNameOffset, device_name.data(), device_name.length() * sizeof(WCHAR));
143     }
144 
145     vector<uint8_t> buf2(sizeof(MOUNTMGR_MOUNT_POINTS));
146     auto mmps = reinterpret_cast<MOUNTMGR_MOUNT_POINTS*>(buf2.data());
147 
148     Status = NtDeviceIoControlFile(h, nullptr, nullptr, nullptr, &iosb, IOCTL_MOUNTMGR_QUERY_POINTS,
149                                    buf.data(), (ULONG)buf.size(), buf2.data(), (ULONG)buf2.size());
150 
151     if (!NT_SUCCESS(Status) && Status != STATUS_BUFFER_OVERFLOW)
152         throw ntstatus_error(Status);
153 
154     buf2.resize(mmps->Size);
155     mmps = reinterpret_cast<MOUNTMGR_MOUNT_POINTS*>(buf2.data());
156 
157     Status = NtDeviceIoControlFile(h, nullptr, nullptr, nullptr, &iosb, IOCTL_MOUNTMGR_QUERY_POINTS,
158                                    buf.data(), (ULONG)buf.size(), buf2.data(), (ULONG)buf2.size());
159 
160     if (!NT_SUCCESS(Status))
161         throw ntstatus_error(Status);
162 
163     for (ULONG i = 0; i < mmps->NumberOfMountPoints; i++) {
164         wstring_view mpsl, mpdn;
165         string_view mpuid;
166 
167         if (mmps->MountPoints[i].SymbolicLinkNameLength)
168             mpsl = wstring_view((WCHAR*)((uint8_t*)mmps + mmps->MountPoints[i].SymbolicLinkNameOffset), mmps->MountPoints[i].SymbolicLinkNameLength / sizeof(WCHAR));
169 
170         if (mmps->MountPoints[i].UniqueIdLength)
171             mpuid = string_view((char*)((uint8_t*)mmps + mmps->MountPoints[i].UniqueIdOffset), mmps->MountPoints[i].UniqueIdLength);
172 
173         if (mmps->MountPoints[i].DeviceNameLength)
174             mpdn = wstring_view((WCHAR*)((uint8_t*)mmps + mmps->MountPoints[i].DeviceNameOffset), mmps->MountPoints[i].DeviceNameLength / sizeof(WCHAR));
175 
176 #ifndef __REACTOS__
177         v.emplace_back(mpsl, mpuid, mpdn);
178 #else
179         v.push_back(mountmgr_point(mpsl, mpuid, mpdn));
180 #endif
181     }
182 
183     return v;
184 }
185