1aed50d7eSPierre Schweitzer #include "shellext.h"
2aed50d7eSPierre Schweitzer #ifndef __REACTOS__
3aed50d7eSPierre Schweitzer #include "mountmgr.h"
4aed50d7eSPierre Schweitzer #else
5aed50d7eSPierre Schweitzer #include "mountmgr_local.h"
6aed50d7eSPierre Schweitzer #endif
7aed50d7eSPierre Schweitzer #include <mountmgr.h>
8aed50d7eSPierre Schweitzer 
9aed50d7eSPierre Schweitzer using namespace std;
10aed50d7eSPierre Schweitzer 
mountmgr()11aed50d7eSPierre Schweitzer mountmgr::mountmgr() {
12aed50d7eSPierre Schweitzer     UNICODE_STRING us;
13aed50d7eSPierre Schweitzer     OBJECT_ATTRIBUTES attr;
14aed50d7eSPierre Schweitzer     IO_STATUS_BLOCK iosb;
15aed50d7eSPierre Schweitzer     NTSTATUS Status;
16aed50d7eSPierre Schweitzer 
17aed50d7eSPierre Schweitzer     RtlInitUnicodeString(&us, MOUNTMGR_DEVICE_NAME);
18aed50d7eSPierre Schweitzer     InitializeObjectAttributes(&attr, &us, 0, nullptr, nullptr);
19aed50d7eSPierre Schweitzer 
20aed50d7eSPierre Schweitzer     Status = NtOpenFile(&h, FILE_GENERIC_READ | FILE_GENERIC_WRITE, &attr, &iosb,
21aed50d7eSPierre Schweitzer                         FILE_SHARE_READ, FILE_SYNCHRONOUS_IO_ALERT);
22aed50d7eSPierre Schweitzer 
23aed50d7eSPierre Schweitzer     if (!NT_SUCCESS(Status))
24aed50d7eSPierre Schweitzer         throw ntstatus_error(Status);
25aed50d7eSPierre Schweitzer }
26aed50d7eSPierre Schweitzer 
~mountmgr()27aed50d7eSPierre Schweitzer mountmgr::~mountmgr() {
28aed50d7eSPierre Schweitzer     NtClose(h);
29aed50d7eSPierre Schweitzer }
30aed50d7eSPierre Schweitzer 
create_point(const wstring_view & symlink,const wstring_view & device) const31aed50d7eSPierre Schweitzer void mountmgr::create_point(const wstring_view& symlink, const wstring_view& device) const {
32aed50d7eSPierre Schweitzer     NTSTATUS Status;
33aed50d7eSPierre Schweitzer     IO_STATUS_BLOCK iosb;
34aed50d7eSPierre Schweitzer 
35aed50d7eSPierre Schweitzer     vector<uint8_t> buf(sizeof(MOUNTMGR_CREATE_POINT_INPUT) + ((symlink.length() + device.length()) * sizeof(WCHAR)));
36aed50d7eSPierre Schweitzer     auto mcpi = reinterpret_cast<MOUNTMGR_CREATE_POINT_INPUT*>(buf.data());
37aed50d7eSPierre Schweitzer 
38aed50d7eSPierre Schweitzer     mcpi->SymbolicLinkNameOffset = sizeof(MOUNTMGR_CREATE_POINT_INPUT);
39aed50d7eSPierre Schweitzer     mcpi->SymbolicLinkNameLength = (USHORT)(symlink.length() * sizeof(WCHAR));
40aed50d7eSPierre Schweitzer     mcpi->DeviceNameOffset = (USHORT)(mcpi->SymbolicLinkNameOffset + mcpi->SymbolicLinkNameLength);
41aed50d7eSPierre Schweitzer     mcpi->DeviceNameLength = (USHORT)(device.length() * sizeof(WCHAR));
42aed50d7eSPierre Schweitzer 
43aed50d7eSPierre Schweitzer     memcpy((uint8_t*)mcpi + mcpi->SymbolicLinkNameOffset, symlink.data(), symlink.length() * sizeof(WCHAR));
44aed50d7eSPierre Schweitzer     memcpy((uint8_t*)mcpi + mcpi->DeviceNameOffset, device.data(), device.length() * sizeof(WCHAR));
45aed50d7eSPierre Schweitzer 
46aed50d7eSPierre Schweitzer     Status = NtDeviceIoControlFile(h, nullptr, nullptr, nullptr, &iosb, IOCTL_MOUNTMGR_CREATE_POINT,
47aed50d7eSPierre Schweitzer                                    buf.data(), (ULONG)buf.size(), nullptr, 0);
48aed50d7eSPierre Schweitzer 
49aed50d7eSPierre Schweitzer     if (!NT_SUCCESS(Status))
50aed50d7eSPierre Schweitzer         throw ntstatus_error(Status);
51aed50d7eSPierre Schweitzer }
52aed50d7eSPierre Schweitzer 
delete_points(const wstring_view & symlink,const wstring_view & unique_id,const wstring_view & device_name) const53aed50d7eSPierre Schweitzer void mountmgr::delete_points(const wstring_view& symlink, const wstring_view& unique_id, const wstring_view& device_name) const {
54aed50d7eSPierre Schweitzer     NTSTATUS Status;
55aed50d7eSPierre Schweitzer     IO_STATUS_BLOCK iosb;
56aed50d7eSPierre Schweitzer 
57aed50d7eSPierre Schweitzer     vector<uint8_t> buf(sizeof(MOUNTMGR_MOUNT_POINT) + ((symlink.length() + unique_id.length() + device_name.length()) * sizeof(WCHAR)));
58aed50d7eSPierre Schweitzer     auto mmp = reinterpret_cast<MOUNTMGR_MOUNT_POINT*>(buf.data());
59aed50d7eSPierre Schweitzer 
60aed50d7eSPierre Schweitzer     memset(mmp, 0, sizeof(MOUNTMGR_MOUNT_POINT));
61aed50d7eSPierre Schweitzer 
62aed50d7eSPierre Schweitzer     if (symlink.length() > 0) {
63aed50d7eSPierre Schweitzer         mmp->SymbolicLinkNameOffset = sizeof(MOUNTMGR_MOUNT_POINT);
64aed50d7eSPierre Schweitzer         mmp->SymbolicLinkNameLength = (USHORT)(symlink.length() * sizeof(WCHAR));
65aed50d7eSPierre Schweitzer         memcpy((uint8_t*)mmp + mmp->SymbolicLinkNameOffset, symlink.data(), symlink.length() * sizeof(WCHAR));
66aed50d7eSPierre Schweitzer     }
67aed50d7eSPierre Schweitzer 
68aed50d7eSPierre Schweitzer     if (unique_id.length() > 0) {
69aed50d7eSPierre Schweitzer         if (mmp->SymbolicLinkNameLength == 0)
70aed50d7eSPierre Schweitzer             mmp->UniqueIdOffset = sizeof(MOUNTMGR_MOUNT_POINT);
71aed50d7eSPierre Schweitzer         else
72aed50d7eSPierre Schweitzer             mmp->UniqueIdOffset = mmp->SymbolicLinkNameOffset + mmp->SymbolicLinkNameLength;
73aed50d7eSPierre Schweitzer 
74aed50d7eSPierre Schweitzer         mmp->UniqueIdLength = (USHORT)(unique_id.length() * sizeof(WCHAR));
75aed50d7eSPierre Schweitzer         memcpy((uint8_t*)mmp + mmp->UniqueIdOffset, unique_id.data(), unique_id.length() * sizeof(WCHAR));
76aed50d7eSPierre Schweitzer     }
77aed50d7eSPierre Schweitzer 
78aed50d7eSPierre Schweitzer     if (device_name.length() > 0) {
79aed50d7eSPierre Schweitzer         if (mmp->SymbolicLinkNameLength == 0 && mmp->UniqueIdOffset == 0)
80aed50d7eSPierre Schweitzer             mmp->DeviceNameOffset = sizeof(MOUNTMGR_MOUNT_POINT);
81aed50d7eSPierre Schweitzer         else if (mmp->SymbolicLinkNameLength != 0)
82aed50d7eSPierre Schweitzer             mmp->DeviceNameOffset = mmp->SymbolicLinkNameOffset + mmp->SymbolicLinkNameLength;
83aed50d7eSPierre Schweitzer         else
84aed50d7eSPierre Schweitzer             mmp->DeviceNameOffset = mmp->UniqueIdOffset + mmp->UniqueIdLength;
85aed50d7eSPierre Schweitzer 
86aed50d7eSPierre Schweitzer         mmp->DeviceNameLength = (USHORT)(device_name.length() * sizeof(WCHAR));
87aed50d7eSPierre Schweitzer         memcpy((uint8_t*)mmp + mmp->DeviceNameOffset, device_name.data(), device_name.length() * sizeof(WCHAR));
88aed50d7eSPierre Schweitzer     }
89aed50d7eSPierre Schweitzer 
90aed50d7eSPierre Schweitzer     vector<uint8_t> buf2(sizeof(MOUNTMGR_MOUNT_POINTS));
91aed50d7eSPierre Schweitzer     auto mmps = reinterpret_cast<MOUNTMGR_MOUNT_POINTS*>(buf2.data());
92aed50d7eSPierre Schweitzer 
93aed50d7eSPierre Schweitzer     Status = NtDeviceIoControlFile(h, nullptr, nullptr, nullptr, &iosb, IOCTL_MOUNTMGR_DELETE_POINTS,
94aed50d7eSPierre Schweitzer                                    buf.data(), (ULONG)buf.size(), buf2.data(), (ULONG)buf2.size());
95aed50d7eSPierre Schweitzer 
96aed50d7eSPierre Schweitzer     if (Status == STATUS_BUFFER_OVERFLOW) {
97aed50d7eSPierre Schweitzer         buf2.resize(mmps->Size);
98aed50d7eSPierre Schweitzer 
99aed50d7eSPierre Schweitzer         Status = NtDeviceIoControlFile(h, nullptr, nullptr, nullptr, &iosb, IOCTL_MOUNTMGR_DELETE_POINTS,
100aed50d7eSPierre Schweitzer                                        buf.data(), (ULONG)buf.size(), buf2.data(), (ULONG)buf2.size());
101aed50d7eSPierre Schweitzer     }
102aed50d7eSPierre Schweitzer 
103aed50d7eSPierre Schweitzer     if (!NT_SUCCESS(Status))
104aed50d7eSPierre Schweitzer         throw ntstatus_error(Status);
105aed50d7eSPierre Schweitzer }
106aed50d7eSPierre Schweitzer 
query_points(const wstring_view & symlink,const wstring_view & unique_id,const wstring_view & device_name) const107aed50d7eSPierre Schweitzer vector<mountmgr_point> mountmgr::query_points(const wstring_view& symlink, const wstring_view& unique_id, const wstring_view& device_name) const {
108aed50d7eSPierre Schweitzer     NTSTATUS Status;
109aed50d7eSPierre Schweitzer     IO_STATUS_BLOCK iosb;
110aed50d7eSPierre Schweitzer     vector<mountmgr_point> v;
111aed50d7eSPierre Schweitzer 
112aed50d7eSPierre Schweitzer     vector<uint8_t> buf(sizeof(MOUNTMGR_MOUNT_POINT) + ((symlink.length() + unique_id.length() + device_name.length()) * sizeof(WCHAR)));
113aed50d7eSPierre Schweitzer     auto mmp = reinterpret_cast<MOUNTMGR_MOUNT_POINT*>(buf.data());
114aed50d7eSPierre Schweitzer 
115aed50d7eSPierre Schweitzer     memset(mmp, 0, sizeof(MOUNTMGR_MOUNT_POINT));
116aed50d7eSPierre Schweitzer 
117aed50d7eSPierre Schweitzer     if (symlink.length() > 0) {
118aed50d7eSPierre Schweitzer         mmp->SymbolicLinkNameOffset = sizeof(MOUNTMGR_MOUNT_POINT);
119aed50d7eSPierre Schweitzer         mmp->SymbolicLinkNameLength = (USHORT)(symlink.length() * sizeof(WCHAR));
120aed50d7eSPierre Schweitzer         memcpy((uint8_t*)mmp + mmp->SymbolicLinkNameOffset, symlink.data(), symlink.length() * sizeof(WCHAR));
121aed50d7eSPierre Schweitzer     }
122aed50d7eSPierre Schweitzer 
123aed50d7eSPierre Schweitzer     if (unique_id.length() > 0) {
124aed50d7eSPierre Schweitzer         if (mmp->SymbolicLinkNameLength == 0)
125aed50d7eSPierre Schweitzer             mmp->UniqueIdOffset = sizeof(MOUNTMGR_MOUNT_POINT);
126aed50d7eSPierre Schweitzer         else
127aed50d7eSPierre Schweitzer             mmp->UniqueIdOffset = mmp->SymbolicLinkNameOffset + mmp->SymbolicLinkNameLength;
128aed50d7eSPierre Schweitzer 
129aed50d7eSPierre Schweitzer         mmp->UniqueIdLength = (USHORT)(unique_id.length() * sizeof(WCHAR));
130aed50d7eSPierre Schweitzer         memcpy((uint8_t*)mmp + mmp->UniqueIdOffset, unique_id.data(), unique_id.length() * sizeof(WCHAR));
131aed50d7eSPierre Schweitzer     }
132aed50d7eSPierre Schweitzer 
133aed50d7eSPierre Schweitzer     if (device_name.length() > 0) {
134aed50d7eSPierre Schweitzer         if (mmp->SymbolicLinkNameLength == 0 && mmp->UniqueIdOffset == 0)
135aed50d7eSPierre Schweitzer             mmp->DeviceNameOffset = sizeof(MOUNTMGR_MOUNT_POINT);
136aed50d7eSPierre Schweitzer         else if (mmp->SymbolicLinkNameLength != 0)
137aed50d7eSPierre Schweitzer             mmp->DeviceNameOffset = mmp->SymbolicLinkNameOffset + mmp->SymbolicLinkNameLength;
138aed50d7eSPierre Schweitzer         else
139aed50d7eSPierre Schweitzer             mmp->DeviceNameOffset = mmp->UniqueIdOffset + mmp->UniqueIdLength;
140aed50d7eSPierre Schweitzer 
141aed50d7eSPierre Schweitzer         mmp->DeviceNameLength = (USHORT)(device_name.length() * sizeof(WCHAR));
142aed50d7eSPierre Schweitzer         memcpy((uint8_t*)mmp + mmp->DeviceNameOffset, device_name.data(), device_name.length() * sizeof(WCHAR));
143aed50d7eSPierre Schweitzer     }
144aed50d7eSPierre Schweitzer 
145aed50d7eSPierre Schweitzer     vector<uint8_t> buf2(sizeof(MOUNTMGR_MOUNT_POINTS));
146aed50d7eSPierre Schweitzer     auto mmps = reinterpret_cast<MOUNTMGR_MOUNT_POINTS*>(buf2.data());
147aed50d7eSPierre Schweitzer 
148aed50d7eSPierre Schweitzer     Status = NtDeviceIoControlFile(h, nullptr, nullptr, nullptr, &iosb, IOCTL_MOUNTMGR_QUERY_POINTS,
149aed50d7eSPierre Schweitzer                                    buf.data(), (ULONG)buf.size(), buf2.data(), (ULONG)buf2.size());
150aed50d7eSPierre Schweitzer 
151aed50d7eSPierre Schweitzer     if (!NT_SUCCESS(Status) && Status != STATUS_BUFFER_OVERFLOW)
152aed50d7eSPierre Schweitzer         throw ntstatus_error(Status);
153aed50d7eSPierre Schweitzer 
154aed50d7eSPierre Schweitzer     buf2.resize(mmps->Size);
155aed50d7eSPierre Schweitzer     mmps = reinterpret_cast<MOUNTMGR_MOUNT_POINTS*>(buf2.data());
156aed50d7eSPierre Schweitzer 
157aed50d7eSPierre Schweitzer     Status = NtDeviceIoControlFile(h, nullptr, nullptr, nullptr, &iosb, IOCTL_MOUNTMGR_QUERY_POINTS,
158aed50d7eSPierre Schweitzer                                    buf.data(), (ULONG)buf.size(), buf2.data(), (ULONG)buf2.size());
159aed50d7eSPierre Schweitzer 
160aed50d7eSPierre Schweitzer     if (!NT_SUCCESS(Status))
161aed50d7eSPierre Schweitzer         throw ntstatus_error(Status);
162aed50d7eSPierre Schweitzer 
163aed50d7eSPierre Schweitzer     for (ULONG i = 0; i < mmps->NumberOfMountPoints; i++) {
164aed50d7eSPierre Schweitzer         wstring_view mpsl, mpdn;
165aed50d7eSPierre Schweitzer         string_view mpuid;
166aed50d7eSPierre Schweitzer 
167aed50d7eSPierre Schweitzer         if (mmps->MountPoints[i].SymbolicLinkNameLength)
168aed50d7eSPierre Schweitzer             mpsl = wstring_view((WCHAR*)((uint8_t*)mmps + mmps->MountPoints[i].SymbolicLinkNameOffset), mmps->MountPoints[i].SymbolicLinkNameLength / sizeof(WCHAR));
169aed50d7eSPierre Schweitzer 
170aed50d7eSPierre Schweitzer         if (mmps->MountPoints[i].UniqueIdLength)
171aed50d7eSPierre Schweitzer             mpuid = string_view((char*)((uint8_t*)mmps + mmps->MountPoints[i].UniqueIdOffset), mmps->MountPoints[i].UniqueIdLength);
172aed50d7eSPierre Schweitzer 
173aed50d7eSPierre Schweitzer         if (mmps->MountPoints[i].DeviceNameLength)
174aed50d7eSPierre Schweitzer             mpdn = wstring_view((WCHAR*)((uint8_t*)mmps + mmps->MountPoints[i].DeviceNameOffset), mmps->MountPoints[i].DeviceNameLength / sizeof(WCHAR));
175aed50d7eSPierre Schweitzer 
176*5f779048SPierre Schweitzer #ifndef __REACTOS__
177aed50d7eSPierre Schweitzer         v.emplace_back(mpsl, mpuid, mpdn);
178*5f779048SPierre Schweitzer #else
179*5f779048SPierre Schweitzer         v.push_back(mountmgr_point(mpsl, mpuid, mpdn));
180*5f779048SPierre Schweitzer #endif
181aed50d7eSPierre Schweitzer     }
182aed50d7eSPierre Schweitzer 
183aed50d7eSPierre Schweitzer     return v;
184aed50d7eSPierre Schweitzer }
185