1# Unix SMB/CIFS implementation.
2# Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2010
3# Copyright (C) Matthias Dieter Wallnoefer 2009
4#
5# Based on the original in EJS:
6# Copyright (C) Andrew Tridgell <tridge@samba.org> 2005
7# Copyright (C) Giampaolo Lauria <lauria2@yahoo.com> 2011
8#
9# This program is free software; you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation; either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program.  If not, see <http://www.gnu.org/licenses/>.
21#
22
23"""Convenience functions for using the SAM."""
24
25import samba
26import ldb
27import time
28import base64
29import os
30import re
31from samba import dsdb, dsdb_dns
32from samba.ndr import ndr_unpack, ndr_pack
33from samba.dcerpc import drsblobs, misc
34from samba.common import normalise_int32
35from samba.compat import text_type
36from samba.compat import binary_type
37from samba.compat import get_bytes
38from samba.dcerpc import security
39
40__docformat__ = "restructuredText"
41
42
43def get_default_backend_store():
44    return "tdb"
45
46
47class SamDB(samba.Ldb):
48    """The SAM database."""
49
50    hash_oid_name = {}
51    hash_well_known = {}
52
53    def __init__(self, url=None, lp=None, modules_dir=None, session_info=None,
54                 credentials=None, flags=ldb.FLG_DONT_CREATE_DB,
55                 options=None, global_schema=True,
56                 auto_connect=True, am_rodc=None):
57        self.lp = lp
58        if not auto_connect:
59            url = None
60        elif url is None and lp is not None:
61            url = lp.samdb_url()
62
63        self.url = url
64
65        super(SamDB, self).__init__(url=url, lp=lp, modules_dir=modules_dir,
66                                    session_info=session_info, credentials=credentials, flags=flags,
67                                    options=options)
68
69        if global_schema:
70            dsdb._dsdb_set_global_schema(self)
71
72        if am_rodc is not None:
73            dsdb._dsdb_set_am_rodc(self, am_rodc)
74
75    def connect(self, url=None, flags=0, options=None):
76        '''connect to the database'''
77        if self.lp is not None and not os.path.exists(url):
78            url = self.lp.private_path(url)
79        self.url = url
80
81        super(SamDB, self).connect(url=url, flags=flags,
82                                   options=options)
83
84    def am_rodc(self):
85        '''return True if we are an RODC'''
86        return dsdb._am_rodc(self)
87
88    def am_pdc(self):
89        '''return True if we are an PDC emulator'''
90        return dsdb._am_pdc(self)
91
92    def domain_dn(self):
93        '''return the domain DN'''
94        return str(self.get_default_basedn())
95
96    def schema_dn(self):
97        '''return the schema partition dn'''
98        return str(self.get_schema_basedn())
99
100    def disable_account(self, search_filter):
101        """Disables an account
102
103        :param search_filter: LDAP filter to find the user (eg
104            samccountname=name)
105        """
106
107        flags = samba.dsdb.UF_ACCOUNTDISABLE
108        self.toggle_userAccountFlags(search_filter, flags, on=True)
109
110    def enable_account(self, search_filter):
111        """Enables an account
112
113        :param search_filter: LDAP filter to find the user (eg
114            samccountname=name)
115        """
116
117        flags = samba.dsdb.UF_ACCOUNTDISABLE | samba.dsdb.UF_PASSWD_NOTREQD
118        self.toggle_userAccountFlags(search_filter, flags, on=False)
119
120    def toggle_userAccountFlags(self, search_filter, flags, flags_str=None,
121                                on=True, strict=False):
122        """Toggle_userAccountFlags
123
124        :param search_filter: LDAP filter to find the user (eg
125            samccountname=name)
126        :param flags: samba.dsdb.UF_* flags
127        :param on: on=True (default) => set, on=False => unset
128        :param strict: strict=False (default) ignore if no action is needed
129                 strict=True raises an Exception if...
130        """
131        res = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
132                          expression=search_filter, attrs=["userAccountControl"])
133        if len(res) == 0:
134                raise Exception("Unable to find account where '%s'" % search_filter)
135        assert(len(res) == 1)
136        account_dn = res[0].dn
137
138        old_uac = int(res[0]["userAccountControl"][0])
139        if on:
140            if strict and (old_uac & flags):
141                error = "Account flag(s) '%s' already set" % flags_str
142                raise Exception(error)
143
144            new_uac = old_uac | flags
145        else:
146            if strict and not (old_uac & flags):
147                error = "Account flag(s) '%s' already unset" % flags_str
148                raise Exception(error)
149
150            new_uac = old_uac & ~flags
151
152        if old_uac == new_uac:
153            return
154
155        mod = """
156dn: %s
157changetype: modify
158delete: userAccountControl
159userAccountControl: %u
160add: userAccountControl
161userAccountControl: %u
162""" % (account_dn, old_uac, new_uac)
163        self.modify_ldif(mod)
164
165    def force_password_change_at_next_login(self, search_filter):
166        """Forces a password change at next login
167
168        :param search_filter: LDAP filter to find the user (eg
169            samccountname=name)
170        """
171        res = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
172                          expression=search_filter, attrs=[])
173        if len(res) == 0:
174                raise Exception('Unable to find user "%s"' % search_filter)
175        assert(len(res) == 1)
176        user_dn = res[0].dn
177
178        mod = """
179dn: %s
180changetype: modify
181replace: pwdLastSet
182pwdLastSet: 0
183""" % (user_dn)
184        self.modify_ldif(mod)
185
186    def newgroup(self, groupname, groupou=None, grouptype=None,
187                 description=None, mailaddress=None, notes=None, sd=None,
188                 gidnumber=None, nisdomain=None):
189        """Adds a new group with additional parameters
190
191        :param groupname: Name of the new group
192        :param grouptype: Type of the new group
193        :param description: Description of the new group
194        :param mailaddress: Email address of the new group
195        :param notes: Notes of the new group
196        :param gidnumber: GID Number of the new group
197        :param nisdomain: NIS Domain Name of the new group
198        :param sd: security descriptor of the object
199        """
200
201        group_dn = "CN=%s,%s,%s" % (groupname, (groupou or "CN=Users"), self.domain_dn())
202
203        # The new user record. Note the reliance on the SAMLDB module which
204        # fills in the default information
205        ldbmessage = {"dn": group_dn,
206                      "sAMAccountName": groupname,
207                      "objectClass": "group"}
208
209        if grouptype is not None:
210            ldbmessage["groupType"] = normalise_int32(grouptype)
211
212        if description is not None:
213            ldbmessage["description"] = description
214
215        if mailaddress is not None:
216            ldbmessage["mail"] = mailaddress
217
218        if notes is not None:
219            ldbmessage["info"] = notes
220
221        if gidnumber is not None:
222            ldbmessage["gidNumber"] = normalise_int32(gidnumber)
223
224        if nisdomain is not None:
225            ldbmessage["msSFU30Name"] = groupname
226            ldbmessage["msSFU30NisDomain"] = nisdomain
227
228        if sd is not None:
229            ldbmessage["nTSecurityDescriptor"] = ndr_pack(sd)
230
231        self.add(ldbmessage)
232
233    def deletegroup(self, groupname):
234        """Deletes a group
235
236        :param groupname: Name of the target group
237        """
238
239        groupfilter = "(&(sAMAccountName=%s)(objectCategory=%s,%s))" % (ldb.binary_encode(groupname), "CN=Group,CN=Schema,CN=Configuration", self.domain_dn())
240        self.transaction_start()
241        try:
242            targetgroup = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
243                                      expression=groupfilter, attrs=[])
244            if len(targetgroup) == 0:
245                raise Exception('Unable to find group "%s"' % groupname)
246            assert(len(targetgroup) == 1)
247            self.delete(targetgroup[0].dn)
248        except:
249            self.transaction_cancel()
250            raise
251        else:
252            self.transaction_commit()
253
254    def group_member_filter(self, member, member_types):
255        filter = ""
256
257        all_member_types = [ 'user',
258                             'group',
259                             'computer',
260                             'serviceaccount',
261                             'contact',
262                           ]
263
264        if 'all' in member_types:
265            member_types = all_member_types
266
267        for member_type in member_types:
268            if member_type not in all_member_types:
269                raise Exception('Invalid group member type "%s". '
270                                'Valid types are %s and all.' %
271                                (member_type, ", ".join(all_member_types)))
272
273        if 'user' in member_types:
274            filter += ('(&(sAMAccountName=%s)(samAccountType=%d))' %
275                       (ldb.binary_encode(member), dsdb.ATYPE_NORMAL_ACCOUNT))
276        if 'group' in member_types:
277            filter += ('(&(sAMAccountName=%s)'
278                       '(objectClass=group)'
279                       '(!(groupType:1.2.840.113556.1.4.803:=1)))' %
280                       ldb.binary_encode(member))
281        if 'computer' in member_types:
282            samaccountname = member
283            if member[-1] != '$':
284                samaccountname = "%s$" % member
285            filter += ('(&(samAccountType=%d)'
286                       '(!(objectCategory=msDS-ManagedServiceAccount))'
287                       '(sAMAccountName=%s))' %
288                       (dsdb.ATYPE_WORKSTATION_TRUST,
289                        ldb.binary_encode(samaccountname)))
290        if 'serviceaccount' in member_types:
291            samaccountname = member
292            if member[-1] != '$':
293                samaccountname = "%s$" % member
294            filter += ('(&(samAccountType=%d)'
295                       '(objectCategory=msDS-ManagedServiceAccount)'
296                       '(sAMAccountName=%s))' %
297                       (dsdb.ATYPE_WORKSTATION_TRUST,
298                        ldb.binary_encode(samaccountname)))
299        if 'contact' in member_types:
300            filter += ('(&(objectCategory=Person)(!(objectSid=*))(name=%s))' %
301                       ldb.binary_encode(member))
302
303        filter = "(|%s)" % filter
304
305        return filter
306
307    def add_remove_group_members(self, groupname, members,
308                                 add_members_operation=True,
309                                 member_types=[ 'user', 'group', 'computer' ],
310                                 member_base_dn=None):
311        """Adds or removes group members
312
313        :param groupname: Name of the target group
314        :param members: list of group members
315        :param add_members_operation: Defines if its an add or remove
316            operation
317        """
318
319        groupfilter = "(&(sAMAccountName=%s)(objectCategory=%s,%s))" % (
320            ldb.binary_encode(groupname), "CN=Group,CN=Schema,CN=Configuration", self.domain_dn())
321
322        self.transaction_start()
323        try:
324            targetgroup = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
325                                      expression=groupfilter, attrs=['member'])
326            if len(targetgroup) == 0:
327                raise Exception('Unable to find group "%s"' % groupname)
328            assert(len(targetgroup) == 1)
329
330            modified = False
331
332            addtargettogroup = """
333dn: %s
334changetype: modify
335""" % (str(targetgroup[0].dn))
336
337            for member in members:
338                targetmember_dn = None
339                if member_base_dn is None:
340                    member_base_dn = self.domain_dn()
341
342                try:
343                    membersid = security.dom_sid(member)
344                    targetmember_dn = "<SID=%s>" % str(membersid)
345                except TypeError as e:
346                    pass
347
348                if targetmember_dn is None:
349                    try:
350                        member_dn = ldb.Dn(self, member)
351                        if member_dn.get_linearized() == member_dn.extended_str(1):
352                            full_member_dn = self.normalize_dn_in_domain(member_dn)
353                        else:
354                            full_member_dn = member_dn
355                        targetmember_dn = full_member_dn.extended_str(1)
356                    except ValueError as e:
357                        pass
358
359                if targetmember_dn is None:
360                    filter = self.group_member_filter(member, member_types)
361                    targetmember = self.search(base=member_base_dn,
362                                               scope=ldb.SCOPE_SUBTREE,
363                                               expression=filter,
364                                               attrs=[])
365
366                    if len(targetmember) > 1:
367                        targetmemberlist_str = ""
368                        for msg in targetmember:
369                            targetmemberlist_str += "%s\n" % msg.get("dn")
370                        raise Exception('Found multiple results for "%s":\n%s' %
371                                        (member, targetmemberlist_str))
372                    if len(targetmember) != 1:
373                        raise Exception('Unable to find "%s". Operation cancelled.' % member)
374                    targetmember_dn = targetmember[0].dn.extended_str(1)
375
376                if add_members_operation is True and (targetgroup[0].get('member') is None or get_bytes(targetmember_dn) not in [str(x) for x in targetgroup[0]['member']]):
377                    modified = True
378                    addtargettogroup += """add: member
379member: %s
380""" % (str(targetmember_dn))
381
382                elif add_members_operation is False and (targetgroup[0].get('member') is not None and get_bytes(targetmember_dn) in targetgroup[0]['member']):
383                    modified = True
384                    addtargettogroup += """delete: member
385member: %s
386""" % (str(targetmember_dn))
387
388            if modified is True:
389                self.modify_ldif(addtargettogroup)
390
391        except:
392            self.transaction_cancel()
393            raise
394        else:
395            self.transaction_commit()
396
397    def newuser(self, username, password,
398                force_password_change_at_next_login_req=False,
399                useusernameascn=False, userou=None, surname=None, givenname=None,
400                initials=None, profilepath=None, scriptpath=None, homedrive=None,
401                homedirectory=None, jobtitle=None, department=None, company=None,
402                description=None, mailaddress=None, internetaddress=None,
403                telephonenumber=None, physicaldeliveryoffice=None, sd=None,
404                setpassword=True, uidnumber=None, gidnumber=None, gecos=None,
405                loginshell=None, uid=None, nisdomain=None, unixhome=None,
406                smartcard_required=False):
407        """Adds a new user with additional parameters
408
409        :param username: Name of the new user
410        :param password: Password for the new user
411        :param force_password_change_at_next_login_req: Force password change
412        :param useusernameascn: Use username as cn rather that firstname +
413            initials + lastname
414        :param userou: Object container (without domainDN postfix) for new user
415        :param surname: Surname of the new user
416        :param givenname: First name of the new user
417        :param initials: Initials of the new user
418        :param profilepath: Profile path of the new user
419        :param scriptpath: Logon script path of the new user
420        :param homedrive: Home drive of the new user
421        :param homedirectory: Home directory of the new user
422        :param jobtitle: Job title of the new user
423        :param department: Department of the new user
424        :param company: Company of the new user
425        :param description: of the new user
426        :param mailaddress: Email address of the new user
427        :param internetaddress: Home page of the new user
428        :param telephonenumber: Phone number of the new user
429        :param physicaldeliveryoffice: Office location of the new user
430        :param sd: security descriptor of the object
431        :param setpassword: optionally disable password reset
432        :param uidnumber: RFC2307 Unix numeric UID of the new user
433        :param gidnumber: RFC2307 Unix primary GID of the new user
434        :param gecos: RFC2307 Unix GECOS field of the new user
435        :param loginshell: RFC2307 Unix login shell of the new user
436        :param uid: RFC2307 Unix username of the new user
437        :param nisdomain: RFC2307 Unix NIS domain of the new user
438        :param unixhome: RFC2307 Unix home directory of the new user
439        :param smartcard_required: set the UF_SMARTCARD_REQUIRED bit of the new user
440        """
441
442        displayname = ""
443        if givenname is not None:
444            displayname += givenname
445
446        if initials is not None:
447            displayname += ' %s.' % initials
448
449        if surname is not None:
450            displayname += ' %s' % surname
451
452        cn = username
453        if useusernameascn is None and displayname != "":
454            cn = displayname
455
456        user_dn = "CN=%s,%s,%s" % (cn, (userou or "CN=Users"), self.domain_dn())
457
458        dnsdomain = ldb.Dn(self, self.domain_dn()).canonical_str().replace("/", "")
459        user_principal_name = "%s@%s" % (username, dnsdomain)
460        # The new user record. Note the reliance on the SAMLDB module which
461        # fills in the default information
462        ldbmessage = {"dn": user_dn,
463                      "sAMAccountName": username,
464                      "userPrincipalName": user_principal_name,
465                      "objectClass": "user"}
466
467        if smartcard_required:
468            ldbmessage["userAccountControl"] = str(dsdb.UF_NORMAL_ACCOUNT |
469                                                   dsdb.UF_SMARTCARD_REQUIRED)
470            setpassword = False
471
472        if surname is not None:
473            ldbmessage["sn"] = surname
474
475        if givenname is not None:
476            ldbmessage["givenName"] = givenname
477
478        if displayname != "":
479            ldbmessage["displayName"] = displayname
480            ldbmessage["name"] = displayname
481
482        if initials is not None:
483            ldbmessage["initials"] = '%s.' % initials
484
485        if profilepath is not None:
486            ldbmessage["profilePath"] = profilepath
487
488        if scriptpath is not None:
489            ldbmessage["scriptPath"] = scriptpath
490
491        if homedrive is not None:
492            ldbmessage["homeDrive"] = homedrive
493
494        if homedirectory is not None:
495            ldbmessage["homeDirectory"] = homedirectory
496
497        if jobtitle is not None:
498            ldbmessage["title"] = jobtitle
499
500        if department is not None:
501            ldbmessage["department"] = department
502
503        if company is not None:
504            ldbmessage["company"] = company
505
506        if description is not None:
507            ldbmessage["description"] = description
508
509        if mailaddress is not None:
510            ldbmessage["mail"] = mailaddress
511
512        if internetaddress is not None:
513            ldbmessage["wWWHomePage"] = internetaddress
514
515        if telephonenumber is not None:
516            ldbmessage["telephoneNumber"] = telephonenumber
517
518        if physicaldeliveryoffice is not None:
519            ldbmessage["physicalDeliveryOfficeName"] = physicaldeliveryoffice
520
521        if sd is not None:
522            ldbmessage["nTSecurityDescriptor"] = ndr_pack(sd)
523
524        ldbmessage2 = None
525        if any(map(lambda b: b is not None, (uid, uidnumber, gidnumber, gecos,
526                                             loginshell, nisdomain, unixhome))):
527            ldbmessage2 = ldb.Message()
528            ldbmessage2.dn = ldb.Dn(self, user_dn)
529            if uid is not None:
530                ldbmessage2["uid"] = ldb.MessageElement(str(uid), ldb.FLAG_MOD_REPLACE, 'uid')
531            if uidnumber is not None:
532                ldbmessage2["uidNumber"] = ldb.MessageElement(str(uidnumber), ldb.FLAG_MOD_REPLACE, 'uidNumber')
533            if gidnumber is not None:
534                ldbmessage2["gidNumber"] = ldb.MessageElement(str(gidnumber), ldb.FLAG_MOD_REPLACE, 'gidNumber')
535            if gecos is not None:
536                ldbmessage2["gecos"] = ldb.MessageElement(str(gecos), ldb.FLAG_MOD_REPLACE, 'gecos')
537            if loginshell is not None:
538                ldbmessage2["loginShell"] = ldb.MessageElement(str(loginshell), ldb.FLAG_MOD_REPLACE, 'loginShell')
539            if unixhome is not None:
540                ldbmessage2["unixHomeDirectory"] = ldb.MessageElement(
541                    str(unixhome), ldb.FLAG_MOD_REPLACE, 'unixHomeDirectory')
542            if nisdomain is not None:
543                ldbmessage2["msSFU30NisDomain"] = ldb.MessageElement(
544                    str(nisdomain), ldb.FLAG_MOD_REPLACE, 'msSFU30NisDomain')
545                ldbmessage2["msSFU30Name"] = ldb.MessageElement(
546                    str(username), ldb.FLAG_MOD_REPLACE, 'msSFU30Name')
547                ldbmessage2["unixUserPassword"] = ldb.MessageElement(
548                    'ABCD!efgh12345$67890', ldb.FLAG_MOD_REPLACE,
549                    'unixUserPassword')
550
551        self.transaction_start()
552        try:
553            self.add(ldbmessage)
554            if ldbmessage2:
555                self.modify(ldbmessage2)
556
557            # Sets the password for it
558            if setpassword:
559                self.setpassword(("(distinguishedName=%s)" %
560                                  ldb.binary_encode(user_dn)),
561                                 password,
562                                 force_password_change_at_next_login_req)
563        except:
564            self.transaction_cancel()
565            raise
566        else:
567            self.transaction_commit()
568
569    def newcontact(self,
570                   fullcontactname=None,
571                   ou=None,
572                   surname=None,
573                   givenname=None,
574                   initials=None,
575                   displayname=None,
576                   jobtitle=None,
577                   department=None,
578                   company=None,
579                   description=None,
580                   mailaddress=None,
581                   internetaddress=None,
582                   telephonenumber=None,
583                   mobilenumber=None,
584                   physicaldeliveryoffice=None):
585        """Adds a new contact with additional parameters
586
587        :param fullcontactname: Optional full name of the new contact
588        :param ou: Object container for new contact
589        :param surname: Surname of the new contact
590        :param givenname: First name of the new contact
591        :param initials: Initials of the new contact
592        :param displayname: displayName of the new contact
593        :param jobtitle: Job title of the new contact
594        :param department: Department of the new contact
595        :param company: Company of the new contact
596        :param description: Description of the new contact
597        :param mailaddress: Email address of the new contact
598        :param internetaddress: Home page of the new contact
599        :param telephonenumber: Phone number of the new contact
600        :param mobilenumber: Primary mobile number of the new contact
601        :param physicaldeliveryoffice: Office location of the new contact
602        """
603
604        # Prepare the contact name like the RSAT, using the name parts.
605        cn = ""
606        if givenname is not None:
607            cn += givenname
608
609        if initials is not None:
610            cn += ' %s.' % initials
611
612        if surname is not None:
613            cn += ' %s' % surname
614
615        # Use the specified fullcontactname instead of the previously prepared
616        # contact name, if it is specified.
617        # This is similar to the "Full name" value of the RSAT.
618        if fullcontactname is not None:
619            cn = fullcontactname
620
621        if fullcontactname is None and cn == "":
622            raise Exception('No name for contact specified')
623
624        contactcontainer_dn = self.domain_dn()
625        if ou:
626            contactcontainer_dn = self.normalize_dn_in_domain(ou)
627
628        contact_dn = "CN=%s,%s" % (cn, contactcontainer_dn)
629
630        ldbmessage = {"dn": contact_dn,
631                      "objectClass": "contact",
632                      }
633
634        if surname is not None:
635            ldbmessage["sn"] = surname
636
637        if givenname is not None:
638            ldbmessage["givenName"] = givenname
639
640        if displayname is not None:
641            ldbmessage["displayName"] = displayname
642
643        if initials is not None:
644            ldbmessage["initials"] = '%s.' % initials
645
646        if jobtitle is not None:
647            ldbmessage["title"] = jobtitle
648
649        if department is not None:
650            ldbmessage["department"] = department
651
652        if company is not None:
653            ldbmessage["company"] = company
654
655        if description is not None:
656            ldbmessage["description"] = description
657
658        if mailaddress is not None:
659            ldbmessage["mail"] = mailaddress
660
661        if internetaddress is not None:
662            ldbmessage["wWWHomePage"] = internetaddress
663
664        if telephonenumber is not None:
665            ldbmessage["telephoneNumber"] = telephonenumber
666
667        if mobilenumber is not None:
668            ldbmessage["mobile"] = mobilenumber
669
670        if physicaldeliveryoffice is not None:
671            ldbmessage["physicalDeliveryOfficeName"] = physicaldeliveryoffice
672
673        self.add(ldbmessage)
674
675        return cn
676
677    def newcomputer(self, computername, computerou=None, description=None,
678                    prepare_oldjoin=False, ip_address_list=None,
679                    service_principal_name_list=None):
680        """Adds a new user with additional parameters
681
682        :param computername: Name of the new computer
683        :param computerou: Object container for new computer
684        :param description: Description of the new computer
685        :param prepare_oldjoin: Preset computer password for oldjoin mechanism
686        :param ip_address_list: ip address list for DNS A or AAAA record
687        :param service_principal_name_list: string list of servicePincipalName
688        """
689
690        cn = re.sub(r"\$$", "", computername)
691        if cn.count('$'):
692            raise Exception('Illegal computername "%s"' % computername)
693        samaccountname = "%s$" % cn
694
695        computercontainer_dn = "CN=Computers,%s" % self.domain_dn()
696        if computerou:
697            computercontainer_dn = self.normalize_dn_in_domain(computerou)
698
699        computer_dn = "CN=%s,%s" % (cn, computercontainer_dn)
700
701        ldbmessage = {"dn": computer_dn,
702                      "sAMAccountName": samaccountname,
703                      "objectClass": "computer",
704                      }
705
706        if description is not None:
707            ldbmessage["description"] = description
708
709        if service_principal_name_list:
710            ldbmessage["servicePrincipalName"] = service_principal_name_list
711
712        accountcontrol = str(dsdb.UF_WORKSTATION_TRUST_ACCOUNT |
713                             dsdb.UF_ACCOUNTDISABLE)
714        if prepare_oldjoin:
715            accountcontrol = str(dsdb.UF_WORKSTATION_TRUST_ACCOUNT)
716        ldbmessage["userAccountControl"] = accountcontrol
717
718        if ip_address_list:
719            ldbmessage['dNSHostName'] = '{}.{}'.format(
720                cn, self.domain_dns_name())
721
722        self.transaction_start()
723        try:
724            self.add(ldbmessage)
725
726            if prepare_oldjoin:
727                password = cn.lower()
728                self.setpassword(("(distinguishedName=%s)" %
729                                  ldb.binary_encode(computer_dn)),
730                                 password, False)
731        except:
732            self.transaction_cancel()
733            raise
734        else:
735            self.transaction_commit()
736
737    def deleteuser(self, username):
738        """Deletes a user
739
740        :param username: Name of the target user
741        """
742
743        filter = "(&(sAMAccountName=%s)(objectCategory=%s,%s))" % (ldb.binary_encode(username), "CN=Person,CN=Schema,CN=Configuration", self.domain_dn())
744        self.transaction_start()
745        try:
746            target = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
747                                 expression=filter, attrs=[])
748            if len(target) == 0:
749                raise Exception('Unable to find user "%s"' % username)
750            assert(len(target) == 1)
751            self.delete(target[0].dn)
752        except:
753            self.transaction_cancel()
754            raise
755        else:
756            self.transaction_commit()
757
758    def setpassword(self, search_filter, password,
759                    force_change_at_next_login=False, username=None):
760        """Sets the password for a user
761
762        :param search_filter: LDAP filter to find the user (eg
763            samccountname=name)
764        :param password: Password for the user
765        :param force_change_at_next_login: Force password change
766        """
767        self.transaction_start()
768        try:
769            res = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
770                              expression=search_filter, attrs=[])
771            if len(res) == 0:
772                raise Exception('Unable to find user "%s"' % (username or search_filter))
773            if len(res) > 1:
774                raise Exception('Matched %u multiple users with filter "%s"' % (len(res), search_filter))
775            user_dn = res[0].dn
776            if not isinstance(password, text_type):
777                pw = password.decode('utf-8')
778            else:
779                pw = password
780            pw = ('"' + pw + '"').encode('utf-16-le')
781            setpw = """
782dn: %s
783changetype: modify
784replace: unicodePwd
785unicodePwd:: %s
786""" % (user_dn, base64.b64encode(pw).decode('utf-8'))
787
788            self.modify_ldif(setpw)
789
790            if force_change_at_next_login:
791                self.force_password_change_at_next_login(
792                    "(distinguishedName=" + str(user_dn) + ")")
793
794            #  modify the userAccountControl to remove the disabled bit
795            self.enable_account(search_filter)
796        except:
797            self.transaction_cancel()
798            raise
799        else:
800            self.transaction_commit()
801
802    def setexpiry(self, search_filter, expiry_seconds, no_expiry_req=False):
803        """Sets the account expiry for a user
804
805        :param search_filter: LDAP filter to find the user (eg
806            samaccountname=name)
807        :param expiry_seconds: expiry time from now in seconds
808        :param no_expiry_req: if set, then don't expire password
809        """
810        self.transaction_start()
811        try:
812            res = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
813                              expression=search_filter,
814                              attrs=["userAccountControl", "accountExpires"])
815            if len(res) == 0:
816                raise Exception('Unable to find user "%s"' % search_filter)
817            assert(len(res) == 1)
818            user_dn = res[0].dn
819
820            userAccountControl = int(res[0]["userAccountControl"][0])
821            accountExpires     = int(res[0]["accountExpires"][0])
822            if no_expiry_req:
823                userAccountControl = userAccountControl | 0x10000
824                accountExpires = 0
825            else:
826                userAccountControl = userAccountControl & ~0x10000
827                accountExpires = samba.unix2nttime(expiry_seconds + int(time.time()))
828
829            setexp = """
830dn: %s
831changetype: modify
832replace: userAccountControl
833userAccountControl: %u
834replace: accountExpires
835accountExpires: %u
836""" % (user_dn, userAccountControl, accountExpires)
837
838            self.modify_ldif(setexp)
839        except:
840            self.transaction_cancel()
841            raise
842        else:
843            self.transaction_commit()
844
845    def set_domain_sid(self, sid):
846        """Change the domain SID used by this LDB.
847
848        :param sid: The new domain sid to use.
849        """
850        dsdb._samdb_set_domain_sid(self, sid)
851
852    def get_domain_sid(self):
853        """Read the domain SID used by this LDB. """
854        return dsdb._samdb_get_domain_sid(self)
855
856    domain_sid = property(get_domain_sid, set_domain_sid,
857                          doc="SID for the domain")
858
859    def set_invocation_id(self, invocation_id):
860        """Set the invocation id for this SamDB handle.
861
862        :param invocation_id: GUID of the invocation id.
863        """
864        dsdb._dsdb_set_ntds_invocation_id(self, invocation_id)
865
866    def get_invocation_id(self):
867        """Get the invocation_id id"""
868        return dsdb._samdb_ntds_invocation_id(self)
869
870    invocation_id = property(get_invocation_id, set_invocation_id,
871                             doc="Invocation ID GUID")
872
873    def get_oid_from_attid(self, attid):
874        return dsdb._dsdb_get_oid_from_attid(self, attid)
875
876    def get_attid_from_lDAPDisplayName(self, ldap_display_name,
877                                       is_schema_nc=False):
878        '''return the attribute ID for a LDAP attribute as an integer as found in DRSUAPI'''
879        return dsdb._dsdb_get_attid_from_lDAPDisplayName(self,
880                                                         ldap_display_name, is_schema_nc)
881
882    def get_syntax_oid_from_lDAPDisplayName(self, ldap_display_name):
883        '''return the syntax OID for a LDAP attribute as a string'''
884        return dsdb._dsdb_get_syntax_oid_from_lDAPDisplayName(self, ldap_display_name)
885
886    def get_systemFlags_from_lDAPDisplayName(self, ldap_display_name):
887        '''return the systemFlags for a LDAP attribute as a integer'''
888        return dsdb._dsdb_get_systemFlags_from_lDAPDisplayName(self, ldap_display_name)
889
890    def get_linkId_from_lDAPDisplayName(self, ldap_display_name):
891        '''return the linkID for a LDAP attribute as a integer'''
892        return dsdb._dsdb_get_linkId_from_lDAPDisplayName(self, ldap_display_name)
893
894    def get_lDAPDisplayName_by_attid(self, attid):
895        '''return the lDAPDisplayName from an integer DRS attribute ID'''
896        return dsdb._dsdb_get_lDAPDisplayName_by_attid(self, attid)
897
898    def get_backlink_from_lDAPDisplayName(self, ldap_display_name):
899        '''return the attribute name of the corresponding backlink from the name
900        of a forward link attribute. If there is no backlink return None'''
901        return dsdb._dsdb_get_backlink_from_lDAPDisplayName(self, ldap_display_name)
902
903    def set_ntds_settings_dn(self, ntds_settings_dn):
904        """Set the NTDS Settings DN, as would be returned on the dsServiceName
905        rootDSE attribute.
906
907        This allows the DN to be set before the database fully exists
908
909        :param ntds_settings_dn: The new DN to use
910        """
911        dsdb._samdb_set_ntds_settings_dn(self, ntds_settings_dn)
912
913    def get_ntds_GUID(self):
914        """Get the NTDS objectGUID"""
915        return dsdb._samdb_ntds_objectGUID(self)
916
917    def server_site_name(self):
918        """Get the server site name"""
919        return dsdb._samdb_server_site_name(self)
920
921    def host_dns_name(self):
922        """return the DNS name of this host"""
923        res = self.search(base='', scope=ldb.SCOPE_BASE, attrs=['dNSHostName'])
924        return str(res[0]['dNSHostName'][0])
925
926    def domain_dns_name(self):
927        """return the DNS name of the domain root"""
928        domain_dn = self.get_default_basedn()
929        return domain_dn.canonical_str().split('/')[0]
930
931    def forest_dns_name(self):
932        """return the DNS name of the forest root"""
933        forest_dn = self.get_root_basedn()
934        return forest_dn.canonical_str().split('/')[0]
935
936    def load_partition_usn(self, base_dn):
937        return dsdb._dsdb_load_partition_usn(self, base_dn)
938
939    def set_schema(self, schema, write_indices_and_attributes=True):
940        self.set_schema_from_ldb(schema.ldb, write_indices_and_attributes=write_indices_and_attributes)
941
942    def set_schema_from_ldb(self, ldb_conn, write_indices_and_attributes=True):
943        dsdb._dsdb_set_schema_from_ldb(self, ldb_conn, write_indices_and_attributes)
944
945    def set_schema_update_now(self):
946        ldif = """
947dn:
948changetype: modify
949add: schemaUpdateNow
950schemaUpdateNow: 1
951"""
952        self.modify_ldif(ldif)
953
954    def dsdb_DsReplicaAttribute(self, ldb, ldap_display_name, ldif_elements):
955        '''convert a list of attribute values to a DRSUAPI DsReplicaAttribute'''
956        return dsdb._dsdb_DsReplicaAttribute(ldb, ldap_display_name, ldif_elements)
957
958    def dsdb_normalise_attributes(self, ldb, ldap_display_name, ldif_elements):
959        '''normalise a list of attribute values'''
960        return dsdb._dsdb_normalise_attributes(ldb, ldap_display_name, ldif_elements)
961
962    def get_attribute_from_attid(self, attid):
963        """ Get from an attid the associated attribute
964
965        :param attid: The attribute id for searched attribute
966        :return: The name of the attribute associated with this id
967        """
968        if len(self.hash_oid_name.keys()) == 0:
969            self._populate_oid_attid()
970        if self.get_oid_from_attid(attid) in self.hash_oid_name:
971            return self.hash_oid_name[self.get_oid_from_attid(attid)]
972        else:
973            return None
974
975    def _populate_oid_attid(self):
976        """Populate the hash hash_oid_name.
977
978        This hash contains the oid of the attribute as a key and
979        its display name as a value
980        """
981        self.hash_oid_name = {}
982        res = self.search(expression="objectClass=attributeSchema",
983                          controls=["search_options:1:2"],
984                          attrs=["attributeID",
985                                 "lDAPDisplayName"])
986        if len(res) > 0:
987            for e in res:
988                strDisplay = str(e.get("lDAPDisplayName"))
989                self.hash_oid_name[str(e.get("attributeID"))] = strDisplay
990
991    def get_attribute_replmetadata_version(self, dn, att):
992        """Get the version field trom the replPropertyMetaData for
993        the given field
994
995        :param dn: The on which we want to get the version
996        :param att: The name of the attribute
997        :return: The value of the version field in the replPropertyMetaData
998            for the given attribute. None if the attribute is not replicated
999        """
1000
1001        res = self.search(expression="distinguishedName=%s" % dn,
1002                          scope=ldb.SCOPE_SUBTREE,
1003                          controls=["search_options:1:2"],
1004                          attrs=["replPropertyMetaData"])
1005        if len(res) == 0:
1006            return None
1007
1008        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
1009                          res[0]["replPropertyMetaData"][0])
1010        ctr = repl.ctr
1011        if len(self.hash_oid_name.keys()) == 0:
1012            self._populate_oid_attid()
1013        for o in ctr.array:
1014            # Search for Description
1015            att_oid = self.get_oid_from_attid(o.attid)
1016            if att_oid in self.hash_oid_name and\
1017               att.lower() == self.hash_oid_name[att_oid].lower():
1018                return o.version
1019        return None
1020
1021    def set_attribute_replmetadata_version(self, dn, att, value,
1022                                           addifnotexist=False):
1023        res = self.search(expression="distinguishedName=%s" % dn,
1024                          scope=ldb.SCOPE_SUBTREE,
1025                          controls=["search_options:1:2"],
1026                          attrs=["replPropertyMetaData"])
1027        if len(res) == 0:
1028            return None
1029
1030        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
1031                          res[0]["replPropertyMetaData"][0])
1032        ctr = repl.ctr
1033        now = samba.unix2nttime(int(time.time()))
1034        found = False
1035        if len(self.hash_oid_name.keys()) == 0:
1036            self._populate_oid_attid()
1037        for o in ctr.array:
1038            # Search for Description
1039            att_oid = self.get_oid_from_attid(o.attid)
1040            if att_oid in self.hash_oid_name and\
1041               att.lower() == self.hash_oid_name[att_oid].lower():
1042                found = True
1043                seq = self.sequence_number(ldb.SEQ_NEXT)
1044                o.version = value
1045                o.originating_change_time = now
1046                o.originating_invocation_id = misc.GUID(self.get_invocation_id())
1047                o.originating_usn = seq
1048                o.local_usn = seq
1049
1050        if not found and addifnotexist and len(ctr.array) > 0:
1051            o2 = drsblobs.replPropertyMetaData1()
1052            o2.attid = 589914
1053            att_oid = self.get_oid_from_attid(o2.attid)
1054            seq = self.sequence_number(ldb.SEQ_NEXT)
1055            o2.version = value
1056            o2.originating_change_time = now
1057            o2.originating_invocation_id = misc.GUID(self.get_invocation_id())
1058            o2.originating_usn = seq
1059            o2.local_usn = seq
1060            found = True
1061            tab = ctr.array
1062            tab.append(o2)
1063            ctr.count = ctr.count + 1
1064            ctr.array = tab
1065
1066        if found:
1067            replBlob = ndr_pack(repl)
1068            msg = ldb.Message()
1069            msg.dn = res[0].dn
1070            msg["replPropertyMetaData"] = \
1071                ldb.MessageElement(replBlob,
1072                                   ldb.FLAG_MOD_REPLACE,
1073                                   "replPropertyMetaData")
1074            self.modify(msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])
1075
1076    def write_prefixes_from_schema(self):
1077        dsdb._dsdb_write_prefixes_from_schema_to_ldb(self)
1078
1079    def get_partitions_dn(self):
1080        return dsdb._dsdb_get_partitions_dn(self)
1081
1082    def get_nc_root(self, dn):
1083        return dsdb._dsdb_get_nc_root(self, dn)
1084
1085    def get_wellknown_dn(self, nc_root, wkguid):
1086        h_nc = self.hash_well_known.get(str(nc_root))
1087        dn = None
1088        if h_nc is not None:
1089            dn = h_nc.get(wkguid)
1090        if dn is None:
1091            dn = dsdb._dsdb_get_wellknown_dn(self, nc_root, wkguid)
1092            if dn is None:
1093                return dn
1094            if h_nc is None:
1095                self.hash_well_known[str(nc_root)] = {}
1096                h_nc = self.hash_well_known[str(nc_root)]
1097            h_nc[wkguid] = dn
1098        return dn
1099
1100    def set_minPwdAge(self, value):
1101        if not isinstance(value, binary_type):
1102            value = str(value).encode('utf8')
1103        m = ldb.Message()
1104        m.dn = ldb.Dn(self, self.domain_dn())
1105        m["minPwdAge"] = ldb.MessageElement(value, ldb.FLAG_MOD_REPLACE, "minPwdAge")
1106        self.modify(m)
1107
1108    def get_minPwdAge(self):
1109        res = self.search(self.domain_dn(), scope=ldb.SCOPE_BASE, attrs=["minPwdAge"])
1110        if len(res) == 0:
1111            return None
1112        elif "minPwdAge" not in res[0]:
1113            return None
1114        else:
1115            return int(res[0]["minPwdAge"][0])
1116
1117    def set_maxPwdAge(self, value):
1118        if not isinstance(value, binary_type):
1119            value = str(value).encode('utf8')
1120        m = ldb.Message()
1121        m.dn = ldb.Dn(self, self.domain_dn())
1122        m["maxPwdAge"] = ldb.MessageElement(value, ldb.FLAG_MOD_REPLACE, "maxPwdAge")
1123        self.modify(m)
1124
1125    def get_maxPwdAge(self):
1126        res = self.search(self.domain_dn(), scope=ldb.SCOPE_BASE, attrs=["maxPwdAge"])
1127        if len(res) == 0:
1128            return None
1129        elif "maxPwdAge" not in res[0]:
1130            return None
1131        else:
1132            return int(res[0]["maxPwdAge"][0])
1133
1134    def set_minPwdLength(self, value):
1135        if not isinstance(value, binary_type):
1136            value = str(value).encode('utf8')
1137        m = ldb.Message()
1138        m.dn = ldb.Dn(self, self.domain_dn())
1139        m["minPwdLength"] = ldb.MessageElement(value, ldb.FLAG_MOD_REPLACE, "minPwdLength")
1140        self.modify(m)
1141
1142    def get_minPwdLength(self):
1143        res = self.search(self.domain_dn(), scope=ldb.SCOPE_BASE, attrs=["minPwdLength"])
1144        if len(res) == 0:
1145            return None
1146        elif "minPwdLength" not in res[0]:
1147            return None
1148        else:
1149            return int(res[0]["minPwdLength"][0])
1150
1151    def set_pwdProperties(self, value):
1152        if not isinstance(value, binary_type):
1153            value = str(value).encode('utf8')
1154        m = ldb.Message()
1155        m.dn = ldb.Dn(self, self.domain_dn())
1156        m["pwdProperties"] = ldb.MessageElement(value, ldb.FLAG_MOD_REPLACE, "pwdProperties")
1157        self.modify(m)
1158
1159    def get_pwdProperties(self):
1160        res = self.search(self.domain_dn(), scope=ldb.SCOPE_BASE, attrs=["pwdProperties"])
1161        if len(res) == 0:
1162            return None
1163        elif "pwdProperties" not in res[0]:
1164            return None
1165        else:
1166            return int(res[0]["pwdProperties"][0])
1167
1168    def set_dsheuristics(self, dsheuristics):
1169        m = ldb.Message()
1170        m.dn = ldb.Dn(self, "CN=Directory Service,CN=Windows NT,CN=Services,%s"
1171                      % self.get_config_basedn().get_linearized())
1172        if dsheuristics is not None:
1173            m["dSHeuristics"] = \
1174                ldb.MessageElement(dsheuristics,
1175                                   ldb.FLAG_MOD_REPLACE,
1176                                   "dSHeuristics")
1177        else:
1178            m["dSHeuristics"] = \
1179                ldb.MessageElement([], ldb.FLAG_MOD_DELETE,
1180                                   "dSHeuristics")
1181        self.modify(m)
1182
1183    def get_dsheuristics(self):
1184        res = self.search("CN=Directory Service,CN=Windows NT,CN=Services,%s"
1185                          % self.get_config_basedn().get_linearized(),
1186                          scope=ldb.SCOPE_BASE, attrs=["dSHeuristics"])
1187        if len(res) == 0:
1188            dsheuristics = None
1189        elif "dSHeuristics" in res[0]:
1190            dsheuristics = res[0]["dSHeuristics"][0]
1191        else:
1192            dsheuristics = None
1193
1194        return dsheuristics
1195
1196    def create_ou(self, ou_dn, description=None, name=None, sd=None):
1197        """Creates an organizationalUnit object
1198        :param ou_dn: dn of the new object
1199        :param description: description attribute
1200        :param name: name atttribute
1201        :param sd: security descriptor of the object, can be
1202        an SDDL string or security.descriptor type
1203        """
1204        m = {"dn": ou_dn,
1205             "objectClass": "organizationalUnit"}
1206
1207        if description:
1208            m["description"] = description
1209        if name:
1210            m["name"] = name
1211
1212        if sd:
1213            m["nTSecurityDescriptor"] = ndr_pack(sd)
1214        self.add(m)
1215
1216    def sequence_number(self, seq_type):
1217        """Returns the value of the sequence number according to the requested type
1218        :param seq_type: type of sequence number
1219         """
1220        self.transaction_start()
1221        try:
1222            seq = super(SamDB, self).sequence_number(seq_type)
1223        except:
1224            self.transaction_cancel()
1225            raise
1226        else:
1227            self.transaction_commit()
1228        return seq
1229
1230    def get_dsServiceName(self):
1231        '''get the NTDS DN from the rootDSE'''
1232        res = self.search(base="", scope=ldb.SCOPE_BASE, attrs=["dsServiceName"])
1233        return str(res[0]["dsServiceName"][0])
1234
1235    def get_serverName(self):
1236        '''get the server DN from the rootDSE'''
1237        res = self.search(base="", scope=ldb.SCOPE_BASE, attrs=["serverName"])
1238        return str(res[0]["serverName"][0])
1239
1240    def dns_lookup(self, dns_name, dns_partition=None):
1241        '''Do a DNS lookup in the database, returns the NDR database structures'''
1242        if dns_partition is None:
1243            return dsdb_dns.lookup(self, dns_name)
1244        else:
1245            return dsdb_dns.lookup(self, dns_name,
1246                                   dns_partition=dns_partition)
1247
1248    def dns_extract(self, el):
1249        '''Return the NDR database structures from a dnsRecord element'''
1250        return dsdb_dns.extract(self, el)
1251
1252    def dns_replace(self, dns_name, new_records):
1253        '''Do a DNS modification on the database, sets the NDR database
1254        structures on a DNS name
1255        '''
1256        return dsdb_dns.replace(self, dns_name, new_records)
1257
1258    def dns_replace_by_dn(self, dn, new_records):
1259        '''Do a DNS modification on the database, sets the NDR database
1260        structures on a LDB DN
1261
1262        This routine is important because if the last record on the DN
1263        is removed, this routine will put a tombstone in the record.
1264        '''
1265        return dsdb_dns.replace_by_dn(self, dn, new_records)
1266
1267    def garbage_collect_tombstones(self, dn, current_time,
1268                                   tombstone_lifetime=None):
1269        '''garbage_collect_tombstones(lp, samdb, [dn], current_time, tombstone_lifetime)
1270        -> (num_objects_expunged, num_links_expunged)'''
1271
1272        if tombstone_lifetime is None:
1273            return dsdb._dsdb_garbage_collect_tombstones(self, dn,
1274                                                         current_time)
1275        else:
1276            return dsdb._dsdb_garbage_collect_tombstones(self, dn,
1277                                                         current_time,
1278                                                         tombstone_lifetime)
1279
1280    def create_own_rid_set(self):
1281        '''create a RID set for this DSA'''
1282        return dsdb._dsdb_create_own_rid_set(self)
1283
1284    def allocate_rid(self):
1285        '''return a new RID from the RID Pool on this DSA'''
1286        return dsdb._dsdb_allocate_rid(self)
1287
1288    def normalize_dn_in_domain(self, dn):
1289        '''return a new DN expanded by adding the domain DN
1290
1291        If the dn is already a child of the domain DN, just
1292        return it as-is.
1293
1294        :param dn: relative dn
1295        '''
1296        domain_dn = ldb.Dn(self, self.domain_dn())
1297
1298        if isinstance(dn, ldb.Dn):
1299            dn = str(dn)
1300
1301        full_dn = ldb.Dn(self, dn)
1302        if not full_dn.is_child_of(domain_dn):
1303            full_dn.add_base(domain_dn)
1304        return full_dn
1305