1# -*- encoding: utf-8 -*-
2# Samba traffic replay and learning
3#
4# Copyright (C) Catalyst IT Ltd. 2017
5#
6# This program is free software; you can redistribute it and/or modify
7# it under the terms of the GNU General Public License as published by
8# the Free Software Foundation; either version 3 of the License, or
9# (at your option) any later version.
10#
11# This program is distributed in the hope that it will be useful,
12# but WITHOUT ANY WARRANTY; without even the implied warranty of
13# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14# GNU General Public License for more details.
15#
16# You should have received a copy of the GNU General Public License
17# along with this program.  If not, see <http://www.gnu.org/licenses/>.
18#
19from __future__ import print_function, division
20
21import time
22import os
23import random
24import json
25import math
26import sys
27import signal
28from errno import ECHILD, ESRCH
29
30from collections import OrderedDict, Counter, defaultdict, namedtuple
31from dns.resolver import query as dns_query
32
33from samba.emulate import traffic_packets
34from samba.samdb import SamDB
35import ldb
36from ldb import LdbError
37from samba.dcerpc import ClientConnection
38from samba.dcerpc import security, drsuapi, lsa
39from samba.dcerpc import netlogon
40from samba.dcerpc.netlogon import netr_Authenticator
41from samba.dcerpc import srvsvc
42from samba.dcerpc import samr
43from samba.drs_utils import drs_DsBind
44import traceback
45from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
46from samba.auth import system_session
47from samba.dsdb import (
48    UF_NORMAL_ACCOUNT,
49    UF_SERVER_TRUST_ACCOUNT,
50    UF_TRUSTED_FOR_DELEGATION,
51    UF_WORKSTATION_TRUST_ACCOUNT
52)
53from samba.dcerpc.misc import SEC_CHAN_BDC
54from samba import gensec
55from samba import sd_utils
56from samba.compat import get_string
57from samba.logger import get_samba_logger
58import bisect
59
60CURRENT_MODEL_VERSION = 2   # save as this
61REQUIRED_MODEL_VERSION = 2  # load accepts this or greater
62SLEEP_OVERHEAD = 3e-4
63
64# we don't use None, because it complicates [de]serialisation
65NON_PACKET = '-'
66
67CLIENT_CLUES = {
68    ('dns', '0'): 1.0,      # query
69    ('smb', '0x72'): 1.0,   # Negotiate protocol
70    ('ldap', '0'): 1.0,     # bind
71    ('ldap', '3'): 1.0,     # searchRequest
72    ('ldap', '2'): 1.0,     # unbindRequest
73    ('cldap', '3'): 1.0,
74    ('dcerpc', '11'): 1.0,  # bind
75    ('dcerpc', '14'): 1.0,  # Alter_context
76    ('nbns', '0'): 1.0,     # query
77}
78
79SERVER_CLUES = {
80    ('dns', '1'): 1.0,      # response
81    ('ldap', '1'): 1.0,     # bind response
82    ('ldap', '4'): 1.0,     # search result
83    ('ldap', '5'): 1.0,     # search done
84    ('cldap', '5'): 1.0,
85    ('dcerpc', '12'): 1.0,  # bind_ack
86    ('dcerpc', '13'): 1.0,  # bind_nak
87    ('dcerpc', '15'): 1.0,  # Alter_context response
88}
89
90SKIPPED_PROTOCOLS = {"smb", "smb2", "browser", "smb_netlogon"}
91
92WAIT_SCALE = 10.0
93WAIT_THRESHOLD = (1.0 / WAIT_SCALE)
94NO_WAIT_LOG_TIME_RANGE = (-10, -3)
95
96# DEBUG_LEVEL can be changed by scripts with -d
97DEBUG_LEVEL = 0
98
99LOGGER = get_samba_logger(name=__name__)
100
101
102def debug(level, msg, *args):
103    """Print a formatted debug message to standard error.
104
105
106    :param level: The debug level, message will be printed if it is <= the
107                  currently set debug level. The debug level can be set with
108                  the -d option.
109    :param msg:   The message to be logged, can contain C-Style format
110                  specifiers
111    :param args:  The parameters required by the format specifiers
112    """
113    if level <= DEBUG_LEVEL:
114        if not args:
115            print(msg, file=sys.stderr)
116        else:
117            print(msg % tuple(args), file=sys.stderr)
118
119
120def debug_lineno(*args):
121    """ Print an unformatted log message to stderr, contaning the line number
122    """
123    tb = traceback.extract_stack(limit=2)
124    print((" %s:" "\033[01;33m"
125           "%s " "\033[00m" % (tb[0][2], tb[0][1])), end=' ',
126          file=sys.stderr)
127    for a in args:
128        print(a, file=sys.stderr)
129    print(file=sys.stderr)
130    sys.stderr.flush()
131
132
133def random_colour_print(seeds):
134    """Return a function that prints a coloured line to stderr. The colour
135    of the line depends on a sort of hash of the integer arguments."""
136    if seeds:
137        s = 214
138        for x in seeds:
139            s += 17
140            s *= x
141            s %= 214
142        prefix = "\033[38;5;%dm" % (18 + s)
143
144        def p(*args):
145            if DEBUG_LEVEL > 0:
146                for a in args:
147                    print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
148    else:
149        def p(*args):
150            if DEBUG_LEVEL > 0:
151                for a in args:
152                    print(a, file=sys.stderr)
153
154    return p
155
156
157class FakePacketError(Exception):
158    pass
159
160
161class Packet(object):
162    """Details of a network packet"""
163    __slots__ = ('timestamp',
164                 'ip_protocol',
165                 'stream_number',
166                 'src',
167                 'dest',
168                 'protocol',
169                 'opcode',
170                 'desc',
171                 'extra',
172                 'endpoints')
173    def __init__(self, timestamp, ip_protocol, stream_number, src, dest,
174                 protocol, opcode, desc, extra):
175        self.timestamp = timestamp
176        self.ip_protocol = ip_protocol
177        self.stream_number = stream_number
178        self.src = src
179        self.dest = dest
180        self.protocol = protocol
181        self.opcode = opcode
182        self.desc = desc
183        self.extra = extra
184        if self.src < self.dest:
185            self.endpoints = (self.src, self.dest)
186        else:
187            self.endpoints = (self.dest, self.src)
188
189    @classmethod
190    def from_line(cls, line):
191        fields = line.rstrip('\n').split('\t')
192        (timestamp,
193         ip_protocol,
194         stream_number,
195         src,
196         dest,
197         protocol,
198         opcode,
199         desc) = fields[:8]
200        extra = fields[8:]
201
202        timestamp = float(timestamp)
203        src = int(src)
204        dest = int(dest)
205
206        return cls(timestamp, ip_protocol, stream_number, src, dest,
207                   protocol, opcode, desc, extra)
208
209    def as_summary(self, time_offset=0.0):
210        """Format the packet as a traffic_summary line.
211        """
212        extra = '\t'.join(self.extra)
213        t = self.timestamp + time_offset
214        return (t, '%f\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s' %
215                (t,
216                 self.ip_protocol,
217                 self.stream_number or '',
218                 self.src,
219                 self.dest,
220                 self.protocol,
221                 self.opcode,
222                 self.desc,
223                 extra))
224
225    def __str__(self):
226        return ("%.3f: %d -> %d; ip %s; strm %s; prot %s; op %s; desc %s %s" %
227                (self.timestamp, self.src, self.dest, self.ip_protocol or '-',
228                 self.stream_number, self.protocol, self.opcode, self.desc,
229                 ('«' + ' '.join(self.extra) + '»' if self.extra else '')))
230
231    def __repr__(self):
232        return "<Packet @%s>" % self
233
234    def copy(self):
235        return self.__class__(self.timestamp,
236                              self.ip_protocol,
237                              self.stream_number,
238                              self.src,
239                              self.dest,
240                              self.protocol,
241                              self.opcode,
242                              self.desc,
243                              self.extra)
244
245    def as_packet_type(self):
246        t = '%s:%s' % (self.protocol, self.opcode)
247        return t
248
249    def client_score(self):
250        """A positive number means we think it is a client; a negative number
251        means we think it is a server. Zero means no idea. range: -1 to 1.
252        """
253        key = (self.protocol, self.opcode)
254        if key in CLIENT_CLUES:
255            return CLIENT_CLUES[key]
256        if key in SERVER_CLUES:
257            return -SERVER_CLUES[key]
258        return 0.0
259
260    def play(self, conversation, context):
261        """Send the packet over the network, if required.
262
263        Some packets are ignored, i.e. for  protocols not handled,
264        server response messages, or messages that are generated by the
265        protocol layer associated with other packets.
266        """
267        fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
268        try:
269            fn = getattr(traffic_packets, fn_name)
270
271        except AttributeError as e:
272            print("Conversation(%s) Missing handler %s" %
273                  (conversation.conversation_id, fn_name),
274                  file=sys.stderr)
275            return
276
277        # Don't display a message for kerberos packets, they're not directly
278        # generated they're used to indicate kerberos should be used
279        if self.protocol != "kerberos":
280            debug(2, "Conversation(%s) Calling handler %s" %
281                     (conversation.conversation_id, fn_name))
282
283        start = time.time()
284        try:
285            if fn(self, conversation, context):
286                # Only collect timing data for functions that generate
287                # network traffic, or fail
288                end = time.time()
289                duration = end - start
290                print("%f\t%s\t%s\t%s\t%f\tTrue\t" %
291                      (end, conversation.conversation_id, self.protocol,
292                       self.opcode, duration))
293        except Exception as e:
294            end = time.time()
295            duration = end - start
296            print("%f\t%s\t%s\t%s\t%f\tFalse\t%s" %
297                  (end, conversation.conversation_id, self.protocol,
298                   self.opcode, duration, e))
299
300    def __cmp__(self, other):
301        return self.timestamp - other.timestamp
302
303    def is_really_a_packet(self, missing_packet_stats=None):
304        return is_a_real_packet(self.protocol, self.opcode)
305
306
307def is_a_real_packet(protocol, opcode):
308    """Is the packet one that can be ignored?
309
310    If so removing it will have no effect on the replay
311    """
312    if protocol in SKIPPED_PROTOCOLS:
313        # Ignore any packets for the protocols we're not interested in.
314        return False
315    if protocol == "ldap" and opcode == '':
316        # skip ldap continuation packets
317        return False
318
319    fn_name = 'packet_%s_%s' % (protocol, opcode)
320    fn = getattr(traffic_packets, fn_name, None)
321    if fn is None:
322        LOGGER.debug("missing packet %s" % fn_name, file=sys.stderr)
323        return False
324    if fn is traffic_packets.null_packet:
325        return False
326    return True
327
328
329def is_a_traffic_generating_packet(protocol, opcode):
330    """Return true if a packet generates traffic in its own right. Some of
331    these will generate traffic in certain contexts (e.g. ldap unbind
332    after a bind) but not if the conversation consists only of these packets.
333    """
334    if protocol == 'wait':
335        return False
336
337    if (protocol, opcode) in (
338            ('kerberos', ''),
339            ('ldap', '2'),
340            ('dcerpc', '15'),
341            ('dcerpc', '16')):
342        return False
343
344    return is_a_real_packet(protocol, opcode)
345
346
347class ReplayContext(object):
348    """State/Context for a conversation between an simulated client and a
349       server. Some of the context is shared amongst all conversations
350       and should be generated before the fork, while other context is
351       specific to a particular conversation and should be generated
352       *after* the fork, in generate_process_local_config().
353    """
354    def __init__(self,
355                 server=None,
356                 lp=None,
357                 creds=None,
358                 total_conversations=None,
359                 badpassword_frequency=None,
360                 prefer_kerberos=None,
361                 tempdir=None,
362                 statsdir=None,
363                 ou=None,
364                 base_dn=None,
365                 domain=os.environ.get("DOMAIN"),
366                 domain_sid=None,
367                 instance_id=None):
368        self.server                   = server
369        self.netlogon_connection      = None
370        self.creds                    = creds
371        self.lp                       = lp
372        if prefer_kerberos:
373            self.kerberos_state = MUST_USE_KERBEROS
374        else:
375            self.kerberos_state = DONT_USE_KERBEROS
376        self.ou                       = ou
377        self.base_dn                  = base_dn
378        self.domain                   = domain
379        self.statsdir                 = statsdir
380        self.global_tempdir           = tempdir
381        self.domain_sid               = domain_sid
382        self.realm                    = lp.get('realm')
383        self.instance_id              = instance_id
384
385        # Bad password attempt controls
386        self.badpassword_frequency    = badpassword_frequency
387        self.last_lsarpc_bad          = False
388        self.last_lsarpc_named_bad    = False
389        self.last_simple_bind_bad     = False
390        self.last_bind_bad            = False
391        self.last_srvsvc_bad          = False
392        self.last_drsuapi_bad         = False
393        self.last_netlogon_bad        = False
394        self.last_samlogon_bad        = False
395        self.total_conversations      = total_conversations
396        self.generate_ldap_search_tables()
397
398    def generate_ldap_search_tables(self):
399        session = system_session()
400
401        db = SamDB(url="ldap://%s" % self.server,
402                   session_info=session,
403                   credentials=self.creds,
404                   lp=self.lp)
405
406        res = db.search(db.domain_dn(),
407                        scope=ldb.SCOPE_SUBTREE,
408                        controls=["paged_results:1:1000"],
409                        attrs=['dn'])
410
411        # find a list of dns for each pattern
412        # e.g. CN,CN,CN,DC,DC
413        dn_map = {}
414        attribute_clue_map = {
415            'invocationId': []
416        }
417
418        for r in res:
419            dn = str(r.dn)
420            pattern = ','.join(x.lstrip()[:2] for x in dn.split(',')).upper()
421            dns = dn_map.setdefault(pattern, [])
422            dns.append(dn)
423            if dn.startswith('CN=NTDS Settings,'):
424                attribute_clue_map['invocationId'].append(dn)
425
426        # extend the map in case we are working with a different
427        # number of DC components.
428        # for k, v in self.dn_map.items():
429        #     print >>sys.stderr, k, len(v)
430
431        for k in list(dn_map.keys()):
432            if k[-3:] != ',DC':
433                continue
434            p = k[:-3]
435            while p[-3:] == ',DC':
436                p = p[:-3]
437            for i in range(5):
438                p += ',DC'
439                if p != k and p in dn_map:
440                    print('dn_map collison %s %s' % (k, p),
441                          file=sys.stderr)
442                    continue
443                dn_map[p] = dn_map[k]
444
445        self.dn_map = dn_map
446        self.attribute_clue_map = attribute_clue_map
447
448        # pre-populate DN-based search filters (it's simplest to generate them
449        # once, when the test starts). These are used by guess_search_filter()
450        # to avoid full-scans
451        self.search_filters = {}
452
453        # lookup all the GPO DNs
454        res = db.search(db.domain_dn(), scope=ldb.SCOPE_SUBTREE, attrs=['dn'],
455                        expression='(objectclass=groupPolicyContainer)')
456        gpos_by_dn = "".join("(distinguishedName={0})".format(msg['dn']) for msg in res)
457
458        # a search for the 'gPCFileSysPath' attribute is probably a GPO search
459        # (as per the MS-GPOL spec) which searches for GPOs by DN
460        self.search_filters['gPCFileSysPath'] = "(|{0})".format(gpos_by_dn)
461
462        # likewise, a search for gpLink is probably the Domain SOM search part
463        # of the MS-GPOL, in which case it's looking up a few OUs by DN
464        ou_str = ""
465        for ou in ["Domain Controllers,", "traffic_replay,", ""]:
466            ou_str += "(distinguishedName={0}{1})".format(ou, db.domain_dn())
467        self.search_filters['gpLink'] = "(|{0})".format(ou_str)
468
469        # The CEP Web Service can query the AD DC to get pKICertificateTemplate
470        # objects (as per MS-WCCE)
471        self.search_filters['pKIExtendedKeyUsage'] = \
472            '(objectCategory=pKICertificateTemplate)'
473
474        # assume that anything querying the usnChanged is some kind of
475        # synchronization tool, e.g. AD Change Detection Connector
476        res = db.search('', scope=ldb.SCOPE_BASE, attrs=['highestCommittedUSN'])
477        self.search_filters['usnChanged'] = \
478            '(usnChanged>={0})'.format(res[0]['highestCommittedUSN'])
479
480    # The traffic_learner script doesn't preserve the LDAP search filter, and
481    # having no filter can result in a full DB scan. This is costly for a large
482    # DB, and not necessarily representative of real world traffic. As there
483    # several standard LDAP queries that get used by AD tools, we can apply
484    # some logic and guess what the search filter might have been originally.
485    def guess_search_filter(self, attrs, dn_sig, dn):
486
487        # there are some standard spec-based searches that query fairly unique
488        # attributes. Check if the search is likely one of these
489        for key in self.search_filters.keys():
490            if key in attrs:
491                return self.search_filters[key]
492
493        # if it's the top-level domain, assume we're looking up a single user,
494        # e.g. like powershell Get-ADUser or a similar tool
495        if dn_sig == 'DC,DC':
496            random_user_id = random.random() % self.total_conversations
497            account_name = user_name(self.instance_id, random_user_id)
498            return '(&(sAMAccountName=%s)(objectClass=user))' % account_name
499
500        # otherwise just return everything in the sub-tree
501        return '(objectClass=*)'
502
503    def generate_process_local_config(self, account, conversation):
504        self.ldap_connections         = []
505        self.dcerpc_connections       = []
506        self.lsarpc_connections       = []
507        self.lsarpc_connections_named = []
508        self.drsuapi_connections      = []
509        self.srvsvc_connections       = []
510        self.samr_contexts            = []
511        self.netbios_name             = account.netbios_name
512        self.machinepass              = account.machinepass
513        self.username                 = account.username
514        self.userpass                 = account.userpass
515
516        self.tempdir = mk_masked_dir(self.global_tempdir,
517                                     'conversation-%d' %
518                                     conversation.conversation_id)
519
520        self.lp.set("private dir", self.tempdir)
521        self.lp.set("lock dir", self.tempdir)
522        self.lp.set("state directory", self.tempdir)
523        self.lp.set("tls verify peer", "no_check")
524
525        self.remoteAddress = "/root/ncalrpc_as_system"
526        self.samlogon_dn   = ("cn=%s,%s" %
527                              (self.netbios_name, self.ou))
528        self.user_dn       = ("cn=%s,%s" %
529                              (self.username, self.ou))
530
531        self.generate_machine_creds()
532        self.generate_user_creds()
533
534    def with_random_bad_credentials(self, f, good, bad, failed_last_time):
535        """Execute the supplied logon function, randomly choosing the
536           bad credentials.
537
538           Based on the frequency in badpassword_frequency randomly perform the
539           function with the supplied bad credentials.
540           If run with bad credentials, the function is re-run with the good
541           credentials.
542           failed_last_time is used to prevent consecutive bad credential
543           attempts. So the over all bad credential frequency will be lower
544           than that requested, but not significantly.
545        """
546        if not failed_last_time:
547            if (self.badpassword_frequency and
548                random.random() < self.badpassword_frequency):
549                try:
550                    f(bad)
551                except Exception:
552                    # Ignore any exceptions as the operation may fail
553                    # as it's being performed with bad credentials
554                    pass
555                failed_last_time = True
556            else:
557                failed_last_time = False
558
559        result = f(good)
560        return (result, failed_last_time)
561
562    def generate_user_creds(self):
563        """Generate the conversation specific user Credentials.
564
565        Each Conversation has an associated user account used to simulate
566        any non Administrative user traffic.
567
568        Generates user credentials with good and bad passwords and ldap
569        simple bind credentials with good and bad passwords.
570        """
571        self.user_creds = Credentials()
572        self.user_creds.guess(self.lp)
573        self.user_creds.set_workstation(self.netbios_name)
574        self.user_creds.set_password(self.userpass)
575        self.user_creds.set_username(self.username)
576        self.user_creds.set_domain(self.domain)
577        self.user_creds.set_kerberos_state(self.kerberos_state)
578
579        self.user_creds_bad = Credentials()
580        self.user_creds_bad.guess(self.lp)
581        self.user_creds_bad.set_workstation(self.netbios_name)
582        self.user_creds_bad.set_password(self.userpass[:-4])
583        self.user_creds_bad.set_username(self.username)
584        self.user_creds_bad.set_kerberos_state(self.kerberos_state)
585
586        # Credentials for ldap simple bind.
587        self.simple_bind_creds = Credentials()
588        self.simple_bind_creds.guess(self.lp)
589        self.simple_bind_creds.set_workstation(self.netbios_name)
590        self.simple_bind_creds.set_password(self.userpass)
591        self.simple_bind_creds.set_username(self.username)
592        self.simple_bind_creds.set_gensec_features(
593            self.simple_bind_creds.get_gensec_features() | gensec.FEATURE_SEAL)
594        self.simple_bind_creds.set_kerberos_state(self.kerberos_state)
595        self.simple_bind_creds.set_bind_dn(self.user_dn)
596
597        self.simple_bind_creds_bad = Credentials()
598        self.simple_bind_creds_bad.guess(self.lp)
599        self.simple_bind_creds_bad.set_workstation(self.netbios_name)
600        self.simple_bind_creds_bad.set_password(self.userpass[:-4])
601        self.simple_bind_creds_bad.set_username(self.username)
602        self.simple_bind_creds_bad.set_gensec_features(
603            self.simple_bind_creds_bad.get_gensec_features() |
604            gensec.FEATURE_SEAL)
605        self.simple_bind_creds_bad.set_kerberos_state(self.kerberos_state)
606        self.simple_bind_creds_bad.set_bind_dn(self.user_dn)
607
608    def generate_machine_creds(self):
609        """Generate the conversation specific machine Credentials.
610
611        Each Conversation has an associated machine account.
612
613        Generates machine credentials with good and bad passwords.
614        """
615
616        self.machine_creds = Credentials()
617        self.machine_creds.guess(self.lp)
618        self.machine_creds.set_workstation(self.netbios_name)
619        self.machine_creds.set_secure_channel_type(SEC_CHAN_BDC)
620        self.machine_creds.set_password(self.machinepass)
621        self.machine_creds.set_username(self.netbios_name + "$")
622        self.machine_creds.set_domain(self.domain)
623        self.machine_creds.set_kerberos_state(self.kerberos_state)
624
625        self.machine_creds_bad = Credentials()
626        self.machine_creds_bad.guess(self.lp)
627        self.machine_creds_bad.set_workstation(self.netbios_name)
628        self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_BDC)
629        self.machine_creds_bad.set_password(self.machinepass[:-4])
630        self.machine_creds_bad.set_username(self.netbios_name + "$")
631        self.machine_creds_bad.set_kerberos_state(self.kerberos_state)
632
633    def get_matching_dn(self, pattern, attributes=None):
634        # If the pattern is an empty string, we assume ROOTDSE,
635        # Otherwise we try adding or removing DC suffixes, then
636        # shorter leading patterns until we hit one.
637        # e.g if there is no CN,CN,CN,CN,DC,DC
638        # we first try       CN,CN,CN,CN,DC
639        # and                CN,CN,CN,CN,DC,DC,DC
640        # then change to        CN,CN,CN,DC,DC
641        # and as last resort we use the base_dn
642        attr_clue = self.attribute_clue_map.get(attributes)
643        if attr_clue:
644            return random.choice(attr_clue)
645
646        pattern = pattern.upper()
647        while pattern:
648            if pattern in self.dn_map:
649                return random.choice(self.dn_map[pattern])
650            # chop one off the front and try it all again.
651            pattern = pattern[3:]
652
653        return self.base_dn
654
655    def get_dcerpc_connection(self, new=False):
656        guid = '12345678-1234-abcd-ef00-01234567cffb'  # RPC_NETLOGON UUID
657        if self.dcerpc_connections and not new:
658            return self.dcerpc_connections[-1]
659        c = ClientConnection("ncacn_ip_tcp:%s" % self.server,
660                             (guid, 1), self.lp)
661        self.dcerpc_connections.append(c)
662        return c
663
664    def get_srvsvc_connection(self, new=False):
665        if self.srvsvc_connections and not new:
666            return self.srvsvc_connections[-1]
667
668        def connect(creds):
669            return srvsvc.srvsvc("ncacn_np:%s" % (self.server),
670                                 self.lp,
671                                 creds)
672
673        (c, self.last_srvsvc_bad) = \
674            self.with_random_bad_credentials(connect,
675                                             self.user_creds,
676                                             self.user_creds_bad,
677                                             self.last_srvsvc_bad)
678
679        self.srvsvc_connections.append(c)
680        return c
681
682    def get_lsarpc_connection(self, new=False):
683        if self.lsarpc_connections and not new:
684            return self.lsarpc_connections[-1]
685
686        def connect(creds):
687            binding_options = 'schannel,seal,sign'
688            return lsa.lsarpc("ncacn_ip_tcp:%s[%s]" %
689                              (self.server, binding_options),
690                              self.lp,
691                              creds)
692
693        (c, self.last_lsarpc_bad) = \
694            self.with_random_bad_credentials(connect,
695                                             self.machine_creds,
696                                             self.machine_creds_bad,
697                                             self.last_lsarpc_bad)
698
699        self.lsarpc_connections.append(c)
700        return c
701
702    def get_lsarpc_named_pipe_connection(self, new=False):
703        if self.lsarpc_connections_named and not new:
704            return self.lsarpc_connections_named[-1]
705
706        def connect(creds):
707            return lsa.lsarpc("ncacn_np:%s" % (self.server),
708                              self.lp,
709                              creds)
710
711        (c, self.last_lsarpc_named_bad) = \
712            self.with_random_bad_credentials(connect,
713                                             self.machine_creds,
714                                             self.machine_creds_bad,
715                                             self.last_lsarpc_named_bad)
716
717        self.lsarpc_connections_named.append(c)
718        return c
719
720    def get_drsuapi_connection_pair(self, new=False, unbind=False):
721        """get a (drs, drs_handle) tuple"""
722        if self.drsuapi_connections and not new:
723            c = self.drsuapi_connections[-1]
724            return c
725
726        def connect(creds):
727            binding_options = 'seal'
728            binding_string = "ncacn_ip_tcp:%s[%s]" %\
729                             (self.server, binding_options)
730            return drsuapi.drsuapi(binding_string, self.lp, creds)
731
732        (drs, self.last_drsuapi_bad) = \
733            self.with_random_bad_credentials(connect,
734                                             self.user_creds,
735                                             self.user_creds_bad,
736                                             self.last_drsuapi_bad)
737
738        (drs_handle, supported_extensions) = drs_DsBind(drs)
739        c = (drs, drs_handle)
740        self.drsuapi_connections.append(c)
741        return c
742
743    def get_ldap_connection(self, new=False, simple=False):
744        if self.ldap_connections and not new:
745            return self.ldap_connections[-1]
746
747        def simple_bind(creds):
748            """
749            To run simple bind against Windows, we need to run
750            following commands in PowerShell:
751
752                Install-windowsfeature ADCS-Cert-Authority
753                Install-AdcsCertificationAuthority -CAType EnterpriseRootCA
754                Restart-Computer
755
756            """
757            return SamDB('ldaps://%s' % self.server,
758                         credentials=creds,
759                         lp=self.lp)
760
761        def sasl_bind(creds):
762            return SamDB('ldap://%s' % self.server,
763                         credentials=creds,
764                         lp=self.lp)
765        if simple:
766            (samdb, self.last_simple_bind_bad) = \
767                self.with_random_bad_credentials(simple_bind,
768                                                 self.simple_bind_creds,
769                                                 self.simple_bind_creds_bad,
770                                                 self.last_simple_bind_bad)
771        else:
772            (samdb, self.last_bind_bad) = \
773                self.with_random_bad_credentials(sasl_bind,
774                                                 self.user_creds,
775                                                 self.user_creds_bad,
776                                                 self.last_bind_bad)
777
778        self.ldap_connections.append(samdb)
779        return samdb
780
781    def get_samr_context(self, new=False):
782        if not self.samr_contexts or new:
783            self.samr_contexts.append(
784                SamrContext(self.server, lp=self.lp, creds=self.creds))
785        return self.samr_contexts[-1]
786
787    def get_netlogon_connection(self):
788
789        if self.netlogon_connection:
790            return self.netlogon_connection
791
792        def connect(creds):
793            return netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" %
794                                     (self.server),
795                                     self.lp,
796                                     creds)
797        (c, self.last_netlogon_bad) = \
798            self.with_random_bad_credentials(connect,
799                                             self.machine_creds,
800                                             self.machine_creds_bad,
801                                             self.last_netlogon_bad)
802        self.netlogon_connection = c
803        return c
804
805    def guess_a_dns_lookup(self):
806        return (self.realm, 'A')
807
808    def get_authenticator(self):
809        auth = self.machine_creds.new_client_authenticator()
810        current  = netr_Authenticator()
811        current.cred.data = [x if isinstance(x, int) else ord(x)
812                             for x in auth["credential"]]
813        current.timestamp = auth["timestamp"]
814
815        subsequent = netr_Authenticator()
816        return (current, subsequent)
817
818    def write_stats(self, filename, **kwargs):
819        """Write arbitrary key/value pairs to a file in our stats directory in
820        order for them to be picked up later by another process working out
821        statistics."""
822        filename = os.path.join(self.statsdir, filename)
823        f = open(filename, 'w')
824        for k, v in kwargs.items():
825            print("%s: %s" % (k, v), file=f)
826        f.close()
827
828
829class SamrContext(object):
830    """State/Context associated with a samr connection.
831    """
832    def __init__(self, server, lp=None, creds=None):
833        self.connection    = None
834        self.handle        = None
835        self.domain_handle = None
836        self.domain_sid    = None
837        self.group_handle  = None
838        self.user_handle   = None
839        self.rids          = None
840        self.server        = server
841        self.lp            = lp
842        self.creds         = creds
843
844    def get_connection(self):
845        if not self.connection:
846            self.connection = samr.samr(
847                "ncacn_ip_tcp:%s[seal]" % (self.server),
848                lp_ctx=self.lp,
849                credentials=self.creds)
850
851        return self.connection
852
853    def get_handle(self):
854        if not self.handle:
855            c = self.get_connection()
856            self.handle = c.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
857        return self.handle
858
859
860class Conversation(object):
861    """Details of a converation between a simulated client and a server."""
862    def __init__(self, start_time=None, endpoints=None, seq=(),
863                 conversation_id=None):
864        self.start_time = start_time
865        self.endpoints = endpoints
866        self.packets = []
867        self.msg = random_colour_print(endpoints)
868        self.client_balance = 0.0
869        self.conversation_id = conversation_id
870        for p in seq:
871            self.add_short_packet(*p)
872
873    def __cmp__(self, other):
874        if self.start_time is None:
875            if other.start_time is None:
876                return 0
877            return -1
878        if other.start_time is None:
879            return 1
880        return self.start_time - other.start_time
881
882    def add_packet(self, packet):
883        """Add a packet object to this conversation, making a local copy with
884        a conversation-relative timestamp."""
885        p = packet.copy()
886
887        if self.start_time is None:
888            self.start_time = p.timestamp
889
890        if self.endpoints is None:
891            self.endpoints = p.endpoints
892
893        if p.endpoints != self.endpoints:
894            raise FakePacketError("Conversation endpoints %s don't match"
895                                  "packet endpoints %s" %
896                                  (self.endpoints, p.endpoints))
897
898        p.timestamp -= self.start_time
899
900        if p.src == p.endpoints[0]:
901            self.client_balance -= p.client_score()
902        else:
903            self.client_balance += p.client_score()
904
905        if p.is_really_a_packet():
906            self.packets.append(p)
907
908    def add_short_packet(self, timestamp, protocol, opcode, extra,
909                         client=True, skip_unused_packets=True):
910        """Create a packet from a timestamp, and 'protocol:opcode' pair, and a
911        (possibly empty) list of extra data. If client is True, assume
912        this packet is from the client to the server.
913        """
914        if skip_unused_packets and not is_a_real_packet(protocol, opcode):
915            return
916
917        src, dest = self.guess_client_server()
918        if not client:
919            src, dest = dest, src
920        key = (protocol, opcode)
921        desc = OP_DESCRIPTIONS.get(key, '')
922        ip_protocol = IP_PROTOCOLS.get(protocol, '06')
923        packet = Packet(timestamp - self.start_time, ip_protocol,
924                        '', src, dest,
925                        protocol, opcode, desc, extra)
926        # XXX we're assuming the timestamp is already adjusted for
927        # this conversation?
928        # XXX should we adjust client balance for guessed packets?
929        if packet.src == packet.endpoints[0]:
930            self.client_balance -= packet.client_score()
931        else:
932            self.client_balance += packet.client_score()
933        if packet.is_really_a_packet():
934            self.packets.append(packet)
935
936    def __str__(self):
937        return ("<Conversation %s %s starting %.3f %d packets>" %
938                (self.conversation_id, self.endpoints, self.start_time,
939                 len(self.packets)))
940
941    __repr__ = __str__
942
943    def __iter__(self):
944        return iter(self.packets)
945
946    def __len__(self):
947        return len(self.packets)
948
949    def get_duration(self):
950        if len(self.packets) < 2:
951            return 0
952        return self.packets[-1].timestamp - self.packets[0].timestamp
953
954    def replay_as_summary_lines(self):
955        return [p.as_summary(self.start_time) for p in self.packets]
956
957    def replay_with_delay(self, start, context=None, account=None):
958        """Replay the conversation at the right time.
959        (We're already in a fork)."""
960        # first we sleep until the first packet
961        t = self.start_time
962        now = time.time() - start
963        gap = t - now
964        sleep_time = gap - SLEEP_OVERHEAD
965        if sleep_time > 0:
966            time.sleep(sleep_time)
967
968        miss = (time.time() - start) - t
969        self.msg("starting %s [miss %.3f]" % (self, miss))
970
971        max_gap = 0.0
972        max_sleep_miss = 0.0
973        # packet times are relative to conversation start
974        p_start = time.time()
975        for p in self.packets:
976            now = time.time() - p_start
977            gap = now - p.timestamp
978            if gap > max_gap:
979                max_gap = gap
980            if gap < 0:
981                sleep_time = -gap - SLEEP_OVERHEAD
982                if sleep_time > 0:
983                    time.sleep(sleep_time)
984                    t = time.time() - p_start
985                    if t - p.timestamp > max_sleep_miss:
986                        max_sleep_miss = t - p.timestamp
987
988            p.play(self, context)
989
990        return max_gap, miss, max_sleep_miss
991
992    def guess_client_server(self, server_clue=None):
993        """Have a go at deciding who is the server and who is the client.
994        returns (client, server)
995        """
996        a, b = self.endpoints
997
998        if self.client_balance < 0:
999            return (a, b)
1000
1001        # in the absense of a clue, we will fall through to assuming
1002        # the lowest number is the server (which is usually true).
1003
1004        if self.client_balance == 0 and server_clue == b:
1005            return (a, b)
1006
1007        return (b, a)
1008
1009    def forget_packets_outside_window(self, s, e):
1010        """Prune any packets outside the timne window we're interested in
1011
1012        :param s: start of the window
1013        :param e: end of the window
1014        """
1015        self.packets = [p for p in self.packets if s <= p.timestamp <= e]
1016        self.start_time = self.packets[0].timestamp if self.packets else None
1017
1018    def renormalise_times(self, start_time):
1019        """Adjust the packet start times relative to the new start time."""
1020        for p in self.packets:
1021            p.timestamp -= start_time
1022
1023        if self.start_time is not None:
1024            self.start_time -= start_time
1025
1026
1027class DnsHammer(Conversation):
1028    """A lightweight conversation that generates a lot of dns:0 packets on
1029    the fly"""
1030
1031    def __init__(self, dns_rate, duration, query_file=None):
1032        n = int(dns_rate * duration)
1033        self.times = [random.uniform(0, duration) for i in range(n)]
1034        self.times.sort()
1035        self.rate = dns_rate
1036        self.duration = duration
1037        self.start_time = 0
1038        self.query_choices = self._get_query_choices(query_file=query_file)
1039
1040    def __str__(self):
1041        return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
1042                (len(self.times), self.duration, self.rate))
1043
1044    def _get_query_choices(self, query_file=None):
1045        """
1046        Read dns query choices from a file, or return default
1047
1048        rname may contain format string like `{realm}`
1049        realm can be fetched from context.realm
1050        """
1051
1052        if query_file:
1053            with open(query_file, 'r') as f:
1054                text = f.read()
1055            choices = []
1056            for line in text.splitlines():
1057                line = line.strip()
1058                if line and not line.startswith('#'):
1059                    args = line.split(',')
1060                    assert len(args) == 4
1061                    choices.append(args)
1062            return choices
1063        else:
1064            return [
1065                (0, '{realm}', 'A', 'yes'),
1066                (1, '{realm}', 'NS', 'yes'),
1067                (2, '*.{realm}', 'A', 'no'),
1068                (3, '*.{realm}', 'NS', 'no'),
1069                (10, '_msdcs.{realm}', 'A', 'yes'),
1070                (11, '_msdcs.{realm}', 'NS', 'yes'),
1071                (20, 'nx.realm.com', 'A', 'no'),
1072                (21, 'nx.realm.com', 'NS', 'no'),
1073                (22, '*.nx.realm.com', 'A', 'no'),
1074                (23, '*.nx.realm.com', 'NS', 'no'),
1075            ]
1076
1077    def replay(self, context=None):
1078        assert context
1079        assert context.realm
1080        start = time.time()
1081        for t in self.times:
1082            now = time.time() - start
1083            gap = t - now
1084            sleep_time = gap - SLEEP_OVERHEAD
1085            if sleep_time > 0:
1086                time.sleep(sleep_time)
1087
1088            opcode, rname, rtype, exist = random.choice(self.query_choices)
1089            rname = rname.format(realm=context.realm)
1090            success = True
1091            packet_start = time.time()
1092            try:
1093                answers = dns_query(rname, rtype)
1094                if exist == 'yes' and not len(answers):
1095                    # expect answers but didn't get, fail
1096                    success = False
1097            except Exception:
1098                success = False
1099            finally:
1100                end = time.time()
1101                duration = end - packet_start
1102                print("%f\tDNS\tdns\t%s\t%f\t%s\t" % (end, opcode, duration, success))
1103
1104
1105def ingest_summaries(files, dns_mode='count'):
1106    """Load a summary traffic summary file and generated Converations from it.
1107    """
1108
1109    dns_counts = defaultdict(int)
1110    packets = []
1111    for f in files:
1112        if isinstance(f, str):
1113            f = open(f)
1114        print("Ingesting %s" % (f.name,), file=sys.stderr)
1115        for line in f:
1116            p = Packet.from_line(line)
1117            if p.protocol == 'dns' and dns_mode != 'include':
1118                dns_counts[p.opcode] += 1
1119            else:
1120                packets.append(p)
1121
1122        f.close()
1123
1124    if not packets:
1125        return [], 0
1126
1127    start_time = min(p.timestamp for p in packets)
1128    last_packet = max(p.timestamp for p in packets)
1129
1130    print("gathering packets into conversations", file=sys.stderr)
1131    conversations = OrderedDict()
1132    for i, p in enumerate(packets):
1133        p.timestamp -= start_time
1134        c = conversations.get(p.endpoints)
1135        if c is None:
1136            c = Conversation(conversation_id=(i + 2))
1137            conversations[p.endpoints] = c
1138        c.add_packet(p)
1139
1140    # We only care about conversations with actual traffic, so we
1141    # filter out conversations with nothing to say. We do that here,
1142    # rather than earlier, because those empty packets contain useful
1143    # hints as to which end of the conversation was the client.
1144    conversation_list = []
1145    for c in conversations.values():
1146        if len(c) != 0:
1147            conversation_list.append(c)
1148
1149    # This is obviously not correct, as many conversations will appear
1150    # to start roughly simultaneously at the beginning of the snapshot.
1151    # To which we say: oh well, so be it.
1152    duration = float(last_packet - start_time)
1153    mean_interval = len(conversations) / duration
1154
1155    return conversation_list, mean_interval, duration, dns_counts
1156
1157
1158def guess_server_address(conversations):
1159    # we guess the most common address.
1160    addresses = Counter()
1161    for c in conversations:
1162        addresses.update(c.endpoints)
1163    if addresses:
1164        return addresses.most_common(1)[0]
1165
1166
1167def stringify_keys(x):
1168    y = {}
1169    for k, v in x.items():
1170        k2 = '\t'.join(k)
1171        y[k2] = v
1172    return y
1173
1174
1175def unstringify_keys(x):
1176    y = {}
1177    for k, v in x.items():
1178        t = tuple(str(k).split('\t'))
1179        y[t] = v
1180    return y
1181
1182
1183class TrafficModel(object):
1184    def __init__(self, n=3):
1185        self.ngrams = {}
1186        self.query_details = {}
1187        self.n = n
1188        self.dns_opcounts = defaultdict(int)
1189        self.cumulative_duration = 0.0
1190        self.packet_rate = [0, 1]
1191
1192    def learn(self, conversations, dns_opcounts={}):
1193        prev = 0.0
1194        cum_duration = 0.0
1195        key = (NON_PACKET,) * (self.n - 1)
1196
1197        server = guess_server_address(conversations)
1198
1199        for k, v in dns_opcounts.items():
1200            self.dns_opcounts[k] += v
1201
1202        if len(conversations) > 1:
1203            first = conversations[0].start_time
1204            total = 0
1205            last = first + 0.1
1206            for c in conversations:
1207                total += len(c)
1208                last = max(last, c.packets[-1].timestamp)
1209
1210            self.packet_rate[0] = total
1211            self.packet_rate[1] = last - first
1212
1213        for c in conversations:
1214            client, server = c.guess_client_server(server)
1215            cum_duration += c.get_duration()
1216            key = (NON_PACKET,) * (self.n - 1)
1217            for p in c:
1218                if p.src != client:
1219                    continue
1220
1221                elapsed = p.timestamp - prev
1222                prev = p.timestamp
1223                if elapsed > WAIT_THRESHOLD:
1224                    # add the wait as an extra state
1225                    wait = 'wait:%d' % (math.log(max(1.0,
1226                                                     elapsed * WAIT_SCALE)))
1227                    self.ngrams.setdefault(key, []).append(wait)
1228                    key = key[1:] + (wait,)
1229
1230                short_p = p.as_packet_type()
1231                self.query_details.setdefault(short_p,
1232                                              []).append(tuple(p.extra))
1233                self.ngrams.setdefault(key, []).append(short_p)
1234                key = key[1:] + (short_p,)
1235
1236        self.cumulative_duration += cum_duration
1237        # add in the end
1238        self.ngrams.setdefault(key, []).append(NON_PACKET)
1239
1240    def save(self, f):
1241        ngrams = {}
1242        for k, v in self.ngrams.items():
1243            k = '\t'.join(k)
1244            ngrams[k] = dict(Counter(v))
1245
1246        query_details = {}
1247        for k, v in self.query_details.items():
1248            query_details[k] = dict(Counter('\t'.join(x) if x else '-'
1249                                            for x in v))
1250
1251        d = {
1252            'ngrams': ngrams,
1253            'query_details': query_details,
1254            'cumulative_duration': self.cumulative_duration,
1255            'packet_rate': self.packet_rate,
1256            'version': CURRENT_MODEL_VERSION
1257        }
1258        d['dns'] = self.dns_opcounts
1259
1260        if isinstance(f, str):
1261            f = open(f, 'w')
1262
1263        json.dump(d, f, indent=2)
1264
1265    def load(self, f):
1266        if isinstance(f, str):
1267            f = open(f)
1268
1269        d = json.load(f)
1270
1271        try:
1272            version = d["version"]
1273            if version < REQUIRED_MODEL_VERSION:
1274                raise ValueError("the model file is version %d; "
1275                                 "version %d is required" %
1276                                 (version, REQUIRED_MODEL_VERSION))
1277        except KeyError:
1278                raise ValueError("the model file lacks a version number; "
1279                                 "version %d is required" %
1280                                 (REQUIRED_MODEL_VERSION))
1281
1282        for k, v in d['ngrams'].items():
1283            k = tuple(str(k).split('\t'))
1284            values = self.ngrams.setdefault(k, [])
1285            for p, count in v.items():
1286                values.extend([str(p)] * count)
1287            values.sort()
1288
1289        for k, v in d['query_details'].items():
1290            values = self.query_details.setdefault(str(k), [])
1291            for p, count in v.items():
1292                if p == '-':
1293                    values.extend([()] * count)
1294                else:
1295                    values.extend([tuple(str(p).split('\t'))] * count)
1296            values.sort()
1297
1298        if 'dns' in d:
1299            for k, v in d['dns'].items():
1300                self.dns_opcounts[k] += v
1301
1302        self.cumulative_duration = d['cumulative_duration']
1303        self.packet_rate = d['packet_rate']
1304
1305    def construct_conversation_sequence(self, timestamp=0.0,
1306                                        hard_stop=None,
1307                                        replay_speed=1,
1308                                        ignore_before=0,
1309                                        persistence=0):
1310        """Construct an individual conversation packet sequence from the
1311        model.
1312        """
1313        c = []
1314        key = (NON_PACKET,) * (self.n - 1)
1315        if ignore_before is None:
1316            ignore_before = timestamp - 1
1317
1318        while True:
1319            p = random.choice(self.ngrams.get(key, (NON_PACKET,)))
1320            if p == NON_PACKET:
1321                if timestamp < ignore_before:
1322                    break
1323                if random.random() > persistence:
1324                    print("ending after %s (persistence %.1f)" % (key, persistence),
1325                          file=sys.stderr)
1326                    break
1327
1328                p = 'wait:%d' % random.randrange(5, 12)
1329                print("trying %s instead of end" % p, file=sys.stderr)
1330
1331            if p in self.query_details:
1332                extra = random.choice(self.query_details[p])
1333            else:
1334                extra = []
1335
1336            protocol, opcode = p.split(':', 1)
1337            if protocol == 'wait':
1338                log_wait_time = int(opcode) + random.random()
1339                wait = math.exp(log_wait_time) / (WAIT_SCALE * replay_speed)
1340                timestamp += wait
1341            else:
1342                log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
1343                wait = math.exp(log_wait) / replay_speed
1344                timestamp += wait
1345                if hard_stop is not None and timestamp > hard_stop:
1346                    break
1347                if timestamp >= ignore_before:
1348                    c.append((timestamp, protocol, opcode, extra))
1349
1350            key = key[1:] + (p,)
1351            if key[-2][:5] == 'wait:' and key[-1][:5] == 'wait:':
1352                # two waits in a row can only be caused by "persistence"
1353                # tricks, and will not result in any packets being found.
1354                # Instead we pretend this is a fresh start.
1355                key = (NON_PACKET,) * (self.n - 1)
1356
1357        return c
1358
1359    def scale_to_packet_rate(self, scale):
1360        rate_n, rate_t  = self.packet_rate
1361        return scale * rate_n / rate_t
1362
1363    def packet_rate_to_scale(self, pps):
1364        rate_n, rate_t  = self.packet_rate
1365        return  pps * rate_t / rate_n
1366
1367    def generate_conversation_sequences(self, packet_rate, duration, replay_speed=1,
1368                                        persistence=0):
1369        """Generate a list of conversation descriptions from the model."""
1370
1371        # We run the simulation for ten times as long as our desired
1372        # duration, and take the section at the end.
1373        lead_in = 9 * duration
1374        target_packets = int(packet_rate * duration)
1375        conversations = []
1376        n_packets = 0
1377
1378        while n_packets < target_packets:
1379            start = random.uniform(-lead_in, duration)
1380            c = self.construct_conversation_sequence(start,
1381                                                     hard_stop=duration,
1382                                                     replay_speed=replay_speed,
1383                                                     ignore_before=0,
1384                                                     persistence=persistence)
1385            # will these "packets" generate actual traffic?
1386            # some (e.g. ldap unbind) will not generate anything
1387            # if the previous packets are not there, and if the
1388            # conversation only has those it wastes a process doing nothing.
1389            for timestamp, protocol, opcode, extra in c:
1390                if is_a_traffic_generating_packet(protocol, opcode):
1391                    break
1392            else:
1393                continue
1394
1395            conversations.append(c)
1396            n_packets += len(c)
1397
1398        scale = self.packet_rate_to_scale(packet_rate)
1399        print(("we have %d packets (target %d) in %d conversations at %.1f/s "
1400               "(scale %f)" % (n_packets, target_packets, len(conversations),
1401                               packet_rate, scale)),
1402              file=sys.stderr)
1403        conversations.sort()  # sorts by first element == start time
1404        return conversations
1405
1406
1407def seq_to_conversations(seq, server=1, client=2):
1408    conversations = []
1409    for s in seq:
1410        if s:
1411            c = Conversation(s[0][0], (server, client), s)
1412            client += 1
1413            conversations.append(c)
1414    return conversations
1415
1416
1417IP_PROTOCOLS = {
1418    'dns': '11',
1419    'rpc_netlogon': '06',
1420    'kerberos': '06',      # ratio 16248:258
1421    'smb': '06',
1422    'smb2': '06',
1423    'ldap': '06',
1424    'cldap': '11',
1425    'lsarpc': '06',
1426    'samr': '06',
1427    'dcerpc': '06',
1428    'epm': '06',
1429    'drsuapi': '06',
1430    'browser': '11',
1431    'smb_netlogon': '11',
1432    'srvsvc': '06',
1433    'nbns': '11',
1434}
1435
1436OP_DESCRIPTIONS = {
1437    ('browser', '0x01'): 'Host Announcement (0x01)',
1438    ('browser', '0x02'): 'Request Announcement (0x02)',
1439    ('browser', '0x08'): 'Browser Election Request (0x08)',
1440    ('browser', '0x09'): 'Get Backup List Request (0x09)',
1441    ('browser', '0x0c'): 'Domain/Workgroup Announcement (0x0c)',
1442    ('browser', '0x0f'): 'Local Master Announcement (0x0f)',
1443    ('cldap', '3'): 'searchRequest',
1444    ('cldap', '5'): 'searchResDone',
1445    ('dcerpc', '0'): 'Request',
1446    ('dcerpc', '11'): 'Bind',
1447    ('dcerpc', '12'): 'Bind_ack',
1448    ('dcerpc', '13'): 'Bind_nak',
1449    ('dcerpc', '14'): 'Alter_context',
1450    ('dcerpc', '15'): 'Alter_context_resp',
1451    ('dcerpc', '16'): 'AUTH3',
1452    ('dcerpc', '2'): 'Response',
1453    ('dns', '0'): 'query',
1454    ('dns', '1'): 'response',
1455    ('drsuapi', '0'): 'DsBind',
1456    ('drsuapi', '12'): 'DsCrackNames',
1457    ('drsuapi', '13'): 'DsWriteAccountSpn',
1458    ('drsuapi', '1'): 'DsUnbind',
1459    ('drsuapi', '2'): 'DsReplicaSync',
1460    ('drsuapi', '3'): 'DsGetNCChanges',
1461    ('drsuapi', '4'): 'DsReplicaUpdateRefs',
1462    ('epm', '3'): 'Map',
1463    ('kerberos', ''): '',
1464    ('ldap', '0'): 'bindRequest',
1465    ('ldap', '1'): 'bindResponse',
1466    ('ldap', '2'): 'unbindRequest',
1467    ('ldap', '3'): 'searchRequest',
1468    ('ldap', '4'): 'searchResEntry',
1469    ('ldap', '5'): 'searchResDone',
1470    ('ldap', ''): '*** Unknown ***',
1471    ('lsarpc', '14'): 'lsa_LookupNames',
1472    ('lsarpc', '15'): 'lsa_LookupSids',
1473    ('lsarpc', '39'): 'lsa_QueryTrustedDomainInfoBySid',
1474    ('lsarpc', '40'): 'lsa_SetTrustedDomainInfo',
1475    ('lsarpc', '6'): 'lsa_OpenPolicy',
1476    ('lsarpc', '76'): 'lsa_LookupSids3',
1477    ('lsarpc', '77'): 'lsa_LookupNames4',
1478    ('nbns', '0'): 'query',
1479    ('nbns', '1'): 'response',
1480    ('rpc_netlogon', '21'): 'NetrLogonDummyRoutine1',
1481    ('rpc_netlogon', '26'): 'NetrServerAuthenticate3',
1482    ('rpc_netlogon', '29'): 'NetrLogonGetDomainInfo',
1483    ('rpc_netlogon', '30'): 'NetrServerPasswordSet2',
1484    ('rpc_netlogon', '39'): 'NetrLogonSamLogonEx',
1485    ('rpc_netlogon', '40'): 'DsrEnumerateDomainTrusts',
1486    ('rpc_netlogon', '45'): 'NetrLogonSamLogonWithFlags',
1487    ('rpc_netlogon', '4'): 'NetrServerReqChallenge',
1488    ('samr', '0',): 'Connect',
1489    ('samr', '16'): 'GetAliasMembership',
1490    ('samr', '17'): 'LookupNames',
1491    ('samr', '18'): 'LookupRids',
1492    ('samr', '19'): 'OpenGroup',
1493    ('samr', '1'): 'Close',
1494    ('samr', '25'): 'QueryGroupMember',
1495    ('samr', '34'): 'OpenUser',
1496    ('samr', '36'): 'QueryUserInfo',
1497    ('samr', '39'): 'GetGroupsForUser',
1498    ('samr', '3'): 'QuerySecurity',
1499    ('samr', '5'): 'LookupDomain',
1500    ('samr', '64'): 'Connect5',
1501    ('samr', '6'): 'EnumDomains',
1502    ('samr', '7'): 'OpenDomain',
1503    ('samr', '8'): 'QueryDomainInfo',
1504    ('smb', '0x04'): 'Close (0x04)',
1505    ('smb', '0x24'): 'Locking AndX (0x24)',
1506    ('smb', '0x2e'): 'Read AndX (0x2e)',
1507    ('smb', '0x32'): 'Trans2 (0x32)',
1508    ('smb', '0x71'): 'Tree Disconnect (0x71)',
1509    ('smb', '0x72'): 'Negotiate Protocol (0x72)',
1510    ('smb', '0x73'): 'Session Setup AndX (0x73)',
1511    ('smb', '0x74'): 'Logoff AndX (0x74)',
1512    ('smb', '0x75'): 'Tree Connect AndX (0x75)',
1513    ('smb', '0xa2'): 'NT Create AndX (0xa2)',
1514    ('smb2', '0'): 'NegotiateProtocol',
1515    ('smb2', '11'): 'Ioctl',
1516    ('smb2', '14'): 'Find',
1517    ('smb2', '16'): 'GetInfo',
1518    ('smb2', '18'): 'Break',
1519    ('smb2', '1'): 'SessionSetup',
1520    ('smb2', '2'): 'SessionLogoff',
1521    ('smb2', '3'): 'TreeConnect',
1522    ('smb2', '4'): 'TreeDisconnect',
1523    ('smb2', '5'): 'Create',
1524    ('smb2', '6'): 'Close',
1525    ('smb2', '8'): 'Read',
1526    ('smb_netlogon', '0x12'): 'SAM LOGON request from client (0x12)',
1527    ('smb_netlogon', '0x17'): ('SAM Active Directory Response - '
1528                               'user unknown (0x17)'),
1529    ('srvsvc', '16'): 'NetShareGetInfo',
1530    ('srvsvc', '21'): 'NetSrvGetInfo',
1531}
1532
1533
1534def expand_short_packet(p, timestamp, src, dest, extra):
1535    protocol, opcode = p.split(':', 1)
1536    desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
1537    ip_protocol = IP_PROTOCOLS.get(protocol, '06')
1538
1539    line = [timestamp, ip_protocol, '', src, dest, protocol, opcode, desc]
1540    line.extend(extra)
1541    return '\t'.join(line)
1542
1543
1544def flushing_signal_handler(signal, frame):
1545    """Signal handler closes standard out and error.
1546
1547    Triggered by a sigterm, ensures that the log messages are flushed
1548    to disk and not lost.
1549    """
1550    sys.stderr.close()
1551    sys.stdout.close()
1552    os._exit(0)
1553
1554
1555def replay_seq_in_fork(cs, start, context, account, client_id, server_id=1):
1556    """Fork a new process and replay the conversation sequence."""
1557    # We will need to reseed the random number generator or all the
1558    # clients will end up using the same sequence of random
1559    # numbers. random.randint() is mixed in so the initial seed will
1560    # have an effect here.
1561    seed = client_id * 1000 + random.randint(0, 999)
1562
1563    # flush our buffers so messages won't be written by both sides
1564    sys.stdout.flush()
1565    sys.stderr.flush()
1566    pid = os.fork()
1567    if pid != 0:
1568        return pid
1569
1570    # we must never return, or we'll end up running parts of the
1571    # parent's clean-up code. So we work in a try...finally, and
1572    # try to print any exceptions.
1573    try:
1574        random.seed(seed)
1575        endpoints = (server_id, client_id)
1576        status = 0
1577        t = cs[0][0]
1578        c = Conversation(t, endpoints, seq=cs, conversation_id=client_id)
1579        signal.signal(signal.SIGTERM, flushing_signal_handler)
1580
1581        context.generate_process_local_config(account, c)
1582        sys.stdin.close()
1583        os.close(0)
1584        filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
1585                                c.conversation_id)
1586        f = open(filename, 'w')
1587        try:
1588            sys.stdout.close()
1589            os.close(1)
1590        except IOError as e:
1591            LOGGER.info("stdout closing failed with %s" % e)
1592            pass
1593
1594        sys.stdout = f
1595        now = time.time() - start
1596        gap = t - now
1597        sleep_time = gap - SLEEP_OVERHEAD
1598        if sleep_time > 0:
1599            time.sleep(sleep_time)
1600
1601        max_lag, start_lag, max_sleep_miss = c.replay_with_delay(start=start,
1602                                                                 context=context)
1603        print("Maximum lag: %f" % max_lag)
1604        print("Start lag: %f" % start_lag)
1605        print("Max sleep miss: %f" % max_sleep_miss)
1606
1607    except Exception:
1608        status = 1
1609        print(("EXCEPTION in child PID %d, conversation %s" % (os.getpid(), c)),
1610              file=sys.stderr)
1611        traceback.print_exc(sys.stderr)
1612        sys.stderr.flush()
1613    finally:
1614        sys.stderr.close()
1615        sys.stdout.close()
1616        os._exit(status)
1617
1618
1619def dnshammer_in_fork(dns_rate, duration, context, query_file=None):
1620    sys.stdout.flush()
1621    sys.stderr.flush()
1622    pid = os.fork()
1623    if pid != 0:
1624        return pid
1625
1626    sys.stdin.close()
1627    os.close(0)
1628
1629    try:
1630        sys.stdout.close()
1631        os.close(1)
1632    except IOError as e:
1633        LOGGER.warn("stdout closing failed with %s" % e)
1634        pass
1635    filename = os.path.join(context.statsdir, 'stats-dns')
1636    sys.stdout = open(filename, 'w')
1637
1638    try:
1639        status = 0
1640        signal.signal(signal.SIGTERM, flushing_signal_handler)
1641        hammer = DnsHammer(dns_rate, duration, query_file=query_file)
1642        hammer.replay(context=context)
1643    except Exception:
1644        status = 1
1645        print(("EXCEPTION in child PID %d, the DNS hammer" % (os.getpid())),
1646              file=sys.stderr)
1647        traceback.print_exc(sys.stderr)
1648    finally:
1649        sys.stderr.close()
1650        sys.stdout.close()
1651        os._exit(status)
1652
1653
1654def replay(conversation_seq,
1655           host=None,
1656           creds=None,
1657           lp=None,
1658           accounts=None,
1659           dns_rate=0,
1660           dns_query_file=None,
1661           duration=None,
1662           latency_timeout=1.0,
1663           stop_on_any_error=False,
1664           **kwargs):
1665
1666    context = ReplayContext(server=host,
1667                            creds=creds,
1668                            lp=lp,
1669                            total_conversations=len(conversation_seq),
1670                            **kwargs)
1671
1672    if len(accounts) < len(conversation_seq):
1673        raise ValueError(("we have %d accounts but %d conversations" %
1674                          (len(accounts), len(conversation_seq))))
1675
1676    # Set the process group so that the calling scripts are not killed
1677    # when the forked child processes are killed.
1678    os.setpgrp()
1679
1680    # we delay the start by a bit to allow all the forks to get up and
1681    # running.
1682    delay = len(conversation_seq) * 0.02
1683    start = time.time() + delay
1684
1685    if duration is None:
1686        # end slightly after the last packet of the last conversation
1687        # to start. Conversations other than the last could still be
1688        # going, but we don't care.
1689        duration = conversation_seq[-1][-1][0] + latency_timeout
1690
1691    print("We will start in %.1f seconds" % delay,
1692          file=sys.stderr)
1693    print("We will stop after %.1f seconds" % (duration + delay),
1694          file=sys.stderr)
1695    print("runtime %.1f seconds" % duration,
1696          file=sys.stderr)
1697
1698    # give one second grace for packets to finish before killing begins
1699    end = start + duration + 1.0
1700
1701    LOGGER.info("Replaying traffic for %u conversations over %d seconds"
1702          % (len(conversation_seq), duration))
1703
1704    context.write_stats('intentions',
1705                        Planned_conversations=len(conversation_seq),
1706                        Planned_packets=sum(len(x) for x in conversation_seq))
1707
1708    children = {}
1709    try:
1710        if dns_rate:
1711            pid = dnshammer_in_fork(dns_rate, duration, context,
1712                                    query_file=dns_query_file)
1713            children[pid] = 1
1714
1715        for i, cs in enumerate(conversation_seq):
1716            account = accounts[i]
1717            client_id = i + 2
1718            pid = replay_seq_in_fork(cs, start, context, account, client_id)
1719            children[pid] = client_id
1720
1721        # HERE, we are past all the forks
1722        t = time.time()
1723        print("all forks done in %.1f seconds, waiting %.1f" %
1724              (t - start + delay, t - start),
1725              file=sys.stderr)
1726
1727        while time.time() < end and children:
1728            time.sleep(0.003)
1729            try:
1730                pid, status = os.waitpid(-1, os.WNOHANG)
1731            except OSError as e:
1732                if e.errno != ECHILD:  # no child processes
1733                    raise
1734                break
1735            if pid:
1736                c = children.pop(pid, None)
1737                if DEBUG_LEVEL > 0:
1738                    print(("process %d finished conversation %d;"
1739                           " %d to go" %
1740                           (pid, c, len(children))), file=sys.stderr)
1741                if stop_on_any_error and status != 0:
1742                    break
1743
1744    except Exception:
1745        print("EXCEPTION in parent", file=sys.stderr)
1746        traceback.print_exc()
1747    finally:
1748        context.write_stats('unfinished',
1749                            Unfinished_conversations=len(children))
1750
1751        for s in (15, 15, 9):
1752            print(("killing %d children with -%d" %
1753                   (len(children), s)), file=sys.stderr)
1754            for pid in children:
1755                try:
1756                    os.kill(pid, s)
1757                except OSError as e:
1758                    if e.errno != ESRCH:  # don't fail if it has already died
1759                        raise
1760            time.sleep(0.5)
1761            end = time.time() + 1
1762            while children:
1763                try:
1764                    pid, status = os.waitpid(-1, os.WNOHANG)
1765                except OSError as e:
1766                    if e.errno != ECHILD:
1767                        raise
1768                if pid != 0:
1769                    c = children.pop(pid, None)
1770                    if c is None:
1771                        print("children is %s, no pid found" % children)
1772                        sys.stderr.flush()
1773                        sys.stdout.flush()
1774                        os._exit(1)
1775                    print(("kill -%d %d KILLED conversation; "
1776                           "%d to go" %
1777                           (s, pid, len(children))),
1778                          file=sys.stderr)
1779                if time.time() >= end:
1780                    break
1781
1782            if not children:
1783                break
1784            time.sleep(1)
1785
1786        if children:
1787            print("%d children are missing" % len(children),
1788                  file=sys.stderr)
1789
1790        # there may be stragglers that were forked just as ^C was hit
1791        # and don't appear in the list of children. We can get them
1792        # with killpg, but that will also kill us, so this is^H^H would be
1793        # goodbye, except we cheat and pretend to use ^C (SIG_INTERRUPT),
1794        # so as not to have to fuss around writing signal handlers.
1795        try:
1796            os.killpg(0, 2)
1797        except KeyboardInterrupt:
1798            print("ignoring fake ^C", file=sys.stderr)
1799
1800
1801def openLdb(host, creds, lp):
1802    session = system_session()
1803    ldb = SamDB(url="ldap://%s" % host,
1804                session_info=session,
1805                options=['modules:paged_searches'],
1806                credentials=creds,
1807                lp=lp)
1808    return ldb
1809
1810
1811def ou_name(ldb, instance_id):
1812    """Generate an ou name from the instance id"""
1813    return "ou=instance-%d,ou=traffic_replay,%s" % (instance_id,
1814                                                    ldb.domain_dn())
1815
1816
1817def create_ou(ldb, instance_id):
1818    """Create an ou, all created user and machine accounts will belong to it.
1819
1820    This allows all the created resources to be cleaned up easily.
1821    """
1822    ou = ou_name(ldb, instance_id)
1823    try:
1824        ldb.add({"dn": ou.split(',', 1)[1],
1825                 "objectclass": "organizationalunit"})
1826    except LdbError as e:
1827        (status, _) = e.args
1828        # ignore already exists
1829        if status != 68:
1830            raise
1831    try:
1832        ldb.add({"dn": ou,
1833                 "objectclass": "organizationalunit"})
1834    except LdbError as e:
1835        (status, _) = e.args
1836        # ignore already exists
1837        if status != 68:
1838            raise
1839    return ou
1840
1841
1842# ConversationAccounts holds details of the machine and user accounts
1843# associated with a conversation.
1844#
1845# We use a named tuple to reduce shared memory usage.
1846ConversationAccounts = namedtuple('ConversationAccounts',
1847                                  ('netbios_name',
1848                                   'machinepass',
1849                                   'username',
1850                                   'userpass'))
1851
1852
1853def generate_replay_accounts(ldb, instance_id, number, password):
1854    """Generate a series of unique machine and user account names."""
1855
1856    accounts = []
1857    for i in range(1, number + 1):
1858        netbios_name = machine_name(instance_id, i)
1859        username = user_name(instance_id, i)
1860
1861        account = ConversationAccounts(netbios_name, password, username,
1862                                       password)
1863        accounts.append(account)
1864    return accounts
1865
1866
1867def create_machine_account(ldb, instance_id, netbios_name, machinepass,
1868                           traffic_account=True):
1869    """Create a machine account via ldap."""
1870
1871    ou = ou_name(ldb, instance_id)
1872    dn = "cn=%s,%s" % (netbios_name, ou)
1873    utf16pw = ('"%s"' % get_string(machinepass)).encode('utf-16-le')
1874
1875    if traffic_account:
1876        # we set these bits for the machine account otherwise the replayed
1877        # traffic throws up NT_STATUS_NO_TRUST_SAM_ACCOUNT errors
1878        account_controls = str(UF_TRUSTED_FOR_DELEGATION |
1879                               UF_SERVER_TRUST_ACCOUNT)
1880
1881    else:
1882        account_controls = str(UF_WORKSTATION_TRUST_ACCOUNT)
1883
1884    ldb.add({
1885        "dn": dn,
1886        "objectclass": "computer",
1887        "sAMAccountName": "%s$" % netbios_name,
1888        "userAccountControl": account_controls,
1889        "unicodePwd": utf16pw})
1890
1891
1892def create_user_account(ldb, instance_id, username, userpass):
1893    """Create a user account via ldap."""
1894    ou = ou_name(ldb, instance_id)
1895    user_dn = "cn=%s,%s" % (username, ou)
1896    utf16pw = ('"%s"' % get_string(userpass)).encode('utf-16-le')
1897    ldb.add({
1898        "dn": user_dn,
1899        "objectclass": "user",
1900        "sAMAccountName": username,
1901        "userAccountControl": str(UF_NORMAL_ACCOUNT),
1902        "unicodePwd": utf16pw
1903    })
1904
1905    # grant user write permission to do things like write account SPN
1906    sdutils = sd_utils.SDUtils(ldb)
1907    sdutils.dacl_add_ace(user_dn, "(A;;WP;;;PS)")
1908
1909
1910def create_group(ldb, instance_id, name):
1911    """Create a group via ldap."""
1912
1913    ou = ou_name(ldb, instance_id)
1914    dn = "cn=%s,%s" % (name, ou)
1915    ldb.add({
1916        "dn": dn,
1917        "objectclass": "group",
1918        "sAMAccountName": name,
1919    })
1920
1921
1922def user_name(instance_id, i):
1923    """Generate a user name based in the instance id"""
1924    return "STGU-%d-%d" % (instance_id, i)
1925
1926
1927def search_objectclass(ldb, objectclass='user', attr='sAMAccountName'):
1928    """Seach objectclass, return attr in a set"""
1929    objs = ldb.search(
1930        expression="(objectClass={})".format(objectclass),
1931        attrs=[attr]
1932    )
1933    return {str(obj[attr]) for obj in objs}
1934
1935
1936def generate_users(ldb, instance_id, number, password):
1937    """Add users to the server"""
1938    existing_objects = search_objectclass(ldb, objectclass='user')
1939    users = 0
1940    for i in range(number, 0, -1):
1941        name = user_name(instance_id, i)
1942        if name not in existing_objects:
1943            create_user_account(ldb, instance_id, name, password)
1944            users += 1
1945            if users % 50 == 0:
1946                LOGGER.info("Created %u/%u users" % (users, number))
1947
1948    return users
1949
1950
1951def machine_name(instance_id, i, traffic_account=True):
1952    """Generate a machine account name from instance id."""
1953    if traffic_account:
1954        # traffic accounts correspond to a given user, and use different
1955        # userAccountControl flags to ensure packets get processed correctly
1956        # by the DC
1957        return "STGM-%d-%d" % (instance_id, i)
1958    else:
1959        # Otherwise we're just generating computer accounts to simulate a
1960        # semi-realistic network. These use the default computer
1961        # userAccountControl flags, so we use a different account name so that
1962        # we don't try to use them when generating packets
1963        return "PC-%d-%d" % (instance_id, i)
1964
1965
1966def generate_machine_accounts(ldb, instance_id, number, password,
1967                              traffic_account=True):
1968    """Add machine accounts to the server"""
1969    existing_objects = search_objectclass(ldb, objectclass='computer')
1970    added = 0
1971    for i in range(number, 0, -1):
1972        name = machine_name(instance_id, i, traffic_account)
1973        if name + "$" not in existing_objects:
1974            create_machine_account(ldb, instance_id, name, password,
1975                                   traffic_account)
1976            added += 1
1977            if added % 50 == 0:
1978                LOGGER.info("Created %u/%u machine accounts" % (added, number))
1979
1980    return added
1981
1982
1983def group_name(instance_id, i):
1984    """Generate a group name from instance id."""
1985    return "STGG-%d-%d" % (instance_id, i)
1986
1987
1988def generate_groups(ldb, instance_id, number):
1989    """Create the required number of groups on the server."""
1990    existing_objects = search_objectclass(ldb, objectclass='group')
1991    groups = 0
1992    for i in range(number, 0, -1):
1993        name = group_name(instance_id, i)
1994        if name not in existing_objects:
1995            create_group(ldb, instance_id, name)
1996            groups += 1
1997            if groups % 1000 == 0:
1998                LOGGER.info("Created %u/%u groups" % (groups, number))
1999
2000    return groups
2001
2002
2003def clean_up_accounts(ldb, instance_id):
2004    """Remove the created accounts and groups from the server."""
2005    ou = ou_name(ldb, instance_id)
2006    try:
2007        ldb.delete(ou, ["tree_delete:1"])
2008    except LdbError as e:
2009        (status, _) = e.args
2010        # ignore does not exist
2011        if status != 32:
2012            raise
2013
2014
2015def generate_users_and_groups(ldb, instance_id, password,
2016                              number_of_users, number_of_groups,
2017                              group_memberships, max_members,
2018                              machine_accounts, traffic_accounts=True):
2019    """Generate the required users and groups, allocating the users to
2020       those groups."""
2021    memberships_added = 0
2022    groups_added = 0
2023    computers_added = 0
2024
2025    create_ou(ldb, instance_id)
2026
2027    LOGGER.info("Generating dummy user accounts")
2028    users_added = generate_users(ldb, instance_id, number_of_users, password)
2029
2030    LOGGER.info("Generating dummy machine accounts")
2031    computers_added = generate_machine_accounts(ldb, instance_id,
2032                                                machine_accounts, password,
2033                                                traffic_accounts)
2034
2035    if number_of_groups > 0:
2036        LOGGER.info("Generating dummy groups")
2037        groups_added = generate_groups(ldb, instance_id, number_of_groups)
2038
2039    if group_memberships > 0:
2040        LOGGER.info("Assigning users to groups")
2041        assignments = GroupAssignments(number_of_groups,
2042                                       groups_added,
2043                                       number_of_users,
2044                                       users_added,
2045                                       group_memberships,
2046                                       max_members)
2047        LOGGER.info("Adding users to groups")
2048        add_users_to_groups(ldb, instance_id, assignments)
2049        memberships_added = assignments.total()
2050
2051    if (groups_added > 0 and users_added == 0 and
2052       number_of_groups != groups_added):
2053        LOGGER.warning("The added groups will contain no members")
2054
2055    LOGGER.info("Added %d users (%d machines), %d groups and %d memberships" %
2056                (users_added, computers_added, groups_added,
2057                 memberships_added))
2058
2059
2060class GroupAssignments(object):
2061    def __init__(self, number_of_groups, groups_added, number_of_users,
2062                 users_added, group_memberships, max_members):
2063
2064        self.count = 0
2065        self.generate_group_distribution(number_of_groups)
2066        self.generate_user_distribution(number_of_users, group_memberships)
2067        self.max_members = max_members
2068        self.assignments = defaultdict(list)
2069        self.assign_groups(number_of_groups, groups_added, number_of_users,
2070                           users_added, group_memberships)
2071
2072    def cumulative_distribution(self, weights):
2073        # make sure the probabilities conform to a cumulative distribution
2074        # spread between 0.0 and 1.0. Dividing by the weighted total gives each
2075        # probability a proportional share of 1.0. Higher probabilities get a
2076        # bigger share, so are more likely to be picked. We use the cumulative
2077        # value, so we can use random.random() as a simple index into the list
2078        dist = []
2079        total = sum(weights)
2080        if total == 0:
2081            return None
2082
2083        cumulative = 0.0
2084        for probability in weights:
2085            cumulative += probability
2086            dist.append(cumulative / total)
2087        return dist
2088
2089    def generate_user_distribution(self, num_users, num_memberships):
2090        """Probability distribution of a user belonging to a group.
2091        """
2092        # Assign a weighted probability to each user. Use the Pareto
2093        # Distribution so that some users are in a lot of groups, and the
2094        # bulk of users are in only a few groups. If we're assigning a large
2095        # number of group memberships, use a higher shape. This means slightly
2096        # fewer outlying users that are in large numbers of groups. The aim is
2097        # to have no users belonging to more than ~500 groups.
2098        if num_memberships > 5000000:
2099            shape = 3.0
2100        elif num_memberships > 2000000:
2101            shape = 2.5
2102        elif num_memberships > 300000:
2103            shape = 2.25
2104        else:
2105            shape = 1.75
2106
2107        weights = []
2108        for x in range(1, num_users + 1):
2109            p = random.paretovariate(shape)
2110            weights.append(p)
2111
2112        # convert the weights to a cumulative distribution between 0.0 and 1.0
2113        self.user_dist = self.cumulative_distribution(weights)
2114
2115    def generate_group_distribution(self, n):
2116        """Probability distribution of a group containing a user."""
2117
2118        # Assign a weighted probability to each user. Probability decreases
2119        # as the group-ID increases
2120        weights = []
2121        for x in range(1, n + 1):
2122            p = 1 / (x**1.3)
2123            weights.append(p)
2124
2125        # convert the weights to a cumulative distribution between 0.0 and 1.0
2126        self.group_weights = weights
2127        self.group_dist = self.cumulative_distribution(weights)
2128
2129    def generate_random_membership(self):
2130        """Returns a randomly generated user-group membership"""
2131
2132        # the list items are cumulative distribution values between 0.0 and
2133        # 1.0, which makes random() a handy way to index the list to get a
2134        # weighted random user/group. (Here the user/group returned are
2135        # zero-based array indexes)
2136        user = bisect.bisect(self.user_dist, random.random())
2137        group = bisect.bisect(self.group_dist, random.random())
2138
2139        return user, group
2140
2141    def users_in_group(self, group):
2142        return self.assignments[group]
2143
2144    def get_groups(self):
2145        return self.assignments.keys()
2146
2147    def cap_group_membership(self, group, max_members):
2148        """Prevent the group's membership from exceeding the max specified"""
2149        num_members = len(self.assignments[group])
2150        if num_members >= max_members:
2151            LOGGER.info("Group {0} has {1} members".format(group, num_members))
2152
2153            # remove this group and then recalculate the cumulative
2154            # distribution, so this group is no longer selected
2155            self.group_weights[group - 1] = 0
2156            new_dist = self.cumulative_distribution(self.group_weights)
2157            self.group_dist = new_dist
2158
2159    def add_assignment(self, user, group):
2160        # the assignments are stored in a dictionary where key=group,
2161        # value=list-of-users-in-group (indexing by group-ID allows us to
2162        # optimize for DB membership writes)
2163        if user not in self.assignments[group]:
2164            self.assignments[group].append(user)
2165            self.count += 1
2166
2167        # check if there'a cap on how big the groups can grow
2168        if self.max_members:
2169            self.cap_group_membership(group, self.max_members)
2170
2171    def assign_groups(self, number_of_groups, groups_added,
2172                      number_of_users, users_added, group_memberships):
2173        """Allocate users to groups.
2174
2175        The intention is to have a few users that belong to most groups, while
2176        the majority of users belong to a few groups.
2177
2178        A few groups will contain most users, with the remaining only having a
2179        few users.
2180        """
2181
2182        if group_memberships <= 0:
2183            return
2184
2185        # Calculate the number of group menberships required
2186        group_memberships = math.ceil(
2187            float(group_memberships) *
2188            (float(users_added) / float(number_of_users)))
2189
2190        if self.max_members:
2191            group_memberships = min(group_memberships,
2192                                    self.max_members * number_of_groups)
2193
2194        existing_users  = number_of_users  - users_added  - 1
2195        existing_groups = number_of_groups - groups_added - 1
2196        while self.total() < group_memberships:
2197            user, group = self.generate_random_membership()
2198
2199            if group > existing_groups or user > existing_users:
2200                # the + 1 converts the array index to the corresponding
2201                # group or user number
2202                self.add_assignment(user + 1, group + 1)
2203
2204    def total(self):
2205        return self.count
2206
2207
2208def add_users_to_groups(db, instance_id, assignments):
2209    """Takes the assignments of users to groups and applies them to the DB."""
2210
2211    total = assignments.total()
2212    count = 0
2213    added = 0
2214
2215    for group in assignments.get_groups():
2216        users_in_group = assignments.users_in_group(group)
2217        if len(users_in_group) == 0:
2218            continue
2219
2220        # Split up the users into chunks, so we write no more than 1K at a
2221        # time. (Minimizing the DB modifies is more efficient, but writing
2222        # 10K+ users to a single group becomes inefficient memory-wise)
2223        for chunk in range(0, len(users_in_group), 1000):
2224            chunk_of_users = users_in_group[chunk:chunk + 1000]
2225            add_group_members(db, instance_id, group, chunk_of_users)
2226
2227            added += len(chunk_of_users)
2228            count += 1
2229            if count % 50 == 0:
2230                LOGGER.info("Added %u/%u memberships" % (added, total))
2231
2232def add_group_members(db, instance_id, group, users_in_group):
2233    """Adds the given users to group specified."""
2234
2235    ou = ou_name(db, instance_id)
2236
2237    def build_dn(name):
2238        return("cn=%s,%s" % (name, ou))
2239
2240    group_dn = build_dn(group_name(instance_id, group))
2241    m = ldb.Message()
2242    m.dn = ldb.Dn(db, group_dn)
2243
2244    for user in users_in_group:
2245        user_dn = build_dn(user_name(instance_id, user))
2246        idx = "member-" + str(user)
2247        m[idx] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
2248
2249    db.modify(m)
2250
2251
2252def generate_stats(statsdir, timing_file):
2253    """Generate and print the summary stats for a run."""
2254    first      = sys.float_info.max
2255    last       = 0
2256    successful = 0
2257    failed     = 0
2258    latencies  = {}
2259    failures   = Counter()
2260    unique_conversations = set()
2261    if timing_file is not None:
2262        tw = timing_file.write
2263    else:
2264        def tw(x):
2265            pass
2266
2267    tw("time\tconv\tprotocol\ttype\tduration\tsuccessful\terror\n")
2268
2269    float_values = {
2270        'Maximum lag': 0,
2271        'Start lag': 0,
2272        'Max sleep miss': 0,
2273    }
2274    int_values = {
2275        'Planned_conversations': 0,
2276        'Planned_packets': 0,
2277        'Unfinished_conversations': 0,
2278    }
2279
2280    for filename in os.listdir(statsdir):
2281        path = os.path.join(statsdir, filename)
2282        with open(path, 'r') as f:
2283            for line in f:
2284                try:
2285                    fields       = line.rstrip('\n').split('\t')
2286                    conversation = fields[1]
2287                    protocol     = fields[2]
2288                    packet_type  = fields[3]
2289                    latency      = float(fields[4])
2290                    t = float(fields[0])
2291                    first        = min(t - latency, first)
2292                    last         = max(t, last)
2293
2294                    op = (protocol, packet_type)
2295                    latencies.setdefault(op, []).append(latency)
2296                    if fields[5] == 'True':
2297                        successful += 1
2298                    else:
2299                        failed += 1
2300                        failures[op] += 1
2301
2302                    unique_conversations.add(conversation)
2303
2304                    tw(line)
2305                except (ValueError, IndexError):
2306                    if ':' in line:
2307                        k, v = line.split(':', 1)
2308                        if k in float_values:
2309                            float_values[k] = max(float(v),
2310                                                  float_values[k])
2311                        elif k in int_values:
2312                            int_values[k] = max(int(v),
2313                                                int_values[k])
2314                        else:
2315                            print(line, file=sys.stderr)
2316                    else:
2317                        # not a valid line print and ignore
2318                        print(line, file=sys.stderr)
2319
2320    duration = last - first
2321    if successful == 0:
2322        success_rate = 0
2323    else:
2324        success_rate = successful / duration
2325    if failed == 0:
2326        failure_rate = 0
2327    else:
2328        failure_rate = failed / duration
2329
2330    conversations = len(unique_conversations)
2331
2332    print("Total conversations:   %10d" % conversations)
2333    print("Successful operations: %10d (%.3f per second)"
2334          % (successful, success_rate))
2335    print("Failed operations:     %10d (%.3f per second)"
2336          % (failed, failure_rate))
2337
2338    for k, v in sorted(float_values.items()):
2339        print("%-28s %f" % (k.replace('_', ' ') + ':', v))
2340    for k, v in sorted(int_values.items()):
2341        print("%-28s %d" % (k.replace('_', ' ') + ':', v))
2342
2343    print("Protocol    Op Code  Description                               "
2344          " Count       Failed         Mean       Median          "
2345          "95%        Range          Max")
2346
2347    ops = {}
2348    for proto, packet in latencies:
2349        if proto not in ops:
2350            ops[proto] = set()
2351        ops[proto].add(packet)
2352    protocols = sorted(ops.keys())
2353
2354    for protocol in protocols:
2355        packet_types = sorted(ops[protocol], key=opcode_key)
2356        for packet_type in packet_types:
2357            op = (protocol, packet_type)
2358            values     = latencies[op]
2359            values     = sorted(values)
2360            count      = len(values)
2361            failed     = failures[op]
2362            mean       = sum(values) / count
2363            median     = calc_percentile(values, 0.50)
2364            percentile = calc_percentile(values, 0.95)
2365            rng        = values[-1] - values[0]
2366            maxv       = values[-1]
2367            desc       = OP_DESCRIPTIONS.get(op, '')
2368            print("%-12s   %4s  %-35s %12d %12d %12.6f "
2369                  "%12.6f %12.6f %12.6f %12.6f"
2370                  % (protocol,
2371                     packet_type,
2372                     desc,
2373                     count,
2374                     failed,
2375                     mean,
2376                     median,
2377                     percentile,
2378                     rng,
2379                     maxv))
2380
2381
2382def opcode_key(v):
2383    """Sort key for the operation code to ensure that it sorts numerically"""
2384    try:
2385        return "%03d" % int(v)
2386    except ValueError:
2387        return v
2388
2389
2390def calc_percentile(values, percentile):
2391    """Calculate the specified percentile from the list of values.
2392
2393    Assumes the list is sorted in ascending order.
2394    """
2395
2396    if not values:
2397        return 0
2398    k = (len(values) - 1) * percentile
2399    f = math.floor(k)
2400    c = math.ceil(k)
2401    if f == c:
2402        return values[int(k)]
2403    d0 = values[int(f)] * (c - k)
2404    d1 = values[int(c)] * (k - f)
2405    return d0 + d1
2406
2407
2408def mk_masked_dir(*path):
2409    """In a testenv we end up with 0777 directories that look an alarming
2410    green colour with ls. Use umask to avoid that."""
2411    # py3 os.mkdir can do this
2412    d = os.path.join(*path)
2413    mask = os.umask(0o077)
2414    os.mkdir(d)
2415    os.umask(mask)
2416    return d
2417