1#!/usr/bin/env python3
2#
3# Unit tests for the notification control
4# Copyright (C) Stefan Metzmacher 2016
5#
6# This program is free software; you can redistribute it and/or modify
7# it under the terms of the GNU General Public License as published by
8# the Free Software Foundation; either version 3 of the License, or
9# (at your option) any later version.
10#
11# This program is distributed in the hope that it will be useful,
12# but WITHOUT ANY WARRANTY; without even the implied warranty of
13# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14# GNU General Public License for more details.
15#
16# You should have received a copy of the GNU General Public License
17# along with this program.  If not, see <http://www.gnu.org/licenses/>.
18
19from __future__ import print_function
20import optparse
21import sys
22import os
23
24sys.path.insert(0, "bin/python")
25import samba
26
27from samba.tests.subunitrun import SubunitOptions, TestProgram
28
29import samba.getopt as options
30
31from samba.auth import system_session
32from samba import ldb
33from samba.samdb import SamDB
34from samba.ndr import ndr_unpack
35from samba import gensec
36from samba.credentials import Credentials
37import samba.tests
38
39from samba.auth import AUTH_SESSION_INFO_DEFAULT_GROUPS, AUTH_SESSION_INFO_AUTHENTICATED, AUTH_SESSION_INFO_SIMPLE_PRIVILEGES
40
41from ldb import SCOPE_SUBTREE, SCOPE_ONELEVEL, SCOPE_BASE, LdbError
42from ldb import ERR_TIME_LIMIT_EXCEEDED, ERR_ADMIN_LIMIT_EXCEEDED, ERR_UNWILLING_TO_PERFORM
43from ldb import Message
44
45parser = optparse.OptionParser("notification.py [options] <host>")
46sambaopts = options.SambaOptions(parser)
47parser.add_option_group(sambaopts)
48parser.add_option_group(options.VersionOptions(parser))
49# use command line creds if available
50credopts = options.CredentialsOptions(parser)
51parser.add_option_group(credopts)
52subunitopts = SubunitOptions(parser)
53parser.add_option_group(subunitopts)
54opts, args = parser.parse_args()
55
56if len(args) < 1:
57    parser.print_usage()
58    sys.exit(1)
59
60url = args[0]
61
62lp = sambaopts.get_loadparm()
63creds = credopts.get_credentials(lp)
64
65
66class LDAPNotificationTest(samba.tests.TestCase):
67
68    def setUp(self):
69        super(LDAPNotificationTest, self).setUp()
70        self.ldb = SamDB(url, credentials=creds, session_info=system_session(lp), lp=lp)
71        self.base_dn = self.ldb.domain_dn()
72
73        res = self.ldb.search("", scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
74        self.assertEquals(len(res), 1)
75
76        self.user_sid_dn = "<SID=%s>" % str(ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["tokenGroups"][0]))
77
78    def test_simple_search(self):
79        """Testing a notification with an modify and a timeout"""
80        if not url.startswith("ldap"):
81            self.fail(msg="This test is only valid on ldap")
82
83        msg1 = None
84        search1 = self.ldb.search_iterator(base=self.user_sid_dn,
85                                           expression="(objectClass=*)",
86                                           scope=ldb.SCOPE_SUBTREE,
87                                           attrs=["name", "objectGUID", "displayName"])
88        for reply in search1:
89            self.assertIsInstance(reply, ldb.Message)
90            self.assertIsNone(msg1)
91            msg1 = reply
92        res1 = search1.result()
93
94        search2 = self.ldb.search_iterator(base=self.base_dn,
95                                           expression="(objectClass=*)",
96                                           scope=ldb.SCOPE_SUBTREE,
97                                           attrs=["name", "objectGUID", "displayName"])
98        refs2 = 0
99        msg2 = None
100        for reply in search2:
101            if isinstance(reply, str):
102                refs2 += 1
103                continue
104            self.assertIsInstance(reply, ldb.Message)
105            if reply["objectGUID"][0] == msg1["objectGUID"][0]:
106                self.assertIsNone(msg2)
107                msg2 = reply
108                self.assertEqual(msg1.dn, msg2.dn)
109                self.assertEqual(len(msg1), len(msg2))
110                self.assertEqual(msg1["name"], msg2["name"])
111                #self.assertEqual(msg1["displayName"], msg2["displayName"])
112        res2 = search2.result()
113
114        self.ldb.modify_ldif("""
115dn: """ + self.user_sid_dn + """
116changetype: modify
117replace: otherLoginWorkstations
118otherLoginWorkstations: BEFORE"
119""")
120        notify1 = self.ldb.search_iterator(base=self.base_dn,
121                                           expression="(objectClass=*)",
122                                           scope=ldb.SCOPE_SUBTREE,
123                                           attrs=["name", "objectGUID", "displayName"],
124                                           controls=["notification:1"],
125                                           timeout=1)
126
127        self.ldb.modify_ldif("""
128dn: """ + self.user_sid_dn + """
129changetype: modify
130replace: otherLoginWorkstations
131otherLoginWorkstations: AFTER"
132""")
133
134        msg3 = None
135        for reply in notify1:
136            self.assertIsInstance(reply, ldb.Message)
137            if reply["objectGUID"][0] == msg1["objectGUID"][0]:
138                self.assertIsNone(msg3)
139                msg3 = reply
140                self.assertEqual(msg1.dn, msg3.dn)
141                self.assertEqual(len(msg1), len(msg3))
142                self.assertEqual(msg1["name"], msg3["name"])
143                #self.assertEqual(msg1["displayName"], msg3["displayName"])
144        try:
145            res = notify1.result()
146            self.fail()
147        except LdbError as e10:
148            (num, _) = e10.args
149            self.assertEquals(num, ERR_TIME_LIMIT_EXCEEDED)
150        self.assertIsNotNone(msg3)
151
152        self.ldb.modify_ldif("""
153dn: """ + self.user_sid_dn + """
154changetype: delete
155delete: otherLoginWorkstations
156""")
157
158    def test_max_search(self):
159        """Testing the max allowed notifications"""
160        if not url.startswith("ldap"):
161            self.fail(msg="This test is only valid on ldap")
162
163        max_notifications = 5
164
165        notifies = [None] * (max_notifications + 1)
166        for i in range(0, max_notifications + 1):
167            notifies[i] = self.ldb.search_iterator(base=self.base_dn,
168                                                   expression="(objectClass=*)",
169                                                   scope=ldb.SCOPE_SUBTREE,
170                                                   attrs=["name"],
171                                                   controls=["notification:1"],
172                                                   timeout=1)
173        num_admin_limit = 0
174        num_time_limit = 0
175        for i in range(0, max_notifications + 1):
176            try:
177                for msg in notifies[i]:
178                    continue
179                res = notifies[i].result()
180                self.fail()
181            except LdbError as e:
182                (num, _) = e.args
183                if num == ERR_ADMIN_LIMIT_EXCEEDED:
184                    num_admin_limit += 1
185                    continue
186                if num == ERR_TIME_LIMIT_EXCEEDED:
187                    num_time_limit += 1
188                    continue
189                raise
190        self.assertEqual(num_admin_limit, 1)
191        self.assertEqual(num_time_limit, max_notifications)
192
193    def test_invalid_filter(self):
194        """Testing invalid filters for notifications"""
195        if not url.startswith("ldap"):
196            self.fail(msg="This test is only valid on ldap")
197
198        valid_attrs = ["objectClass", "objectGUID", "distinguishedName", "name"]
199
200        for va in valid_attrs:
201            try:
202                hnd = self.ldb.search_iterator(base=self.base_dn,
203                                               expression="(%s=*)" % va,
204                                               scope=ldb.SCOPE_SUBTREE,
205                                               attrs=["name"],
206                                               controls=["notification:1"],
207                                               timeout=1)
208                for reply in hnd:
209                    self.fail()
210                res = hnd.result()
211                self.fail()
212            except LdbError as e1:
213                (num, _) = e1.args
214                self.assertEquals(num, ERR_TIME_LIMIT_EXCEEDED)
215
216            try:
217                hnd = self.ldb.search_iterator(base=self.base_dn,
218                                               expression="(|(%s=*)(%s=value))" % (va, va),
219                                               scope=ldb.SCOPE_SUBTREE,
220                                               attrs=["name"],
221                                               controls=["notification:1"],
222                                               timeout=1)
223                for reply in hnd:
224                    self.fail()
225                res = hnd.result()
226                self.fail()
227            except LdbError as e2:
228                (num, _) = e2.args
229                self.assertEquals(num, ERR_TIME_LIMIT_EXCEEDED)
230
231            try:
232                hnd = self.ldb.search_iterator(base=self.base_dn,
233                                               expression="(&(%s=*)(%s=value))" % (va, va),
234                                               scope=ldb.SCOPE_SUBTREE,
235                                               attrs=["name"],
236                                               controls=["notification:1"],
237                                               timeout=0)
238                for reply in hnd:
239                    self.fail()
240                res = hnd.result()
241                self.fail()
242            except LdbError as e3:
243                (num, _) = e3.args
244                self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)
245
246            try:
247                hnd = self.ldb.search_iterator(base=self.base_dn,
248                                               expression="(%s=value)" % va,
249                                               scope=ldb.SCOPE_SUBTREE,
250                                               attrs=["name"],
251                                               controls=["notification:1"],
252                                               timeout=0)
253                for reply in hnd:
254                    self.fail()
255                res = hnd.result()
256                self.fail()
257            except LdbError as e4:
258                (num, _) = e4.args
259                self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)
260
261            try:
262                hnd = self.ldb.search_iterator(base=self.base_dn,
263                                               expression="(%s>=value)" % va,
264                                               scope=ldb.SCOPE_SUBTREE,
265                                               attrs=["name"],
266                                               controls=["notification:1"],
267                                               timeout=0)
268                for reply in hnd:
269                    self.fail()
270                res = hnd.result()
271                self.fail()
272            except LdbError as e5:
273                (num, _) = e5.args
274                self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)
275
276            try:
277                hnd = self.ldb.search_iterator(base=self.base_dn,
278                                               expression="(%s<=value)" % va,
279                                               scope=ldb.SCOPE_SUBTREE,
280                                               attrs=["name"],
281                                               controls=["notification:1"],
282                                               timeout=0)
283                for reply in hnd:
284                    self.fail()
285                res = hnd.result()
286                self.fail()
287            except LdbError as e6:
288                (num, _) = e6.args
289                self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)
290
291            try:
292                hnd = self.ldb.search_iterator(base=self.base_dn,
293                                               expression="(%s=*value*)" % va,
294                                               scope=ldb.SCOPE_SUBTREE,
295                                               attrs=["name"],
296                                               controls=["notification:1"],
297                                               timeout=0)
298                for reply in hnd:
299                    self.fail()
300                res = hnd.result()
301                self.fail()
302            except LdbError as e7:
303                (num, _) = e7.args
304                self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)
305
306            try:
307                hnd = self.ldb.search_iterator(base=self.base_dn,
308                                               expression="(!(%s=*))" % va,
309                                               scope=ldb.SCOPE_SUBTREE,
310                                               attrs=["name"],
311                                               controls=["notification:1"],
312                                               timeout=0)
313                for reply in hnd:
314                    self.fail()
315                res = hnd.result()
316                self.fail()
317            except LdbError as e8:
318                (num, _) = e8.args
319                self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)
320
321        res = self.ldb.search(base=self.ldb.get_schema_basedn(),
322                              expression="(objectClass=attributeSchema)",
323                              scope=ldb.SCOPE_ONELEVEL,
324                              attrs=["lDAPDisplayName"],
325                              controls=["paged_results:1:2500"])
326        for msg in res:
327            va = str(msg["lDAPDisplayName"][0])
328            if va in valid_attrs:
329                continue
330
331            try:
332                hnd = self.ldb.search_iterator(base=self.base_dn,
333                                               expression="(%s=*)" % va,
334                                               scope=ldb.SCOPE_SUBTREE,
335                                               attrs=["name"],
336                                               controls=["notification:1"],
337                                               timeout=0)
338                for reply in hnd:
339                    self.fail()
340                res = hnd.result()
341                self.fail()
342            except LdbError as e9:
343                (num, _) = e9.args
344                if num != ERR_UNWILLING_TO_PERFORM:
345                    print("va[%s]" % va)
346                self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)
347
348        try:
349            va = "noneAttributeName"
350            hnd = self.ldb.search_iterator(base=self.base_dn,
351                                           expression="(%s=*)" % va,
352                                           scope=ldb.SCOPE_SUBTREE,
353                                           attrs=["name"],
354                                           controls=["notification:1"],
355                                           timeout=0)
356            for reply in hnd:
357                self.fail()
358            res = hnd.result()
359            self.fail()
360        except LdbError as e11:
361            (num, _) = e11.args
362            if num != ERR_UNWILLING_TO_PERFORM:
363                print("va[%s]" % va)
364            self.assertEquals(num, ERR_UNWILLING_TO_PERFORM)
365
366
367if "://" not in url:
368    if os.path.isfile(url):
369        url = "tdb://%s" % url
370    else:
371        url = "ldap://%s" % url
372
373TestProgram(module=__name__, opts=subunitopts)
374