1aed50d7eSPierre Schweitzer #include "shellext.h"
2aed50d7eSPierre Schweitzer #ifndef __REACTOS__
3aed50d7eSPierre Schweitzer #include "mountmgr.h"
4aed50d7eSPierre Schweitzer #else
5aed50d7eSPierre Schweitzer #include "mountmgr_local.h"
6aed50d7eSPierre Schweitzer #endif
7aed50d7eSPierre Schweitzer #include <mountmgr.h>
8aed50d7eSPierre Schweitzer 
9aed50d7eSPierre Schweitzer using namespace std;
10aed50d7eSPierre Schweitzer 
11aed50d7eSPierre Schweitzer mountmgr::mountmgr() {
12aed50d7eSPierre Schweitzer     UNICODE_STRING us;
13aed50d7eSPierre Schweitzer     OBJECT_ATTRIBUTES attr;
14aed50d7eSPierre Schweitzer     IO_STATUS_BLOCK iosb;
15aed50d7eSPierre Schweitzer     NTSTATUS Status;
16aed50d7eSPierre Schweitzer 
17aed50d7eSPierre Schweitzer     RtlInitUnicodeString(&us, MOUNTMGR_DEVICE_NAME);
18aed50d7eSPierre Schweitzer     InitializeObjectAttributes(&attr, &us, 0, nullptr, nullptr);
19aed50d7eSPierre Schweitzer 
20aed50d7eSPierre Schweitzer     Status = NtOpenFile(&h, FILE_GENERIC_READ | FILE_GENERIC_WRITE, &attr, &iosb,
21aed50d7eSPierre Schweitzer                         FILE_SHARE_READ, FILE_SYNCHRONOUS_IO_ALERT);
22aed50d7eSPierre Schweitzer 
23aed50d7eSPierre Schweitzer     if (!NT_SUCCESS(Status))
24aed50d7eSPierre Schweitzer         throw ntstatus_error(Status);
25aed50d7eSPierre Schweitzer }
26aed50d7eSPierre Schweitzer 
27aed50d7eSPierre Schweitzer mountmgr::~mountmgr() {
28aed50d7eSPierre Schweitzer     NtClose(h);
29aed50d7eSPierre Schweitzer }
30aed50d7eSPierre Schweitzer 
31aed50d7eSPierre Schweitzer void mountmgr::create_point(const wstring_view& symlink, const wstring_view& device) const {
32aed50d7eSPierre Schweitzer     NTSTATUS Status;
33aed50d7eSPierre Schweitzer     IO_STATUS_BLOCK iosb;
34aed50d7eSPierre Schweitzer 
35aed50d7eSPierre Schweitzer     vector<uint8_t> buf(sizeof(MOUNTMGR_CREATE_POINT_INPUT) + ((symlink.length() + device.length()) * sizeof(WCHAR)));
361725ddfdSPierre Schweitzer #ifndef __REACTOS__
37aed50d7eSPierre Schweitzer     auto mcpi = reinterpret_cast<MOUNTMGR_CREATE_POINT_INPUT*>(buf.data());
381725ddfdSPierre Schweitzer #else
391725ddfdSPierre Schweitzer     auto mcpi = reinterpret_cast<MOUNTMGR_CREATE_POINT_INPUT*>(&buf[0]);
401725ddfdSPierre Schweitzer #endif
41aed50d7eSPierre Schweitzer 
42aed50d7eSPierre Schweitzer     mcpi->SymbolicLinkNameOffset = sizeof(MOUNTMGR_CREATE_POINT_INPUT);
43aed50d7eSPierre Schweitzer     mcpi->SymbolicLinkNameLength = (USHORT)(symlink.length() * sizeof(WCHAR));
44aed50d7eSPierre Schweitzer     mcpi->DeviceNameOffset = (USHORT)(mcpi->SymbolicLinkNameOffset + mcpi->SymbolicLinkNameLength);
45aed50d7eSPierre Schweitzer     mcpi->DeviceNameLength = (USHORT)(device.length() * sizeof(WCHAR));
46aed50d7eSPierre Schweitzer 
47aed50d7eSPierre Schweitzer     memcpy((uint8_t*)mcpi + mcpi->SymbolicLinkNameOffset, symlink.data(), symlink.length() * sizeof(WCHAR));
48aed50d7eSPierre Schweitzer     memcpy((uint8_t*)mcpi + mcpi->DeviceNameOffset, device.data(), device.length() * sizeof(WCHAR));
49aed50d7eSPierre Schweitzer 
50aed50d7eSPierre Schweitzer     Status = NtDeviceIoControlFile(h, nullptr, nullptr, nullptr, &iosb, IOCTL_MOUNTMGR_CREATE_POINT,
51a3c13c62SPierre Schweitzer #ifndef __REACTOS__
52aed50d7eSPierre Schweitzer                                    buf.data(), (ULONG)buf.size(), nullptr, 0);
53a3c13c62SPierre Schweitzer #else
54a3c13c62SPierre Schweitzer                                    &buf[0], (ULONG)buf.size(), nullptr, 0);
55a3c13c62SPierre Schweitzer #endif
56aed50d7eSPierre Schweitzer 
57aed50d7eSPierre Schweitzer     if (!NT_SUCCESS(Status))
58aed50d7eSPierre Schweitzer         throw ntstatus_error(Status);
59aed50d7eSPierre Schweitzer }
60aed50d7eSPierre Schweitzer 
61aed50d7eSPierre Schweitzer void mountmgr::delete_points(const wstring_view& symlink, const wstring_view& unique_id, const wstring_view& device_name) const {
62aed50d7eSPierre Schweitzer     NTSTATUS Status;
63aed50d7eSPierre Schweitzer     IO_STATUS_BLOCK iosb;
64aed50d7eSPierre Schweitzer 
65aed50d7eSPierre Schweitzer     vector<uint8_t> buf(sizeof(MOUNTMGR_MOUNT_POINT) + ((symlink.length() + unique_id.length() + device_name.length()) * sizeof(WCHAR)));
661725ddfdSPierre Schweitzer #ifndef __REACTOS__
67aed50d7eSPierre Schweitzer     auto mmp = reinterpret_cast<MOUNTMGR_MOUNT_POINT*>(buf.data());
681725ddfdSPierre Schweitzer #else
691725ddfdSPierre Schweitzer     auto mmp = reinterpret_cast<MOUNTMGR_MOUNT_POINT*>(&buf[0]);
701725ddfdSPierre Schweitzer #endif
71aed50d7eSPierre Schweitzer 
72aed50d7eSPierre Schweitzer     memset(mmp, 0, sizeof(MOUNTMGR_MOUNT_POINT));
73aed50d7eSPierre Schweitzer 
74aed50d7eSPierre Schweitzer     if (symlink.length() > 0) {
75aed50d7eSPierre Schweitzer         mmp->SymbolicLinkNameOffset = sizeof(MOUNTMGR_MOUNT_POINT);
76aed50d7eSPierre Schweitzer         mmp->SymbolicLinkNameLength = (USHORT)(symlink.length() * sizeof(WCHAR));
77aed50d7eSPierre Schweitzer         memcpy((uint8_t*)mmp + mmp->SymbolicLinkNameOffset, symlink.data(), symlink.length() * sizeof(WCHAR));
78aed50d7eSPierre Schweitzer     }
79aed50d7eSPierre Schweitzer 
80aed50d7eSPierre Schweitzer     if (unique_id.length() > 0) {
81aed50d7eSPierre Schweitzer         if (mmp->SymbolicLinkNameLength == 0)
82aed50d7eSPierre Schweitzer             mmp->UniqueIdOffset = sizeof(MOUNTMGR_MOUNT_POINT);
83aed50d7eSPierre Schweitzer         else
84aed50d7eSPierre Schweitzer             mmp->UniqueIdOffset = mmp->SymbolicLinkNameOffset + mmp->SymbolicLinkNameLength;
85aed50d7eSPierre Schweitzer 
86aed50d7eSPierre Schweitzer         mmp->UniqueIdLength = (USHORT)(unique_id.length() * sizeof(WCHAR));
87aed50d7eSPierre Schweitzer         memcpy((uint8_t*)mmp + mmp->UniqueIdOffset, unique_id.data(), unique_id.length() * sizeof(WCHAR));
88aed50d7eSPierre Schweitzer     }
89aed50d7eSPierre Schweitzer 
90aed50d7eSPierre Schweitzer     if (device_name.length() > 0) {
91aed50d7eSPierre Schweitzer         if (mmp->SymbolicLinkNameLength == 0 && mmp->UniqueIdOffset == 0)
92aed50d7eSPierre Schweitzer             mmp->DeviceNameOffset = sizeof(MOUNTMGR_MOUNT_POINT);
93aed50d7eSPierre Schweitzer         else if (mmp->SymbolicLinkNameLength != 0)
94aed50d7eSPierre Schweitzer             mmp->DeviceNameOffset = mmp->SymbolicLinkNameOffset + mmp->SymbolicLinkNameLength;
95aed50d7eSPierre Schweitzer         else
96aed50d7eSPierre Schweitzer             mmp->DeviceNameOffset = mmp->UniqueIdOffset + mmp->UniqueIdLength;
97aed50d7eSPierre Schweitzer 
98aed50d7eSPierre Schweitzer         mmp->DeviceNameLength = (USHORT)(device_name.length() * sizeof(WCHAR));
99aed50d7eSPierre Schweitzer         memcpy((uint8_t*)mmp + mmp->DeviceNameOffset, device_name.data(), device_name.length() * sizeof(WCHAR));
100aed50d7eSPierre Schweitzer     }
101aed50d7eSPierre Schweitzer 
102aed50d7eSPierre Schweitzer     vector<uint8_t> buf2(sizeof(MOUNTMGR_MOUNT_POINTS));
1031725ddfdSPierre Schweitzer #ifndef __REACTOS__
104aed50d7eSPierre Schweitzer     auto mmps = reinterpret_cast<MOUNTMGR_MOUNT_POINTS*>(buf2.data());
1051725ddfdSPierre Schweitzer #else
1061725ddfdSPierre Schweitzer     auto mmps = reinterpret_cast<MOUNTMGR_MOUNT_POINTS*>(&buf2[0]);
1071725ddfdSPierre Schweitzer #endif
108aed50d7eSPierre Schweitzer 
109aed50d7eSPierre Schweitzer     Status = NtDeviceIoControlFile(h, nullptr, nullptr, nullptr, &iosb, IOCTL_MOUNTMGR_DELETE_POINTS,
110a3c13c62SPierre Schweitzer #ifndef __REACTOS__
111aed50d7eSPierre Schweitzer                                    buf.data(), (ULONG)buf.size(), buf2.data(), (ULONG)buf2.size());
112a3c13c62SPierre Schweitzer #else
113a3c13c62SPierre Schweitzer                                    &buf[0], (ULONG)buf.size(), &buf2[0], (ULONG)buf2.size());
114a3c13c62SPierre Schweitzer #endif
115aed50d7eSPierre Schweitzer 
116aed50d7eSPierre Schweitzer     if (Status == STATUS_BUFFER_OVERFLOW) {
117aed50d7eSPierre Schweitzer         buf2.resize(mmps->Size);
118aed50d7eSPierre Schweitzer 
119aed50d7eSPierre Schweitzer         Status = NtDeviceIoControlFile(h, nullptr, nullptr, nullptr, &iosb, IOCTL_MOUNTMGR_DELETE_POINTS,
120a3c13c62SPierre Schweitzer #ifndef __REACTOS__
121aed50d7eSPierre Schweitzer                                        buf.data(), (ULONG)buf.size(), buf2.data(), (ULONG)buf2.size());
122a3c13c62SPierre Schweitzer #else
123a3c13c62SPierre Schweitzer                                        &buf[0], (ULONG)buf.size(), &buf2[0], (ULONG)buf2.size());
124a3c13c62SPierre Schweitzer #endif
125aed50d7eSPierre Schweitzer     }
126aed50d7eSPierre Schweitzer 
127aed50d7eSPierre Schweitzer     if (!NT_SUCCESS(Status))
128aed50d7eSPierre Schweitzer         throw ntstatus_error(Status);
129aed50d7eSPierre Schweitzer }
130aed50d7eSPierre Schweitzer 
131aed50d7eSPierre Schweitzer vector<mountmgr_point> mountmgr::query_points(const wstring_view& symlink, const wstring_view& unique_id, const wstring_view& device_name) const {
132aed50d7eSPierre Schweitzer     NTSTATUS Status;
133aed50d7eSPierre Schweitzer     IO_STATUS_BLOCK iosb;
134aed50d7eSPierre Schweitzer     vector<mountmgr_point> v;
135aed50d7eSPierre Schweitzer 
136aed50d7eSPierre Schweitzer     vector<uint8_t> buf(sizeof(MOUNTMGR_MOUNT_POINT) + ((symlink.length() + unique_id.length() + device_name.length()) * sizeof(WCHAR)));
1371725ddfdSPierre Schweitzer #ifndef __REACTOS__
138aed50d7eSPierre Schweitzer     auto mmp = reinterpret_cast<MOUNTMGR_MOUNT_POINT*>(buf.data());
1391725ddfdSPierre Schweitzer #else
1401725ddfdSPierre Schweitzer     auto mmp = reinterpret_cast<MOUNTMGR_MOUNT_POINT*>(&buf[0]);
1411725ddfdSPierre Schweitzer #endif
142aed50d7eSPierre Schweitzer 
143aed50d7eSPierre Schweitzer     memset(mmp, 0, sizeof(MOUNTMGR_MOUNT_POINT));
144aed50d7eSPierre Schweitzer 
145aed50d7eSPierre Schweitzer     if (symlink.length() > 0) {
146aed50d7eSPierre Schweitzer         mmp->SymbolicLinkNameOffset = sizeof(MOUNTMGR_MOUNT_POINT);
147aed50d7eSPierre Schweitzer         mmp->SymbolicLinkNameLength = (USHORT)(symlink.length() * sizeof(WCHAR));
148aed50d7eSPierre Schweitzer         memcpy((uint8_t*)mmp + mmp->SymbolicLinkNameOffset, symlink.data(), symlink.length() * sizeof(WCHAR));
149aed50d7eSPierre Schweitzer     }
150aed50d7eSPierre Schweitzer 
151aed50d7eSPierre Schweitzer     if (unique_id.length() > 0) {
152aed50d7eSPierre Schweitzer         if (mmp->SymbolicLinkNameLength == 0)
153aed50d7eSPierre Schweitzer             mmp->UniqueIdOffset = sizeof(MOUNTMGR_MOUNT_POINT);
154aed50d7eSPierre Schweitzer         else
155aed50d7eSPierre Schweitzer             mmp->UniqueIdOffset = mmp->SymbolicLinkNameOffset + mmp->SymbolicLinkNameLength;
156aed50d7eSPierre Schweitzer 
157aed50d7eSPierre Schweitzer         mmp->UniqueIdLength = (USHORT)(unique_id.length() * sizeof(WCHAR));
158aed50d7eSPierre Schweitzer         memcpy((uint8_t*)mmp + mmp->UniqueIdOffset, unique_id.data(), unique_id.length() * sizeof(WCHAR));
159aed50d7eSPierre Schweitzer     }
160aed50d7eSPierre Schweitzer 
161aed50d7eSPierre Schweitzer     if (device_name.length() > 0) {
162aed50d7eSPierre Schweitzer         if (mmp->SymbolicLinkNameLength == 0 && mmp->UniqueIdOffset == 0)
163aed50d7eSPierre Schweitzer             mmp->DeviceNameOffset = sizeof(MOUNTMGR_MOUNT_POINT);
164aed50d7eSPierre Schweitzer         else if (mmp->SymbolicLinkNameLength != 0)
165aed50d7eSPierre Schweitzer             mmp->DeviceNameOffset = mmp->SymbolicLinkNameOffset + mmp->SymbolicLinkNameLength;
166aed50d7eSPierre Schweitzer         else
167aed50d7eSPierre Schweitzer             mmp->DeviceNameOffset = mmp->UniqueIdOffset + mmp->UniqueIdLength;
168aed50d7eSPierre Schweitzer 
169aed50d7eSPierre Schweitzer         mmp->DeviceNameLength = (USHORT)(device_name.length() * sizeof(WCHAR));
170aed50d7eSPierre Schweitzer         memcpy((uint8_t*)mmp + mmp->DeviceNameOffset, device_name.data(), device_name.length() * sizeof(WCHAR));
171aed50d7eSPierre Schweitzer     }
172aed50d7eSPierre Schweitzer 
173aed50d7eSPierre Schweitzer     vector<uint8_t> buf2(sizeof(MOUNTMGR_MOUNT_POINTS));
1741725ddfdSPierre Schweitzer #ifndef __REACTOS__
175aed50d7eSPierre Schweitzer     auto mmps = reinterpret_cast<MOUNTMGR_MOUNT_POINTS*>(buf2.data());
1761725ddfdSPierre Schweitzer #else
1771725ddfdSPierre Schweitzer     auto mmps = reinterpret_cast<MOUNTMGR_MOUNT_POINTS*>(&buf2[0]);
1781725ddfdSPierre Schweitzer #endif
179aed50d7eSPierre Schweitzer 
180aed50d7eSPierre Schweitzer     Status = NtDeviceIoControlFile(h, nullptr, nullptr, nullptr, &iosb, IOCTL_MOUNTMGR_QUERY_POINTS,
181a3c13c62SPierre Schweitzer #ifndef __REACTOS__
182aed50d7eSPierre Schweitzer                                    buf.data(), (ULONG)buf.size(), buf2.data(), (ULONG)buf2.size());
183a3c13c62SPierre Schweitzer #else
184a3c13c62SPierre Schweitzer                                    &buf[0], (ULONG)buf.size(), &buf2[0], (ULONG)buf2.size());
185a3c13c62SPierre Schweitzer #endif
186aed50d7eSPierre Schweitzer 
187aed50d7eSPierre Schweitzer     if (!NT_SUCCESS(Status) && Status != STATUS_BUFFER_OVERFLOW)
188aed50d7eSPierre Schweitzer         throw ntstatus_error(Status);
189aed50d7eSPierre Schweitzer 
190aed50d7eSPierre Schweitzer     buf2.resize(mmps->Size);
1911725ddfdSPierre Schweitzer #ifndef __REACTOS__
192aed50d7eSPierre Schweitzer     mmps = reinterpret_cast<MOUNTMGR_MOUNT_POINTS*>(buf2.data());
1931725ddfdSPierre Schweitzer #else
1941725ddfdSPierre Schweitzer     mmps = reinterpret_cast<MOUNTMGR_MOUNT_POINTS*>(&buf2[0]);
1951725ddfdSPierre Schweitzer #endif
196aed50d7eSPierre Schweitzer 
197aed50d7eSPierre Schweitzer     Status = NtDeviceIoControlFile(h, nullptr, nullptr, nullptr, &iosb, IOCTL_MOUNTMGR_QUERY_POINTS,
198a3c13c62SPierre Schweitzer #ifndef __REACTOS__
199aed50d7eSPierre Schweitzer                                    buf.data(), (ULONG)buf.size(), buf2.data(), (ULONG)buf2.size());
200a3c13c62SPierre Schweitzer #else
201a3c13c62SPierre Schweitzer                                    &buf[0], (ULONG)buf.size(), &buf2[0], (ULONG)buf2.size());
202a3c13c62SPierre Schweitzer #endif
203aed50d7eSPierre Schweitzer 
204aed50d7eSPierre Schweitzer     if (!NT_SUCCESS(Status))
205aed50d7eSPierre Schweitzer         throw ntstatus_error(Status);
206aed50d7eSPierre Schweitzer 
207aed50d7eSPierre Schweitzer     for (ULONG i = 0; i < mmps->NumberOfMountPoints; i++) {
208aed50d7eSPierre Schweitzer         wstring_view mpsl, mpdn;
209aed50d7eSPierre Schweitzer         string_view mpuid;
210aed50d7eSPierre Schweitzer 
211aed50d7eSPierre Schweitzer         if (mmps->MountPoints[i].SymbolicLinkNameLength)
212aed50d7eSPierre Schweitzer             mpsl = wstring_view((WCHAR*)((uint8_t*)mmps + mmps->MountPoints[i].SymbolicLinkNameOffset), mmps->MountPoints[i].SymbolicLinkNameLength / sizeof(WCHAR));
213aed50d7eSPierre Schweitzer 
214aed50d7eSPierre Schweitzer         if (mmps->MountPoints[i].UniqueIdLength)
215aed50d7eSPierre Schweitzer             mpuid = string_view((char*)((uint8_t*)mmps + mmps->MountPoints[i].UniqueIdOffset), mmps->MountPoints[i].UniqueIdLength);
216aed50d7eSPierre Schweitzer 
217aed50d7eSPierre Schweitzer         if (mmps->MountPoints[i].DeviceNameLength)
218aed50d7eSPierre Schweitzer             mpdn = wstring_view((WCHAR*)((uint8_t*)mmps + mmps->MountPoints[i].DeviceNameOffset), mmps->MountPoints[i].DeviceNameLength / sizeof(WCHAR));
219aed50d7eSPierre Schweitzer 
220*5f779048SPierre Schweitzer #ifndef __REACTOS__
221aed50d7eSPierre Schweitzer         v.emplace_back(mpsl, mpuid, mpdn);
222*5f779048SPierre Schweitzer #else
223*5f779048SPierre Schweitzer         v.push_back(mountmgr_point(mpsl, mpuid, mpdn));
224*5f779048SPierre Schweitzer #endif
225aed50d7eSPierre Schweitzer     }
226aed50d7eSPierre Schweitzer 
227aed50d7eSPierre Schweitzer     return v;
228aed50d7eSPierre Schweitzer }
229