1# Copyright 2007 Google Inc.
2#
3# This program is free software; you can redistribute it and/or
4# modify it under the terms of the GNU General Public License
5# as published by the Free Software Foundation; either version 2
6# of the License, or (at your option) any later version.
7#
8# This program is distributed in the hope that it will be useful,
9# but WITHOUT ANY WARRANTY; without even the implied warranty of
10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11# GNU General Public License for more details.
12#
13# You should have received a copy of the GNU General Public License
14# along with this program; if not, write to the Free Software Foundation,
15# Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
16"""An implementation of an ldap data source for nsscache."""
17
18__author__ = ('jaq@google.com (Jamie Wilkinson)',
19              'vasilios@google.com (Vasilios Hoffman)')
20
21import calendar
22import logging
23import time
24import ldap
25import ldap.sasl
26import re
27from binascii import b2a_hex
28from urllib.parse import quote
29from distutils.version import StrictVersion
30
31from nss_cache import error
32from nss_cache.maps import automount
33from nss_cache.maps import group
34from nss_cache.maps import netgroup
35from nss_cache.maps import passwd
36from nss_cache.maps import shadow
37from nss_cache.maps import sshkey
38from nss_cache.sources import source
39
40IS_LDAP24_OR_NEWER = StrictVersion(ldap.__version__) >= StrictVersion('2.4')
41
42# ldap.LDAP_CONTROL_PAGE_OID is unavailable on some systems, so we define it here
43LDAP_CONTROL_PAGE_OID = '1.2.840.113556.1.4.319'
44
45
46def RegisterImplementation(registration_callback):
47    registration_callback(LdapSource)
48
49
50def makeSimplePagedResultsControl(page_size):
51    # The API for this is different on older versions of python-ldap, so we need
52    # to handle this case.
53    if IS_LDAP24_OR_NEWER:
54        return ldap.controls.SimplePagedResultsControl(True,
55                                                       size=page_size,
56                                                       cookie='')
57    else:
58        return ldap.controls.SimplePagedResultsControl(LDAP_CONTROL_PAGE_OID,
59                                                       True, (page_size, ''))
60
61
62def getCookieFromControl(pctrl):
63    if IS_LDAP24_OR_NEWER:
64        return pctrl.cookie
65    else:
66        return pctrl.controlValue[1]
67
68
69def setCookieOnControl(control, cookie, page_size):
70    if IS_LDAP24_OR_NEWER:
71        control.cookie = cookie
72    else:
73        control.controlValue = (page_size, cookie)
74
75    return cookie
76
77
78def sidToStr(sid):
79    """Converts an objectSid hexadecimal string returned from the LDAP query to
80    the objectSid string version in format of
81    S-1-5-21-1270288957-3800934213-3019856503-500 For more information about
82    the objectSid binary structure:
83
84    https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-dtyp/78eb9013-1c3a-4970-ad1f-2b1dad588a25
85    https://devblogs.microsoft.com/oldnewthing/?p=40253
86    This function was based from:
87    https://ldap3.readthedocs.io/_modules/ldap3/protocol/formatters/formatters.html#format_sid
88    """
89    try:
90        if sid.startswith(b'S-1') or sid.startswith('S-1'):
91            return sid
92    except Exception:
93        pass
94    try:
95        if str is not bytes:
96            revision = int(sid[0])
97            sub_authorities = int(sid[1])
98            identifier_authority = int.from_bytes(sid[2:8], byteorder='big')
99            if identifier_authority >= 2**32:
100                identifier_authority = hex(identifier_authority)
101
102            sub_authority = '-' + '-'.join([
103                str(
104                    int.from_bytes(sid[8 + (i * 4):12 + (i * 4)],
105                                   byteorder='little'))
106                for i in range(sub_authorities)
107            ])
108        else:
109            revision = int(b2a_hex(sid[0]))
110            sub_authorities = int(b2a_hex(sid[1]))
111            identifier_authority = int(b2a_hex(sid[2:8]), 16)
112            if identifier_authority >= 2**32:
113                identifier_authority = hex(identifier_authority)
114
115            sub_authority = '-' + '-'.join([
116                str(int(b2a_hex(sid[11 + (i * 4):7 + (i * 4):-1]), 16))
117                for i in range(sub_authorities)
118            ])
119        objectSid = 'S-' + str(revision) + '-' + str(
120            identifier_authority) + sub_authority
121
122        return objectSid
123    except Exception:
124        pass
125
126    return sid
127
128
129class LdapSource(source.Source):
130    """Source for data in LDAP.
131
132    After initialisation, one can search the data source for 'objects'
133    under a particular part of the LDAP tree, with some filter, and have it
134    return only some set of attributes.
135
136    'objects' in this sense means some structured blob of data, not a Python
137    object.
138    """
139    # ldap defaults
140    BIND_DN = ''
141    BIND_PASSWORD = ''
142    RETRY_DELAY = 5
143    RETRY_MAX = 3
144    SCOPE = 'one'
145    TIMELIMIT = -1
146    TLS_REQUIRE_CERT = 'demand'  # one of never, hard, demand, allow, try
147
148    # for registration
149    name = 'ldap'
150
151    # Page size for paged LDAP requests
152    # Value chosen based on default Active Directory MaxPageSize
153    PAGE_SIZE = 1000
154
155    def __init__(self, conf, conn=None):
156        """Initialise the LDAP Data Source.
157
158        Args:
159          conf: config.Config instance
160          conn: An instance of ldap.LDAPObject that'll be used as the connection.
161        """
162        super(LdapSource, self).__init__(conf)
163        self._dn_requested = False  # dn is a special-cased attribute
164
165        self._SetDefaults(conf)
166        self._conf = conf
167        self.ldap_controls = makeSimplePagedResultsControl(self.PAGE_SIZE)
168
169        # Used by _ReSearch:
170        self._last_search_params = None
171
172        if conn is None:
173            # ReconnectLDAPObject should handle interrupted ldap transactions.
174            # also, ugh
175            rlo = ldap.ldapobject.ReconnectLDAPObject
176            self.conn = rlo(uri=conf['uri'],
177                            retry_max=conf['retry_max'],
178                            retry_delay=conf['retry_delay'])
179            if conf['tls_starttls'] == 1:
180                self.conn.start_tls_s()
181            if 'ldap_debug' in conf:
182                self.conn.set_option(ldap.OPT_DEBUG_LEVEL, conf['ldap_debug'])
183        else:
184            self.conn = conn
185
186        # TODO(v): We should bind on-demand instead.
187        # (although binding here makes it easier to simulate a dropped network)
188        self.Bind(conf)
189
190    def _SetDefaults(self, configuration):
191        """Set defaults if necessary."""
192        # LDAPI URLs must be url escaped socket filenames; rewrite if necessary.
193        if 'uri' in configuration:
194            if configuration['uri'].startswith('ldapi://'):
195                configuration['uri'] = 'ldapi://' + quote(
196                    configuration['uri'][8:], '')
197        if 'bind_dn' not in configuration:
198            configuration['bind_dn'] = self.BIND_DN
199        if 'bind_password' not in configuration:
200            configuration['bind_password'] = self.BIND_PASSWORD
201        if 'retry_delay' not in configuration:
202            configuration['retry_delay'] = self.RETRY_DELAY
203        if 'retry_max' not in configuration:
204            configuration['retry_max'] = self.RETRY_MAX
205        if 'scope' not in configuration:
206            configuration['scope'] = self.SCOPE
207        if 'timelimit' not in configuration:
208            configuration['timelimit'] = self.TIMELIMIT
209        # TODO(jaq): XXX EVIL.  ldap client libraries change behaviour if we use
210        # polling, and it's nasty.  So don't let the user poll.
211        if configuration['timelimit'] == 0:
212            configuration['timelimit'] = -1
213        if 'tls_require_cert' not in configuration:
214            configuration['tls_require_cert'] = self.TLS_REQUIRE_CERT
215        if 'tls_starttls' not in configuration:
216            configuration['tls_starttls'] = 0
217
218        # Translate tls_require into appropriate constant, if necessary.
219        if configuration['tls_require_cert'] == 'never':
220            configuration['tls_require_cert'] = ldap.OPT_X_TLS_NEVER
221        elif configuration['tls_require_cert'] == 'hard':
222            configuration['tls_require_cert'] = ldap.OPT_X_TLS_HARD
223        elif configuration['tls_require_cert'] == 'demand':
224            configuration['tls_require_cert'] = ldap.OPT_X_TLS_DEMAND
225        elif configuration['tls_require_cert'] == 'allow':
226            configuration['tls_require_cert'] = ldap.OPT_X_TLS_ALLOW
227        elif configuration['tls_require_cert'] == 'try':
228            configuration['tls_require_cert'] = ldap.OPT_X_TLS_TRY
229
230        if 'sasl_authzid' not in configuration:
231            configuration['sasl_authzid'] = ''
232
233        # Should we issue STARTTLS?
234        if configuration['tls_starttls'] in (1, '1', 'on', 'yes', 'true'):
235            configuration['tls_starttls'] = 1
236        # if not configuration['tls_starttls']:
237        else:
238            configuration['tls_starttls'] = 0
239
240        # Setting global ldap defaults.
241        ldap.set_option(ldap.OPT_X_TLS_REQUIRE_CERT,
242                        configuration['tls_require_cert'])
243        ldap.set_option(ldap.OPT_REFERRALS, 0)
244        if 'tls_cacertdir' in configuration:
245            ldap.set_option(ldap.OPT_X_TLS_CACERTDIR,
246                            configuration['tls_cacertdir'])
247        if 'tls_cacertfile' in configuration:
248            ldap.set_option(ldap.OPT_X_TLS_CACERTFILE,
249                            configuration['tls_cacertfile'])
250        if 'tls_certfile' in configuration:
251            ldap.set_option(ldap.OPT_X_TLS_CERTFILE,
252                            configuration['tls_certfile'])
253        if 'tls_keyfile' in configuration:
254            ldap.set_option(ldap.OPT_X_TLS_KEYFILE,
255                            configuration['tls_keyfile'])
256        ldap.version = ldap.VERSION3  # this is hard-coded, we only support V3
257
258    def _SetCookie(self, cookie):
259        return setCookieOnControl(self.ldap_controls, cookie, self.PAGE_SIZE)
260
261    def Bind(self, configuration):
262        """Bind to LDAP, retrying if necessary."""
263        # If the server is unavailable, we are going to find out now, as this
264        # actually initiates the network connection.
265        retry_count = 0
266        while retry_count < configuration['retry_max']:
267            self.log.debug('opening ldap connection and binding to %s',
268                           configuration['uri'])
269            try:
270                if 'use_sasl' in configuration and configuration['use_sasl']:
271                    if ('sasl_mech' in configuration and
272                            configuration['sasl_mech'] and
273                            configuration['sasl_mech'].lower() == 'gssapi'):
274                        sasl = ldap.sasl.gssapi(configuration['sasl_authzid'])
275                    # TODO: Add other sasl mechs
276                    else:
277                        raise error.ConfigurationError(
278                            'SASL mechanism not supported')
279
280                    self.conn.sasl_interactive_bind_s('', sasl)
281                else:
282                    self.conn.simple_bind_s(who=configuration['bind_dn'],
283                                            cred=str(
284                                                configuration['bind_password']))
285                break
286            except ldap.SERVER_DOWN as e:
287                retry_count += 1
288                self.log.warning('Failed LDAP connection: attempt #%s.',
289                                 retry_count)
290                self.log.debug('ldap error is %r', e)
291                if retry_count == configuration['retry_max']:
292                    self.log.debug('max retries hit')
293                    raise error.SourceUnavailable(e)
294                self.log.debug('sleeping %d seconds',
295                               configuration['retry_delay'])
296                time.sleep(configuration['retry_delay'])
297
298    def _ReSearch(self):
299        """Performs self.Search again with the previously used parameters.
300
301        Returns:
302         self.Search result.
303        """
304        self.Search(*self._last_search_params)
305
306    def Search(self, search_base, search_filter, search_scope, attrs):
307        """Search the data source.
308
309        The search is asynchronous; data should be retrieved by iterating over
310        the source object itself (see __iter__() below).
311
312        Args:
313         search_base: the base of the tree being searched
314         search_filter: a filter on the objects to be returned
315         search_scope: the scope of the search from ldap.SCOPE_*
316         attrs: a list of attributes to be returned
317
318        Returns:
319         nothing.
320        """
321        self._last_search_params = (search_base, search_filter, search_scope,
322                                    attrs)
323
324        self.log.debug('searching for base=%r, filter=%r, scope=%r, attrs=%r',
325                       search_base, search_filter, search_scope, attrs)
326        if 'dn' in attrs:  # special cased attribute
327            self._dn_requested = True
328        self.message_id = self.conn.search_ext(base=search_base,
329                                               filterstr=search_filter,
330                                               scope=search_scope,
331                                               attrlist=attrs,
332                                               serverctrls=[self.ldap_controls])
333
334    def __iter__(self):
335        """Iterate over the data from the last search.
336
337        Probably not threadsafe.
338
339        Yields:
340          Search results from the prior call to self.Search()
341        """
342        # Acquire data to yield:
343        while True:
344            result_type, data = None, None
345
346            timeout_retries = 0
347            while timeout_retries < int(self._conf['retry_max']):
348                try:
349                    result_type, data, _, serverctrls = self.conn.result3(
350                        self.message_id, all=0, timeout=self.conf['timelimit'])
351                    # we need to filter out AD referrals
352                    if data and not data[0][0]:
353                        continue
354
355                    # Paged requests return a new cookie in serverctrls at the end of a page,
356                    # so we search for the cookie and perform another search if needed.
357                    if len(serverctrls) > 0:
358                        # Search for appropriate control
359                        simple_paged_results_controls = [
360                            control for control in serverctrls
361                            if control.controlType == LDAP_CONTROL_PAGE_OID
362                        ]
363                        if simple_paged_results_controls:
364                            # We only expect one control; just take the first in the list.
365                            cookie = getCookieFromControl(
366                                simple_paged_results_controls[0])
367
368                            if len(cookie) > 0:
369                                # If cookie is non-empty, call search_ext and result3 again
370                                self._SetCookie(cookie)
371                                self._ReSearch()
372                                result_type, data, _, serverctrls = self.conn.result3(
373                                    self.message_id,
374                                    all=0,
375                                    timeout=self.conf['timelimit'])
376                            # else: An empty cookie means we are done.
377
378                    # break loop once result3 doesn't time out and reset cookie
379                    setCookieOnControl(self.ldap_controls, '', self.PAGE_SIZE)
380                    break
381                except ldap.SIZELIMIT_EXCEEDED:
382                    self.log.warning(
383                        'LDAP server size limit exceeded; using page size {0}.'.
384                        format(self.PAGE_SIZE))
385                    return
386                except ldap.NO_SUCH_OBJECT:
387                    self.log.debug('Returning due to ldap.NO_SUCH_OBJECT')
388                    return
389                except ldap.TIMELIMIT_EXCEEDED:
390                    timeout_retries += 1
391                    self.log.warning('Timeout on LDAP results, attempt #%s.',
392                                     timeout_retries)
393                    if timeout_retries >= self._conf['retry_max']:
394                        self.log.debug('max retries hit, returning')
395                        return
396                    self.log.debug('sleeping %d seconds',
397                                   self._conf['retry_delay'])
398                    time.sleep(self.conf['retry_delay'])
399
400            if result_type == ldap.RES_SEARCH_RESULT:
401                self.log.debug('Returning due to RES_SEARCH_RESULT')
402                return
403
404            if result_type != ldap.RES_SEARCH_ENTRY:
405                self.log.info('Unknown result type %r, ignoring.', result_type)
406
407            if not data:
408                self.log.debug('Returning due to len(data) == 0')
409                return
410
411            for record in data:
412                # If the dn is requested, return it along with the payload,
413                # otherwise ignore it.
414                for key in record[1]:
415                    for i in range(len(record[1][key])):
416                        if isinstance(record[1][key][i],
417                                      bytes) and key != 'objectSid':
418                            value = record[1][key][i].decode('utf-8')
419                            record[1][key][i] = value
420                if self._dn_requested:
421                    merged_records = {'dn': record[0]}
422                    merged_records.update(record[1])
423                    yield merged_records
424                else:
425                    yield record[1]
426
427    def GetSshkeyMap(self, since=None):
428        """Return the sshkey map from this source.
429
430        Args:
431          since: Get data only changed since this timestamp (inclusive) or None
432          for all data.
433
434        Returns:
435          instance of maps.SshkeyMap
436        """
437        return SshkeyUpdateGetter(self.conf).GetUpdates(
438            source=self,
439            search_base=self.conf['base'],
440            search_filter=self.conf['filter'],
441            search_scope=self.conf['scope'],
442            since=since)
443
444    def GetPasswdMap(self, since=None):
445        """Return the passwd map from this source.
446
447        Args:
448          since: Get data only changed since this timestamp (inclusive) or None
449          for all data.
450
451        Returns:
452          instance of maps.PasswdMap
453        """
454        return PasswdUpdateGetter(self.conf).GetUpdates(
455            source=self,
456            search_base=self.conf['base'],
457            search_filter=self.conf['filter'],
458            search_scope=self.conf['scope'],
459            since=since)
460
461    def GetGroupMap(self, since=None):
462        """Return the group map from this source.
463
464        Args:
465          since: Get data only changed since this timestamp (inclusive) or None
466          for all data.
467
468        Returns:
469          instance of maps.GroupMap
470        """
471        return GroupUpdateGetter(self.conf).GetUpdates(
472            source=self,
473            search_base=self.conf['base'],
474            search_filter=self.conf['filter'],
475            search_scope=self.conf['scope'],
476            since=since)
477
478    def GetShadowMap(self, since=None):
479        """Return the shadow map from this source.
480
481        Args:
482          since: Get data only changed since this timestamp (inclusive) or None
483          for all data.
484
485        Returns:
486          instance of ShadowMap
487        """
488        return ShadowUpdateGetter(self.conf).GetUpdates(
489            source=self,
490            search_base=self.conf['base'],
491            search_filter=self.conf['filter'],
492            search_scope=self.conf['scope'],
493            since=since)
494
495    def GetNetgroupMap(self, since=None):
496        """Return the netgroup map from this source.
497
498        Args:
499          since: Get data only changed since this timestamp (inclusive) or None
500          for all data.
501
502        Returns:
503          instance of NetgroupMap
504        """
505        return NetgroupUpdateGetter(self.conf).GetUpdates(
506            source=self,
507            search_base=self.conf['base'],
508            search_filter=self.conf['filter'],
509            search_scope=self.conf['scope'],
510            since=since)
511
512    def GetAutomountMap(self, since=None, location=None):
513        """Return an automount map from this source.
514
515        Note that autmount maps are stored in multiple locations, thus we expect
516        a caller to provide a location.  We also follow the automount spec and
517        set our search scope to be 'one'.
518
519        Args:
520          since: Get data only changed since this timestamp (inclusive) or None
521            for all data.
522          location: Currently a string containing our search base, later we
523            may support hostname and additional parameters.
524
525        Returns:
526          instance of AutomountMap
527        """
528        if location is None:
529            self.log.error(
530                'A location is required to retrieve an automount map!')
531            raise error.EmptyMap
532
533        autofs_filter = '(objectclass=automount)'
534        return AutomountUpdateGetter(self.conf).GetUpdates(
535            source=self,
536            search_base=location,
537            search_filter=autofs_filter,
538            search_scope='one',
539            since=since)
540
541    def GetAutomountMasterMap(self):
542        """Return the autmount master map from this source.
543
544        The automount master map is a special-case map which points to a dynamic
545        list of additional maps. We currently support only the schema outlined at
546        http://docs.sun.com/source/806-4251-10/mapping.htm commonly used by linux
547        automount clients, namely ou=auto.master and objectclass=automount entries.
548
549        Returns:
550          an instance of maps.AutomountMap
551        """
552        search_base = self.conf['base']
553        search_scope = ldap.SCOPE_SUBTREE
554
555        # auto.master is stored under ou=auto.master with objectclass=automountMap
556        search_filter = '(&(objectclass=automountMap)(ou=auto.master))'
557        self.log.debug('retrieving automount master map.')
558        self.Search(search_base=search_base,
559                    search_filter=search_filter,
560                    search_scope=search_scope,
561                    attrs=['dn'])
562
563        search_base = None
564        for obj in self:
565            # the dn of the matched object is our search base
566            search_base = obj['dn']
567
568        if search_base is None:
569            self.log.critical('Could not find automount master map!')
570            raise error.EmptyMap
571
572        self.log.debug('found ou=auto.master at %s', search_base)
573        master_map = self.GetAutomountMap(location=search_base)
574
575        # fix our location attribute to contain the data we
576        # expect returned to us later, namely the new search base(s)
577        for map_entry in master_map:
578            # we currently ignore hostname and just look for the dn which will
579            # be the search_base for this map.  third field, colon delimited.
580            map_entry.location = map_entry.location.split(':')[2]
581            # and strip the space seperated options
582            map_entry.location = map_entry.location.split(' ')[0]
583            self.log.debug('master map has: %s' % map_entry.location)
584
585        return master_map
586
587    def Verify(self, since=None):
588        """Verify that this source is contactable and can be queried for
589        data."""
590        if since is None:
591            # one minute in the future
592            since = int(time.time() + 60)
593        try:
594            results = self.GetPasswdMap(since=since)
595        except KeyError:
596            # AD groups don't have all attributes of AD users
597            results = self.GetGroupMap(since=since)
598        return len(results)
599
600
601class UpdateGetter(object):
602    """Base class that gets updates from LDAP."""
603
604    def __init__(self, conf):
605        super(UpdateGetter, self).__init__()
606        self.conf = conf
607
608    def FromLdapToTimestamp(self, ldap_ts_string):
609        """Transforms a LDAP timestamp into the nss_cache internal timestamp.
610
611        Args:
612          ldap_ts_string: An LDAP timestamp string in the format %Y%m%d%H%M%SZ
613
614        Returns:
615          number of seconds since epoch.
616        """
617        if isinstance(ldap_ts_string, bytes):
618            ldap_ts_string = ldap_ts_string.decode('utf-8')
619        try:
620            if self.conf.get('ad'):
621                # AD timestamp has different format
622                t = time.strptime(ldap_ts_string, '%Y%m%d%H%M%S.0Z')
623            else:
624                t = time.strptime(ldap_ts_string, '%Y%m%d%H%M%SZ')
625        except ValueError:
626            # Some systems add a decimal component; try to filter it:
627            m = re.match('([0-9]*)(\.[0-9]*)?(Z)', ldap_ts_string)
628            if m:
629                ldap_ts_string = m.group(1) + m.group(3)
630            if self.conf.get('ad'):
631                t = time.strptime(ldap_ts_string, '%Y%m%d%H%M%S.0Z')
632            else:
633                t = time.strptime(ldap_ts_string, '%Y%m%d%H%M%SZ')
634        return int(calendar.timegm(t))
635
636    def FromTimestampToLdap(self, ts):
637        """Transforms nss_cache internal timestamp into a LDAP timestamp.
638
639        Args:
640          ts: number of seconds since epoch
641
642        Returns:
643          LDAP format timestamp string.
644        """
645        if self.conf.get('ad'):
646            t = time.strftime('%Y%m%d%H%M%S.0Z', time.gmtime(ts))
647        else:
648            t = time.strftime('%Y%m%d%H%M%SZ', time.gmtime(ts))
649        return t
650
651    def GetUpdates(self, source, search_base, search_filter, search_scope,
652                   since):
653        """Get updates from a source.
654
655        Args:
656          source: a data source
657          search_base: the LDAP base of the tree
658          search_filter: the LDAP object filter
659          search_scope:  the LDAP scope filter, one of 'base', 'one', or 'sub'.
660          since: a timestamp to get updates since (None for 'get everything')
661
662        Returns:
663          a tuple containing the map of updates and a maximum timestamp
664
665        Raises:
666          error.ConfigurationError: scope is invalid
667          ValueError: an object in the source map is malformed
668        """
669        if self.conf.get('ad'):
670            # AD attribute for modifyTimestamp is whenChanged
671            self.attrs.append('whenChanged')
672        else:
673            self.attrs.append('modifyTimestamp')
674
675        if since is not None:
676            ts = self.FromTimestampToLdap(since)
677            # since openldap disallows modifyTimestamp "greater than" we have to
678            # increment by one second.
679            if self.conf.get('ad'):
680                ts = int(ts.rstrip('.0Z')) + 1
681                ts = '%s.0Z' % ts
682                search_filter = ('(&%s(whenChanged>=%s))' % (search_filter, ts))
683            else:
684                ts = int(ts.rstrip('Z')) + 1
685                ts = '%sZ' % ts
686                search_filter = ('(&%s(modifyTimestamp>=%s))' %
687                                 (search_filter, ts))
688
689        if search_scope == 'base':
690            search_scope = ldap.SCOPE_BASE
691        elif search_scope in ['one', 'onelevel']:
692            search_scope = ldap.SCOPE_ONELEVEL
693        elif search_scope in ['sub', 'subtree']:
694            search_scope = ldap.SCOPE_SUBTREE
695        else:
696            raise error.ConfigurationError('Invalid scope: %s' % search_scope)
697
698        source.Search(search_base=search_base,
699                      search_filter=search_filter,
700                      search_scope=search_scope,
701                      attrs=self.attrs)
702
703        # Don't initialize with since, because we really want to get the
704        # latest timestamp read, and if somehow a larger 'since' slips through
705        # the checks in main(), we'd better catch it here.
706        max_ts = None
707
708        data_map = self.CreateMap()
709
710        for obj in source:
711            for field in self.essential_fields:
712                if field not in obj:
713                    logging.warn('invalid object passed: %r not in %r', field,
714                                 obj)
715                    raise ValueError('Invalid object passed: %r', obj)
716
717            if self.conf.get('ad'):
718                obj_ts = self.FromLdapToTimestamp(obj['whenChanged'][0])
719            else:
720                try:
721                    obj_ts = self.FromLdapToTimestamp(obj['modifyTimestamp'][0])
722                except KeyError:
723                    obj_ts = self.FromLdapToTimestamp(obj['modifyTimeStamp'][0])
724
725            if max_ts is None or obj_ts > max_ts:
726                max_ts = obj_ts
727
728            try:
729                if not data_map.Add(self.Transform(obj)):
730                    logging.info('could not add obj: %r', obj)
731            except AttributeError as e:
732                logging.warning('error %r, discarding malformed obj: %r',
733                                str(e), obj)
734        # Perform some post processing on the data_map.
735        self.PostProcess(data_map, source, search_filter, search_scope)
736
737        data_map.SetModifyTimestamp(max_ts)
738
739        return data_map
740
741    def PostProcess(self, data_map, source, search_filter, search_scope):
742        """Perform some post-process of the data."""
743        pass
744
745
746class PasswdUpdateGetter(UpdateGetter):
747    """Get passwd updates."""
748
749    def __init__(self, conf):
750        super(PasswdUpdateGetter, self).__init__(conf)
751        if self.conf.get('ad'):
752            # attributes of AD user to be returned
753            self.attrs = [
754                'sAMAccountName', 'objectSid', 'displayName',
755                'unixHomeDirectory', 'pwdLastSet', 'loginShell'
756            ]
757            self.essential_fields = ['sAMAccountName', 'objectSid']
758        else:
759            self.attrs = [
760                'uid', 'uidNumber', 'gidNumber', 'gecos', 'cn', 'homeDirectory',
761                'loginShell', 'fullName'
762            ]
763            if 'uidattr' in self.conf:
764                self.attrs.append(self.conf['uidattr'])
765            if 'uidregex' in self.conf:
766                self.uidregex = re.compile(self.conf['uidregex'])
767            self.essential_fields = ['uid', 'uidNumber', 'gidNumber']
768            if self.conf.get('use_rid'):
769                self.attrs.append('sambaSID')
770                self.essential_fields.append('sambaSID')
771        self.log = logging.getLogger(self.__class__.__name__)
772
773    def CreateMap(self):
774        """Returns a new PasswdMap instance to have PasswdMapEntries added to
775        it."""
776        return passwd.PasswdMap()
777
778    def Transform(self, obj):
779        """Transforms a LDAP posixAccount data structure into a
780        PasswdMapEntry."""
781
782        pw = passwd.PasswdMapEntry()
783
784        if self.conf.get('ad'):
785            if 'displayName' in obj:
786                pw.gecos = obj['displayName'][0]
787        elif 'gecos' in obj:
788            pw.gecos = obj['gecos'][0]
789        elif 'cn' in obj:
790            pw.gecos = obj['cn'][0]
791        elif 'fullName' in obj:
792            pw.gecos = obj['fullName'][0]
793        else:
794            raise ValueError('Neither gecos nor cn found')
795
796        pw.gecos = pw.gecos.replace('\n', '')
797
798        if self.conf.get('ad'):
799            pw.name = obj['sAMAccountName'][0]
800        elif 'uidattr' in self.conf:
801            pw.name = obj[self.conf['uidattr']][0]
802        else:
803            pw.name = obj['uid'][0]
804
805        if hasattr(self, 'uidregex'):
806            pw.name = ''.join([x for x in self.uidregex.findall(pw.name)])
807
808        if 'override_shell' in self.conf:
809            pw.shell = self.conf['override_shell']
810        elif 'loginShell' in obj:
811            pw.shell = obj['loginShell'][0]
812        else:
813            pw.shell = ''
814
815        if self.conf.get('ad'):
816            # use the user's RID for uid and gid to have
817            # the correspondant group with the same name
818            pw.uid = int(sidToStr(obj['objectSid'][0]).split('-')[-1])
819            pw.gid = int(sidToStr(obj['objectSid'][0]).split('-')[-1])
820        elif self.conf.get('use_rid'):
821            # use the user's RID for uid and gid to have
822            # the correspondant group with the same name
823            pw.uid = int(sidToStr(obj['sambaSID'][0]).split('-')[-1])
824            pw.gid = int(sidToStr(obj['sambaSID'][0]).split('-')[-1])
825        else:
826            pw.uid = int(obj['uidNumber'][0])
827            pw.gid = int(obj['gidNumber'][0])
828
829        if 'offset' in self.conf:
830            # map uid and gid to higher number
831            # to avoid conflict with local accounts
832            pw.uid = int(pw.uid + self.conf['offset'])
833            pw.gid = int(pw.gid + self.conf['offset'])
834
835        if self.conf.get('home_dir'):
836            pw.dir = '/home/%s' % pw.name
837        elif 'unixHomeDirectory' in obj:
838            pw.dir = obj['unixHomeDirectory'][0]
839        elif 'homeDirectory' in obj:
840            pw.dir = obj['homeDirectory'][0]
841        else:
842            pw.dir = ''
843
844        # hack
845        pw.passwd = 'x'
846
847        return pw
848
849
850class GroupUpdateGetter(UpdateGetter):
851    """Get group updates."""
852
853    def __init__(self, conf):
854        super(GroupUpdateGetter, self).__init__(conf)
855        # TODO: Merge multiple rcf2307bis[_alt] options into a single option.
856        if self.conf.get('ad'):
857            # attributes of AD group to be returned
858            self.attrs = ['sAMAccountName', 'member', 'objectSid']
859            self.essential_fields = ['sAMAccountName', 'objectSid']
860        else:
861            if conf.get('rfc2307bis'):
862                self.attrs = ['cn', 'gidNumber', 'member', 'uid']
863            elif conf.get('rfc2307bis_alt'):
864                self.attrs = ['cn', 'gidNumber', 'uniqueMember', 'uid']
865            else:
866                self.attrs = ['cn', 'gidNumber', 'memberUid', 'uid']
867            if 'groupregex' in conf:
868                self.groupregex = re.compile(self.conf['groupregex'])
869            self.essential_fields = ['cn']
870            if conf.get('use_rid'):
871                self.attrs.append('sambaSID')
872                self.essential_fields.append('sambaSID')
873
874        self.log = logging.getLogger(__name__)
875
876    def CreateMap(self):
877        """Return a GroupMap instance."""
878        return group.GroupMap()
879
880    def Transform(self, obj):
881        """Transforms a LDAP posixGroup object into a group(5) entry."""
882
883        gr = group.GroupMapEntry()
884
885        if self.conf.get('ad'):
886            gr.name = obj['sAMAccountName'][0]
887        # hack to map the users as the corresponding group with the same name
888        elif 'uid' in obj:
889            gr.name = obj['uid'][0]
890        else:
891            gr.name = obj['cn'][0]
892        # group passwords are deferred to gshadow
893        gr.passwd = '*'
894        base = self.conf.get("base")
895        members = []
896        group_members = []
897        if 'memberUid' in obj:
898            if hasattr(self, 'groupregex'):
899                members.extend(''.join(
900                    [x for x in self.groupregex.findall(obj['memberUid'])]))
901            else:
902                members.extend(obj['memberUid'])
903        elif 'member' in obj:
904            for member_dn in obj['member']:
905                member_uid = member_dn.split(',')[0].split('=')[1]
906                # Note that there is not currently a way to consistently distinguish
907                # a group from a person
908                group_members.append(member_uid)
909                if hasattr(self, 'groupregex'):
910                    members.append(''.join(
911                        [x for x in self.groupregex.findall(member_uid)]))
912                else:
913                    members.append(member_uid)
914        elif 'uniqueMember' in obj:
915            """This contains a DN and is processed in PostProcess in
916            GetUpdates."""
917            members.extend(obj['uniqueMember'])
918        members.sort()
919
920        if self.conf.get('ad'):
921            gr.gid = int(sidToStr(obj['objectSid'][0]).split('-')[-1])
922        elif self.conf.get('use_rid'):
923            gr.gid = int(sidToStr(obj['sambaSID'][0]).split('-')[-1])
924        else:
925            gr.gid = int(obj['gidNumber'][0])
926
927        if 'offset' in self.conf:
928            gr.gid = int(gr.gid + self.conf['offset'])
929
930        gr.members = members
931        gr.groupmembers = group_members
932
933        return gr
934
935    def PostProcess(self, data_map, source, search_filter, search_scope):
936        """Perform some post-process of the data."""
937        if 'uniqueMember' in self.attrs:
938            for gr in data_map:
939                uidmembers = []
940                for member in gr.members:
941                    source.Search(search_base=member,
942                                  search_filter='(objectClass=*)',
943                                  search_scope=ldap.SCOPE_BASE,
944                                  attrs=['uid'])
945                    for obj in source:
946                        if 'uid' in obj:
947                            uidmembers.extend(obj['uid'])
948                del gr.members[:]
949                gr.members.extend(uidmembers)
950
951        _group_map = {i.name: i for i in data_map}
952
953        def _expand_members(obj, visited=None):
954            """Expand all subgroups recursively."""
955            for member_name in obj.groupmembers:
956                if member_name in _group_map and member_name not in visited:
957                    gmember = _group_map[member_name]
958                    for member in gmember.members:
959                        if member not in obj.members:
960                            obj.members.append(member)
961                    for submember_name in gmember.groupmembers:
962                        if submember_name in _group_map and submember_name not in visited:
963                            visited.append(submember_name)
964                            _expand_members(_group_map[submember_name], visited)
965
966        if self.conf.get("nested_groups"):
967            self.log.info("Expanding nested groups")
968            for gr in data_map:
969                _expand_members(gr, [gr.name])
970
971
972class ShadowUpdateGetter(UpdateGetter):
973    """Get Shadow updates from the LDAP Source."""
974
975    def __init__(self, conf):
976        super(ShadowUpdateGetter, self).__init__(conf)
977        self.attrs = [
978            'uid', 'shadowLastChange', 'shadowMin', 'shadowMax',
979            'shadowWarning', 'shadowInactive', 'shadowExpire', 'shadowFlag',
980            'userPassword'
981        ]
982        if self.conf.get('ad'):
983            # attributes of AD user to be returned for shadow
984            self.attrs.extend(('sAMAccountName', 'pwdLastSet'))
985            self.essential_fields = ['sAMAccountName', 'pwdLastSet']
986        else:
987            if 'uidattr' in self.conf:
988                self.attrs.append(self.conf['uidattr'])
989            if 'uidregex' in self.conf:
990                self.uidregex = re.compile(self.conf['uidregex'])
991            self.essential_fields = ['uid']
992        self.log = logging.getLogger(self.__class__.__name__)
993
994    def CreateMap(self):
995        """Return a ShadowMap instance."""
996        return shadow.ShadowMap()
997
998    def Transform(self, obj):
999        """Transforms an LDAP shadowAccont object into a shadow(5) entry."""
1000
1001        shadow_ent = shadow.ShadowMapEntry()
1002
1003        if self.conf.get('ad'):
1004            shadow_ent.name = obj['sAMAccountName'][0]
1005        elif 'uidattr' in self.conf:
1006            shadow_ent.name = obj[self.conf['uidattr']][0]
1007        else:
1008            shadow_ent.name = obj['uid'][0]
1009
1010        if hasattr(self, 'uidregex'):
1011            shadow_ent.name = ''.join(
1012                [x for x in self.uidregex.findall(shadow_end.name)])
1013
1014        # TODO(jaq): does nss_ldap check the contents of the userPassword
1015        # attribute?
1016        shadow_ent.passwd = '*'
1017        if self.conf.get('ad'):
1018            # Time attributes of AD objects use interval date/time format with a value
1019            # that represents the number of 100-nanosecond intervals since January 1, 1601.
1020            # We need to calculate the difference between 1970-01-01 and 1601-01-01 in seconds wich is 11644473600
1021            # then abstract it from the pwdLastChange value in seconds, then devide it by 86400 to get the
1022            # days since Jan 1, 1970 the password wa changed.
1023            shadow_ent.lstchg = int(
1024                (int(obj['pwdLastSet'][0]) / 10000000 - 11644473600) / 86400)
1025        elif 'shadowLastChange' in obj:
1026            shadow_ent.lstchg = int(obj['shadowLastChange'][0])
1027        if 'shadowMin' in obj:
1028            shadow_ent.min = int(obj['shadowMin'][0])
1029        if 'shadowMax' in obj:
1030            shadow_ent.max = int(obj['shadowMax'][0])
1031        if 'shadowWarning' in obj:
1032            shadow_ent.warn = int(obj['shadowWarning'][0])
1033        if 'shadowInactive' in obj:
1034            shadow_ent.inact = int(obj['shadowInactive'][0])
1035        if 'shadowExpire' in obj:
1036            shadow_ent.expire = int(obj['shadowExpire'][0])
1037        if 'shadowFlag' in obj:
1038            shadow_ent.flag = int(obj['shadowFlag'][0])
1039        if shadow_ent.flag is None:
1040            shadow_ent.flag = 0
1041        if 'userPassword' in obj:
1042            passwd = obj['userPassword'][0]
1043            if passwd[:7].lower() == '{crypt}':
1044                shadow_ent.passwd = passwd[7:]
1045            else:
1046                logging.info('Ignored password that was not in crypt format')
1047        return shadow_ent
1048
1049
1050class NetgroupUpdateGetter(UpdateGetter):
1051    """Get netgroup updates."""
1052
1053    def __init__(self, conf):
1054        super(NetgroupUpdateGetter, self).__init__(conf)
1055        self.attrs = ['cn', 'memberNisNetgroup', 'nisNetgroupTriple']
1056        self.essential_fields = ['cn']
1057
1058    def CreateMap(self):
1059        """Return a NetgroupMap instance."""
1060        return netgroup.NetgroupMap()
1061
1062    def Transform(self, obj):
1063        """Transforms an LDAP nisNetgroup object into a netgroup(5) entry."""
1064        netgroup_ent = netgroup.NetgroupMapEntry()
1065        netgroup_ent.name = obj['cn'][0]
1066
1067        entries = set()
1068        if 'memberNisNetgroup' in obj:
1069            entries.update(obj['memberNisNetgroup'])
1070        if 'nisNetgroupTriple' in obj:
1071            entries.update(obj['nisNetgroupTriple'])
1072
1073        # final data is stored as a string in the object
1074        netgroup_ent.entries = ' '.join(sorted(entries))
1075
1076        return netgroup_ent
1077
1078
1079class AutomountUpdateGetter(UpdateGetter):
1080    """Get specific automount maps."""
1081
1082    def __init__(self, conf):
1083        super(AutomountUpdateGetter, self).__init__(conf)
1084        self.attrs = ['cn', 'automountInformation']
1085        self.essential_fields = ['cn']
1086
1087    def CreateMap(self):
1088        """Return a AutomountMap instance."""
1089        return automount.AutomountMap()
1090
1091    def Transform(self, obj):
1092        """Transforms an LDAP automount object into an autofs(5) entry."""
1093        automount_ent = automount.AutomountMapEntry()
1094        automount_ent.key = obj['cn'][0]
1095
1096        automount_information = obj['automountInformation'][0]
1097
1098        if automount_information.startswith('ldap'):
1099            # we are creating an autmount master map, pointing to other maps in LDAP
1100            automount_ent.location = automount_information
1101        else:
1102            # we are creating normal automount maps, with filesystems and options
1103            automount_ent.options = automount_information.split(' ')[0]
1104            automount_ent.location = automount_information.split(' ')[1]
1105
1106        return automount_ent
1107
1108
1109class SshkeyUpdateGetter(UpdateGetter):
1110    """Fetches SSH keys."""
1111
1112    def __init__(self, conf):
1113        super(SshkeyUpdateGetter, self).__init__(conf)
1114        self.attrs = ['uid', 'sshPublicKey']
1115        if 'uidattr' in self.conf:
1116            self.attrs.append(self.conf['uidattr'])
1117        if 'uidregex' in self.conf:
1118            self.uidregex = re.compile(self.conf['uidregex'])
1119        self.essential_fields = ['uid']
1120
1121    def CreateMap(self):
1122        """Returns a new SshkeyMap instance to have SshkeyMapEntries added to
1123        it."""
1124        return sshkey.SshkeyMap()
1125
1126    def Transform(self, obj):
1127        """Transforms a LDAP posixAccount data structure into a
1128        SshkeyMapEntry."""
1129
1130        skey = sshkey.SshkeyMapEntry()
1131
1132        if 'uidattr' in self.conf:
1133            skey.name = obj[self.conf['uidattr']][0]
1134        else:
1135            skey.name = obj['uid'][0]
1136
1137        if hasattr(self, 'uidregex'):
1138            skey.name = ''.join([x for x in self.uidregex.findall(pw.name)])
1139
1140        if 'sshPublicKey' in obj:
1141            skey.sshkey = obj['sshPublicKey']
1142        else:
1143            skey.sshkey = ''
1144
1145        return skey
1146